133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
|
# holder of all proprietary rights on this computer program.
|
|
# You can only use this computer program if you have closed
|
|
# a license agreement with MPG or you get the right to use the computer
|
|
# program from someone who is authorized to grant you that right.
|
|
# Any use of the computer program without a valid license is prohibited and
|
|
# liable to prosecution.
|
|
#
|
|
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
|
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
|
# for Intelligent Systems. All rights reserved.
|
|
#
|
|
# Contact: ps-license@tuebingen.mpg.de
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .. import config
|
|
from .smpl_head import SMPL
|
|
|
|
|
|
class SMPLCamHead(nn.Module):
|
|
def __init__(self, img_res=224):
|
|
super(SMPLCamHead, self).__init__()
|
|
self.smpl = SMPL(config.SMPL_MODEL_DIR, create_transl=False)
|
|
self.add_module('smpl', self.smpl)
|
|
|
|
self.img_res = img_res
|
|
|
|
def forward(self, rotmat, shape, cam, cam_rotmat, cam_intrinsics,
|
|
bbox_scale, bbox_center, img_w, img_h, normalize_joints2d=False):
|
|
'''
|
|
:param rotmat: rotation in euler angles format (N,J,3,3)
|
|
:param shape: smpl betas
|
|
:param cam: weak perspective camera
|
|
:param normalize_joints2d: bool, normalize joints between -1, 1 if true
|
|
:param cam_rotmat (Nx3x3) camera rotation matrix
|
|
:param cam_intrinsics (Nx3x3) camera intrinsics matrix
|
|
:param bbox_scale (N,) bbox height normalized by 200
|
|
:param bbox_center (N,2) bbox center
|
|
:param img_w (N,) original image width
|
|
:param img_h (N,) original image height
|
|
:return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True
|
|
'''
|
|
smpl_output = self.smpl(
|
|
betas=shape,
|
|
body_pose=rotmat[:, 1:].contiguous(),
|
|
global_orient=rotmat[:, 0].unsqueeze(1).contiguous(),
|
|
pose2rot=False,
|
|
)
|
|
|
|
output = {
|
|
'smpl_vertices': smpl_output.vertices,
|
|
'smpl_joints3d': smpl_output.joints,
|
|
}
|
|
|
|
joints3d = smpl_output.joints
|
|
|
|
cam_t = convert_pare_to_full_img_cam(
|
|
pare_cam=cam,
|
|
bbox_height=bbox_scale * 200.,
|
|
bbox_center=bbox_center,
|
|
img_w=img_w,
|
|
img_h=img_h,
|
|
focal_length=cam_intrinsics[:, 0, 0],
|
|
crop_res=self.img_res,
|
|
)
|
|
|
|
joints2d = perspective_projection(
|
|
joints3d,
|
|
rotation=cam_rotmat,
|
|
translation=cam_t,
|
|
cam_intrinsics=cam_intrinsics,
|
|
)
|
|
|
|
# logger.debug(f'PARE cam: {cam}')
|
|
# logger.debug(f'FIMG cam: {cam_t}')
|
|
# logger.debug(f'joints2d: {joints2d}')
|
|
|
|
|
|
if normalize_joints2d:
|
|
# Normalize keypoints to [-1,1]
|
|
joints2d = joints2d / (self.img_res / 2.)
|
|
|
|
output['smpl_joints2d'] = joints2d
|
|
output['pred_cam_t'] = cam_t
|
|
|
|
return output
|
|
|
|
|
|
def perspective_projection(points, rotation, translation, cam_intrinsics):
|
|
"""
|
|
This function computes the perspective projection of a set of points.
|
|
Input:
|
|
points (bs, N, 3): 3D points
|
|
rotation (bs, 3, 3): Camera rotation
|
|
translation (bs, 3): Camera translation
|
|
cam_intrinsics (bs, 3, 3): Camera intrinsics
|
|
"""
|
|
K = cam_intrinsics
|
|
|
|
# Transform points
|
|
points = torch.einsum('bij,bkj->bki', rotation, points)
|
|
points = points + translation.unsqueeze(1)
|
|
|
|
# Apply perspective distortion
|
|
projected_points = points / points[:,:,-1].unsqueeze(-1)
|
|
|
|
# Apply camera intrinsics
|
|
projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
|
|
|
|
return projected_points[:, :, :-1]
|
|
|
|
|
|
def convert_pare_to_full_img_cam(
|
|
pare_cam, bbox_height, bbox_center,
|
|
img_w, img_h, focal_length, crop_res=224):
|
|
# Converts weak perspective camera estimated by PARE in
|
|
# bbox coords to perspective camera in full image coordinates
|
|
# from https://arxiv.org/pdf/2009.06549.pdf
|
|
s, tx, ty = pare_cam[:, 0], pare_cam[:, 1], pare_cam[:, 2]
|
|
res = 224
|
|
r = bbox_height / res
|
|
tz = 2 * focal_length / (r * res * s)
|
|
|
|
cx = 2 * (bbox_center[:, 0] - (img_w / 2.)) / (s * bbox_height)
|
|
cy = 2 * (bbox_center[:, 1] - (img_h / 2.)) / (s * bbox_height)
|
|
|
|
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
|
|
|
return cam_t |