EasyMocap/myeasymocap/backbone/pare/head/pare_head.py
2023-06-24 22:39:33 +08:00

926 lines
39 KiB
Python

# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from ..config import SMPL_MEAN_PARAMS
from ..layers.coattention import CoAttention
from ..utils.geometry import rot6d_to_rotmat, get_coord_maps
from ..utils.kp_utils import get_smpl_neighbor_triplets
from ..layers.softargmax import softargmax2d, get_heatmap_preds
from ..layers import LocallyConnected2d, KeypointAttention, interpolate
from ..layers.non_local import dot_product
from ..backbone.resnet import conv3x3, conv1x1, BasicBlock
class logger:
@staticmethod
def info(*args, **kwargs):
pass
BN_MOMENTUM = 0.1
class PareHead(nn.Module):
def __init__(
self,
num_joints,
num_input_features,
softmax_temp=1.0,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
num_camera_params=3,
num_features_smpl=64,
final_conv_kernel=1,
iterative_regression=False,
iter_residual=False,
num_iterations=3,
shape_input_type='feats', # 'feats.pose.shape.cam'
pose_input_type='feats', # 'feats.neighbor_pose_feats.all_pose.self_pose.neighbor_pose.shape.cam'
pose_mlp_num_layers=1,
shape_mlp_num_layers=1,
pose_mlp_hidden_size=256,
shape_mlp_hidden_size=256,
use_keypoint_features_for_smpl_regression=False,
use_heatmaps='',
use_keypoint_attention=False,
use_postconv_keypoint_attention=False,
keypoint_attention_act='softmax',
use_scale_keypoint_attention=False,
use_branch_nonlocal=None, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
use_final_nonlocal=None, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
backbone='resnet',
use_hmr_regression=False,
use_coattention=False,
num_coattention_iter=1,
coattention_conv='simple', # 'double_1', 'double_3', 'single_1', 'single_3', 'simple'
use_upsampling=False,
use_soft_attention=False, # Stefan & Otmar 3DV style attention
num_branch_iteration=0,
branch_deeper=False,
use_resnet_conv_hrnet=False,
use_position_encodings=None,
use_mean_camshape=False,
use_mean_pose=False,
init_xavier=False,
):
super(PareHead, self).__init__()
self.backbone = backbone
self.num_joints = num_joints
self.deconv_with_bias = False
self.use_heatmaps = use_heatmaps
self.num_iterations = num_iterations
self.use_final_nonlocal = use_final_nonlocal
self.use_branch_nonlocal = use_branch_nonlocal
self.use_hmr_regression = use_hmr_regression
self.use_coattention = use_coattention
self.num_coattention_iter = num_coattention_iter
self.coattention_conv = coattention_conv
self.use_soft_attention = use_soft_attention
self.num_branch_iteration = num_branch_iteration
self.iter_residual = iter_residual
self.iterative_regression = iterative_regression
self.pose_mlp_num_layers = pose_mlp_num_layers
self.shape_mlp_num_layers = shape_mlp_num_layers
self.pose_mlp_hidden_size = pose_mlp_hidden_size
self.shape_mlp_hidden_size = shape_mlp_hidden_size
self.use_keypoint_attention = use_keypoint_attention
self.use_keypoint_features_for_smpl_regression = use_keypoint_features_for_smpl_regression
self.use_position_encodings = use_position_encodings
self.use_mean_camshape = use_mean_camshape
self.use_mean_pose = use_mean_pose
self.num_input_features = num_input_features
if use_soft_attention:
# These options should be True by default when soft attention is used
self.use_keypoint_features_for_smpl_regression = True
self.use_hmr_regression = True
self.use_coattention = False
logger.warning('Coattention cannot be used together with soft attention')
logger.warning('Overriding use_coattention=False')
if use_coattention:
self.use_keypoint_features_for_smpl_regression = False
logger.warning('\"use_keypoint_features_for_smpl_regression\" cannot be used together with co-attention')
logger.warning('Overriding \"use_keypoint_features_for_smpl_regression\"=False')
if use_hmr_regression:
self.iterative_regression = False
logger.warning('iterative_regression cannot be used together with hmr regression')
if self.use_heatmaps in ['part_segm', 'attention']:
logger.info('\"Keypoint Attention\" should be activated to be able to use part segmentation')
logger.info('Overriding use_keypoint_attention')
self.use_keypoint_attention = True
assert num_iterations > 0, '\"num_iterations\" should be greater than 0.'
if use_position_encodings:
assert backbone.startswith('hrnet'), 'backbone should be hrnet to use position encodings'
# self.pos_enc = get_coord_maps(size=56)
self.register_buffer('pos_enc', get_coord_maps(size=56))
num_input_features += 2
self.num_input_features = num_input_features
if backbone.startswith('hrnet'):
if use_resnet_conv_hrnet:
logger.info('Using resnet block for keypoint and smpl conv layers...')
self.keypoint_deconv_layers = self._make_res_conv_layers(
input_channels=self.num_input_features,
num_channels=num_deconv_filters[-1],
num_basic_blocks=num_deconv_layers,
)
self.num_input_features = num_input_features
self.smpl_deconv_layers = self._make_res_conv_layers(
input_channels=self.num_input_features,
num_channels=num_deconv_filters[-1],
num_basic_blocks=num_deconv_layers,
)
else:
self.keypoint_deconv_layers = self._make_conv_layer(
num_deconv_layers,
num_deconv_filters,
(3,)*num_deconv_layers,
)
self.num_input_features = num_input_features
self.smpl_deconv_layers = self._make_conv_layer(
num_deconv_layers,
num_deconv_filters,
(3,)*num_deconv_layers,
)
else:
# part branch that estimates 2d keypoints
conv_fn = self._make_upsample_layer if use_upsampling else self._make_deconv_layer
if use_upsampling:
logger.info('Upsampling is active to increase spatial dimension')
logger.info(f'Upsampling conv kernels: {num_deconv_kernels}')
self.keypoint_deconv_layers = conv_fn(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
# reset inplanes to 2048 -> final resnet layer
self.num_input_features = num_input_features
self.smpl_deconv_layers = conv_fn(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
pose_mlp_inp_dim = num_deconv_filters[-1]
smpl_final_dim = num_features_smpl
shape_mlp_inp_dim = num_joints * smpl_final_dim
if self.use_soft_attention:
logger.info('Soft attention (Stefan & Otmar 3DV) is active')
self.keypoint_final_layer = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
conv1x1(256, num_joints+1 if self.use_heatmaps in ('part_segm', 'part_segm_pool') else num_joints),
)
soft_att_feature_size = smpl_final_dim # if use_hmr_regression else pose_mlp_inp_dim
self.smpl_final_layer = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
conv1x1(256, soft_att_feature_size),
)
# pose_mlp_inp_dim = soft_att_feature_size
else:
self.keypoint_final_layer = nn.Conv2d(
in_channels=num_deconv_filters[-1],
out_channels=num_joints+1 if self.use_heatmaps in ('part_segm', 'part_segm_pool') else num_joints,
kernel_size=final_conv_kernel,
stride=1,
padding=1 if final_conv_kernel == 3 else 0,
)
self.smpl_final_layer = nn.Conv2d(
in_channels=num_deconv_filters[-1],
out_channels=smpl_final_dim,
kernel_size=final_conv_kernel,
stride=1,
padding=1 if final_conv_kernel == 3 else 0,
)
# temperature for softargmax function
self.register_buffer('temperature', torch.tensor(softmax_temp))
# if self.iterative_regression or self.num_branch_iteration > 0 or self.use_coattention:
mean_params = np.load(SMPL_MEAN_PARAMS)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
if self.iterative_regression:
# enable iterative regression similar to HMR
# these are the features that can be used as input to final MLPs
input_type_dim = {
'feats': 0, # image features for self
'neighbor_pose_feats': 2 * 256, # image features from neighbor joints
'all_pose': 24 * 6, # rot6d of all joints from previous iter
'self_pose': 6, # rot6d of self
'neighbor_pose': 2 * 6, # rot6d of neighbor joints from previous iter
'shape': 10, # smpl betas/shape
'cam': num_camera_params, # weak perspective camera
}
assert 'feats' in shape_input_type, '\"feats\" should be the default value'
assert 'feats' in pose_input_type, '\"feats\" should be the default value'
self.shape_input_type = shape_input_type.split('.')
self.pose_input_type = pose_input_type.split('.')
pose_mlp_inp_dim = pose_mlp_inp_dim + sum([input_type_dim[x] for x in self.pose_input_type])
shape_mlp_inp_dim = shape_mlp_inp_dim + sum([input_type_dim[x] for x in self.shape_input_type])
logger.debug(f'Shape MLP takes \"{self.shape_input_type}\" as input, '
f'input dim: {shape_mlp_inp_dim}')
logger.debug(f'Pose MLP takes \"{self.pose_input_type}\" as input, '
f'input dim: {pose_mlp_inp_dim}')
self.pose_mlp_inp_dim = pose_mlp_inp_dim
self.shape_mlp_inp_dim = shape_mlp_inp_dim
if self.use_hmr_regression:
logger.info(f'HMR regression is active...')
# enable iterative regression similar to HMR
self.fc1 = nn.Linear(num_joints * smpl_final_dim + (num_joints * 6) + 10 + num_camera_params, 1024)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(1024, (num_joints * 6))
self.decshape = nn.Linear(1024, 10)
self.deccam = nn.Linear(1024, num_camera_params)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
else:
# here we use 2 different MLPs to estimate shape and camera
# They take a channelwise downsampled version of smpl features
self.shape_mlp = self._get_shape_mlp(output_size=10)
self.cam_mlp = self._get_shape_mlp(output_size=num_camera_params)
# for pose each joint has a separate MLP
# weights for these MLPs are not shared
# hence we use Locally Connected layers
# TODO support kernel_size > 1 to access context of other joints
self.pose_mlp = self._get_pose_mlp(num_joints=num_joints, output_size=6)
if init_xavier:
nn.init.xavier_uniform_(self.shape_mlp.weight, gain=0.01)
nn.init.xavier_uniform_(self.cam_mlp.weight, gain=0.01)
nn.init.xavier_uniform_(self.pose_mlp.weight, gain=0.01)
if self.use_branch_nonlocal:
logger.info(f'Branch nonlocal is active, type {self.use_branch_nonlocal}')
self.branch_2d_nonlocal = eval(self.use_branch_nonlocal).NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
self.branch_3d_nonlocal = eval(self.use_branch_nonlocal).NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
if self.use_final_nonlocal:
logger.info(f'Final nonlocal is active, type {self.use_final_nonlocal}')
self.final_pose_nonlocal = eval(self.use_final_nonlocal).NONLocalBlock1D(
in_channels=self.pose_mlp_inp_dim,
sub_sample=False,
bn_layer=True,
)
self.final_shape_nonlocal = eval(self.use_final_nonlocal).NONLocalBlock1D(
in_channels=num_features_smpl,
sub_sample=False,
bn_layer=True,
)
if self.use_keypoint_attention:
logger.info('Keypoint attention is active')
self.keypoint_attention = KeypointAttention(
use_conv=use_postconv_keypoint_attention,
in_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
out_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
act=keypoint_attention_act,
use_scale=use_scale_keypoint_attention,
)
if self.use_coattention:
logger.info(f'Coattention is active, final conv type {self.coattention_conv}')
self.coattention = CoAttention(n_channel=num_deconv_filters[-1], final_conv=self.coattention_conv)
if self.num_branch_iteration > 0:
logger.info(f'Branch iteration is active')
if branch_deeper:
self.branch_iter_2d_nonlocal = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
)
self.branch_iter_3d_nonlocal = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
)
else:
self.branch_iter_2d_nonlocal = dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
self.branch_iter_3d_nonlocal = dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
def _get_shape_mlp(self, output_size):
if self.shape_mlp_num_layers == 1:
return nn.Linear(self.shape_mlp_inp_dim, output_size)
module_list = []
for i in range(self.shape_mlp_num_layers):
if i == 0:
module_list.append(
nn.Linear(self.shape_mlp_inp_dim, self.shape_mlp_hidden_size)
)
elif i == self.shape_mlp_num_layers - 1:
module_list.append(
nn.Linear(self.shape_mlp_hidden_size, output_size)
)
else:
module_list.append(
nn.Linear(self.shape_mlp_hidden_size, self.shape_mlp_hidden_size)
)
return nn.Sequential(*module_list)
def _get_pose_mlp(self, num_joints, output_size):
if self.pose_mlp_num_layers == 1:
return LocallyConnected2d(
in_channels=self.pose_mlp_inp_dim,
out_channels=output_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
module_list = []
for i in range(self.pose_mlp_num_layers):
if i == 0:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_inp_dim,
out_channels=self.pose_mlp_hidden_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
elif i == self.pose_mlp_num_layers - 1:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_hidden_size,
out_channels=output_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
else:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_hidden_size,
out_channels=self.pose_mlp_hidden_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
return nn.Sequential(*module_list)
def _get_deconv_cfg(self, deconv_kernel):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_conv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_conv_layers is different len(num_conv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_conv_layers is different len(num_conv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
nn.Conv2d(
in_channels=self.num_input_features,
out_channels=planes,
kernel_size=kernel,
stride=1,
padding=padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.num_input_features = planes
return nn.Sequential(*layers)
def _make_res_conv_layers(self, input_channels, num_channels=64,
num_heads=1, num_basic_blocks=2):
head_layers = []
# kernel_sizes, strides, paddings = self._get_trans_cfg()
# for kernel_size, padding, stride in zip(kernel_sizes, paddings, strides):
head_layers.append(nn.Sequential(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_channels,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(num_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
)
for i in range(num_heads):
layers = []
for _ in range(num_basic_blocks):
layers.append(nn.Sequential(BasicBlock(num_channels, num_channels)))
head_layers.append(nn.Sequential(*layers))
# head_layers.append(nn.Conv2d(in_channels=num_channels, out_channels=output_channels,
# kernel_size=1, stride=1, padding=0))
return nn.Sequential(*head_layers)
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
nn.ConvTranspose2d(
in_channels=self.num_input_features,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
# if self.use_self_attention:
# layers.append(SelfAttention(planes))
self.num_input_features = planes
return nn.Sequential(*layers)
def _make_upsample_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_layers is different len(num_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_layers is different len(num_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
layers.append(
nn.Conv2d(in_channels=self.num_input_features, out_channels=planes,
kernel_size=kernel, stride=1, padding=padding, bias=self.deconv_with_bias)
)
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
# if self.use_self_attention:
# layers.append(SelfAttention(planes))
self.num_input_features = planes
return nn.Sequential(*layers)
def _prepare_pose_mlp_inp(self, feats, pred_pose, pred_shape, pred_cam):
# feats shape: [N, 256, J, 1]
# pose shape: [N, 6, J, 1]
# cam shape: [N, 3]
# beta shape: [N, 10]
batch_size, num_joints = pred_pose.shape[0], pred_pose.shape[2]
joint_triplets = get_smpl_neighbor_triplets()
inp_list = []
for inp_type in self.pose_input_type:
if inp_type == 'feats':
# add image features
inp_list.append(feats)
if inp_type == 'neighbor_pose_feats':
# add the image features from neighboring joints
n_pose_feat = []
for jt in joint_triplets:
n_pose_feat.append(
feats[:, :, jt[1:]].reshape(batch_size, -1, 1).unsqueeze(-2)
)
n_pose_feat = torch.cat(n_pose_feat, 2)
inp_list.append(n_pose_feat)
if inp_type == 'self_pose':
# add image features
inp_list.append(pred_pose)
if inp_type == 'all_pose':
# append all of the joint angels
all_pose = pred_pose.reshape(batch_size, -1, 1)[..., None].repeat(1, 1, num_joints, 1)
inp_list.append(all_pose)
if inp_type == 'neighbor_pose':
# append only the joint angles of neighboring ones
n_pose = []
for jt in joint_triplets:
n_pose.append(
pred_pose[:,:,jt[1:]].reshape(batch_size, -1, 1).unsqueeze(-2)
)
n_pose = torch.cat(n_pose, 2)
inp_list.append(n_pose)
if inp_type == 'shape':
# append shape predictions
pred_shape = pred_shape[..., None, None].repeat(1, 1, num_joints, 1)
inp_list.append(pred_shape)
if inp_type == 'cam':
# append camera predictions
pred_cam = pred_cam[..., None, None].repeat(1, 1, num_joints, 1)
inp_list.append(pred_cam)
assert len(inp_list) > 0
# for i,inp in enumerate(inp_list):
# print(i, inp.shape)
return torch.cat(inp_list, 1)
def _prepare_shape_mlp_inp(self, feats, pred_pose, pred_shape, pred_cam):
# feats shape: [N, 256, J, 1]
# pose shape: [N, 6, J, 1]
# cam shape: [N, 3]
# beta shape: [N, 10]
batch_size, num_joints = pred_pose.shape[:2]
inp_list = []
for inp_type in self.shape_input_type:
if inp_type == 'feats':
# add image features
inp_list.append(feats)
if inp_type == 'all_pose':
# append all of the joint angels
pred_pose = pred_pose.reshape(batch_size, -1)
inp_list.append(pred_pose)
if inp_type == 'shape':
# append shape predictions
inp_list.append(pred_shape)
if inp_type == 'cam':
# append camera predictions
inp_list.append(pred_cam)
assert len(inp_list) > 0
return torch.cat(inp_list, 1)
def forward(self, features, gt_segm=None):
batch_size = features.shape[0]
init_pose = self.init_pose.expand(batch_size, -1) # N, Jx6
init_shape = self.init_shape.expand(batch_size, -1)
init_cam = self.init_cam.expand(batch_size, -1)
if self.use_position_encodings:
features = torch.cat((features, self.pos_enc.repeat(features.shape[0], 1, 1, 1)), 1)
output = {}
############## 2D PART BRANCH FEATURES ##############
part_feats = self._get_2d_branch_feats(features)
############## GET PART ATTENTION MAP ##############
part_attention = self._get_part_attention_map(part_feats, output)
############## 3D SMPL BRANCH FEATURES ##############
smpl_feats = self._get_3d_smpl_feats(features, part_feats)
############## SAMPLE LOCAL FEATURES ##############
if gt_segm is not None:
# logger.debug(gt_segm.shape)
# import IPython; IPython.embed(); exit()
gt_segm = F.interpolate(gt_segm.unsqueeze(1).float(), scale_factor=(1/4, 1/4), mode='nearest').long().squeeze(1)
part_attention = F.one_hot(gt_segm.to('cpu'), num_classes=self.num_joints + 1).permute(0,3,1,2).float()[:,1:,:,:]
part_attention = part_attention.to('cuda')
# part_attention = F.interpolate(part_attention, scale_factor=1/4, mode='bilinear', align_corners=True)
# import IPython; IPython.embed(); exit()
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
############## GET FINAL PREDICTIONS ##############
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, init_pose, init_shape, init_cam
)
if self.use_coattention:
for c in range(self.num_coattention_iter):
smpl_feats, part_feats = self.coattention(smpl_feats, part_feats)
part_attention = self._get_part_attention_map(part_feats, output)
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, pred_pose, pred_shape, pred_cam
)
if self.num_branch_iteration > 0:
for nbi in range(self.num_branch_iteration):
if self.use_soft_attention:
smpl_feats = self.branch_iter_3d_nonlocal(smpl_feats)
part_feats = self.branch_iter_2d_nonlocal(part_feats)
else:
smpl_feats = self.branch_iter_3d_nonlocal(smpl_feats)
part_feats = smpl_feats
part_attention = self._get_part_attention_map(part_feats, output)
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, pred_pose, pred_shape, pred_cam,
)
pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3)
output.update({
'pred_pose': pred_rotmat,
'pred_cam': pred_cam,
'pred_shape': pred_shape,
})
return output
def _get_local_feats(self, smpl_feats, part_attention, output):
cam_shape_feats = self.smpl_final_layer(smpl_feats)
if self.use_keypoint_attention:
point_local_feat = self.keypoint_attention(smpl_feats, part_attention)
cam_shape_feats = self.keypoint_attention(cam_shape_feats, part_attention)
else:
point_local_feat = interpolate(smpl_feats, output['pred_kp2d'])
cam_shape_feats = interpolate(cam_shape_feats, output['pred_kp2d'])
return point_local_feat, cam_shape_feats
def _get_2d_branch_feats(self, features):
part_feats = self.keypoint_deconv_layers(features)
if self.use_branch_nonlocal:
part_feats = self.branch_2d_nonlocal(part_feats)
return part_feats
def _get_3d_smpl_feats(self, features, part_feats):
if self.use_keypoint_features_for_smpl_regression:
smpl_feats = part_feats
else:
smpl_feats = self.smpl_deconv_layers(features)
if self.use_branch_nonlocal:
smpl_feats = self.branch_3d_nonlocal(smpl_feats)
return smpl_feats
def _get_part_attention_map(self, part_feats, output):
heatmaps = self.keypoint_final_layer(part_feats)
if self.use_heatmaps == 'hm':
# returns coords between [-1,1]
pred_kp2d, confidence = get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
output['pred_kp2d_conf'] = confidence
output['pred_heatmaps_2d'] = heatmaps
elif self.use_heatmaps == 'hm_soft':
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
output['pred_heatmaps_2d'] = heatmaps
elif self.use_heatmaps == 'part_segm':
output['pred_segm_mask'] = heatmaps
heatmaps = heatmaps[:,1:,:,:] # remove the first channel which encodes the background
elif self.use_heatmaps == 'part_segm_pool':
output['pred_segm_mask'] = heatmaps
heatmaps = heatmaps[:,1:,:,:] # remove the first channel which encodes the background
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature) # get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
for k, v in output.items():
if torch.any(torch.isnan(v)):
logger.debug(f'{k} is Nan!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
if torch.any(torch.isinf(v)):
logger.debug(f'{k} is Inf!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
# if torch.any(torch.isnan(pred_kp2d)):
# print('pred_kp2d nan', pred_kp2d.min(), pred_kp2d.max())
# if torch.any(torch.isnan(heatmaps)):
# print('heatmap nan', heatmaps.min(), heatmaps.max())
#
# if torch.any(torch.isinf(pred_kp2d)):
# print('pred_kp2d inf', pred_kp2d.min(), pred_kp2d.max())
# if torch.any(torch.isinf(heatmaps)):
# print('heatmap inf', heatmaps.min(), heatmaps.max())
elif self.use_heatmaps == 'attention':
output['pred_attention'] = heatmaps
else:
# returns coords between [-1,1]
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
output['pred_heatmaps_2d'] = heatmaps
return heatmaps
def _get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam):
if self.use_hmr_regression:
return self._hmr_get_final_preds(cam_shape_feats, init_pose, init_shape, init_cam)
else:
return self._pare_get_final_preds(pose_feats, cam_shape_feats, init_pose, init_shape, init_cam)
def _hmr_get_final_preds(self, cam_shape_feats, init_pose, init_shape, init_cam):
if self.use_final_nonlocal:
cam_shape_feats = self.final_shape_nonlocal(cam_shape_feats)
xf = torch.flatten(cam_shape_feats, start_dim=1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(3):
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
return pred_pose, pred_shape, pred_cam
def _pare_get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam):
pose_feats = pose_feats.unsqueeze(-1) #
if init_pose.shape[-1] == 6:
# This means init_pose comes from a previous iteration
init_pose = init_pose.transpose(2,1).unsqueeze(-1)
else:
# This means init pose comes from mean pose
init_pose = init_pose.reshape(init_pose.shape[0], 6, -1).unsqueeze(-1)
if self.iterative_regression:
shape_feats = torch.flatten(cam_shape_feats, start_dim=1)
pred_pose = init_pose # [N, 6, J, 1]
pred_cam = init_cam # [N, 3]
pred_shape = init_shape # [N, 10]
# import IPython; IPython.embed(); exit(1)
for i in range(self.num_iterations):
# pose_feats shape: [N, 256, 24, 1]
# shape_feats shape: [N, 24*64]
pose_mlp_inp = self._prepare_pose_mlp_inp(pose_feats, pred_pose, pred_shape, pred_cam)
shape_mlp_inp = self._prepare_shape_mlp_inp(shape_feats, pred_pose, pred_shape, pred_cam)
# print('pose_mlp_inp', pose_mlp_inp.shape)
# print('shape_mlp_inp', shape_mlp_inp.shape)
# TODO: this does not work but let it go since we dont use iterative regression for now.
# if self.use_final_nonlocal:
# pose_mlp_inp = self.final_pose_nonlocal(pose_mlp_inp)
# shape_mlp_inp = self.final_shape_nonlocal(shape_mlp_inp)
if self.iter_residual:
pred_pose = self.pose_mlp(pose_mlp_inp) + pred_pose
pred_cam = self.cam_mlp(shape_mlp_inp) + pred_cam
pred_shape = self.shape_mlp(shape_mlp_inp) + pred_shape
else:
pred_pose = self.pose_mlp(pose_mlp_inp)
pred_cam = self.cam_mlp(shape_mlp_inp)
pred_shape = self.shape_mlp(shape_mlp_inp) + init_shape
else:
shape_feats = cam_shape_feats
if self.use_final_nonlocal:
pose_feats = self.final_pose_nonlocal(pose_feats.squeeze(-1)).unsqueeze(-1)
shape_feats = self.final_shape_nonlocal(shape_feats)
shape_feats = torch.flatten(shape_feats, start_dim=1)
pred_pose = self.pose_mlp(pose_feats)
pred_cam = self.cam_mlp(shape_feats)
pred_shape = self.shape_mlp(shape_feats)
if self.use_mean_camshape:
pred_cam = pred_cam + init_cam
pred_shape = pred_shape + init_shape
if self.use_mean_pose:
pred_pose = pred_pose + init_pose
pred_pose = pred_pose.squeeze(-1).transpose(2, 1) # N, J, 6
return pred_pose, pred_shape, pred_cam
def forward_pretraining(self, features):
# TODO: implement pretraining
kp_feats = self.keypoint_deconv_layers(features)
heatmaps = self.keypoint_final_layer(kp_feats)
output = {}
if self.use_heatmaps == 'hm':
# returns coords between [-1,1]
pred_kp2d, confidence = get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
output['pred_kp2d_conf'] = confidence
elif self.use_heatmaps == 'hm_soft':
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
else:
# returns coords between [-1,1]
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
if self.use_keypoint_features_for_smpl_regression:
smpl_feats = kp_feats
else:
smpl_feats = self.smpl_deconv_layers(features)
cam_shape_feats = self.smpl_final_layer(smpl_feats)
output.update({
'kp_feats': heatmaps,
'heatmaps': heatmaps,
'smpl_feats': smpl_feats,
'cam_shape_feats': cam_shape_feats,
})
return output