422 lines
18 KiB
Python
422 lines
18 KiB
Python
'''
|
|
@ Date: 2020-11-18 14:04:10
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-08-28 16:37:55
|
|
@ FilePath: /EasyMocap/easymocap/smplmodel/body_model.py
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
from .lbs import batch_rodrigues
|
|
from .lbs import lbs, dqs
|
|
import os.path as osp
|
|
import pickle
|
|
import numpy as np
|
|
import os
|
|
|
|
def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')):
|
|
if 'torch.tensor' not in str(type(array)):
|
|
return torch.tensor(array, dtype=dtype).to(device)
|
|
else:
|
|
return array.to(device)
|
|
|
|
def to_np(array, dtype=np.float32):
|
|
if 'scipy.sparse' in str(type(array)):
|
|
array = array.todense()
|
|
return np.array(array, dtype=dtype)
|
|
|
|
def load_regressor(regressor_path):
|
|
if regressor_path.endswith('.npy'):
|
|
X_regressor = to_tensor(np.load(regressor_path))
|
|
elif regressor_path.endswith('.txt'):
|
|
data = np.loadtxt(regressor_path)
|
|
with open(regressor_path, 'r') as f:
|
|
shape = f.readline().split()[1:]
|
|
reg = np.zeros((int(shape[0]), int(shape[1])))
|
|
for i, j, v in data:
|
|
reg[int(i), int(j)] = v
|
|
X_regressor = to_tensor(reg)
|
|
else:
|
|
import ipdb; ipdb.set_trace()
|
|
return X_regressor
|
|
|
|
def load_bodydata(model_type, model_path, gender):
|
|
if osp.isdir(model_path):
|
|
model_fn = '{}_{}.{ext}'.format(model_type.upper(), gender.upper(), ext='pkl')
|
|
smpl_path = osp.join(model_path, model_fn)
|
|
else:
|
|
smpl_path = model_path
|
|
assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
|
|
smpl_path)
|
|
|
|
with open(smpl_path, 'rb') as smpl_file:
|
|
data = pickle.load(smpl_file, encoding='latin1')
|
|
return data
|
|
|
|
NUM_POSES = {'smpl': 72, 'smplh': 78, 'smplx': 66 + 12 + 9, 'mano': 9}
|
|
NUM_SHAPES = 10
|
|
NUM_EXPR = 10
|
|
class SMPLlayer(nn.Module):
|
|
def __init__(self, model_path, model_type='smpl', gender='neutral', device=None,
|
|
regressor_path=None,
|
|
use_pose_blending=True, use_shape_blending=True, use_joints=True,
|
|
with_color=False, use_lbs=True,
|
|
**kwargs) -> None:
|
|
super(SMPLlayer, self).__init__()
|
|
dtype = torch.float32
|
|
self.dtype = dtype
|
|
self.use_pose_blending = use_pose_blending
|
|
self.use_shape_blending = use_shape_blending
|
|
self.use_joints = use_joints
|
|
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
self.device = device
|
|
self.model_type = model_type
|
|
self.NUM_POSES = NUM_POSES[model_type]
|
|
# create the SMPL model
|
|
if use_lbs:
|
|
self.lbs = lbs
|
|
else:
|
|
self.lbs = dqs
|
|
data = load_bodydata(model_type, model_path, gender)
|
|
if with_color:
|
|
self.color = data['vertex_colors']
|
|
else:
|
|
self.color = None
|
|
self.faces = data['f']
|
|
self.register_buffer('faces_tensor',
|
|
to_tensor(to_np(self.faces, dtype=np.int64),
|
|
dtype=torch.long))
|
|
for key in ['J_regressor', 'v_template', 'weights']:
|
|
val = to_tensor(to_np(data[key]), dtype=dtype)
|
|
self.register_buffer(key, val)
|
|
# add poseblending
|
|
if use_pose_blending:
|
|
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
|
|
num_pose_basis = data['posedirs'].shape[-1]
|
|
# 207 x 20670
|
|
posedirs = data['posedirs']
|
|
data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T
|
|
val = to_tensor(to_np(data['posedirs']), dtype=dtype)
|
|
self.register_buffer('posedirs', val)
|
|
else:
|
|
self.posedirs = None
|
|
# add shape blending
|
|
if use_shape_blending:
|
|
val = to_tensor(to_np(data['shapedirs']), dtype=dtype)
|
|
self.register_buffer('shapedirs', val)
|
|
else:
|
|
self.shapedirs = None
|
|
if use_shape_blending:
|
|
self.J_shaped = None
|
|
else:
|
|
val = to_tensor(to_np(data['J']), dtype=dtype)
|
|
self.register_buffer('J_shaped', val)
|
|
|
|
self.nVertices = self.v_template.shape[0]
|
|
# indices of parents for each joints
|
|
parents = to_tensor(to_np(data['kintree_table'][0])).long()
|
|
parents[0] = -1
|
|
self.register_buffer('parents', parents)
|
|
|
|
if self.use_shape_blending:
|
|
if self.model_type == 'smplx':
|
|
# shape
|
|
self.num_expression_coeffs = 10
|
|
self.num_shapes = 10
|
|
self.shapedirs = self.shapedirs[:, :, :self.num_shapes+self.num_expression_coeffs]
|
|
elif self.model_type in ['smpl', 'smplh']:
|
|
self.shapedirs = self.shapedirs[:, :, :NUM_SHAPES]
|
|
# joints regressor
|
|
if regressor_path is not None and use_joints:
|
|
X_regressor = load_regressor(regressor_path)
|
|
X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0)
|
|
|
|
j_J_regressor = torch.zeros(self.J_regressor.shape[0], X_regressor.shape[0], device=device)
|
|
for i in range(self.J_regressor.shape[0]):
|
|
j_J_regressor[i, i] = 1
|
|
j_v_template = X_regressor @ self.v_template
|
|
#
|
|
# (25, 24)
|
|
j_weights = X_regressor @ self.weights
|
|
if self.use_pose_blending:
|
|
j_posedirs = torch.einsum('ab, bde->ade', [X_regressor, torch.Tensor(posedirs)]).numpy()
|
|
j_posedirs = np.reshape(j_posedirs, [-1, num_pose_basis]).T
|
|
j_posedirs = to_tensor(j_posedirs)
|
|
self.register_buffer('j_posedirs', j_posedirs)
|
|
else:
|
|
self.j_posedirs = None
|
|
if self.use_shape_blending:
|
|
j_shapedirs = torch.einsum('vij,kv->kij', [self.shapedirs, X_regressor])
|
|
self.register_buffer('j_shapedirs', j_shapedirs)
|
|
else:
|
|
self.j_shapedirs = None
|
|
self.register_buffer('j_weights', j_weights)
|
|
self.register_buffer('j_v_template', j_v_template)
|
|
self.register_buffer('j_J_regressor', j_J_regressor)
|
|
if self.model_type == 'smplh':
|
|
# load smplh data
|
|
self.num_pca_comps = kwargs['num_pca_comps']
|
|
from os.path import join
|
|
for key in ['LEFT', 'RIGHT']:
|
|
left_file = join(kwargs['mano_path'], 'MANO_{}.pkl'.format(key))
|
|
with open(left_file, 'rb') as f:
|
|
data = pickle.load(f, encoding='latin1')
|
|
val = to_tensor(to_np(data['hands_mean'].reshape(1, -1)), dtype=dtype)
|
|
self.register_buffer('mHandsMean'+key[0], val)
|
|
val = to_tensor(to_np(data['hands_components'][:self.num_pca_comps, :]), dtype=dtype)
|
|
self.register_buffer('mHandsComponents'+key[0], val)
|
|
self.use_pca = kwargs['use_pca']
|
|
self.use_flat_mean = kwargs['use_flat_mean']
|
|
if self.use_pca:
|
|
self.NUM_POSES = 66 + self.num_pca_comps * 2
|
|
else:
|
|
self.NUM_POSES = 66 + 15 * 3 * 2
|
|
elif self.model_type == 'mano':
|
|
self.num_pca_comps = kwargs['num_pca_comps']
|
|
self.use_pca = kwargs['use_pca']
|
|
self.use_flat_mean = kwargs['use_flat_mean']
|
|
if self.use_pca:
|
|
self.NUM_POSES = self.num_pca_comps + 3
|
|
else:
|
|
self.NUM_POSES = 45 + 3
|
|
|
|
val = to_tensor(to_np(data['hands_mean'].reshape(1, -1)), dtype=dtype)
|
|
self.register_buffer('mHandsMean', val)
|
|
val = to_tensor(to_np(data['hands_components'][:self.num_pca_comps, :]), dtype=dtype)
|
|
self.register_buffer('mHandsComponents', val)
|
|
elif self.model_type == 'smplx':
|
|
# hand pose
|
|
self.num_pca_comps = 6
|
|
from os.path import join
|
|
for key in ['Ll', 'Rr']:
|
|
val = to_tensor(to_np(data['hands_mean'+key[1]].reshape(1, -1)), dtype=dtype)
|
|
self.register_buffer('mHandsMean'+key[0], val)
|
|
val = to_tensor(to_np(data['hands_components'+key[1]][:self.num_pca_comps, :]), dtype=dtype)
|
|
self.register_buffer('mHandsComponents'+key[0], val)
|
|
self.use_pca = True
|
|
self.use_flat_mean = True
|
|
self.to(self.device)
|
|
|
|
@staticmethod
|
|
def extend_hand(poses, use_pca, use_flat_mean, coeffs, mean):
|
|
if use_pca:
|
|
poses = poses @ coeffs
|
|
if not use_flat_mean:
|
|
poses = poses + mean
|
|
return poses
|
|
|
|
def extend_pose(self, poses):
|
|
# skip SMPL or already extend
|
|
if self.model_type not in ['smplh', 'smplx', 'mano']:
|
|
return poses
|
|
elif self.model_type == 'smplh' and poses.shape[-1] == 156 and self.use_flat_mean:
|
|
return poses
|
|
elif self.model_type == 'smplx' and poses.shape[-1] == 165 and self.use_flat_mean:
|
|
return poses
|
|
elif self.model_type == 'mano' and poses.shape[-1] == 48 and self.use_flat_mean:
|
|
return poses
|
|
# skip mano
|
|
if self.model_type == 'mano':
|
|
poses_hand = self.extend_hand(poses[..., 3:], self.use_pca, self.use_flat_mean,
|
|
self.mHandsComponents, self.mHandsMean)
|
|
poses = torch.cat([poses[..., :3], poses_hand], dim=-1)
|
|
return poses
|
|
NUM_BODYJOINTS = 22 * 3
|
|
if self.use_pca:
|
|
NUM_HANDJOINTS = self.num_pca_comps
|
|
else:
|
|
NUM_HANDJOINTS = 15 * 3
|
|
NUM_FACEJOINTS = 3 * 3
|
|
poses_lh = poses[:, NUM_BODYJOINTS:NUM_BODYJOINTS + NUM_HANDJOINTS]
|
|
poses_rh = poses[:, NUM_BODYJOINTS + NUM_HANDJOINTS:NUM_BODYJOINTS+NUM_HANDJOINTS*2]
|
|
if self.use_pca:
|
|
poses_lh = poses_lh @ self.mHandsComponentsL
|
|
poses_rh = poses_rh @ self.mHandsComponentsR
|
|
if not self.use_flat_mean:
|
|
poses_lh = poses_lh + self.mHandsMeanL
|
|
poses_rh = poses_rh + self.mHandsMeanR
|
|
if self.model_type == 'smplh':
|
|
poses = torch.cat([poses[:, :NUM_BODYJOINTS], poses_lh, poses_rh], dim=1)
|
|
elif self.model_type == 'smplx':
|
|
# the head part have only three joints
|
|
# poses_head: (N, 9), jaw_pose, leye_pose, reye_pose respectively
|
|
poses_head = poses[:, NUM_BODYJOINTS+NUM_HANDJOINTS*2:]
|
|
# body, head, left hand, right hand
|
|
poses = torch.cat([poses[:, :NUM_BODYJOINTS], poses_head, poses_lh, poses_rh], dim=1)
|
|
return poses
|
|
|
|
def get_root(self, poses, shapes, return_tensor=False):
|
|
if 'torch' not in str(type(poses)):
|
|
dtype, device = self.dtype, self.device
|
|
poses = to_tensor(poses, dtype, device)
|
|
shapes = to_tensor(shapes, dtype, device)
|
|
vertices, joints = lbs(shapes, poses, self.v_template,
|
|
self.shapedirs, self.posedirs,
|
|
self.J_regressor, self.parents,
|
|
self.weights, pose2rot=True, dtype=self.dtype, only_shape=True)
|
|
# N x 3
|
|
j0 = joints[:, 0, :]
|
|
if not return_tensor:
|
|
j0 = j0.detach().cpu().numpy()
|
|
return j0
|
|
|
|
def convert_from_standard_smpl(self, poses, shapes, Rh=None, Th=None, expression=None):
|
|
if 'torch' not in str(type(poses)):
|
|
dtype, device = self.dtype, self.device
|
|
poses = to_tensor(poses, dtype, device)
|
|
shapes = to_tensor(shapes, dtype, device)
|
|
Rh = to_tensor(Rh, dtype, device)
|
|
Th = to_tensor(Th, dtype, device)
|
|
if expression is not None:
|
|
expression = to_tensor(expression, dtype, device)
|
|
|
|
bn = poses.shape[0]
|
|
# process shapes
|
|
if shapes.shape[0] < bn:
|
|
shapes = shapes.expand(bn, -1)
|
|
vertices, joints = lbs(shapes, poses, self.v_template,
|
|
self.shapedirs, self.posedirs,
|
|
self.J_regressor, self.parents,
|
|
self.weights, pose2rot=True, dtype=self.dtype, only_shape=True)
|
|
# N x 3
|
|
j0 = joints[:, 0, :]
|
|
Rh = poses[:, :3].clone()
|
|
# N x 3 x 3
|
|
rot = batch_rodrigues(Rh)
|
|
Tnew = Th + j0 - torch.einsum('bij,bj->bi', rot, j0)
|
|
poses[:, :3] = 0
|
|
res = dict(poses=poses.detach().cpu().numpy(),
|
|
shapes=shapes.detach().cpu().numpy(),
|
|
Rh=Rh.detach().cpu().numpy(),
|
|
Th=Tnew.detach().cpu().numpy()
|
|
)
|
|
return res
|
|
|
|
def full_poses(self, poses):
|
|
if 'torch' not in str(type(poses)):
|
|
dtype, device = self.dtype, self.device
|
|
poses = to_tensor(poses, dtype, device)
|
|
poses = self.extend_pose(poses)
|
|
return poses.detach().cpu().numpy()
|
|
|
|
def forward(self, poses, shapes, Rh=None, Th=None, expression=None,
|
|
v_template=None,
|
|
return_verts=True, return_tensor=True, return_smpl_joints=False,
|
|
only_shape=False, pose2rot=True, **kwargs):
|
|
""" Forward pass for SMPL model
|
|
|
|
Args:
|
|
poses (n, 72)
|
|
shapes (n, 10)
|
|
Rh (n, 3): global orientation
|
|
Th (n, 3): global translation
|
|
return_verts (bool, optional): if True return (6890, 3). Defaults to False.
|
|
"""
|
|
if 'torch' not in str(type(poses)):
|
|
dtype, device = self.dtype, self.device
|
|
poses = to_tensor(poses, dtype, device)
|
|
shapes = to_tensor(shapes, dtype, device)
|
|
if Rh is not None:
|
|
Rh = to_tensor(Rh, dtype, device)
|
|
if Th is not None:
|
|
Th = to_tensor(Th, dtype, device)
|
|
if expression is not None:
|
|
expression = to_tensor(expression, dtype, device)
|
|
|
|
bn = poses.shape[0]
|
|
# process Rh, Th
|
|
if Rh is None:
|
|
Rh = torch.zeros(bn, 3, device=poses.device)
|
|
if Th is None:
|
|
Th = torch.zeros(bn, 3, device=poses.device)
|
|
|
|
if len(Rh.shape) == 2: # angle-axis
|
|
rot = batch_rodrigues(Rh)
|
|
else:
|
|
rot = Rh
|
|
transl = Th.unsqueeze(dim=1)
|
|
# process shapes
|
|
if shapes.shape[0] < bn:
|
|
shapes = shapes.expand(bn, -1)
|
|
if expression is not None and self.model_type == 'smplx':
|
|
shapes = torch.cat([shapes, expression], dim=1)
|
|
# process poses
|
|
if pose2rot: # if given rotation matrix, no need for this
|
|
poses = self.extend_pose(poses)
|
|
if return_verts or not self.use_joints:
|
|
if v_template is None:
|
|
v_template = self.v_template
|
|
vertices, joints = self.lbs(shapes, poses, v_template,
|
|
self.shapedirs, self.posedirs,
|
|
self.J_regressor, self.parents,
|
|
self.weights, pose2rot=pose2rot, dtype=self.dtype,
|
|
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
|
|
if not self.use_joints and not return_verts:
|
|
vertices = joints
|
|
else:
|
|
vertices, joints = self.lbs(shapes, poses, self.j_v_template,
|
|
self.j_shapedirs, self.j_posedirs,
|
|
self.j_J_regressor, self.parents,
|
|
self.j_weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape,
|
|
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
|
|
if return_smpl_joints:
|
|
vertices = vertices[:, :self.J_regressor.shape[0], :]
|
|
else:
|
|
vertices = vertices[:, self.J_regressor.shape[0]:, :]
|
|
vertices = torch.matmul(vertices, rot.transpose(1, 2)) + transl
|
|
if not return_tensor:
|
|
vertices = vertices.detach().cpu().numpy()
|
|
return vertices
|
|
|
|
def init_params(self, nFrames=1, nShapes=1, ret_tensor=False):
|
|
params = {
|
|
'poses': np.zeros((nFrames, self.NUM_POSES)),
|
|
'shapes': np.zeros((nShapes, NUM_SHAPES)),
|
|
'Rh': np.zeros((nFrames, 3)),
|
|
'Th': np.zeros((nFrames, 3)),
|
|
}
|
|
if self.model_type == 'smplx':
|
|
params['expression'] = np.zeros((nFrames, self.NUM_EXPR))
|
|
if ret_tensor:
|
|
for key in params.keys():
|
|
params[key] = to_tensor(params[key], self.dtype, self.device)
|
|
return params
|
|
|
|
def check_params(self, body_params):
|
|
model_type = self.model_type
|
|
nFrames = body_params['poses'].shape[0]
|
|
if body_params['poses'].shape[1] != self.NUM_POSES:
|
|
body_params['poses'] = np.hstack((body_params['poses'], np.zeros((nFrames, self.NUM_POSES - body_params['poses'].shape[1]))))
|
|
if model_type == 'smplx' and 'expression' not in body_params.keys():
|
|
body_params['expression'] = np.zeros((nFrames, self.NUM_EXPR))
|
|
return body_params
|
|
|
|
@staticmethod
|
|
def merge_params(param_list, share_shape=True):
|
|
output = {}
|
|
for key in ['poses', 'shapes', 'Rh', 'Th', 'expression']:
|
|
if key in param_list[0].keys():
|
|
output[key] = np.vstack([v[key] for v in param_list])
|
|
if share_shape:
|
|
output['shapes'] = output['shapes'].mean(axis=0, keepdims=True)
|
|
# add other keys
|
|
for key in param_list[0].keys():
|
|
if key in output.keys():
|
|
continue
|
|
output[key] = np.stack([v[key] for v in param_list])
|
|
return output
|
|
|
|
@staticmethod
|
|
def select_nf(params_all, nf):
|
|
output = {}
|
|
for key in ['poses', 'Rh', 'Th']:
|
|
output[key] = params_all[key][nf:nf+1, :]
|
|
if 'expression' in params_all.keys():
|
|
output['expression'] = params_all['expression'][nf:nf+1, :]
|
|
if params_all['shapes'].shape[0] == 1:
|
|
output['shapes'] = params_all['shapes']
|
|
else:
|
|
output['shapes'] = params_all['shapes'][nf:nf+1, :]
|
|
return output |