This commit is contained in:
shuaiqing 2023-06-24 22:39:33 +08:00
parent e30a28bff0
commit ad0791fac6
25 changed files with 5768 additions and 1 deletions

View File

@ -14,7 +14,7 @@ args:
key_from_previous: [bbox] key_from_previous: [bbox]
key_keep: [] key_keep: []
args: args:
ckpt: /nas/home/shuaiqing/Code/EasyMocapPublic/data/models/pose_hrnet_w48_384x288.pth ckpt: data/models/pose_hrnet_w48_384x288.pth
vis2d: vis2d:
module: myeasymocap.io.vis.Vis2D module: myeasymocap.io.vis.Vis2D
skip: False skip: False

View File

@ -0,0 +1,3 @@
# from .hrnet_pare import *
from .resnet import *
from .mobilenet import *

View File

@ -0,0 +1,631 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
import os
import torch
import torch.nn as nn
# from loguru import logger
import torch.nn.functional as F
from yacs.config import CfgNode as CN
models = [
'hrnet_w32',
'hrnet_w48',
]
BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)
def _check_branches(self, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(
num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM
),
)
layers = []
layers.append(
block(
self.num_inchannels[branch_index],
num_channels[branch_index],
stride,
downsample
)
)
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.num_inchannels[branch_index],
num_channels[branch_index]
)
)
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels)
)
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_inchannels[i],
1, 1, 0, bias=False
),
nn.BatchNorm2d(num_inchannels[i]),
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
)
)
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False
),
nn.BatchNorm2d(num_outchannels_conv3x3)
)
)
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False
),
nn.BatchNorm2d(num_outchannels_conv3x3),
nn.ReLU(True)
)
)
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {
'BASIC': BasicBlock,
'BOTTLENECK': Bottleneck
}
class PoseHighResolutionNet(nn.Module):
def __init__(self, cfg):
self.inplanes = 64
extra = cfg['MODEL']['EXTRA']
super(PoseHighResolutionNet, self).__init__()
self.cfg = extra
# stem net
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, 64, 4)
self.stage2_cfg = extra['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))
]
self.transition1 = self._make_transition_layer([256], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = extra['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))
]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = extra['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))
]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
self.final_layer = nn.Conv2d(
in_channels=pre_stage_channels[0],
out_channels=cfg['MODEL']['NUM_JOINTS'],
kernel_size=extra['FINAL_CONV_KERNEL'],
stride=1,
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
)
self.pretrained_layers = extra['PRETRAINED_LAYERS']
if extra.DOWNSAMPLE and extra.USE_CONV:
self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
elif not extra.DOWNSAMPLE and extra.USE_CONV:
self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
3, 1, 1, bias=False
),
nn.BatchNorm2d(num_channels_cur_layer[i]),
nn.ReLU(inplace=True)
)
)
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i+1-num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False
),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True)
)
)
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(
num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output
)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
layers = []
for i in range(num_layers):
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
layers.append(
nn.Conv2d(
in_channels=num_channel, out_channels=num_channel,
kernel_size=kernel_size, stride=1, padding=1, bias=False,
)
)
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
layers = []
for i in range(num_layers):
layers.append(
nn.Conv2d(
in_channels=num_channel, out_channels=num_channel,
kernel_size=kernel_size, stride=2, padding=1, bias=False,
)
)
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
if self.cfg.DOWNSAMPLE:
if self.cfg.USE_CONV:
# Downsampling with strided convolutions
x1 = self.downsample_stage_1(x[0])
x2 = self.downsample_stage_2(x[1])
x3 = self.downsample_stage_3(x[2])
x = torch.cat([x1, x2, x3, x[3]], 1)
else:
# Downsampling with interpolation
x0_h, x0_w = x[3].size(2), x[3].size(3)
x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x = torch.cat([x1, x2, x3, x[3]], 1)
else:
if self.cfg.USE_CONV:
# Upsampling with interpolations + convolutions
x1 = self.upsample_stage_2(x[1])
x2 = self.upsample_stage_3(x[2])
x3 = self.upsample_stage_4(x[3])
x = torch.cat([x[0], x1, x2, x3], 1)
else:
# Upsampling with interpolation
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x = torch.cat([x[0], x1, x2, x3], 1)
return x
def init_weights(self, pretrained=''):
# logger.info('=> init weights from normal distribution')
for m in self.modules():
if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.normal_(m.weight, std=0.001)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, std=0.001)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
if os.path.isfile(pretrained):
pretrained_state_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
need_init_state_dict = {}
for name, m in pretrained_state_dict.items():
if name.split('.')[0] in self.pretrained_layers \
or self.pretrained_layers[0] is '*':
need_init_state_dict[name] = m
self.load_state_dict(need_init_state_dict, strict=False)
elif pretrained:
# logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
# raise ValueError('{} is not exist!'.format(pretrained))
pass
def get_pose_net(cfg, is_train):
model = PoseHighResolutionNet(cfg)
if is_train and cfg['MODEL']['INIT_WEIGHTS']:
model.init_weights(cfg['MODEL']['PRETRAINED'])
return model
def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
# pose_multi_resoluton_net related params
HRNET = CN()
HRNET.PRETRAINED_LAYERS = [
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
]
HRNET.STEM_INPLANES = 64
HRNET.FINAL_CONV_KERNEL = 1
HRNET.STAGE2 = CN()
HRNET.STAGE2.NUM_MODULES = 1
HRNET.STAGE2.NUM_BRANCHES = 2
HRNET.STAGE2.NUM_BLOCKS = [4, 4]
HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
HRNET.STAGE2.BLOCK = 'BASIC'
HRNET.STAGE2.FUSE_METHOD = 'SUM'
HRNET.STAGE3 = CN()
HRNET.STAGE3.NUM_MODULES = 4
HRNET.STAGE3.NUM_BRANCHES = 3
HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
HRNET.STAGE3.BLOCK = 'BASIC'
HRNET.STAGE3.FUSE_METHOD = 'SUM'
HRNET.STAGE4 = CN()
HRNET.STAGE4.NUM_MODULES = 3
HRNET.STAGE4.NUM_BRANCHES = 4
HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
HRNET.STAGE4.BLOCK = 'BASIC'
HRNET.STAGE4.FUSE_METHOD = 'SUM'
HRNET.DOWNSAMPLE = downsample
HRNET.USE_CONV = use_conv
cfg = CN()
cfg.MODEL = CN()
cfg.MODEL.INIT_WEIGHTS = True
cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
cfg.MODEL.EXTRA = HRNET
cfg.MODEL.NUM_JOINTS = 24
return cfg
def hrnet_w32(
pretrained=True,
pretrained_ckpt='data/pretrained_models/pose_coco/pose_hrnet_w32_256x192.pth',
downsample=False,
use_conv=False,
):
cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
return get_pose_net(cfg, is_train=True)
def hrnet_w48(
pretrained=True,
pretrained_ckpt='data/pretrained_models/pose_coco/pose_hrnet_w48_256x192.pth',
downsample=False,
use_conv=False,
):
cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
return get_pose_net(cfg, is_train=True)

View File

