105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
|
# holder of all proprietary rights on this computer program.
|
|
# You can only use this computer program if you have closed
|
|
# a license agreement with MPG or you get the right to use the computer
|
|
# program from someone who is authorized to grant you that right.
|
|
# Any use of the computer program without a valid license is prohibited and
|
|
# liable to prosecution.
|
|
#
|
|
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
|
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
|
# for Intelligent Systems. All rights reserved.
|
|
#
|
|
# Contact: ps-license@tuebingen.mpg.de
|
|
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
from smplx import SMPL as _SMPL
|
|
from smplx.utils import SMPLOutput
|
|
from smplx.lbs import vertices2joints
|
|
|
|
from .. import config, constants
|
|
from ..utils.geometry import perspective_projection, convert_weak_perspective_to_perspective
|
|
|
|
|
|
class SMPL(_SMPL):
|
|
""" Extension of the official SMPL implementation to support more joints """
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(SMPL, self).__init__(*args, **kwargs)
|
|
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
|
|
J_regressor_extra = np.load(config.JOINT_REGRESSOR_TRAIN_EXTRA)
|
|
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
|
|
self.joint_map = torch.tensor(joints, dtype=torch.long)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
kwargs['get_skin'] = True
|
|
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
|
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
|
|
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
|
|
joints = joints[:, self.joint_map, :]
|
|
output = SMPLOutput(vertices=smpl_output.vertices,
|
|
global_orient=smpl_output.global_orient,
|
|
body_pose=smpl_output.body_pose,
|
|
joints=joints,
|
|
betas=smpl_output.betas,
|
|
full_pose=smpl_output.full_pose)
|
|
return output
|
|
|
|
|
|
class SMPLHead(nn.Module):
|
|
def __init__(self, focal_length=5000., img_res=224):
|
|
super(SMPLHead, self).__init__()
|
|
self.smpl = SMPL(config.SMPL_MODEL_DIR, create_transl=False)
|
|
self.add_module('smpl', self.smpl)
|
|
self.focal_length = focal_length
|
|
self.img_res = img_res
|
|
|
|
def forward(self, rotmat, shape, cam=None, normalize_joints2d=False):
|
|
'''
|
|
:param rotmat: rotation in euler angles format (N,J,3,3)
|
|
:param shape: smpl betas
|
|
:param cam: weak perspective camera
|
|
:param normalize_joints2d: bool, normalize joints between -1, 1 if true
|
|
:return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True
|
|
'''
|
|
smpl_output = self.smpl(
|
|
betas=shape,
|
|
body_pose=rotmat[:, 1:].contiguous(),
|
|
global_orient=rotmat[:, 0].unsqueeze(1).contiguous(),
|
|
pose2rot=False,
|
|
)
|
|
|
|
output = {
|
|
'smpl_vertices': smpl_output.vertices,
|
|
'smpl_joints3d': smpl_output.joints,
|
|
}
|
|
if cam is not None:
|
|
joints3d = smpl_output.joints
|
|
batch_size = joints3d.shape[0]
|
|
device = joints3d.device
|
|
cam_t = convert_weak_perspective_to_perspective(
|
|
cam,
|
|
focal_length=self.focal_length,
|
|
img_res=self.img_res,
|
|
)
|
|
joints2d = perspective_projection(
|
|
joints3d,
|
|
rotation=torch.eye(3, device=device).unsqueeze(0).expand(batch_size, -1, -1),
|
|
translation=cam_t,
|
|
focal_length=self.focal_length,
|
|
camera_center=torch.zeros(batch_size, 2, device=device)
|
|
)
|
|
if normalize_joints2d:
|
|
# Normalize keypoints to [-1,1]
|
|
joints2d = joints2d / (self.img_res / 2.)
|
|
|
|
output['smpl_joints2d'] = joints2d
|
|
output['pred_cam_t'] = cam_t
|
|
|
|
return output
|