EasyMocap/myeasymocap/backbone/hand2d/hand2d.py
2023-06-19 20:46:44 +08:00

89 lines
3.7 KiB
Python

import os
import cv2
import torch
import torch.nn as nn
import numpy as np
import math
# https://download.openmmlab.com/mmpose/hand/hrnetv2/hrnetv2_w18_rhd2d_256x256-95b20dd8_20210330.pth
# https://download.openmmlab.com/mmpose/hand/dark/hrnetv2_w18_onehand10k_256x256_dark-a2f80c64_20210330.pth
from ..basetopdown import BaseTopDownModelCache, get_preds_from_heatmaps, gdown_models
class TopDownAsMMPose(nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.bacbone = backbone
self.head = head
def forward(self, x):
feat_list = self.bacbone(x)
size = feat_list[0].shape[-2:]
resized_inputs = [
nn.functional.interpolate(feat, size, mode='bilinear', align_corners=False) \
for feat in feat_list
]
resized_inputs = torch.cat(resized_inputs, 1)
out = self.head(resized_inputs)
pred = get_preds_from_heatmaps(out.detach().cpu().numpy())
return {'keypoints': pred}
class MyHand2D(BaseTopDownModelCache):
def __init__(self, ckpt, url=None, mode='hrnet'):
if mode == 'hrnet':
super().__init__(name='hand2d', bbox_scale=1.1, res_input=256)
from .hrnet import PoseHighResolutionNet
backbone = PoseHighResolutionNet(inp_ch=3, out_ch=21, W=18, multi_scale_final=True, add_final_layer=False)
checkpoint = torch.load(ckpt, map_location='cpu')['state_dict']
self.load_checkpoint(backbone, checkpoint, prefix='backbone.', strict=True)
head = nn.Sequential(
nn.Conv2d(270, 270, kernel_size=1),
nn.BatchNorm2d(270),
nn.ReLU(inplace=True),
nn.Conv2d(270, 21, kernel_size=1)
)
self.load_checkpoint(head, checkpoint, prefix='keypoint_head.final_layer.', strict=True)
# self.model = nn.Sequential(backbone, head)
self.model = TopDownAsMMPose(backbone, head)
elif mode == 'resnet':
super().__init__(name='hand2d', bbox_scale=1.1, res_input=256, mean=[0., 0., 0.], std=[1., 1., 1.])
from .resnet import ResNet_Deconv
if not os.path.exists(ckpt) and url is not None:
gdown_models(ckpt, url)
assert os.path.exists(ckpt), f'{ckpt} not exists'
checkpoint = torch.load(ckpt, map_location='cpu')['state_dict']
model = ResNet_Deconv()
self.load_checkpoint(model, checkpoint, prefix='model.', strict=True)
self.model = model
self.model.eval()
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(self.device)
def __call__(self, bbox, images, imgnames):
squeeze = False
if not isinstance(images, list):
images = [images]
imgnames = [imgnames]
bbox = [bbox]
squeeze = True
nViews = len(images)
kpts_all = []
for nv in range(nViews):
if bbox[nv].shape[0] == 0:
kpts_all.append(np.zeros((21, 3)))
continue
_bbox = bbox[nv]
if len(_bbox.shape) == 1:
_bbox = _bbox[None]
output = super().__call__(_bbox, images[nv], imgnames[nv])
kpts = output['params']['keypoints']
conf = kpts[..., -1:]
kpts = self.batch_affine_transform(kpts, output['params']['inv_trans'])
kpts = np.concatenate([kpts, conf], axis=-1)
if len(kpts.shape) == 3:
kpts = kpts[0]
kpts_all.append(kpts)
kpts_all = np.stack(kpts_all)
if squeeze:
kpts_all = kpts_all[0]
return {
'keypoints': kpts_all
}