import torch.nn as nn import torch.nn.functional as F from .torch_utils import * def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1, validation=False): # Output would be invalid if it does not satisfy this assert # assert (output.size(1) == (5 + num_classes) * num_anchors) # print(output.size()) # Slice the second dimension (channel) of output into: # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] # And then into # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] batch = output.size(0) H = output.size(2) W = output.size(3) bxy_list = [] bwh_list = [] det_confs_list = [] cls_confs_list = [] for i in range(num_anchors): begin = i * (5 + num_classes) end = (i + 1) * (5 + num_classes) bxy_list.append(output[:, begin : begin + 2]) bwh_list.append(output[:, begin + 2 : begin + 4]) det_confs_list.append(output[:, begin + 4 : begin + 5]) cls_confs_list.append(output[:, begin + 5 : end]) # Shape: [batch, num_anchors * 2, H, W] bxy = torch.cat(bxy_list, dim=1) # Shape: [batch, num_anchors * 2, H, W] bwh = torch.cat(bwh_list, dim=1) # Shape: [batch, num_anchors, H, W] det_confs = torch.cat(det_confs_list, dim=1) # Shape: [batch, num_anchors * H * W] det_confs = det_confs.view(batch, num_anchors * H * W) # Shape: [batch, num_anchors * num_classes, H, W] cls_confs = torch.cat(cls_confs_list, dim=1) # Shape: [batch, num_anchors, num_classes, H * W] cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W) # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, num_classes) # Apply sigmoid(), exp() and softmax() to slices # bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1) bwh = torch.exp(bwh) det_confs = torch.sigmoid(det_confs) cls_confs = torch.sigmoid(cls_confs) # Prepare C-x, C-y, P-w, P-h (None of them are torch related) grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0) grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0), axis=0) # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) anchor_w = [] anchor_h = [] for i in range(num_anchors): anchor_w.append(anchors[i * 2]) anchor_h.append(anchors[i * 2 + 1]) device = None cuda_check = output.is_cuda if cuda_check: device = output.get_device() bx_list = [] by_list = [] bw_list = [] bh_list = [] # Apply C-x, C-y, P-w, P-h for i in range(num_anchors): ii = i * 2 # Shape: [batch, 1, H, W] bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32) # Shape: [batch, 1, H, W] by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32) # Shape: [batch, 1, H, W] bw = bwh[:, ii : ii + 1] * anchor_w[i] # Shape: [batch, 1, H, W] bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] bx_list.append(bx) by_list.append(by) bw_list.append(bw) bh_list.append(bh) ######################################## # Figure out bboxes from slices # ######################################## # Shape: [batch, num_anchors, H, W] bx = torch.cat(bx_list, dim=1) # Shape: [batch, num_anchors, H, W] by = torch.cat(by_list, dim=1) # Shape: [batch, num_anchors, H, W] bw = torch.cat(bw_list, dim=1) # Shape: [batch, num_anchors, H, W] bh = torch.cat(bh_list, dim=1) # Shape: [batch, 2 * num_anchors, H, W] bx_bw = torch.cat((bx, bw), dim=1) # Shape: [batch, 2 * num_anchors, H, W] by_bh = torch.cat((by, bh), dim=1) # normalize coordinates to [0, 1] bx_bw /= W by_bh /= H # Shape: [batch, num_anchors * H * W, 1] bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1) by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1) bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1) bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1) bx1 = bx - bw * 0.5 by1 = by - bh * 0.5 bx2 = bx1 + bw by2 = by1 + bh # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4) # boxes = boxes.repeat(1, 1, num_classes, 1) # boxes: [batch, num_anchors * H * W, 1, 4] # cls_confs: [batch, num_anchors * H * W, num_classes] # det_confs: [batch, num_anchors * H * W] det_confs = det_confs.view(batch, num_anchors * H * W, 1) confs = cls_confs * det_confs # boxes: [batch, num_anchors * H * W, 1, 4] # confs: [batch, num_anchors * H * W, num_classes] return boxes, confs def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1, validation=False): # Output would be invalid if it does not satisfy this assert # assert (output.size(1) == (5 + num_classes) * num_anchors) # print(output.size()) # Slice the second dimension (channel) of output into: # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] # And then into # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] # batch = output.size(0) # H = output.size(2) # W = output.size(3) bxy_list = [] bwh_list = [] det_confs_list = [] cls_confs_list = [] for i in range(num_anchors): begin = i * (5 + num_classes) end = (i + 1) * (5 + num_classes) bxy_list.append(output[:, begin : begin + 2]) bwh_list.append(output[:, begin + 2 : begin + 4]) det_confs_list.append(output[:, begin + 4 : begin + 5]) cls_confs_list.append(output[:, begin + 5 : end]) # Shape: [batch, num_anchors * 2, H, W] bxy = torch.cat(bxy_list, dim=1) # Shape: [batch, num_anchors * 2, H, W] bwh = torch.cat(bwh_list, dim=1) # Shape: [batch, num_anchors, H, W] det_confs = torch.cat(det_confs_list, dim=1) # Shape: [batch, num_anchors * H * W] det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3)) # Shape: [batch, num_anchors * num_classes, H, W] cls_confs = torch.cat(cls_confs_list, dim=1) # Shape: [batch, num_anchors, num_classes, H * W] cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3)) # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes) # Apply sigmoid(), exp() and softmax() to slices # bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1) bwh = torch.exp(bwh) det_confs = torch.sigmoid(det_confs) cls_confs = torch.sigmoid(cls_confs) # Prepare C-x, C-y, P-w, P-h (None of them are torch related) grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0) grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0) # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) anchor_w = [] anchor_h = [] for i in range(num_anchors): anchor_w.append(anchors[i * 2]) anchor_h.append(anchors[i * 2 + 1]) device = None cuda_check = output.is_cuda if cuda_check: device = output.get_device() bx_list = [] by_list = [] bw_list = [] bh_list = [] # Apply C-x, C-y, P-w, P-h for i in range(num_anchors): ii = i * 2 # Shape: [batch, 1, H, W] bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32) # Shape: [batch, 1, H, W] by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32) # Shape: [batch, 1, H, W] bw = bwh[:, ii : ii + 1] * anchor_w[i] # Shape: [batch, 1, H, W] bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] bx_list.append(bx) by_list.append(by) bw_list.append(bw) bh_list.append(bh) ######################################## # Figure out bboxes from slices # ######################################## # Shape: [batch, num_anchors, H, W] bx = torch.cat(bx_list, dim=1) # Shape: [batch, num_anchors, H, W] by = torch.cat(by_list, dim=1) # Shape: [batch, num_anchors, H, W] bw = torch.cat(bw_list, dim=1) # Shape: [batch, num_anchors, H, W] bh = torch.cat(bh_list, dim=1) # Shape: [batch, 2 * num_anchors, H, W] bx_bw = torch.cat((bx, bw), dim=1) # Shape: [batch, 2 * num_anchors, H, W] by_bh = torch.cat((by, bh), dim=1) # normalize coordinates to [0, 1] bx_bw /= output.size(3) by_bh /= output.size(2) # Shape: [batch, num_anchors * H * W, 1] bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) bx1 = bx - bw * 0.5 by1 = by - bh * 0.5 bx2 = bx1 + bw by2 = by1 + bh # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4) # boxes = boxes.repeat(1, 1, num_classes, 1) # boxes: [batch, num_anchors * H * W, 1, 4] # cls_confs: [batch, num_anchors * H * W, num_classes] # det_confs: [batch, num_anchors * H * W] det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) confs = cls_confs * det_confs # boxes: [batch, num_anchors * H * W, 1, 4] # confs: [batch, num_anchors * H * W, num_classes] return boxes, confs class YoloLayer(nn.Module): ''' Yolo layer model_out: while inference,is post-processing inside or outside the model true:outside ''' def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False): super(YoloLayer, self).__init__() self.anchor_mask = anchor_mask self.num_classes = num_classes self.anchors = anchors self.num_anchors = num_anchors self.anchor_step = len(anchors) // num_anchors self.coord_scale = 1 self.noobject_scale = 1 self.object_scale = 5 self.class_scale = 1 self.thresh = 0.6 self.stride = stride self.seen = 0 self.scale_x_y = 1 self.model_out = model_out def forward(self, output, target=None): if self.training: return output masked_anchors = [] for m in self.anchor_mask: masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step] masked_anchors = [anchor / self.stride for anchor in masked_anchors] return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)