diff --git a/myeasymocap/backbone/basetopdown.py b/myeasymocap/backbone/basetopdown.py index e85e2a8..4fc1503 100644 --- a/myeasymocap/backbone/basetopdown.py +++ b/myeasymocap/backbone/basetopdown.py @@ -49,6 +49,37 @@ def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_heig return trans, inv_trans +# TODO: add UDP +def get_warp_matrix(theta, size_input, size_dst, size_target): + """Calculate the transformation matrix under the constraint of unbiased. + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased + Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + theta (float): Rotation angle in degrees. + size_input (np.ndarray): Size of input image [w, h]. + size_dst (np.ndarray): Size of output image [w, h]. + size_target (np.ndarray): Size of ROI in input plane [w, h]. + + Returns: + np.ndarray: A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = math.cos(theta) * scale_x + matrix[0, 1] = -math.sin(theta) * scale_x + matrix[0, 2] = scale_x * (-0.5 * size_input[0] * math.cos(theta) + + 0.5 * size_input[1] * math.sin(theta) + + 0.5 * size_target[0]) + matrix[1, 0] = math.sin(theta) * scale_y + matrix[1, 1] = math.cos(theta) * scale_y + matrix[1, 2] = scale_y * (-0.5 * size_input[0] * math.sin(theta) - + 0.5 * size_input[1] * math.cos(theta) + + 0.5 * size_target[1]) + return matrix + def generate_patch_image_cv(cvimg, c_x, c_y, bb_width, bb_height, patch_width, patch_height, do_flip, scale, rot): trans, inv_trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot, inv=False) @@ -75,8 +106,8 @@ def get_single_image_crop_demo(image, bbox, scale=1.2, crop_size=224, ) if fliplr: crop_image = cv2.flip(crop_image, 1) - # cv2.imwrite('debug_crop.jpg', crop_image) - # import ipdb; ipdb.set_trace() + # cv2.imwrite('debug_crop.jpg', crop_image[:,:,::-1]) + # cv2.imwrite('debug_crop_full.jpg', image[:,:,::-1]) crop_image = crop_image.transpose(2,0,1) mean1=np.array(mean, dtype=np.float32).reshape(3,1,1) std1= np.array(std, dtype=np.float32).reshape(3,1,1) @@ -123,6 +154,14 @@ class BaseTopDownModel(nn.Module): squeeze = True # TODO: 兼容多张图片的 bbox = xyxy2ccwh(bbox) + # convert the bbox to the aspect of input bbox + aspect_ratio = self.crop_size[1] / self.crop_size[0] + w, h = bbox[:, 2], bbox[:, 3] + # 如果height大于w*ratio,那么增大w + flag = h > aspect_ratio * w + bbox[flag, 2] = h[flag] / aspect_ratio + # 否则增大h + bbox[~flag, 3] = w[~flag] * aspect_ratio inputs = [] inv_trans_ = [] for i in range(bbox.shape[0]): @@ -141,6 +180,15 @@ class BaseTopDownModel(nn.Module): ) inputs.append(norm_img) inv_trans_.append(inv_trans) + if False: + vis = np.hstack(inputs) + mean, std = np.array(self.mean), np.array(self.std) + mean = mean.reshape(3, 1, 1) + std = std.reshape(3, 1, 1) + vis = (vis * std) + mean + vis = vis.transpose(1, 2, 0) + vis = (vis[:, :, ::-1] * 255).astype(np.uint8) + cv2.imwrite('debug_crop.jpg', vis) inputs = np.stack(inputs) inv_trans_ = np.stack(inv_trans_) inputs = torch.FloatTensor(inputs).to(self.device) @@ -168,17 +216,30 @@ class BaseTopDownModelCache(BaseTopDownModel): super().__init__(**kwargs) self.name = name - def __call__(self, bbox, images, imgname, flips=None): + def cachename(self, imgname): basename = os.sep.join(imgname.split(os.sep)[-2:]) cachename = join(self.output, self.name, basename.replace('.jpg', '.pkl')) + return cachename + + def dump(self, cachename, output): os.makedirs(os.path.dirname(cachename), exist_ok=True) + with open(cachename, 'wb') as f: + pickle.dump(output, f) + return output + + def load(self, cachename): + with open(cachename, 'rb') as f: + output = pickle.load(f) + return output + + def __call__(self, bbox, images, imgname, flips=None): + cachename = self.cachename(imgname) if os.path.exists(cachename): - with open(cachename, 'rb') as f: - output = pickle.load(f) + output = self.load(cachename) else: output = self.infer(images, bbox, to_numpy=True, flips=flips) - with open(cachename, 'wb') as f: - pickle.dump(output, f) + output = self.dump(cachename, output) + ret = { 'params': output } diff --git a/myeasymocap/backbone/hrnet/myhrnet.py b/myeasymocap/backbone/hrnet/myhrnet.py index 178db28..05ad4f5 100644 --- a/myeasymocap/backbone/hrnet/myhrnet.py +++ b/myeasymocap/backbone/hrnet/myhrnet.py @@ -51,9 +51,12 @@ def coco17tobody25(points2d): return kpts class MyHRNet(BaseTopDownModelCache): - def __init__(self, ckpt): - super().__init__(name='hand2d', bbox_scale=1.25, res_input=[288, 384]) - model = HRNet(48, 17, 0.1) + def __init__(self, ckpt, single_person=True, num_joints=17, name='keypoints2d'): + super().__init__(name, bbox_scale=1.25, res_input=[288, 384]) + # 如果启用,那么将每个视角最多保留一个,并且squeeze and stack + self.single_person = single_person + model = HRNet(48, num_joints, 0.1) + self.num_joints = num_joints if not os.path.exists(ckpt) and ckpt.endswith('pose_hrnet_w48_384x288.pth'): url = "11ezQ6a_MxIRtj26WqhH3V3-xPI3XqYAw" text = '''Download `models/pytorch/pose_coco/pose_hrnet_w48_384x288.pth` from (OneDrive)[https://1drv.ms/f/s!AhIXJn_J-blW231MH2krnmLq5kkQ], @@ -109,20 +112,23 @@ class MyHRNet(BaseTopDownModelCache): for nv in range(nViews): _bbox = bbox[nv] if _bbox.shape[0] == 0: - kpts_all.append(np.zeros((17, 3))) - continue - img = images[nv] - # TODO: add flip test - out = super().__call__(_bbox, img, imgnames[nv]) - output = out['params']['output'] - kpts = self.get_max_preds(output) - kpts_ori = self.batch_affine_transform(kpts, out['params']['inv_trans']) - kpts = np.concatenate([kpts_ori, kpts[..., -1:]], axis=-1) + if self.single_person: + kpts = np.zeros((1, self.num_joints, 3)) + else: + kpts = np.zeros((_bbox.shape[0], self.num_joints, 3)) + else: + img = images[nv] + # TODO: add flip test + out = super().__call__(_bbox, img, imgnames[nv]) + output = out['params']['output'] + kpts = self.get_max_preds(output) + kpts_ori = self.batch_affine_transform(kpts, out['params']['inv_trans']) + kpts = np.concatenate([kpts_ori, kpts[..., -1:]], axis=-1) kpts = coco17tobody25(kpts) - if len(kpts.shape) == 3: - kpts = kpts[0] kpts_all.append(kpts) - kpts_all = np.stack(kpts_all) + if self.single_person: + kpts_all = [k[0] for k in kpts_all] + kpts_all = np.stack(kpts_all) if squeeze: kpts_all = kpts_all[0] return { diff --git a/myeasymocap/backbone/topdown_keypoints.py b/myeasymocap/backbone/topdown_keypoints.py new file mode 100644 index 0000000..99a83b6 --- /dev/null +++ b/myeasymocap/backbone/topdown_keypoints.py @@ -0,0 +1,94 @@ +import math +import numpy as np + +def get_max_preds(batch_heatmaps): + ''' + get predictions from score maps + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + ''' + assert isinstance(batch_heatmaps, np.ndarray), \ + 'batch_heatmaps should be numpy.ndarray' + assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim: {}'.format(batch_heatmaps.shape) + + batch_size = batch_heatmaps.shape[0] + num_joints = batch_heatmaps.shape[1] + width = batch_heatmaps.shape[3] + heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) + maxvals = np.amax(heatmaps_reshaped, 2) + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) + idx = idx.reshape((batch_size, num_joints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + return preds, maxvals + +COCO17_IN_BODY25 = [0,16,15,18,17,5,2,6,3,7,4,12,9,13,10,14,11] +pairs = [[1, 8], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], [13, 14], [1, 0], [0,15], [15,17], [0,16], [16,18], [14,19], [19,20], [14,21], [11,22], [22,23], [11,24]] +def coco17tobody25(points2d): + kpts = np.zeros((points2d.shape[0], 25, 3)) + kpts[:, COCO17_IN_BODY25, :2] = points2d[:, :, :2] + kpts[:, COCO17_IN_BODY25, 2:3] = points2d[:, :, 2:3] + kpts[:, 8, :2] = kpts[:, [9, 12], :2].mean(axis=1) + kpts[:, 8, 2] = kpts[:, [9, 12], 2].min(axis=1) + kpts[:, 1, :2] = kpts[:, [2, 5], :2].mean(axis=1) + kpts[:, 1, 2] = kpts[:, [2, 5], 2].min(axis=1) + # 需要交换一下 + # kpts = kpts[:, :, [1,0,2]] + return kpts + +def coco23tobody25(points2d): + kpts = coco17tobody25(points2d[:, :17]) + kpts[:, [19, 20, 21, 22, 23, 24]] = points2d[:, [17, 18, 19, 20, 21, 22]] + return kpts + +class BaseKeypoints(): + @staticmethod + def get_max_preds(batch_heatmaps): + coords, maxvals = get_max_preds(batch_heatmaps) + + heatmap_height = batch_heatmaps.shape[2] + heatmap_width = batch_heatmaps.shape[3] + + # post-processing + if True: + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = batch_heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1: + diff = np.array( + [ + hm[py][px+1] - hm[py][px-1], + hm[py+1][px]-hm[py-1][px] + ] + ) + coords[n][p] += np.sign(diff) * .25 + coords = coords.astype(np.float32) * 4 + pred = np.dstack((coords, maxvals)) + return pred + + @staticmethod + def batch_affine_transform(points, trans): + # points: (Bn, J, 2), trans: (Bn, 2, 3) + points = np.dstack((points[..., :2], np.ones((*points.shape[:-1], 1)))) + out = np.matmul(points, trans.swapaxes(-1, -2)) + return out + + @staticmethod + def coco17tobody25(points2d): + return coco17tobody25(points2d) + + @staticmethod + def coco23tobody25(points2d): + return coco23tobody25(points2d) + \ No newline at end of file diff --git a/myeasymocap/backbone/vitpose/layers.py b/myeasymocap/backbone/vitpose/layers.py new file mode 100644 index 0000000..a1daefe --- /dev/null +++ b/myeasymocap/backbone/vitpose/layers.py @@ -0,0 +1,98 @@ +import torch +import math +import collections.abc +from itertools import repeat +import warnings +import torch.nn as nn + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) \ No newline at end of file diff --git a/myeasymocap/backbone/vitpose/vit_moe.py b/myeasymocap/backbone/vitpose/vit_moe.py new file mode 100644 index 0000000..0e8f80c --- /dev/null +++ b/myeasymocap/backbone/vitpose/vit_moe.py @@ -0,0 +1,607 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import numpy as np +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from .layers import drop_path, to_2tuple, trunc_normal_ + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class MoEMlp(nn.Module): + def __init__(self, num_expert=1, in_features=1024, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., part_features=256): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.part_features = part_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features - part_features) + self.drop = nn.Dropout(drop) + + self.num_expert = num_expert + experts = [] + + for i in range(num_expert): + experts.append( + nn.Linear(hidden_features, part_features) + ) + self.experts = nn.ModuleList(experts) + + def forward(self, x, indices): + + expert_x = torch.zeros_like(x[:, :, -self.part_features:], device=x.device, dtype=x.dtype) + + x = self.fc1(x) + x = self.act(x) + shared_x = self.fc2(x) + indices = indices.view(-1, 1, 1) + + # to support ddp training + for i in range(self.num_expert): + selectedIndex = (indices == i) + current_x = self.experts[i](x) * selectedIndex + expert_x = expert_x + current_x + + x = torch.cat([shared_x, expert_x], dim=-1) + + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None, num_expert=1, part_features=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MoEMlp(num_expert=num_expert, in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, part_features=part_features) + + def forward(self, x, indices=None): + + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x), indices)) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + +class ViTMoE(nn.Module): + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + num_expert=1, part_features=None + ): + # Protect mutable default arguments + super(ViTMoE, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + self.part_features = part_features + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_expert=num_expert, part_features=part_features + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + super().init_weights(pretrained, patch_padding=self.patch_padding, part_features=self.part_features) + + if pretrained is None: + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x, dataset_source=None): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, dataset_source) + else: + x = blk(x, dataset_source) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x, dataset_source=None): + x = self.forward_features(x, dataset_source) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() + +class Head(nn.Module): + def __init__(self, in_channels, + out_channels, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4),): + super().__init__() + self.in_channels = in_channels + self.deconv_layers = self._make_deconv_layer(num_deconv_layers, num_deconv_filters, num_deconv_kernels) + self.final_layer = nn.Conv2d(in_channels=num_deconv_filters[-1], out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def forward(self, x): + """Forward function.""" + x = self.deconv_layers(x) + x = self.final_layer(x) + return x + + +class ComposeVit(nn.Module): + def __init__(self): + super().__init__() + cfg_backbone = dict( + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + num_expert=6, + part_features=192 + ) + cfg_head = dict( + in_channels=768, + out_channels=17, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + ) + cfg_head_133 = dict( + in_channels=768, + out_channels=133, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + ) + self.backbone = ViTMoE(**cfg_backbone) + self.keypoint_head = Head(**cfg_head) + self.associate_head = Head(**cfg_head_133) + + def forward(self, x): + indices = torch.zeros((x.shape[0]), dtype=torch.long, device=x.device) + back_out = self.backbone(x, indices) + out = self.keypoint_head(back_out) + if True: + indices += 5 # 最后一个是whole body dataset + back_133 = self.backbone(x, indices) + out_133 = self.associate_head(back_133) + out_foot = out_133[:, 17:23] + out = torch.cat([out, out_foot], dim=1) + if False: + import cv2 + vis = x[0].permute(1, 2, 0).cpu().numpy() + mean= np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3) + std=np.array([0.229, 0.224, 0.225]).reshape(1, 1 ,3) + vis = np.clip(vis * std + mean, 0., 1.) + vis = (vis[:,:,::-1] * 255).astype(np.uint8) + value = out_133[0].detach().cpu().numpy() + vis_all = [] + for i in range(value.shape[0]): + _val = np.clip(value[i], 0., 1.) + _val = (_val * 255).astype(np.uint8) + _val = cv2.resize(_val, None, fx=4, fy=4) + _val = cv2.applyColorMap(_val, cv2.COLORMAP_JET) + _vis = cv2.addWeighted(vis, 0.5, _val, 0.5, 0) + vis_all.append(_vis) + from easymocap.mytools.vis_base import merge + cv2.imwrite('debug.jpg', merge(vis_all)) + + import ipdb; ipdb.set_trace() + return { + 'output': out + } + +from ..basetopdown import BaseTopDownModelCache +from ..topdown_keypoints import BaseKeypoints + +class MyViT(BaseTopDownModelCache, BaseKeypoints): + def __init__(self, ckpt='data/models/vitpose+_base.pth', single_person=True, url='https://1drv.ms/u/s!AimBgYV7JjTlgcckRZk1bIAuRa_E1w?e=ylDB2G', **kwargs): + super().__init__(name='myvit', bbox_scale=1.25, + res_input=[192, 256], **kwargs) + self.single_person = single_person + model = ComposeVit() + if not os.path.exists(ckpt): + print('') + print('{} not exists, please download it from {} and place it to {}'.format(ckpt, url, ckpt)) + print('') + raise FileNotFoundError + ckpt = torch.load(ckpt, map_location='cpu')['state_dict'] + ckpt_backbone = {key:val for key, val in ckpt.items() if key.startswith('backbone.')} + ckpt_head = {key:val for key, val in ckpt.items() if key.startswith('keypoint_head.')} + key_whole = 'associate_keypoint_heads.4.' + ckpt_head_133 = {key.replace(key_whole, 'associate_head.'):val for key, val in ckpt.items() if key.startswith(key_whole)} + ckpt_backbone.update(ckpt_head) + ckpt_backbone.update(ckpt_head_133) + state_dict = ckpt_backbone + self.load_checkpoint(model, state_dict, prefix='', strict=True) + model.eval() + self.model = model + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self.model.to(self.device) + + def dump(self, cachename, output): + _output = output['output'] + kpts = self.get_max_preds(_output) + kpts_ori = self.batch_affine_transform(kpts, output['inv_trans']) + kpts = np.concatenate([kpts_ori, kpts[..., -1:]], axis=-1) + output = {'keypoints': kpts} + super().dump(cachename, output) + return output + + def estimate_keypoints(self, bbox, images, imgnames): + squeeze = False + if not isinstance(images, list): + images = [images] + imgnames = [imgnames] + bbox = [bbox] + squeeze = True + nViews = len(images) + kpts_all = [] + for nv in range(nViews): + _bbox = bbox[nv] + if _bbox.shape[0] == 0: + if self.single_person: + kpts = np.zeros((1, self.num_joints, 3)) + else: + kpts = np.zeros((_bbox.shape[0], self.num_joints, 3)) + else: + img = images[nv] + # TODO: add flip test + out = super().__call__(_bbox, img, imgnames[nv]) + kpts = out['params']['keypoints'] + if kpts.shape[-2] == 23: + kpts = self.coco23tobody25(kpts) + elif kpts.shape[-2] == 17: + kpts = self.coco17tobody25(kpts) + else: + raise NotImplementedError + kpts_all.append(kpts) + if self.single_person: + kpts_all = [k[0] for k in kpts_all] + kpts_all = np.stack(kpts_all) + if squeeze: + kpts_all = kpts_all[0] + return { + 'keypoints': kpts_all + } + + def __call__(self, bbox, images, imgnames): + return self.estimate_keypoints(bbox, images, imgnames) + +if __name__ == '__main__': + # Load checkpoint + rand_input = torch.rand(1, 3, 256, 192) + model = MyViT() diff --git a/myeasymocap/backbone/yolo/yolo.py b/myeasymocap/backbone/yolo/yolo.py index f32bc44..954fdff 100644 --- a/myeasymocap/backbone/yolo/yolo.py +++ b/myeasymocap/backbone/yolo/yolo.py @@ -145,6 +145,21 @@ class YoloWithTrack(BaseYOLOv5): self.track_cache[sub]['bbox'].append(select) return select +class MultiPerson(BaseYOLOv5): + def __init__(self, min_length, max_length, **kwargs): + super().__init__(**kwargs) + self.min_length = min_length + self.max_length = max_length + print('[{}] Only keep the bbox in [{}, {}]'.format(self.__class__.__name__, min_length, max_length)) + + def select_bbox(self, select, imgname): + if select.shape[0] == 0: + return select + # 判断一下面积 + area = np.sqrt((select[:, 2] - select[:, 0])*(select[:, 3]-select[:, 1])) + valid = (area > self.min_length) & (area < self.max_length) + return select[valid] + class DetectToPelvis: def __init__(self, key) -> None: self.key = key