🚧 re-organize bodymodel

This commit is contained in:
Qing Shuai 2022-08-21 16:02:11 +08:00
parent cd3f184f04
commit 050cb209d1
4 changed files with 1333 additions and 0 deletions

135
easymocap/bodymodel/base.py Normal file
View File

@ -0,0 +1,135 @@
'''
@ Date: 2022-03-17 19:23:59
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-07-15 12:15:46
@ FilePath: /EasyMocapPublic/easymocap/bodymodel/base.py
'''
import numpy as np
import torch
from ..mytools.file_utils import myarray2string
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.name = 'custom'
def forward(self):
pass
def vertices(self, params, **kwargs):
return self.forward(return_verts=True, **kwargs, **params)
def keypoints(self, params, **kwargs):
return self.forward(return_verts=False, **kwargs, **params)
def transform(self, params, **kwargs):
raise NotImplementedError
class ComposedModel(torch.nn.Module):
def __init__(self, config_dict):
# 叠加多个模型的配置
for name, config in config_dict.items():
pass
class Params(dict):
@classmethod
def merge(self, params_list, share_shape=True, stack=np.vstack):
output = {}
for key in params_list[0].keys():
if key == 'id':continue
output[key] = stack([v[key] for v in params_list])
if share_shape:
output['shapes'] = output['shapes'].mean(axis=0, keepdims=True)
return output
def __len__(self):
return len(self['poses'])
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __getitem__(self, index):
if not isinstance(index, int):
return super().__getitem__(index)
if 'shapes' not in self.keys():
# arbitray data
ret = {}
for key, val in self.items():
if index >= 1 and val.shape[0] == 1:
ret[key] = val[0]
else:
ret[key] = val[index]
return Params(**ret)
ret = {'id': 0}
poses = self.poses
shapes = self.shapes
while len(shapes.shape) < len(poses.shape):
shapes = shapes[None]
if poses.shape[0] == shapes.shape[0]:
if index >= 1 and shapes.shape[0] == 1:
ret['shapes'] = shapes[0]
else:
ret['shapes'] = shapes[index]
elif shapes.shape[0] == 1:
ret['shapes'] = shapes[0]
else:
import ipdb; ipdb.set_trace()
if index >= 1 and poses.shape[0] == 1:
ret['poses'] = poses[0]
else:
ret['poses'] = poses[index]
for key, val in self.items():
if key == 'id':
ret[key] = self[key]
continue
if key in ret.keys():continue
if index >= 1 and val.shape[0] == 1:
ret[key] = val[0]
else:
ret[key] = val[index]
for key, val in ret.items():
if key == 'id': continue
if len(val.shape) == 1:
ret[key] = val[None]
return Params(**ret)
def to_multiperson(self, pids):
results = []
for i, pid in enumerate(pids):
param = self[i]
# TODO: this class just implement getattr
# param.id = pid # is wrong
param['id'] = pid
results.append(param)
return results
def __str__(self) -> str:
ret = ''
lastkey = list(self.keys())[-1]
for key, val in self.items():
if isinstance(val, np.ndarray):
ret += '"{}": '.format(key) + myarray2string(val, indent=0)
else:
ret += '"{}": '.format(key) + str(val)
if key != lastkey:
ret += ',\n'
return ret
def shape(self):
ret = ''
lastkey = list(self.keys())[-1]
for key, val in self.items():
if isinstance(val, np.ndarray):
ret += '"{}": {}'.format(key, val.shape)
else:
ret += '"{}": '.format(key) + str(val)
if key != lastkey:
ret += ',\n'
print(ret)
return ret

501
easymocap/bodymodel/lbs.py Normal file
View File

