import torch.nn as nn import torch.nn.functional as F from .torch_utils import * def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen): nB = target.size(0) nA = num_anchors nC = num_classes anchor_step = len(anchors) / num_anchors conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale coord_mask = torch.zeros(nB, nA, nH, nW) cls_mask = torch.zeros(nB, nA, nH, nW) tx = torch.zeros(nB, nA, nH, nW) ty = torch.zeros(nB, nA, nH, nW) tw = torch.zeros(nB, nA, nH, nW) th = torch.zeros(nB, nA, nH, nW) tconf = torch.zeros(nB, nA, nH, nW) tcls = torch.zeros(nB, nA, nH, nW) nAnchors = nA * nH * nW nPixels = nH * nW for b in range(nB): cur_pred_boxes = pred_boxes[b * nAnchors:(b + 1) * nAnchors].t() cur_ious = torch.zeros(nAnchors) for t in range(50): if target[b][t * 5 + 1] == 0: break gx = target[b][t * 5 + 1] * nW gy = target[b][t * 5 + 2] * nH gw = target[b][t * 5 + 3] * nW gh = target[b][t * 5 + 4] * nH cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).repeat(nAnchors, 1).t() cur_ious = torch.max(cur_ious, bbox_ious(cur_pred_boxes, cur_gt_boxes, x1y1x2y2=False)) conf_mask[b][cur_ious > sil_thresh] = 0 if seen < 12800: if anchor_step == 4: tx = torch.FloatTensor(anchors).view(nA, anchor_step).index_select(1, torch.LongTensor([2])).view(1, nA, 1, 1).repeat( nB, 1, nH, nW) ty = torch.FloatTensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([2])).view( 1, nA, 1, 1).repeat(nB, 1, nH, nW) else: tx.fill_(0.5) ty.fill_(0.5) tw.zero_() th.zero_() coord_mask.fill_(1) nGT = 0 nCorrect = 0 for b in range(nB): for t in range(50): if target[b][t * 5 + 1] == 0: break nGT = nGT + 1 best_iou = 0.0 best_n = -1 min_dist = 10000 gx = target[b][t * 5 + 1] * nW gy = target[b][t * 5 + 2] * nH gi = int(gx) gj = int(gy) gw = target[b][t * 5 + 3] * nW gh = target[b][t * 5 + 4] * nH gt_box = [0, 0, gw, gh] for n in range(nA): aw = anchors[anchor_step * n] ah = anchors[anchor_step * n + 1] anchor_box = [0, 0, aw, ah] iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False) if anchor_step == 4: ax = anchors[anchor_step * n + 2] ay = anchors[anchor_step * n + 3] dist = pow(((gi + ax) - gx), 2) + pow(((gj + ay) - gy), 2) if iou > best_iou: best_iou = iou best_n = n elif anchor_step == 4 and iou == best_iou and dist < min_dist: best_iou = iou best_n = n min_dist = dist gt_box = [gx, gy, gw, gh] pred_box = pred_boxes[b * nAnchors + best_n * nPixels + gj * nW + gi] coord_mask[b][best_n][gj][gi] = 1 cls_mask[b][best_n][gj][gi] = 1 conf_mask[b][best_n][gj][gi] = object_scale tx[b][best_n][gj][gi] = target[b][t * 5 + 1] * nW - gi ty[b][best_n][gj][gi] = target[b][t * 5 + 2] * nH - gj tw[b][best_n][gj][gi] = math.log(gw / anchors[anchor_step * best_n]) th[b][best_n][gj][gi] = math.log(gh / anchors[anchor_step * best_n + 1]) iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) # best_iou tconf[b][best_n][gj][gi] = iou tcls[b][best_n][gj][gi] = target[b][t * 5] if iou > 0.5: nCorrect = nCorrect + 1 return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls class RegionLoss(nn.Module): def __init__(self, num_classes=0, anchors=[], num_anchors=1): super(RegionLoss, self).__init__() self.num_classes = num_classes self.anchors = anchors self.num_anchors = num_anchors self.anchor_step = len(anchors) / num_anchors self.coord_scale = 1 self.noobject_scale = 1 self.object_scale = 5 self.class_scale = 1 self.thresh = 0.6 self.seen = 0 def forward(self, output, target): # output : BxAs*(4+1+num_classes)*H*W t0 = time.time() nB = output.data.size(0) nA = self.num_anchors nC = self.num_classes nH = output.data.size(2) nW = output.data.size(3) output = output.view(nB, nA, (5 + nC), nH, nW) x = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW)) y = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW)) w = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW) h = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW) conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)) cls = output.index_select(2, Variable(torch.linspace(5, 5 + nC - 1, nC).long().cuda())) cls = cls.view(nB * nA, nC, nH * nW).transpose(1, 2).contiguous().view(nB * nA * nH * nW, nC) t1 = time.time() pred_boxes = torch.cuda.FloatTensor(4, nB * nA * nH * nW) grid_x = torch.linspace(0, nW - 1, nW).repeat(nH, 1).repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda() grid_y = torch.linspace(0, nH - 1, nH).repeat(nW, 1).t().repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda() anchor_w = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([0])).cuda() anchor_h = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([1])).cuda() anchor_w = anchor_w.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW) anchor_h = anchor_h.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW) pred_boxes[0] = x.data + grid_x pred_boxes[1] = y.data + grid_y pred_boxes[2] = torch.exp(w.data) * anchor_w pred_boxes[3] = torch.exp(h.data) * anchor_h pred_boxes = convert2cpu(pred_boxes.transpose(0, 1).contiguous().view(-1, 4)) t2 = time.time() nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes, target.data, self.anchors, nA, nC, \ nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen) cls_mask = (cls_mask == 1) nProposals = int((conf > 0.25).sum().data[0]) tx = Variable(tx.cuda()) ty = Variable(ty.cuda()) tw = Variable(tw.cuda()) th = Variable(th.cuda()) tconf = Variable(tconf.cuda()) tcls = Variable(tcls.view(-1)[cls_mask].long().cuda()) coord_mask = Variable(coord_mask.cuda()) conf_mask = Variable(conf_mask.cuda().sqrt()) cls_mask = Variable(cls_mask.view(-1, 1).repeat(1, nC).cuda()) cls = cls[cls_mask].view(-1, nC) t3 = time.time() loss_x = self.coord_scale * nn.MSELoss(reduction='sum')(x * coord_mask, tx * coord_mask) / 2.0 loss_y = self.coord_scale * nn.MSELoss(reduction='sum')(y * coord_mask, ty * coord_mask) / 2.0 loss_w = self.coord_scale * nn.MSELoss(reduction='sum')(w * coord_mask, tw * coord_mask) / 2.0 loss_h = self.coord_scale * nn.MSELoss(reduction='sum')(h * coord_mask, th * coord_mask) / 2.0 loss_conf = nn.MSELoss(reduction='sum')(conf * conf_mask, tconf * conf_mask) / 2.0 loss_cls = self.class_scale * nn.CrossEntropyLoss(reduction='sum')(cls, tcls) loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls t4 = time.time() if False: print('-----------------------------------') print(' activation : %f' % (t1 - t0)) print(' create pred_boxes : %f' % (t2 - t1)) print(' build targets : %f' % (t3 - t2)) print(' create loss : %f' % (t4 - t3)) print(' total : %f' % (t4 - t0)) print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f' % ( self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_w.data[0], loss_h.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0])) return loss