@ -0,0 +1,191 @@
from torch import nn
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torchvision.models.utils import load_state_dict_from_url
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
if norm_layer is None:
norm_layer = nn.BatchNorm2d
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
norm_layer=None):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use
"""
super(MobileNetV2, self).__init__()
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = nn.BatchNorm2d
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
# self.classifier = nn.Sequential(
# nn.Dropout(0.2),
# nn.Linear(self.last_channel, num_classes),
# )
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
# x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
# x = self.classifier(x)
return x
def forward(self, x):
return self._forward_impl(x)
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
model.load_state_dict(state_dict, strict=False)
return model

View File

@ -0,0 +1,355 @@
import torch
import torch.nn as nn
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torchvision.models.utils import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict, strict=False)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

View File

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
def get_backbone_info(backbone):
info = {
'resnet18': {'n_output_channels': 512, 'downsample_rate': 4},
'resnet34': {'n_output_channels': 512, 'downsample_rate': 4},
'resnet50': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnet50_adf_dropout': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnet50_dropout': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnet101': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnet152': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnext50_32x4d': {'n_output_channels': 2048, 'downsample_rate': 4},
'resnext101_32x8d': {'n_output_channels': 2048, 'downsample_rate': 4},
'wide_resnet50_2': {'n_output_channels': 2048, 'downsample_rate': 4},
'wide_resnet101_2': {'n_output_channels': 2048, 'downsample_rate': 4},
'mobilenet_v2': {'n_output_channels': 1280, 'downsample_rate': 4},
'hrnet_w32': {'n_output_channels': 480, 'downsample_rate': 4},
'hrnet_w48': {'n_output_channels': 720, 'downsample_rate': 4},
# 'hrnet_w64': {'n_output_channels': 2048, 'downsample_rate': 4},
'dla34': {'n_output_channels': 512, 'downsample_rate': 4},
}
return info[backbone]

View File

@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import os
import time
import yaml
import shutil
import argparse
import operator
import itertools
from os.path import join
from functools import reduce
from yacs.config import CfgNode as CN
from typing import Dict, List, Union, Any
# from ..utils.cluster import execute_task_on_cluster
##### CONSTANTS #####
DATASET_NPZ_PATH = 'data/dataset_extras'
DATASET_LMDB_PATH = 'data/lmdb'
MMPOSE_PATH = '/is/cluster/work/mkocabas/projects/mmpose'
MMDET_PATH = '/is/cluster/work/mkocabas/projects/mmdetection'
MMPOSE_CFG = os.path.join(MMPOSE_PATH, 'configs/top_down/hrnet/coco-wholebody/hrnet_w48_coco_wholebody_256x192.py')
MMPOSE_CKPT = os.path.join(MMPOSE_PATH, 'checkpoints/hrnet_w48_coco_wholebody_256x192-643e18cb_20200922.pth')
MMDET_CFG = os.path.join(MMDET_PATH, 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
MMDET_CKPT = os.path.join(MMDET_PATH, 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
PW3D_ROOT = 'data/dataset_folders/3dpw'
OH3D_ROOT = 'data/dataset_folders/3doh'
JOINT_REGRESSOR_TRAIN_EXTRA = 'models/pare/data/J_regressor_extra.npy'
JOINT_REGRESSOR_H36M = 'models/pare/data/J_regressor_h36m.npy'
SMPL_MEAN_PARAMS = 'models/pare/data/smpl_mean_params.npz'
SMPL_MODEL_DIR = 'models/pare/data/body_models/smpl'
COCO_OCCLUDERS_FILE = 'data/occlusion_augmentation/coco_train2014_occluders.pkl'
PASCAL_OCCLUDERS_FILE = 'data/occlusion_augmentation/pascal_occluders.pkl'
DATASET_FOLDERS = {
'3dpw': PW3D_ROOT,
'3dpw-val': PW3D_ROOT,
'3dpw-val-cam': PW3D_ROOT,
'3dpw-test-cam': PW3D_ROOT,
'3dpw-train-cam': PW3D_ROOT,
'3dpw-cam': PW3D_ROOT,
'3dpw-all': PW3D_ROOT,
'3doh': OH3D_ROOT,
}
DATASET_FILES = [
# Training
{
'3dpw-all': join(DATASET_NPZ_PATH, '3dpw_all_test_with_mmpose.npz'),
'3doh': join(DATASET_NPZ_PATH, '3doh_test.npz'),
},
# Testing
{
'3doh': join(DATASET_NPZ_PATH, '3doh_train.npz'),
'3dpw': join(DATASET_NPZ_PATH, '3dpw_train.npz'),
}
]
EVAL_MESH_DATASETS = ['3dpw', '3dpw-val', '3dpw-all', '3doh']
##### CONFIGS #####
hparams = CN()
# General settings
hparams.LOG_DIR = 'logs/experiments'
hparams.METHOD = 'pare'
hparams.EXP_NAME = 'default'
hparams.RUN_TEST = False
hparams.PROJECT_NAME = 'pare'
hparams.SEED_VALUE = -1
hparams.SYSTEM = CN()
hparams.SYSTEM.GPU = ''
hparams.SYSTEM.CLUSTER_NODE = 0.0
# Dataset hparams
hparams.DATASET = CN()
hparams.DATASET.LOAD_TYPE = 'Base'
hparams.DATASET.NOISE_FACTOR = 0.4
hparams.DATASET.ROT_FACTOR = 30
hparams.DATASET.SCALE_FACTOR = 0.25
hparams.DATASET.FLIP_PROB = 0.5
hparams.DATASET.CROP_PROB = 0.0
hparams.DATASET.CROP_FACTOR = 0.0
hparams.DATASET.BATCH_SIZE = 64
hparams.DATASET.NUM_WORKERS = 8
hparams.DATASET.PIN_MEMORY = True
hparams.DATASET.SHUFFLE_TRAIN = True
hparams.DATASET.TRAIN_DS = 'all'
hparams.DATASET.VAL_DS = '3dpw_3doh'
hparams.DATASET.NUM_IMAGES = -1
hparams.DATASET.TRAIN_NUM_IMAGES = -1
hparams.DATASET.TEST_NUM_IMAGES = -1
hparams.DATASET.IMG_RES = 224
hparams.DATASET.USE_HEATMAPS = '' # 'hm', 'hm_soft', 'part_segm', 'attention'
hparams.DATASET.RENDER_RES = 480
hparams.DATASET.MESH_COLOR = 'pinkish'
hparams.DATASET.FOCAL_LENGTH = 5000.
hparams.DATASET.IGNORE_3D = False
hparams.DATASET.USE_SYNTHETIC_OCCLUSION = False
hparams.DATASET.OCC_AUG_DATASET = 'pascal'
hparams.DATASET.USE_3D_CONF = False
hparams.DATASET.USE_GENDER = False
# this is a bit confusing but for the in the wild dataset ratios should be same, otherwise the code
# will be a bit verbose
hparams.DATASET.DATASETS_AND_RATIOS = 'h36m_mpii_lspet_coco_mpi-inf-3dhp_0.3_0.6_0.6_0.6_0.1'
hparams.DATASET.STAGE_DATASETS = '0+h36m_coco_0.2_0.8 2+h36m_coco_0.4_0.6'
# enable non parametric representation
hparams.DATASET.NONPARAMETRIC = False
# optimizer config
hparams.OPTIMIZER = CN()
hparams.OPTIMIZER.TYPE = 'adam'
hparams.OPTIMIZER.LR = 0.0001 # 0.00003
hparams.OPTIMIZER.WD = 0.0
# Training process hparams
hparams.TRAINING = CN()
hparams.TRAINING.RESUME = None
hparams.TRAINING.PRETRAINED = None
hparams.TRAINING.PRETRAINED_LIT = None
hparams.TRAINING.MAX_EPOCHS = 100
hparams.TRAINING.LOG_SAVE_INTERVAL = 50
hparams.TRAINING.LOG_FREQ_TB_IMAGES = 500
hparams.TRAINING.CHECK_VAL_EVERY_N_EPOCH = 1
hparams.TRAINING.RELOAD_DATALOADERS_EVERY_EPOCH = True
hparams.TRAINING.NUM_SMPLIFY_ITERS = 100 # 50
hparams.TRAINING.RUN_SMPLIFY = False
hparams.TRAINING.SMPLIFY_THRESHOLD = 100
hparams.TRAINING.DROPOUT_P = 0.2
hparams.TRAINING.TEST_BEFORE_TRAINING = False
hparams.TRAINING.SAVE_IMAGES = False
hparams.TRAINING.USE_PART_SEGM_LOSS = False
hparams.TRAINING.USE_AMP = False
# Training process hparams
hparams.TESTING = CN()
hparams.TESTING.SAVE_IMAGES = False
hparams.TESTING.SAVE_FREQ = 1
hparams.TESTING.SAVE_RESULTS = True
hparams.TESTING.SAVE_MESHES = False
hparams.TESTING.SIDEVIEW = True
hparams.TESTING.TEST_ON_TRAIN_END = True
hparams.TESTING.MULTI_SIDEVIEW = False
hparams.TESTING.USE_GT_CAM = False
# PARE method hparams
hparams.PARE = CN()
hparams.PARE.BACKBONE = 'resnet50' # hrnet_w32-conv, hrnet_w32-interp
hparams.PARE.NUM_JOINTS = 24
hparams.PARE.SOFTMAX_TEMP = 1.
hparams.PARE.NUM_FEATURES_SMPL = 64
hparams.PARE.USE_ATTENTION = False
hparams.PARE.USE_SELF_ATTENTION = False
hparams.PARE.USE_KEYPOINT_ATTENTION = False
hparams.PARE.USE_KEYPOINT_FEATURES_FOR_SMPL_REGRESSION = False
hparams.PARE.USE_POSTCONV_KEYPOINT_ATTENTION = False
hparams.PARE.KEYPOINT_ATTENTION_ACT = 'softmax'
hparams.PARE.USE_SCALE_KEYPOINT_ATTENTION = False
hparams.PARE.USE_FINAL_NONLOCAL = None
hparams.PARE.USE_BRANCH_NONLOCAL = None
hparams.PARE.USE_HMR_REGRESSION = False
hparams.PARE.USE_COATTENTION = False
hparams.PARE.NUM_COATTENTION_ITER = 1
hparams.PARE.COATTENTION_CONV = 'simple' # 'double_1', 'double_3', 'single_1', 'single_3', 'simple'
hparams.PARE.USE_UPSAMPLING = False
hparams.PARE.DECONV_CONV_KERNEL_SIZE = 4
hparams.PARE.USE_SOFT_ATTENTION = False
hparams.PARE.NUM_BRANCH_ITERATION = 0
hparams.PARE.BRANCH_DEEPER = False
hparams.PARE.NUM_DECONV_LAYERS = 3
hparams.PARE.NUM_DECONV_FILTERS = 256
hparams.PARE.USE_RESNET_CONV_HRNET = False
hparams.PARE.USE_POS_ENC = False
hparams.PARE.ITERATIVE_REGRESSION = False
hparams.PARE.ITER_RESIDUAL = False
hparams.PARE.NUM_ITERATIONS = 3
hparams.PARE.SHAPE_INPUT_TYPE = 'feats.all_pose.shape.cam'
hparams.PARE.POSE_INPUT_TYPE = 'feats.neighbor_pose_feats.all_pose.self_pose.neighbor_pose.shape.cam'
hparams.PARE.POSE_MLP_NUM_LAYERS = 1
hparams.PARE.SHAPE_MLP_NUM_LAYERS = 1
hparams.PARE.POSE_MLP_HIDDEN_SIZE = 256
hparams.PARE.SHAPE_MLP_HIDDEN_SIZE = 256
hparams.PARE.SHAPE_LOSS_WEIGHT = 0
hparams.PARE.KEYPOINT_LOSS_WEIGHT = 5.
hparams.PARE.KEYPOINT_NATIVE_LOSS_WEIGHT = 5.
hparams.PARE.HEATMAPS_LOSS_WEIGHT = 5.
hparams.PARE.SMPL_PART_LOSS_WEIGHT = 1.
hparams.PARE.PART_SEGM_LOSS_WEIGHT = 1.
hparams.PARE.POSE_LOSS_WEIGHT = 1.
hparams.PARE.BETA_LOSS_WEIGHT = 0.001
hparams.PARE.OPENPOSE_TRAIN_WEIGHT = 0.
hparams.PARE.GT_TRAIN_WEIGHT = 1.
hparams.PARE.LOSS_WEIGHT = 60.
hparams.PARE.USE_SHAPE_REG = False
hparams.PARE.USE_MEAN_CAMSHAPE = False
hparams.PARE.USE_MEAN_POSE = False
hparams.PARE.INIT_XAVIER = False
def get_hparams_defaults():
"""Get a yacs hparamsNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return hparams.clone()
def update_hparams(hparams_file):
hparams = get_hparams_defaults()
hparams.merge_from_file(hparams_file)
return hparams.clone()
def update_hparams_from_dict(cfg_dict):
hparams = get_hparams_defaults()
cfg = hparams.load_cfg(str(cfg_dict))
hparams.merge_from_other_cfg(cfg)
return hparams.clone()

View File

@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import numpy as np
# Mean and standard deviation for normalizing input image
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
IMG_NORM_STD = [0.229, 0.224, 0.225]
"""
We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
We keep a superset of 24 joints such that we include all joints from every dataset.
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
The joints used here are the following:
"""
JOINT_NAMES = [
# 25 OpenPose joints (in the order provided by OpenPose)
'OP Nose',
'OP Neck',
'OP RShoulder',
'OP RElbow',
'OP RWrist',
'OP LShoulder',
'OP LElbow',
'OP LWrist',
'OP MidHip',
'OP RHip',
'OP RKnee',
'OP RAnkle',
'OP LHip',
'OP LKnee',
'OP LAnkle',
'OP REye',
'OP LEye',
'OP REar',
'OP LEar',
'OP LBigToe',
'OP LSmallToe',
'OP LHeel',
'OP RBigToe',
'OP RSmallToe',
'OP RHeel',
# 24 Ground Truth joints (superset of joints from different datasets)
'Right Ankle',
'Right Knee',
'Right Hip',
'Left Hip',
'Left Knee',
'Left Ankle',
'Right Wrist',
'Right Elbow',
'Right Shoulder',
'Left Shoulder',
'Left Elbow',
'Left Wrist',
'Neck (LSP)',
'Top of Head (LSP)',
'Pelvis (MPII)',
'Thorax (MPII)',
'Spine (H36M)',
'Jaw (H36M)',
'Head (H36M)',
'Nose',
'Left Eye',
'Right Eye',
'Left Ear',
'Right Ear'
]
# Dict containing the joints in numerical order
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
# Map joints to SMPL joints
JOINT_MAP = {
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
'Spine (H36M)': 51, 'Jaw (H36M)': 52,
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
}
# Joint selectors
# Indices to get the 14 LSP joints from the 17 H36M joints
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
H36M_TO_J14 = H36M_TO_J17[:14]
# Indices to get the 14 LSP joints from the ground truth joints
J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
J24_TO_J14 = J24_TO_J17[:14]
# Permutation of SMPL pose parameters when flipping the shape
SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
SMPL_POSE_FLIP_PERM.append(3*i)
SMPL_POSE_FLIP_PERM.append(3*i+1)
SMPL_POSE_FLIP_PERM.append(3*i+2)
# Permutation indices for the 24 ground truth joints
J24_FLIP_PERM = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
# Permutation indices for the full set of 49 joints
J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
+ [25+i for i in J24_FLIP_PERM]
SMPLH_TO_SMPL = np.arange(0, 156).reshape((-1, 3))[
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 37])
].reshape(-1)
pw3d_occluded_sequences = [
'courtyard_backpack',
'courtyard_basketball',
'courtyard_bodyScannerMotions',
'courtyard_box',
'courtyard_golf',
'courtyard_jacket',
'courtyard_laceShoe',
'downtown_stairs',
'flat_guitar',
'flat_packBags',
'outdoors_climbing',
'outdoors_crosscountry',
'outdoors_fencing',
'outdoors_freestyle',
'outdoors_golf',
'outdoors_parcours',
'outdoors_slalom',
]
pw3d_test_sequences = [
'flat_packBags_00',
'downtown_weeklyMarket_00',
'outdoors_fencing_01',
'downtown_walkBridge_01',
'downtown_enterShop_00',
'downtown_rampAndStairs_00',
'downtown_bar_00',
'downtown_runForBus_01',
'downtown_cafe_00',
'flat_guitar_01',
'downtown_runForBus_00',
'downtown_sitOnStairs_00',
'downtown_bus_00',
'downtown_arguing_00',
'downtown_crossStreets_00',
'downtown_walkUphill_00',
'downtown_walking_00',
'downtown_car_00',
'downtown_warmWelcome_00',
'downtown_upstairs_00',
'downtown_stairs_00',
'downtown_windowShopping_00',
'office_phoneCall_00',
'downtown_downstairs_00'
]
pw3d_cam_sequences = [
# TEST
'downtown_downstairs_00',
'downtown_stairs_00',
'downtown_rampAndStairs_00',
'flat_packBags_00',
'flat_guitar_01',
'downtown_warmWelcome_00',
'downtown_walkUphill_00',
# VALIDATION
'outdoors_parcours_01',
'outdoors_crosscountry_00',
'outdoors_freestyle_01',
'downtown_walkDownhill_00',
'outdoors_parcours_00',
]

View File

@ -0,0 +1,4 @@
from .pare_head import PareHead
from .hmr_head import HMRHead
# from .smpl_head import SMPLHead
# from .smpl_cam_head import SMPLCamHead

View File

