🚀 add model check point
This commit is contained in:
parent
e7800a1356
commit
11f13d6953
@ -26,7 +26,8 @@ args:
|
||||
key_from_previous: [bbox]
|
||||
args:
|
||||
# ckpt: /nas/public/EasyMocapModels/hrnetv2_w18_coco_wholebody_hand_256x256-1c028db7_20210908.pth
|
||||
ckpt: /nas/public/EasyMocapModels/hand/resnet_kp2d_clean.pt
|
||||
ckpt: models/hand_resnet_kp2d_clean.pt
|
||||
url: 1LTK7e9oAS6B3drmQyXwTZild6k87fEZa
|
||||
mode: resnet
|
||||
vis2d:
|
||||
module: myeasymocap.io.vis.Vis2D
|
||||
@ -42,13 +43,14 @@ args:
|
||||
key_from_previous: [bbox]
|
||||
key_keep: [meta, cameras, imgnames] # 将这些参数都保留到最后的输出中
|
||||
args:
|
||||
ckpt: models/manol_pca45_noflat.ckpt
|
||||
ckpt: models/hand_manol_pca45_noflat.ckpt
|
||||
url: '1KTi_oJ_udLRK3WZ3xyHzBUd6vKAApfT8'
|
||||
# TODO: add visualize for Init MANO
|
||||
at_final:
|
||||
load_hand_model: # 载入身体模型
|
||||
module: myeasymocap.io.model.MANOLoader
|
||||
args:
|
||||
cfg_path: config/model/mano.yml
|
||||
cfg_path: config/model/manol.yml
|
||||
model_path: models/manov1.2/MANO_LEFT.pkl #models/handmesh/data/MANO_RIGHT.pkl # load mano model
|
||||
regressor_path: models/manov1.2/J_regressor_mano_LEFT.txt #models/handmesh/data/J_regressor_mano_RIGHT.txt
|
||||
num_pca_comps: 45
|
||||
|
@ -26,7 +26,9 @@ args:
|
||||
key_from_previous: [bbox]
|
||||
args:
|
||||
# ckpt: /nas/public/EasyMocapModels/hrnetv2_w18_coco_wholebody_hand_256x256-1c028db7_20210908.pth
|
||||
ckpt: /nas/public/EasyMocapModels/hand/resnet_kp2d_clean.pt
|
||||
# ckpt: /nas/public/EasyMocapModels/hand/resnet_kp2d_clean.pt
|
||||
ckpt: models/hand_resnet_kp2d_clean.pt
|
||||
url: 1LTK7e9oAS6B3drmQyXwTZild6k87fEZa
|
||||
mode: resnet
|
||||
vis2d:
|
||||
module: myeasymocap.io.vis.Vis2D
|
||||
@ -42,13 +44,14 @@ args:
|
||||
key_from_previous: [bbox]
|
||||
key_keep: [meta, cameras, imgnames] # 将这些参数都保留到最后的输出中
|
||||
args:
|
||||
ckpt: models/manol_pca45_noflat.ckpt
|
||||
ckpt: models/hand_manol_pca45_noflat.ckpt
|
||||
url: '1KTi_oJ_udLRK3WZ3xyHzBUd6vKAApfT8'
|
||||
# TODO: add visualize for Init MANO
|
||||
at_final:
|
||||
load_hand_model: # 载入身体模型
|
||||
module: myeasymocap.io.model.MANOLoader
|
||||
args:
|
||||
cfg_path: config/model/mano.yml
|
||||
cfg_path: config/model/manol.yml
|
||||
model_path: models/manov1.2/MANO_LEFT.pkl #models/handmesh/data/MANO_RIGHT.pkl # load mano model
|
||||
regressor_path: models/manov1.2/J_regressor_mano_LEFT.txt #models/handmesh/data/J_regressor_mano_RIGHT.txt
|
||||
num_pca_comps: 45
|
||||
|
@ -33,11 +33,13 @@ class FileStorage(object):
|
||||
self._write(' rows: {}'.format(value.shape[0]))
|
||||
self._write(' cols: {}'.format(value.shape[1]))
|
||||
self._write(' dt: d')
|
||||
self._write(' data: [{}]'.format(', '.join(['{:.3f}'.format(i) for i in value.reshape(-1)])))
|
||||
self._write(' data: [{}]'.format(', '.join(['{:.6f}'.format(i) for i in value.reshape(-1)])))
|
||||
elif dt == 'list':
|
||||
self._write('{}:'.format(key))
|
||||
for elem in value:
|
||||
self._write(' - "{}"'.format(elem))
|
||||
elif dt == 'int':
|
||||
self._write('{}: {}'.format(key, value))
|
||||
|
||||
def read(self, key, dt='mat'):
|
||||
if dt == 'mat':
|
||||
@ -52,6 +54,8 @@ class FileStorage(object):
|
||||
if val != 'none':
|
||||
results.append(val)
|
||||
output = results
|
||||
elif dt == 'int':
|
||||
output = int(self.fs.getNode(key).real())
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return output
|
||||
@ -114,6 +118,13 @@ def read_camera(intri_name, extri_name, cam_names=[]):
|
||||
cams[cam] = {}
|
||||
cams[cam]['K'] = intri.read('K_{}'.format( cam))
|
||||
cams[cam]['invK'] = np.linalg.inv(cams[cam]['K'])
|
||||
H = intri.read('H_{}'.format(cam), dt='int')
|
||||
W = intri.read('W_{}'.format(cam), dt='int')
|
||||
if H is None or W is None:
|
||||
print('[camera] no H or W for {}'.format(cam))
|
||||
H, W = -1, -1
|
||||
cams[cam]['H'] = H
|
||||
cams[cam]['W'] = W
|
||||
Rvec = extri.read('R_{}'.format(cam))
|
||||
Tvec = extri.read('T_{}'.format(cam))
|
||||
assert Rvec is not None, cam
|
||||
@ -129,6 +140,10 @@ def read_camera(intri_name, extri_name, cam_names=[]):
|
||||
cams[cam]['P'] = P[cam]
|
||||
|
||||
cams[cam]['dist'] = intri.read('dist_{}'.format(cam))
|
||||
if cams[cam]['dist'] is None:
|
||||
cams[cam]['dist'] = intri.read('D_{}'.format(cam))
|
||||
if cams[cam]['dist'] is None:
|
||||
print('[camera] no dist for {}'.format(cam))
|
||||
cams['basenames'] = cam_names
|
||||
return cams
|
||||
|
||||
@ -155,6 +170,9 @@ def write_camera(camera, path):
|
||||
key = key_.split('.')[0]
|
||||
intri.write('K_{}'.format(key), val['K'])
|
||||
intri.write('dist_{}'.format(key), val['dist'])
|
||||
if 'H' in val.keys() and 'W' in val.keys():
|
||||
intri.write('H_{}'.format(key), val['H'], dt='int')
|
||||
intri.write('W_{}'.format(key), val['W'], dt='int')
|
||||
if 'Rvec' not in val.keys():
|
||||
val['Rvec'] = cv2.Rodrigues(val['R'])[0]
|
||||
extri.write('R_{}'.format(key), val['Rvec'])
|
||||
@ -174,7 +192,7 @@ def camera_from_img(img):
|
||||
class Undistort:
|
||||
distortMap = {}
|
||||
@classmethod
|
||||
def image(cls, frame, K, dist, sub=None):
|
||||
def image(cls, frame, K, dist, sub=None, interp=cv2.INTER_NEAREST):
|
||||
if sub is None:
|
||||
return cv2.undistort(frame, K, dist, None)
|
||||
else:
|
||||
@ -183,7 +201,7 @@ class Undistort:
|
||||
mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, K, (w,h), 5)
|
||||
cls.distortMap[sub] = (mapx, mapy)
|
||||
mapx, mapy = cls.distortMap[sub]
|
||||
img = cv2.remap(frame, mapx, mapy, cv2.INTER_NEAREST)
|
||||
img = cv2.remap(frame, mapx, mapy, interp)
|
||||
return img
|
||||
|
||||
@staticmethod
|
||||
@ -203,6 +221,21 @@ class Undistort:
|
||||
bbox = np.array([kpts[0, 0], kpts[0, 1], kpts[1, 0], kpts[1, 1], bbox[4]])
|
||||
return bbox
|
||||
|
||||
class Distort:
|
||||
@staticmethod
|
||||
def points(keypoints, K, dist):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def bbox(bbox, K, dist):
|
||||
keypoints = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[3]]], dtype=np.float32)
|
||||
k3d = cv2.convertPointsToHomogeneous(keypoints)
|
||||
k3d = (np.linalg.inv(K) @ k3d[:, 0].T).T[:, None]
|
||||
k2d, _ = cv2.projectPoints(k3d, np.zeros((3,)), np.zeros((3,)), K, dist)
|
||||
k2d = k2d[:, 0]
|
||||
bbox = np.array([k2d[0,0], k2d[0,1], k2d[1, 0], k2d[1, 1], bbox[-1]])
|
||||
return bbox
|
||||
|
||||
def unproj(kpts, invK):
|
||||
homo = np.hstack([kpts[:, :2], np.ones_like(kpts[:, :1])])
|
||||
homo = homo @ invK.T
|
||||
|
@ -2,14 +2,15 @@
|
||||
@ Date: 2020-11-28 17:23:04
|
||||
@ Author: Qing Shuai
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-08-12 21:50:56
|
||||
@ LastEditTime: 2022-10-27 15:13:56
|
||||
@ FilePath: /EasyMocapPublic/easymocap/mytools/vis_base.py
|
||||
'''
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
def generate_colorbar(N = 20, cmap = 'jet', rand=True):
|
||||
def generate_colorbar(N = 20, cmap = 'jet', rand=True,
|
||||
ret_float=False, ret_array=False, ret_rgb=False):
|
||||
bar = ((np.arange(N)/(N-1))*255).astype(np.uint8).reshape(-1, 1)
|
||||
colorbar = cv2.applyColorMap(bar, cv2.COLORMAP_JET).squeeze()
|
||||
if False:
|
||||
@ -22,7 +23,12 @@ def generate_colorbar(N = 20, cmap = 'jet', rand=True):
|
||||
rgb = colorbar[index, :]
|
||||
else:
|
||||
rgb = colorbar
|
||||
rgb = rgb.tolist()
|
||||
if ret_rgb:
|
||||
rgb = rgb[:, ::-1]
|
||||
if ret_float:
|
||||
rgb = rgb/255.
|
||||
if not ret_array:
|
||||
rgb = rgb.tolist()
|
||||
return rgb
|
||||
|
||||
# colors_bar_rgb = generate_colorbar(cmap='hsv')
|
||||
@ -69,9 +75,11 @@ def get_rgb(index):
|
||||
# elif index == 0:
|
||||
# return (245, 150, 150)
|
||||
col = list(colors_bar_rgb[index%len(colors_bar_rgb)])[::-1]
|
||||
else:
|
||||
elif isinstance(index, str):
|
||||
col = colors_table.get(index, (1, 0, 0))
|
||||
col = tuple([int(c*255) for c in col[::-1]])
|
||||
else:
|
||||
raise TypeError('index should be int or str')
|
||||
return col
|
||||
|
||||
def get_rgb_01(index):
|
||||
@ -150,14 +158,16 @@ def plot_keypoints(img, points, pid, config, vis_conf=False, use_limb_color=True
|
||||
cv2.putText(img, '{:.1f}'.format(c), (int(x), int(y)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, text_size, col, 2)
|
||||
|
||||
def plot_keypoints_auto(img, points, pid, vis_conf=False, use_limb_color=True, scale=1, lw=-1):
|
||||
def plot_keypoints_auto(img, points, pid, vis_conf=False, use_limb_color=True, scale=1, lw=-1, config_name=None, lw_factor=1):
|
||||
from ..dataset.config import CONFIG
|
||||
config_name = {25: 'body25', 21: 'hand', 42:'handlr', 17: 'coco', 1:'points', 67:'bodyhand', 137: 'total', 79:'up'}[len(points)]
|
||||
if config_name is None:
|
||||
config_name = {25: 'body25', 15: 'body15', 21: 'hand', 42:'handlr', 17: 'coco', 1:'points', 67:'bodyhand', 137: 'total', 79:'up',
|
||||
19:'ochuman'}[len(points)]
|
||||
config = CONFIG[config_name]
|
||||
if lw == -1:
|
||||
lw = img.shape[0]//200
|
||||
if config_name == 'hand':
|
||||
lw = img.shape[0]//1000
|
||||
lw = img.shape[0]//100
|
||||
lw = max(lw, 1)
|
||||
for ii, (i, j) in enumerate(config['kintree']):
|
||||
if i >= len(points) or j >= len(points):
|
||||
@ -169,9 +179,9 @@ def plot_keypoints_auto(img, points, pid, vis_conf=False, use_limb_color=True, s
|
||||
col = get_rgb(config['colors'][ii])
|
||||
else:
|
||||
col = get_rgb(pid)
|
||||
if pt1[0] < 0 or pt1[1] < 0 or pt1[0] > 10000 or pt1[1] > 10000:
|
||||
if pt1[0] < -10000 or pt1[1] < -10000 or pt1[0] > 10000 or pt1[1] > 10000:
|
||||
continue
|
||||
if pt2[0] < 0 or pt2[1] < 0 or pt2[0] > 10000 or pt2[1] > 10000:
|
||||
if pt2[0] < -10000 or pt2[1] < -10000 or pt2[0] > 10000 or pt2[1] > 10000:
|
||||
continue
|
||||
if pt1[-1] > 0.01 and pt2[-1] > 0.01:
|
||||
image = cv2.line(
|
||||
@ -191,12 +201,13 @@ def plot_keypoints_auto(img, points, pid, vis_conf=False, use_limb_color=True, s
|
||||
if c > 0.01:
|
||||
col = get_rgb(pid)
|
||||
if len(points) == 1:
|
||||
cv2.circle(img, (int(x+0.5), int(y+0.5)), lw*10, col, lw*2)
|
||||
plot_cross(img, int(x+0.5), int(y+0.5), width=lw*5, col=col, lw=lw*2)
|
||||
_lw = max(0, int(lw * lw_factor))
|
||||
cv2.circle(img, (int(x+0.5), int(y+0.5)), _lw*2, col, lw*2)
|
||||
plot_cross(img, int(x+0.5), int(y+0.5), width=_lw, col=col, lw=lw*2)
|
||||
else:
|
||||
cv2.circle(img, (int(x+0.5), int(y+0.5)), lw*2, col, -1)
|
||||
if vis_conf:
|
||||
cv2.putText(img, '{:.1f}'.format(c), (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 1, col, 2)
|
||||
cv2.putText(img, '{:.1f}'.format(c), (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, col, 2)
|
||||
|
||||
def plot_keypoints_total(img, annots, scale, pid_offset=0):
|
||||
_lw = img.shape[0] // 150
|
||||
|
@ -239,3 +239,10 @@ def get_preds_from_heatmaps(batch_heatmaps):
|
||||
coords = coords.astype(np.float32) * 4
|
||||
pred = np.dstack((coords, maxvals))
|
||||
return pred
|
||||
|
||||
def gdown_models(ckpt, url):
|
||||
print('Try to download model from {} to {}'.format(url, ckpt))
|
||||
os.makedirs(os.path.dirname(ckpt), exist_ok=True)
|
||||
cmd = 'gdown "{}" -O {}'.format(url, ckpt)
|
||||
print('\n', cmd, '\n')
|
||||
os.system(cmd)
|
0
myeasymocap/backbone/hand2d/__init__.py
Normal file
0
myeasymocap/backbone/hand2d/__init__.py
Normal file
89
myeasymocap/backbone/hand2d/hand2d.py
Normal file
89
myeasymocap/backbone/hand2d/hand2d.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
# https://download.openmmlab.com/mmpose/hand/hrnetv2/hrnetv2_w18_rhd2d_256x256-95b20dd8_20210330.pth
|
||||
# https://download.openmmlab.com/mmpose/hand/dark/hrnetv2_w18_onehand10k_256x256_dark-a2f80c64_20210330.pth
|
||||
from .hrnet import PoseHighResolutionNet
|
||||
from ..basetopdown import BaseTopDownModelCache, get_preds_from_heatmaps, gdown_models
|
||||
|
||||
class TopDownAsMMPose(nn.Module):
|
||||
def __init__(self, backbone, head):
|
||||
super().__init__()
|
||||
self.bacbone = backbone
|
||||
self.head = head
|
||||
|
||||
def forward(self, x):
|
||||
feat_list = self.bacbone(x)
|
||||
size = feat_list[0].shape[-2:]
|
||||
resized_inputs = [
|
||||
nn.functional.interpolate(feat, size, mode='bilinear', align_corners=False) \
|
||||
for feat in feat_list
|
||||
]
|
||||
resized_inputs = torch.cat(resized_inputs, 1)
|
||||
out = self.head(resized_inputs)
|
||||
pred = get_preds_from_heatmaps(out.detach().cpu().numpy())
|
||||
return {'keypoints': pred}
|
||||
|
||||
class MyHand2D(BaseTopDownModelCache):
|
||||
def __init__(self, ckpt, url=None, mode='hrnet'):
|
||||
if mode == 'hrnet':
|
||||
super().__init__(name='hand2d', bbox_scale=1.1, res_input=256)
|
||||
backbone = PoseHighResolutionNet(inp_ch=3, out_ch=21, W=18, multi_scale_final=True, add_final_layer=False)
|
||||
checkpoint = torch.load(ckpt, map_location='cpu')['state_dict']
|
||||
self.load_checkpoint(backbone, checkpoint, prefix='backbone.', strict=True)
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(270, 270, kernel_size=1),
|
||||
nn.BatchNorm2d(270),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(270, 21, kernel_size=1)
|
||||
)
|
||||
self.load_checkpoint(head, checkpoint, prefix='keypoint_head.final_layer.', strict=True)
|
||||
# self.model = nn.Sequential(backbone, head)
|
||||
self.model = TopDownAsMMPose(backbone, head)
|
||||
elif mode == 'resnet':
|
||||
super().__init__(name='hand2d', bbox_scale=1.1, res_input=256, mean=[0., 0., 0.], std=[1., 1., 1.])
|
||||
from .resnet import ResNet_Deconv
|
||||
if not os.path.exists(ckpt) and url is not None:
|
||||
gdown_models(ckpt, url)
|
||||
assert os.path.exists(ckpt), f'{ckpt} not exists'
|
||||
checkpoint = torch.load(ckpt, map_location='cpu')['state_dict']
|
||||
model = ResNet_Deconv()
|
||||
self.load_checkpoint(model, checkpoint, prefix='model.', strict=True)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
self.model.to(self.device)
|
||||
|
||||
def __call__(self, bbox, images, imgnames):
|
||||
squeeze = False
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
imgnames = [imgnames]
|
||||
bbox = [bbox]
|
||||
squeeze = True
|
||||
nViews = len(images)
|
||||
kpts_all = []
|
||||
for nv in range(nViews):
|
||||
if bbox[nv].shape[0] == 0:
|
||||
kpts_all.append(np.zeros((21, 3)))
|
||||
continue
|
||||
_bbox = bbox[nv]
|
||||
if len(_bbox.shape) == 1:
|
||||
_bbox = _bbox[None]
|
||||
output = super().__call__(_bbox, images[nv], imgnames[nv])
|
||||
kpts = output['params']['keypoints']
|
||||
conf = kpts[..., -1:]
|
||||
kpts = self.batch_affine_transform(kpts, output['params']['inv_trans'])
|
||||
kpts = np.concatenate([kpts, conf], axis=-1)
|
||||
if len(kpts.shape) == 3:
|
||||
kpts = kpts[0]
|
||||
kpts_all.append(kpts)
|
||||
kpts_all = np.stack(kpts_all)
|
||||
if squeeze:
|
||||
kpts_all = kpts_all[0]
|
||||
return {
|
||||
'keypoints': kpts_all
|
||||
}
|
161
myeasymocap/backbone/hand2d/resnet.py
Normal file
161
myeasymocap/backbone/hand2d/resnet.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from torchvision.models.resnet import model_urls
|
||||
from ..basetopdown import get_preds_from_heatmaps
|
||||
|
||||
def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
|
||||
layers = []
|
||||
for i in range(len(feat_dims)-1):
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels=feat_dims[i],
|
||||
out_channels=feat_dims[i+1],
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
))
|
||||
# Do not use BN and ReLU for final estimation
|
||||
if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
|
||||
layers.append(nn.BatchNorm2d(feat_dims[i+1]))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def make_deconv_layers(feat_dims, bnrelu_final=True):
|
||||
layers = []
|
||||
for i in range(len(feat_dims)-1):
|
||||
layers.append(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=feat_dims[i],
|
||||
out_channels=feat_dims[i+1],
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=0,
|
||||
bias=False))
|
||||
|
||||
# Do not use BN and ReLU for final estimation
|
||||
if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
|
||||
layers.append(nn.BatchNorm2d(feat_dims[i+1]))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResNetBackbone(nn.Module):
|
||||
|
||||
def __init__(self, resnet_type):
|
||||
|
||||
resnet_spec = {18: (BasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512], 'resnet18'),
|
||||
34: (BasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512], 'resnet34'),
|
||||
50: (Bottleneck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], 'resnet50'),
|
||||
101: (Bottleneck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048], 'resnet101'),
|
||||
152: (Bottleneck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048], 'resnet152')}
|
||||
block, layers, channels, name = resnet_spec[resnet_type]
|
||||
|
||||
self.name = name
|
||||
self.inplanes = 64
|
||||
super(ResNetBackbone, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False) # RGB
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.normal_(m.weight, mean=0, std=0.001)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
org_resnet = torch.utils.model_zoo.load_url(model_urls[self.name])
|
||||
# drop orginal resnet fc layer, add 'None' in case of no fc layer, that will raise error
|
||||
org_resnet.pop('fc.weight', None)
|
||||
org_resnet.pop('fc.bias', None)
|
||||
|
||||
self.load_state_dict(org_resnet)
|
||||
print("Initialize resnet from model zoo")
|
||||
|
||||
class ResNet_Deconv(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hm2d_size = 64
|
||||
|
||||
self.resnet = ResNetBackbone(50)
|
||||
self.deconv = make_deconv_layers([2048, 256, 256, 256])
|
||||
self.conv_hm2d = make_conv_layers([256, 21],kernel=1,stride=1,padding=0,bnrelu_final=False)
|
||||
|
||||
self.resnet.init_weights()
|
||||
self.deconv.apply(self.init_weights)
|
||||
self.conv_hm2d.apply(self.init_weights)
|
||||
|
||||
@staticmethod
|
||||
def init_weights(m):
|
||||
if type(m) == nn.ConvTranspose2d:
|
||||
nn.init.normal_(m.weight,std=0.001)
|
||||
elif type(m) == nn.Conv2d:
|
||||
nn.init.normal_(m.weight,std=0.001)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif type(m) == nn.BatchNorm2d:
|
||||
nn.init.constant_(m.weight,1)
|
||||
nn.init.constant_(m.bias,0)
|
||||
elif type(m) == nn.Linear:
|
||||
nn.init.normal_(m.weight,std=0.01)
|
||||
nn.init.constant_(m.bias,0)
|
||||
|
||||
def forward(self, img):
|
||||
x_feat = self.resnet(img)
|
||||
x_feat = self.deconv(x_feat)
|
||||
|
||||
x_hm2d = self.conv_hm2d(x_feat)
|
||||
pred = get_preds_from_heatmaps(x_hm2d.detach().cpu().numpy())
|
||||
return {
|
||||
'keypoints': pred
|
||||
}
|
0
myeasymocap/backbone/hmr/__init__.py
Normal file
0
myeasymocap/backbone/hmr/__init__.py
Normal file
35
myeasymocap/backbone/hmr/hmr.py
Normal file
35
myeasymocap/backbone/hmr/hmr.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from ..basetopdown import BaseTopDownModelCache, gdown_models
|
||||
import pickle
|
||||
from .models import hmr
|
||||
|
||||
class MyHMR(BaseTopDownModelCache):
|
||||
def __init__(self, ckpt, url=None):
|
||||
super().__init__('handhmr', bbox_scale=1., res_input=224)
|
||||
self.model = hmr()
|
||||
self.model.eval()
|
||||
if not os.path.exists(ckpt) and url is not None:
|
||||
gdown_models(ckpt, url)
|
||||
assert os.path.exists(ckpt), f'{ckpt} not exists'
|
||||
checkpoint = torch.load(ckpt)
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
state_dict = checkpoint['state_dict']
|
||||
prefix = 'model.'
|
||||
self.load_checkpoint(self.model, state_dict, prefix, strict=True)
|
||||
self.model.to(self.device)
|
||||
|
||||
def __call__(self, bbox, images, imgnames):
|
||||
output = super().__call__(bbox, images, imgnames)
|
||||
Rh = output['params']['poses'][:3].copy()
|
||||
poses = output['params']['poses'][3:]
|
||||
Th = np.zeros_like(Rh)
|
||||
Th[2] = 1.
|
||||
output['params'] = {
|
||||
'Rh': Rh,
|
||||
'Th': Th,
|
||||
'poses': poses,
|
||||
'shapes': output['params']['shapes'],
|
||||
}
|
||||
return output
|
256
myeasymocap/backbone/hmr/hmr_api.py
Normal file
256
myeasymocap/backbone/hmr/hmr_api.py
Normal file
@ -0,0 +1,256 @@
|
||||
'''
|
||||
Date: 2021-10-25 11:51:37 am
|
||||
Author: dihuangdh
|
||||
Descriptions:
|
||||
-----
|
||||
LastEditTime: 2021-10-25 1:50:40 pm
|
||||
LastEditors: dihuangdh
|
||||
'''
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import Normalize
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .models import hmr
|
||||
|
||||
|
||||
class constants:
|
||||
FOCAL_LENGTH = 5000.
|
||||
IMG_RES = 224
|
||||
|
||||
# Mean and standard deviation for normalizing input image
|
||||
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
|
||||
IMG_NORM_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
def get_transform(center, scale, res, rot=0):
|
||||
"""Generate transformation matrix."""
|
||||
h = 200 * scale
|
||||
t = np.zeros((3, 3))
|
||||
t[0, 0] = float(res[1]) / h
|
||||
t[1, 1] = float(res[0]) / h
|
||||
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
||||
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
||||
t[2, 2] = 1
|
||||
if not rot == 0:
|
||||
rot = -rot # To match direction of rotation from cropping
|
||||
rot_mat = np.zeros((3,3))
|
||||
rot_rad = rot * np.pi / 180
|
||||
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
rot_mat[0,:2] = [cs, -sn]
|
||||
rot_mat[1,:2] = [sn, cs]
|
||||
rot_mat[2,2] = 1
|
||||
# Need to rotate around center
|
||||
t_mat = np.eye(3)
|
||||
t_mat[0,2] = -res[1]/2
|
||||
t_mat[1,2] = -res[0]/2
|
||||
t_inv = t_mat.copy()
|
||||
t_inv[:2,2] *= -1
|
||||
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
|
||||
return t
|
||||
|
||||
def transform(pt, center, scale, res, invert=0, rot=0):
|
||||
"""Transform pixel location to different reference."""
|
||||
t = get_transform(center, scale, res, rot=rot)
|
||||
if invert:
|
||||
t = np.linalg.inv(t)
|
||||
new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2].astype(int)+1
|
||||
|
||||
def crop(img, center, scale, res, rot=0, bias=0):
|
||||
"""Crop image according to the supplied bounding box."""
|
||||
# Upper left point
|
||||
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
||||
# Bottom right point
|
||||
br = np.array(transform([res[0]+1,
|
||||
res[1]+1], center, scale, res, invert=1))-1
|
||||
|
||||
# Padding so that when rotated proper amount of context is included
|
||||
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
||||
if not rot == 0:
|
||||
ul -= pad
|
||||
br += pad
|
||||
|
||||
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
if len(img.shape) > 2:
|
||||
new_shape += [img.shape[2]]
|
||||
new_img = np.zeros(new_shape) + bias
|
||||
|
||||
# Range to fill new array
|
||||
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
||||
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
||||
# Range to sample from original image
|
||||
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
||||
old_y = max(0, ul[1]), min(len(img), br[1])
|
||||
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
||||
old_x[0]:old_x[1]]
|
||||
|
||||
if not rot == 0:
|
||||
# Remove padding
|
||||
new_img = scipy.misc.imrotate(new_img, rot)
|
||||
new_img = new_img[pad:-pad, pad:-pad]
|
||||
new_img = cv2.resize(new_img, (res[0], res[1]))
|
||||
return new_img
|
||||
|
||||
def process_image(img, bbox, input_res=224):
|
||||
"""Read image, do preprocessing and possibly crop it according to the bounding box.
|
||||
If there are bounding box annotations, use them to crop the image.
|
||||
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
|
||||
"""
|
||||
img = img[:, :, ::-1].copy()
|
||||
normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
|
||||
l, t, r, b = bbox[:4]
|
||||
center = [(l+r)/2, (t+b)/2]
|
||||
width = max(r-l, b-t)
|
||||
scale = width/200.0
|
||||
img = crop(img, center, scale, (input_res, input_res))
|
||||
img = img.astype(np.float32) / 255.
|
||||
img = torch.from_numpy(img).permute(2,0,1)
|
||||
norm_img = normalize_img(img.clone())[None]
|
||||
return img, norm_img
|
||||
|
||||
def solve_translation(X, x, K):
|
||||
A = np.zeros((2*X.shape[0], 3))
|
||||
b = np.zeros((2*X.shape[0], 1))
|
||||
fx, fy = K[0, 0], K[1, 1]
|
||||
cx, cy = K[0, 2], K[1, 2]
|
||||
for nj in range(X.shape[0]):
|
||||
A[2*nj, 0] = 1
|
||||
A[2*nj + 1, 1] = 1
|
||||
A[2*nj, 2] = -(x[nj, 0] - cx)/fx
|
||||
A[2*nj+1, 2] = -(x[nj, 1] - cy)/fy
|
||||
b[2*nj, 0] = X[nj, 2]*(x[nj, 0] - cx)/fx - X[nj, 0]
|
||||
b[2*nj+1, 0] = X[nj, 2]*(x[nj, 1] - cy)/fy - X[nj, 1]
|
||||
A[2*nj:2*nj+2, :] *= x[nj, 2]
|
||||
b[2*nj:2*nj+2, :] *= x[nj, 2]
|
||||
trans = np.linalg.inv(A.T @ A) @ A.T @ b
|
||||
return trans.T[0]
|
||||
|
||||
def estimate_translation_np(S, joints_2d, joints_conf, K):
|
||||
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
||||
Input:
|
||||
S: (25, 3) 3D joint locations
|
||||
joints: (25, 3) 2D joint locations and confidence
|
||||
Returns:
|
||||
(3,) camera translation vector
|
||||
"""
|
||||
num_joints = S.shape[0]
|
||||
# focal length
|
||||
f = np.array([K[0, 0], K[1, 1]])
|
||||
# optical center
|
||||
center = np.array([K[0, 2], K[1, 2]])
|
||||
|
||||
# transformations
|
||||
Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
|
||||
XY = np.reshape(S[:,0:2],-1)
|
||||
O = np.tile(center,num_joints)
|
||||
F = np.tile(f,num_joints)
|
||||
weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
|
||||
|
||||
# least squares
|
||||
Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
|
||||
c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
|
||||
|
||||
# weighted least squares
|
||||
W = np.diagflat(weight2)
|
||||
Q = np.dot(W,Q)
|
||||
c = np.dot(W,c)
|
||||
|
||||
# square matrix
|
||||
A = np.dot(Q.T,Q)
|
||||
b = np.dot(Q.T,c)
|
||||
|
||||
# solution
|
||||
trans = np.linalg.solve(A, b)
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
class HMR:
|
||||
def __init__(self, checkpoint, device) -> None:
|
||||
model = hmr().to(device)
|
||||
checkpoint = torch.load(checkpoint)
|
||||
state_dict = checkpoint['state_dict']
|
||||
# update state_dict, remove 'model.'
|
||||
for key in list(state_dict.keys()):
|
||||
state_dict[key[6:]] = state_dict.pop(key)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
# Load SMPL model
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
def forward(self, img, bbox, use_rh_th=True):
|
||||
# Preprocess input image and generate predictions
|
||||
img, norm_img = process_image(img, bbox, input_res=constants.IMG_RES)
|
||||
with torch.no_grad():
|
||||
pred_rotmat, pred_betas, pred_camera = self.model(norm_img.to(self.device))
|
||||
results = {
|
||||
'shapes': pred_betas.detach().cpu().numpy()
|
||||
}
|
||||
results['poses'] = pred_rotmat.detach().cpu().numpy()
|
||||
if use_rh_th:
|
||||
body_params = {
|
||||
'poses': results['poses'],
|
||||
'shapes': results['shapes'],
|
||||
'Rh': results['poses'][:, :3].copy(),
|
||||
'Th': np.zeros((1, 3)),
|
||||
}
|
||||
body_params['Th'][0, 2] = 5
|
||||
body_params['poses'][:, :3] = 0
|
||||
results = body_params
|
||||
return results
|
||||
|
||||
def __call__(self, body_model, img, bbox, kpts, camera, ret_vertices=True):
|
||||
body_params = self.forward(img.copy(), bbox)
|
||||
body_params = body_model.check_params(body_params)
|
||||
# only use body joints to estimation translation
|
||||
nJoints = 21
|
||||
keypoints3d = body_model(return_verts=False, return_tensor=False, **body_params)[0]
|
||||
trans = solve_translation(keypoints3d[:nJoints], kpts[:nJoints], camera['K'])
|
||||
body_params['Th'] += trans[None, :]
|
||||
if body_params['Th'][0, 2] < 0:
|
||||
body_params['Th'] = -body_params['Th']
|
||||
Rhold = cv2.Rodrigues(body_params['Rh'])[0]
|
||||
rotx = cv2.Rodrigues(np.pi*np.array([1., 0, 0]))[0]
|
||||
Rhold = rotx @ Rhold
|
||||
body_params['Rh'] = cv2.Rodrigues(Rhold)[0].reshape(1, 3)
|
||||
# convert to world coordinate
|
||||
Rhold = cv2.Rodrigues(body_params['Rh'])[0]
|
||||
Thold = body_params['Th']
|
||||
Rh = camera['R'].T @ Rhold
|
||||
Th = (camera['R'].T @ (Thold.T - camera['T'])).T
|
||||
body_params['Th'] = Th
|
||||
body_params['Rh'] = cv2.Rodrigues(Rh)[0].reshape(1, 3)
|
||||
keypoints3d = body_model(return_verts=False, return_tensor=False, **body_params)[0]
|
||||
results = {'body_params': body_params, 'keypoints3d': keypoints3d}
|
||||
if ret_vertices:
|
||||
vertices = body_model(return_verts=True, return_tensor=False, **body_params)[0]
|
||||
results['vertices'] = vertices
|
||||
return results
|
||||
|
||||
def init_with_hmr(body_model, spin_model, img, bbox, kpts, camera):
|
||||
body_params = spin_model.forward(img.copy(), bbox)
|
||||
body_params = body_model.check_params(body_params)
|
||||
# only use body joints to estimation translation
|
||||
nJoints = 15
|
||||
keypoints3d = body_model(return_verts=False, return_tensor=False, **body_params)[0]
|
||||
trans = estimate_translation_np(keypoints3d[:nJoints], kpts[:nJoints, :2], kpts[:nJoints, 2], camera['K'])
|
||||
body_params['Th'] += trans[None, :]
|
||||
# convert to world coordinate
|
||||
Rhold = cv2.Rodrigues(body_params['Rh'])[0]
|
||||
Thold = body_params['Th']
|
||||
Rh = camera['R'].T @ Rhold
|
||||
Th = (camera['R'].T @ (Thold.T - camera['T'])).T
|
||||
body_params['Th'] = Th
|
||||
body_params['Rh'] = cv2.Rodrigues(Rh)[0].reshape(1, 3)
|
||||
vertices = body_model(return_verts=True, return_tensor=False, **body_params)[0]
|
||||
keypoints3d = body_model(return_verts=False, return_tensor=False, **body_params)[0]
|
||||
results = {'body_params': body_params, 'vertices': vertices, 'keypoints3d': keypoints3d}
|
||||
return results
|
||||
|
||||
class
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
174
myeasymocap/backbone/hmr/models.py
Normal file
174
myeasymocap/backbone/hmr/models.py
Normal file
@ -0,0 +1,174 @@
|
||||
'''
|
||||
Date: 2021-10-25 11:51:29 am
|
||||
Author: dihuangdh
|
||||
Descriptions:
|
||||
-----
|
||||
LastEditTime: 2021-10-25 11:51:58 am
|
||||
LastEditors: dihuangdh
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models.resnet as resnet
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
""" Redefinition of Bottleneck residual block
|
||||
Adapted from the official PyTorch implementation
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class HMR(nn.Module):
|
||||
""" SMPL Iterative Regressor with ResNet50 backbone
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers):
|
||||
self.inplanes = 64
|
||||
super(HMR, self).__init__()
|
||||
npose = 3 + 45
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024)
|
||||
self.drop1 = nn.Dropout()
|
||||
self.fc2 = nn.Linear(1024, 1024)
|
||||
self.drop2 = nn.Dropout()
|
||||
self.decpose = nn.Linear(1024, npose)
|
||||
self.decshape = nn.Linear(1024, 10)
|
||||
self.deccam = nn.Linear(1024, 3)
|
||||
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
|
||||
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
|
||||
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
init_pose = torch.zeros(npose).unsqueeze(0)
|
||||
init_shape = torch.zeros(10).unsqueeze(0)
|
||||
init_cam = torch.zeros(3).unsqueeze(0)
|
||||
self.register_buffer('init_pose', init_pose)
|
||||
self.register_buffer('init_shape', init_shape)
|
||||
self.register_buffer('init_cam', init_cam)
|
||||
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
if init_pose is None:
|
||||
init_pose = self.init_pose.expand(batch_size, -1)
|
||||
if init_shape is None:
|
||||
init_shape = self.init_shape.expand(batch_size, -1)
|
||||
if init_cam is None:
|
||||
init_cam = self.init_cam.expand(batch_size, -1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x1 = self.layer1(x)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
|
||||
xf = self.avgpool(x4)
|
||||
xf = xf.view(xf.size(0), -1)
|
||||
|
||||
pred_pose = init_pose
|
||||
pred_shape = init_shape
|
||||
pred_cam = init_cam
|
||||
for i in range(n_iter):
|
||||
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam],1)
|
||||
xc = self.fc1(xc)
|
||||
xc = self.drop1(xc)
|
||||
xc = self.fc2(xc)
|
||||
xc = self.drop2(xc)
|
||||
pred_pose = self.decpose(xc) + pred_pose
|
||||
pred_shape = self.decshape(xc) + pred_shape
|
||||
pred_cam = self.deccam(xc) + pred_cam
|
||||
|
||||
# pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
|
||||
return {
|
||||
'poses': pred_pose,
|
||||
'shapes': pred_shape,
|
||||
'cam': pred_cam
|
||||
}
|
||||
|
||||
def hmr(pretrained=True, **kwargs):
|
||||
""" Constructs an HMR model with ResNet50 backbone.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = HMR(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
resnet_imagenet = resnet.resnet50(pretrained=True)
|
||||
model.load_state_dict(resnet_imagenet.state_dict(),strict=False)
|
||||
return model
|
Loading…
Reference in New Issue
Block a user