97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
|
'''
|
||
|
@ Date: 2022-07-28 14:39:23
|
||
|
@ Author: Qing Shuai
|
||
|
@ Mail: s_q@zju.edu.cn
|
||
|
@ LastEditors: Qing Shuai
|
||
|
@ LastEditTime: 2022-08-12 21:42:12
|
||
|
@ FilePath: /EasyMocapPublic/easymocap/multistage/totalfitting.py
|
||
|
'''
|
||
|
import torch
|
||
|
|
||
|
from ..bodymodel.lbs import batch_rodrigues
|
||
|
from .torchgeometry import rotation_matrix_to_axis_angle, rotation_matrix_to_quaternion, quaternion_to_rotation_matrix, quaternion_to_axis_angle
|
||
|
import numpy as np
|
||
|
from .base_ops import BeforeAfterBase
|
||
|
|
||
|
def compute_twist_rotation(rotation_matrix, twist_axis):
|
||
|
'''
|
||
|
Compute the twist component of given rotation and twist axis
|
||
|
https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis
|
||
|
Parameters
|
||
|
----------
|
||
|
rotation_matrix : Tensor (B, 3, 3,)
|
||
|
The rotation to convert
|
||
|
twist_axis : Tensor (B, 3,)
|
||
|
The twist axis
|
||
|
Returns
|
||
|
-------
|
||
|
Tensor (B, 3, 3)
|
||
|
The twist rotation
|
||
|
'''
|
||
|
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
||
|
|
||
|
twist_axis = twist_axis / (torch.norm(twist_axis, dim=1, keepdim=True) + 1e-9)
|
||
|
|
||
|
projection = torch.einsum('bi,bi->b', twist_axis, quaternion[:, 1:]).unsqueeze(-1) * twist_axis
|
||
|
|
||
|
twist_quaternion = torch.cat([quaternion[:, 0:1], projection], dim=1)
|
||
|
twist_quaternion = twist_quaternion / (torch.norm(twist_quaternion, dim=1, keepdim=True) + 1e-9)
|
||
|
|
||
|
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
|
||
|
|
||
|
twist_aa = quaternion_to_axis_angle(twist_quaternion)
|
||
|
|
||
|
twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
|
||
|
|
||
|
return twist_rotation, twist_angle
|
||
|
|
||
|
class ClearTwist(BeforeAfterBase):
|
||
|
def start(self, body_params):
|
||
|
idx_elbow = [18-1, 19-1]
|
||
|
for idx in idx_elbow:
|
||
|
# x
|
||
|
body_params['poses'][:, 3*idx] = 0.
|
||
|
# z
|
||
|
body_params['poses'][:, 3*idx+2] = 0.
|
||
|
idx_wrist = [20-1, 21-1]
|
||
|
for idx in idx_wrist:
|
||
|
body_params['poses'][:, 3*idx:3*idx+3] = 0.
|
||
|
return body_params
|
||
|
|
||
|
class SolveTwist(BeforeAfterBase):
|
||
|
def __init__(self, body_model=None) -> None:
|
||
|
self.body_model = body_model
|
||
|
|
||
|
def final(self, body_params):
|
||
|
T_joints, T_vertices = self.body_model.transform(body_params)
|
||
|
# This transform don't consider RT
|
||
|
R = batch_rodrigues(body_params['Rh'])
|
||
|
template = self.body_model.keypoints({'shapes': body_params['shapes'],
|
||
|
'poses': torch.zeros_like(body_params['poses'])},
|
||
|
only_shape=True, return_smpl_joints=True)
|
||
|
config = {
|
||
|
'left': {
|
||
|
'index_smpl': 20,
|
||
|
'index_elbow_smpl': 18,
|
||
|
'R_global': 'R_handl3d',
|
||
|
'axis': torch.Tensor([[1., 0., 0.]]).to(device=T_joints.device),
|
||
|
},
|
||
|
'right': {
|
||
|
'index_smpl': 21,
|
||
|
'index_elbow_smpl': 19,
|
||
|
'R_global': 'R_handr3d',
|
||
|
'axis': torch.Tensor([[-1., 0., 0.]]).to(device=T_joints.device),
|
||
|
}
|
||
|
}
|
||
|
for key in ['left', 'right']:
|
||
|
cfg = config[key]
|
||
|
R_wrist_add = batch_rodrigues(body_params[cfg['R_global']])
|
||
|
idx_elbow = cfg['index_elbow_smpl']
|
||
|
idx_wrist = cfg['index_smpl']
|
||
|
pred_parent_elbow = R @ T_joints[..., idx_elbow, :3, :3]
|
||
|
pred_parent_wrist = R @ T_joints[..., idx_wrist, :3, :3]
|
||
|
pred_global_wrist = torch.bmm(R_wrist_add, pred_parent_wrist)
|
||
|
pred_local_wrist = torch.bmm(pred_parent_wrist.transpose(-1, -2), pred_global_wrist)
|
||
|
axis = rotation_matrix_to_axis_angle(pred_local_wrist)
|
||
|
body_params['poses'][..., 3*idx_wrist-3:3*idx_wrist] = axis
|
||
|
return body_params
|