EasyMocap/easymocap/pyfitting/lossfactory.py

419 lines
16 KiB
Python
Raw Normal View History

2021-04-14 15:22:51 +08:00
'''
@ Date: 2020-11-19 17:46:04
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-04-14 11:46:56
@ FilePath: /EasyMocap/easymocap/pyfitting/lossfactory.py
'''
import numpy as np
import torch
from .operation import projection, batch_rodrigues
funcl2 = lambda x: torch.sum(x**2)
funcl1 = lambda x: torch.sum(torch.abs(x**2))
def gmof(squared_res, sigma_squared):
"""
Geman-McClure error function
"""
return (sigma_squared * squared_res) / (sigma_squared + squared_res)
def ReprojectionLoss(keypoints3d, keypoints2d, K, Rc, Tc, inv_bbox_sizes, norm='l2'):
img_points = projection(keypoints3d, K, Rc, Tc)
residual = (img_points - keypoints2d[:, :, :2]) * keypoints2d[:, :, -1:]
# squared_res: (nFrames, nJoints, 2)
if norm == 'l2':
squared_res = (residual ** 2) * inv_bbox_sizes
elif norm == 'l1':
squared_res = torch.abs(residual) * inv_bbox_sizes
else:
import ipdb; ipdb.set_trace()
return torch.sum(squared_res)
class LossKeypoints3D:
def __init__(self, keypoints3d, cfg, norm='l2') -> None:
self.cfg = cfg
keypoints3d = torch.Tensor(keypoints3d).to(cfg.device)
self.nJoints = keypoints3d.shape[1]
self.keypoints3d = keypoints3d[..., :3]
self.conf = keypoints3d[..., 3:]
self.nFrames = keypoints3d.shape[0]
self.norm = norm
def loss(self, diff_square):
if self.norm == 'l2':
loss_3d = funcl2(diff_square)
elif self.norm == 'l1':
loss_3d = funcl1(diff_square)
elif self.norm == 'gm':
# 阈值设为0.2^2米
loss_3d = torch.sum(gmof(diff_square**2, 0.04))
else:
raise NotImplementedError
return loss_3d/self.nFrames
def body(self, kpts_est, **kwargs):
"distance of keypoints3d"
nJoints = min([kpts_est.shape[1], self.keypoints3d.shape[1], 25])
diff_square = (kpts_est[:, :nJoints, :3] - self.keypoints3d[:, :nJoints, :3])*self.conf[:, :nJoints]
return self.loss(diff_square)
def hand(self, kpts_est, **kwargs):
"distance of 3d hand keypoints"
diff_square = (kpts_est[:, 25:25+42, :3] - self.keypoints3d[:, 25:25+42, :3])*self.conf[:, 25:25+42]
return self.loss(diff_square)
def face(self, kpts_est, **kwargs):
"distance of 3d face keypoints"
diff_square = (kpts_est[:, 25+42:, :3] - self.keypoints3d[:, 25+42:, :3])*self.conf[:, 25+42:]
return self.loss(diff_square)
def __str__(self) -> str:
return 'Loss function for keypoints3D, norm = {}'.format(self.norm)
class LossRegPoses:
def __init__(self, cfg) -> None:
self.cfg = cfg
def reg_hand(self, poses, **kwargs):
"regulizer for hand pose"
assert self.cfg.model in ['smplh', 'smplx']
hand_poses = poses[:, 66:78]
loss = funcl2(hand_poses)
return loss/poses.shape[0]
def reg_head(self, poses, **kwargs):
"regulizer for head pose"
assert self.cfg.model in ['smplx']
poses = poses[:, 78:]
loss = funcl2(poses)
return loss/poses.shape[0]
def reg_expr(self, expression, **kwargs):
"regulizer for expression"
assert self.cfg.model in ['smplh', 'smplx']
return torch.sum(expression**2)
def reg_body(self, poses, **kwargs):
"regulizer for body poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, :66]
loss = funcl2(poses)
return loss/poses.shape[0]
def __str__(self) -> str:
return 'Loss function for Regulizer of Poses'
class LossRegPosesZero:
def __init__(self, keypoints, cfg) -> None:
model_type = cfg.model
if keypoints.shape[-2] <= 15:
use_feet = False
use_head = False
else:
use_feet = keypoints[..., [19, 20, 21, 22, 23, 24], -1].sum() > 0.1
use_head = keypoints[..., [15, 16, 17, 18], -1].sum() > 0.1
if model_type == 'smpl':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14, 20, 21, 22, 23]
elif model_type == 'smplh':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
elif model_type == 'smplx':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
else:
raise NotImplementedError
if not use_feet:
SMPL_JOINT_ZERO_IDX.extend([7, 8])
if not use_head:
SMPL_JOINT_ZERO_IDX.extend([12, 15])
SMPL_POSES_ZERO_IDX = [[j for j in range(3*i, 3*i+3)] for i in SMPL_JOINT_ZERO_IDX]
SMPL_POSES_ZERO_IDX = sum(SMPL_POSES_ZERO_IDX, [])
# SMPL_POSES_ZERO_IDX.extend([36, 37, 38, 45, 46, 47])
self.idx = SMPL_POSES_ZERO_IDX
def __call__(self, poses, **kwargs):
"regulizer for zero joints"
return torch.sum(torch.abs(poses[:, self.idx]))/poses.shape[0]
def __str__(self) -> str:
return 'Loss function for Regulizer of Poses'
class LossSmoothBody:
def __init__(self, cfg) -> None:
self.norm = 'l2'
def __call__(self, kpts_est, **kwargs):
N_BODY = min(25, kpts_est.shape[1])
assert kpts_est.shape[0] > 1, 'If you use smooth loss, it must be more than 1 frames'
if self.norm == 'l2':
loss = funcl2(kpts_est[:-1, :N_BODY] - kpts_est[1:, :N_BODY])
else:
loss = funcl1(kpts_est[:-1, :N_BODY] - kpts_est[1:, :N_BODY])
return loss/kpts_est.shape[0]
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothBodyMean:
def __init__(self, cfg) -> None:
self.cfg = cfg
def smooth(self, kpts_est, **kwargs):
"smooth body"
kpts_interp = kpts_est.clone().detach()
kpts_interp[1:-1] = (kpts_interp[:-2] + kpts_interp[2:])/2
loss = funcl2(kpts_est[1:-1] - kpts_interp[1:-1])
return loss/(kpts_est.shape[0] - 2)
def body(self, kpts_est, **kwargs):
"smooth body"
return self.smooth(kpts_est[:, :25])
def hand(self, kpts_est, **kwargs):
"smooth body"
return self.smooth(kpts_est[:, 25:25+42])
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothPoses:
def __init__(self, nViews, nFrames, cfg=None) -> None:
self.nViews = nViews
self.nFrames = nFrames
self.norm = 'l2'
self.cfg = cfg
def _poses(self, poses):
"smooth poses"
loss = 0
for nv in range(self.nViews):
poses_ = poses[nv*self.nFrames:(nv+1)*self.nFrames, ]
# 计算poses插值
poses_interp = poses_.clone().detach()
poses_interp[1:-1] = (poses_interp[1:-1] + poses_interp[:-2] + poses_interp[2:])/3
loss += funcl2(poses_[1:-1] - poses_interp[1:-1])
return loss/(self.nFrames-2)/self.nViews
def poses(self, poses, **kwargs):
"smooth body poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, :66]
return self._poses(poses)
def hands(self, poses, **kwargs):
"smooth hand poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, 66:66+12]
else:
raise NotImplementedError
return self._poses(poses)
def head(self, poses, **kwargs):
"smooth head poses"
if self.cfg.model == 'smplx':
poses = poses[:, 66+12:]
else:
raise NotImplementedError
return self._poses(poses)
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothBodyMulti(LossSmoothBody):
def __init__(self, dimGroups, cfg) -> None:
super().__init__(cfg)
self.cfg = cfg
self.dimGroups = dimGroups
def __call__(self, kpts_est, **kwargs):
"Smooth body"
assert kpts_est.shape[0] > 1, 'If you use smooth loss, it must be more than 1 frames'
loss = 0
for nv in range(len(self.dimGroups) - 1):
kpts = kpts_est[self.dimGroups[nv]:self.dimGroups[nv+1]]
loss += super().__call__(kpts_est=kpts)
return loss/(len(self.dimGroups) - 1)
def __str__(self) -> str:
return 'Loss function for Multi Smooth of Body'
class LossSmoothPosesMulti:
def __init__(self, dimGroups, cfg) -> None:
self.dimGroups = dimGroups
self.norm = 'l2'
def __call__(self, poses, **kwargs):
"Smooth poses"
loss = 0
for nv in range(len(self.dimGroups) - 1):
poses_ = poses[self.dimGroups[nv]:self.dimGroups[nv+1]]
poses_interp = poses_.clone().detach()
poses_interp[1:-1] = (poses_interp[1:-1] + poses_interp[:-2] + poses_interp[2:])/3
loss += funcl2(poses_[1:-1] - poses_interp[1:-1])/(poses_.shape[0] - 2)
return loss/(len(self.dimGroups) - 1)
def __str__(self) -> str:
return 'Loss function for Multi Smooth of Poses'
class LossRepro:
def __init__(self, bboxes, keypoints2d, cfg) -> None:
device = cfg.device
bbox_sizes = np.maximum(bboxes[..., 2] - bboxes[..., 0], bboxes[..., 3] - bboxes[..., 1])
# 这里的valid不是一维的因为不清楚总共有多少维所以不能遍历去做
bbox_conf = bboxes[..., 4]
bbox_mean_axis = -1
bbox_sizes = (bbox_sizes * bbox_conf).sum(axis=bbox_mean_axis)/(1e-3 + bbox_conf.sum(axis=bbox_mean_axis))
bbox_sizes = bbox_sizes[..., None, None, None]
# 抑制掉完全不可见的视角将其置信度设成0
bbox_sizes[bbox_sizes < 10] = 1e6
inv_bbox_sizes = torch.Tensor(1./bbox_sizes).to(device)
keypoints2d = torch.Tensor(keypoints2d).to(device)
self.keypoints2d = keypoints2d[..., :2]
self.conf = keypoints2d[..., 2:] * inv_bbox_sizes * 100
self.norm = 'gm'
def __call__(self, img_points):
residual = (img_points - self.keypoints2d) * self.conf
# squared_res: (nFrames, nJoints, 2)
if self.norm == 'l2':
squared_res = residual ** 2
elif self.norm == 'l1':
squared_res = torch.abs(residual)
elif self.norm == 'gm':
squared_res = gmof(residual**2, 200)
else:
import ipdb; ipdb.set_trace()
return torch.sum(squared_res)
class LossInit:
def __init__(self, params, cfg) -> None:
self.norm = 'l2'
self.poses = torch.Tensor(params['poses']).to(cfg.device)
self.shapes = torch.Tensor(params['shapes']).to(cfg.device)
def init_poses(self, poses, **kwargs):
"distance to poses_0"
if self.norm == 'l2':
return torch.sum((poses - self.poses)**2)/poses.shape[0]
def init_shapes(self, shapes, **kwargs):
"distance to shapes_0"
if self.norm == 'l2':
return torch.sum((shapes - self.shapes)**2)/shapes.shape[0]
class LossKeypointsMV2D(LossRepro):
def __init__(self, keypoints2d, bboxes, Pall, cfg) -> None:
"""
Args:
keypoints2d (ndarray): (nViews, nFrames, nJoints, 3)
bboxes (ndarray): (nViews, nFrames, 5)
"""
super().__init__(bboxes, keypoints2d, cfg)
assert Pall.shape[0] == keypoints2d.shape[0] and Pall.shape[0] == bboxes.shape[0], \
'check you P shape: {} and keypoints2d shape: {}'.format(Pall.shape, keypoints2d.shape)
device = cfg.device
self.Pall = torch.Tensor(Pall).to(device)
self.nViews, self.nFrames, self.nJoints = keypoints2d.shape[:3]
self.kpt_homo = torch.ones((self.nFrames, self.nJoints, 1), device=device)
def __call__(self, kpts_est, **kwargs):
"reprojection loss for multiple views"
# kpts_est: (nFrames, nJoints, 3+1), P: (nViews, 3, 4)
# => projection: (nViews, nFrames, nJoints, 3)
kpts_homo = torch.cat([kpts_est[..., :self.nJoints, :], self.kpt_homo], dim=2)
point_cam = torch.einsum('vab,fnb->vfna', self.Pall, kpts_homo)
img_points = point_cam[..., :2]/point_cam[..., 2:]
return super().__call__(img_points)/self.nViews/self.nFrames
def __str__(self) -> str:
return 'Loss function for Reprojection error'
class SMPLAngleLoss:
def __init__(self, keypoints, model_type='smpl'):
if keypoints.shape[1] <= 15:
use_feet = False
use_head = False
else:
use_feet = keypoints[:, [19, 20, 21, 22, 23, 24], -1].sum() > 0.1
use_head = keypoints[:, [15, 16, 17, 18], -1].sum() > 0.1
if model_type == 'smpl':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14, 20, 21, 22, 23]
elif model_type == 'smplh':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
elif model_type == 'smplx':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
else:
raise NotImplementedError
if not use_feet:
SMPL_JOINT_ZERO_IDX.extend([7, 8])
if not use_head:
SMPL_JOINT_ZERO_IDX.extend([12, 15])
SMPL_POSES_ZERO_IDX = [[j for j in range(3*i, 3*i+3)] for i in SMPL_JOINT_ZERO_IDX]
SMPL_POSES_ZERO_IDX = sum(SMPL_POSES_ZERO_IDX, [])
# SMPL_POSES_ZERO_IDX.extend([36, 37, 38, 45, 46, 47])
self.idx = SMPL_POSES_ZERO_IDX
def loss(self, poses):
return torch.sum(torch.abs(poses[:, self.idx]))
def SmoothLoss(body_params, keys, weight_loss, span=4, model_type='smpl'):
spans = [i for i in range(1, span)]
span_weights = {i:1/i for i in range(1, span)}
span_weights = {key: i/sum(span_weights) for key, i in span_weights.items()}
loss_dict = {}
nFrames = body_params['poses'].shape[0]
nPoses = body_params['poses'].shape[1]
if model_type == 'smplh' or model_type == 'smplx':
nPoses = 66
for key in ['poses', 'Th', 'poses_hand', 'expression']:
if key not in keys:
continue
k = 'smooth_' + key
if k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
if key == 'poses_hand':
val = torch.sum((body_params['poses'][span:, 66:] - body_params['poses'][:nFrames-span, 66:])**2)
else:
val = torch.sum((body_params[key][span:, :nPoses] - body_params[key][:nFrames-span, :nPoses])**2)
loss_dict[k] += span_weights[span] * val
k = 'smooth_' + key + '_l1'
if k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
if key == 'poses_hand':
val = torch.sum((body_params['poses'][span:, 66:] - body_params['poses'][:nFrames-span, 66:]).abs())
else:
val = torch.sum((body_params[key][span:, :nPoses] - body_params[key][:nFrames-span, :nPoses]).abs())
loss_dict[k] += span_weights[span] * val
# smooth rotation
rot = batch_rodrigues(body_params['Rh'])
key, k = 'Rh', 'smooth_Rh'
if key in keys and k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
val = torch.sum((rot[span:, :] - rot[:nFrames-span, :])**2)
loss_dict[k] += span_weights[span] * val
return loss_dict
def RegularizationLoss(body_params, body_params_init, weight_loss):
loss_dict = {}
for key in ['poses', 'shapes', 'Th', 'hands', 'head', 'expression']:
if 'init_'+key in weight_loss.keys() and weight_loss['init_'+key] > 0.:
if key == 'poses':
loss_dict['init_'+key] = torch.sum((body_params[key][:, :66] - body_params_init[key][:, :66])**2)
elif key == 'hands':
loss_dict['init_'+key] = torch.sum((body_params['poses'][: , 66:66+12] - body_params_init['poses'][:, 66:66+12])**2)
elif key == 'head':
loss_dict['init_'+key] = torch.sum((body_params['poses'][: , 78:78+9] - body_params_init['poses'][:, 78:78+9])**2)
elif key in body_params.keys():
loss_dict['init_'+key] = torch.sum((body_params[key] - body_params_init[key])**2)
for key in ['poses', 'shapes', 'hands', 'head', 'expression']:
if 'reg_'+key in weight_loss.keys() and weight_loss['reg_'+key] > 0.:
if key == 'poses':
loss_dict['reg_'+key] = torch.sum((body_params[key][:, :66])**2)
elif key == 'hands':
loss_dict['reg_'+key] = torch.sum((body_params['poses'][: , 66:66+12])**2)
elif key == 'head':
loss_dict['reg_'+key] = torch.sum((body_params['poses'][: , 78:78+9])**2)
elif key in body_params.keys():
loss_dict['reg_'+key] = torch.sum((body_params[key])**2)
return loss_dict