@ -0,0 +1,203 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import math
import torch
import numpy as np
import torch.nn as nn
from ..config import SMPL_MEAN_PARAMS
from ..utils.geometry import rot6d_to_rotmat, rotmat_to_rot6d
BN_MOMENTUM = 0.1
class HMRHead(nn.Module):
def __init__(
self,
num_input_features,
smpl_mean_params=SMPL_MEAN_PARAMS,
estimate_var=False,
use_separate_var_branch=False,
uncertainty_activation='',
backbone='resnet50',
use_cam_feats=False,
):
super(HMRHead, self).__init__()
npose = 24 * 6
self.npose = npose
self.estimate_var = estimate_var
self.use_separate_var_branch = use_separate_var_branch
self.uncertainty_activation = uncertainty_activation
self.backbone = backbone
self.num_input_features = num_input_features
self.use_cam_feats = use_cam_feats
if use_cam_feats:
num_input_features += 7 # 6d rotmat + vfov
self.avgpool = nn.AdaptiveAvgPool2d(1) # nn.AvgPool2d(7, stride=1)
self.fc1 = nn.Linear(num_input_features + npose + 13, 1024)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout()
if self.estimate_var:
# estimate variance for pose and shape parameters
if self.use_separate_var_branch:
# Decouple var estimation layer using separate linear layers
self.decpose = nn.Linear(1024, npose)
self.decshape = nn.Linear(1024, 10)
self.deccam = nn.Linear(1024, 3)
self.decpose_var = nn.Linear(1024, npose)
self.decshape_var = nn.Linear(1024, 10)
nn.init.xavier_uniform_(self.decpose_var.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape_var.weight, gain=0.01)
else:
# double the output sizes to estimate var
self.decpose = nn.Linear(1024, npose * 2)
self.decshape = nn.Linear(1024, 10 * 2)
self.deccam = nn.Linear(1024, 3)
else:
self.decpose = nn.Linear(1024, npose)
self.decshape = nn.Linear(1024, 10)
self.deccam = nn.Linear(1024, 3)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
if self.backbone.startswith('hrnet'):
self.downsample_module = self._make_head()
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
def _make_head(self):
# downsampling modules
downsamp_modules = []
for i in range(3):
in_channels = self.num_input_features
out_channels = self.num_input_features
downsamp_module = nn.Sequential(
nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
downsamp_modules.append(downsamp_module)
downsamp_modules = nn.Sequential(*downsamp_modules)
return downsamp_modules
def forward(
self,
features,
init_pose=None,
init_shape=None,
init_cam=None,
cam_rotmat=None,
cam_vfov=None,
n_iter=3
):
# if self.backbone.startswith('hrnet'):
# features = self.downsample_module(features)
batch_size = features.shape[0]
if init_pose is None:
init_pose = self.init_pose.expand(batch_size, -1)
if init_shape is None:
init_shape = self.init_shape.expand(batch_size, -1)
if init_cam is None:
init_cam = self.init_cam.expand(batch_size, -1)
xf = self.avgpool(features)
xf = xf.view(xf.size(0), -1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(n_iter):
if self.use_cam_feats:
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam,
rotmat_to_rot6d(cam_rotmat), cam_vfov.unsqueeze(-1)], 1)
else:
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
if self.estimate_var:
pred_pose = self.decpose(xc)[:,:self.npose] + pred_pose
pred_shape = self.decshape(xc)[:,:10] + pred_shape
pred_cam = self.deccam(xc) + pred_cam
if self.use_separate_var_branch:
pred_pose_var = self.decpose_var(xc)
pred_shape_var = self.decshape_var(xc)
else:
pred_pose_var = self.decpose(xc)[:,self.npose:]
pred_shape_var = self.decshape(xc)[:,10:]
if self.uncertainty_activation != '':
# Use an activation layer to output uncertainty
pred_pose_var = eval(f'F.{self.uncertainty_activation}')(pred_pose_var)
pred_shape_var = eval(f'F.{self.uncertainty_activation}')(pred_shape_var)
else:
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
output = {
'pred_pose': pred_rotmat,
'pred_cam': pred_cam,
'pred_shape': pred_shape,
'pred_pose_6d': pred_pose,
}
if self.estimate_var:
output.update({
'pred_pose_var': torch.cat([pred_pose, pred_pose_var], dim=1),
'pred_shape_var': torch.cat([pred_shape, pred_shape_var], dim=1),
})
return output
def keep_variance(x, min_variance):
return x + min_variance

View File