@ -0,0 +1,501 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import numpy as np
import torch
import torch.nn.functional as F
def rot_mat_to_euler(rot_mats):
# Calculates rotation matrix to euler angles
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
return torch.atan2(-rot_mats[:, 2, 0], sy)
def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
dynamic_lmk_b_coords,
neck_kin_chain, dtype=torch.float32):
''' Compute the faces, barycentric coordinates for the dynamic landmarks
To do so, we first compute the rotation of the neck around the y-axis
and then use a pre-computed look-up table to find the faces and the
barycentric coordinates that will be used.
Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
for providing the original TensorFlow implementation and for the LUT.
Parameters
----------
vertices: torch.tensor BxVx3, dtype = torch.float32
The tensor of input vertices
pose: torch.tensor Bx(Jx3), dtype = torch.float32
The current pose of the body model
dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
The look-up table from neck rotation to faces
dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
The look-up table from neck rotation to barycentric coordinates
neck_kin_chain: list
A python list that contains the indices of the joints that form the
kinematic chain of the neck.
dtype: torch.dtype, optional
Returns
-------
dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
A tensor of size BxL that contains the indices of the faces that
will be used to compute the current dynamic landmarks.
dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
A tensor of size BxL that contains the indices of the faces that
will be used to compute the current dynamic landmarks.
'''
batch_size = vertices.shape[0]
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
neck_kin_chain)
rot_mats = batch_rodrigues(
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
rel_rot_mat = torch.eye(3, device=vertices.device,
dtype=dtype).unsqueeze_(dim=0)
for idx in range(len(neck_kin_chain)):
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
y_rot_angle = torch.round(
torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
max=39)).to(dtype=torch.long)
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
y_rot_angle = (neg_mask * neg_vals +
(1 - neg_mask) * y_rot_angle)
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
0, y_rot_angle)
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
0, y_rot_angle)
return dyn_lmk_faces_idx, dyn_lmk_b_coords
def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
''' Calculates landmarks by barycentric interpolation
Parameters
----------
vertices: torch.tensor BxVx3, dtype = torch.float32
The tensor of input vertices
faces: torch.tensor Fx3, dtype = torch.long
The faces of the mesh
lmk_faces_idx: torch.tensor L, dtype = torch.long
The tensor with the indices of the faces used to calculate the
landmarks.
lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
The tensor of barycentric coordinates that are used to interpolate
the landmarks
Returns
-------
landmarks: torch.tensor BxLx3, dtype = torch.float32
The coordinates of the landmarks for each mesh in the batch
'''
# Extract the indices of the vertices for each face
# BxLx3
batch_size, num_verts = vertices.shape[:2]
device = vertices.device
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
batch_size, -1, 3)
lmk_faces += torch.arange(
batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
batch_size, -1, 3, 3)
landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
return landmarks
def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False,
use_shape_blending=True, use_pose_blending=True, J_shaped=None, return_vertices=True):
''' Performs Linear Blend Skinning with the given shape and pose parameters
Parameters
----------
betas : torch.tensor BxNB
The tensor of shape parameters
pose : torch.tensor Bx(J + 1) * 3
The pose parameters in axis-angle format
v_template torch.tensor BxVx3
The template mesh that will be deformed
shapedirs : torch.tensor 1xNB
The tensor of PCA shape displacements
posedirs : torch.tensor Px(V * 3)
The pose PCA coefficients
J_regressor : torch.tensor JxV
The regressor array that is used to calculate the joints from
the position of the vertices
parents: torch.tensor J
The array that describes the kinematic tree for the model
lbs_weights: torch.tensor N x V x (J + 1)
The linear blend skinning weights that represent how much the
rotation matrix of each part affects each vertex
pose2rot: bool, optional
Flag on whether to convert the input pose tensor to rotation
matrices. The default value is True. If False, then the pose tensor
should already contain rotation matrices and have a size of
Bx(J + 1)x9
dtype: torch.dtype, optional
Returns
-------
verts: torch.tensor BxVx3
The vertices of the mesh after applying the shape and pose
displacements.
joints: torch.tensor BxJx3
The joints of the model
'''
batch_size = max(betas.shape[0], pose.shape[0])
device = betas.device
# Add shape contribution
if use_shape_blending:
v_shaped = v_template + blend_shapes(betas, shapedirs)
# Get the joints
# NxJx3 array
J = vertices2joints(J_regressor, v_shaped)
else:
v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1)
assert J_shaped is not None
J = J_shaped[None].expand(batch_size, -1, -1)
if only_shape:
return v_shaped, J, None, None
# 3. Add pose blend shapes
# N x J x 3 x 3
if pose2rot:
rot_mats = batch_rodrigues(
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
else:
rot_mats = pose.view(batch_size, -1, 3, 3)
if use_pose_blending:
ident = torch.eye(3, dtype=dtype, device=device)
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
pose_offsets = torch.matmul(pose_feature, posedirs) \
.view(batch_size, -1, 3)
v_posed = pose_offsets + v_shaped
else:
v_posed = v_shaped
# 4. Get the global joint location
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
if not return_vertices:
return None, J_transformed, A, None
# 5. Do skinning:
# W is N x V x (J + 1)
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
# (N x V x (J + 1)) x (N x (J + 1) x 16)
num_joints = J_transformed.shape[1]
T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
.view(batch_size, -1, 4, 4)
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
dtype=dtype, device=device)
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
verts = v_homo[:, :, :3, 0]
return verts, J_transformed, A, T
def vertices2joints(J_regressor, vertices):
''' Calculates the 3D joint locations from the vertices
Parameters
----------
J_regressor : torch.tensor JxV
The regressor array that is used to calculate the joints from the
position of the vertices
vertices : torch.tensor BxVx3
The tensor of mesh vertices
Returns
-------
torch.tensor BxJx3
The location of the joints
'''
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
def blend_shapes(betas, shape_disps):
''' Calculates the per vertex displacement due to the blend shapes
Parameters
----------
betas : torch.tensor Bx(num_betas)
Blend shape coefficients
shape_disps: torch.tensor Vx3x(num_betas)
Blend shapes
Returns
-------
torch.tensor BxVx3
The per-vertex displacement due to shape deformation
'''
# Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
# i.e. Multiply each shape displacement by its corresponding beta and
# then sum them.
blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
return blend_shape
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
''' Calculates the rotation matrices for a batch of rotation vectors
Parameters
----------
rot_vecs: torch.tensor Nx3
array of N axis-angle vectors
Returns
-------
R: torch.tensor Nx3x3
The rotation matrices for the given axis-angle parameters
'''
if len(rot_vecs.shape) > 2:
rot_vec_ori = rot_vecs
rot_vecs = rot_vecs.view(-1, 3)
else:
rot_vec_ori = None
batch_size = rot_vecs.shape[0]
device = rot_vecs.device
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
rot_dir = rot_vecs / angle
cos = torch.unsqueeze(torch.cos(angle), dim=1)
sin = torch.unsqueeze(torch.sin(angle), dim=1)
# Bx1 arrays
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
.view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
if rot_vec_ori is not None:
rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3)
return rot_mat
def transform_mat(R, t):
''' Creates a batch of transformation matrices
Args:
- R: Bx3x3 array of a batch of rotation matrices
- t: Bx3x1 array of a batch of translation vectors
Returns:
- T: Bx4x4 Transformation matrix
'''
# No padding left or right, only add an extra row
return torch.cat([F.pad(R, [0, 0, 0, 1]),
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
"""
Applies a batch of rigid transformations to the joints
Parameters
----------
rot_mats : torch.tensor BxNx3x3
Tensor of rotation matrices
joints : torch.tensor BxNx3
Locations of joints
parents : torch.tensor BxN
The kinematic tree of each object
dtype : torch.dtype, optional:
The data type of the created tensors, the default is torch.float32
Returns
-------
posed_joints : torch.tensor BxNx3
The locations of the joints after applying the pose rotations
rel_transforms : torch.tensor BxNx4x4
The relative (with respect to the root joint) rigid transformations
for all the joints
"""
joints = torch.unsqueeze(joints, dim=-1)
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, parents[1:]]
transforms_mat = transform_mat(
rot_mats.view(-1, 3, 3),
rel_joints.contiguous().view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
curr_res = torch.matmul(transform_chain[parents[i]],
transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
# The last column of the transformations contains the posed joints
posed_joints = transforms[:, :, :3, 3]
# The last column of the transformations contains the posed joints
posed_joints = transforms[:, :, :3, 3]
joints_homogen = F.pad(joints, [0, 0, 0, 1])
rel_transforms = transforms - F.pad(
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
return posed_joints, rel_transforms
def dqs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False,
use_shape_blending=True, use_pose_blending=True, J_shaped=None):
''' Performs Linear Blend Skinning with the given shape and pose parameters
Parameters
----------
betas : torch.tensor BxNB
The tensor of shape parameters
pose : torch.tensor Bx(J + 1) * 3
The pose parameters in axis-angle format
v_template torch.tensor BxVx3
The template mesh that will be deformed
shapedirs : torch.tensor 1xNB
The tensor of PCA shape displacements
posedirs : torch.tensor Px(V * 3)
The pose PCA coefficients
J_regressor : torch.tensor JxV
The regressor array that is used to calculate the joints from
the position of the vertices
parents: torch.tensor J
The array that describes the kinematic tree for the model
lbs_weights: torch.tensor N x V x (J + 1)
The linear blend skinning weights that represent how much the
rotation matrix of each part affects each vertex
pose2rot: bool, optional
Flag on whether to convert the input pose tensor to rotation
matrices. The default value is True. If False, then the pose tensor
should already contain rotation matrices and have a size of
Bx(J + 1)x9
dtype: torch.dtype, optional
Returns
-------
verts: torch.tensor BxVx3
The vertices of the mesh after applying the shape and pose
displacements.
joints: torch.tensor BxJx3
The joints of the model
'''
batch_size = max(betas.shape[0], pose.shape[0])
device = betas.device
# Add shape contribution
if use_shape_blending:
v_shaped = v_template + blend_shapes(betas, shapedirs)
# Get the joints
# NxJx3 array
J = vertices2joints(J_regressor, v_shaped)
else:
v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1)
assert J_shaped is not None
J = J_shaped[None].expand(batch_size, -1, -1)
if only_shape:
return v_shaped, J
# 3. Add pose blend shapes
# N x J x 3 x 3
if pose2rot:
rot_mats = batch_rodrigues(
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
else:
rot_mats = pose.view(batch_size, -1, 3, 3)
if use_pose_blending:
ident = torch.eye(3, dtype=dtype, device=device)
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
pose_offsets = torch.matmul(pose_feature, posedirs) \
.view(batch_size, -1, 3)
v_posed = pose_offsets + v_shaped
else:
v_posed = v_shaped
# 4. Get the global joint location
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
# 5. Do skinning:
# W is N x V x (J + 1)
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
verts=batch_dqs_blending(A,W,v_posed)
return verts, J_transformed
#A: B,J,4,4 W: B,V,J
def batch_dqs_blending(A,W,Vs):
Bnum,Jnum,_,_=A.shape
_,Vnum,_=W.shape
A = A.view(Bnum*Jnum,4,4)
Rs=A[:,:3,:3]
ws=torch.sqrt(torch.clamp(Rs[:,0,0]+Rs[:,1,1]+Rs[:,2,2]+1.,min=1.e-6))/2.
xs=(Rs[:,2,1]-Rs[:,1,2])/(4.*ws)
ys=(Rs[:,0,2]-Rs[:,2,0])/(4.*ws)
zs=(Rs[:,1,0]-Rs[:,0,1])/(4.*ws)
Ts=A[:,:3,3]
vDw=-0.5*( Ts[:,0]*xs + Ts[:,1]*ys + Ts[:,2]*zs)
vDx=0.5*( Ts[:,0]*ws + Ts[:,1]*zs - Ts[:,2]*ys)
vDy=0.5*(-Ts[:,0]*zs + Ts[:,1]*ws + Ts[:,2]*xs)
vDz=0.5*( Ts[:,0]*ys - Ts[:,1]*xs + Ts[:,2]*ws)
b0=W.unsqueeze(-2)@torch.cat([ws[:,None],xs[:,None],ys[:,None],zs[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4
be=W.unsqueeze(-2)@torch.cat([vDw[:,None],vDx[:,None],vDy[:,None],vDz[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4
b0 = b0.reshape(-1, 4)
be = be.reshape(-1, 4)
ns=torch.norm(b0,dim=-1,keepdim=True)
be=be/ns
b0=b0/ns
Vs=Vs.view(Bnum*Vnum,3)
Vs=Vs+2.*b0[:,1:].cross(b0[:,1:].cross(Vs)+b0[:,:1]*Vs)+2.*(b0[:,:1]*be[:,1:]-be[:,:1]*b0[:,1:]+b0[:,1:].cross(be[:,1:]))
return Vs.reshape(Bnum,Vnum,3)

438
easymocap/bodymodel/smpl.py Normal file
View File

@ -0,0 +1,438 @@
from .base import Model, Params
from .lbs import lbs, batch_rodrigues
import os
import numpy as np
import torch
def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')):
if 'torch.tensor' not in str(type(array)):
return torch.tensor(array, dtype=dtype).to(device)
else:
return array.to(device)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
def read_pickle(name):
import pickle
with open(name, 'rb') as f:
data = pickle.load(f, encoding='latin1')
return data
def load_model_data(model_path):
model_path = os.path.abspath(model_path)
assert os.path.exists(model_path), 'Path {} does not exist!'.format(
model_path)
if model_path.endswith('.npz'):
data = np.load(model_path)
data = dict(data)
elif model_path.endswith('.pkl'):
data = read_pickle(model_path)
return data
def load_regressor(regressor_path):
if regressor_path.endswith('.npy'):
X_regressor = to_tensor(np.load(regressor_path))
elif regressor_path.endswith('.txt'):
data = np.loadtxt(regressor_path)
with open(regressor_path, 'r') as f:
shape = f.readline().split()[1:]
reg = np.zeros((int(shape[0]), int(shape[1])))
for i, j, v in data:
reg[int(i), int(j)] = v
X_regressor = to_tensor(reg)
else:
import ipdb; ipdb.set_trace()
return X_regressor
def save_regressor(fname, data):
with open(fname, 'w') as f:
f.writelines('{} {} {}\r\n'.format('#', data.shape[0], data.shape[1]))
for i in range(data.shape[0]):
for j in range(data.shape[1]):
if(data[i, j] > 0):
f.writelines('{} {} {}\r\n'.format(i, j, data[i, j]))
class SMPLModel(Model):
def __init__(self, model_path, regressor_path=None,
device='cpu',
use_pose_blending=True, use_shape_blending=True, use_joints=True,
NUM_SHAPES=-1, NUM_POSES=-1,
use_lbs=True,
use_root_rot=False,
**kwargs) -> None:
super().__init__()
self.name = 'lbs'
self.dtype = torch.float32 # not support fp16 now
self.use_pose_blending = use_pose_blending
self.use_shape_blending = use_shape_blending
self.use_root_rot = use_root_rot
self.NUM_SHAPES = NUM_SHAPES
self.NUM_POSES = NUM_POSES
self.NUM_POSES_FULL = NUM_POSES
self.use_joints = use_joints
if isinstance(device, str):
device = torch.device(device)
if not torch.torch.cuda.is_available():
device = torch.device('cpu')
self.device = device
self.model_type = 'smpl'
# create the SMPL model
self.lbs = lbs
self.data = load_model_data(model_path)
self.register_any_lbs(self.data)
# keypoints regressor
if regressor_path is not None and use_joints:
X_regressor = load_regressor(regressor_path)
X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0)
self.register_any_keypoints(X_regressor)
if not self.use_root_rot:
self.NUM_POSES -= 3 # remove first 3 dims
self.to(self.device)
def register_any_lbs(self, data):
self.faces = to_np(self.data['f'], dtype=np.int64)
self.register_buffer('faces_tensor',
to_tensor(self.faces, dtype=torch.long))
for key in ['J_regressor', 'v_template', 'weights']:
if key not in data.keys():
print('Warning: {} not in data'.format(key))
self.__setattr__(key, None)
continue
val = to_tensor(to_np(data[key]), dtype=self.dtype)
self.register_buffer(key, val)
self.NUM_POSES = self.weights.shape[-1] * 3
self.NUM_POSES_FULL = self.NUM_POSES
# add poseblending
if self.use_pose_blending:
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
num_pose_basis = data['posedirs'].shape[-1]
# 207 x 20670
data['posedirs_origin'] = data['posedirs']
data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T
val = to_tensor(to_np(data['posedirs']), dtype=self.dtype)
self.register_buffer('posedirs', val)
else:
self.posedirs = None
# add shape blending
if self.use_shape_blending:
val = to_tensor(to_np(data['shapedirs']), dtype=self.dtype)
if self.NUM_SHAPES != -1:
val = val[..., :self.NUM_SHAPES]
self.register_buffer('shapedirs', val)
self.NUM_SHAPES = val.shape[-1]
else:
self.shapedirs = None
if self.use_shape_blending:
self.J_shaped = None
else:
val = to_tensor(to_np(data['J']), dtype=self.dtype)
self.register_buffer('J_shaped', val)
self.nVertices = self.v_template.shape[0]
# indices of parents for each joints
kintree_table = data['kintree_table']
if len(kintree_table.shape) == 2:
kintree_table = kintree_table[0]
parents = to_tensor(to_np(kintree_table)).long()
parents[0] = -1
self.register_buffer('parents', parents)
def register_any_keypoints(self, X_regressor):
# set the parameter of keypoints level
j_J_regressor = torch.zeros(self.J_regressor.shape[0], X_regressor.shape[0], device=self.device)
for i in range(self.J_regressor.shape[0]):
j_J_regressor[i, i] = 1
j_v_template = X_regressor @ self.v_template
j_weights = X_regressor @ self.weights
if self.use_pose_blending:
posedirs = self.data['posedirs_origin']
j_posedirs = torch.einsum('ab, bde->ade', X_regressor, torch.Tensor(posedirs)).numpy()
j_posedirs = np.reshape(j_posedirs, [-1, posedirs.shape[-1]]).T
j_posedirs = to_tensor(j_posedirs)
self.register_buffer('j_posedirs', j_posedirs)
else:
self.j_posedirs = None
if self.use_shape_blending:
j_shapedirs = torch.einsum('vij,kv->kij', [self.shapedirs, X_regressor])
self.register_buffer('j_shapedirs', j_shapedirs)
else:
self.j_shapedirs = None
self.register_buffer('j_weights', j_weights)
self.register_buffer('j_v_template', j_v_template)
self.register_buffer('j_J_regressor', j_J_regressor)
def forward(self, return_verts=True, return_tensor=True,
return_smpl_joints=False,
only_shape=False, pose2rot=True, **params):
params = self.check_params(params)
poses, shapes = params['poses'], params['shapes']
poses = self.extend_poses(pose2rot=pose2rot, **params)
Rh, Th = params['Rh'], params['Th']
# check if there are multiple person
if len(shapes.shape) == 3:
reshape = poses.shape[:2]
Rh = Rh.reshape(-1, *Rh.shape[2:])
Th = Th.reshape(-1, *Th.shape[2:])
poses = poses.reshape(-1, *poses.shape[2:])
shapes = shapes.reshape(-1, *shapes.shape[2:])
else:
reshape = None
if len(Rh.shape) == 2: # angle-axis
Rh = batch_rodrigues(Rh)
Th = Th.unsqueeze(dim=1)
if return_verts or not self.use_joints:
v_template = self.v_template
if 'scale' in params.keys():
v_template = v_template * params['scale'][0]
vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape,
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
if not self.use_joints and not return_verts:
vertices = joints
else:
# only forward joints
v_template = self.j_v_template
if 'scale' in params.keys():
v_template = v_template * params['scale'][0]
vertices, joints, _, _ = self.lbs(shapes, poses, v_template,
self.j_shapedirs, self.j_posedirs,
self.j_J_regressor, self.parents,
self.j_weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape,
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
if return_smpl_joints:
# vertices = vertices[:, :self.J_regressor.shape[0], :]
vertices = joints
else:
vertices = vertices[:, self.J_regressor.shape[0]:, :]
vertices = torch.matmul(vertices, Rh.transpose(1, 2)) + Th
if not return_tensor:
vertices = vertices.detach().cpu().numpy()
if reshape is not None:
vertices = vertices.reshape(*reshape, *vertices.shape[1:])
return vertices
def transform(self, params, pose2rot=True, return_vertices=True):
v_template = self.v_template
params = self.check_params(params)
shapes = params['shapes']
poses = self.extend_poses(**params)
vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.weights, pose2rot=pose2rot, dtype=self.dtype,
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped,
return_vertices=return_vertices)
return T_joints, T_vertices
def merge_params(self, params, **kwargs):
return Params.merge(params, **kwargs)
def convert_from_standard_smpl(self, params):
params = self.check_params(params)
poses, shapes = params['poses'], params['shapes']
Th = params['Th']
vertices, joints, _, _ = lbs(shapes, poses, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.weights, pose2rot=True, dtype=self.dtype, only_shape=True)
# N x 3
j0 = joints[:, 0, :]
Rh = poses[:, :3].clone()
# N x 3 x 3
rot = batch_rodrigues(Rh)
Tnew = Th + j0 - torch.einsum('bij,bj->bi', rot, j0)
poses[:, :3] = 0
res = dict(poses=poses.detach().cpu().numpy(),
shapes=shapes.detach().cpu().numpy(),
Rh=Rh.detach().cpu().numpy(),
Th=Tnew.detach().cpu().numpy()
)
return res
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
params = {
'poses': np.zeros((nFrames, self.NUM_POSES)),
'shapes': np.zeros((nShapes, self.NUM_SHAPES)),
'Rh': np.zeros((nFrames, 3)),
'Th': np.zeros((nFrames, 3)),
}
if add_scale:
params['scale'] = np.ones((1, 1))
if nPerson > 1:
for key in params.keys():
params[key] = params[key][:, None].repeat(nPerson, axis=1)
if ret_tensor:
for key in params.keys():
params[key] = to_tensor(params[key], self.dtype, self.device)
return params
def check_params(self, body_params):
# 预先拷贝一下,不要修改到原始数据了
body_params = body_params.copy()
poses = body_params['poses']
nFrames = poses.shape[0]
# convert to torch
if 'torch' not in str(type(poses)):
dtype, device = self.dtype, self.device
for key in ['poses', 'handl', 'handr', 'shapes', 'expression', 'Rh', 'Th', 'scale']:
if key not in body_params.keys():
continue
body_params[key] = to_tensor(body_params[key], dtype, device)
poses = body_params['poses']
# check Rh and Th
for key in ['Rh', 'Th']:
if key not in body_params.keys():
body_params[key] = torch.zeros((nFrames, 3), dtype=poses.dtype, device=poses.device)
# process shapes
for key in ['shapes']:
if body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 2:
body_params[key] = body_params[key].expand(nFrames, -1)
elif body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 3:
body_params[key] = body_params[key].expand(*body_params['poses'].shape[:2], -1)
return body_params
def __str__(self) -> str:
res = '- Model: {}\n'.format(self.model_type)
res += ' poses: {}\n'.format(self.NUM_POSES)
res += ' shapes: {}\n'.format(self.NUM_SHAPES)
res += ' vertices: {}\n'.format(self.v_template.shape)
res += ' faces: {}\n'.format(self.faces.shape)
res += ' posedirs: {}\n'.format(self.posedirs.shape)
res += ' shapedirs: {}\n'.format(self.shapedirs.shape)
return res
def extend_poses(self, poses, **kwargs):
if poses.shape[-1] == self.NUM_POSES_FULL:
return poses
if not self.use_root_rot:
if kwargs.get('pose2rot', True):
zero_rot = torch.zeros((*poses.shape[:-1], 3), dtype=poses.dtype, device=poses.device)
poses = torch.cat([zero_rot, poses], dim=-1)
elif poses.shape[-3] != self.NUM_POSES_FULL // 3:
# insert a blank rotation
zero_rot = torch.zeros((*poses.shape[:-3], 1, 3), dtype=poses.dtype, device=poses.device)
zero_rot = batch_rodrigues(zero_rot)
poses = torch.cat([zero_rot, poses], dim=-3)
return poses
def jacobian_posesfull_poses(self, poses, poses_full):
# TODO: cache this
if self.use_root_rot:
jacobian = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device)
else:
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
eye_right = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device)
jacobian = torch.cat([zero_root, eye_right], dim=0)
return jacobian
def export_full_poses(self, poses, **kwargs):
if not self.use_root_rot:
poses = np.hstack([np.zeros((poses.shape[0], 3)), poses])
return poses
def encode(self, body_params):
# This function provide standard SMPL parameters to this model
poses = body_params['poses']
if 'Rh' not in body_params.keys():
body_params['Rh'] = poses[:, :3].copy()
if 'Th' not in body_params.keys():
if 'trans' in body_params.keys():
body_params['Th'] = body_params.pop('trans')
else:
body_params['Th'] = np.zeros((poses.shape[0], 3), dtype=poses.dtype)
if not self.use_root_rot and poses.shape[1] == 72:
body_params['poses'] = poses[:, 3:].copy()
return body_params
class SMPLLayerEmbedding(SMPLModel):
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
super().__init__(**kwargs)
from human_body_prior.tools.model_loader import load_model
from human_body_prior.models.vposer_model import VPoser
vposer, _ = load_model(vposer_ckpt,
model_code=VPoser,
remove_words_in_model_weights='vp_model.',
disable_grad=True)
vposer.eval()
vposer.to(self.device)
self.vposer = vposer
self.vposer_dim = 32
self.NUM_POSES = self.vposer_dim
def encode(self, body_params):
# This function provide standard SMPL parameters to this model
poses = body_params['poses']
if poses.shape[1] == self.vposer_dim:
return body_params
poses_tensor = torch.Tensor(poses).to(self.device)
ret = self.vposer.encode(poses_tensor[:, :63]).mean
body_params = super().encode(body_params)
body_params['poses'] = ret.detach().cpu().numpy()
return body_params
def extend_poses(self, poses, **kwargs):
if poses.shape[-1] == self.vposer_dim:
ret = self.vposer.decode(poses)
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
elif poses.shape[-1] == self.NUM_POSES_FULL:
return poses
elif poses.shape[-1] == self.NUM_POSES_FULL - 3:
poses_zero = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
poses = torch.cat([poses_zero, poses], dim=-1)
return poses
poses_zero = torch.zeros((poses_body.shape[0], 3), dtype=poses_body.dtype, device=poses_body.device)
poses = torch.cat([poses_zero, poses_body, poses_zero, poses_zero], dim=1)
return poses
def export_full_poses(self, poses, **kwargs):
poses = torch.Tensor(poses).to(self.device)
poses = self.extend_poses(poses)
return poses.detach().cpu().numpy()
if __name__ == '__main__':
vis = True
test_config = {
'smpl':{
'model_path': 'data/bodymodels/SMPL_python_v.1.1.0/smpl/models/basicmodel_m_lbs_10_207_0_v1.1.0.pkl',
'regressor_path': 'data/smplx/J_regressor_body25.npy',
},
'smplh':{
'model_path': 'data/bodymodels/smplhv1.2/male/model.npz',
'regressor_path': None,
},
'mano':{
'model_path': 'data/bodymodels/manov1.2/MANO_LEFT.pkl',
'regressor_path': None,
},
'flame':{
'model_path': 'data/bodymodels/FLAME2020/FLAME_MALE.pkl',
'regressor_path': None,
}
}
for name, cfg in test_config.items():
print('Testing {}...'.format(name))
model = SMPLModel(**cfg)
print(model)
params = model.init_params()
for key in params.keys():
params[key] = (np.random.rand(*params[key].shape) - 0.5)*0.5
vertices = model.vertices(params, return_tensor=True)[0]
if cfg['regressor_path'] is not None:
keypoints = model.keypoints(params, return_tensor=True)[0]
print(keypoints.shape)
if vis:
import open3d as o3d
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices.reshape(-1, 3))
mesh.triangles = o3d.utility.Vector3iVector(model.faces.reshape(-1, 3))
mesh.compute_vertex_normals()
o3d.visualization.draw_geometries([mesh])

View File

@ -0,0 +1,259 @@
import torch
import torch.nn as nn
from .base import Model
from .smpl import SMPLModel, SMPLLayerEmbedding, read_pickle, to_tensor
from os.path import join
import numpy as np
def read_hand(path, use_pca, use_flat_mean, num_pca_comps):
data = read_pickle(path)
mean = data['hands_mean'].reshape(1, -1).astype(np.float32)
mean_full = mean
components_full = data['hands_components'].astype(np.float32)
weight = np.diag(components_full @ components_full.T)
components = components_full[:num_pca_comps]
weight = weight[:num_pca_comps]
if use_flat_mean:
mean = np.zeros_like(mean)
return mean, components, weight, mean_full, components_full
class MANO(SMPLModel):
def __init__(self, cfg_hand, **kwargs):
super().__init__(**kwargs)
self.name = 'mano'
self.use_root_rot = False
mean, components, weight, mean_full, components_full = read_hand(kwargs['model_path'], **cfg_hand)
self.register_buffer('mean', to_tensor(mean, dtype=self.dtype))
self.register_buffer('components', to_tensor(components, dtype=self.dtype))
self.cfg_hand = cfg_hand
self.to(self.device)
if cfg_hand.use_pca:
self.NUM_POSES = cfg_hand.num_pca_comps
def extend_poses(self, poses, **kwargs):
if poses.shape[-1] == self.mean.shape[-1] + 3:
return poses
if self.cfg_hand.use_pca:
poses = poses @ self.components
if kwargs.get('pose2rot', True):
poses = super().extend_poses(poses+self.mean, **kwargs)
else:
poses = super().extend_poses(poses, **kwargs)
return poses
def jacobian_posesfull_poses(self, poses, poses_full):
if self.cfg_hand.use_pca:
jacobian = self.components.t()
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
jacobian = torch.cat([zero_root, jacobian], dim=0)
else:
jacobian = super().jacobian_posesfull_poses(poses, poses_full)
return jacobian
class MANOLR(Model):
def __init__(self, model_path, regressor_path, cfg_hand, **kwargs):
super().__init__()
self.name = 'manolr'
keys = list(model_path.keys())
# stack 方式:(nframes, nhand x ndim)
self.keys = keys
modules_hand = {}
faces = []
v_template = []
cnt = 0
for key in keys:
modules_hand[key] = MANO(cfg_hand, model_path=model_path[key], regressor_path=regressor_path[key], **kwargs)
v_template.append(modules_hand[key].v_template.cpu().numpy())
faces.append(modules_hand[key].faces + cnt)
cnt += v_template[-1].shape[0]
self.device = modules_hand[key].device
self.dtype = modules_hand[key].dtype
if key == 'right':
modules_hand[key].shapedirs[:, 0] *= -1
modules_hand[key].j_shapedirs[:, 0] *= -1
self.faces = np.vstack(faces)
self.v_template = np.vstack(v_template)
self.modules_hand = nn.ModuleDict(modules_hand)
self.to(self.device)
def init_params(self, **kwargs):
param_all = {}
for key in self.keys:
param = self.modules_hand[key].init_params(**kwargs)
param_all[key] = param
if False:
params = {k: torch.cat([param_all[key][k] for key in self.keys], dim=-1) for k in param.keys()}
else:
params = {k: np.concatenate([param_all[key][k] for key in self.keys], axis=-1) for k in param.keys() if k != 'shapes'}
params['shapes'] = param_all['left']['shapes']
return params
def split(self, params):
params_split = {}
for imodel, model in enumerate(self.keys):
param_= params.copy()
for key in ['poses', 'shapes', 'Rh', 'Th']:
if key not in params.keys():continue
if key == 'shapes':
continue
shape = params[key].shape[-1]
start = shape//len(self.keys)*imodel
end = shape//len(self.keys)*(imodel+1)
param_[key] = params[key][:, start:end]
params_split[model] = param_
return params_split
def forward(self, **params):
params_split = self.split(params)
rets = []
for imodel, model in enumerate(self.keys):
ret = self.modules_hand[model](**params_split[model])
rets.append(ret)
if params.get('return_tensor', True):
rets = torch.cat(rets, dim=1)
else:
rets = np.concatenate(rets, axis=1)
return rets
def extend_poses(self, poses, **kwargs):
params_split = self.split({'poses': poses})
rets = []
for imodel, model in enumerate(self.keys):
poses = params_split[model]['poses']
poses = self.modules_hand[model].extend_poses(poses)
rets.append(poses)
poses = torch.cat(rets, dim=1)
return poses
def export_full_poses(self, poses, **kwargs):
params_split = self.split({'poses': poses})
rets = []
for imodel, model in enumerate(self.keys):
poses = torch.Tensor(params_split[model]['poses']).to(self.device)
poses = self.modules_hand[model].extend_poses(poses)
rets.append(poses)
poses = torch.cat(rets, dim=1)
return poses.detach().cpu().numpy()
class SMPLHModel(SMPLModel):
def __init__(self, mano_path, cfg_hand, **kwargs):
super().__init__(**kwargs)
self.NUM_POSES = self.NUM_POSES - 90
meanl, componentsl, weight_l, self.mean_full_l, self.components_full_l = read_hand(join(mano_path, 'MANO_LEFT.pkl'), **cfg_hand)
meanr, componentsr, weight_r, self.mean_full_r, self.components_full_r = read_hand(join(mano_path, 'MANO_RIGHT.pkl'), **cfg_hand)
self.register_buffer('weight_l', to_tensor(weight_l, dtype=self.dtype))
self.register_buffer('weight_r', to_tensor(weight_r, dtype=self.dtype))
self.register_buffer('meanl', to_tensor(meanl, dtype=self.dtype))
self.register_buffer('meanr', to_tensor(meanr, dtype=self.dtype))
self.register_buffer('componentsl', to_tensor(componentsl, dtype=self.dtype))
self.register_buffer('componentsr', to_tensor(componentsr, dtype=self.dtype))
self.register_buffer('jacobian_posesfull_poses_', self._jacobian_posesfull_poses())
self.NUM_HANDS = cfg_hand.num_pca_comps if cfg_hand.use_pca else 45
self.cfg_hand = cfg_hand
self.to(self.device)
def _jacobian_posesfull_poses(self):
# TODO: cache this
# | body_full/body | 0 | 0 |
# | 0 | l | 0 |
# | 0 | 0 | r |
eye_right = torch.eye(self.NUM_POSES, dtype=self.dtype)
#
jac_handl = self.componentsl.t()
jac_handr = self.componentsr.t()
output = torch.zeros((self.NUM_POSES_FULL, self.NUM_POSES+jac_handl.shape[1]*2), dtype=self.dtype)
if self.use_root_rot:
raise NotImplementedError
else:
output[3:3+self.NUM_POSES, :self.NUM_POSES] = eye_right
output[3+self.NUM_POSES:3+self.NUM_POSES+jac_handl.shape[0], \
self.NUM_POSES:self.NUM_POSES+jac_handl.shape[1]] = jac_handl
output[3+self.NUM_POSES+jac_handl.shape[0]:3+self.NUM_POSES+2*jac_handl.shape[0], \
self.NUM_POSES+jac_handl.shape[1]:self.NUM_POSES+jac_handl.shape[1]*2] = jac_handr
return output
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
params = super().init_params(nFrames, nShapes, nPerson, ret_tensor, add_scale=add_scale)
handl = np.zeros((nFrames, self.NUM_HANDS))
handr = np.zeros((nFrames, self.NUM_HANDS))
if nPerson > 1:
handl = handl[:, None].repeat(nPerson, axis=1)
handr = handr[:, None].repeat(nPerson, axis=1)
if ret_tensor:
handl = to_tensor(handl, self.dtype, self.device)
handr = to_tensor(handr, self.dtype, self.device)
params['handl'] = handl
params['handr'] = handr
return params
def extend_poses(self, poses, handl=None, handr=None, **kwargs):
if poses.shape[-1] == self.NUM_POSES_FULL:
return poses
poses = super().extend_poses(poses)
if handl is None:
handl = self.meanl.clone()
handr = self.meanr.clone()
handl = handl.expand(poses.shape[0], -1)
handr = handr.expand(poses.shape[0], -1)
else:
if self.cfg_hand.use_pca:
handl = handl @ self.componentsl
handr = handr @ self.componentsr
handl = handl +self.meanl
handr = handr +self.meanr
poses = torch.cat([poses, handl, handr], dim=-1)
return poses
def export_full_poses(self, poses, handl, handr, **kwargs):
poses = torch.Tensor(poses).to(self.device)
handl = torch.Tensor(handl).to(self.device)
handr = torch.Tensor(handr).to(self.device)
poses = self.extend_poses(poses, handl, handr)
return poses.detach().cpu().numpy()
class SMPLHModelEmbedding(SMPLHModel):
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
super().__init__(**kwargs)
from human_body_prior.tools.model_loader import load_model
from human_body_prior.models.vposer_model import VPoser
vposer, _ = load_model(vposer_ckpt,
model_code=VPoser,
remove_words_in_model_weights='vp_model.',
disable_grad=True)
vposer.to(self.device)
self.vposer = vposer
self.vposer_dim = 32
self.NUM_POSES = self.vposer_dim
def decode(self, poses, add_rot=True):
if poses.shape[-1] == 66 and add_rot:
return poses
elif poses.shape[-1] == 63 and not add_rot:
return poses
assert poses.shape[-1] == self.vposer_dim, poses.shape
ret = self.vposer.decode(poses)
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
if add_rot:
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
poses_body = torch.cat([zero_rot, poses_body], dim=-1)
return poses_body
def extend_poses(self, poses, handl, handr, **kwargs):
if poses.shape[-1] == self.NUM_POSES_FULL:
return poses
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
poses_body = self.decode(poses, add_rot=False)
if self.cfg_hand.use_pca:
handl = handl @ self.componentsl
handr = handr @ self.componentsr
handl = handl +self.meanl
handr = handr +self.meanr
poses = torch.cat([zero_rot, poses_body, handl, handr], dim=-1)
return poses
def export_full_poses(self, poses, handl, handr, **kwargs):
poses = torch.Tensor(poses).to(self.device)
handl = torch.Tensor(handl).to(self.device)
handr = torch.Tensor(handr).to(self.device)
poses = self.extend_poses(poses, handl, handr)
return poses.detach().cpu().numpy()