EasyMocap/easymocap/estimator/YOLOv4/yolo_layer.py

323 lines
12 KiB
Python
Raw Normal View History

2022-08-22 00:07:46 +08:00
import torch.nn as nn
import torch.nn.functional as F
from .torch_utils import *
def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
validation=False):
# Output would be invalid if it does not satisfy this assert
# assert (output.size(1) == (5 + num_classes) * num_anchors)
# print(output.size())
# Slice the second dimension (channel) of output into:
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
# And then into
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
batch = output.size(0)
H = output.size(2)
W = output.size(3)
bxy_list = []
bwh_list = []
det_confs_list = []
cls_confs_list = []
for i in range(num_anchors):
begin = i * (5 + num_classes)
end = (i + 1) * (5 + num_classes)
bxy_list.append(output[:, begin : begin + 2])
bwh_list.append(output[:, begin + 2 : begin + 4])
det_confs_list.append(output[:, begin + 4 : begin + 5])
cls_confs_list.append(output[:, begin + 5 : end])
# Shape: [batch, num_anchors * 2, H, W]
bxy = torch.cat(bxy_list, dim=1)
# Shape: [batch, num_anchors * 2, H, W]
bwh = torch.cat(bwh_list, dim=1)
# Shape: [batch, num_anchors, H, W]
det_confs = torch.cat(det_confs_list, dim=1)
# Shape: [batch, num_anchors * H * W]
det_confs = det_confs.view(batch, num_anchors * H * W)
# Shape: [batch, num_anchors * num_classes, H, W]
cls_confs = torch.cat(cls_confs_list, dim=1)
# Shape: [batch, num_anchors, num_classes, H * W]
cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W)
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, num_classes)
# Apply sigmoid(), exp() and softmax() to slices
#
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
bwh = torch.exp(bwh)
det_confs = torch.sigmoid(det_confs)
cls_confs = torch.sigmoid(cls_confs)
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0)
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0), axis=0)
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
anchor_w = []
anchor_h = []
for i in range(num_anchors):
anchor_w.append(anchors[i * 2])
anchor_h.append(anchors[i * 2 + 1])
device = None
cuda_check = output.is_cuda
if cuda_check:
device = output.get_device()
bx_list = []
by_list = []
bw_list = []
bh_list = []
# Apply C-x, C-y, P-w, P-h
for i in range(num_anchors):
ii = i * 2
# Shape: [batch, 1, H, W]
bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)
# Shape: [batch, 1, H, W]
by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)
# Shape: [batch, 1, H, W]
bw = bwh[:, ii : ii + 1] * anchor_w[i]
# Shape: [batch, 1, H, W]
bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
bx_list.append(bx)
by_list.append(by)
bw_list.append(bw)
bh_list.append(bh)
########################################
# Figure out bboxes from slices #
########################################
# Shape: [batch, num_anchors, H, W]
bx = torch.cat(bx_list, dim=1)
# Shape: [batch, num_anchors, H, W]
by = torch.cat(by_list, dim=1)
# Shape: [batch, num_anchors, H, W]
bw = torch.cat(bw_list, dim=1)
# Shape: [batch, num_anchors, H, W]
bh = torch.cat(bh_list, dim=1)
# Shape: [batch, 2 * num_anchors, H, W]
bx_bw = torch.cat((bx, bw), dim=1)
# Shape: [batch, 2 * num_anchors, H, W]
by_bh = torch.cat((by, bh), dim=1)
# normalize coordinates to [0, 1]
bx_bw /= W
by_bh /= H
# Shape: [batch, num_anchors * H * W, 1]
bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
bx1 = bx - bw * 0.5
by1 = by - bh * 0.5
bx2 = bx1 + bw
by2 = by1 + bh
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4)
# boxes = boxes.repeat(1, 1, num_classes, 1)
# boxes: [batch, num_anchors * H * W, 1, 4]
# cls_confs: [batch, num_anchors * H * W, num_classes]
# det_confs: [batch, num_anchors * H * W]
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
confs = cls_confs * det_confs
# boxes: [batch, num_anchors * H * W, 1, 4]
# confs: [batch, num_anchors * H * W, num_classes]
return boxes, confs
def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
validation=False):
# Output would be invalid if it does not satisfy this assert
# assert (output.size(1) == (5 + num_classes) * num_anchors)
# print(output.size())
# Slice the second dimension (channel) of output into:
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
# And then into
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
# batch = output.size(0)
# H = output.size(2)
# W = output.size(3)
bxy_list = []
bwh_list = []
det_confs_list = []
cls_confs_list = []
for i in range(num_anchors):
begin = i * (5 + num_classes)
end = (i + 1) * (5 + num_classes)
bxy_list.append(output[:, begin : begin + 2])
bwh_list.append(output[:, begin + 2 : begin + 4])
det_confs_list.append(output[:, begin + 4 : begin + 5])
cls_confs_list.append(output[:, begin + 5 : end])
# Shape: [batch, num_anchors * 2, H, W]
bxy = torch.cat(bxy_list, dim=1)
# Shape: [batch, num_anchors * 2, H, W]
bwh = torch.cat(bwh_list, dim=1)
# Shape: [batch, num_anchors, H, W]
det_confs = torch.cat(det_confs_list, dim=1)
# Shape: [batch, num_anchors * H * W]
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3))
# Shape: [batch, num_anchors * num_classes, H, W]
cls_confs = torch.cat(cls_confs_list, dim=1)
# Shape: [batch, num_anchors, num_classes, H * W]
cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3))
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes)
# Apply sigmoid(), exp() and softmax() to slices
#
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
bwh = torch.exp(bwh)
det_confs = torch.sigmoid(det_confs)
cls_confs = torch.sigmoid(cls_confs)
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0)
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0)
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
anchor_w = []
anchor_h = []
for i in range(num_anchors):
anchor_w.append(anchors[i * 2])
anchor_h.append(anchors[i * 2 + 1])
device = None
cuda_check = output.is_cuda
if cuda_check:
device = output.get_device()
bx_list = []
by_list = []
bw_list = []
bh_list = []
# Apply C-x, C-y, P-w, P-h
for i in range(num_anchors):
ii = i * 2
# Shape: [batch, 1, H, W]
bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)
# Shape: [batch, 1, H, W]
by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)
# Shape: [batch, 1, H, W]
bw = bwh[:, ii : ii + 1] * anchor_w[i]
# Shape: [batch, 1, H, W]
bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
bx_list.append(bx)
by_list.append(by)
bw_list.append(bw)
bh_list.append(bh)
########################################
# Figure out bboxes from slices #
########################################
# Shape: [batch, num_anchors, H, W]
bx = torch.cat(bx_list, dim=1)
# Shape: [batch, num_anchors, H, W]
by = torch.cat(by_list, dim=1)
# Shape: [batch, num_anchors, H, W]
bw = torch.cat(bw_list, dim=1)
# Shape: [batch, num_anchors, H, W]
bh = torch.cat(bh_list, dim=1)
# Shape: [batch, 2 * num_anchors, H, W]
bx_bw = torch.cat((bx, bw), dim=1)
# Shape: [batch, 2 * num_anchors, H, W]
by_bh = torch.cat((by, bh), dim=1)
# normalize coordinates to [0, 1]
bx_bw /= output.size(3)
by_bh /= output.size(2)
# Shape: [batch, num_anchors * H * W, 1]
bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bx1 = bx - bw * 0.5
by1 = by - bh * 0.5
bx2 = bx1 + bw
by2 = by1 + bh
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4)
# boxes = boxes.repeat(1, 1, num_classes, 1)
# boxes: [batch, num_anchors * H * W, 1, 4]
# cls_confs: [batch, num_anchors * H * W, num_classes]
# det_confs: [batch, num_anchors * H * W]
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
confs = cls_confs * det_confs
# boxes: [batch, num_anchors * H * W, 1, 4]
# confs: [batch, num_anchors * H * W, num_classes]
return boxes, confs
class YoloLayer(nn.Module):
''' Yolo layer
model_out: while inference,is post-processing inside or outside the model
true:outside
'''
def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False):
super(YoloLayer, self).__init__()
self.anchor_mask = anchor_mask
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors) // num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.stride = stride
self.seen = 0
self.scale_x_y = 1
self.model_out = model_out
def forward(self, output, target=None):
if self.training:
return output
masked_anchors = []
for m in self.anchor_mask:
masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
masked_anchors = [anchor / self.stride for anchor in masked_anchors]
return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)