@ -0,0 +1,926 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from ..config import SMPL_MEAN_PARAMS
from ..layers.coattention import CoAttention
from ..utils.geometry import rot6d_to_rotmat, get_coord_maps
from ..utils.kp_utils import get_smpl_neighbor_triplets
from ..layers.softargmax import softargmax2d, get_heatmap_preds
from ..layers import LocallyConnected2d, KeypointAttention, interpolate
from ..layers.non_local import dot_product
from ..backbone.resnet import conv3x3, conv1x1, BasicBlock
class logger:
@staticmethod
def info(*args, **kwargs):
pass
BN_MOMENTUM = 0.1
class PareHead(nn.Module):
def __init__(
self,
num_joints,
num_input_features,
softmax_temp=1.0,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
num_camera_params=3,
num_features_smpl=64,
final_conv_kernel=1,
iterative_regression=False,
iter_residual=False,
num_iterations=3,
shape_input_type='feats', # 'feats.pose.shape.cam'
pose_input_type='feats', # 'feats.neighbor_pose_feats.all_pose.self_pose.neighbor_pose.shape.cam'
pose_mlp_num_layers=1,
shape_mlp_num_layers=1,
pose_mlp_hidden_size=256,
shape_mlp_hidden_size=256,
use_keypoint_features_for_smpl_regression=False,
use_heatmaps='',
use_keypoint_attention=False,
use_postconv_keypoint_attention=False,
keypoint_attention_act='softmax',
use_scale_keypoint_attention=False,
use_branch_nonlocal=None, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
use_final_nonlocal=None, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
backbone='resnet',
use_hmr_regression=False,
use_coattention=False,
num_coattention_iter=1,
coattention_conv='simple', # 'double_1', 'double_3', 'single_1', 'single_3', 'simple'
use_upsampling=False,
use_soft_attention=False, # Stefan & Otmar 3DV style attention
num_branch_iteration=0,
branch_deeper=False,
use_resnet_conv_hrnet=False,
use_position_encodings=None,
use_mean_camshape=False,
use_mean_pose=False,
init_xavier=False,
):
super(PareHead, self).__init__()
self.backbone = backbone
self.num_joints = num_joints
self.deconv_with_bias = False
self.use_heatmaps = use_heatmaps
self.num_iterations = num_iterations
self.use_final_nonlocal = use_final_nonlocal
self.use_branch_nonlocal = use_branch_nonlocal
self.use_hmr_regression = use_hmr_regression
self.use_coattention = use_coattention
self.num_coattention_iter = num_coattention_iter
self.coattention_conv = coattention_conv
self.use_soft_attention = use_soft_attention
self.num_branch_iteration = num_branch_iteration
self.iter_residual = iter_residual
self.iterative_regression = iterative_regression
self.pose_mlp_num_layers = pose_mlp_num_layers
self.shape_mlp_num_layers = shape_mlp_num_layers
self.pose_mlp_hidden_size = pose_mlp_hidden_size
self.shape_mlp_hidden_size = shape_mlp_hidden_size
self.use_keypoint_attention = use_keypoint_attention
self.use_keypoint_features_for_smpl_regression = use_keypoint_features_for_smpl_regression
self.use_position_encodings = use_position_encodings
self.use_mean_camshape = use_mean_camshape
self.use_mean_pose = use_mean_pose
self.num_input_features = num_input_features
if use_soft_attention:
# These options should be True by default when soft attention is used
self.use_keypoint_features_for_smpl_regression = True
self.use_hmr_regression = True
self.use_coattention = False
logger.warning('Coattention cannot be used together with soft attention')
logger.warning('Overriding use_coattention=False')
if use_coattention:
self.use_keypoint_features_for_smpl_regression = False
logger.warning('\"use_keypoint_features_for_smpl_regression\" cannot be used together with co-attention')
logger.warning('Overriding \"use_keypoint_features_for_smpl_regression\"=False')
if use_hmr_regression:
self.iterative_regression = False
logger.warning('iterative_regression cannot be used together with hmr regression')
if self.use_heatmaps in ['part_segm', 'attention']:
logger.info('\"Keypoint Attention\" should be activated to be able to use part segmentation')
logger.info('Overriding use_keypoint_attention')
self.use_keypoint_attention = True
assert num_iterations > 0, '\"num_iterations\" should be greater than 0.'
if use_position_encodings:
assert backbone.startswith('hrnet'), 'backbone should be hrnet to use position encodings'
# self.pos_enc = get_coord_maps(size=56)
self.register_buffer('pos_enc', get_coord_maps(size=56))
num_input_features += 2
self.num_input_features = num_input_features
if backbone.startswith('hrnet'):
if use_resnet_conv_hrnet:
logger.info('Using resnet block for keypoint and smpl conv layers...')
self.keypoint_deconv_layers = self._make_res_conv_layers(
input_channels=self.num_input_features,
num_channels=num_deconv_filters[-1],
num_basic_blocks=num_deconv_layers,
)
self.num_input_features = num_input_features
self.smpl_deconv_layers = self._make_res_conv_layers(
input_channels=self.num_input_features,
num_channels=num_deconv_filters[-1],
num_basic_blocks=num_deconv_layers,
)
else:
self.keypoint_deconv_layers = self._make_conv_layer(
num_deconv_layers,
num_deconv_filters,
(3,)*num_deconv_layers,
)
self.num_input_features = num_input_features
self.smpl_deconv_layers = self._make_conv_layer(
num_deconv_layers,
num_deconv_filters,
(3,)*num_deconv_layers,
)
else:
# part branch that estimates 2d keypoints
conv_fn = self._make_upsample_layer if use_upsampling else self._make_deconv_layer
if use_upsampling:
logger.info('Upsampling is active to increase spatial dimension')
logger.info(f'Upsampling conv kernels: {num_deconv_kernels}')
self.keypoint_deconv_layers = conv_fn(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
# reset inplanes to 2048 -> final resnet layer
self.num_input_features = num_input_features
self.smpl_deconv_layers = conv_fn(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
pose_mlp_inp_dim = num_deconv_filters[-1]
smpl_final_dim = num_features_smpl
shape_mlp_inp_dim = num_joints * smpl_final_dim
if self.use_soft_attention:
logger.info('Soft attention (Stefan & Otmar 3DV) is active')
self.keypoint_final_layer = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
conv1x1(256, num_joints+1 if self.use_heatmaps in ('part_segm', 'part_segm_pool') else num_joints),
)
soft_att_feature_size = smpl_final_dim # if use_hmr_regression else pose_mlp_inp_dim
self.smpl_final_layer = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
conv1x1(256, soft_att_feature_size),
)
# pose_mlp_inp_dim = soft_att_feature_size
else:
self.keypoint_final_layer = nn.Conv2d(
in_channels=num_deconv_filters[-1],
out_channels=num_joints+1 if self.use_heatmaps in ('part_segm', 'part_segm_pool') else num_joints,
kernel_size=final_conv_kernel,
stride=1,
padding=1 if final_conv_kernel == 3 else 0,
)
self.smpl_final_layer = nn.Conv2d(
in_channels=num_deconv_filters[-1],
out_channels=smpl_final_dim,
kernel_size=final_conv_kernel,
stride=1,
padding=1 if final_conv_kernel == 3 else 0,
)
# temperature for softargmax function
self.register_buffer('temperature', torch.tensor(softmax_temp))
# if self.iterative_regression or self.num_branch_iteration > 0 or self.use_coattention:
mean_params = np.load(SMPL_MEAN_PARAMS)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
if self.iterative_regression:
# enable iterative regression similar to HMR
# these are the features that can be used as input to final MLPs
input_type_dim = {
'feats': 0, # image features for self
'neighbor_pose_feats': 2 * 256, # image features from neighbor joints
'all_pose': 24 * 6, # rot6d of all joints from previous iter
'self_pose': 6, # rot6d of self
'neighbor_pose': 2 * 6, # rot6d of neighbor joints from previous iter
'shape': 10, # smpl betas/shape
'cam': num_camera_params, # weak perspective camera
}
assert 'feats' in shape_input_type, '\"feats\" should be the default value'
assert 'feats' in pose_input_type, '\"feats\" should be the default value'
self.shape_input_type = shape_input_type.split('.')
self.pose_input_type = pose_input_type.split('.')
pose_mlp_inp_dim = pose_mlp_inp_dim + sum([input_type_dim[x] for x in self.pose_input_type])
shape_mlp_inp_dim = shape_mlp_inp_dim + sum([input_type_dim[x] for x in self.shape_input_type])
logger.debug(f'Shape MLP takes \"{self.shape_input_type}\" as input, '
f'input dim: {shape_mlp_inp_dim}')
logger.debug(f'Pose MLP takes \"{self.pose_input_type}\" as input, '
f'input dim: {pose_mlp_inp_dim}')
self.pose_mlp_inp_dim = pose_mlp_inp_dim
self.shape_mlp_inp_dim = shape_mlp_inp_dim
if self.use_hmr_regression:
logger.info(f'HMR regression is active...')
# enable iterative regression similar to HMR
self.fc1 = nn.Linear(num_joints * smpl_final_dim + (num_joints * 6) + 10 + num_camera_params, 1024)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(1024, (num_joints * 6))
self.decshape = nn.Linear(1024, 10)
self.deccam = nn.Linear(1024, num_camera_params)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
else:
# here we use 2 different MLPs to estimate shape and camera
# They take a channelwise downsampled version of smpl features
self.shape_mlp = self._get_shape_mlp(output_size=10)
self.cam_mlp = self._get_shape_mlp(output_size=num_camera_params)
# for pose each joint has a separate MLP
# weights for these MLPs are not shared
# hence we use Locally Connected layers
# TODO support kernel_size > 1 to access context of other joints
self.pose_mlp = self._get_pose_mlp(num_joints=num_joints, output_size=6)
if init_xavier:
nn.init.xavier_uniform_(self.shape_mlp.weight, gain=0.01)
nn.init.xavier_uniform_(self.cam_mlp.weight, gain=0.01)
nn.init.xavier_uniform_(self.pose_mlp.weight, gain=0.01)
if self.use_branch_nonlocal:
logger.info(f'Branch nonlocal is active, type {self.use_branch_nonlocal}')
self.branch_2d_nonlocal = eval(self.use_branch_nonlocal).NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
self.branch_3d_nonlocal = eval(self.use_branch_nonlocal).NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
if self.use_final_nonlocal:
logger.info(f'Final nonlocal is active, type {self.use_final_nonlocal}')
self.final_pose_nonlocal = eval(self.use_final_nonlocal).NONLocalBlock1D(
in_channels=self.pose_mlp_inp_dim,
sub_sample=False,
bn_layer=True,
)
self.final_shape_nonlocal = eval(self.use_final_nonlocal).NONLocalBlock1D(
in_channels=num_features_smpl,
sub_sample=False,
bn_layer=True,
)
if self.use_keypoint_attention:
logger.info('Keypoint attention is active')
self.keypoint_attention = KeypointAttention(
use_conv=use_postconv_keypoint_attention,
in_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
out_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
act=keypoint_attention_act,
use_scale=use_scale_keypoint_attention,
)
if self.use_coattention:
logger.info(f'Coattention is active, final conv type {self.coattention_conv}')
self.coattention = CoAttention(n_channel=num_deconv_filters[-1], final_conv=self.coattention_conv)
if self.num_branch_iteration > 0:
logger.info(f'Branch iteration is active')
if branch_deeper:
self.branch_iter_2d_nonlocal = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
)
self.branch_iter_3d_nonlocal = nn.Sequential(
conv3x3(num_deconv_filters[-1], 256),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
)
else:
self.branch_iter_2d_nonlocal = dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
self.branch_iter_3d_nonlocal = dot_product.NONLocalBlock2D(
in_channels=num_deconv_filters[-1],
sub_sample=False,
bn_layer=True,
)
def _get_shape_mlp(self, output_size):
if self.shape_mlp_num_layers == 1:
return nn.Linear(self.shape_mlp_inp_dim, output_size)
module_list = []
for i in range(self.shape_mlp_num_layers):
if i == 0:
module_list.append(
nn.Linear(self.shape_mlp_inp_dim, self.shape_mlp_hidden_size)
)
elif i == self.shape_mlp_num_layers - 1:
module_list.append(
nn.Linear(self.shape_mlp_hidden_size, output_size)
)
else:
module_list.append(
nn.Linear(self.shape_mlp_hidden_size, self.shape_mlp_hidden_size)
)
return nn.Sequential(*module_list)
def _get_pose_mlp(self, num_joints, output_size):
if self.pose_mlp_num_layers == 1:
return LocallyConnected2d(
in_channels=self.pose_mlp_inp_dim,
out_channels=output_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
module_list = []
for i in range(self.pose_mlp_num_layers):
if i == 0:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_inp_dim,
out_channels=self.pose_mlp_hidden_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
elif i == self.pose_mlp_num_layers - 1:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_hidden_size,
out_channels=output_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
else:
module_list.append(
LocallyConnected2d(
in_channels=self.pose_mlp_hidden_size,
out_channels=self.pose_mlp_hidden_size,
output_size=[num_joints, 1],
kernel_size=1,
stride=1,
)
)
return nn.Sequential(*module_list)
def _get_deconv_cfg(self, deconv_kernel):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_conv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_conv_layers is different len(num_conv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_conv_layers is different len(num_conv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
nn.Conv2d(
in_channels=self.num_input_features,
out_channels=planes,
kernel_size=kernel,
stride=1,
padding=padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.num_input_features = planes
return nn.Sequential(*layers)
def _make_res_conv_layers(self, input_channels, num_channels=64,
num_heads=1, num_basic_blocks=2):
head_layers = []
# kernel_sizes, strides, paddings = self._get_trans_cfg()
# for kernel_size, padding, stride in zip(kernel_sizes, paddings, strides):
head_layers.append(nn.Sequential(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_channels,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(num_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
)
for i in range(num_heads):
layers = []
for _ in range(num_basic_blocks):
layers.append(nn.Sequential(BasicBlock(num_channels, num_channels)))
head_layers.append(nn.Sequential(*layers))
# head_layers.append(nn.Conv2d(in_channels=num_channels, out_channels=output_channels,
# kernel_size=1, stride=1, padding=0))
return nn.Sequential(*head_layers)
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
nn.ConvTranspose2d(
in_channels=self.num_input_features,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
# if self.use_self_attention:
# layers.append(SelfAttention(planes))
self.num_input_features = planes
return nn.Sequential(*layers)
def _make_upsample_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_layers is different len(num_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_layers is different len(num_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
layers.append(
nn.Conv2d(in_channels=self.num_input_features, out_channels=planes,
kernel_size=kernel, stride=1, padding=padding, bias=self.deconv_with_bias)
)
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
# if self.use_self_attention:
# layers.append(SelfAttention(planes))
self.num_input_features = planes
return nn.Sequential(*layers)
def _prepare_pose_mlp_inp(self, feats, pred_pose, pred_shape, pred_cam):
# feats shape: [N, 256, J, 1]
# pose shape: [N, 6, J, 1]
# cam shape: [N, 3]
# beta shape: [N, 10]
batch_size, num_joints = pred_pose.shape[0], pred_pose.shape[2]
joint_triplets = get_smpl_neighbor_triplets()
inp_list = []
for inp_type in self.pose_input_type:
if inp_type == 'feats':
# add image features
inp_list.append(feats)
if inp_type == 'neighbor_pose_feats':
# add the image features from neighboring joints
n_pose_feat = []
for jt in joint_triplets:
n_pose_feat.append(
feats[:, :, jt[1:]].reshape(batch_size, -1, 1).unsqueeze(-2)
)
n_pose_feat = torch.cat(n_pose_feat, 2)
inp_list.append(n_pose_feat)
if inp_type == 'self_pose':
# add image features
inp_list.append(pred_pose)
if inp_type == 'all_pose':
# append all of the joint angels
all_pose = pred_pose.reshape(batch_size, -1, 1)[..., None].repeat(1, 1, num_joints, 1)
inp_list.append(all_pose)
if inp_type == 'neighbor_pose':
# append only the joint angles of neighboring ones
n_pose = []
for jt in joint_triplets:
n_pose.append(
pred_pose[:,:,jt[1:]].reshape(batch_size, -1, 1).unsqueeze(-2)
)
n_pose = torch.cat(n_pose, 2)
inp_list.append(n_pose)
if inp_type == 'shape':
# append shape predictions
pred_shape = pred_shape[..., None, None].repeat(1, 1, num_joints, 1)
inp_list.append(pred_shape)
if inp_type == 'cam':
# append camera predictions
pred_cam = pred_cam[..., None, None].repeat(1, 1, num_joints, 1)
inp_list.append(pred_cam)
assert len(inp_list) > 0
# for i,inp in enumerate(inp_list):
# print(i, inp.shape)
return torch.cat(inp_list, 1)
def _prepare_shape_mlp_inp(self, feats, pred_pose, pred_shape, pred_cam):
# feats shape: [N, 256, J, 1]
# pose shape: [N, 6, J, 1]
# cam shape: [N, 3]
# beta shape: [N, 10]
batch_size, num_joints = pred_pose.shape[:2]
inp_list = []
for inp_type in self.shape_input_type:
if inp_type == 'feats':
# add image features
inp_list.append(feats)
if inp_type == 'all_pose':
# append all of the joint angels
pred_pose = pred_pose.reshape(batch_size, -1)
inp_list.append(pred_pose)
if inp_type == 'shape':
# append shape predictions
inp_list.append(pred_shape)
if inp_type == 'cam':
# append camera predictions
inp_list.append(pred_cam)
assert len(inp_list) > 0
return torch.cat(inp_list, 1)
def forward(self, features, gt_segm=None):
batch_size = features.shape[0]
init_pose = self.init_pose.expand(batch_size, -1) # N, Jx6
init_shape = self.init_shape.expand(batch_size, -1)
init_cam = self.init_cam.expand(batch_size, -1)
if self.use_position_encodings:
features = torch.cat((features, self.pos_enc.repeat(features.shape[0], 1, 1, 1)), 1)
output = {}
############## 2D PART BRANCH FEATURES ##############
part_feats = self._get_2d_branch_feats(features)
############## GET PART ATTENTION MAP ##############
part_attention = self._get_part_attention_map(part_feats, output)
############## 3D SMPL BRANCH FEATURES ##############
smpl_feats = self._get_3d_smpl_feats(features, part_feats)
############## SAMPLE LOCAL FEATURES ##############
if gt_segm is not None:
# logger.debug(gt_segm.shape)
# import IPython; IPython.embed(); exit()
gt_segm = F.interpolate(gt_segm.unsqueeze(1).float(), scale_factor=(1/4, 1/4), mode='nearest').long().squeeze(1)
part_attention = F.one_hot(gt_segm.to('cpu'), num_classes=self.num_joints + 1).permute(0,3,1,2).float()[:,1:,:,:]
part_attention = part_attention.to('cuda')
# part_attention = F.interpolate(part_attention, scale_factor=1/4, mode='bilinear', align_corners=True)
# import IPython; IPython.embed(); exit()
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
############## GET FINAL PREDICTIONS ##############
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, init_pose, init_shape, init_cam
)
if self.use_coattention:
for c in range(self.num_coattention_iter):
smpl_feats, part_feats = self.coattention(smpl_feats, part_feats)
part_attention = self._get_part_attention_map(part_feats, output)
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, pred_pose, pred_shape, pred_cam
)
if self.num_branch_iteration > 0:
for nbi in range(self.num_branch_iteration):
if self.use_soft_attention:
smpl_feats = self.branch_iter_3d_nonlocal(smpl_feats)
part_feats = self.branch_iter_2d_nonlocal(part_feats)
else:
smpl_feats = self.branch_iter_3d_nonlocal(smpl_feats)
part_feats = smpl_feats
part_attention = self._get_part_attention_map(part_feats, output)
point_local_feat, cam_shape_feats = self._get_local_feats(smpl_feats, part_attention, output)
pred_pose, pred_shape, pred_cam = self._get_final_preds(
point_local_feat, cam_shape_feats, pred_pose, pred_shape, pred_cam,
)
pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3)
output.update({
'pred_pose': pred_rotmat,
'pred_cam': pred_cam,
'pred_shape': pred_shape,
})
return output
def _get_local_feats(self, smpl_feats, part_attention, output):
cam_shape_feats = self.smpl_final_layer(smpl_feats)
if self.use_keypoint_attention:
point_local_feat = self.keypoint_attention(smpl_feats, part_attention)
cam_shape_feats = self.keypoint_attention(cam_shape_feats, part_attention)
else:
point_local_feat = interpolate(smpl_feats, output['pred_kp2d'])
cam_shape_feats = interpolate(cam_shape_feats, output['pred_kp2d'])
return point_local_feat, cam_shape_feats
def _get_2d_branch_feats(self, features):
part_feats = self.keypoint_deconv_layers(features)
if self.use_branch_nonlocal:
part_feats = self.branch_2d_nonlocal(part_feats)
return part_feats
def _get_3d_smpl_feats(self, features, part_feats):
if self.use_keypoint_features_for_smpl_regression:
smpl_feats = part_feats
else:
smpl_feats = self.smpl_deconv_layers(features)
if self.use_branch_nonlocal:
smpl_feats = self.branch_3d_nonlocal(smpl_feats)
return smpl_feats
def _get_part_attention_map(self, part_feats, output):
heatmaps = self.keypoint_final_layer(part_feats)
if self.use_heatmaps == 'hm':
# returns coords between [-1,1]
pred_kp2d, confidence = get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
output['pred_kp2d_conf'] = confidence
output['pred_heatmaps_2d'] = heatmaps
elif self.use_heatmaps == 'hm_soft':
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
output['pred_heatmaps_2d'] = heatmaps
elif self.use_heatmaps == 'part_segm':
output['pred_segm_mask'] = heatmaps
heatmaps = heatmaps[:,1:,:,:] # remove the first channel which encodes the background
elif self.use_heatmaps == 'part_segm_pool':
output['pred_segm_mask'] = heatmaps
heatmaps = heatmaps[:,1:,:,:] # remove the first channel which encodes the background
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature) # get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
for k, v in output.items():
if torch.any(torch.isnan(v)):
logger.debug(f'{k} is Nan!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
if torch.any(torch.isinf(v)):
logger.debug(f'{k} is Inf!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
# if torch.any(torch.isnan(pred_kp2d)):
# print('pred_kp2d nan', pred_kp2d.min(), pred_kp2d.max())
# if torch.any(torch.isnan(heatmaps)):
# print('heatmap nan', heatmaps.min(), heatmaps.max())
#
# if torch.any(torch.isinf(pred_kp2d)):
# print('pred_kp2d inf', pred_kp2d.min(), pred_kp2d.max())
# if torch.any(torch.isinf(heatmaps)):
# print('heatmap inf', heatmaps.min(), heatmaps.max())
elif self.use_heatmaps == 'attention':
output['pred_attention'] = heatmaps
else:
# returns coords between [-1,1]
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
output['pred_heatmaps_2d'] = heatmaps
return heatmaps
def _get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam):
if self.use_hmr_regression:
return self._hmr_get_final_preds(cam_shape_feats, init_pose, init_shape, init_cam)
else:
return self._pare_get_final_preds(pose_feats, cam_shape_feats, init_pose, init_shape, init_cam)
def _hmr_get_final_preds(self, cam_shape_feats, init_pose, init_shape, init_cam):
if self.use_final_nonlocal:
cam_shape_feats = self.final_shape_nonlocal(cam_shape_feats)
xf = torch.flatten(cam_shape_feats, start_dim=1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(3):
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
return pred_pose, pred_shape, pred_cam
def _pare_get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam):
pose_feats = pose_feats.unsqueeze(-1) #
if init_pose.shape[-1] == 6:
# This means init_pose comes from a previous iteration
init_pose = init_pose.transpose(2,1).unsqueeze(-1)
else:
# This means init pose comes from mean pose
init_pose = init_pose.reshape(init_pose.shape[0], 6, -1).unsqueeze(-1)
if self.iterative_regression:
shape_feats = torch.flatten(cam_shape_feats, start_dim=1)
pred_pose = init_pose # [N, 6, J, 1]
pred_cam = init_cam # [N, 3]
pred_shape = init_shape # [N, 10]
# import IPython; IPython.embed(); exit(1)
for i in range(self.num_iterations):
# pose_feats shape: [N, 256, 24, 1]
# shape_feats shape: [N, 24*64]
pose_mlp_inp = self._prepare_pose_mlp_inp(pose_feats, pred_pose, pred_shape, pred_cam)
shape_mlp_inp = self._prepare_shape_mlp_inp(shape_feats, pred_pose, pred_shape, pred_cam)
# print('pose_mlp_inp', pose_mlp_inp.shape)
# print('shape_mlp_inp', shape_mlp_inp.shape)
# TODO: this does not work but let it go since we dont use iterative regression for now.
# if self.use_final_nonlocal:
# pose_mlp_inp = self.final_pose_nonlocal(pose_mlp_inp)
# shape_mlp_inp = self.final_shape_nonlocal(shape_mlp_inp)
if self.iter_residual:
pred_pose = self.pose_mlp(pose_mlp_inp) + pred_pose
pred_cam = self.cam_mlp(shape_mlp_inp) + pred_cam
pred_shape = self.shape_mlp(shape_mlp_inp) + pred_shape
else:
pred_pose = self.pose_mlp(pose_mlp_inp)
pred_cam = self.cam_mlp(shape_mlp_inp)
pred_shape = self.shape_mlp(shape_mlp_inp) + init_shape
else:
shape_feats = cam_shape_feats
if self.use_final_nonlocal:
pose_feats = self.final_pose_nonlocal(pose_feats.squeeze(-1)).unsqueeze(-1)
shape_feats = self.final_shape_nonlocal(shape_feats)
shape_feats = torch.flatten(shape_feats, start_dim=1)
pred_pose = self.pose_mlp(pose_feats)
pred_cam = self.cam_mlp(shape_feats)
pred_shape = self.shape_mlp(shape_feats)
if self.use_mean_camshape:
pred_cam = pred_cam + init_cam
pred_shape = pred_shape + init_shape
if self.use_mean_pose:
pred_pose = pred_pose + init_pose
pred_pose = pred_pose.squeeze(-1).transpose(2, 1) # N, J, 6
return pred_pose, pred_shape, pred_cam
def forward_pretraining(self, features):
# TODO: implement pretraining
kp_feats = self.keypoint_deconv_layers(features)
heatmaps = self.keypoint_final_layer(kp_feats)
output = {}
if self.use_heatmaps == 'hm':
# returns coords between [-1,1]
pred_kp2d, confidence = get_heatmap_preds(heatmaps)
output['pred_kp2d'] = pred_kp2d
output['pred_kp2d_conf'] = confidence
elif self.use_heatmaps == 'hm_soft':
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
else:
# returns coords between [-1,1]
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
output['pred_kp2d'] = pred_kp2d
if self.use_keypoint_features_for_smpl_regression:
smpl_feats = kp_feats
else:
smpl_feats = self.smpl_deconv_layers(features)
cam_shape_feats = self.smpl_final_layer(smpl_feats)
output.update({
'kp_feats': heatmaps,
'heatmaps': heatmaps,
'smpl_feats': smpl_feats,
'cam_shape_feats': cam_shape_feats,
})
return output

View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
from .. import config
from .smpl_head import SMPL
class SMPLCamHead(nn.Module):
def __init__(self, img_res=224):
super(SMPLCamHead, self).__init__()
self.smpl = SMPL(config.SMPL_MODEL_DIR, create_transl=False)
self.add_module('smpl', self.smpl)
self.img_res = img_res
def forward(self, rotmat, shape, cam, cam_rotmat, cam_intrinsics,
bbox_scale, bbox_center, img_w, img_h, normalize_joints2d=False):
'''
:param rotmat: rotation in euler angles format (N,J,3,3)
:param shape: smpl betas
:param cam: weak perspective camera
:param normalize_joints2d: bool, normalize joints between -1, 1 if true
:param cam_rotmat (Nx3x3) camera rotation matrix
:param cam_intrinsics (Nx3x3) camera intrinsics matrix
:param bbox_scale (N,) bbox height normalized by 200
:param bbox_center (N,2) bbox center
:param img_w (N,) original image width
:param img_h (N,) original image height
:return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True
'''
smpl_output = self.smpl(
betas=shape,
body_pose=rotmat[:, 1:].contiguous(),
global_orient=rotmat[:, 0].unsqueeze(1).contiguous(),
pose2rot=False,
)
output = {
'smpl_vertices': smpl_output.vertices,
'smpl_joints3d': smpl_output.joints,
}
joints3d = smpl_output.joints
cam_t = convert_pare_to_full_img_cam(
pare_cam=cam,
bbox_height=bbox_scale * 200.,
bbox_center=bbox_center,
img_w=img_w,
img_h=img_h,
focal_length=cam_intrinsics[:, 0, 0],
crop_res=self.img_res,
)
joints2d = perspective_projection(
joints3d,
rotation=cam_rotmat,
translation=cam_t,
cam_intrinsics=cam_intrinsics,
)
# logger.debug(f'PARE cam: {cam}')
# logger.debug(f'FIMG cam: {cam_t}')
# logger.debug(f'joints2d: {joints2d}')
if normalize_joints2d:
# Normalize keypoints to [-1,1]
joints2d = joints2d / (self.img_res / 2.)
output['smpl_joints2d'] = joints2d
output['pred_cam_t'] = cam_t
return output
def perspective_projection(points, rotation, translation, cam_intrinsics):
"""
This function computes the perspective projection of a set of points.
Input:
points (bs, N, 3): 3D points
rotation (bs, 3, 3): Camera rotation
translation (bs, 3): Camera translation
cam_intrinsics (bs, 3, 3): Camera intrinsics
"""
K = cam_intrinsics
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
return projected_points[:, :, :-1]
def convert_pare_to_full_img_cam(
pare_cam, bbox_height, bbox_center,
img_w, img_h, focal_length, crop_res=224):
# Converts weak perspective camera estimated by PARE in
# bbox coords to perspective camera in full image coordinates
# from https://arxiv.org/pdf/2009.06549.pdf
s, tx, ty = pare_cam[:, 0], pare_cam[:, 1], pare_cam[:, 2]
res = 224
r = bbox_height / res
tz = 2 * focal_length / (r * res * s)
cx = 2 * (bbox_center[:, 0] - (img_w / 2.)) / (s * bbox_height)
cy = 2 * (bbox_center[:, 1] - (img_h / 2.)) / (s * bbox_height)
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
return cam_t

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import numpy as np
import torch.nn as nn
from smplx import SMPL as _SMPL
from smplx.utils import SMPLOutput
from smplx.lbs import vertices2joints
from .. import config, constants
from ..utils.geometry import perspective_projection, convert_weak_perspective_to_perspective
class SMPL(_SMPL):
""" Extension of the official SMPL implementation to support more joints """
def __init__(self, *args, **kwargs):
super(SMPL, self).__init__(*args, **kwargs)
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(config.JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
self.joint_map = torch.tensor(joints, dtype=torch.long)
def forward(self, *args, **kwargs):
kwargs['get_skin'] = True
smpl_output = super(SMPL, self).forward(*args, **kwargs)
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
joints = joints[:, self.joint_map, :]
output = SMPLOutput(vertices=smpl_output.vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
betas=smpl_output.betas,
full_pose=smpl_output.full_pose)
return output
class SMPLHead(nn.Module):
def __init__(self, focal_length=5000., img_res=224):
super(SMPLHead, self).__init__()
self.smpl = SMPL(config.SMPL_MODEL_DIR, create_transl=False)
self.add_module('smpl', self.smpl)
self.focal_length = focal_length
self.img_res = img_res
def forward(self, rotmat, shape, cam=None, normalize_joints2d=False):
'''
:param rotmat: rotation in euler angles format (N,J,3,3)
:param shape: smpl betas
:param cam: weak perspective camera
:param normalize_joints2d: bool, normalize joints between -1, 1 if true
:return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True
'''
smpl_output = self.smpl(
betas=shape,
body_pose=rotmat[:, 1:].contiguous(),
global_orient=rotmat[:, 0].unsqueeze(1).contiguous(),
pose2rot=False,
)
output = {
'smpl_vertices': smpl_output.vertices,
'smpl_joints3d': smpl_output.joints,
}
if cam is not None:
joints3d = smpl_output.joints
batch_size = joints3d.shape[0]
device = joints3d.device
cam_t = convert_weak_perspective_to_perspective(
cam,
focal_length=self.focal_length,
img_res=self.img_res,
)
joints2d = perspective_projection(
joints3d,
rotation=torch.eye(3, device=device).unsqueeze(0).expand(batch_size, -1, -1),
translation=cam_t,
focal_length=self.focal_length,
camera_center=torch.zeros(batch_size, 2, device=device)
)
if normalize_joints2d:
# Normalize keypoints to [-1,1]
joints2d = joints2d / (self.img_res / 2.)
output['smpl_joints2d'] = joints2d
output['pred_cam_t'] = cam_t
return output

View File

@ -0,0 +1,4 @@
from .locallyconnected2d import LocallyConnected2d
from .interpolate import interpolate
from .nonlocalattention import NonLocalAttention
from .keypoint_attention import KeypointAttention

View File

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..backbone.resnet import conv1x1, conv3x3
class CoAttention(nn.Module):
def __init__(
self,
n_channel,
final_conv='simple', # 'double_1', 'double_3', 'single_1', 'single_3', 'simple'
):
super(CoAttention, self).__init__()
self.linear_e = nn.Linear(n_channel, n_channel, bias=False)
self.channel = n_channel
# self.dim = all_dim
self.gate = nn.Conv2d(n_channel, 1, kernel_size=1, bias=False)
self.gate_s = nn.Sigmoid()
self.softmax = nn.Sigmoid()
if final_conv.startswith('double'):
kernel_size = int(final_conv[-1])
conv = conv1x1 if kernel_size == 1 else conv3x3
self.final_conv_1 = nn.Sequential(
conv(n_channel * 2, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
conv(n_channel, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
)
self.final_conv_2 = nn.Sequential(
conv(n_channel * 2, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
conv(n_channel, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
)
elif final_conv.startswith('single'):
kernel_size = int(final_conv[-1])
conv = conv1x1 if kernel_size == 1 else conv3x3
self.final_conv_1 = nn.Sequential(
conv(n_channel*2, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
)
self.final_conv_2 = nn.Sequential(
conv(n_channel*2, n_channel),
nn.BatchNorm2d(n_channel),
nn.ReLU(inplace=True),
)
elif final_conv == 'simple':
self.final_conv_1 = conv1x1(n_channel * 2, n_channel)
self.final_conv_2 = conv1x1(n_channel * 2, n_channel)
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, 0.01)
# init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# init.xavier_normal(m.weight.data)
# m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, input_1, input_2):
'''
input_1: [N, C, H, W]
input_2: [N, C, H, W]
'''
b, c, h, w = input_1.shape
exemplar, query = input_1, input_2
exemplar_flat = exemplar.reshape(-1, c, h*w) # N,C,H*W
query_flat = query.reshape(-1, c, h*w)
# Compute coattention scores, S in the paper
exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous() # batch size x dim x num
exemplar_corr = self.linear_e(exemplar_t)
A = torch.bmm(exemplar_corr, query_flat)
A1 = F.softmax(A.clone(), dim=1)
B = F.softmax(torch.transpose(A, 1, 2), dim=1)
query_att = torch.bmm(exemplar_flat, A1)
exemplar_att = torch.bmm(query_flat, B)
input1_att = exemplar_att.reshape(-1, c, h, w)
input2_att = query_att.reshape(-1, c, h, w)
# Apply gating on S, section gated coattention
input1_mask = self.gate(input1_att)
input2_mask = self.gate(input2_att)
input1_mask = self.gate_s(input1_mask)
input2_mask = self.gate_s(input2_mask)
input1_att = input1_att * input1_mask
input2_att = input2_att * input2_mask
# Concatenate inputs with their attended version
input1_att = torch.cat([input1_att, exemplar], 1)
input2_att = torch.cat([input2_att, query], 1)
input1 = self.final_conv_1(input1_att)
input2 = self.final_conv_2(input2_att)
return input1, input2

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
def interpolate(feat, uv):
'''
:param feat: [B, C, H, W] image features
:param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
:return: [B, C, N] image features at the uv coordinates
'''
if uv.shape[-1] != 2:
uv = uv.transpose(1, 2) # [B, N, 2]
uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
if int(torch.__version__.split('.')[1]) < 4:
samples = torch.nn.functional.grid_sample(feat, uv) # [B, C, N, 1]
else:
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
return samples[:, :, :, 0] # [B, C, N]

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class KeypointAttention(nn.Module):
def __init__(self, use_conv=False, in_channels=(256, 64), out_channels=(256, 64), act='softmax', use_scale=False):
super(KeypointAttention, self).__init__()
self.use_conv = use_conv
self.in_channels = in_channels
self.out_channels = out_channels
self.act = act
self.use_scale = use_scale
if use_conv:
self.conv1x1_pose = nn.Conv1d(in_channels[0], out_channels[0], kernel_size=1)
self.conv1x1_shape_cam = nn.Conv1d(in_channels[1], out_channels[1], kernel_size=1)
def forward(self, features, heatmaps):
batch_size, num_joints, height, width = heatmaps.shape
if self.use_scale:
scale = 1.0 / np.sqrt(height * width)
heatmaps = heatmaps * scale
if self.act == 'softmax':
normalized_heatmap = F.softmax(heatmaps.reshape(batch_size, num_joints, -1), dim=-1)
elif self.act == 'sigmoid':
normalized_heatmap = torch.sigmoid(heatmaps.reshape(batch_size, num_joints, -1))
features = features.reshape(batch_size, -1, height*width)
attended_features = torch.matmul(normalized_heatmap, features.transpose(2,1))
attended_features = attended_features.transpose(2,1)
if self.use_conv:
if attended_features.shape[1] == self.in_channels[0]:
attended_features = self.conv1x1_pose(attended_features)
else:
attended_features = self.conv1x1_shape_cam(attended_features)
return attended_features

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
class LocallyConnected2d(nn.Module):
def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
super(LocallyConnected2d, self).__init__()
output_size = _pair(output_size)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size ** 2),
requires_grad=True,
)
if bias:
self.bias = nn.Parameter(
torch.randn(1, out_channels, output_size[0], output_size[1]), requires_grad=True
)
else:
self.register_parameter('bias', None)
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
def forward(self, x):
_, c, h, w = x.size()
kh, kw = self.kernel_size
dh, dw = self.stride
x = x.unfold(2, kh, dh).unfold(3, kw, dw)
x = x.contiguous().view(*x.size()[:-2], -1)
# Sum in in_channel and kernel_size dims
out = (x.unsqueeze(1) * self.weight).sum([2, -1])
if self.bias is not None:
out += self.bias
return out

View File

@ -0,0 +1,152 @@
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
img = torch.zeros(2, 256, 24)
net = NONLocalBlock1D(
in_channels=256, inter_channels=None, sub_sample=False, bn_layer=True
)
out = net(img)
print(out.size())
img = torch.zeros(2, 256, 56, 56)
net = NONLocalBlock2D(
in_channels=256, inter_channels=None, sub_sample=False, bn_layer=True
)
out = net(img)
print(out.size())
# for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
# img = torch.zeros(2, 256, 24)
# net = NONLocalBlock1D(256, inter_channels=24, sub_sample=sub_sample_, bn_layer=bn_layer_)
# out = net(img)
# print(out.size())
#
# img = torch.zeros(2, 3, 20, 20)
# net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
# out = net(img)
# print(out.size())
#
# img = torch.randn(2, 3, 8, 20, 20)
# net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
# out = net(img)
# print(out.size())

View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
import torch.nn.functional as F
class NonLocalAttention(nn.Module):
def __init__(
self,
in_channels=256,
out_channels=256,
):
super(NonLocalAttention, self).__init__()
self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
def forward(self, input):
'''
input [N, Feats, J, 1]
output [N, Feats, J, 1]
'''
batch_size, n_feats, n_joints, _ = input.shape
input = input.squeeze(-1)
# Compute attention weights
attention = torch.matmul(input.transpose(2, 1), input)
norm_attention = F.softmax(attention, dim=-1)
# Compute final dot product
out = torch.matmul(input, norm_attention)
out = self.conv1x1(out)
out = out.unsqueeze(-1) # [N, F, J, 1]
return out
if __name__ == '__main__':
nla = NonLocalAttention()
inp = torch.rand(32, 256, 24, 1)
out = nla(inp)
print(out.shape)

View File

@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn.functional as F
def _softmax(tensor, temperature, dim=-1):
return F.softmax(tensor * temperature, dim=dim)
def softargmax1d(
heatmaps,
temperature=None,
normalize_keypoints=True,
):
dtype, device = heatmaps.dtype, heatmaps.device
if temperature is None:
temperature = torch.tensor(1.0, dtype=dtype, device=device)
batch_size, num_channels, dim = heatmaps.shape
points = torch.arange(0, dim, device=device, dtype=dtype).reshape(1, 1, dim).expand(batch_size, -1, -1)
# y = torch.arange(0, height, device=device, dtype=dtype).reshape(1, 1, height, 1).expand(batch_size, -1, -1, width)
# Should be Bx2xHxW
# points = torch.cat([x, y], dim=1)
normalized_heatmap = _softmax(
heatmaps.reshape(batch_size, num_channels, -1),
temperature=temperature.reshape(1, -1, 1),
dim=-1)
# Should be BxJx2
keypoints = (normalized_heatmap.reshape(batch_size, -1, dim) * points).sum(dim=-1)
if normalize_keypoints:
# Normalize keypoints to [-1, 1]
keypoints = (keypoints / (dim - 1) * 2 - 1)
return keypoints, normalized_heatmap.reshape(
batch_size, -1, dim)
def softargmax2d(
heatmaps,
temperature=None,
normalize_keypoints=True,
):
dtype, device = heatmaps.dtype, heatmaps.device
if temperature is None:
temperature = torch.tensor(1.0, dtype=dtype, device=device)
batch_size, num_channels, height, width = heatmaps.shape
x = torch.arange(0, width, device=device, dtype=dtype).reshape(1, 1, 1, width).expand(batch_size, -1, height, -1)
y = torch.arange(0, height, device=device, dtype=dtype).reshape(1, 1, height, 1).expand(batch_size, -1, -1, width)
# Should be Bx2xHxW
points = torch.cat([x, y], dim=1)
normalized_heatmap = _softmax(
heatmaps.reshape(batch_size, num_channels, -1),
temperature=temperature.reshape(1, -1, 1),
dim=-1)
# Should be BxJx2
keypoints = (
normalized_heatmap.reshape(batch_size, -1, 1, height * width) *
points.reshape(batch_size, 1, 2, -1)).sum(dim=-1)
if normalize_keypoints:
# Normalize keypoints to [-1, 1]
keypoints[:, :, 0] = (keypoints[:, :, 0] / (width - 1) * 2 - 1)
keypoints[:, :, 1] = (keypoints[:, :, 1] / (height - 1) * 2 - 1)
return keypoints, normalized_heatmap.reshape(
batch_size, -1, height, width)
def softargmax3d(
heatmaps,
temperature=None,
normalize_keypoints=True,
):
dtype, device = heatmaps.dtype, heatmaps.device
if temperature is None:
temperature = torch.tensor(1.0, dtype=dtype, device=device)
batch_size, num_channels, height, width, depth = heatmaps.shape
x = torch.arange(0, width, device=device, dtype=dtype).reshape(1, 1, 1, width, 1).expand(batch_size, -1, height, -1, depth)
y = torch.arange(0, height, device=device, dtype=dtype).reshape(1, 1, height, 1, 1).expand(batch_size, -1, -1, width, depth)
z = torch.arange(0, depth, device=device, dtype=dtype).reshape(1, 1, 1, 1, depth).expand(batch_size, -1, height, width, -1)
# Should be Bx2xHxW
points = torch.cat([x, y, z], dim=1)
normalized_heatmap = _softmax(
heatmaps.reshape(batch_size, num_channels, -1),
temperature=temperature.reshape(1, -1, 1),
dim=-1)
# Should be BxJx3
keypoints = (
normalized_heatmap.reshape(batch_size, -1, 1, height * width * depth) *
points.reshape(batch_size, 1, 3, -1)).sum(dim=-1)
if normalize_keypoints:
# Normalize keypoints to [-1, 1]
keypoints[:, :, 0] = (keypoints[:, :, 0] / (width - 1) * 2 - 1)
keypoints[:, :, 1] = (keypoints[:, :, 1] / (height - 1) * 2 - 1)
keypoints[:, :, 2] = (keypoints[:, :, 2] / (depth - 1) * 2 - 1)
return keypoints, normalized_heatmap.reshape(
batch_size, -1, height, width, depth)
def get_heatmap_preds(batch_heatmaps, normalize_keypoints=True):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
height = batch_heatmaps.shape[2]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
maxvals, idx = torch.max(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = idx.repeat(1, 1, 2).float()
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = torch.floor((preds[:, :, 1]) / width)
pred_mask = torch.gt(maxvals, 0.0).repeat(1, 1, 2)
pred_mask = pred_mask.float()
preds *= pred_mask
if normalize_keypoints:
# Normalize keypoints to [-1, 1]
preds[:, :, 0] = (preds[:, :, 0] / (width - 1) * 2 - 1)
preds[:, :, 1] = (preds[:, :, 1] / (height - 1) * 2 - 1)
return preds, maxvals

View File

@ -0,0 +1,262 @@
import os
import torch
import torch.nn as nn
from .config import update_hparams
# from .head import PareHead, SMPLHead, SMPLCamHead
from .head import PareHead
from .backbone.utils import get_backbone_info
from .backbone.hrnet import hrnet_w32
from os.path import join
from easymocap.multistage.torchgeometry import rotation_matrix_to_axis_angle
import cv2
def try_to_download():
model_dir = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'pare')
cmd = 'wget https://www.dropbox.com/s/aeulffqzb3zmh8x/pare-github-data.zip'
os.system(cmd)
os.makedirs(model_dir, exist_ok=True)
cmd = 'unzip pare-github-data.zip -d {}'.format(model_dir)
os.system(cmd)
CFG = 'models/pare/data/pare/checkpoints/pare_w_3dpw_config.yaml'
CKPT = 'models/pare/data/pare/checkpoints/pare_w_3dpw_checkpoint.ckpt'
class PARE(nn.Module):
def __init__(
self,
num_joints=24,
softmax_temp=1.0,
num_features_smpl=64,
backbone='resnet50',
focal_length=5000.,
img_res=224,
pretrained=None,
iterative_regression=False,
iter_residual=False,
num_iterations=3,
shape_input_type='feats', # 'feats.all_pose.shape.cam',
pose_input_type='feats', # 'feats.neighbor_pose_feats.all_pose.self_pose.neighbor_pose.shape.cam'
pose_mlp_num_layers=1,
shape_mlp_num_layers=1,
pose_mlp_hidden_size=256,
shape_mlp_hidden_size=256,
use_keypoint_features_for_smpl_regression=False,
use_heatmaps='',
use_keypoint_attention=False,
keypoint_attention_act='softmax',
use_postconv_keypoint_attention=False,
use_scale_keypoint_attention=False,
use_final_nonlocal=None,
use_branch_nonlocal=None,
use_hmr_regression=False,
use_coattention=False,
num_coattention_iter=1,
coattention_conv='simple',
deconv_conv_kernel_size=4,
use_upsampling=False,
use_soft_attention=False,
num_branch_iteration=0,
branch_deeper=False,
num_deconv_layers=3,
num_deconv_filters=256,
use_resnet_conv_hrnet=False,
use_position_encodings=None,
use_mean_camshape=False,
use_mean_pose=False,
init_xavier=False,
use_cam=False,
):
super(PARE, self).__init__()
if backbone.startswith('hrnet'):
backbone, use_conv = backbone.split('-')
# hrnet_w32-conv, hrnet_w32-interp
self.backbone = eval(backbone)(
pretrained=True,
downsample=False,
use_conv=(use_conv == 'conv')
)
else:
self.backbone = eval(backbone)(pretrained=True)
# self.backbone = eval(backbone)(pretrained=True)
self.head = PareHead(
num_joints=num_joints,
num_input_features=get_backbone_info(backbone)['n_output_channels'],
softmax_temp=softmax_temp,
num_deconv_layers=num_deconv_layers,
num_deconv_filters=[num_deconv_filters] * num_deconv_layers,
num_deconv_kernels=[deconv_conv_kernel_size] * num_deconv_layers,
num_features_smpl=num_features_smpl,
final_conv_kernel=1,
iterative_regression=iterative_regression,
iter_residual=iter_residual,
num_iterations=num_iterations,
shape_input_type=shape_input_type,
pose_input_type=pose_input_type,
pose_mlp_num_layers=pose_mlp_num_layers,
shape_mlp_num_layers=shape_mlp_num_layers,
pose_mlp_hidden_size=pose_mlp_hidden_size,
shape_mlp_hidden_size=shape_mlp_hidden_size,
use_keypoint_features_for_smpl_regression=use_keypoint_features_for_smpl_regression,
use_heatmaps=use_heatmaps,
use_keypoint_attention=use_keypoint_attention,
use_postconv_keypoint_attention=use_postconv_keypoint_attention,
keypoint_attention_act=keypoint_attention_act,
use_scale_keypoint_attention=use_scale_keypoint_attention,
use_branch_nonlocal=use_branch_nonlocal, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
use_final_nonlocal=use_final_nonlocal, # 'concatenation', 'dot_product', 'embedded_gaussian', 'gaussian'
backbone=backbone,
use_hmr_regression=use_hmr_regression,
use_coattention=use_coattention,
num_coattention_iter=num_coattention_iter,
coattention_conv=coattention_conv,
use_upsampling=use_upsampling,
use_soft_attention=use_soft_attention,
num_branch_iteration=num_branch_iteration,
branch_deeper=branch_deeper,
use_resnet_conv_hrnet=use_resnet_conv_hrnet,
use_position_encodings=use_position_encodings,
use_mean_camshape=use_mean_camshape,
use_mean_pose=use_mean_pose,
init_xavier=init_xavier,
)
self.use_cam = use_cam
# if self.use_cam:
# self.smpl = SMPLCamHead(
# img_res=img_res,
# )
# else:
# self.smpl = SMPLHead(
# focal_length=focal_length,
# img_res=img_res
# )
if pretrained is not None:
self.load_pretrained(pretrained)
def forward(
self,
images,
gt_segm=None,
):
features = self.backbone(images)
hmr_output = self.head(features, gt_segm=gt_segm)
rotmat = hmr_output['pred_pose']
shape = hmr_output['pred_shape']
rotmat_flat = rotmat.reshape(-1, 3, 3)
rvec_flat = rotation_matrix_to_axis_angle(rotmat_flat)
rvec = rvec_flat.reshape(*rotmat.shape[:-2], 3)
rvec = rvec.reshape(*rvec.shape[:-2], -1)
return {
'Rh': rvec[..., :3],
'Th': torch.zeros_like(rvec[..., :3]),
'poses': rvec[..., 3:],
'shapes': shape,
}
from ..basetopdown import BaseTopDownModelCache
import pickle
class NullSPIN:
def __init__(self, ckpt) -> None:
self.name = 'spin'
def __call__(self, bbox, images, imgname):
from easymocap.mytools.reader import read_smpl
basename = os.path.basename(imgname)
cachename = join(self.output, self.name, basename.replace('.jpg', '.json'))
if os.path.exists(cachename):
params = read_smpl(cachename)
params = params[0]
params = {key:val[0] for key, val in params.items() if key != 'id'}
ret = {
'params': params
}
return ret
else:
import ipdb; ipdb.set_trace()
class MyPARE(BaseTopDownModelCache):
def __init__(self, ckpt) -> None:
super().__init__('pare', bbox_scale=1.1, res_input=224)
if not os.path.exists(CFG):
from ...io.model import try_to_download_SMPL
try_to_download_SMPL('models/pare')
self.model_cfg = update_hparams(CFG)
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = self._build_model()
self._load_pretrained_model(CKPT)
self.model.eval()
self.model.to(self.device)
def __call__(self, bbox, images, imgnames):
return super().__call__(bbox[0], images, imgnames)
def _build_model(self):
# ========= Define PARE model ========= #
model_cfg = self.model_cfg
if model_cfg.METHOD == 'pare':
model = PARE(
backbone=model_cfg.PARE.BACKBONE,
num_joints=model_cfg.PARE.NUM_JOINTS,
softmax_temp=model_cfg.PARE.SOFTMAX_TEMP,
num_features_smpl=model_cfg.PARE.NUM_FEATURES_SMPL,
focal_length=model_cfg.DATASET.FOCAL_LENGTH,
img_res=model_cfg.DATASET.IMG_RES,
pretrained=model_cfg.TRAINING.PRETRAINED,
iterative_regression=model_cfg.PARE.ITERATIVE_REGRESSION,
num_iterations=model_cfg.PARE.NUM_ITERATIONS,
iter_residual=model_cfg.PARE.ITER_RESIDUAL,
shape_input_type=model_cfg.PARE.SHAPE_INPUT_TYPE,
pose_input_type=model_cfg.PARE.POSE_INPUT_TYPE,
pose_mlp_num_layers=model_cfg.PARE.POSE_MLP_NUM_LAYERS,
shape_mlp_num_layers=model_cfg.PARE.SHAPE_MLP_NUM_LAYERS,
pose_mlp_hidden_size=model_cfg.PARE.POSE_MLP_HIDDEN_SIZE,
shape_mlp_hidden_size=model_cfg.PARE.SHAPE_MLP_HIDDEN_SIZE,
use_keypoint_features_for_smpl_regression=model_cfg.PARE.USE_KEYPOINT_FEATURES_FOR_SMPL_REGRESSION,
use_heatmaps=model_cfg.DATASET.USE_HEATMAPS,
use_keypoint_attention=model_cfg.PARE.USE_KEYPOINT_ATTENTION,
use_postconv_keypoint_attention=model_cfg.PARE.USE_POSTCONV_KEYPOINT_ATTENTION,
use_scale_keypoint_attention=model_cfg.PARE.USE_SCALE_KEYPOINT_ATTENTION,
keypoint_attention_act=model_cfg.PARE.KEYPOINT_ATTENTION_ACT,
use_final_nonlocal=model_cfg.PARE.USE_FINAL_NONLOCAL,
use_branch_nonlocal=model_cfg.PARE.USE_BRANCH_NONLOCAL,
use_hmr_regression=model_cfg.PARE.USE_HMR_REGRESSION,
use_coattention=model_cfg.PARE.USE_COATTENTION,
num_coattention_iter=model_cfg.PARE.NUM_COATTENTION_ITER,
coattention_conv=model_cfg.PARE.COATTENTION_CONV,
use_upsampling=model_cfg.PARE.USE_UPSAMPLING,
deconv_conv_kernel_size=model_cfg.PARE.DECONV_CONV_KERNEL_SIZE,
use_soft_attention=model_cfg.PARE.USE_SOFT_ATTENTION,
num_branch_iteration=model_cfg.PARE.NUM_BRANCH_ITERATION,
branch_deeper=model_cfg.PARE.BRANCH_DEEPER,
num_deconv_layers=model_cfg.PARE.NUM_DECONV_LAYERS,
num_deconv_filters=model_cfg.PARE.NUM_DECONV_FILTERS,
use_resnet_conv_hrnet=model_cfg.PARE.USE_RESNET_CONV_HRNET,
use_position_encodings=model_cfg.PARE.USE_POS_ENC,
use_mean_camshape=model_cfg.PARE.USE_MEAN_CAMSHAPE,
use_mean_pose=model_cfg.PARE.USE_MEAN_POSE,
init_xavier=model_cfg.PARE.INIT_XAVIER,
).to(self.device)
else:
exit()
return model
def _load_pretrained_model(self, ckpt):
# ========= Load pretrained weights ========= #
state_dict = torch.load(ckpt, map_location='cpu')['state_dict']
pretrained_keys = state_dict.keys()
new_state_dict = {}
for pk in pretrained_keys:
if pk.startswith('model.'):
new_state_dict[pk.replace('model.', '')] = state_dict[pk]
else:
new_state_dict[pk] = state_dict[pk]
self.model.load_state_dict(new_state_dict, strict=False)
if __name__ == '__main__':
pass

View File

@ -0,0 +1,722 @@
import torch
import numpy as np
from torch.nn import functional as F
"""
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
"""
def batch_rot2aa(Rs):
"""
Rs is B x 3 x 3
void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis,
double& out_theta)
{
double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1);
c = cMathUtil::Clamp(c, -1.0, 1.0);
out_theta = std::acos(c);
if (std::abs(out_theta) < 0.00001)
{
out_axis = tVector(0, 0, 1, 0);
}
else
{
double m21 = mat(2, 1) - mat(1, 2);
double m02 = mat(0, 2) - mat(2, 0);
double m10 = mat(1, 0) - mat(0, 1);
double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10);
out_axis[0] = m21 / denom;
out_axis[1] = m02 / denom;
out_axis[2] = m10 / denom;
out_axis[3] = 0;
}
}
"""
cos = 0.5 * (torch.stack([torch.trace(x) for x in Rs]) - 1)
cos = torch.clamp(cos, -1, 1)
theta = torch.acos(cos)
m21 = Rs[:, 2, 1] - Rs[:, 1, 2]
m02 = Rs[:, 0, 2] - Rs[:, 2, 0]
m10 = Rs[:, 1, 0] - Rs[:, 0, 1]
denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10)
axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom)
axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom)
axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom)
return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)
def batch_rodrigues(theta):
"""Convert axis-angle representation to rotation matrix.
Args:
theta: size = [B, 3]
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
angle = torch.unsqueeze(l1norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
return quat_to_rotmat(quat)
def quat_to_rotmat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
B = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w*x, w*y, w*z
xy, xz, yz = x*y, x*z, y*z
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
return rotMat
def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.reshape(-1,3,2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def rotmat_to_rot6d(x):
rotmat = x.reshape(-1, 3, 3)
rot6d = rotmat[:, :, :2].reshape(x.shape[0], -1)
return rot6d
def rotation_matrix_to_angle_axis(rotation_matrix):
"""
This function is borrowed from https://github.com/kornia/kornia
Convert 3x4 rotation matrix to Rodrigues vector
Args:
rotation_matrix (Tensor): rotation matrix.
Returns:
Tensor: Rodrigues vector transformation.
Shape:
- Input: :math:`(N, 3, 4)`
- Output: :math:`(N, 3)`
Example:
>>> input = torch.rand(2, 3, 4) # Nx4x4
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
"""
if rotation_matrix.shape[1:] == (3,3):
rot_mat = rotation_matrix.reshape(-1, 3, 3)
hom = torch.tensor([0, 0, 1], dtype=torch.float32,
device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
aa = quaternion_to_angle_axis(quaternion)
aa[torch.isnan(aa)] = 0.0
return aa
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
"""
This function is borrowed from https://github.com/kornia/kornia
Convert quaternion vector to angle axis of rotation.
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
Args:
quaternion (torch.Tensor): tensor with quaternions.
Return:
torch.Tensor: tensor with angle axis of rotation.
Shape:
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
- Output: :math:`(*, 3)`
Example:
>>> quaternion = torch.rand(2, 4) # Nx4
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
"""
if not torch.is_tensor(quaternion):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(quaternion)))
if not quaternion.shape[-1] == 4:
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
.format(quaternion.shape))
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
q3: torch.Tensor = quaternion[..., 3]
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
cos_theta: torch.Tensor = quaternion[..., 0]
two_theta: torch.Tensor = 2.0 * torch.where(
cos_theta < 0.0,
torch.atan2(-sin_theta, -cos_theta),
torch.atan2(sin_theta, cos_theta))
k_pos: torch.Tensor = two_theta / sin_theta
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
angle_axis[..., 0] += q1 * k
angle_axis[..., 1] += q2 * k
angle_axis[..., 2] += q3 * k
return angle_axis
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
"""
This function is borrowed from https://github.com/kornia/kornia
Convert 3x4 rotation matrix to 4d quaternion vector
This algorithm is based on algorithm described in
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
Args:
rotation_matrix (Tensor): the rotation matrix to convert.
Return:
Tensor: the rotation in quaternion
Shape:
- Input: :math:`(N, 3, 4)`
- Output: :math:`(N, 4)`
Example:
>>> input = torch.rand(4, 3, 4) # Nx3x4
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
"""
if not torch.is_tensor(rotation_matrix):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
raise ValueError(
"Input size must be a three dimensional tensor. Got {}".format(
rotation_matrix.shape))
if not rotation_matrix.shape[-2:] == (3, 4):
raise ValueError(
"Input size must be a N x 3 x 4 tensor. Got {}".format(
rotation_matrix.shape))
rmat_t = torch.transpose(rotation_matrix, 1, 2)
mask_d2 = rmat_t[:, 2, 2] < eps
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
t0_rep = t0.repeat(4, 1).t()
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
t1_rep = t1.repeat(4, 1).t()
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
t2_rep = t2.repeat(4, 1).t()
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
t3_rep = t3.repeat(4, 1).t()
mask_c0 = mask_d2 * mask_d0_d1
mask_c1 = mask_d2 * ~mask_d0_d1
mask_c2 = ~mask_d2 * mask_d0_nd1
mask_c3 = ~mask_d2 * ~mask_d0_nd1
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q *= 0.5
return q
def convert_perspective_to_weak_perspective(
perspective_camera,
focal_length=5000.,
img_res=224,
):
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
# in 3D given the bounding box size
# This camera translation can be used in a full perspective projection
# if isinstance(focal_length, torch.Tensor):
# focal_length = focal_length[:, 0]
weak_perspective_camera = torch.stack(
[
2 * focal_length / (img_res * perspective_camera[:, 2] + 1e-9),
perspective_camera[:, 0],
perspective_camera[:, 1],
],
dim=-1
)
return weak_perspective_camera
def convert_weak_perspective_to_perspective(
weak_perspective_camera,
focal_length=5000.,
img_res=224,
):
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
# in 3D given the bounding box size
# This camera translation can be used in a full perspective projection
# if isinstance(focal_length, torch.Tensor):
# focal_length = focal_length[:, 0]
perspective_camera = torch.stack(
[
weak_perspective_camera[:, 1],
weak_perspective_camera[:, 2],
2 * focal_length / (img_res * weak_perspective_camera[:, 0] + 1e-9)
],
dim=-1
)
return perspective_camera
def perspective_projection(points, rotation, translation,
focal_length, camera_center):
"""
This function computes the perspective projection of a set of points.
Input:
points (bs, N, 3): 3D points
rotation (bs, 3, 3): Camera rotation
translation (bs, 3): Camera translation
focal_length (bs,) or scalar: Focal length
camera_center (bs, 2): Camera center
"""
batch_size = points.shape[0]
K = torch.zeros([batch_size, 3, 3], device=points.device)
K[:,0,0] = focal_length
K[:,1,1] = focal_length
K[:,2,2] = 1.
K[:,:-1, -1] = camera_center
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
return projected_points[:, :, :-1]
def weak_perspective_projection(points, rotation, weak_cam_params, focal_length, camera_center, img_res):
"""
This function computes the perspective projection of a set of points.
Input:
points (bs, N, 3): 3D points
rotation (bs, 3, 3): Camera rotation
translation (bs, 3): Camera translation
focal_length (bs,) or scalar: Focal length
camera_center (bs, 2): Camera center
"""
batch_size = points.shape[0]
K = torch.zeros([batch_size, 3, 3], device=points.device)
K[:,0,0] = focal_length
K[:,1,1] = focal_length
K[:,2,2] = 1.
K[:,:-1, -1] = camera_center
translation = convert_weak_perspective_to_perspective(weak_cam_params, focal_length, img_res)
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
return projected_points[:, :, :-1]
def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
Input:
S: (25, 3) 3D joint locations
joints: (25, 3) 2D joint locations and confidence
Returns:
(3,) camera translation vector
"""
num_joints = S.shape[0]
# focal length
f = np.array([focal_length,focal_length])
# optical center
center = np.array([img_size/2., img_size/2.])
# transformations
Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
XY = np.reshape(S[:,0:2],-1)
O = np.tile(center,num_joints)
F = np.tile(f,num_joints)
weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
# least squares
Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
# weighted least squares
W = np.diagflat(weight2)
Q = np.dot(W,Q)
c = np.dot(W,c)
# square matrix
A = np.dot(Q.T,Q)
b = np.dot(Q.T,c)
# solution
trans = np.linalg.solve(A, b)
return trans
def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_all_joints=False, rotation=None):
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
Input:
S: (B, 49, 3) 3D joint locations
joints: (B, 49, 3) 2D joint locations and confidence
Returns:
(B, 3) camera translation vectors
"""
device = S.device
if rotation is not None:
S = torch.einsum('bij,bkj->bki', rotation, S)
# Use only joints 25:49 (GT joints)
if use_all_joints:
S = S.cpu().numpy()
joints_2d = joints_2d.cpu().numpy()
else:
S = S[:, 25:, :].cpu().numpy()
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
joints_conf = joints_2d[:, :, -1]
joints_2d = joints_2d[:, :, :-1]
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
# Find the translation for each example in the batch
for i in range(S.shape[0]):
S_i = S[i]
joints_i = joints_2d[i]
conf_i = joints_conf[i]
trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
return torch.from_numpy(trans).to(device)
def estimate_translation_cam(S, joints_2d, focal_length=(5000., 5000.), img_size=(224., 224.),
use_all_joints=False, rotation=None):
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
Input:
S: (B, 49, 3) 3D joint locations
joints: (B, 49, 3) 2D joint locations and confidence
Returns:
(B, 3) camera translation vectors
"""
def estimate_translation_np(S, joints_2d, joints_conf, focal_length=(5000., 5000.), img_size=(224., 224.)):
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
Input:
S: (25, 3) 3D joint locations
joints: (25, 3) 2D joint locations and confidence
Returns:
(3,) camera translation vector
"""
num_joints = S.shape[0]
# focal length
f = np.array([focal_length[0], focal_length[1]])
# optical center
center = np.array([img_size[0] / 2., img_size[1] / 2.])
# transformations
Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
XY = np.reshape(S[:, 0:2], -1)
O = np.tile(center, num_joints)
F = np.tile(f, num_joints)
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
# least squares
Q = np.array([F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
O - np.reshape(joints_2d, -1)]).T
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
# weighted least squares
W = np.diagflat(weight2)
Q = np.dot(W, Q)
c = np.dot(W, c)
# square matrix
A = np.dot(Q.T, Q)
b = np.dot(Q.T, c)
# solution
trans = np.linalg.solve(A, b)
return trans
device = S.device
if rotation is not None:
S = torch.einsum('bij,bkj->bki', rotation, S)
# Use only joints 25:49 (GT joints)
if use_all_joints:
S = S.cpu().numpy()
joints_2d = joints_2d.cpu().numpy()
else:
S = S[:, 25:, :].cpu().numpy()
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
joints_conf = joints_2d[:, :, -1]
joints_2d = joints_2d[:, :, :-1]
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
# Find the translation for each example in the batch
for i in range(S.shape[0]):
S_i = S[i]
joints_i = joints_2d[i]
conf_i = joints_conf[i]
trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
return torch.from_numpy(trans).to(device)
def get_coord_maps(size=56):
xx_ones = torch.ones([1, size], dtype=torch.int32)
xx_ones = xx_ones.unsqueeze(-1)
xx_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
xx_range = xx_range.unsqueeze(1)
xx_channel = torch.matmul(xx_ones, xx_range)
xx_channel = xx_channel.unsqueeze(-1)
yy_ones = torch.ones([1, size], dtype=torch.int32)
yy_ones = yy_ones.unsqueeze(1)
yy_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
yy_range = yy_range.unsqueeze(-1)
yy_channel = torch.matmul(yy_range, yy_ones)
yy_channel = yy_channel.unsqueeze(-1)
xx_channel = xx_channel.permute(0, 3, 1, 2)
yy_channel = yy_channel.permute(0, 3, 1, 2)
xx_channel = xx_channel.float() / (size - 1)
yy_channel = yy_channel.float() / (size - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
out = torch.cat([xx_channel, yy_channel], dim=1)
return out
def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5):
at = at.astype(float).reshape(1, 3)
up = up.astype(float).reshape(1, 3)
eye = eye.reshape(-1, 3)
up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
z_axis = eye - at
z_axis /= np.max(np.stack([np.linalg.norm(z_axis, axis=1, keepdims=True), eps]))
x_axis = np.cross(up, z_axis)
x_axis /= np.max(np.stack([np.linalg.norm(x_axis, axis=1, keepdims=True), eps]))
y_axis = np.cross(z_axis, x_axis)
y_axis /= np.max(np.stack([np.linalg.norm(y_axis, axis=1, keepdims=True), eps]))
r_mat = np.concatenate((x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(-1, 3, 1)), axis=2)
return r_mat
def to_sphere(u, v):
theta = 2 * np.pi * u
phi = np.arccos(1 - 2 * v)
cx = np.sin(phi) * np.cos(theta)
cy = np.sin(phi) * np.sin(theta)
cz = np.cos(phi)
s = np.stack([cx, cy, cz])
return s
def sample_on_sphere(range_u=(0, 1), range_v=(0, 1)):
u = np.random.uniform(*range_u)
v = np.random.uniform(*range_v)
return to_sphere(u, v)
def sample_pose_on_sphere(range_v=(0,1), range_u=(0,1), radius=1, up=[0,1,0]):
# sample location on unit sphere
loc = sample_on_sphere(range_u, range_v)
# sample radius if necessary
if isinstance(radius, tuple):
radius = np.random.uniform(*radius)
loc = loc * radius
R = look_at(loc, up=np.array(up))[0]
RT = np.concatenate([R, loc.reshape(3, 1)], axis=1)
RT = torch.Tensor(RT.astype(np.float32))
return RT
def rectify_pose(camera_r, body_aa, rotate_x=False):
body_r = batch_rodrigues(body_aa).reshape(-1,3,3)
if rotate_x:
rotate_x = torch.tensor([[[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]])
body_r = body_r @ rotate_x
final_r = camera_r @ body_r
body_aa = batch_rot2aa(final_r)
return body_aa
def batch_euler2matrix(r):
return quaternion_to_rotation_matrix(euler_to_quaternion(r))
def euler_to_quaternion(r):
x = r[..., 0]
y = r[..., 1]
z = r[..., 2]
z = z/2.0
y = y/2.0
x = x/2.0
cz = torch.cos(z)
sz = torch.sin(z)
cy = torch.cos(y)
sy = torch.sin(y)
cx = torch.cos(x)
sx = torch.sin(x)
quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device)
quaternion[..., 0] += cx*cy*cz - sx*sy*sz
quaternion[..., 1] += cx*sy*sz + cy*cz*sx
quaternion[..., 2] += cx*cz*sy - sx*cy*sz
quaternion[..., 3] += cx*cy*sz + sx*cz*sy
return quaternion
def quaternion_to_rotation_matrix(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
B = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
return rotMat
def euler_angles_from_rotmat(R):
"""
computer euler angles for rotation around x, y, z axis
from rotation amtrix
R: 4x4 rotation matrix
https://www.gregslabaugh.net/publications/euler.pdf
"""
r21 = np.round(R[:, 2, 0].item(), 4)
if abs(r21) != 1:
y_angle1 = -1 * torch.asin(R[:, 2, 0])
y_angle2 = math.pi + torch.asin(R[:, 2, 0])
cy1, cy2 = torch.cos(y_angle1), torch.cos(y_angle2)
x_angle1 = torch.atan2(R[:, 2, 1] / cy1, R[:, 2, 2] / cy1)
x_angle2 = torch.atan2(R[:, 2, 1] / cy2, R[:, 2, 2] / cy2)
z_angle1 = torch.atan2(R[:, 1, 0] / cy1, R[:, 0, 0] / cy1)
z_angle2 = torch.atan2(R[:, 1, 0] / cy2, R[:, 0, 0] / cy2)
s1 = (x_angle1, y_angle1, z_angle1)
s2 = (x_angle2, y_angle2, z_angle2)
s = (s1, s2)
else:
z_angle = torch.tensor([0], device=R.device).float()
if r21 == -1:
y_angle = torch.tensor([math.pi / 2], device=R.device).float()
x_angle = z_angle + torch.atan2(R[:, 0, 1], R[:, 0, 2])
else:
y_angle = -torch.tensor([math.pi / 2], device=R.device).float()
x_angle = -z_angle + torch.atan2(-R[:, 0, 1], R[:, 0, 2])
s = ((x_angle, y_angle, z_angle),)
return s

File diff suppressed because it is too large Load Diff