EasyMocap/myeasymocap/backbone/pare/head/smpl_cam_head.py
2023-06-24 22:39:33 +08:00

133 lines
4.4 KiB
Python

# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
from .. import config
from .smpl_head import SMPL
class SMPLCamHead(nn.Module):
def __init__(self, img_res=224):
super(SMPLCamHead, self).__init__()
self.smpl = SMPL(config.SMPL_MODEL_DIR, create_transl=False)
self.add_module('smpl', self.smpl)
self.img_res = img_res
def forward(self, rotmat, shape, cam, cam_rotmat, cam_intrinsics,
bbox_scale, bbox_center, img_w, img_h, normalize_joints2d=False):
'''
:param rotmat: rotation in euler angles format (N,J,3,3)
:param shape: smpl betas
:param cam: weak perspective camera
:param normalize_joints2d: bool, normalize joints between -1, 1 if true
:param cam_rotmat (Nx3x3) camera rotation matrix
:param cam_intrinsics (Nx3x3) camera intrinsics matrix
:param bbox_scale (N,) bbox height normalized by 200
:param bbox_center (N,2) bbox center
:param img_w (N,) original image width
:param img_h (N,) original image height
:return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True
'''
smpl_output = self.smpl(
betas=shape,
body_pose=rotmat[:, 1:].contiguous(),
global_orient=rotmat[:, 0].unsqueeze(1).contiguous(),
pose2rot=False,
)
output = {
'smpl_vertices': smpl_output.vertices,
'smpl_joints3d': smpl_output.joints,
}
joints3d = smpl_output.joints
cam_t = convert_pare_to_full_img_cam(
pare_cam=cam,
bbox_height=bbox_scale * 200.,
bbox_center=bbox_center,
img_w=img_w,
img_h=img_h,
focal_length=cam_intrinsics[:, 0, 0],
crop_res=self.img_res,
)
joints2d = perspective_projection(
joints3d,
rotation=cam_rotmat,
translation=cam_t,
cam_intrinsics=cam_intrinsics,
)
# logger.debug(f'PARE cam: {cam}')
# logger.debug(f'FIMG cam: {cam_t}')
# logger.debug(f'joints2d: {joints2d}')
if normalize_joints2d:
# Normalize keypoints to [-1,1]
joints2d = joints2d / (self.img_res / 2.)
output['smpl_joints2d'] = joints2d
output['pred_cam_t'] = cam_t
return output
def perspective_projection(points, rotation, translation, cam_intrinsics):
"""
This function computes the perspective projection of a set of points.
Input:
points (bs, N, 3): 3D points
rotation (bs, 3, 3): Camera rotation
translation (bs, 3): Camera translation
cam_intrinsics (bs, 3, 3): Camera intrinsics
"""
K = cam_intrinsics
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
return projected_points[:, :, :-1]
def convert_pare_to_full_img_cam(
pare_cam, bbox_height, bbox_center,
img_w, img_h, focal_length, crop_res=224):
# Converts weak perspective camera estimated by PARE in
# bbox coords to perspective camera in full image coordinates
# from https://arxiv.org/pdf/2009.06549.pdf
s, tx, ty = pare_cam[:, 0], pare_cam[:, 1], pare_cam[:, 2]
res = 224
r = bbox_height / res
tz = 2 * focal_length / (r * res * s)
cx = 2 * (bbox_center[:, 0] - (img_w / 2.)) / (s * bbox_height)
cy = 2 * (bbox_center[:, 1] - (img_h / 2.)) / (s * bbox_height)
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
return cam_t