EasyMocap/easymocap/estimator/YOLOv4/region_loss.py

196 lines
9.4 KiB
Python
Raw Normal View History

2022-08-22 00:07:46 +08:00
import torch.nn as nn
import torch.nn.functional as F
from .torch_utils import *
def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale,
sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors) / num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx = torch.zeros(nB, nA, nH, nW)
ty = torch.zeros(nB, nA, nH, nW)
tw = torch.zeros(nB, nA, nH, nW)
th = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
nAnchors = nA * nH * nW
nPixels = nH * nW
for b in range(nB):
cur_pred_boxes = pred_boxes[b * nAnchors:(b + 1) * nAnchors].t()
cur_ious = torch.zeros(nAnchors)
for t in range(50):
if target[b][t * 5 + 1] == 0:
break
gx = target[b][t * 5 + 1] * nW
gy = target[b][t * 5 + 2] * nH
gw = target[b][t * 5 + 3] * nW
gh = target[b][t * 5 + 4] * nH
cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).repeat(nAnchors, 1).t()
cur_ious = torch.max(cur_ious, bbox_ious(cur_pred_boxes, cur_gt_boxes, x1y1x2y2=False))
conf_mask[b][cur_ious > sil_thresh] = 0
if seen < 12800:
if anchor_step == 4:
tx = torch.FloatTensor(anchors).view(nA, anchor_step).index_select(1, torch.LongTensor([2])).view(1, nA, 1,
1).repeat(
nB, 1, nH, nW)
ty = torch.FloatTensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([2])).view(
1, nA, 1, 1).repeat(nB, 1, nH, nW)
else:
tx.fill_(0.5)
ty.fill_(0.5)
tw.zero_()
th.zero_()
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in range(nB):
for t in range(50):
if target[b][t * 5 + 1] == 0:
break
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx = target[b][t * 5 + 1] * nW
gy = target[b][t * 5 + 2] * nH
gi = int(gx)
gj = int(gy)
gw = target[b][t * 5 + 3] * nW
gh = target[b][t * 5 + 4] * nH
gt_box = [0, 0, gw, gh]
for n in range(nA):
aw = anchors[anchor_step * n]
ah = anchors[anchor_step * n + 1]
anchor_box = [0, 0, aw, ah]
iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False)
if anchor_step == 4:
ax = anchors[anchor_step * n + 2]
ay = anchors[anchor_step * n + 3]
dist = pow(((gi + ax) - gx), 2) + pow(((gj + ay) - gy), 2)
if iou > best_iou:
best_iou = iou
best_n = n
elif anchor_step == 4 and iou == best_iou and dist < min_dist:
best_iou = iou
best_n = n
min_dist = dist
gt_box = [gx, gy, gw, gh]
pred_box = pred_boxes[b * nAnchors + best_n * nPixels + gj * nW + gi]
coord_mask[b][best_n][gj][gi] = 1
cls_mask[b][best_n][gj][gi] = 1
conf_mask[b][best_n][gj][gi] = object_scale
tx[b][best_n][gj][gi] = target[b][t * 5 + 1] * nW - gi
ty[b][best_n][gj][gi] = target[b][t * 5 + 2] * nH - gj
tw[b][best_n][gj][gi] = math.log(gw / anchors[anchor_step * best_n])
th[b][best_n][gj][gi] = math.log(gh / anchors[anchor_step * best_n + 1])
iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) # best_iou
tconf[b][best_n][gj][gi] = iou
tcls[b][best_n][gj][gi] = target[b][t * 5]
if iou > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=1):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors) / num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
def forward(self, output, target):
# output : BxAs*(4+1+num_classes)*H*W
t0 = time.time()
nB = output.data.size(0)
nA = self.num_anchors
nC = self.num_classes
nH = output.data.size(2)
nW = output.data.size(3)
output = output.view(nB, nA, (5 + nC), nH, nW)
x = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
w = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
h = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(5, 5 + nC - 1, nC).long().cuda()))
cls = cls.view(nB * nA, nC, nH * nW).transpose(1, 2).contiguous().view(nB * nA * nH * nW, nC)
t1 = time.time()
pred_boxes = torch.cuda.FloatTensor(4, nB * nA * nH * nW)
grid_x = torch.linspace(0, nW - 1, nW).repeat(nH, 1).repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda()
grid_y = torch.linspace(0, nH - 1, nH).repeat(nW, 1).t().repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda()
anchor_w = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([0])).cuda()
anchor_h = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([1])).cuda()
anchor_w = anchor_w.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW)
anchor_h = anchor_h.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW)
pred_boxes[0] = x.data + grid_x
pred_boxes[1] = y.data + grid_y
pred_boxes[2] = torch.exp(w.data) * anchor_w
pred_boxes[3] = torch.exp(h.data) * anchor_h
pred_boxes = convert2cpu(pred_boxes.transpose(0, 1).contiguous().view(-1, 4))
t2 = time.time()
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes,
target.data,
self.anchors, nA,
nC, \
nH, nW,
self.noobject_scale,
self.object_scale,
self.thresh,
self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx = Variable(tx.cuda())
ty = Variable(ty.cuda())
tw = Variable(tw.cuda())
th = Variable(th.cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1, nC).cuda())
cls = cls[cls_mask].view(-1, nC)
t3 = time.time()
loss_x = self.coord_scale * nn.MSELoss(reduction='sum')(x * coord_mask, tx * coord_mask) / 2.0
loss_y = self.coord_scale * nn.MSELoss(reduction='sum')(y * coord_mask, ty * coord_mask) / 2.0
loss_w = self.coord_scale * nn.MSELoss(reduction='sum')(w * coord_mask, tw * coord_mask) / 2.0
loss_h = self.coord_scale * nn.MSELoss(reduction='sum')(h * coord_mask, th * coord_mask) / 2.0
loss_conf = nn.MSELoss(reduction='sum')(conf * conf_mask, tconf * conf_mask) / 2.0
loss_cls = self.class_scale * nn.CrossEntropyLoss(reduction='sum')(cls, tcls)
loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
t4 = time.time()
if False:
print('-----------------------------------')
print(' activation : %f' % (t1 - t0))
print(' create pred_boxes : %f' % (t2 - t1))
print(' build targets : %f' % (t3 - t2))
print(' create loss : %f' % (t4 - t3))
print(' total : %f' % (t4 - t0))
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f' % (
self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_w.data[0], loss_h.data[0],
loss_conf.data[0], loss_cls.data[0], loss.data[0]))
return loss