EasyMocap/myeasymocap/backbone/hrnet/myhrnet.py

136 lines
5.5 KiB
Python
Raw Normal View History

2023-06-19 16:39:27 +08:00
import os
import numpy as np
import math
import cv2
import torch
from ..basetopdown import BaseTopDownModelCache
from .hrnet import HRNet
def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim: {}'.format(batch_heatmaps.shape)
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
COCO17_IN_BODY25 = [0,16,15,18,17,5,2,6,3,7,4,12,9,13,10,14,11]
pairs = [[1, 8], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], [13, 14], [1, 0], [0,15], [15,17], [0,16], [16,18], [14,19], [19,20], [14,21], [11,22], [22,23], [11,24]]
def coco17tobody25(points2d):
kpts = np.zeros((points2d.shape[0], 25, 3))
kpts[:, COCO17_IN_BODY25, :2] = points2d[:, :, :2]
kpts[:, COCO17_IN_BODY25, 2:3] = points2d[:, :, 2:3]
kpts[:, 8, :2] = kpts[:, [9, 12], :2].mean(axis=1)
kpts[:, 8, 2] = kpts[:, [9, 12], 2].min(axis=1)
kpts[:, 1, :2] = kpts[:, [2, 5], :2].mean(axis=1)
kpts[:, 1, 2] = kpts[:, [2, 5], 2].min(axis=1)
# 需要交换一下
# kpts = kpts[:, :, [1,0,2]]
return kpts
class MyHRNet(BaseTopDownModelCache):
2023-07-10 22:10:41 +08:00
def __init__(self, ckpt, single_person=True, num_joints=17, name='keypoints2d'):
super().__init__(name, bbox_scale=1.25, res_input=[288, 384])
# 如果启用那么将每个视角最多保留一个并且squeeze and stack
self.single_person = single_person
model = HRNet(48, num_joints, 0.1)
self.num_joints = num_joints
2023-06-19 16:39:27 +08:00
if not os.path.exists(ckpt) and ckpt.endswith('pose_hrnet_w48_384x288.pth'):
url = "11ezQ6a_MxIRtj26WqhH3V3-xPI3XqYAw"
text = '''Download `models/pytorch/pose_coco/pose_hrnet_w48_384x288.pth` from (OneDrive)[https://1drv.ms/f/s!AhIXJn_J-blW231MH2krnmLq5kkQ],
And place it into {}'''.format(os.path.dirname(ckpt))
print(text)
os.makedirs(os.path.dirname(ckpt), exist_ok=True)
cmd = 'gdown "{}" -O {}'.format(url, ckpt)
print('\n', cmd, '\n')
os.system(cmd)
assert os.path.exists(ckpt), f'{ckpt} not exists'
checkpoint = torch.load(ckpt, map_location='cpu')
model.load_state_dict(checkpoint)
model.eval()
self.model = model
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(self.device)
@staticmethod
def get_max_preds(batch_heatmaps):
coords, maxvals = get_max_preds(batch_heatmaps)
heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]
# post-processing
if True:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
diff = np.array(
[
hm[py][px+1] - hm[py][px-1],
hm[py+1][px]-hm[py-1][px]
]
)
coords[n][p] += np.sign(diff) * .25
coords = coords.astype(np.float32) * 4
pred = np.dstack((coords, maxvals))
return pred
def __call__(self, bbox, images, imgnames):
squeeze = False
if not isinstance(images, list):
images = [images]
imgnames = [imgnames]
bbox = [bbox]
squeeze = True
nViews = len(images)
kpts_all = []
for nv in range(nViews):
_bbox = bbox[nv]
if _bbox.shape[0] == 0:
2023-07-10 22:10:41 +08:00
if self.single_person:
kpts = np.zeros((1, self.num_joints, 3))
else:
kpts = np.zeros((_bbox.shape[0], self.num_joints, 3))
else:
img = images[nv]
# TODO: add flip test
out = super().__call__(_bbox, img, imgnames[nv])
output = out['params']['output']
kpts = self.get_max_preds(output)
kpts_ori = self.batch_affine_transform(kpts, out['params']['inv_trans'])
kpts = np.concatenate([kpts_ori, kpts[..., -1:]], axis=-1)
2023-06-19 16:39:27 +08:00
kpts = coco17tobody25(kpts)
kpts_all.append(kpts)
2023-07-10 22:10:41 +08:00
if self.single_person:
kpts_all = [k[0] for k in kpts_all]
kpts_all = np.stack(kpts_all)
2023-06-19 16:39:27 +08:00
if squeeze:
kpts_all = kpts_all[0]
return {
'keypoints': kpts_all
}