263 lines
11 KiB
Python
263 lines
11 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from .config import update_hparams
|
|
# from .head import PareHead, SMPLHead, SMPLCamHead
|
|
from .head import PareHead
|
|
from .backbone.utils import get_backbone_info
|
|
from .backbone.hrnet import hrnet_w32
|
|
from os.path import join
|
|
from easymocap.multistage.torchgeometry import rotation_matrix_to_axis_angle
|
|
import cv2
|
|
|
|
def try_to_download():
|
|
model_dir = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'pare')
|
|
cmd = 'wget https://www.dropbox.com/s/aeulffqzb3zmh8x/pare-github-data.zip'
|
|
os.system(cmd)
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
cmd = 'unzip pare-github-data.zip -d {}'.format(model_dir)
|
|
os.system(cmd)
|
|
|
|
CFG = 'models/pare/data/pare/checkpoints/pare_w_3dpw_config.yaml'
|
|
CKPT = 'models/pare/data/pare/checkpoints/pare_w_3dpw_checkpoint.ckpt'
|
|
|
|
class PARE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_joints=24,
|
|
softmax_temp=1.0,
|
|
num_features_smpl=64,
|
|
backbone='resnet50',
|
|
focal_length=5000.,
|
|
img_res=224,
|
|
pretrained=None,
|
|
iterative_regression=False,
|
|
iter_residual=False,
|
|
num_iterations=3,
|
|
shape_input_type='feats', # 'feats.all_pose.shape.cam',
|
|
pose_input_type='feats', # 'feats.neighbor_pose_feats.all_pose.self_pose.neighbor_pose.shape.cam'
|
|
pose_mlp_num_layers=1,
|
|
shape_mlp_num_layers=1,
|
|
pose_mlp_hidden_size=256,
|
|
shape_mlp_hidden_size=256,
|
|
use_keypoint_features_for_smpl_regression=False,
|
|
use_heatmaps='',
|
|
use_keypoint_attention=False,
|
|
keypoint_attention_act='softmax',
|
|
use_postconv_keypoint_attention=False,
|
|
use_scale_keypoint_attention=False,
|
|
use_final_nonlocal=None,
|
|
use_branch_nonlocal=None,
|
|
use_hmr_regression=False,
|
|
use_coattention=False,
|
|
num_coattention_iter=1,
|
|
coattention_conv='simple',
|
|
deconv_conv_kernel_size=4,
|
|
use_upsampling=False,
|
|
use_soft_attention=False,
|
|
num_branch_iteration=0,
|
|
branch_deeper=False,
|
|
num_deconv_layers=3,
|
|
num_deconv_filters=256,
|
|
use_resnet_conv_hrnet=False,
|
|
use_position_encodings=None,
|
|
use_mean_camshape=False,
|
|
use_mean_pose=False,
|
|
init_xavier=False,
|
|
use_cam=False,
|
|
):
|
|
super(PARE, self).__init__()
|
|
if backbone.startswith('hrnet'):
|
|
backbone, use_conv = backbone.split('-')
|
|
# hrnet_w32-conv, hrnet_w32-interp
|
|
self.backbone = eval(backbone)(
|
|
pretrained=True,
|
|
downsample=False,
|
|
use_conv=(use_conv == 'conv')
|
|
)
|
|
else:
|
|
self.backbone = eval(backbone)(pretrained=True)
|
|
|
|
# self.backbone = eval(backbone)(pretrained=True)
|
|
self.head = PareHead(
|
|
num_joints=num_joints,
|
|
num_input_features=get_backbone_info(backbone)['n_output_channels'],
|
|
softmax_temp=softmax_temp,
|
|
num_deconv_layers=num_deconv_layers,
|
|
num_deconv_filters=[num_deconv_filters] * num_deconv_layers,
|
|
num_deconv_kernels=[deconv_conv_kernel_size] * num_deconv_layers,
|
|
num_features_smpl=num_features_smpl,
|
|
final_conv_kernel=1,
|
|
iterative_regression=iterative_regression,
|
|
iter_residual=iter_residual,
|
|
num_iterations=num_iterations,
|
|
shape_input_type=shape_input_type,
|
|
pose_input_type=pose_input_type,
|
|
pose_mlp_num_layers=pose_mlp_num_layers,
|
|
shape_mlp_num_layers=shape_mlp_num_layers,
|
|
pose_mlp_hidden_size=pose_mlp_hidden_size,
|
|
shape_mlp_hidden_size=shape_mlp_hidden_size,
|
|
use_keypoint_features_for_smpl_regression=use_keypoint_features_for_smpl_regression,
|
|
use_heatmaps=use_heatmaps,
|
|
use_keypoint_attention=use_keypoint_attention,
|
|
use_postconv_keypoint_attention=use_postconv_keypoint_attention,
|
|
keypoint_attention_act=keypoint_attention_act,
|
|
use_scale_keypoint_attention=use_scale_keypoint_attention,
|
|
use_branch_nonlocal=use_branch_nonlocal, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
|
|
use_final_nonlocal=use_final_nonlocal, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
|
|
backbone=backbone,
|
|
use_hmr_regression=use_hmr_regression,
|
|
use_coattention=use_coattention,
|
|
num_coattention_iter=num_coattention_iter,
|
|
coattention_conv=coattention_conv,
|
|
use_upsampling=use_upsampling,
|
|
use_soft_attention=use_soft_attention,
|
|
num_branch_iteration=num_branch_iteration,
|
|
branch_deeper=branch_deeper,
|
|
use_resnet_conv_hrnet=use_resnet_conv_hrnet,
|
|
use_position_encodings=use_position_encodings,
|
|
use_mean_camshape=use_mean_camshape,
|
|
use_mean_pose=use_mean_pose,
|
|
init_xavier=init_xavier,
|
|
)
|
|
|
|
self.use_cam = use_cam
|
|
# if self.use_cam:
|
|
# self.smpl = SMPLCamHead(
|
|
# img_res=img_res,
|
|
# )
|
|
# else:
|
|
# self.smpl = SMPLHead(
|
|
# focal_length=focal_length,
|
|
# img_res=img_res
|
|
# )
|
|
|
|
if pretrained is not None:
|
|
self.load_pretrained(pretrained)
|
|
|
|
def forward(
|
|
self,
|
|
images,
|
|
gt_segm=None,
|
|
):
|
|
features = self.backbone(images)
|
|
hmr_output = self.head(features, gt_segm=gt_segm)
|
|
rotmat = hmr_output['pred_pose']
|
|
shape = hmr_output['pred_shape']
|
|
rotmat_flat = rotmat.reshape(-1, 3, 3)
|
|
rvec_flat = rotation_matrix_to_axis_angle(rotmat_flat)
|
|
rvec = rvec_flat.reshape(*rotmat.shape[:-2], 3)
|
|
rvec = rvec.reshape(*rvec.shape[:-2], -1)
|
|
return {
|
|
'Rh': rvec[..., :3],
|
|
'Th': torch.zeros_like(rvec[..., :3]),
|
|
'poses': rvec[..., 3:],
|
|
'shapes': shape,
|
|
}
|
|
|
|
from ..basetopdown import BaseTopDownModelCache
|
|
import pickle
|
|
|
|
class NullSPIN:
|
|
def __init__(self, ckpt) -> None:
|
|
self.name = 'spin'
|
|
|
|
def __call__(self, bbox, images, imgname):
|
|
from easymocap.mytools.reader import read_smpl
|
|
basename = os.path.basename(imgname)
|
|
cachename = join(self.output, self.name, basename.replace('.jpg', '.json'))
|
|
if os.path.exists(cachename):
|
|
params = read_smpl(cachename)
|
|
params = params[0]
|
|
params = {key:val[0] for key, val in params.items() if key != 'id'}
|
|
ret = {
|
|
'params': params
|
|
}
|
|
return ret
|
|
else:
|
|
import ipdb; ipdb.set_trace()
|
|
|
|
class MyPARE(BaseTopDownModelCache):
|
|
def __init__(self, ckpt) -> None:
|
|
super().__init__('pare', bbox_scale=1.1, res_input=224)
|
|
if not os.path.exists(CFG):
|
|
from ...io.model import try_to_download_SMPL
|
|
try_to_download_SMPL('models/pare')
|
|
self.model_cfg = update_hparams(CFG)
|
|
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
self.model = self._build_model()
|
|
self._load_pretrained_model(CKPT)
|
|
self.model.eval()
|
|
self.model.to(self.device)
|
|
|
|
def __call__(self, bbox, images, imgnames):
|
|
return super().__call__(bbox[0], images, imgnames)
|
|
|
|
def _build_model(self):
|
|
# ========= Define PARE model ========= #
|
|
model_cfg = self.model_cfg
|
|
|
|
if model_cfg.METHOD == 'pare':
|
|
model = PARE(
|
|
backbone=model_cfg.PARE.BACKBONE,
|
|
num_joints=model_cfg.PARE.NUM_JOINTS,
|
|
softmax_temp=model_cfg.PARE.SOFTMAX_TEMP,
|
|
num_features_smpl=model_cfg.PARE.NUM_FEATURES_SMPL,
|
|
focal_length=model_cfg.DATASET.FOCAL_LENGTH,
|
|
img_res=model_cfg.DATASET.IMG_RES,
|
|
pretrained=model_cfg.TRAINING.PRETRAINED,
|
|
iterative_regression=model_cfg.PARE.ITERATIVE_REGRESSION,
|
|
num_iterations=model_cfg.PARE.NUM_ITERATIONS,
|
|
iter_residual=model_cfg.PARE.ITER_RESIDUAL,
|
|
shape_input_type=model_cfg.PARE.SHAPE_INPUT_TYPE,
|
|
pose_input_type=model_cfg.PARE.POSE_INPUT_TYPE,
|
|
pose_mlp_num_layers=model_cfg.PARE.POSE_MLP_NUM_LAYERS,
|
|
shape_mlp_num_layers=model_cfg.PARE.SHAPE_MLP_NUM_LAYERS,
|
|
pose_mlp_hidden_size=model_cfg.PARE.POSE_MLP_HIDDEN_SIZE,
|
|
shape_mlp_hidden_size=model_cfg.PARE.SHAPE_MLP_HIDDEN_SIZE,
|
|
use_keypoint_features_for_smpl_regression=model_cfg.PARE.USE_KEYPOINT_FEATURES_FOR_SMPL_REGRESSION,
|
|
use_heatmaps=model_cfg.DATASET.USE_HEATMAPS,
|
|
use_keypoint_attention=model_cfg.PARE.USE_KEYPOINT_ATTENTION,
|
|
use_postconv_keypoint_attention=model_cfg.PARE.USE_POSTCONV_KEYPOINT_ATTENTION,
|
|
use_scale_keypoint_attention=model_cfg.PARE.USE_SCALE_KEYPOINT_ATTENTION,
|
|
keypoint_attention_act=model_cfg.PARE.KEYPOINT_ATTENTION_ACT,
|
|
use_final_nonlocal=model_cfg.PARE.USE_FINAL_NONLOCAL,
|
|
use_branch_nonlocal=model_cfg.PARE.USE_BRANCH_NONLOCAL,
|
|
use_hmr_regression=model_cfg.PARE.USE_HMR_REGRESSION,
|
|
use_coattention=model_cfg.PARE.USE_COATTENTION,
|
|
num_coattention_iter=model_cfg.PARE.NUM_COATTENTION_ITER,
|
|
coattention_conv=model_cfg.PARE.COATTENTION_CONV,
|
|
use_upsampling=model_cfg.PARE.USE_UPSAMPLING,
|
|
deconv_conv_kernel_size=model_cfg.PARE.DECONV_CONV_KERNEL_SIZE,
|
|
use_soft_attention=model_cfg.PARE.USE_SOFT_ATTENTION,
|
|
num_branch_iteration=model_cfg.PARE.NUM_BRANCH_ITERATION,
|
|
branch_deeper=model_cfg.PARE.BRANCH_DEEPER,
|
|
num_deconv_layers=model_cfg.PARE.NUM_DECONV_LAYERS,
|
|
num_deconv_filters=model_cfg.PARE.NUM_DECONV_FILTERS,
|
|
use_resnet_conv_hrnet=model_cfg.PARE.USE_RESNET_CONV_HRNET,
|
|
use_position_encodings=model_cfg.PARE.USE_POS_ENC,
|
|
use_mean_camshape=model_cfg.PARE.USE_MEAN_CAMSHAPE,
|
|
use_mean_pose=model_cfg.PARE.USE_MEAN_POSE,
|
|
init_xavier=model_cfg.PARE.INIT_XAVIER,
|
|
).to(self.device)
|
|
else:
|
|
exit()
|
|
|
|
return model
|
|
|
|
def _load_pretrained_model(self, ckpt):
|
|
# ========= Load pretrained weights ========= #
|
|
state_dict = torch.load(ckpt, map_location='cpu')['state_dict']
|
|
pretrained_keys = state_dict.keys()
|
|
new_state_dict = {}
|
|
for pk in pretrained_keys:
|
|
if pk.startswith('model.'):
|
|
new_state_dict[pk.replace('model.', '')] = state_dict[pk]
|
|
else:
|
|
new_state_dict[pk] = state_dict[pk]
|
|
|
|
self.model.load_state_dict(new_state_dict, strict=False)
|
|
|
|
if __name__ == '__main__':
|
|
pass
|