EasyMocap/easymocap/bodymodel/smplx.py

260 lines
11 KiB
Python
Raw Normal View History

2022-08-21 16:02:11 +08:00
import torch
import torch.nn as nn
from .base import Model
from .smpl import SMPLModel, SMPLLayerEmbedding, read_pickle, to_tensor
from os.path import join
import numpy as np
def read_hand(path, use_pca, use_flat_mean, num_pca_comps):
data = read_pickle(path)
mean = data['hands_mean'].reshape(1, -1).astype(np.float32)
mean_full = mean
components_full = data['hands_components'].astype(np.float32)
weight = np.diag(components_full @ components_full.T)
components = components_full[:num_pca_comps]
weight = weight[:num_pca_comps]
if use_flat_mean:
mean = np.zeros_like(mean)
return mean, components, weight, mean_full, components_full
class MANO(SMPLModel):
def __init__(self, cfg_hand, **kwargs):
super().__init__(**kwargs)
self.name = 'mano'
self.use_root_rot = False
mean, components, weight, mean_full, components_full = read_hand(kwargs['model_path'], **cfg_hand)
self.register_buffer('mean', to_tensor(mean, dtype=self.dtype))
self.register_buffer('components', to_tensor(components, dtype=self.dtype))
self.cfg_hand = cfg_hand
self.to(self.device)
if cfg_hand.use_pca:
self.NUM_POSES = cfg_hand.num_pca_comps
def extend_poses(self, poses, **kwargs):
if poses.shape[-1] == self.mean.shape[-1] + 3:
return poses
if self.cfg_hand.use_pca:
poses = poses @ self.components
if kwargs.get('pose2rot', True):
poses = super().extend_poses(poses+self.mean, **kwargs)
else:
poses = super().extend_poses(poses, **kwargs)
return poses
def jacobian_posesfull_poses(self, poses, poses_full):
if self.cfg_hand.use_pca:
jacobian = self.components.t()
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
jacobian = torch.cat([zero_root, jacobian], dim=0)
else:
jacobian = super().jacobian_posesfull_poses(poses, poses_full)
return jacobian
class MANOLR(Model):
def __init__(self, model_path, regressor_path, cfg_hand, **kwargs):
super().__init__()
self.name = 'manolr'
keys = list(model_path.keys())
# stack 方式:(nframes, nhand x ndim)
self.keys = keys
modules_hand = {}
faces = []
v_template = []
cnt = 0
for key in keys:
modules_hand[key] = MANO(cfg_hand, model_path=model_path[key], regressor_path=regressor_path[key], **kwargs)
v_template.append(modules_hand[key].v_template.cpu().numpy())
faces.append(modules_hand[key].faces + cnt)
cnt += v_template[-1].shape[0]
self.device = modules_hand[key].device
self.dtype = modules_hand[key].dtype
if key == 'right':
modules_hand[key].shapedirs[:, 0] *= -1
modules_hand[key].j_shapedirs[:, 0] *= -1
self.faces = np.vstack(faces)
self.v_template = np.vstack(v_template)
self.modules_hand = nn.ModuleDict(modules_hand)
self.to(self.device)
def init_params(self, **kwargs):
param_all = {}
for key in self.keys:
param = self.modules_hand[key].init_params(**kwargs)
param_all[key] = param
if False:
params = {k: torch.cat([param_all[key][k] for key in self.keys], dim=-1) for k in param.keys()}
else:
params = {k: np.concatenate([param_all[key][k] for key in self.keys], axis=-1) for k in param.keys() if k != 'shapes'}
params['shapes'] = param_all['left']['shapes']
return params
def split(self, params):
params_split = {}
for imodel, model in enumerate(self.keys):
param_= params.copy()
for key in ['poses', 'shapes', 'Rh', 'Th']:
if key not in params.keys():continue
if key == 'shapes':
continue
shape = params[key].shape[-1]
start = shape//len(self.keys)*imodel
end = shape//len(self.keys)*(imodel+1)
param_[key] = params[key][:, start:end]
params_split[model] = param_
return params_split
def forward(self, **params):
params_split = self.split(params)
rets = []
for imodel, model in enumerate(self.keys):
ret = self.modules_hand[model](**params_split[model])
rets.append(ret)
if params.get('return_tensor', True):
rets = torch.cat(rets, dim=1)
else:
rets = np.concatenate(rets, axis=1)
return rets
def extend_poses(self, poses, **kwargs):
params_split = self.split({'poses': poses})
rets = []
for imodel, model in enumerate(self.keys):
poses = params_split[model]['poses']
poses = self.modules_hand[model].extend_poses(poses)
rets.append(poses)
poses = torch.cat(rets, dim=1)
return poses
def export_full_poses(self, poses, **kwargs):
params_split = self.split({'poses': poses})
rets = []
for imodel, model in enumerate(self.keys):
poses = torch.Tensor(params_split[model]['poses']).to(self.device)
poses = self.modules_hand[model].extend_poses(poses)
rets.append(poses)
poses = torch.cat(rets, dim=1)
return poses.detach().cpu().numpy()
class SMPLHModel(SMPLModel):
def __init__(self, mano_path, cfg_hand, **kwargs):
super().__init__(**kwargs)
self.NUM_POSES = self.NUM_POSES - 90
meanl, componentsl, weight_l, self.mean_full_l, self.components_full_l = read_hand(join(mano_path, 'MANO_LEFT.pkl'), **cfg_hand)
meanr, componentsr, weight_r, self.mean_full_r, self.components_full_r = read_hand(join(mano_path, 'MANO_RIGHT.pkl'), **cfg_hand)
self.register_buffer('weight_l', to_tensor(weight_l, dtype=self.dtype))
self.register_buffer('weight_r', to_tensor(weight_r, dtype=self.dtype))
self.register_buffer('meanl', to_tensor(meanl, dtype=self.dtype))
self.register_buffer('meanr', to_tensor(meanr, dtype=self.dtype))
self.register_buffer('componentsl', to_tensor(componentsl, dtype=self.dtype))
self.register_buffer('componentsr', to_tensor(componentsr, dtype=self.dtype))
self.register_buffer('jacobian_posesfull_poses_', self._jacobian_posesfull_poses())
self.NUM_HANDS = cfg_hand.num_pca_comps if cfg_hand.use_pca else 45
self.cfg_hand = cfg_hand
self.to(self.device)
def _jacobian_posesfull_poses(self):
# TODO: cache this
# | body_full/body | 0 | 0 |
# | 0 | l | 0 |
# | 0 | 0 | r |
eye_right = torch.eye(self.NUM_POSES, dtype=self.dtype)
#
jac_handl = self.componentsl.t()
jac_handr = self.componentsr.t()
output = torch.zeros((self.NUM_POSES_FULL, self.NUM_POSES+jac_handl.shape[1]*2), dtype=self.dtype)
if self.use_root_rot:
raise NotImplementedError
else:
output[3:3+self.NUM_POSES, :self.NUM_POSES] = eye_right
output[3+self.NUM_POSES:3+self.NUM_POSES+jac_handl.shape[0], \
self.NUM_POSES:self.NUM_POSES+jac_handl.shape[1]] = jac_handl
output[3+self.NUM_POSES+jac_handl.shape[0]:3+self.NUM_POSES+2*jac_handl.shape[0], \
self.NUM_POSES+jac_handl.shape[1]:self.NUM_POSES+jac_handl.shape[1]*2] = jac_handr
return output
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
params = super().init_params(nFrames, nShapes, nPerson, ret_tensor, add_scale=add_scale)
handl = np.zeros((nFrames, self.NUM_HANDS))
handr = np.zeros((nFrames, self.NUM_HANDS))
if nPerson > 1:
handl = handl[:, None].repeat(nPerson, axis=1)
handr = handr[:, None].repeat(nPerson, axis=1)
if ret_tensor:
handl = to_tensor(handl, self.dtype, self.device)
handr = to_tensor(handr, self.dtype, self.device)
params['handl'] = handl
params['handr'] = handr
return params
def extend_poses(self, poses, handl=None, handr=None, **kwargs):
if poses.shape[-1] == self.NUM_POSES_FULL:
return poses
poses = super().extend_poses(poses)
if handl is None:
handl = self.meanl.clone()
handr = self.meanr.clone()
handl = handl.expand(poses.shape[0], -1)
handr = handr.expand(poses.shape[0], -1)
else:
if self.cfg_hand.use_pca:
handl = handl @ self.componentsl
handr = handr @ self.componentsr
handl = handl +self.meanl
handr = handr +self.meanr
poses = torch.cat([poses, handl, handr], dim=-1)
return poses
def export_full_poses(self, poses, handl, handr, **kwargs):
poses = torch.Tensor(poses).to(self.device)
handl = torch.Tensor(handl).to(self.device)
handr = torch.Tensor(handr).to(self.device)
poses = self.extend_poses(poses, handl, handr)
return poses.detach().cpu().numpy()
class SMPLHModelEmbedding(SMPLHModel):
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
super().__init__(**kwargs)
from human_body_prior.tools.model_loader import load_model
from human_body_prior.models.vposer_model import VPoser
vposer, _ = load_model(vposer_ckpt,
model_code=VPoser,
remove_words_in_model_weights='vp_model.',
disable_grad=True)
vposer.to(self.device)
self.vposer = vposer
self.vposer_dim = 32
self.NUM_POSES = self.vposer_dim
def decode(self, poses, add_rot=True):
if poses.shape[-1] == 66 and add_rot:
return poses
elif poses.shape[-1] == 63 and not add_rot:
return poses
assert poses.shape[-1] == self.vposer_dim, poses.shape
ret = self.vposer.decode(poses)
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
if add_rot:
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
poses_body = torch.cat([zero_rot, poses_body], dim=-1)
return poses_body
def extend_poses(self, poses, handl, handr, **kwargs):
if poses.shape[-1] == self.NUM_POSES_FULL:
return poses
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
poses_body = self.decode(poses, add_rot=False)
if self.cfg_hand.use_pca:
handl = handl @ self.componentsl
handr = handr @ self.componentsr
handl = handl +self.meanl
handr = handr +self.meanr
poses = torch.cat([zero_rot, poses_body, handl, handr], dim=-1)
return poses
def export_full_poses(self, poses, handl, handr, **kwargs):
poses = torch.Tensor(poses).to(self.device)
handl = torch.Tensor(handl).to(self.device)
handr = torch.Tensor(handr).to(self.device)
poses = self.extend_poses(poses, handl, handr)
return poses.detach().cpu().numpy()