260 lines
11 KiB
Python
260 lines
11 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from .base import Model
|
||
|
from .smpl import SMPLModel, SMPLLayerEmbedding, read_pickle, to_tensor
|
||
|
from os.path import join
|
||
|
import numpy as np
|
||
|
|
||
|
def read_hand(path, use_pca, use_flat_mean, num_pca_comps):
|
||
|
data = read_pickle(path)
|
||
|
mean = data['hands_mean'].reshape(1, -1).astype(np.float32)
|
||
|
mean_full = mean
|
||
|
components_full = data['hands_components'].astype(np.float32)
|
||
|
weight = np.diag(components_full @ components_full.T)
|
||
|
components = components_full[:num_pca_comps]
|
||
|
weight = weight[:num_pca_comps]
|
||
|
if use_flat_mean:
|
||
|
mean = np.zeros_like(mean)
|
||
|
return mean, components, weight, mean_full, components_full
|
||
|
|
||
|
class MANO(SMPLModel):
|
||
|
def __init__(self, cfg_hand, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
self.name = 'mano'
|
||
|
self.use_root_rot = False
|
||
|
mean, components, weight, mean_full, components_full = read_hand(kwargs['model_path'], **cfg_hand)
|
||
|
self.register_buffer('mean', to_tensor(mean, dtype=self.dtype))
|
||
|
self.register_buffer('components', to_tensor(components, dtype=self.dtype))
|
||
|
self.cfg_hand = cfg_hand
|
||
|
self.to(self.device)
|
||
|
if cfg_hand.use_pca:
|
||
|
self.NUM_POSES = cfg_hand.num_pca_comps
|
||
|
|
||
|
def extend_poses(self, poses, **kwargs):
|
||
|
if poses.shape[-1] == self.mean.shape[-1] + 3:
|
||
|
return poses
|
||
|
if self.cfg_hand.use_pca:
|
||
|
poses = poses @ self.components
|
||
|
if kwargs.get('pose2rot', True):
|
||
|
poses = super().extend_poses(poses+self.mean, **kwargs)
|
||
|
else:
|
||
|
poses = super().extend_poses(poses, **kwargs)
|
||
|
return poses
|
||
|
|
||
|
def jacobian_posesfull_poses(self, poses, poses_full):
|
||
|
if self.cfg_hand.use_pca:
|
||
|
jacobian = self.components.t()
|
||
|
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
|
||
|
jacobian = torch.cat([zero_root, jacobian], dim=0)
|
||
|
else:
|
||
|
jacobian = super().jacobian_posesfull_poses(poses, poses_full)
|
||
|
return jacobian
|
||
|
class MANOLR(Model):
|
||
|
def __init__(self, model_path, regressor_path, cfg_hand, **kwargs):
|
||
|
super().__init__()
|
||
|
self.name = 'manolr'
|
||
|
keys = list(model_path.keys())
|
||
|
# stack 方式:(nframes, nhand x ndim)
|
||
|
self.keys = keys
|
||
|
modules_hand = {}
|
||
|
faces = []
|
||
|
v_template = []
|
||
|
cnt = 0
|
||
|
for key in keys:
|
||
|
modules_hand[key] = MANO(cfg_hand, model_path=model_path[key], regressor_path=regressor_path[key], **kwargs)
|
||
|
v_template.append(modules_hand[key].v_template.cpu().numpy())
|
||
|
faces.append(modules_hand[key].faces + cnt)
|
||
|
cnt += v_template[-1].shape[0]
|
||
|
self.device = modules_hand[key].device
|
||
|
self.dtype = modules_hand[key].dtype
|
||
|
if key == 'right':
|
||
|
modules_hand[key].shapedirs[:, 0] *= -1
|
||
|
modules_hand[key].j_shapedirs[:, 0] *= -1
|
||
|
self.faces = np.vstack(faces)
|
||
|
self.v_template = np.vstack(v_template)
|
||
|
self.modules_hand = nn.ModuleDict(modules_hand)
|
||
|
self.to(self.device)
|
||
|
|
||
|
def init_params(self, **kwargs):
|
||
|
param_all = {}
|
||
|
for key in self.keys:
|
||
|
param = self.modules_hand[key].init_params(**kwargs)
|
||
|
param_all[key] = param
|
||
|
if False:
|
||
|
params = {k: torch.cat([param_all[key][k] for key in self.keys], dim=-1) for k in param.keys()}
|
||
|
else:
|
||
|
params = {k: np.concatenate([param_all[key][k] for key in self.keys], axis=-1) for k in param.keys() if k != 'shapes'}
|
||
|
params['shapes'] = param_all['left']['shapes']
|
||
|
return params
|
||
|
|
||
|
def split(self, params):
|
||
|
params_split = {}
|
||
|
for imodel, model in enumerate(self.keys):
|
||
|
param_= params.copy()
|
||
|
for key in ['poses', 'shapes', 'Rh', 'Th']:
|
||
|
if key not in params.keys():continue
|
||
|
if key == 'shapes':
|
||
|
continue
|
||
|
shape = params[key].shape[-1]
|
||
|
start = shape//len(self.keys)*imodel
|
||
|
end = shape//len(self.keys)*(imodel+1)
|
||
|
param_[key] = params[key][:, start:end]
|
||
|
params_split[model] = param_
|
||
|
return params_split
|
||
|
|
||
|
def forward(self, **params):
|
||
|
params_split = self.split(params)
|
||
|
rets = []
|
||
|
for imodel, model in enumerate(self.keys):
|
||
|
ret = self.modules_hand[model](**params_split[model])
|
||
|
rets.append(ret)
|
||
|
if params.get('return_tensor', True):
|
||
|
rets = torch.cat(rets, dim=1)
|
||
|
else:
|
||
|
rets = np.concatenate(rets, axis=1)
|
||
|
return rets
|
||
|
|
||
|
def extend_poses(self, poses, **kwargs):
|
||
|
params_split = self.split({'poses': poses})
|
||
|
rets = []
|
||
|
for imodel, model in enumerate(self.keys):
|
||
|
poses = params_split[model]['poses']
|
||
|
poses = self.modules_hand[model].extend_poses(poses)
|
||
|
rets.append(poses)
|
||
|
poses = torch.cat(rets, dim=1)
|
||
|
return poses
|
||
|
|
||
|
def export_full_poses(self, poses, **kwargs):
|
||
|
params_split = self.split({'poses': poses})
|
||
|
rets = []
|
||
|
for imodel, model in enumerate(self.keys):
|
||
|
poses = torch.Tensor(params_split[model]['poses']).to(self.device)
|
||
|
poses = self.modules_hand[model].extend_poses(poses)
|
||
|
rets.append(poses)
|
||
|
poses = torch.cat(rets, dim=1)
|
||
|
return poses.detach().cpu().numpy()
|
||
|
|
||
|
class SMPLHModel(SMPLModel):
|
||
|
def __init__(self, mano_path, cfg_hand, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
self.NUM_POSES = self.NUM_POSES - 90
|
||
|
meanl, componentsl, weight_l, self.mean_full_l, self.components_full_l = read_hand(join(mano_path, 'MANO_LEFT.pkl'), **cfg_hand)
|
||
|
meanr, componentsr, weight_r, self.mean_full_r, self.components_full_r = read_hand(join(mano_path, 'MANO_RIGHT.pkl'), **cfg_hand)
|
||
|
self.register_buffer('weight_l', to_tensor(weight_l, dtype=self.dtype))
|
||
|
self.register_buffer('weight_r', to_tensor(weight_r, dtype=self.dtype))
|
||
|
self.register_buffer('meanl', to_tensor(meanl, dtype=self.dtype))
|
||
|
self.register_buffer('meanr', to_tensor(meanr, dtype=self.dtype))
|
||
|
self.register_buffer('componentsl', to_tensor(componentsl, dtype=self.dtype))
|
||
|
self.register_buffer('componentsr', to_tensor(componentsr, dtype=self.dtype))
|
||
|
|
||
|
self.register_buffer('jacobian_posesfull_poses_', self._jacobian_posesfull_poses())
|
||
|
self.NUM_HANDS = cfg_hand.num_pca_comps if cfg_hand.use_pca else 45
|
||
|
self.cfg_hand = cfg_hand
|
||
|
self.to(self.device)
|
||
|
|
||
|
def _jacobian_posesfull_poses(self):
|
||
|
# TODO: cache this
|
||
|
# | body_full/body | 0 | 0 |
|
||
|
# | 0 | l | 0 |
|
||
|
# | 0 | 0 | r |
|
||
|
eye_right = torch.eye(self.NUM_POSES, dtype=self.dtype)
|
||
|
#
|
||
|
jac_handl = self.componentsl.t()
|
||
|
jac_handr = self.componentsr.t()
|
||
|
output = torch.zeros((self.NUM_POSES_FULL, self.NUM_POSES+jac_handl.shape[1]*2), dtype=self.dtype)
|
||
|
if self.use_root_rot:
|
||
|
raise NotImplementedError
|
||
|
else:
|
||
|
output[3:3+self.NUM_POSES, :self.NUM_POSES] = eye_right
|
||
|
output[3+self.NUM_POSES:3+self.NUM_POSES+jac_handl.shape[0], \
|
||
|
self.NUM_POSES:self.NUM_POSES+jac_handl.shape[1]] = jac_handl
|
||
|
output[3+self.NUM_POSES+jac_handl.shape[0]:3+self.NUM_POSES+2*jac_handl.shape[0], \
|
||
|
self.NUM_POSES+jac_handl.shape[1]:self.NUM_POSES+jac_handl.shape[1]*2] = jac_handr
|
||
|
return output
|
||
|
|
||
|
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
|
||
|
params = super().init_params(nFrames, nShapes, nPerson, ret_tensor, add_scale=add_scale)
|
||
|
handl = np.zeros((nFrames, self.NUM_HANDS))
|
||
|
handr = np.zeros((nFrames, self.NUM_HANDS))
|
||
|
if nPerson > 1:
|
||
|
handl = handl[:, None].repeat(nPerson, axis=1)
|
||
|
handr = handr[:, None].repeat(nPerson, axis=1)
|
||
|
if ret_tensor:
|
||
|
handl = to_tensor(handl, self.dtype, self.device)
|
||
|
handr = to_tensor(handr, self.dtype, self.device)
|
||
|
params['handl'] = handl
|
||
|
params['handr'] = handr
|
||
|
return params
|
||
|
|
||
|
def extend_poses(self, poses, handl=None, handr=None, **kwargs):
|
||
|
if poses.shape[-1] == self.NUM_POSES_FULL:
|
||
|
return poses
|
||
|
poses = super().extend_poses(poses)
|
||
|
if handl is None:
|
||
|
handl = self.meanl.clone()
|
||
|
handr = self.meanr.clone()
|
||
|
handl = handl.expand(poses.shape[0], -1)
|
||
|
handr = handr.expand(poses.shape[0], -1)
|
||
|
else:
|
||
|
if self.cfg_hand.use_pca:
|
||
|
handl = handl @ self.componentsl
|
||
|
handr = handr @ self.componentsr
|
||
|
handl = handl +self.meanl
|
||
|
handr = handr +self.meanr
|
||
|
poses = torch.cat([poses, handl, handr], dim=-1)
|
||
|
return poses
|
||
|
|
||
|
def export_full_poses(self, poses, handl, handr, **kwargs):
|
||
|
poses = torch.Tensor(poses).to(self.device)
|
||
|
handl = torch.Tensor(handl).to(self.device)
|
||
|
handr = torch.Tensor(handr).to(self.device)
|
||
|
poses = self.extend_poses(poses, handl, handr)
|
||
|
return poses.detach().cpu().numpy()
|
||
|
|
||
|
class SMPLHModelEmbedding(SMPLHModel):
|
||
|
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
from human_body_prior.tools.model_loader import load_model
|
||
|
from human_body_prior.models.vposer_model import VPoser
|
||
|
vposer, _ = load_model(vposer_ckpt,
|
||
|
model_code=VPoser,
|
||
|
remove_words_in_model_weights='vp_model.',
|
||
|
disable_grad=True)
|
||
|
vposer.to(self.device)
|
||
|
self.vposer = vposer
|
||
|
self.vposer_dim = 32
|
||
|
self.NUM_POSES = self.vposer_dim
|
||
|
|
||
|
def decode(self, poses, add_rot=True):
|
||
|
if poses.shape[-1] == 66 and add_rot:
|
||
|
return poses
|
||
|
elif poses.shape[-1] == 63 and not add_rot:
|
||
|
return poses
|
||
|
assert poses.shape[-1] == self.vposer_dim, poses.shape
|
||
|
ret = self.vposer.decode(poses)
|
||
|
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
|
||
|
if add_rot:
|
||
|
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
|
||
|
poses_body = torch.cat([zero_rot, poses_body], dim=-1)
|
||
|
return poses_body
|
||
|
|
||
|
def extend_poses(self, poses, handl, handr, **kwargs):
|
||
|
if poses.shape[-1] == self.NUM_POSES_FULL:
|
||
|
return poses
|
||
|
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
|
||
|
poses_body = self.decode(poses, add_rot=False)
|
||
|
if self.cfg_hand.use_pca:
|
||
|
handl = handl @ self.componentsl
|
||
|
handr = handr @ self.componentsr
|
||
|
handl = handl +self.meanl
|
||
|
handr = handr +self.meanr
|
||
|
poses = torch.cat([zero_rot, poses_body, handl, handr], dim=-1)
|
||
|
return poses
|
||
|
|
||
|
def export_full_poses(self, poses, handl, handr, **kwargs):
|
||
|
poses = torch.Tensor(poses).to(self.device)
|
||
|
handl = torch.Tensor(handl).to(self.device)
|
||
|
handr = torch.Tensor(handr).to(self.device)
|
||
|
poses = self.extend_poses(poses, handl, handr)
|
||
|
return poses.detach().cpu().numpy()
|