EasyMocap/myeasymocap/operations/loss.py
2023-06-19 16:39:27 +08:00

246 lines
9.5 KiB
Python

import torch
import torch.nn as nn
import numpy as np
class GMoF(nn.Module):
def __init__(self, rho=1):
super(GMoF, self).__init__()
self.rho2 = rho * rho
def extra_repr(self):
return 'rho = {}'.format(self.rho)
def forward(self, est, gt=None, conf=None):
if gt is not None:
square_diff = torch.sum((est - gt)**2, dim=-1)
else:
square_diff = torch.sum(est**2, dim=-1)
diff = torch.div(square_diff, square_diff + self.rho2)
if conf is not None:
res = torch.sum(diff * conf)/(1e-5 + conf.sum())
else:
res = diff.sum()/diff.numel()
return res
class BaseLoss(nn.Module):
def __init__(self, norm='l2', norm_info={}, reduce='sum') -> None:
super().__init__()
self.loss = self.make_loss(norm, norm_info, reduce)
def make_loss(self, norm='l2', norm_info={}, reduce='sum'):
reduce = torch.sum if reduce=='sum' else torch.mean
if norm == 'l2':
def loss(est, gt=None, conf=None):
if gt is not None:
square_diff = reduce((est - gt)**2, dim=-1)
else:
square_diff = reduce(est**2, dim=-1)
if conf is not None:
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
else:
res = square_diff.sum()/square_diff.numel()
return res
elif norm == 'l1':
def loss(est, gt=None, conf=None):
if gt is not None:
square_diff = reduce(torch.abs(est - gt), dim=-1)
else:
square_diff = reduce(torch.abs(est), dim=-1)
if conf is not None:
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
else:
res = square_diff.sum()/square_diff.numel()
return res
elif norm == 'gm':
loss = GMoF(norm_info)
else:
loss = None
return loss
def forward(self, pred, target):
pass
class BaseKeypoints(BaseLoss):
@staticmethod
def select(keypoints, index, ranges):
if len(index) > 0:
keypoints = keypoints[..., index, :]
elif len(ranges) > 0:
if ranges[1] == -1:
keypoints = keypoints[..., ranges[0]:, :]
else:
keypoints = keypoints[..., ranges[0]:ranges[1], :]
return keypoints
def __init__(self, index_est=[], index_gt=[],
ranges_est=[], ranges_gt=[], **kwargs):
super().__init__(**kwargs)
self.index_est = index_est
self.index_gt = index_gt
self.ranges_est = ranges_est
self.ranges_gt = ranges_gt
def forward(self, pred, target):
return super().forward(pred, target)
def loss_keypoints(self, pred, target, conf):
# pred: (..., dim)
# target: (..., dim)
# conf: (..., 1)
dist = torch.sum((pred - target)**2, dim=-1, keepdim=True)
loss = torch.sum(dist * conf) / torch.sum(conf)
return loss
class Keypoints2D(BaseKeypoints):
def forward(self, pred, target):
# (nFrames, nJoints, 3)
pred_kpts3d = self.select(pred['keypoints'] , self.index_est, self.ranges_est)
target_kpts2d = self.select(target['keypoints'], self.index_gt, self.ranges_gt)
cameras = target['cameras']
P = torch.cat([cameras['R'], cameras['T']], dim=-1)
invKtrans = torch.inverse(cameras['K']).transpose(-1, -2)
homo = torch.cat([target_kpts2d[..., :2], torch.ones_like(target_kpts2d[..., 2:])], dim=-1)
target_points = torch.matmul(homo, invKtrans)[..., :2]
pred_homo = torch.cat([pred_kpts3d, torch.ones_like(pred_kpts3d[..., :1])], dim=-1)
self.einsum = 'fab,fjb->fja'
point_cam = torch.einsum(self.einsum, P, pred_homo)
img_points = point_cam[..., :2]/point_cam[..., 2:]
loss = self.loss(est=img_points, gt=target_points, conf=target_kpts2d[..., -1])
return loss
class Keypoints3D(BaseKeypoints):
def forward(self, pred, target):
# (nFrames, nJoints, 3)
# breakpoint()
pred_kpts3d = self.select(pred['keypoints'] , self.index_est, self.ranges_est)
target_kpts3d = self.select(target['keypoints3d'], self.index_gt, self.ranges_gt)
assert target_kpts3d.shape[-1] == 4, 'Target keypoints {} must have confidence '.format(target_kpts3d.shape)
loss = self.loss(est=pred_kpts3d, gt=target_kpts3d[...,:3], conf=target_kpts3d[..., -1])
return loss
class LimbLength(BaseKeypoints):
def __init__(self, kintree, key='keypoints3d', **kwargs):
self.kintree = np.array(kintree)
super().__init__(**kwargs)
def __str__(self):
return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree]))
def forward(self, pred, target):
pred_kpts3d = pred['keypoints']
target_kpts3d = target['keypoints3d']
# 用kin tree来进行选择
pred = torch.norm(pred_kpts3d[..., self.kintree[:, 1], :] - pred_kpts3d[..., self.kintree[:, 0], :], dim=-1, keepdim=True)
target = torch.norm(target_kpts3d[..., self.kintree[:, 1], :] - target_kpts3d[..., self.kintree[:, 0], :], dim=-1, keepdim=True)
target_conf = torch.minimum(target_kpts3d[..., self.kintree[:, 1], -1], target_kpts3d[..., self.kintree[:, 0], -1])
loss = self.loss(est=pred, gt=target, conf=target_conf)
return loss
class Smooth(BaseLoss):
def __init__(self, keys, smooth_type, order, norm, weights, window_weight) -> None:
super().__init__(norm)
self.loss = {}
for i in range(len(keys)):
new_key = keys[i] + '_' + smooth_type[i]
self.loss[new_key] = {
'func': self.make_loss(norm='l2', norm_info={}, reduce='sum'),
'key': keys[i],
'weight': weights[i],
'norm': norm[i],
'order': order[i],
'type': smooth_type[i],
}
self.window_weight = window_weight
def convert_Rh_to_R(self, Rh):
from ..bodymodels.geometry import batch_rodrigues
# Rh: (..., nRot x 3)
nRot = Rh.shape[-1] // 3
Rh_flat = Rh.reshape(-1, nRot, 3)
Rh_flat = Rh_flat.reshape(-1, 3)
Rot = batch_rodrigues(Rh_flat)
Rot_0 = Rot.reshape(-1, nRot, 3, 3)
Rot = Rot_0.reshape(*Rh.shape[:-1], 3, 3)
Rot = Rot.reshape(*Rh.shape[:-1], 9)
return Rot
def forward(self, pred, target):
ret = {}
for key, cfg in self.loss.items():
value = pred[cfg['key']]
loss = 0
for width, weight in enumerate(self.window_weight, start=1):
if cfg['type'] == 'Linear':
vel = value[width:] - value[:-width]
elif cfg['type'] == 'Rot':
_value = self.convert_Rh_to_R(value)
vel = _value[width:] - _value[:-width]
elif cfg['type'] == 'Depth':
# TODO: 考虑相机的RT
if 'cameras' in target.keys():
R = target['cameras']['R']
_value = torch.bmm(value[..., None, :], R.transpose(-1, -2))
_value = _value[..., 0, :]
_value = _value[..., [2]] # 只使用深度
vel = _value[width:] - _value[:-width]
if cfg['order'] == 2:
vel = vel[1:] - vel[:-1]
loss += weight * cfg['func'](est=vel)
ret[key] = loss * cfg['weight']
return ret
class AnySmooth(BaseLoss):
def __init__(self, key, weight, norm, norm_info={}, dim=-1, order=1):
super().__init__()
self.dim = dim
self.weight = weight
self.loss = self.make_loss(norm, norm_info)
self.norm_name = norm
self.key = key
self.order = order
def forward(self, pred, target):
loss = 0
value = pred[self.key]
# value = select(value, self.ranges, self.index, self.dim)
if value.shape[0] <= len(self.weight):
return torch.FloatTensor([0.]).to(value.device)
for width, weight in enumerate(self.weight, start=1):
vel = value[width:] - value[:-width]
if self.order == 2:
vel = vel[1:] - vel[:-1]
loss += weight * self.loss(vel)
return loss
class Init(BaseLoss):
def __init__(self, keys, weights, norm) -> None:
super().__init__(norm)
self.keys = keys
self.weights = weights
def forward(self, pred, target):
ret = {}
for key in self.keys:
ret[key] = torch.mean((pred[key] - target['init_'+key])**2)
return ret
from easymocap.multistage.lossbase import AnyReg
class RegLoss(AnyReg):
def __init__(self, key, norm) -> None:
super().__init__(key, norm)
def __call__(self, pred, target):
return self.forward(**{self.key: pred[self.key]})
class Init_pose(Init):
def __init__(self, keys, weights, norm) -> None:
super().__init__(keys, weights, norm)
self.norm = norm
def forward(self, pred, target):
ret = {}
for key in self.keys:
if self.norm == 'l2':
ret[key] = torch.sum((pred[key] - target['target_'+key])**2)
elif self.norm == 'l1':
ret[key] = torch.sum(torch.abs(pred[key] - target['target_'+key]))
return ret