246 lines
9.5 KiB
Python
246 lines
9.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
class GMoF(nn.Module):
|
|
def __init__(self, rho=1):
|
|
super(GMoF, self).__init__()
|
|
self.rho2 = rho * rho
|
|
|
|
def extra_repr(self):
|
|
return 'rho = {}'.format(self.rho)
|
|
|
|
def forward(self, est, gt=None, conf=None):
|
|
if gt is not None:
|
|
square_diff = torch.sum((est - gt)**2, dim=-1)
|
|
else:
|
|
square_diff = torch.sum(est**2, dim=-1)
|
|
diff = torch.div(square_diff, square_diff + self.rho2)
|
|
if conf is not None:
|
|
res = torch.sum(diff * conf)/(1e-5 + conf.sum())
|
|
else:
|
|
res = diff.sum()/diff.numel()
|
|
return res
|
|
|
|
class BaseLoss(nn.Module):
|
|
def __init__(self, norm='l2', norm_info={}, reduce='sum') -> None:
|
|
super().__init__()
|
|
self.loss = self.make_loss(norm, norm_info, reduce)
|
|
|
|
def make_loss(self, norm='l2', norm_info={}, reduce='sum'):
|
|
reduce = torch.sum if reduce=='sum' else torch.mean
|
|
if norm == 'l2':
|
|
def loss(est, gt=None, conf=None):
|
|
if gt is not None:
|
|
square_diff = reduce((est - gt)**2, dim=-1)
|
|
else:
|
|
square_diff = reduce(est**2, dim=-1)
|
|
if conf is not None:
|
|
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
|
|
else:
|
|
res = square_diff.sum()/square_diff.numel()
|
|
return res
|
|
elif norm == 'l1':
|
|
def loss(est, gt=None, conf=None):
|
|
if gt is not None:
|
|
square_diff = reduce(torch.abs(est - gt), dim=-1)
|
|
else:
|
|
square_diff = reduce(torch.abs(est), dim=-1)
|
|
if conf is not None:
|
|
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
|
|
else:
|
|
res = square_diff.sum()/square_diff.numel()
|
|
return res
|
|
elif norm == 'gm':
|
|
loss = GMoF(norm_info)
|
|
else:
|
|
loss = None
|
|
return loss
|
|
|
|
def forward(self, pred, target):
|
|
pass
|
|
|
|
class BaseKeypoints(BaseLoss):
|
|
@staticmethod
|
|
def select(keypoints, index, ranges):
|
|
if len(index) > 0:
|
|
keypoints = keypoints[..., index, :]
|
|
elif len(ranges) > 0:
|
|
if ranges[1] == -1:
|
|
keypoints = keypoints[..., ranges[0]:, :]
|
|
else:
|
|
keypoints = keypoints[..., ranges[0]:ranges[1], :]
|
|
return keypoints
|
|
|
|
def __init__(self, index_est=[], index_gt=[],
|
|
ranges_est=[], ranges_gt=[], **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.index_est = index_est
|
|
self.index_gt = index_gt
|
|
self.ranges_est = ranges_est
|
|
self.ranges_gt = ranges_gt
|
|
|
|
def forward(self, pred, target):
|
|
return super().forward(pred, target)
|
|
|
|
def loss_keypoints(self, pred, target, conf):
|
|
# pred: (..., dim)
|
|
# target: (..., dim)
|
|
# conf: (..., 1)
|
|
dist = torch.sum((pred - target)**2, dim=-1, keepdim=True)
|
|
loss = torch.sum(dist * conf) / torch.sum(conf)
|
|
return loss
|
|
|
|
class Keypoints2D(BaseKeypoints):
|
|
def forward(self, pred, target):
|
|
# (nFrames, nJoints, 3)
|
|
pred_kpts3d = self.select(pred['keypoints'] , self.index_est, self.ranges_est)
|
|
target_kpts2d = self.select(target['keypoints'], self.index_gt, self.ranges_gt)
|
|
cameras = target['cameras']
|
|
P = torch.cat([cameras['R'], cameras['T']], dim=-1)
|
|
invKtrans = torch.inverse(cameras['K']).transpose(-1, -2)
|
|
homo = torch.cat([target_kpts2d[..., :2], torch.ones_like(target_kpts2d[..., 2:])], dim=-1)
|
|
target_points = torch.matmul(homo, invKtrans)[..., :2]
|
|
pred_homo = torch.cat([pred_kpts3d, torch.ones_like(pred_kpts3d[..., :1])], dim=-1)
|
|
self.einsum = 'fab,fjb->fja'
|
|
point_cam = torch.einsum(self.einsum, P, pred_homo)
|
|
img_points = point_cam[..., :2]/point_cam[..., 2:]
|
|
loss = self.loss(est=img_points, gt=target_points, conf=target_kpts2d[..., -1])
|
|
return loss
|
|
|
|
class Keypoints3D(BaseKeypoints):
|
|
def forward(self, pred, target):
|
|
# (nFrames, nJoints, 3)
|
|
# breakpoint()
|
|
pred_kpts3d = self.select(pred['keypoints'] , self.index_est, self.ranges_est)
|
|
target_kpts3d = self.select(target['keypoints3d'], self.index_gt, self.ranges_gt)
|
|
assert target_kpts3d.shape[-1] == 4, 'Target keypoints {} must have confidence '.format(target_kpts3d.shape)
|
|
loss = self.loss(est=pred_kpts3d, gt=target_kpts3d[...,:3], conf=target_kpts3d[..., -1])
|
|
return loss
|
|
|
|
class LimbLength(BaseKeypoints):
|
|
def __init__(self, kintree, key='keypoints3d', **kwargs):
|
|
self.kintree = np.array(kintree)
|
|
super().__init__(**kwargs)
|
|
|
|
def __str__(self):
|
|
return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree]))
|
|
|
|
def forward(self, pred, target):
|
|
pred_kpts3d = pred['keypoints']
|
|
target_kpts3d = target['keypoints3d']
|
|
# 用kin tree来进行选择
|
|
pred = torch.norm(pred_kpts3d[..., self.kintree[:, 1], :] - pred_kpts3d[..., self.kintree[:, 0], :], dim=-1, keepdim=True)
|
|
target = torch.norm(target_kpts3d[..., self.kintree[:, 1], :] - target_kpts3d[..., self.kintree[:, 0], :], dim=-1, keepdim=True)
|
|
target_conf = torch.minimum(target_kpts3d[..., self.kintree[:, 1], -1], target_kpts3d[..., self.kintree[:, 0], -1])
|
|
loss = self.loss(est=pred, gt=target, conf=target_conf)
|
|
return loss
|
|
|
|
class Smooth(BaseLoss):
|
|
def __init__(self, keys, smooth_type, order, norm, weights, window_weight) -> None:
|
|
super().__init__(norm)
|
|
self.loss = {}
|
|
for i in range(len(keys)):
|
|
new_key = keys[i] + '_' + smooth_type[i]
|
|
self.loss[new_key] = {
|
|
'func': self.make_loss(norm='l2', norm_info={}, reduce='sum'),
|
|
'key': keys[i],
|
|
'weight': weights[i],
|
|
'norm': norm[i],
|
|
'order': order[i],
|
|
'type': smooth_type[i],
|
|
}
|
|
self.window_weight = window_weight
|
|
|
|
def convert_Rh_to_R(self, Rh):
|
|
from ..bodymodels.geometry import batch_rodrigues
|
|
# Rh: (..., nRot x 3)
|
|
nRot = Rh.shape[-1] // 3
|
|
Rh_flat = Rh.reshape(-1, nRot, 3)
|
|
Rh_flat = Rh_flat.reshape(-1, 3)
|
|
Rot = batch_rodrigues(Rh_flat)
|
|
Rot_0 = Rot.reshape(-1, nRot, 3, 3)
|
|
Rot = Rot_0.reshape(*Rh.shape[:-1], 3, 3)
|
|
Rot = Rot.reshape(*Rh.shape[:-1], 9)
|
|
return Rot
|
|
|
|
def forward(self, pred, target):
|
|
ret = {}
|
|
for key, cfg in self.loss.items():
|
|
value = pred[cfg['key']]
|
|
loss = 0
|
|
for width, weight in enumerate(self.window_weight, start=1):
|
|
if cfg['type'] == 'Linear':
|
|
vel = value[width:] - value[:-width]
|
|
elif cfg['type'] == 'Rot':
|
|
_value = self.convert_Rh_to_R(value)
|
|
vel = _value[width:] - _value[:-width]
|
|
elif cfg['type'] == 'Depth':
|
|
# TODO: 考虑相机的RT
|
|
if 'cameras' in target.keys():
|
|
R = target['cameras']['R']
|
|
_value = torch.bmm(value[..., None, :], R.transpose(-1, -2))
|
|
_value = _value[..., 0, :]
|
|
_value = _value[..., [2]] # 只使用深度
|
|
vel = _value[width:] - _value[:-width]
|
|
if cfg['order'] == 2:
|
|
vel = vel[1:] - vel[:-1]
|
|
loss += weight * cfg['func'](est=vel)
|
|
ret[key] = loss * cfg['weight']
|
|
return ret
|
|
|
|
class AnySmooth(BaseLoss):
|
|
def __init__(self, key, weight, norm, norm_info={}, dim=-1, order=1):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.weight = weight
|
|
self.loss = self.make_loss(norm, norm_info)
|
|
self.norm_name = norm
|
|
self.key = key
|
|
self.order = order
|
|
|
|
def forward(self, pred, target):
|
|
loss = 0
|
|
value = pred[self.key]
|
|
# value = select(value, self.ranges, self.index, self.dim)
|
|
if value.shape[0] <= len(self.weight):
|
|
return torch.FloatTensor([0.]).to(value.device)
|
|
for width, weight in enumerate(self.weight, start=1):
|
|
vel = value[width:] - value[:-width]
|
|
if self.order == 2:
|
|
vel = vel[1:] - vel[:-1]
|
|
loss += weight * self.loss(vel)
|
|
return loss
|
|
|
|
class Init(BaseLoss):
|
|
def __init__(self, keys, weights, norm) -> None:
|
|
super().__init__(norm)
|
|
self.keys = keys
|
|
self.weights = weights
|
|
|
|
def forward(self, pred, target):
|
|
ret = {}
|
|
for key in self.keys:
|
|
ret[key] = torch.mean((pred[key] - target['init_'+key])**2)
|
|
return ret
|
|
|
|
from easymocap.multistage.lossbase import AnyReg
|
|
class RegLoss(AnyReg):
|
|
def __init__(self, key, norm) -> None:
|
|
super().__init__(key, norm)
|
|
|
|
def __call__(self, pred, target):
|
|
return self.forward(**{self.key: pred[self.key]})
|
|
|
|
class Init_pose(Init):
|
|
def __init__(self, keys, weights, norm) -> None:
|
|
super().__init__(keys, weights, norm)
|
|
self.norm = norm
|
|
def forward(self, pred, target):
|
|
ret = {}
|
|
for key in self.keys:
|
|
if self.norm == 'l2':
|
|
ret[key] = torch.sum((pred[key] - target['target_'+key])**2)
|
|
elif self.norm == 'l1':
|
|
ret[key] = torch.sum(torch.abs(pred[key] - target['target_'+key]))
|
|
return ret |