EasyMocap/easymocap/pyfitting/lossfactory.py
2021-04-14 15:22:51 +08:00

419 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
@ Date: 2020-11-19 17:46:04
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-04-14 11:46:56
@ FilePath: /EasyMocap/easymocap/pyfitting/lossfactory.py
'''
import numpy as np
import torch
from .operation import projection, batch_rodrigues
funcl2 = lambda x: torch.sum(x**2)
funcl1 = lambda x: torch.sum(torch.abs(x**2))
def gmof(squared_res, sigma_squared):
"""
Geman-McClure error function
"""
return (sigma_squared * squared_res) / (sigma_squared + squared_res)
def ReprojectionLoss(keypoints3d, keypoints2d, K, Rc, Tc, inv_bbox_sizes, norm='l2'):
img_points = projection(keypoints3d, K, Rc, Tc)
residual = (img_points - keypoints2d[:, :, :2]) * keypoints2d[:, :, -1:]
# squared_res: (nFrames, nJoints, 2)
if norm == 'l2':
squared_res = (residual ** 2) * inv_bbox_sizes
elif norm == 'l1':
squared_res = torch.abs(residual) * inv_bbox_sizes
else:
import ipdb; ipdb.set_trace()
return torch.sum(squared_res)
class LossKeypoints3D:
def __init__(self, keypoints3d, cfg, norm='l2') -> None:
self.cfg = cfg
keypoints3d = torch.Tensor(keypoints3d).to(cfg.device)
self.nJoints = keypoints3d.shape[1]
self.keypoints3d = keypoints3d[..., :3]
self.conf = keypoints3d[..., 3:]
self.nFrames = keypoints3d.shape[0]
self.norm = norm
def loss(self, diff_square):
if self.norm == 'l2':
loss_3d = funcl2(diff_square)
elif self.norm == 'l1':
loss_3d = funcl1(diff_square)
elif self.norm == 'gm':
# 阈值设为0.2^2米
loss_3d = torch.sum(gmof(diff_square**2, 0.04))
else:
raise NotImplementedError
return loss_3d/self.nFrames
def body(self, kpts_est, **kwargs):
"distance of keypoints3d"
nJoints = min([kpts_est.shape[1], self.keypoints3d.shape[1], 25])
diff_square = (kpts_est[:, :nJoints, :3] - self.keypoints3d[:, :nJoints, :3])*self.conf[:, :nJoints]
return self.loss(diff_square)
def hand(self, kpts_est, **kwargs):
"distance of 3d hand keypoints"
diff_square = (kpts_est[:, 25:25+42, :3] - self.keypoints3d[:, 25:25+42, :3])*self.conf[:, 25:25+42]
return self.loss(diff_square)
def face(self, kpts_est, **kwargs):
"distance of 3d face keypoints"
diff_square = (kpts_est[:, 25+42:, :3] - self.keypoints3d[:, 25+42:, :3])*self.conf[:, 25+42:]
return self.loss(diff_square)
def __str__(self) -> str:
return 'Loss function for keypoints3D, norm = {}'.format(self.norm)
class LossRegPoses:
def __init__(self, cfg) -> None:
self.cfg = cfg
def reg_hand(self, poses, **kwargs):
"regulizer for hand pose"
assert self.cfg.model in ['smplh', 'smplx']
hand_poses = poses[:, 66:78]
loss = funcl2(hand_poses)
return loss/poses.shape[0]
def reg_head(self, poses, **kwargs):
"regulizer for head pose"
assert self.cfg.model in ['smplx']
poses = poses[:, 78:]
loss = funcl2(poses)
return loss/poses.shape[0]
def reg_expr(self, expression, **kwargs):
"regulizer for expression"
assert self.cfg.model in ['smplh', 'smplx']
return torch.sum(expression**2)
def reg_body(self, poses, **kwargs):
"regulizer for body poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, :66]
loss = funcl2(poses)
return loss/poses.shape[0]
def __str__(self) -> str:
return 'Loss function for Regulizer of Poses'
class LossRegPosesZero:
def __init__(self, keypoints, cfg) -> None:
model_type = cfg.model
if keypoints.shape[-2] <= 15:
use_feet = False
use_head = False
else:
use_feet = keypoints[..., [19, 20, 21, 22, 23, 24], -1].sum() > 0.1
use_head = keypoints[..., [15, 16, 17, 18], -1].sum() > 0.1
if model_type == 'smpl':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14, 20, 21, 22, 23]
elif model_type == 'smplh':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
elif model_type == 'smplx':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
else:
raise NotImplementedError
if not use_feet:
SMPL_JOINT_ZERO_IDX.extend([7, 8])
if not use_head:
SMPL_JOINT_ZERO_IDX.extend([12, 15])
SMPL_POSES_ZERO_IDX = [[j for j in range(3*i, 3*i+3)] for i in SMPL_JOINT_ZERO_IDX]
SMPL_POSES_ZERO_IDX = sum(SMPL_POSES_ZERO_IDX, [])
# SMPL_POSES_ZERO_IDX.extend([36, 37, 38, 45, 46, 47])
self.idx = SMPL_POSES_ZERO_IDX
def __call__(self, poses, **kwargs):
"regulizer for zero joints"
return torch.sum(torch.abs(poses[:, self.idx]))/poses.shape[0]
def __str__(self) -> str:
return 'Loss function for Regulizer of Poses'
class LossSmoothBody:
def __init__(self, cfg) -> None:
self.norm = 'l2'
def __call__(self, kpts_est, **kwargs):
N_BODY = min(25, kpts_est.shape[1])
assert kpts_est.shape[0] > 1, 'If you use smooth loss, it must be more than 1 frames'
if self.norm == 'l2':
loss = funcl2(kpts_est[:-1, :N_BODY] - kpts_est[1:, :N_BODY])
else:
loss = funcl1(kpts_est[:-1, :N_BODY] - kpts_est[1:, :N_BODY])
return loss/kpts_est.shape[0]
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothBodyMean:
def __init__(self, cfg) -> None:
self.cfg = cfg
def smooth(self, kpts_est, **kwargs):
"smooth body"
kpts_interp = kpts_est.clone().detach()
kpts_interp[1:-1] = (kpts_interp[:-2] + kpts_interp[2:])/2
loss = funcl2(kpts_est[1:-1] - kpts_interp[1:-1])
return loss/(kpts_est.shape[0] - 2)
def body(self, kpts_est, **kwargs):
"smooth body"
return self.smooth(kpts_est[:, :25])
def hand(self, kpts_est, **kwargs):
"smooth body"
return self.smooth(kpts_est[:, 25:25+42])
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothPoses:
def __init__(self, nViews, nFrames, cfg=None) -> None:
self.nViews = nViews
self.nFrames = nFrames
self.norm = 'l2'
self.cfg = cfg
def _poses(self, poses):
"smooth poses"
loss = 0
for nv in range(self.nViews):
poses_ = poses[nv*self.nFrames:(nv+1)*self.nFrames, ]
# 计算poses插值
poses_interp = poses_.clone().detach()
poses_interp[1:-1] = (poses_interp[1:-1] + poses_interp[:-2] + poses_interp[2:])/3
loss += funcl2(poses_[1:-1] - poses_interp[1:-1])
return loss/(self.nFrames-2)/self.nViews
def poses(self, poses, **kwargs):
"smooth body poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, :66]
return self._poses(poses)
def hands(self, poses, **kwargs):
"smooth hand poses"
if self.cfg.model in ['smplh', 'smplx']:
poses = poses[:, 66:66+12]
else:
raise NotImplementedError
return self._poses(poses)
def head(self, poses, **kwargs):
"smooth head poses"
if self.cfg.model == 'smplx':
poses = poses[:, 66+12:]
else:
raise NotImplementedError
return self._poses(poses)
def __str__(self) -> str:
return 'Loss function for Smooth of Body'
class LossSmoothBodyMulti(LossSmoothBody):
def __init__(self, dimGroups, cfg) -> None:
super().__init__(cfg)
self.cfg = cfg
self.dimGroups = dimGroups
def __call__(self, kpts_est, **kwargs):
"Smooth body"
assert kpts_est.shape[0] > 1, 'If you use smooth loss, it must be more than 1 frames'
loss = 0
for nv in range(len(self.dimGroups) - 1):
kpts = kpts_est[self.dimGroups[nv]:self.dimGroups[nv+1]]
loss += super().__call__(kpts_est=kpts)
return loss/(len(self.dimGroups) - 1)
def __str__(self) -> str:
return 'Loss function for Multi Smooth of Body'
class LossSmoothPosesMulti:
def __init__(self, dimGroups, cfg) -> None:
self.dimGroups = dimGroups
self.norm = 'l2'
def __call__(self, poses, **kwargs):
"Smooth poses"
loss = 0
for nv in range(len(self.dimGroups) - 1):
poses_ = poses[self.dimGroups[nv]:self.dimGroups[nv+1]]
poses_interp = poses_.clone().detach()
poses_interp[1:-1] = (poses_interp[1:-1] + poses_interp[:-2] + poses_interp[2:])/3
loss += funcl2(poses_[1:-1] - poses_interp[1:-1])/(poses_.shape[0] - 2)
return loss/(len(self.dimGroups) - 1)
def __str__(self) -> str:
return 'Loss function for Multi Smooth of Poses'
class LossRepro:
def __init__(self, bboxes, keypoints2d, cfg) -> None:
device = cfg.device
bbox_sizes = np.maximum(bboxes[..., 2] - bboxes[..., 0], bboxes[..., 3] - bboxes[..., 1])
# 这里的valid不是一维的因为不清楚总共有多少维所以不能遍历去做
bbox_conf = bboxes[..., 4]
bbox_mean_axis = -1
bbox_sizes = (bbox_sizes * bbox_conf).sum(axis=bbox_mean_axis)/(1e-3 + bbox_conf.sum(axis=bbox_mean_axis))
bbox_sizes = bbox_sizes[..., None, None, None]
# 抑制掉完全不可见的视角将其置信度设成0
bbox_sizes[bbox_sizes < 10] = 1e6
inv_bbox_sizes = torch.Tensor(1./bbox_sizes).to(device)
keypoints2d = torch.Tensor(keypoints2d).to(device)
self.keypoints2d = keypoints2d[..., :2]
self.conf = keypoints2d[..., 2:] * inv_bbox_sizes * 100
self.norm = 'gm'
def __call__(self, img_points):
residual = (img_points - self.keypoints2d) * self.conf
# squared_res: (nFrames, nJoints, 2)
if self.norm == 'l2':
squared_res = residual ** 2
elif self.norm == 'l1':
squared_res = torch.abs(residual)
elif self.norm == 'gm':
squared_res = gmof(residual**2, 200)
else:
import ipdb; ipdb.set_trace()
return torch.sum(squared_res)
class LossInit:
def __init__(self, params, cfg) -> None:
self.norm = 'l2'
self.poses = torch.Tensor(params['poses']).to(cfg.device)
self.shapes = torch.Tensor(params['shapes']).to(cfg.device)
def init_poses(self, poses, **kwargs):
"distance to poses_0"
if self.norm == 'l2':
return torch.sum((poses - self.poses)**2)/poses.shape[0]
def init_shapes(self, shapes, **kwargs):
"distance to shapes_0"
if self.norm == 'l2':
return torch.sum((shapes - self.shapes)**2)/shapes.shape[0]
class LossKeypointsMV2D(LossRepro):
def __init__(self, keypoints2d, bboxes, Pall, cfg) -> None:
"""
Args:
keypoints2d (ndarray): (nViews, nFrames, nJoints, 3)
bboxes (ndarray): (nViews, nFrames, 5)
"""
super().__init__(bboxes, keypoints2d, cfg)
assert Pall.shape[0] == keypoints2d.shape[0] and Pall.shape[0] == bboxes.shape[0], \
'check you P shape: {} and keypoints2d shape: {}'.format(Pall.shape, keypoints2d.shape)
device = cfg.device
self.Pall = torch.Tensor(Pall).to(device)
self.nViews, self.nFrames, self.nJoints = keypoints2d.shape[:3]
self.kpt_homo = torch.ones((self.nFrames, self.nJoints, 1), device=device)
def __call__(self, kpts_est, **kwargs):
"reprojection loss for multiple views"
# kpts_est: (nFrames, nJoints, 3+1), P: (nViews, 3, 4)
# => projection: (nViews, nFrames, nJoints, 3)
kpts_homo = torch.cat([kpts_est[..., :self.nJoints, :], self.kpt_homo], dim=2)
point_cam = torch.einsum('vab,fnb->vfna', self.Pall, kpts_homo)
img_points = point_cam[..., :2]/point_cam[..., 2:]
return super().__call__(img_points)/self.nViews/self.nFrames
def __str__(self) -> str:
return 'Loss function for Reprojection error'
class SMPLAngleLoss:
def __init__(self, keypoints, model_type='smpl'):
if keypoints.shape[1] <= 15:
use_feet = False
use_head = False
else:
use_feet = keypoints[:, [19, 20, 21, 22, 23, 24], -1].sum() > 0.1
use_head = keypoints[:, [15, 16, 17, 18], -1].sum() > 0.1
if model_type == 'smpl':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14, 20, 21, 22, 23]
elif model_type == 'smplh':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
elif model_type == 'smplx':
SMPL_JOINT_ZERO_IDX = [3, 6, 9, 10, 11, 13, 14]
else:
raise NotImplementedError
if not use_feet:
SMPL_JOINT_ZERO_IDX.extend([7, 8])
if not use_head:
SMPL_JOINT_ZERO_IDX.extend([12, 15])
SMPL_POSES_ZERO_IDX = [[j for j in range(3*i, 3*i+3)] for i in SMPL_JOINT_ZERO_IDX]
SMPL_POSES_ZERO_IDX = sum(SMPL_POSES_ZERO_IDX, [])
# SMPL_POSES_ZERO_IDX.extend([36, 37, 38, 45, 46, 47])
self.idx = SMPL_POSES_ZERO_IDX
def loss(self, poses):
return torch.sum(torch.abs(poses[:, self.idx]))
def SmoothLoss(body_params, keys, weight_loss, span=4, model_type='smpl'):
spans = [i for i in range(1, span)]
span_weights = {i:1/i for i in range(1, span)}
span_weights = {key: i/sum(span_weights) for key, i in span_weights.items()}
loss_dict = {}
nFrames = body_params['poses'].shape[0]
nPoses = body_params['poses'].shape[1]
if model_type == 'smplh' or model_type == 'smplx':
nPoses = 66
for key in ['poses', 'Th', 'poses_hand', 'expression']:
if key not in keys:
continue
k = 'smooth_' + key
if k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
if key == 'poses_hand':
val = torch.sum((body_params['poses'][span:, 66:] - body_params['poses'][:nFrames-span, 66:])**2)
else:
val = torch.sum((body_params[key][span:, :nPoses] - body_params[key][:nFrames-span, :nPoses])**2)
loss_dict[k] += span_weights[span] * val
k = 'smooth_' + key + '_l1'
if k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
if key == 'poses_hand':
val = torch.sum((body_params['poses'][span:, 66:] - body_params['poses'][:nFrames-span, 66:]).abs())
else:
val = torch.sum((body_params[key][span:, :nPoses] - body_params[key][:nFrames-span, :nPoses]).abs())
loss_dict[k] += span_weights[span] * val
# smooth rotation
rot = batch_rodrigues(body_params['Rh'])
key, k = 'Rh', 'smooth_Rh'
if key in keys and k in weight_loss.keys() and weight_loss[k] > 0.:
loss_dict[k] = 0.
for span in spans:
val = torch.sum((rot[span:, :] - rot[:nFrames-span, :])**2)
loss_dict[k] += span_weights[span] * val
return loss_dict
def RegularizationLoss(body_params, body_params_init, weight_loss):
loss_dict = {}
for key in ['poses', 'shapes', 'Th', 'hands', 'head', 'expression']:
if 'init_'+key in weight_loss.keys() and weight_loss['init_'+key] > 0.:
if key == 'poses':
loss_dict['init_'+key] = torch.sum((body_params[key][:, :66] - body_params_init[key][:, :66])**2)
elif key == 'hands':
loss_dict['init_'+key] = torch.sum((body_params['poses'][: , 66:66+12] - body_params_init['poses'][:, 66:66+12])**2)
elif key == 'head':
loss_dict['init_'+key] = torch.sum((body_params['poses'][: , 78:78+9] - body_params_init['poses'][:, 78:78+9])**2)
elif key in body_params.keys():
loss_dict['init_'+key] = torch.sum((body_params[key] - body_params_init[key])**2)
for key in ['poses', 'shapes', 'hands', 'head', 'expression']:
if 'reg_'+key in weight_loss.keys() and weight_loss['reg_'+key] > 0.:
if key == 'poses':
loss_dict['reg_'+key] = torch.sum((body_params[key][:, :66])**2)
elif key == 'hands':
loss_dict['reg_'+key] = torch.sum((body_params['poses'][: , 66:66+12])**2)
elif key == 'head':
loss_dict['reg_'+key] = torch.sum((body_params['poses'][: , 78:78+9])**2)
elif key in body_params.keys():
loss_dict['reg_'+key] = torch.sum((body_params[key])**2)
return loss_dict