diff --git a/easymocap/bodymodel/base.py b/easymocap/bodymodel/base.py new file mode 100644 index 0000000..3916e4a --- /dev/null +++ b/easymocap/bodymodel/base.py @@ -0,0 +1,135 @@ +''' + @ Date: 2022-03-17 19:23:59 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-07-15 12:15:46 + @ FilePath: /EasyMocapPublic/easymocap/bodymodel/base.py +''' +import numpy as np +import torch + +from ..mytools.file_utils import myarray2string + +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.name = 'custom' + + def forward(self): + pass + + def vertices(self, params, **kwargs): + return self.forward(return_verts=True, **kwargs, **params) + + def keypoints(self, params, **kwargs): + return self.forward(return_verts=False, **kwargs, **params) + + def transform(self, params, **kwargs): + raise NotImplementedError + +class ComposedModel(torch.nn.Module): + def __init__(self, config_dict): + # 叠加多个模型的配置 + for name, config in config_dict.items(): + pass + +class Params(dict): + @classmethod + def merge(self, params_list, share_shape=True, stack=np.vstack): + output = {} + for key in params_list[0].keys(): + if key == 'id':continue + output[key] = stack([v[key] for v in params_list]) + if share_shape: + output['shapes'] = output['shapes'].mean(axis=0, keepdims=True) + return output + + def __len__(self): + return len(self['poses']) + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __getitem__(self, index): + if not isinstance(index, int): + return super().__getitem__(index) + if 'shapes' not in self.keys(): + # arbitray data + ret = {} + for key, val in self.items(): + if index >= 1 and val.shape[0] == 1: + ret[key] = val[0] + else: + ret[key] = val[index] + return Params(**ret) + ret = {'id': 0} + poses = self.poses + shapes = self.shapes + while len(shapes.shape) < len(poses.shape): + shapes = shapes[None] + if poses.shape[0] == shapes.shape[0]: + if index >= 1 and shapes.shape[0] == 1: + ret['shapes'] = shapes[0] + else: + ret['shapes'] = shapes[index] + elif shapes.shape[0] == 1: + ret['shapes'] = shapes[0] + else: + import ipdb; ipdb.set_trace() + if index >= 1 and poses.shape[0] == 1: + ret['poses'] = poses[0] + else: + ret['poses'] = poses[index] + for key, val in self.items(): + if key == 'id': + ret[key] = self[key] + continue + if key in ret.keys():continue + if index >= 1 and val.shape[0] == 1: + ret[key] = val[0] + else: + ret[key] = val[index] + for key, val in ret.items(): + if key == 'id': continue + if len(val.shape) == 1: + ret[key] = val[None] + return Params(**ret) + + def to_multiperson(self, pids): + results = [] + for i, pid in enumerate(pids): + param = self[i] + # TODO: this class just implement getattr + # param.id = pid # is wrong + param['id'] = pid + results.append(param) + return results + + def __str__(self) -> str: + ret = '' + lastkey = list(self.keys())[-1] + for key, val in self.items(): + if isinstance(val, np.ndarray): + ret += '"{}": '.format(key) + myarray2string(val, indent=0) + else: + ret += '"{}": '.format(key) + str(val) + if key != lastkey: + ret += ',\n' + return ret + + def shape(self): + ret = '' + lastkey = list(self.keys())[-1] + for key, val in self.items(): + if isinstance(val, np.ndarray): + ret += '"{}": {}'.format(key, val.shape) + else: + ret += '"{}": '.format(key) + str(val) + if key != lastkey: + ret += ',\n' + print(ret) + return ret \ No newline at end of file diff --git a/easymocap/bodymodel/lbs.py b/easymocap/bodymodel/lbs.py new file mode 100644 index 0000000..9de7842 --- /dev/null +++ b/easymocap/bodymodel/lbs.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn.functional as F + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) + + +def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, + dynamic_lmk_b_coords, + neck_kin_chain, dtype=torch.float32): + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + batch_size = vertices.shape[0] + + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues( + aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) + + rel_rot_mat = torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, + 0, y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, + 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange( + batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( + batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, + lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False, + use_shape_blending=True, use_pose_blending=True, J_shaped=None, return_vertices=True): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device = betas.device + + # Add shape contribution + if use_shape_blending: + v_shaped = v_template + blend_shapes(betas, shapedirs) + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + else: + v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1) + assert J_shaped is not None + J = J_shaped[None].expand(batch_size, -1, -1) + + if only_shape: + return v_shaped, J, None, None + # 3. Add pose blend shapes + # N x J x 3 x 3 + if pose2rot: + rot_mats = batch_rodrigues( + pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) + else: + rot_mats = pose.view(batch_size, -1, 3, 3) + + if use_pose_blending: + ident = torch.eye(3, dtype=dtype, device=device) + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + pose_offsets = torch.matmul(pose_feature, posedirs) \ + .view(batch_size, -1, 3) + v_posed = pose_offsets + v_shaped + else: + v_posed = v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + if not return_vertices: + return None, J_transformed, A, None + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_transformed.shape[1] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + return verts, J_transformed, A, T + + +def vertices2joints(J_regressor, vertices): + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas, shape_disps): + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + if len(rot_vecs.shape) > 2: + rot_vec_ori = rot_vecs + rot_vecs = rot_vecs.view(-1, 3) + else: + rot_vec_ori = None + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + if rot_vec_ori is not None: + rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3) + return rot_mat + + +def transform_mat(R, t): + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + +def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat( + rot_mats.view(-1, 3, 3), + rel_joints.contiguous().view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], + transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms + +def dqs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, + lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False, + use_shape_blending=True, use_pose_blending=True, J_shaped=None): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device = betas.device + + # Add shape contribution + if use_shape_blending: + v_shaped = v_template + blend_shapes(betas, shapedirs) + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + else: + v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1) + assert J_shaped is not None + J = J_shaped[None].expand(batch_size, -1, -1) + + if only_shape: + return v_shaped, J + # 3. Add pose blend shapes + # N x J x 3 x 3 + if pose2rot: + rot_mats = batch_rodrigues( + pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) + else: + rot_mats = pose.view(batch_size, -1, 3, 3) + + if use_pose_blending: + ident = torch.eye(3, dtype=dtype, device=device) + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + pose_offsets = torch.matmul(pose_feature, posedirs) \ + .view(batch_size, -1, 3) + v_posed = pose_offsets + v_shaped + else: + v_posed = v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + + verts=batch_dqs_blending(A,W,v_posed) + + return verts, J_transformed + +#A: B,J,4,4 W: B,V,J +def batch_dqs_blending(A,W,Vs): + Bnum,Jnum,_,_=A.shape + _,Vnum,_=W.shape + A = A.view(Bnum*Jnum,4,4) + Rs=A[:,:3,:3] + ws=torch.sqrt(torch.clamp(Rs[:,0,0]+Rs[:,1,1]+Rs[:,2,2]+1.,min=1.e-6))/2. + xs=(Rs[:,2,1]-Rs[:,1,2])/(4.*ws) + ys=(Rs[:,0,2]-Rs[:,2,0])/(4.*ws) + zs=(Rs[:,1,0]-Rs[:,0,1])/(4.*ws) + Ts=A[:,:3,3] + vDw=-0.5*( Ts[:,0]*xs + Ts[:,1]*ys + Ts[:,2]*zs) + vDx=0.5*( Ts[:,0]*ws + Ts[:,1]*zs - Ts[:,2]*ys) + vDy=0.5*(-Ts[:,0]*zs + Ts[:,1]*ws + Ts[:,2]*xs) + vDz=0.5*( Ts[:,0]*ys - Ts[:,1]*xs + Ts[:,2]*ws) + b0=W.unsqueeze(-2)@torch.cat([ws[:,None],xs[:,None],ys[:,None],zs[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4 + be=W.unsqueeze(-2)@torch.cat([vDw[:,None],vDx[:,None],vDy[:,None],vDz[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4 + b0 = b0.reshape(-1, 4) + be = be.reshape(-1, 4) + + ns=torch.norm(b0,dim=-1,keepdim=True) + be=be/ns + b0=b0/ns + Vs=Vs.view(Bnum*Vnum,3) + Vs=Vs+2.*b0[:,1:].cross(b0[:,1:].cross(Vs)+b0[:,:1]*Vs)+2.*(b0[:,:1]*be[:,1:]-be[:,:1]*b0[:,1:]+b0[:,1:].cross(be[:,1:])) + return Vs.reshape(Bnum,Vnum,3) \ No newline at end of file diff --git a/easymocap/bodymodel/smpl.py b/easymocap/bodymodel/smpl.py new file mode 100644 index 0000000..a2ba1d3 --- /dev/null +++ b/easymocap/bodymodel/smpl.py @@ -0,0 +1,438 @@ +from .base import Model, Params +from .lbs import lbs, batch_rodrigues +import os +import numpy as np +import torch + +def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')): + if 'torch.tensor' not in str(type(array)): + return torch.tensor(array, dtype=dtype).to(device) + else: + return array.to(device) + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + +def read_pickle(name): + import pickle + with open(name, 'rb') as f: + data = pickle.load(f, encoding='latin1') + return data + +def load_model_data(model_path): + model_path = os.path.abspath(model_path) + assert os.path.exists(model_path), 'Path {} does not exist!'.format( + model_path) + if model_path.endswith('.npz'): + data = np.load(model_path) + data = dict(data) + elif model_path.endswith('.pkl'): + data = read_pickle(model_path) + return data + +def load_regressor(regressor_path): + if regressor_path.endswith('.npy'): + X_regressor = to_tensor(np.load(regressor_path)) + elif regressor_path.endswith('.txt'): + data = np.loadtxt(regressor_path) + with open(regressor_path, 'r') as f: + shape = f.readline().split()[1:] + reg = np.zeros((int(shape[0]), int(shape[1]))) + for i, j, v in data: + reg[int(i), int(j)] = v + X_regressor = to_tensor(reg) + else: + import ipdb; ipdb.set_trace() + return X_regressor + +def save_regressor(fname, data): + with open(fname, 'w') as f: + f.writelines('{} {} {}\r\n'.format('#', data.shape[0], data.shape[1])) + for i in range(data.shape[0]): + for j in range(data.shape[1]): + if(data[i, j] > 0): + f.writelines('{} {} {}\r\n'.format(i, j, data[i, j])) + + +class SMPLModel(Model): + def __init__(self, model_path, regressor_path=None, + device='cpu', + use_pose_blending=True, use_shape_blending=True, use_joints=True, + NUM_SHAPES=-1, NUM_POSES=-1, + use_lbs=True, + use_root_rot=False, + **kwargs) -> None: + super().__init__() + self.name = 'lbs' + self.dtype = torch.float32 # not support fp16 now + self.use_pose_blending = use_pose_blending + self.use_shape_blending = use_shape_blending + self.use_root_rot = use_root_rot + self.NUM_SHAPES = NUM_SHAPES + self.NUM_POSES = NUM_POSES + self.NUM_POSES_FULL = NUM_POSES + self.use_joints = use_joints + + if isinstance(device, str): + device = torch.device(device) + if not torch.torch.cuda.is_available(): + device = torch.device('cpu') + self.device = device + self.model_type = 'smpl' + # create the SMPL model + self.lbs = lbs + self.data = load_model_data(model_path) + self.register_any_lbs(self.data) + + # keypoints regressor + if regressor_path is not None and use_joints: + X_regressor = load_regressor(regressor_path) + X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0) + self.register_any_keypoints(X_regressor) + if not self.use_root_rot: + self.NUM_POSES -= 3 # remove first 3 dims + self.to(self.device) + + def register_any_lbs(self, data): + self.faces = to_np(self.data['f'], dtype=np.int64) + self.register_buffer('faces_tensor', + to_tensor(self.faces, dtype=torch.long)) + for key in ['J_regressor', 'v_template', 'weights']: + if key not in data.keys(): + print('Warning: {} not in data'.format(key)) + self.__setattr__(key, None) + continue + val = to_tensor(to_np(data[key]), dtype=self.dtype) + self.register_buffer(key, val) + self.NUM_POSES = self.weights.shape[-1] * 3 + self.NUM_POSES_FULL = self.NUM_POSES + # add poseblending + if self.use_pose_blending: + # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 + num_pose_basis = data['posedirs'].shape[-1] + # 207 x 20670 + data['posedirs_origin'] = data['posedirs'] + data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T + val = to_tensor(to_np(data['posedirs']), dtype=self.dtype) + self.register_buffer('posedirs', val) + else: + self.posedirs = None + # add shape blending + if self.use_shape_blending: + val = to_tensor(to_np(data['shapedirs']), dtype=self.dtype) + if self.NUM_SHAPES != -1: + val = val[..., :self.NUM_SHAPES] + self.register_buffer('shapedirs', val) + self.NUM_SHAPES = val.shape[-1] + else: + self.shapedirs = None + if self.use_shape_blending: + self.J_shaped = None + else: + val = to_tensor(to_np(data['J']), dtype=self.dtype) + self.register_buffer('J_shaped', val) + + self.nVertices = self.v_template.shape[0] + # indices of parents for each joints + kintree_table = data['kintree_table'] + if len(kintree_table.shape) == 2: + kintree_table = kintree_table[0] + parents = to_tensor(to_np(kintree_table)).long() + parents[0] = -1 + self.register_buffer('parents', parents) + + def register_any_keypoints(self, X_regressor): + # set the parameter of keypoints level + j_J_regressor = torch.zeros(self.J_regressor.shape[0], X_regressor.shape[0], device=self.device) + for i in range(self.J_regressor.shape[0]): + j_J_regressor[i, i] = 1 + j_v_template = X_regressor @ self.v_template + j_weights = X_regressor @ self.weights + if self.use_pose_blending: + posedirs = self.data['posedirs_origin'] + j_posedirs = torch.einsum('ab, bde->ade', X_regressor, torch.Tensor(posedirs)).numpy() + j_posedirs = np.reshape(j_posedirs, [-1, posedirs.shape[-1]]).T + j_posedirs = to_tensor(j_posedirs) + self.register_buffer('j_posedirs', j_posedirs) + else: + self.j_posedirs = None + if self.use_shape_blending: + j_shapedirs = torch.einsum('vij,kv->kij', [self.shapedirs, X_regressor]) + self.register_buffer('j_shapedirs', j_shapedirs) + else: + self.j_shapedirs = None + self.register_buffer('j_weights', j_weights) + self.register_buffer('j_v_template', j_v_template) + self.register_buffer('j_J_regressor', j_J_regressor) + + def forward(self, return_verts=True, return_tensor=True, + return_smpl_joints=False, + only_shape=False, pose2rot=True, **params): + params = self.check_params(params) + poses, shapes = params['poses'], params['shapes'] + poses = self.extend_poses(pose2rot=pose2rot, **params) + Rh, Th = params['Rh'], params['Th'] + # check if there are multiple person + if len(shapes.shape) == 3: + reshape = poses.shape[:2] + Rh = Rh.reshape(-1, *Rh.shape[2:]) + Th = Th.reshape(-1, *Th.shape[2:]) + poses = poses.reshape(-1, *poses.shape[2:]) + shapes = shapes.reshape(-1, *shapes.shape[2:]) + else: + reshape = None + if len(Rh.shape) == 2: # angle-axis + Rh = batch_rodrigues(Rh) + Th = Th.unsqueeze(dim=1) + if return_verts or not self.use_joints: + v_template = self.v_template + if 'scale' in params.keys(): + v_template = v_template * params['scale'][0] + vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape, + use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped) + if not self.use_joints and not return_verts: + vertices = joints + else: + # only forward joints + v_template = self.j_v_template + if 'scale' in params.keys(): + v_template = v_template * params['scale'][0] + vertices, joints, _, _ = self.lbs(shapes, poses, v_template, + self.j_shapedirs, self.j_posedirs, + self.j_J_regressor, self.parents, + self.j_weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape, + use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped) + if return_smpl_joints: + # vertices = vertices[:, :self.J_regressor.shape[0], :] + vertices = joints + else: + vertices = vertices[:, self.J_regressor.shape[0]:, :] + vertices = torch.matmul(vertices, Rh.transpose(1, 2)) + Th + if not return_tensor: + vertices = vertices.detach().cpu().numpy() + if reshape is not None: + vertices = vertices.reshape(*reshape, *vertices.shape[1:]) + return vertices + + def transform(self, params, pose2rot=True, return_vertices=True): + v_template = self.v_template + params = self.check_params(params) + shapes = params['shapes'] + poses = self.extend_poses(**params) + vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.weights, pose2rot=pose2rot, dtype=self.dtype, + use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped, + return_vertices=return_vertices) + return T_joints, T_vertices + + def merge_params(self, params, **kwargs): + return Params.merge(params, **kwargs) + + + def convert_from_standard_smpl(self, params): + params = self.check_params(params) + poses, shapes = params['poses'], params['shapes'] + Th = params['Th'] + vertices, joints, _, _ = lbs(shapes, poses, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.weights, pose2rot=True, dtype=self.dtype, only_shape=True) + # N x 3 + j0 = joints[:, 0, :] + Rh = poses[:, :3].clone() + # N x 3 x 3 + rot = batch_rodrigues(Rh) + Tnew = Th + j0 - torch.einsum('bij,bj->bi', rot, j0) + poses[:, :3] = 0 + res = dict(poses=poses.detach().cpu().numpy(), + shapes=shapes.detach().cpu().numpy(), + Rh=Rh.detach().cpu().numpy(), + Th=Tnew.detach().cpu().numpy() + ) + return res + + def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False): + params = { + 'poses': np.zeros((nFrames, self.NUM_POSES)), + 'shapes': np.zeros((nShapes, self.NUM_SHAPES)), + 'Rh': np.zeros((nFrames, 3)), + 'Th': np.zeros((nFrames, 3)), + } + if add_scale: + params['scale'] = np.ones((1, 1)) + if nPerson > 1: + for key in params.keys(): + params[key] = params[key][:, None].repeat(nPerson, axis=1) + if ret_tensor: + for key in params.keys(): + params[key] = to_tensor(params[key], self.dtype, self.device) + return params + + def check_params(self, body_params): + # 预先拷贝一下,不要修改到原始数据了 + body_params = body_params.copy() + poses = body_params['poses'] + nFrames = poses.shape[0] + # convert to torch + if 'torch' not in str(type(poses)): + dtype, device = self.dtype, self.device + for key in ['poses', 'handl', 'handr', 'shapes', 'expression', 'Rh', 'Th', 'scale']: + if key not in body_params.keys(): + continue + body_params[key] = to_tensor(body_params[key], dtype, device) + poses = body_params['poses'] + # check Rh and Th + for key in ['Rh', 'Th']: + if key not in body_params.keys(): + body_params[key] = torch.zeros((nFrames, 3), dtype=poses.dtype, device=poses.device) + # process shapes + for key in ['shapes']: + if body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 2: + body_params[key] = body_params[key].expand(nFrames, -1) + elif body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 3: + body_params[key] = body_params[key].expand(*body_params['poses'].shape[:2], -1) + return body_params + + def __str__(self) -> str: + res = '- Model: {}\n'.format(self.model_type) + res += ' poses: {}\n'.format(self.NUM_POSES) + res += ' shapes: {}\n'.format(self.NUM_SHAPES) + res += ' vertices: {}\n'.format(self.v_template.shape) + res += ' faces: {}\n'.format(self.faces.shape) + res += ' posedirs: {}\n'.format(self.posedirs.shape) + res += ' shapedirs: {}\n'.format(self.shapedirs.shape) + return res + + def extend_poses(self, poses, **kwargs): + if poses.shape[-1] == self.NUM_POSES_FULL: + return poses + if not self.use_root_rot: + if kwargs.get('pose2rot', True): + zero_rot = torch.zeros((*poses.shape[:-1], 3), dtype=poses.dtype, device=poses.device) + poses = torch.cat([zero_rot, poses], dim=-1) + elif poses.shape[-3] != self.NUM_POSES_FULL // 3: + # insert a blank rotation + zero_rot = torch.zeros((*poses.shape[:-3], 1, 3), dtype=poses.dtype, device=poses.device) + zero_rot = batch_rodrigues(zero_rot) + poses = torch.cat([zero_rot, poses], dim=-3) + return poses + + def jacobian_posesfull_poses(self, poses, poses_full): + # TODO: cache this + if self.use_root_rot: + jacobian = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device) + else: + zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device) + eye_right = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device) + jacobian = torch.cat([zero_root, eye_right], dim=0) + return jacobian + + def export_full_poses(self, poses, **kwargs): + if not self.use_root_rot: + poses = np.hstack([np.zeros((poses.shape[0], 3)), poses]) + return poses + + def encode(self, body_params): + # This function provide standard SMPL parameters to this model + poses = body_params['poses'] + if 'Rh' not in body_params.keys(): + body_params['Rh'] = poses[:, :3].copy() + if 'Th' not in body_params.keys(): + if 'trans' in body_params.keys(): + body_params['Th'] = body_params.pop('trans') + else: + body_params['Th'] = np.zeros((poses.shape[0], 3), dtype=poses.dtype) + if not self.use_root_rot and poses.shape[1] == 72: + body_params['poses'] = poses[:, 3:].copy() + return body_params + +class SMPLLayerEmbedding(SMPLModel): + def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs): + super().__init__(**kwargs) + from human_body_prior.tools.model_loader import load_model + from human_body_prior.models.vposer_model import VPoser + vposer, _ = load_model(vposer_ckpt, + model_code=VPoser, + remove_words_in_model_weights='vp_model.', + disable_grad=True) + vposer.eval() + vposer.to(self.device) + self.vposer = vposer + self.vposer_dim = 32 + self.NUM_POSES = self.vposer_dim + + def encode(self, body_params): + # This function provide standard SMPL parameters to this model + poses = body_params['poses'] + if poses.shape[1] == self.vposer_dim: + return body_params + poses_tensor = torch.Tensor(poses).to(self.device) + ret = self.vposer.encode(poses_tensor[:, :63]).mean + body_params = super().encode(body_params) + body_params['poses'] = ret.detach().cpu().numpy() + return body_params + + def extend_poses(self, poses, **kwargs): + if poses.shape[-1] == self.vposer_dim: + ret = self.vposer.decode(poses) + poses_body = ret['pose_body'].reshape(poses.shape[0], -1) + elif poses.shape[-1] == self.NUM_POSES_FULL: + return poses + elif poses.shape[-1] == self.NUM_POSES_FULL - 3: + poses_zero = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device) + poses = torch.cat([poses_zero, poses], dim=-1) + return poses + poses_zero = torch.zeros((poses_body.shape[0], 3), dtype=poses_body.dtype, device=poses_body.device) + poses = torch.cat([poses_zero, poses_body, poses_zero, poses_zero], dim=1) + return poses + + def export_full_poses(self, poses, **kwargs): + poses = torch.Tensor(poses).to(self.device) + poses = self.extend_poses(poses) + return poses.detach().cpu().numpy() + +if __name__ == '__main__': + vis = True + test_config = { + 'smpl':{ + 'model_path': 'data/bodymodels/SMPL_python_v.1.1.0/smpl/models/basicmodel_m_lbs_10_207_0_v1.1.0.pkl', + 'regressor_path': 'data/smplx/J_regressor_body25.npy', + }, + 'smplh':{ + 'model_path': 'data/bodymodels/smplhv1.2/male/model.npz', + 'regressor_path': None, + }, + 'mano':{ + 'model_path': 'data/bodymodels/manov1.2/MANO_LEFT.pkl', + 'regressor_path': None, + }, + 'flame':{ + 'model_path': 'data/bodymodels/FLAME2020/FLAME_MALE.pkl', + 'regressor_path': None, + } + } + for name, cfg in test_config.items(): + print('Testing {}...'.format(name)) + model = SMPLModel(**cfg) + print(model) + params = model.init_params() + for key in params.keys(): + params[key] = (np.random.rand(*params[key].shape) - 0.5)*0.5 + vertices = model.vertices(params, return_tensor=True)[0] + if cfg['regressor_path'] is not None: + keypoints = model.keypoints(params, return_tensor=True)[0] + print(keypoints.shape) + if vis: + import open3d as o3d + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(vertices.reshape(-1, 3)) + mesh.triangles = o3d.utility.Vector3iVector(model.faces.reshape(-1, 3)) + mesh.compute_vertex_normals() + o3d.visualization.draw_geometries([mesh]) diff --git a/easymocap/bodymodel/smplx.py b/easymocap/bodymodel/smplx.py new file mode 100644 index 0000000..878ba94 --- /dev/null +++ b/easymocap/bodymodel/smplx.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +from .base import Model +from .smpl import SMPLModel, SMPLLayerEmbedding, read_pickle, to_tensor +from os.path import join +import numpy as np + +def read_hand(path, use_pca, use_flat_mean, num_pca_comps): + data = read_pickle(path) + mean = data['hands_mean'].reshape(1, -1).astype(np.float32) + mean_full = mean + components_full = data['hands_components'].astype(np.float32) + weight = np.diag(components_full @ components_full.T) + components = components_full[:num_pca_comps] + weight = weight[:num_pca_comps] + if use_flat_mean: + mean = np.zeros_like(mean) + return mean, components, weight, mean_full, components_full + +class MANO(SMPLModel): + def __init__(self, cfg_hand, **kwargs): + super().__init__(**kwargs) + self.name = 'mano' + self.use_root_rot = False + mean, components, weight, mean_full, components_full = read_hand(kwargs['model_path'], **cfg_hand) + self.register_buffer('mean', to_tensor(mean, dtype=self.dtype)) + self.register_buffer('components', to_tensor(components, dtype=self.dtype)) + self.cfg_hand = cfg_hand + self.to(self.device) + if cfg_hand.use_pca: + self.NUM_POSES = cfg_hand.num_pca_comps + + def extend_poses(self, poses, **kwargs): + if poses.shape[-1] == self.mean.shape[-1] + 3: + return poses + if self.cfg_hand.use_pca: + poses = poses @ self.components + if kwargs.get('pose2rot', True): + poses = super().extend_poses(poses+self.mean, **kwargs) + else: + poses = super().extend_poses(poses, **kwargs) + return poses + + def jacobian_posesfull_poses(self, poses, poses_full): + if self.cfg_hand.use_pca: + jacobian = self.components.t() + zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device) + jacobian = torch.cat([zero_root, jacobian], dim=0) + else: + jacobian = super().jacobian_posesfull_poses(poses, poses_full) + return jacobian +class MANOLR(Model): + def __init__(self, model_path, regressor_path, cfg_hand, **kwargs): + super().__init__() + self.name = 'manolr' + keys = list(model_path.keys()) + # stack 方式:(nframes, nhand x ndim) + self.keys = keys + modules_hand = {} + faces = [] + v_template = [] + cnt = 0 + for key in keys: + modules_hand[key] = MANO(cfg_hand, model_path=model_path[key], regressor_path=regressor_path[key], **kwargs) + v_template.append(modules_hand[key].v_template.cpu().numpy()) + faces.append(modules_hand[key].faces + cnt) + cnt += v_template[-1].shape[0] + self.device = modules_hand[key].device + self.dtype = modules_hand[key].dtype + if key == 'right': + modules_hand[key].shapedirs[:, 0] *= -1 + modules_hand[key].j_shapedirs[:, 0] *= -1 + self.faces = np.vstack(faces) + self.v_template = np.vstack(v_template) + self.modules_hand = nn.ModuleDict(modules_hand) + self.to(self.device) + + def init_params(self, **kwargs): + param_all = {} + for key in self.keys: + param = self.modules_hand[key].init_params(**kwargs) + param_all[key] = param + if False: + params = {k: torch.cat([param_all[key][k] for key in self.keys], dim=-1) for k in param.keys()} + else: + params = {k: np.concatenate([param_all[key][k] for key in self.keys], axis=-1) for k in param.keys() if k != 'shapes'} + params['shapes'] = param_all['left']['shapes'] + return params + + def split(self, params): + params_split = {} + for imodel, model in enumerate(self.keys): + param_= params.copy() + for key in ['poses', 'shapes', 'Rh', 'Th']: + if key not in params.keys():continue + if key == 'shapes': + continue + shape = params[key].shape[-1] + start = shape//len(self.keys)*imodel + end = shape//len(self.keys)*(imodel+1) + param_[key] = params[key][:, start:end] + params_split[model] = param_ + return params_split + + def forward(self, **params): + params_split = self.split(params) + rets = [] + for imodel, model in enumerate(self.keys): + ret = self.modules_hand[model](**params_split[model]) + rets.append(ret) + if params.get('return_tensor', True): + rets = torch.cat(rets, dim=1) + else: + rets = np.concatenate(rets, axis=1) + return rets + + def extend_poses(self, poses, **kwargs): + params_split = self.split({'poses': poses}) + rets = [] + for imodel, model in enumerate(self.keys): + poses = params_split[model]['poses'] + poses = self.modules_hand[model].extend_poses(poses) + rets.append(poses) + poses = torch.cat(rets, dim=1) + return poses + + def export_full_poses(self, poses, **kwargs): + params_split = self.split({'poses': poses}) + rets = [] + for imodel, model in enumerate(self.keys): + poses = torch.Tensor(params_split[model]['poses']).to(self.device) + poses = self.modules_hand[model].extend_poses(poses) + rets.append(poses) + poses = torch.cat(rets, dim=1) + return poses.detach().cpu().numpy() + +class SMPLHModel(SMPLModel): + def __init__(self, mano_path, cfg_hand, **kwargs): + super().__init__(**kwargs) + self.NUM_POSES = self.NUM_POSES - 90 + meanl, componentsl, weight_l, self.mean_full_l, self.components_full_l = read_hand(join(mano_path, 'MANO_LEFT.pkl'), **cfg_hand) + meanr, componentsr, weight_r, self.mean_full_r, self.components_full_r = read_hand(join(mano_path, 'MANO_RIGHT.pkl'), **cfg_hand) + self.register_buffer('weight_l', to_tensor(weight_l, dtype=self.dtype)) + self.register_buffer('weight_r', to_tensor(weight_r, dtype=self.dtype)) + self.register_buffer('meanl', to_tensor(meanl, dtype=self.dtype)) + self.register_buffer('meanr', to_tensor(meanr, dtype=self.dtype)) + self.register_buffer('componentsl', to_tensor(componentsl, dtype=self.dtype)) + self.register_buffer('componentsr', to_tensor(componentsr, dtype=self.dtype)) + + self.register_buffer('jacobian_posesfull_poses_', self._jacobian_posesfull_poses()) + self.NUM_HANDS = cfg_hand.num_pca_comps if cfg_hand.use_pca else 45 + self.cfg_hand = cfg_hand + self.to(self.device) + + def _jacobian_posesfull_poses(self): + # TODO: cache this + # | body_full/body | 0 | 0 | + # | 0 | l | 0 | + # | 0 | 0 | r | + eye_right = torch.eye(self.NUM_POSES, dtype=self.dtype) + # + jac_handl = self.componentsl.t() + jac_handr = self.componentsr.t() + output = torch.zeros((self.NUM_POSES_FULL, self.NUM_POSES+jac_handl.shape[1]*2), dtype=self.dtype) + if self.use_root_rot: + raise NotImplementedError + else: + output[3:3+self.NUM_POSES, :self.NUM_POSES] = eye_right + output[3+self.NUM_POSES:3+self.NUM_POSES+jac_handl.shape[0], \ + self.NUM_POSES:self.NUM_POSES+jac_handl.shape[1]] = jac_handl + output[3+self.NUM_POSES+jac_handl.shape[0]:3+self.NUM_POSES+2*jac_handl.shape[0], \ + self.NUM_POSES+jac_handl.shape[1]:self.NUM_POSES+jac_handl.shape[1]*2] = jac_handr + return output + + def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False): + params = super().init_params(nFrames, nShapes, nPerson, ret_tensor, add_scale=add_scale) + handl = np.zeros((nFrames, self.NUM_HANDS)) + handr = np.zeros((nFrames, self.NUM_HANDS)) + if nPerson > 1: + handl = handl[:, None].repeat(nPerson, axis=1) + handr = handr[:, None].repeat(nPerson, axis=1) + if ret_tensor: + handl = to_tensor(handl, self.dtype, self.device) + handr = to_tensor(handr, self.dtype, self.device) + params['handl'] = handl + params['handr'] = handr + return params + + def extend_poses(self, poses, handl=None, handr=None, **kwargs): + if poses.shape[-1] == self.NUM_POSES_FULL: + return poses + poses = super().extend_poses(poses) + if handl is None: + handl = self.meanl.clone() + handr = self.meanr.clone() + handl = handl.expand(poses.shape[0], -1) + handr = handr.expand(poses.shape[0], -1) + else: + if self.cfg_hand.use_pca: + handl = handl @ self.componentsl + handr = handr @ self.componentsr + handl = handl +self.meanl + handr = handr +self.meanr + poses = torch.cat([poses, handl, handr], dim=-1) + return poses + + def export_full_poses(self, poses, handl, handr, **kwargs): + poses = torch.Tensor(poses).to(self.device) + handl = torch.Tensor(handl).to(self.device) + handr = torch.Tensor(handr).to(self.device) + poses = self.extend_poses(poses, handl, handr) + return poses.detach().cpu().numpy() + +class SMPLHModelEmbedding(SMPLHModel): + def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs): + super().__init__(**kwargs) + from human_body_prior.tools.model_loader import load_model + from human_body_prior.models.vposer_model import VPoser + vposer, _ = load_model(vposer_ckpt, + model_code=VPoser, + remove_words_in_model_weights='vp_model.', + disable_grad=True) + vposer.to(self.device) + self.vposer = vposer + self.vposer_dim = 32 + self.NUM_POSES = self.vposer_dim + + def decode(self, poses, add_rot=True): + if poses.shape[-1] == 66 and add_rot: + return poses + elif poses.shape[-1] == 63 and not add_rot: + return poses + assert poses.shape[-1] == self.vposer_dim, poses.shape + ret = self.vposer.decode(poses) + poses_body = ret['pose_body'].reshape(poses.shape[0], -1) + if add_rot: + zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device) + poses_body = torch.cat([zero_rot, poses_body], dim=-1) + return poses_body + + def extend_poses(self, poses, handl, handr, **kwargs): + if poses.shape[-1] == self.NUM_POSES_FULL: + return poses + zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device) + poses_body = self.decode(poses, add_rot=False) + if self.cfg_hand.use_pca: + handl = handl @ self.componentsl + handr = handr @ self.componentsr + handl = handl +self.meanl + handr = handr +self.meanr + poses = torch.cat([zero_rot, poses_body, handl, handr], dim=-1) + return poses + + def export_full_poses(self, poses, handl, handr, **kwargs): + poses = torch.Tensor(poses).to(self.device) + handl = torch.Tensor(handl).to(self.device) + handr = torch.Tensor(handr).to(self.device) + poses = self.extend_poses(poses, handl, handr) + return poses.detach().cpu().numpy()