🚧 re-organize bodymodel
This commit is contained in:
parent
cd3f184f04
commit
050cb209d1
135
easymocap/bodymodel/base.py
Normal file
135
easymocap/bodymodel/base.py
Normal file
@ -0,0 +1,135 @@
|
||||
'''
|
||||
@ Date: 2022-03-17 19:23:59
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-07-15 12:15:46
|
||||
@ FilePath: /EasyMocapPublic/easymocap/bodymodel/base.py
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..mytools.file_utils import myarray2string
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = 'custom'
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def vertices(self, params, **kwargs):
|
||||
return self.forward(return_verts=True, **kwargs, **params)
|
||||
|
||||
def keypoints(self, params, **kwargs):
|
||||
return self.forward(return_verts=False, **kwargs, **params)
|
||||
|
||||
def transform(self, params, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class ComposedModel(torch.nn.Module):
|
||||
def __init__(self, config_dict):
|
||||
# 叠加多个模型的配置
|
||||
for name, config in config_dict.items():
|
||||
pass
|
||||
|
||||
class Params(dict):
|
||||
@classmethod
|
||||
def merge(self, params_list, share_shape=True, stack=np.vstack):
|
||||
output = {}
|
||||
for key in params_list[0].keys():
|
||||
if key == 'id':continue
|
||||
output[key] = stack([v[key] for v in params_list])
|
||||
if share_shape:
|
||||
output['shapes'] = output['shapes'].mean(axis=0, keepdims=True)
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return len(self['poses'])
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self:
|
||||
return self[name]
|
||||
else:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not isinstance(index, int):
|
||||
return super().__getitem__(index)
|
||||
if 'shapes' not in self.keys():
|
||||
# arbitray data
|
||||
ret = {}
|
||||
for key, val in self.items():
|
||||
if index >= 1 and val.shape[0] == 1:
|
||||
ret[key] = val[0]
|
||||
else:
|
||||
ret[key] = val[index]
|
||||
return Params(**ret)
|
||||
ret = {'id': 0}
|
||||
poses = self.poses
|
||||
shapes = self.shapes
|
||||
while len(shapes.shape) < len(poses.shape):
|
||||
shapes = shapes[None]
|
||||
if poses.shape[0] == shapes.shape[0]:
|
||||
if index >= 1 and shapes.shape[0] == 1:
|
||||
ret['shapes'] = shapes[0]
|
||||
else:
|
||||
ret['shapes'] = shapes[index]
|
||||
elif shapes.shape[0] == 1:
|
||||
ret['shapes'] = shapes[0]
|
||||
else:
|
||||
import ipdb; ipdb.set_trace()
|
||||
if index >= 1 and poses.shape[0] == 1:
|
||||
ret['poses'] = poses[0]
|
||||
else:
|
||||
ret['poses'] = poses[index]
|
||||
for key, val in self.items():
|
||||
if key == 'id':
|
||||
ret[key] = self[key]
|
||||
continue
|
||||
if key in ret.keys():continue
|
||||
if index >= 1 and val.shape[0] == 1:
|
||||
ret[key] = val[0]
|
||||
else:
|
||||
ret[key] = val[index]
|
||||
for key, val in ret.items():
|
||||
if key == 'id': continue
|
||||
if len(val.shape) == 1:
|
||||
ret[key] = val[None]
|
||||
return Params(**ret)
|
||||
|
||||
def to_multiperson(self, pids):
|
||||
results = []
|
||||
for i, pid in enumerate(pids):
|
||||
param = self[i]
|
||||
# TODO: this class just implement getattr
|
||||
# param.id = pid # is wrong
|
||||
param['id'] = pid
|
||||
results.append(param)
|
||||
return results
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = ''
|
||||
lastkey = list(self.keys())[-1]
|
||||
for key, val in self.items():
|
||||
if isinstance(val, np.ndarray):
|
||||
ret += '"{}": '.format(key) + myarray2string(val, indent=0)
|
||||
else:
|
||||
ret += '"{}": '.format(key) + str(val)
|
||||
if key != lastkey:
|
||||
ret += ',\n'
|
||||
return ret
|
||||
|
||||
def shape(self):
|
||||
ret = ''
|
||||
lastkey = list(self.keys())[-1]
|
||||
for key, val in self.items():
|
||||
if isinstance(val, np.ndarray):
|
||||
ret += '"{}": {}'.format(key, val.shape)
|
||||
else:
|
||||
ret += '"{}": '.format(key) + str(val)
|
||||
if key != lastkey:
|
||||
ret += ',\n'
|
||||
print(ret)
|
||||
return ret
|
501
easymocap/bodymodel/lbs.py
Normal file
501
easymocap/bodymodel/lbs.py
Normal file
@ -0,0 +1,501 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
||||
# holder of all proprietary rights on this computer program.
|
||||
# You can only use this computer program if you have closed
|
||||
# a license agreement with MPG or you get the right to use the computer
|
||||
# program from someone who is authorized to grant you that right.
|
||||
# Any use of the computer program without a valid license is prohibited and
|
||||
# liable to prosecution.
|
||||
#
|
||||
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
||||
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
||||
# for Intelligent Systems. All rights reserved.
|
||||
#
|
||||
# Contact: ps-license@tuebingen.mpg.de
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def rot_mat_to_euler(rot_mats):
|
||||
# Calculates rotation matrix to euler angles
|
||||
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
|
||||
|
||||
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
|
||||
rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
|
||||
return torch.atan2(-rot_mats[:, 2, 0], sy)
|
||||
|
||||
|
||||
def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
|
||||
dynamic_lmk_b_coords,
|
||||
neck_kin_chain, dtype=torch.float32):
|
||||
''' Compute the faces, barycentric coordinates for the dynamic landmarks
|
||||
|
||||
|
||||
To do so, we first compute the rotation of the neck around the y-axis
|
||||
and then use a pre-computed look-up table to find the faces and the
|
||||
barycentric coordinates that will be used.
|
||||
|
||||
Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
|
||||
for providing the original TensorFlow implementation and for the LUT.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vertices: torch.tensor BxVx3, dtype = torch.float32
|
||||
The tensor of input vertices
|
||||
pose: torch.tensor Bx(Jx3), dtype = torch.float32
|
||||
The current pose of the body model
|
||||
dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
|
||||
The look-up table from neck rotation to faces
|
||||
dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
|
||||
The look-up table from neck rotation to barycentric coordinates
|
||||
neck_kin_chain: list
|
||||
A python list that contains the indices of the joints that form the
|
||||
kinematic chain of the neck.
|
||||
dtype: torch.dtype, optional
|
||||
|
||||
Returns
|
||||
-------
|
||||
dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
|
||||
A tensor of size BxL that contains the indices of the faces that
|
||||
will be used to compute the current dynamic landmarks.
|
||||
dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
|
||||
A tensor of size BxL that contains the indices of the faces that
|
||||
will be used to compute the current dynamic landmarks.
|
||||
'''
|
||||
|
||||
batch_size = vertices.shape[0]
|
||||
|
||||
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
|
||||
neck_kin_chain)
|
||||
rot_mats = batch_rodrigues(
|
||||
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
|
||||
|
||||
rel_rot_mat = torch.eye(3, device=vertices.device,
|
||||
dtype=dtype).unsqueeze_(dim=0)
|
||||
for idx in range(len(neck_kin_chain)):
|
||||
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
|
||||
|
||||
y_rot_angle = torch.round(
|
||||
torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
|
||||
max=39)).to(dtype=torch.long)
|
||||
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
|
||||
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
|
||||
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
|
||||
y_rot_angle = (neg_mask * neg_vals +
|
||||
(1 - neg_mask) * y_rot_angle)
|
||||
|
||||
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
|
||||
0, y_rot_angle)
|
||||
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
|
||||
0, y_rot_angle)
|
||||
|
||||
return dyn_lmk_faces_idx, dyn_lmk_b_coords
|
||||
|
||||
|
||||
def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
|
||||
''' Calculates landmarks by barycentric interpolation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vertices: torch.tensor BxVx3, dtype = torch.float32
|
||||
The tensor of input vertices
|
||||
faces: torch.tensor Fx3, dtype = torch.long
|
||||
The faces of the mesh
|
||||
lmk_faces_idx: torch.tensor L, dtype = torch.long
|
||||
The tensor with the indices of the faces used to calculate the
|
||||
landmarks.
|
||||
lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
|
||||
The tensor of barycentric coordinates that are used to interpolate
|
||||
the landmarks
|
||||
|
||||
Returns
|
||||
-------
|
||||
landmarks: torch.tensor BxLx3, dtype = torch.float32
|
||||
The coordinates of the landmarks for each mesh in the batch
|
||||
'''
|
||||
# Extract the indices of the vertices for each face
|
||||
# BxLx3
|
||||
batch_size, num_verts = vertices.shape[:2]
|
||||
device = vertices.device
|
||||
|
||||
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
|
||||
batch_size, -1, 3)
|
||||
|
||||
lmk_faces += torch.arange(
|
||||
batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
|
||||
|
||||
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
|
||||
batch_size, -1, 3, 3)
|
||||
|
||||
landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
|
||||
return landmarks
|
||||
|
||||
|
||||
def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
|
||||
lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False,
|
||||
use_shape_blending=True, use_pose_blending=True, J_shaped=None, return_vertices=True):
|
||||
''' Performs Linear Blend Skinning with the given shape and pose parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
betas : torch.tensor BxNB
|
||||
The tensor of shape parameters
|
||||
pose : torch.tensor Bx(J + 1) * 3
|
||||
The pose parameters in axis-angle format
|
||||
v_template torch.tensor BxVx3
|
||||
The template mesh that will be deformed
|
||||
shapedirs : torch.tensor 1xNB
|
||||
The tensor of PCA shape displacements
|
||||
posedirs : torch.tensor Px(V * 3)
|
||||
The pose PCA coefficients
|
||||
J_regressor : torch.tensor JxV
|
||||
The regressor array that is used to calculate the joints from
|
||||
the position of the vertices
|
||||
parents: torch.tensor J
|
||||
The array that describes the kinematic tree for the model
|
||||
lbs_weights: torch.tensor N x V x (J + 1)
|
||||
The linear blend skinning weights that represent how much the
|
||||
rotation matrix of each part affects each vertex
|
||||
pose2rot: bool, optional
|
||||
Flag on whether to convert the input pose tensor to rotation
|
||||
matrices. The default value is True. If False, then the pose tensor
|
||||
should already contain rotation matrices and have a size of
|
||||
Bx(J + 1)x9
|
||||
dtype: torch.dtype, optional
|
||||
|
||||
Returns
|
||||
-------
|
||||
verts: torch.tensor BxVx3
|
||||
The vertices of the mesh after applying the shape and pose
|
||||
displacements.
|
||||
joints: torch.tensor BxJx3
|
||||
The joints of the model
|
||||
'''
|
||||
|
||||
batch_size = max(betas.shape[0], pose.shape[0])
|
||||
device = betas.device
|
||||
|
||||
# Add shape contribution
|
||||
if use_shape_blending:
|
||||
v_shaped = v_template + blend_shapes(betas, shapedirs)
|
||||
# Get the joints
|
||||
# NxJx3 array
|
||||
J = vertices2joints(J_regressor, v_shaped)
|
||||
else:
|
||||
v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
assert J_shaped is not None
|
||||
J = J_shaped[None].expand(batch_size, -1, -1)
|
||||
|
||||
if only_shape:
|
||||
return v_shaped, J, None, None
|
||||
# 3. Add pose blend shapes
|
||||
# N x J x 3 x 3
|
||||
if pose2rot:
|
||||
rot_mats = batch_rodrigues(
|
||||
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
|
||||
else:
|
||||
rot_mats = pose.view(batch_size, -1, 3, 3)
|
||||
|
||||
if use_pose_blending:
|
||||
ident = torch.eye(3, dtype=dtype, device=device)
|
||||
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
||||
pose_offsets = torch.matmul(pose_feature, posedirs) \
|
||||
.view(batch_size, -1, 3)
|
||||
v_posed = pose_offsets + v_shaped
|
||||
else:
|
||||
v_posed = v_shaped
|
||||
# 4. Get the global joint location
|
||||
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
||||
|
||||
if not return_vertices:
|
||||
return None, J_transformed, A, None
|
||||
# 5. Do skinning:
|
||||
# W is N x V x (J + 1)
|
||||
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
||||
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
||||
num_joints = J_transformed.shape[1]
|
||||
T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
|
||||
.view(batch_size, -1, 4, 4)
|
||||
|
||||
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
|
||||
dtype=dtype, device=device)
|
||||
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
|
||||
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
||||
|
||||
verts = v_homo[:, :, :3, 0]
|
||||
|
||||
return verts, J_transformed, A, T
|
||||
|
||||
|
||||
def vertices2joints(J_regressor, vertices):
|
||||
''' Calculates the 3D joint locations from the vertices
|
||||
|
||||
Parameters
|
||||
----------
|
||||
J_regressor : torch.tensor JxV
|
||||
The regressor array that is used to calculate the joints from the
|
||||
position of the vertices
|
||||
vertices : torch.tensor BxVx3
|
||||
The tensor of mesh vertices
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.tensor BxJx3
|
||||
The location of the joints
|
||||
'''
|
||||
|
||||
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
|
||||
|
||||
|
||||
def blend_shapes(betas, shape_disps):
|
||||
''' Calculates the per vertex displacement due to the blend shapes
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
betas : torch.tensor Bx(num_betas)
|
||||
Blend shape coefficients
|
||||
shape_disps: torch.tensor Vx3x(num_betas)
|
||||
Blend shapes
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.tensor BxVx3
|
||||
The per-vertex displacement due to shape deformation
|
||||
'''
|
||||
|
||||
# Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
|
||||
# i.e. Multiply each shape displacement by its corresponding beta and
|
||||
# then sum them.
|
||||
blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
|
||||
return blend_shape
|
||||
|
||||
|
||||
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
||||
''' Calculates the rotation matrices for a batch of rotation vectors
|
||||
Parameters
|
||||
----------
|
||||
rot_vecs: torch.tensor Nx3
|
||||
array of N axis-angle vectors
|
||||
Returns
|
||||
-------
|
||||
R: torch.tensor Nx3x3
|
||||
The rotation matrices for the given axis-angle parameters
|
||||
'''
|
||||
if len(rot_vecs.shape) > 2:
|
||||
rot_vec_ori = rot_vecs
|
||||
rot_vecs = rot_vecs.view(-1, 3)
|
||||
else:
|
||||
rot_vec_ori = None
|
||||
batch_size = rot_vecs.shape[0]
|
||||
device = rot_vecs.device
|
||||
|
||||
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
||||
rot_dir = rot_vecs / angle
|
||||
|
||||
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
||||
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
||||
|
||||
# Bx1 arrays
|
||||
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
||||
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
||||
|
||||
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
||||
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
|
||||
.view((batch_size, 3, 3))
|
||||
|
||||
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
||||
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
||||
if rot_vec_ori is not None:
|
||||
rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3)
|
||||
return rot_mat
|
||||
|
||||
|
||||
def transform_mat(R, t):
|
||||
''' Creates a batch of transformation matrices
|
||||
Args:
|
||||
- R: Bx3x3 array of a batch of rotation matrices
|
||||
- t: Bx3x1 array of a batch of translation vectors
|
||||
Returns:
|
||||
- T: Bx4x4 Transformation matrix
|
||||
'''
|
||||
# No padding left or right, only add an extra row
|
||||
return torch.cat([F.pad(R, [0, 0, 0, 1]),
|
||||
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
|
||||
|
||||
|
||||
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
||||
"""
|
||||
Applies a batch of rigid transformations to the joints
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rot_mats : torch.tensor BxNx3x3
|
||||
Tensor of rotation matrices
|
||||
joints : torch.tensor BxNx3
|
||||
Locations of joints
|
||||
parents : torch.tensor BxN
|
||||
The kinematic tree of each object
|
||||
dtype : torch.dtype, optional:
|
||||
The data type of the created tensors, the default is torch.float32
|
||||
|
||||
Returns
|
||||
-------
|
||||
posed_joints : torch.tensor BxNx3
|
||||
The locations of the joints after applying the pose rotations
|
||||
rel_transforms : torch.tensor BxNx4x4
|
||||
The relative (with respect to the root joint) rigid transformations
|
||||
for all the joints
|
||||
"""
|
||||
|
||||
joints = torch.unsqueeze(joints, dim=-1)
|
||||
|
||||
rel_joints = joints.clone()
|
||||
rel_joints[:, 1:] -= joints[:, parents[1:]]
|
||||
|
||||
transforms_mat = transform_mat(
|
||||
rot_mats.view(-1, 3, 3),
|
||||
rel_joints.contiguous().view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
|
||||
|
||||
transform_chain = [transforms_mat[:, 0]]
|
||||
for i in range(1, parents.shape[0]):
|
||||
# Subtract the joint location at the rest pose
|
||||
# No need for rotation, since it's identity when at rest
|
||||
curr_res = torch.matmul(transform_chain[parents[i]],
|
||||
transforms_mat[:, i])
|
||||
transform_chain.append(curr_res)
|
||||
|
||||
transforms = torch.stack(transform_chain, dim=1)
|
||||
|
||||
# The last column of the transformations contains the posed joints
|
||||
posed_joints = transforms[:, :, :3, 3]
|
||||
|
||||
# The last column of the transformations contains the posed joints
|
||||
posed_joints = transforms[:, :, :3, 3]
|
||||
|
||||
joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
||||
|
||||
rel_transforms = transforms - F.pad(
|
||||
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
return posed_joints, rel_transforms
|
||||
|
||||
def dqs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
|
||||
lbs_weights, pose2rot=True, dtype=torch.float32, only_shape=False,
|
||||
use_shape_blending=True, use_pose_blending=True, J_shaped=None):
|
||||
''' Performs Linear Blend Skinning with the given shape and pose parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
betas : torch.tensor BxNB
|
||||
The tensor of shape parameters
|
||||
pose : torch.tensor Bx(J + 1) * 3
|
||||
The pose parameters in axis-angle format
|
||||
v_template torch.tensor BxVx3
|
||||
The template mesh that will be deformed
|
||||
shapedirs : torch.tensor 1xNB
|
||||
The tensor of PCA shape displacements
|
||||
posedirs : torch.tensor Px(V * 3)
|
||||
The pose PCA coefficients
|
||||
J_regressor : torch.tensor JxV
|
||||
The regressor array that is used to calculate the joints from
|
||||
the position of the vertices
|
||||
parents: torch.tensor J
|
||||
The array that describes the kinematic tree for the model
|
||||
lbs_weights: torch.tensor N x V x (J + 1)
|
||||
The linear blend skinning weights that represent how much the
|
||||
rotation matrix of each part affects each vertex
|
||||
pose2rot: bool, optional
|
||||
Flag on whether to convert the input pose tensor to rotation
|
||||
matrices. The default value is True. If False, then the pose tensor
|
||||
should already contain rotation matrices and have a size of
|
||||
Bx(J + 1)x9
|
||||
dtype: torch.dtype, optional
|
||||
|
||||
Returns
|
||||
-------
|
||||
verts: torch.tensor BxVx3
|
||||
The vertices of the mesh after applying the shape and pose
|
||||
displacements.
|
||||
joints: torch.tensor BxJx3
|
||||
The joints of the model
|
||||
'''
|
||||
|
||||
batch_size = max(betas.shape[0], pose.shape[0])
|
||||
device = betas.device
|
||||
|
||||
# Add shape contribution
|
||||
if use_shape_blending:
|
||||
v_shaped = v_template + blend_shapes(betas, shapedirs)
|
||||
# Get the joints
|
||||
# NxJx3 array
|
||||
J = vertices2joints(J_regressor, v_shaped)
|
||||
else:
|
||||
v_shaped = v_template.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
assert J_shaped is not None
|
||||
J = J_shaped[None].expand(batch_size, -1, -1)
|
||||
|
||||
if only_shape:
|
||||
return v_shaped, J
|
||||
# 3. Add pose blend shapes
|
||||
# N x J x 3 x 3
|
||||
if pose2rot:
|
||||
rot_mats = batch_rodrigues(
|
||||
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
|
||||
else:
|
||||
rot_mats = pose.view(batch_size, -1, 3, 3)
|
||||
|
||||
if use_pose_blending:
|
||||
ident = torch.eye(3, dtype=dtype, device=device)
|
||||
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
||||
pose_offsets = torch.matmul(pose_feature, posedirs) \
|
||||
.view(batch_size, -1, 3)
|
||||
v_posed = pose_offsets + v_shaped
|
||||
else:
|
||||
v_posed = v_shaped
|
||||
# 4. Get the global joint location
|
||||
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
||||
|
||||
|
||||
# 5. Do skinning:
|
||||
# W is N x V x (J + 1)
|
||||
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
||||
|
||||
verts=batch_dqs_blending(A,W,v_posed)
|
||||
|
||||
return verts, J_transformed
|
||||
|
||||
#A: B,J,4,4 W: B,V,J
|
||||
def batch_dqs_blending(A,W,Vs):
|
||||
Bnum,Jnum,_,_=A.shape
|
||||
_,Vnum,_=W.shape
|
||||
A = A.view(Bnum*Jnum,4,4)
|
||||
Rs=A[:,:3,:3]
|
||||
ws=torch.sqrt(torch.clamp(Rs[:,0,0]+Rs[:,1,1]+Rs[:,2,2]+1.,min=1.e-6))/2.
|
||||
xs=(Rs[:,2,1]-Rs[:,1,2])/(4.*ws)
|
||||
ys=(Rs[:,0,2]-Rs[:,2,0])/(4.*ws)
|
||||
zs=(Rs[:,1,0]-Rs[:,0,1])/(4.*ws)
|
||||
Ts=A[:,:3,3]
|
||||
vDw=-0.5*( Ts[:,0]*xs + Ts[:,1]*ys + Ts[:,2]*zs)
|
||||
vDx=0.5*( Ts[:,0]*ws + Ts[:,1]*zs - Ts[:,2]*ys)
|
||||
vDy=0.5*(-Ts[:,0]*zs + Ts[:,1]*ws + Ts[:,2]*xs)
|
||||
vDz=0.5*( Ts[:,0]*ys - Ts[:,1]*xs + Ts[:,2]*ws)
|
||||
b0=W.unsqueeze(-2)@torch.cat([ws[:,None],xs[:,None],ys[:,None],zs[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4
|
||||
be=W.unsqueeze(-2)@torch.cat([vDw[:,None],vDx[:,None],vDy[:,None],vDz[:,None]],dim=-1).reshape(Bnum, 1, Jnum, 4) #B,V,J,4
|
||||
b0 = b0.reshape(-1, 4)
|
||||
be = be.reshape(-1, 4)
|
||||
|
||||
ns=torch.norm(b0,dim=-1,keepdim=True)
|
||||
be=be/ns
|
||||
b0=b0/ns
|
||||
Vs=Vs.view(Bnum*Vnum,3)
|
||||
Vs=Vs+2.*b0[:,1:].cross(b0[:,1:].cross(Vs)+b0[:,:1]*Vs)+2.*(b0[:,:1]*be[:,1:]-be[:,:1]*b0[:,1:]+b0[:,1:].cross(be[:,1:]))
|
||||
return Vs.reshape(Bnum,Vnum,3)
|
438
easymocap/bodymodel/smpl.py
Normal file
438
easymocap/bodymodel/smpl.py
Normal file
@ -0,0 +1,438 @@
|
||||
from .base import Model, Params
|
||||
from .lbs import lbs, batch_rodrigues
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')):
|
||||
if 'torch.tensor' not in str(type(array)):
|
||||
return torch.tensor(array, dtype=dtype).to(device)
|
||||
else:
|
||||
return array.to(device)
|
||||
|
||||
def to_np(array, dtype=np.float32):
|
||||
if 'scipy.sparse' in str(type(array)):
|
||||
array = array.todense()
|
||||
return np.array(array, dtype=dtype)
|
||||
|
||||
def read_pickle(name):
|
||||
import pickle
|
||||
with open(name, 'rb') as f:
|
||||
data = pickle.load(f, encoding='latin1')
|
||||
return data
|
||||
|
||||
def load_model_data(model_path):
|
||||
model_path = os.path.abspath(model_path)
|
||||
assert os.path.exists(model_path), 'Path {} does not exist!'.format(
|
||||
model_path)
|
||||
if model_path.endswith('.npz'):
|
||||
data = np.load(model_path)
|
||||
data = dict(data)
|
||||
elif model_path.endswith('.pkl'):
|
||||
data = read_pickle(model_path)
|
||||
return data
|
||||
|
||||
def load_regressor(regressor_path):
|
||||
if regressor_path.endswith('.npy'):
|
||||
X_regressor = to_tensor(np.load(regressor_path))
|
||||
elif regressor_path.endswith('.txt'):
|
||||
data = np.loadtxt(regressor_path)
|
||||
with open(regressor_path, 'r') as f:
|
||||
shape = f.readline().split()[1:]
|
||||
reg = np.zeros((int(shape[0]), int(shape[1])))
|
||||
for i, j, v in data:
|
||||
reg[int(i), int(j)] = v
|
||||
X_regressor = to_tensor(reg)
|
||||
else:
|
||||
import ipdb; ipdb.set_trace()
|
||||
return X_regressor
|
||||
|
||||
def save_regressor(fname, data):
|
||||
with open(fname, 'w') as f:
|
||||
f.writelines('{} {} {}\r\n'.format('#', data.shape[0], data.shape[1]))
|
||||
for i in range(data.shape[0]):
|
||||
for j in range(data.shape[1]):
|
||||
if(data[i, j] > 0):
|
||||
f.writelines('{} {} {}\r\n'.format(i, j, data[i, j]))
|
||||
|
||||
|
||||
class SMPLModel(Model):
|
||||
def __init__(self, model_path, regressor_path=None,
|
||||
device='cpu',
|
||||
use_pose_blending=True, use_shape_blending=True, use_joints=True,
|
||||
NUM_SHAPES=-1, NUM_POSES=-1,
|
||||
use_lbs=True,
|
||||
use_root_rot=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.name = 'lbs'
|
||||
self.dtype = torch.float32 # not support fp16 now
|
||||
self.use_pose_blending = use_pose_blending
|
||||
self.use_shape_blending = use_shape_blending
|
||||
self.use_root_rot = use_root_rot
|
||||
self.NUM_SHAPES = NUM_SHAPES
|
||||
self.NUM_POSES = NUM_POSES
|
||||
self.NUM_POSES_FULL = NUM_POSES
|
||||
self.use_joints = use_joints
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if not torch.torch.cuda.is_available():
|
||||
device = torch.device('cpu')
|
||||
self.device = device
|
||||
self.model_type = 'smpl'
|
||||
# create the SMPL model
|
||||
self.lbs = lbs
|
||||
self.data = load_model_data(model_path)
|
||||
self.register_any_lbs(self.data)
|
||||
|
||||
# keypoints regressor
|
||||
if regressor_path is not None and use_joints:
|
||||
X_regressor = load_regressor(regressor_path)
|
||||
X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0)
|
||||
self.register_any_keypoints(X_regressor)
|
||||
if not self.use_root_rot:
|
||||
self.NUM_POSES -= 3 # remove first 3 dims
|
||||
self.to(self.device)
|
||||
|
||||
def register_any_lbs(self, data):
|
||||
self.faces = to_np(self.data['f'], dtype=np.int64)
|
||||
self.register_buffer('faces_tensor',
|
||||
to_tensor(self.faces, dtype=torch.long))
|
||||
for key in ['J_regressor', 'v_template', 'weights']:
|
||||
if key not in data.keys():
|
||||
print('Warning: {} not in data'.format(key))
|
||||
self.__setattr__(key, None)
|
||||
continue
|
||||
val = to_tensor(to_np(data[key]), dtype=self.dtype)
|
||||
self.register_buffer(key, val)
|
||||
self.NUM_POSES = self.weights.shape[-1] * 3
|
||||
self.NUM_POSES_FULL = self.NUM_POSES
|
||||
# add poseblending
|
||||
if self.use_pose_blending:
|
||||
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
|
||||
num_pose_basis = data['posedirs'].shape[-1]
|
||||
# 207 x 20670
|
||||
data['posedirs_origin'] = data['posedirs']
|
||||
data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T
|
||||
val = to_tensor(to_np(data['posedirs']), dtype=self.dtype)
|
||||
self.register_buffer('posedirs', val)
|
||||
else:
|
||||
self.posedirs = None
|
||||
# add shape blending
|
||||
if self.use_shape_blending:
|
||||
val = to_tensor(to_np(data['shapedirs']), dtype=self.dtype)
|
||||
if self.NUM_SHAPES != -1:
|
||||
val = val[..., :self.NUM_SHAPES]
|
||||
self.register_buffer('shapedirs', val)
|
||||
self.NUM_SHAPES = val.shape[-1]
|
||||
else:
|
||||
self.shapedirs = None
|
||||
if self.use_shape_blending:
|
||||
self.J_shaped = None
|
||||
else:
|
||||
val = to_tensor(to_np(data['J']), dtype=self.dtype)
|
||||
self.register_buffer('J_shaped', val)
|
||||
|
||||
self.nVertices = self.v_template.shape[0]
|
||||
# indices of parents for each joints
|
||||
kintree_table = data['kintree_table']
|
||||
if len(kintree_table.shape) == 2:
|
||||
kintree_table = kintree_table[0]
|
||||
parents = to_tensor(to_np(kintree_table)).long()
|
||||
parents[0] = -1
|
||||
self.register_buffer('parents', parents)
|
||||
|
||||
def register_any_keypoints(self, X_regressor):
|
||||
# set the parameter of keypoints level
|
||||
j_J_regressor = torch.zeros(self.J_regressor.shape[0], X_regressor.shape[0], device=self.device)
|
||||
for i in range(self.J_regressor.shape[0]):
|
||||
j_J_regressor[i, i] = 1
|
||||
j_v_template = X_regressor @ self.v_template
|
||||
j_weights = X_regressor @ self.weights
|
||||
if self.use_pose_blending:
|
||||
posedirs = self.data['posedirs_origin']
|
||||
j_posedirs = torch.einsum('ab, bde->ade', X_regressor, torch.Tensor(posedirs)).numpy()
|
||||
j_posedirs = np.reshape(j_posedirs, [-1, posedirs.shape[-1]]).T
|
||||
j_posedirs = to_tensor(j_posedirs)
|
||||
self.register_buffer('j_posedirs', j_posedirs)
|
||||
else:
|
||||
self.j_posedirs = None
|
||||
if self.use_shape_blending:
|
||||
j_shapedirs = torch.einsum('vij,kv->kij', [self.shapedirs, X_regressor])
|
||||
self.register_buffer('j_shapedirs', j_shapedirs)
|
||||
else:
|
||||
self.j_shapedirs = None
|
||||
self.register_buffer('j_weights', j_weights)
|
||||
self.register_buffer('j_v_template', j_v_template)
|
||||
self.register_buffer('j_J_regressor', j_J_regressor)
|
||||
|
||||
def forward(self, return_verts=True, return_tensor=True,
|
||||
return_smpl_joints=False,
|
||||
only_shape=False, pose2rot=True, **params):
|
||||
params = self.check_params(params)
|
||||
poses, shapes = params['poses'], params['shapes']
|
||||
poses = self.extend_poses(pose2rot=pose2rot, **params)
|
||||
Rh, Th = params['Rh'], params['Th']
|
||||
# check if there are multiple person
|
||||
if len(shapes.shape) == 3:
|
||||
reshape = poses.shape[:2]
|
||||
Rh = Rh.reshape(-1, *Rh.shape[2:])
|
||||
Th = Th.reshape(-1, *Th.shape[2:])
|
||||
poses = poses.reshape(-1, *poses.shape[2:])
|
||||
shapes = shapes.reshape(-1, *shapes.shape[2:])
|
||||
else:
|
||||
reshape = None
|
||||
if len(Rh.shape) == 2: # angle-axis
|
||||
Rh = batch_rodrigues(Rh)
|
||||
Th = Th.unsqueeze(dim=1)
|
||||
if return_verts or not self.use_joints:
|
||||
v_template = self.v_template
|
||||
if 'scale' in params.keys():
|
||||
v_template = v_template * params['scale'][0]
|
||||
vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template,
|
||||
self.shapedirs, self.posedirs,
|
||||
self.J_regressor, self.parents,
|
||||
self.weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape,
|
||||
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
|
||||
if not self.use_joints and not return_verts:
|
||||
vertices = joints
|
||||
else:
|
||||
# only forward joints
|
||||
v_template = self.j_v_template
|
||||
if 'scale' in params.keys():
|
||||
v_template = v_template * params['scale'][0]
|
||||
vertices, joints, _, _ = self.lbs(shapes, poses, v_template,
|
||||
self.j_shapedirs, self.j_posedirs,
|
||||
self.j_J_regressor, self.parents,
|
||||
self.j_weights, pose2rot=pose2rot, dtype=self.dtype, only_shape=only_shape,
|
||||
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped)
|
||||
if return_smpl_joints:
|
||||
# vertices = vertices[:, :self.J_regressor.shape[0], :]
|
||||
vertices = joints
|
||||
else:
|
||||
vertices = vertices[:, self.J_regressor.shape[0]:, :]
|
||||
vertices = torch.matmul(vertices, Rh.transpose(1, 2)) + Th
|
||||
if not return_tensor:
|
||||
vertices = vertices.detach().cpu().numpy()
|
||||
if reshape is not None:
|
||||
vertices = vertices.reshape(*reshape, *vertices.shape[1:])
|
||||
return vertices
|
||||
|
||||
def transform(self, params, pose2rot=True, return_vertices=True):
|
||||
v_template = self.v_template
|
||||
params = self.check_params(params)
|
||||
shapes = params['shapes']
|
||||
poses = self.extend_poses(**params)
|
||||
vertices, joints, T_joints, T_vertices = self.lbs(shapes, poses, v_template,
|
||||
self.shapedirs, self.posedirs,
|
||||
self.J_regressor, self.parents,
|
||||
self.weights, pose2rot=pose2rot, dtype=self.dtype,
|
||||
use_pose_blending=self.use_pose_blending, use_shape_blending=self.use_shape_blending, J_shaped=self.J_shaped,
|
||||
return_vertices=return_vertices)
|
||||
return T_joints, T_vertices
|
||||
|
||||
def merge_params(self, params, **kwargs):
|
||||
return Params.merge(params, **kwargs)
|
||||
|
||||
|
||||
def convert_from_standard_smpl(self, params):
|
||||
params = self.check_params(params)
|
||||
poses, shapes = params['poses'], params['shapes']
|
||||
Th = params['Th']
|
||||
vertices, joints, _, _ = lbs(shapes, poses, self.v_template,
|
||||
self.shapedirs, self.posedirs,
|
||||
self.J_regressor, self.parents,
|
||||
self.weights, pose2rot=True, dtype=self.dtype, only_shape=True)
|
||||
# N x 3
|
||||
j0 = joints[:, 0, :]
|
||||
Rh = poses[:, :3].clone()
|
||||
# N x 3 x 3
|
||||
rot = batch_rodrigues(Rh)
|
||||
Tnew = Th + j0 - torch.einsum('bij,bj->bi', rot, j0)
|
||||
poses[:, :3] = 0
|
||||
res = dict(poses=poses.detach().cpu().numpy(),
|
||||
shapes=shapes.detach().cpu().numpy(),
|
||||
Rh=Rh.detach().cpu().numpy(),
|
||||
Th=Tnew.detach().cpu().numpy()
|
||||
)
|
||||
return res
|
||||
|
||||
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
|
||||
params = {
|
||||
'poses': np.zeros((nFrames, self.NUM_POSES)),
|
||||
'shapes': np.zeros((nShapes, self.NUM_SHAPES)),
|
||||
'Rh': np.zeros((nFrames, 3)),
|
||||
'Th': np.zeros((nFrames, 3)),
|
||||
}
|
||||
if add_scale:
|
||||
params['scale'] = np.ones((1, 1))
|
||||
if nPerson > 1:
|
||||
for key in params.keys():
|
||||
params[key] = params[key][:, None].repeat(nPerson, axis=1)
|
||||
if ret_tensor:
|
||||
for key in params.keys():
|
||||
params[key] = to_tensor(params[key], self.dtype, self.device)
|
||||
return params
|
||||
|
||||
def check_params(self, body_params):
|
||||
# 预先拷贝一下,不要修改到原始数据了
|
||||
body_params = body_params.copy()
|
||||
poses = body_params['poses']
|
||||
nFrames = poses.shape[0]
|
||||
# convert to torch
|
||||
if 'torch' not in str(type(poses)):
|
||||
dtype, device = self.dtype, self.device
|
||||
for key in ['poses', 'handl', 'handr', 'shapes', 'expression', 'Rh', 'Th', 'scale']:
|
||||
if key not in body_params.keys():
|
||||
continue
|
||||
body_params[key] = to_tensor(body_params[key], dtype, device)
|
||||
poses = body_params['poses']
|
||||
# check Rh and Th
|
||||
for key in ['Rh', 'Th']:
|
||||
if key not in body_params.keys():
|
||||
body_params[key] = torch.zeros((nFrames, 3), dtype=poses.dtype, device=poses.device)
|
||||
# process shapes
|
||||
for key in ['shapes']:
|
||||
if body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 2:
|
||||
body_params[key] = body_params[key].expand(nFrames, -1)
|
||||
elif body_params[key].shape[0] < nFrames and len(body_params[key].shape) == 3:
|
||||
body_params[key] = body_params[key].expand(*body_params['poses'].shape[:2], -1)
|
||||
return body_params
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = '- Model: {}\n'.format(self.model_type)
|
||||
res += ' poses: {}\n'.format(self.NUM_POSES)
|
||||
res += ' shapes: {}\n'.format(self.NUM_SHAPES)
|
||||
res += ' vertices: {}\n'.format(self.v_template.shape)
|
||||
res += ' faces: {}\n'.format(self.faces.shape)
|
||||
res += ' posedirs: {}\n'.format(self.posedirs.shape)
|
||||
res += ' shapedirs: {}\n'.format(self.shapedirs.shape)
|
||||
return res
|
||||
|
||||
def extend_poses(self, poses, **kwargs):
|
||||
if poses.shape[-1] == self.NUM_POSES_FULL:
|
||||
return poses
|
||||
if not self.use_root_rot:
|
||||
if kwargs.get('pose2rot', True):
|
||||
zero_rot = torch.zeros((*poses.shape[:-1], 3), dtype=poses.dtype, device=poses.device)
|
||||
poses = torch.cat([zero_rot, poses], dim=-1)
|
||||
elif poses.shape[-3] != self.NUM_POSES_FULL // 3:
|
||||
# insert a blank rotation
|
||||
zero_rot = torch.zeros((*poses.shape[:-3], 1, 3), dtype=poses.dtype, device=poses.device)
|
||||
zero_rot = batch_rodrigues(zero_rot)
|
||||
poses = torch.cat([zero_rot, poses], dim=-3)
|
||||
return poses
|
||||
|
||||
def jacobian_posesfull_poses(self, poses, poses_full):
|
||||
# TODO: cache this
|
||||
if self.use_root_rot:
|
||||
jacobian = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device)
|
||||
else:
|
||||
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
|
||||
eye_right = torch.eye(poses.shape[-1], dtype=poses.dtype, device=poses.device)
|
||||
jacobian = torch.cat([zero_root, eye_right], dim=0)
|
||||
return jacobian
|
||||
|
||||
def export_full_poses(self, poses, **kwargs):
|
||||
if not self.use_root_rot:
|
||||
poses = np.hstack([np.zeros((poses.shape[0], 3)), poses])
|
||||
return poses
|
||||
|
||||
def encode(self, body_params):
|
||||
# This function provide standard SMPL parameters to this model
|
||||
poses = body_params['poses']
|
||||
if 'Rh' not in body_params.keys():
|
||||
body_params['Rh'] = poses[:, :3].copy()
|
||||
if 'Th' not in body_params.keys():
|
||||
if 'trans' in body_params.keys():
|
||||
body_params['Th'] = body_params.pop('trans')
|
||||
else:
|
||||
body_params['Th'] = np.zeros((poses.shape[0], 3), dtype=poses.dtype)
|
||||
if not self.use_root_rot and poses.shape[1] == 72:
|
||||
body_params['poses'] = poses[:, 3:].copy()
|
||||
return body_params
|
||||
|
||||
class SMPLLayerEmbedding(SMPLModel):
|
||||
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
from human_body_prior.tools.model_loader import load_model
|
||||
from human_body_prior.models.vposer_model import VPoser
|
||||
vposer, _ = load_model(vposer_ckpt,
|
||||
model_code=VPoser,
|
||||
remove_words_in_model_weights='vp_model.',
|
||||
disable_grad=True)
|
||||
vposer.eval()
|
||||
vposer.to(self.device)
|
||||
self.vposer = vposer
|
||||
self.vposer_dim = 32
|
||||
self.NUM_POSES = self.vposer_dim
|
||||
|
||||
def encode(self, body_params):
|
||||
# This function provide standard SMPL parameters to this model
|
||||
poses = body_params['poses']
|
||||
if poses.shape[1] == self.vposer_dim:
|
||||
return body_params
|
||||
poses_tensor = torch.Tensor(poses).to(self.device)
|
||||
ret = self.vposer.encode(poses_tensor[:, :63]).mean
|
||||
body_params = super().encode(body_params)
|
||||
body_params['poses'] = ret.detach().cpu().numpy()
|
||||
return body_params
|
||||
|
||||
def extend_poses(self, poses, **kwargs):
|
||||
if poses.shape[-1] == self.vposer_dim:
|
||||
ret = self.vposer.decode(poses)
|
||||
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
|
||||
elif poses.shape[-1] == self.NUM_POSES_FULL:
|
||||
return poses
|
||||
elif poses.shape[-1] == self.NUM_POSES_FULL - 3:
|
||||
poses_zero = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
|
||||
poses = torch.cat([poses_zero, poses], dim=-1)
|
||||
return poses
|
||||
poses_zero = torch.zeros((poses_body.shape[0], 3), dtype=poses_body.dtype, device=poses_body.device)
|
||||
poses = torch.cat([poses_zero, poses_body, poses_zero, poses_zero], dim=1)
|
||||
return poses
|
||||
|
||||
def export_full_poses(self, poses, **kwargs):
|
||||
poses = torch.Tensor(poses).to(self.device)
|
||||
poses = self.extend_poses(poses)
|
||||
return poses.detach().cpu().numpy()
|
||||
|
||||
if __name__ == '__main__':
|
||||
vis = True
|
||||
test_config = {
|
||||
'smpl':{
|
||||
'model_path': 'data/bodymodels/SMPL_python_v.1.1.0/smpl/models/basicmodel_m_lbs_10_207_0_v1.1.0.pkl',
|
||||
'regressor_path': 'data/smplx/J_regressor_body25.npy',
|
||||
},
|
||||
'smplh':{
|
||||
'model_path': 'data/bodymodels/smplhv1.2/male/model.npz',
|
||||
'regressor_path': None,
|
||||
},
|
||||
'mano':{
|
||||
'model_path': 'data/bodymodels/manov1.2/MANO_LEFT.pkl',
|
||||
'regressor_path': None,
|
||||
},
|
||||
'flame':{
|
||||
'model_path': 'data/bodymodels/FLAME2020/FLAME_MALE.pkl',
|
||||
'regressor_path': None,
|
||||
}
|
||||
}
|
||||
for name, cfg in test_config.items():
|
||||
print('Testing {}...'.format(name))
|
||||
model = SMPLModel(**cfg)
|
||||
print(model)
|
||||
params = model.init_params()
|
||||
for key in params.keys():
|
||||
params[key] = (np.random.rand(*params[key].shape) - 0.5)*0.5
|
||||
vertices = model.vertices(params, return_tensor=True)[0]
|
||||
if cfg['regressor_path'] is not None:
|
||||
keypoints = model.keypoints(params, return_tensor=True)[0]
|
||||
print(keypoints.shape)
|
||||
if vis:
|
||||
import open3d as o3d
|
||||
mesh = o3d.geometry.TriangleMesh()
|
||||
mesh.vertices = o3d.utility.Vector3dVector(vertices.reshape(-1, 3))
|
||||
mesh.triangles = o3d.utility.Vector3iVector(model.faces.reshape(-1, 3))
|
||||
mesh.compute_vertex_normals()
|
||||
o3d.visualization.draw_geometries([mesh])
|
259
easymocap/bodymodel/smplx.py
Normal file
259
easymocap/bodymodel/smplx.py
Normal file
@ -0,0 +1,259 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .base import Model
|
||||
from .smpl import SMPLModel, SMPLLayerEmbedding, read_pickle, to_tensor
|
||||
from os.path import join
|
||||
import numpy as np
|
||||
|
||||
def read_hand(path, use_pca, use_flat_mean, num_pca_comps):
|
||||
data = read_pickle(path)
|
||||
mean = data['hands_mean'].reshape(1, -1).astype(np.float32)
|
||||
mean_full = mean
|
||||
components_full = data['hands_components'].astype(np.float32)
|
||||
weight = np.diag(components_full @ components_full.T)
|
||||
components = components_full[:num_pca_comps]
|
||||
weight = weight[:num_pca_comps]
|
||||
if use_flat_mean:
|
||||
mean = np.zeros_like(mean)
|
||||
return mean, components, weight, mean_full, components_full
|
||||
|
||||
class MANO(SMPLModel):
|
||||
def __init__(self, cfg_hand, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = 'mano'
|
||||
self.use_root_rot = False
|
||||
mean, components, weight, mean_full, components_full = read_hand(kwargs['model_path'], **cfg_hand)
|
||||
self.register_buffer('mean', to_tensor(mean, dtype=self.dtype))
|
||||
self.register_buffer('components', to_tensor(components, dtype=self.dtype))
|
||||
self.cfg_hand = cfg_hand
|
||||
self.to(self.device)
|
||||
if cfg_hand.use_pca:
|
||||
self.NUM_POSES = cfg_hand.num_pca_comps
|
||||
|
||||
def extend_poses(self, poses, **kwargs):
|
||||
if poses.shape[-1] == self.mean.shape[-1] + 3:
|
||||
return poses
|
||||
if self.cfg_hand.use_pca:
|
||||
poses = poses @ self.components
|
||||
if kwargs.get('pose2rot', True):
|
||||
poses = super().extend_poses(poses+self.mean, **kwargs)
|
||||
else:
|
||||
poses = super().extend_poses(poses, **kwargs)
|
||||
return poses
|
||||
|
||||
def jacobian_posesfull_poses(self, poses, poses_full):
|
||||
if self.cfg_hand.use_pca:
|
||||
jacobian = self.components.t()
|
||||
zero_root = torch.zeros((3, poses.shape[-1]), dtype=poses.dtype, device=poses.device)
|
||||
jacobian = torch.cat([zero_root, jacobian], dim=0)
|
||||
else:
|
||||
jacobian = super().jacobian_posesfull_poses(poses, poses_full)
|
||||
return jacobian
|
||||
class MANOLR(Model):
|
||||
def __init__(self, model_path, regressor_path, cfg_hand, **kwargs):
|
||||
super().__init__()
|
||||
self.name = 'manolr'
|
||||
keys = list(model_path.keys())
|
||||
# stack 方式:(nframes, nhand x ndim)
|
||||
self.keys = keys
|
||||
modules_hand = {}
|
||||
faces = []
|
||||
v_template = []
|
||||
cnt = 0
|
||||
for key in keys:
|
||||
modules_hand[key] = MANO(cfg_hand, model_path=model_path[key], regressor_path=regressor_path[key], **kwargs)
|
||||
v_template.append(modules_hand[key].v_template.cpu().numpy())
|
||||
faces.append(modules_hand[key].faces + cnt)
|
||||
cnt += v_template[-1].shape[0]
|
||||
self.device = modules_hand[key].device
|
||||
self.dtype = modules_hand[key].dtype
|
||||
if key == 'right':
|
||||
modules_hand[key].shapedirs[:, 0] *= -1
|
||||
modules_hand[key].j_shapedirs[:, 0] *= -1
|
||||
self.faces = np.vstack(faces)
|
||||
self.v_template = np.vstack(v_template)
|
||||
self.modules_hand = nn.ModuleDict(modules_hand)
|
||||
self.to(self.device)
|
||||
|
||||
def init_params(self, **kwargs):
|
||||
param_all = {}
|
||||
for key in self.keys:
|
||||
param = self.modules_hand[key].init_params(**kwargs)
|
||||
param_all[key] = param
|
||||
if False:
|
||||
params = {k: torch.cat([param_all[key][k] for key in self.keys], dim=-1) for k in param.keys()}
|
||||
else:
|
||||
params = {k: np.concatenate([param_all[key][k] for key in self.keys], axis=-1) for k in param.keys() if k != 'shapes'}
|
||||
params['shapes'] = param_all['left']['shapes']
|
||||
return params
|
||||
|
||||
def split(self, params):
|
||||
params_split = {}
|
||||
for imodel, model in enumerate(self.keys):
|
||||
param_= params.copy()
|
||||
for key in ['poses', 'shapes', 'Rh', 'Th']:
|
||||
if key not in params.keys():continue
|
||||
if key == 'shapes':
|
||||
continue
|
||||
shape = params[key].shape[-1]
|
||||
start = shape//len(self.keys)*imodel
|
||||
end = shape//len(self.keys)*(imodel+1)
|
||||
param_[key] = params[key][:, start:end]
|
||||
params_split[model] = param_
|
||||
return params_split
|
||||
|
||||
def forward(self, **params):
|
||||
params_split = self.split(params)
|
||||
rets = []
|
||||
for imodel, model in enumerate(self.keys):
|
||||
ret = self.modules_hand[model](**params_split[model])
|
||||
rets.append(ret)
|
||||
if params.get('return_tensor', True):
|
||||
rets = torch.cat(rets, dim=1)
|
||||
else:
|
||||
rets = np.concatenate(rets, axis=1)
|
||||
return rets
|
||||
|
||||
def extend_poses(self, poses, **kwargs):
|
||||
params_split = self.split({'poses': poses})
|
||||
rets = []
|
||||
for imodel, model in enumerate(self.keys):
|
||||
poses = params_split[model]['poses']
|
||||
poses = self.modules_hand[model].extend_poses(poses)
|
||||
rets.append(poses)
|
||||
poses = torch.cat(rets, dim=1)
|
||||
return poses
|
||||
|
||||
def export_full_poses(self, poses, **kwargs):
|
||||
params_split = self.split({'poses': poses})
|
||||
rets = []
|
||||
for imodel, model in enumerate(self.keys):
|
||||
poses = torch.Tensor(params_split[model]['poses']).to(self.device)
|
||||
poses = self.modules_hand[model].extend_poses(poses)
|
||||
rets.append(poses)
|
||||
poses = torch.cat(rets, dim=1)
|
||||
return poses.detach().cpu().numpy()
|
||||
|
||||
class SMPLHModel(SMPLModel):
|
||||
def __init__(self, mano_path, cfg_hand, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.NUM_POSES = self.NUM_POSES - 90
|
||||
meanl, componentsl, weight_l, self.mean_full_l, self.components_full_l = read_hand(join(mano_path, 'MANO_LEFT.pkl'), **cfg_hand)
|
||||
meanr, componentsr, weight_r, self.mean_full_r, self.components_full_r = read_hand(join(mano_path, 'MANO_RIGHT.pkl'), **cfg_hand)
|
||||
self.register_buffer('weight_l', to_tensor(weight_l, dtype=self.dtype))
|
||||
self.register_buffer('weight_r', to_tensor(weight_r, dtype=self.dtype))
|
||||
self.register_buffer('meanl', to_tensor(meanl, dtype=self.dtype))
|
||||
self.register_buffer('meanr', to_tensor(meanr, dtype=self.dtype))
|
||||
self.register_buffer('componentsl', to_tensor(componentsl, dtype=self.dtype))
|
||||
self.register_buffer('componentsr', to_tensor(componentsr, dtype=self.dtype))
|
||||
|
||||
self.register_buffer('jacobian_posesfull_poses_', self._jacobian_posesfull_poses())
|
||||
self.NUM_HANDS = cfg_hand.num_pca_comps if cfg_hand.use_pca else 45
|
||||
self.cfg_hand = cfg_hand
|
||||
self.to(self.device)
|
||||
|
||||
def _jacobian_posesfull_poses(self):
|
||||
# TODO: cache this
|
||||
# | body_full/body | 0 | 0 |
|
||||
# | 0 | l | 0 |
|
||||
# | 0 | 0 | r |
|
||||
eye_right = torch.eye(self.NUM_POSES, dtype=self.dtype)
|
||||
#
|
||||
jac_handl = self.componentsl.t()
|
||||
jac_handr = self.componentsr.t()
|
||||
output = torch.zeros((self.NUM_POSES_FULL, self.NUM_POSES+jac_handl.shape[1]*2), dtype=self.dtype)
|
||||
if self.use_root_rot:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
output[3:3+self.NUM_POSES, :self.NUM_POSES] = eye_right
|
||||
output[3+self.NUM_POSES:3+self.NUM_POSES+jac_handl.shape[0], \
|
||||
self.NUM_POSES:self.NUM_POSES+jac_handl.shape[1]] = jac_handl
|
||||
output[3+self.NUM_POSES+jac_handl.shape[0]:3+self.NUM_POSES+2*jac_handl.shape[0], \
|
||||
self.NUM_POSES+jac_handl.shape[1]:self.NUM_POSES+jac_handl.shape[1]*2] = jac_handr
|
||||
return output
|
||||
|
||||
def init_params(self, nFrames=1, nShapes=1, nPerson=1, ret_tensor=False, add_scale=False):
|
||||
params = super().init_params(nFrames, nShapes, nPerson, ret_tensor, add_scale=add_scale)
|
||||
handl = np.zeros((nFrames, self.NUM_HANDS))
|
||||
handr = np.zeros((nFrames, self.NUM_HANDS))
|
||||
if nPerson > 1:
|
||||
handl = handl[:, None].repeat(nPerson, axis=1)
|
||||
handr = handr[:, None].repeat(nPerson, axis=1)
|
||||
if ret_tensor:
|
||||
handl = to_tensor(handl, self.dtype, self.device)
|
||||
handr = to_tensor(handr, self.dtype, self.device)
|
||||
params['handl'] = handl
|
||||
params['handr'] = handr
|
||||
return params
|
||||
|
||||
def extend_poses(self, poses, handl=None, handr=None, **kwargs):
|
||||
if poses.shape[-1] == self.NUM_POSES_FULL:
|
||||
return poses
|
||||
poses = super().extend_poses(poses)
|
||||
if handl is None:
|
||||
handl = self.meanl.clone()
|
||||
handr = self.meanr.clone()
|
||||
handl = handl.expand(poses.shape[0], -1)
|
||||
handr = handr.expand(poses.shape[0], -1)
|
||||
else:
|
||||
if self.cfg_hand.use_pca:
|
||||
handl = handl @ self.componentsl
|
||||
handr = handr @ self.componentsr
|
||||
handl = handl +self.meanl
|
||||
handr = handr +self.meanr
|
||||
poses = torch.cat([poses, handl, handr], dim=-1)
|
||||
return poses
|
||||
|
||||
def export_full_poses(self, poses, handl, handr, **kwargs):
|
||||
poses = torch.Tensor(poses).to(self.device)
|
||||
handl = torch.Tensor(handl).to(self.device)
|
||||
handr = torch.Tensor(handr).to(self.device)
|
||||
poses = self.extend_poses(poses, handl, handr)
|
||||
return poses.detach().cpu().numpy()
|
||||
|
||||
class SMPLHModelEmbedding(SMPLHModel):
|
||||
def __init__(self, vposer_ckpt='data/body_models/vposer_v02', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
from human_body_prior.tools.model_loader import load_model
|
||||
from human_body_prior.models.vposer_model import VPoser
|
||||
vposer, _ = load_model(vposer_ckpt,
|
||||
model_code=VPoser,
|
||||
remove_words_in_model_weights='vp_model.',
|
||||
disable_grad=True)
|
||||
vposer.to(self.device)
|
||||
self.vposer = vposer
|
||||
self.vposer_dim = 32
|
||||
self.NUM_POSES = self.vposer_dim
|
||||
|
||||
def decode(self, poses, add_rot=True):
|
||||
if poses.shape[-1] == 66 and add_rot:
|
||||
return poses
|
||||
elif poses.shape[-1] == 63 and not add_rot:
|
||||
return poses
|
||||
assert poses.shape[-1] == self.vposer_dim, poses.shape
|
||||
ret = self.vposer.decode(poses)
|
||||
poses_body = ret['pose_body'].reshape(poses.shape[0], -1)
|
||||
if add_rot:
|
||||
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
|
||||
poses_body = torch.cat([zero_rot, poses_body], dim=-1)
|
||||
return poses_body
|
||||
|
||||
def extend_poses(self, poses, handl, handr, **kwargs):
|
||||
if poses.shape[-1] == self.NUM_POSES_FULL:
|
||||
return poses
|
||||
zero_rot = torch.zeros((poses.shape[0], 3), dtype=poses.dtype, device=poses.device)
|
||||
poses_body = self.decode(poses, add_rot=False)
|
||||
if self.cfg_hand.use_pca:
|
||||
handl = handl @ self.componentsl
|
||||
handr = handr @ self.componentsr
|
||||
handl = handl +self.meanl
|
||||
handr = handr +self.meanr
|
||||
poses = torch.cat([zero_rot, poses_body, handl, handr], dim=-1)
|
||||
return poses
|
||||
|
||||
def export_full_poses(self, poses, handl, handr, **kwargs):
|
||||
poses = torch.Tensor(poses).to(self.device)
|
||||
handl = torch.Tensor(handl).to(self.device)
|
||||
handr = torch.Tensor(handr).to(self.device)
|
||||
poses = self.extend_poses(poses, handl, handr)
|
||||
return poses.detach().cpu().numpy()
|
Loading…
Reference in New Issue
Block a user