589 lines
22 KiB
Python
589 lines
22 KiB
Python
|
import numpy as np
|
||
|
import torch.nn as nn
|
||
|
import torch
|
||
|
from ..bodymodel.lbs import batch_rodrigues
|
||
|
|
||
|
class LossBase(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return '# lack of comment'
|
||
|
|
||
|
def check_at_start(self, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def check_at_end(self, **kwargs):
|
||
|
pass
|
||
|
|
||
|
class GMoF(nn.Module):
|
||
|
def __init__(self, rho=1):
|
||
|
super(GMoF, self).__init__()
|
||
|
self.rho2 = rho * rho
|
||
|
|
||
|
def extra_repr(self):
|
||
|
return 'rho = {}'.format(self.rho)
|
||
|
|
||
|
def forward(self, est, gt=None, conf=None):
|
||
|
if gt is not None:
|
||
|
square_diff = torch.sum((est - gt)**2, dim=-1)
|
||
|
else:
|
||
|
square_diff = torch.sum(est**2, dim=-1)
|
||
|
diff = torch.div(square_diff, square_diff + self.rho2)
|
||
|
if conf is not None:
|
||
|
res = torch.sum(diff * conf)/(1e-5 + conf.sum())
|
||
|
else:
|
||
|
res = diff.sum()/diff.numel()
|
||
|
return res
|
||
|
|
||
|
def make_loss(norm, norm_info, reduce='sum'):
|
||
|
reduce = torch.sum if reduce=='sum' else torch.mean
|
||
|
if norm == 'l2':
|
||
|
def loss(est, gt=None, conf=None):
|
||
|
if gt is not None:
|
||
|
square_diff = reduce((est - gt)**2, dim=-1)
|
||
|
else:
|
||
|
square_diff = reduce(est**2, dim=-1)
|
||
|
if conf is not None:
|
||
|
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
|
||
|
else:
|
||
|
res = square_diff.sum()/square_diff.numel()
|
||
|
return res
|
||
|
elif norm == 'gm':
|
||
|
loss = GMoF(norm_info)
|
||
|
return loss
|
||
|
|
||
|
def select(value, ranges, index, dim):
|
||
|
if len(ranges) > 0:
|
||
|
if ranges[1] == -1:
|
||
|
value = value[..., ranges[0]:]
|
||
|
else:
|
||
|
value = value[..., ranges[0]:ranges[1]]
|
||
|
return value
|
||
|
if len(index) > 0:
|
||
|
if dim == -1:
|
||
|
value = value[..., index]
|
||
|
elif dim == -2:
|
||
|
value = value[..., index, :]
|
||
|
return value
|
||
|
return value
|
||
|
|
||
|
def print_table(header, contents):
|
||
|
from tabulate import tabulate
|
||
|
length = len(contents[0])
|
||
|
tables = [[] for _ in range(length)]
|
||
|
mean = ['Mean']
|
||
|
for icnt, content in enumerate(contents):
|
||
|
for i in range(length):
|
||
|
if isinstance(content[i], float):
|
||
|
tables[i].append('{:6.2f}'.format(content[i]))
|
||
|
else:
|
||
|
tables[i].append('{}'.format(content[i]))
|
||
|
if icnt > 0:
|
||
|
mean.append('{:6.2f}'.format(sum(content)/length))
|
||
|
tables.append(mean)
|
||
|
print(tabulate(tables, header, tablefmt='fancy_grid'))
|
||
|
|
||
|
class AnyReg(LossBase):
|
||
|
def __init__(self, key, norm, dim=-1, reduce='sum', norm_info={}, ranges=[], index=[], **kwargs):
|
||
|
super().__init__()
|
||
|
self.ranges = ranges
|
||
|
self.index = index
|
||
|
self.key = key
|
||
|
self.dim = dim
|
||
|
if 'init_' + key in kwargs.keys():
|
||
|
init = kwargs['init_'+key]
|
||
|
self.register_buffer('init', torch.Tensor(init))
|
||
|
else:
|
||
|
self.init = None
|
||
|
self.norm_name = norm
|
||
|
self.loss = make_loss(norm, norm_info, reduce=reduce)
|
||
|
|
||
|
def forward(self, **kwargs):
|
||
|
"""
|
||
|
value: (nFrames, ..., nDims)
|
||
|
"""
|
||
|
value = kwargs[self.key]
|
||
|
if self.init is not None:
|
||
|
value = value - self.init
|
||
|
value = select(value, self.ranges, self.index, self.dim)
|
||
|
return self.loss(value)
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return 'Loss for {}'.format(self.key, self.norm_name)
|
||
|
|
||
|
class RegPrior(AnyReg):
|
||
|
def __init__(self, **cfg):
|
||
|
super().__init__(**cfg)
|
||
|
self.init = None # disable init
|
||
|
infos = {
|
||
|
(2, 0): '-exp',
|
||
|
(2, 1): 'l2',
|
||
|
(2, 2): 'l2',
|
||
|
(3, 0): '-exp', # knee
|
||
|
(3, 1): 'L2',
|
||
|
(3, 2): 'L2',
|
||
|
(4, 0): '-exp', # knee
|
||
|
(4, 1): 'L2',
|
||
|
(4, 2): 'L2',
|
||
|
(5, 0): '-exp',
|
||
|
(5, 1): 'l2',
|
||
|
(5, 2): 'l2',
|
||
|
(6, 1): 'L2',
|
||
|
(6, 2): 'L2',
|
||
|
(7, 1): 'L2',
|
||
|
(7, 2): 'L2',
|
||
|
(8, 0): '-exp',
|
||
|
(8, 1): 'l2',
|
||
|
(8, 2): 'l2',
|
||
|
(9, 0): 'L2',
|
||
|
(9, 1): 'L2',
|
||
|
(9, 2): 'L2',
|
||
|
(10, 0): 'L2',
|
||
|
(10, 1): 'L2',
|
||
|
(10, 2): 'L2',
|
||
|
(12, 0): 'l2', # 肩关节前面
|
||
|
(13, 0): 'l2',
|
||
|
(17, 1): 'exp',
|
||
|
(17, 2): 'L2',
|
||
|
(18, 1): '-exp',
|
||
|
(18, 2): 'L2',
|
||
|
}
|
||
|
self.l2dims = []
|
||
|
self.L2dims = []
|
||
|
self.expdims = []
|
||
|
self.nexpdims = []
|
||
|
for (nj, ndim), norm in infos.items():
|
||
|
dim = nj*3 + ndim
|
||
|
if norm == 'l2':
|
||
|
self.l2dims.append(dim)
|
||
|
elif norm == 'L2':
|
||
|
self.L2dims.append(dim)
|
||
|
elif norm == '-exp':
|
||
|
self.nexpdims.append(dim)
|
||
|
elif norm == 'exp':
|
||
|
self.expdims.append(dim)
|
||
|
|
||
|
def forward(self, poses, **kwargs):
|
||
|
"""
|
||
|
poses: (..., nDims)
|
||
|
"""
|
||
|
alll2loss = torch.mean(poses**2)
|
||
|
l2loss = torch.sum(poses[:, self.l2dims]**2)/len(self.l2dims)/poses.shape[0]
|
||
|
L2loss = torch.sum(poses[:, self.L2dims]**2)/len(self.L2dims)/poses.shape[0]
|
||
|
exploss = torch.sum(torch.exp(poses[:, self.expdims]))/poses.shape[0]
|
||
|
nexploss = torch.sum(torch.exp(-poses[:, self.nexpdims]))/poses.shape[0]
|
||
|
loss = 0.1*l2loss + L2loss + 0.0005*(exploss + nexploss)/(len(self.expdims) + len(self.nexpdims)) + 0.01*alll2loss
|
||
|
return loss
|
||
|
|
||
|
class VPoserPrior(AnyReg):
|
||
|
def __init__(self, **cfg):
|
||
|
super().__init__(**cfg)
|
||
|
vposer_ckpt = 'data/bodymodels/vposer_v02'
|
||
|
from human_body_prior.tools.model_loader import load_model
|
||
|
from human_body_prior.models.vposer_model import VPoser
|
||
|
vposer, _ = load_model(vposer_ckpt,
|
||
|
model_code=VPoser,
|
||
|
remove_words_in_model_weights='vp_model.',
|
||
|
disable_grad=True)
|
||
|
vposer.eval()
|
||
|
self.vposer = vposer
|
||
|
self.init = None # disable init
|
||
|
|
||
|
def forward(self, poses, **kwargs):
|
||
|
"""
|
||
|
poses: (..., nDims)
|
||
|
"""
|
||
|
nDims = 63
|
||
|
poses_body = poses[..., :nDims].reshape(-1, nDims)
|
||
|
latent = self.vposer.encode(poses_body)
|
||
|
if True:
|
||
|
ret = self.vposer.decode(latent.sample())['pose_body'].reshape(poses.shape[0], nDims)
|
||
|
return super().forward(poses=poses_body-ret)
|
||
|
else:
|
||
|
return super().forward(poses=latent.mean)
|
||
|
|
||
|
class AnySmooth(LossBase):
|
||
|
def __init__(self, key, weight, norm, norm_info={}, ranges=[], index=[], dim=-1, order=1):
|
||
|
super().__init__()
|
||
|
self.ranges = ranges
|
||
|
self.index = index
|
||
|
self.dim = dim
|
||
|
self.weight = weight
|
||
|
self.loss = make_loss(norm, norm_info)
|
||
|
self.norm_name = norm
|
||
|
self.key = key
|
||
|
self.order = order
|
||
|
|
||
|
def forward(self, **kwargs):
|
||
|
loss = 0
|
||
|
value = kwargs[self.key]
|
||
|
value = select(value, self.ranges, self.index, self.dim)
|
||
|
if value.shape[0] <= len(self.weight):
|
||
|
return torch.FloatTensor([0.]).to(value.device)
|
||
|
for width, weight in enumerate(self.weight, start=1):
|
||
|
vel = value[width:] - value[:-width]
|
||
|
if self.order == 2:
|
||
|
vel = vel[1:] - vel[:-1]
|
||
|
loss += weight * self.loss(vel)
|
||
|
return loss
|
||
|
|
||
|
def check(self, value):
|
||
|
vel = value[1:] - value[:-1]
|
||
|
if len(vel.shape) > 2:
|
||
|
vel = torch.norm(vel, dim=-1)
|
||
|
else:
|
||
|
vel = torch.abs(vel)
|
||
|
vel = vel.detach().cpu().numpy()
|
||
|
return vel
|
||
|
|
||
|
def get_check_name(self, value):
|
||
|
name = [str(i) for i in range(value.shape[1])]
|
||
|
return name
|
||
|
|
||
|
def check_at_start(self, **kwargs):
|
||
|
value = kwargs[self.key]
|
||
|
if value.shape[0] < len(self.weight):
|
||
|
return 0
|
||
|
header = ['Smooth '+self.key, 'mean(before)', 'max(before)', 'frame(before)']
|
||
|
name = self.get_check_name(value)
|
||
|
vel = self.check(value)
|
||
|
contents = [name, vel.mean(axis=0).tolist(), vel.max(axis=0).tolist(), vel.argmax(axis=0).tolist()]
|
||
|
self.cache_check = (header, contents)
|
||
|
return super().check_at_start(**kwargs)
|
||
|
|
||
|
def check_at_end(self, **kwargs):
|
||
|
value = kwargs[self.key]
|
||
|
if value.shape[0] < len(self.weight):
|
||
|
return 0
|
||
|
err_after = self.check(kwargs[self.key])
|
||
|
header, contents = self.cache_check
|
||
|
header.extend(['mean(after)', 'max(after)', 'frame(after)'])
|
||
|
contents.extend([err_after.mean(axis=0).tolist(), err_after.max(axis=0).tolist(), err_after.argmax(axis=0).tolist()])
|
||
|
print_table(header, contents)
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return "smooth in {} frames, range={}, norm={}".format(self.weight, self.ranges, self.norm_name)
|
||
|
|
||
|
class SmoothRot(AnySmooth):
|
||
|
def __init__(self, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
from ..bodymodel.lbs import batch_rodrigues
|
||
|
self.rodrigues = batch_rodrigues
|
||
|
|
||
|
def convert_Rh_to_R(self, Rh):
|
||
|
shape = Rh.shape[1]
|
||
|
ret = []
|
||
|
for i in range(shape//3):
|
||
|
Rot = self.rodrigues(Rh[:, 3*i:3*(i+1)])
|
||
|
ret.append(Rot)
|
||
|
ret = torch.cat(ret, dim=1)
|
||
|
return ret
|
||
|
|
||
|
def forward(self, **kwargs):
|
||
|
Rh = kwargs[self.key]
|
||
|
if Rh.shape[-1] != 3:
|
||
|
loss = 0
|
||
|
for i in range(Rh.shape[-1]//3):
|
||
|
Rh_sub = Rh[..., 3*i:3*i+3]
|
||
|
Rot = self.convert_Rh_to_R(Rh_sub).view(*Rh_sub.shape[:-1], 3, 3)
|
||
|
loss += super().forward(**{self.key: Rot})
|
||
|
return loss
|
||
|
else:
|
||
|
Rh_flat = Rh.view(-1, 3)
|
||
|
Rot = self.convert_Rh_to_R(Rh_flat).view(*Rh.shape[:-1], 3, 3)
|
||
|
return super().forward(**{self.key: Rot})
|
||
|
|
||
|
def get_check_name(self, value):
|
||
|
name = ['angle']
|
||
|
return name
|
||
|
|
||
|
def check(self, value):
|
||
|
import cv2
|
||
|
# TODO: here just use first rotation
|
||
|
if len(value.shape) == 3:
|
||
|
value = value[:, 0]
|
||
|
Rot = self.convert_Rh_to_R(value.detach())[:, :3]
|
||
|
vel = torch.matmul(Rot[:-1], Rot.transpose(1,2)[1:]).cpu().numpy()
|
||
|
vels = []
|
||
|
for i in range(vel.shape[0]):
|
||
|
angle = np.linalg.norm(cv2.Rodrigues(vel[i])[0])
|
||
|
vels.append(angle)
|
||
|
vel = np.array(vels).reshape(-1, 1)
|
||
|
return vel
|
||
|
|
||
|
class BaseKeypoints(LossBase):
|
||
|
@staticmethod
|
||
|
def select(keypoints, index, ranges):
|
||
|
if len(index) > 0:
|
||
|
keypoints = keypoints[..., index, :]
|
||
|
elif len(ranges) > 0:
|
||
|
if ranges[1] == -1:
|
||
|
keypoints = keypoints[..., ranges[0]:, :]
|
||
|
else:
|
||
|
keypoints = keypoints[..., ranges[0]:ranges[1], :]
|
||
|
return keypoints
|
||
|
|
||
|
def set_gt(self, index_gt, ranges_gt):
|
||
|
keypoints = self.select(self.keypoints_np, index_gt, ranges_gt)
|
||
|
keypoints = torch.Tensor(keypoints)
|
||
|
self.register_buffer('keypoints', keypoints[..., :-1])
|
||
|
self.register_buffer('conf', keypoints[..., -1])
|
||
|
|
||
|
def __str__(self):
|
||
|
return "keypoints: {}".format(self.keypoints.shape)
|
||
|
|
||
|
def __init__(self, keypoints, norm='l2', norm_info={},
|
||
|
index_gt=[], ranges_gt=[],
|
||
|
index_est=[], ranges_est=[]) -> None:
|
||
|
super().__init__()
|
||
|
# prepare ground-truth
|
||
|
self.keypoints_np = keypoints
|
||
|
self.set_gt(index_gt, ranges_gt)
|
||
|
#
|
||
|
self.index_est = index_est
|
||
|
self.ranges_est = ranges_est
|
||
|
|
||
|
self.norm = norm
|
||
|
self.loss = make_loss(norm, norm_info)
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||
|
return self.loss(est, self.keypoints, self.conf)
|
||
|
|
||
|
def check(self, kpts_est, min_conf=0.3, **kwargs):
|
||
|
est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||
|
conf = (self.conf>min_conf).float()
|
||
|
norm = torch.norm(est-self.keypoints, dim=-1) * conf
|
||
|
mean_joints = norm.sum(dim=0)/(1e-5 + conf.sum(dim=0)) * 1000
|
||
|
return conf, mean_joints
|
||
|
|
||
|
def check_at_start(self, kpts_est, **kwargs):
|
||
|
if len(self.index_est) > 0:
|
||
|
names = [str(i) for i in self.index_est]
|
||
|
elif len(self.ranges_est) > 0:
|
||
|
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
|
||
|
else:
|
||
|
names = [str(i) for i in range(self.conf.shape[-1])]
|
||
|
conf, error = self.check(kpts_est, **kwargs)
|
||
|
valid = conf.sum(dim=0).detach().cpu().numpy().tolist()
|
||
|
header = ['name', 'count']
|
||
|
contents = [names, valid]
|
||
|
header.append('before')
|
||
|
contents.append(error.detach().cpu().numpy().tolist())
|
||
|
self.cache_check = (header, contents)
|
||
|
|
||
|
def check_at_end(self, kpts_est, **kwargs):
|
||
|
conf, err_after = self.check(kpts_est, **kwargs)
|
||
|
header, contents = self.cache_check
|
||
|
header.append('after')
|
||
|
contents.append(err_after.detach().cpu().numpy().tolist())
|
||
|
print_table(header, contents)
|
||
|
|
||
|
class Keypoints3D(BaseKeypoints):
|
||
|
def __init__(self, keypoints3d, **kwargs) -> None:
|
||
|
super().__init__(keypoints3d, **kwargs)
|
||
|
|
||
|
class AnyKeypoints3D(Keypoints3D):
|
||
|
def __init__(self, **kwargs) -> None:
|
||
|
key = kwargs.pop('key')
|
||
|
keypoints3d = kwargs.pop(key)
|
||
|
super().__init__(keypoints3d, **kwargs)
|
||
|
self.key = key
|
||
|
|
||
|
class AnyKeypoints3DWithRT(Keypoints3D):
|
||
|
def __init__(self, **kwargs) -> None:
|
||
|
key = kwargs.pop('key')
|
||
|
keypoints3d = kwargs.pop(key)
|
||
|
super().__init__(keypoints3d, **kwargs)
|
||
|
self.key = key
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
R = batch_rodrigues(kwargs['R_'+self.key])
|
||
|
T = kwargs['T_'+self.key]
|
||
|
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
|
||
|
return super().forward(RXT)
|
||
|
|
||
|
def check(self, kpts_est, min_conf=0.3, **kwargs):
|
||
|
R = batch_rodrigues(kwargs['R_'+self.key])
|
||
|
T = kwargs['T_'+self.key]
|
||
|
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
|
||
|
kpts_est = RXT
|
||
|
return super().check(kpts_est, min_conf)
|
||
|
|
||
|
class Handl3D(BaseKeypoints):
|
||
|
def __init__(self, handl3d, **kwargs) -> None:
|
||
|
handl3d = handl3d.clone()
|
||
|
handl3d[..., :3] = handl3d[..., :3] - handl3d[:, :1, :3]
|
||
|
super().__init__(handl3d, **kwargs)
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
est = kpts_est[:, 25:46]
|
||
|
est = est - est[:, :1].detach()
|
||
|
return super().forward(est, **kwargs)
|
||
|
|
||
|
def check(self, kpts_est, **kwargs):
|
||
|
est = kpts_est[:, 25:46]
|
||
|
est = est - est[:, :1].detach()
|
||
|
return super().check(est, **kwargs)
|
||
|
|
||
|
|
||
|
class LimbLength(BaseKeypoints):
|
||
|
def __init__(self, kintree, key='keypoints3d', **kwargs):
|
||
|
self.kintree = np.array(kintree)
|
||
|
if key == 'bodyhand':
|
||
|
keypoints3d = np.hstack([kwargs.pop('keypoints3d'), kwargs.pop('handl3d'), kwargs.pop('handr3d')])
|
||
|
else:
|
||
|
keypoints3d = kwargs.pop(key)
|
||
|
super().__init__(keypoints3d, **kwargs)
|
||
|
|
||
|
def __str__(self):
|
||
|
return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree]))
|
||
|
|
||
|
def set_gt(self, index_gt, ranges_gt):
|
||
|
keypoints3d = self.keypoints_np
|
||
|
kintree = self.kintree
|
||
|
# limb_length: nFrames, nLimbs, 1
|
||
|
limb_length = np.linalg.norm(keypoints3d[..., kintree[:, 1], :3] - keypoints3d[..., kintree[:, 0], :3], axis=-1, keepdims=True)
|
||
|
# conf: nFrames, nLimbs, 1
|
||
|
limb_conf = np.minimum(keypoints3d[..., kintree[:, 1], -1], keypoints3d[..., kintree[:, 0], -1])
|
||
|
limb_length = torch.Tensor(limb_length)
|
||
|
limb_conf = torch.Tensor(limb_conf)
|
||
|
self.register_buffer('length', limb_length)
|
||
|
self.register_buffer('conf', limb_conf)
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
src = kpts_est[..., self.kintree[:, 0], :]
|
||
|
dst = kpts_est[..., self.kintree[:, 1], :]
|
||
|
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
|
||
|
return self.loss(length_est, self.length, self.conf)
|
||
|
|
||
|
def check_at_start(self, kpts_est, **kwargs):
|
||
|
names = [str(i) for i in self.kintree]
|
||
|
conf = (self.conf>0)
|
||
|
valid = conf.sum(dim=0).detach().cpu().numpy()
|
||
|
if len(valid.shape) == 2:
|
||
|
valid = valid.mean(axis=0)
|
||
|
header = ['name', 'count']
|
||
|
contents = [names, valid.tolist()]
|
||
|
error, length = self.check(kpts_est)
|
||
|
header.append('before')
|
||
|
contents.append(error.detach().cpu().numpy().tolist())
|
||
|
header.append('length')
|
||
|
length = (self.length[..., 0] * self.conf).sum(dim=0)/self.conf.sum(dim=0)
|
||
|
contents.append(length.detach().cpu().numpy().tolist())
|
||
|
self.cache_check = (header, contents)
|
||
|
|
||
|
def check_at_end(self, kpts_est, **kwargs):
|
||
|
err_after, length = self.check(kpts_est)
|
||
|
header, contents = self.cache_check
|
||
|
header.append('after')
|
||
|
contents.append(err_after.detach().cpu().numpy().tolist())
|
||
|
header.append('length_est')
|
||
|
contents.append(length[:,:,0].mean(dim=0).detach().cpu().numpy().tolist())
|
||
|
print_table(header, contents)
|
||
|
|
||
|
def check(self, kpts_est, **kwargs):
|
||
|
src = kpts_est[..., self.kintree[:, 0], :]
|
||
|
dst = kpts_est[..., self.kintree[:, 1], :]
|
||
|
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
|
||
|
conf = (self.conf>0).float()
|
||
|
norm = torch.abs(length_est-self.length)[..., 0] * conf
|
||
|
mean_joints = norm.sum(dim=0)/conf.sum(dim=0) * 1000
|
||
|
if len(mean_joints.shape) == 2:
|
||
|
mean_joints = mean_joints.mean(dim=0)
|
||
|
length_est = length_est.mean(dim=0)
|
||
|
return mean_joints, length_est
|
||
|
|
||
|
class LimbLengthHand(LimbLength):
|
||
|
def __init__(self, handl3d, handr3d, **kwargs):
|
||
|
kintree = kwargs.pop('kintree')
|
||
|
keypoints3d = torch.cat([handl3d, handr3d], dim=0)
|
||
|
super().__init__(kintree, keypoints3d, **kwargs)
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
|
||
|
return super().forward(kpts_est, **kwargs)
|
||
|
|
||
|
def check(self, kpts_est, **kwargs):
|
||
|
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
|
||
|
return super().check(kpts_est, **kwargs)
|
||
|
|
||
|
class Keypoints2D(BaseKeypoints):
|
||
|
def __init__(self, keypoints2d, K, Rc, Tc, einsum='fab,fnb->fna',
|
||
|
unproj=True, reshape_views=False, **kwargs) -> None:
|
||
|
# convert to camera coordinate
|
||
|
invKtrans = torch.inverse(K).transpose(-1, -2)
|
||
|
if unproj:
|
||
|
homo = torch.ones_like(keypoints2d[..., :1])
|
||
|
homo = torch.cat([keypoints2d[..., :2], homo], dim=-1)
|
||
|
if len(invKtrans.shape) < len(homo.shape):
|
||
|
invKtrans = invKtrans.unsqueeze(-3)
|
||
|
homo = torch.matmul(homo, invKtrans)
|
||
|
keypoints2d = torch.cat([homo[..., :2], keypoints2d[..., 2:]], dim=-1)
|
||
|
# keypoints2d: (nFrames, nViews, ..., nJoints, 3)
|
||
|
super().__init__(keypoints2d, **kwargs)
|
||
|
self.register_buffer('K', K)
|
||
|
self.register_buffer('invKtrans', invKtrans)
|
||
|
self.register_buffer('Rc', Rc)
|
||
|
self.register_buffer('Tc', Tc)
|
||
|
self.unproj = unproj
|
||
|
self.einsum = einsum
|
||
|
self.reshape_views = reshape_views
|
||
|
|
||
|
def project(self, kpts_est):
|
||
|
kpts_est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||
|
kpts_homo = torch.ones_like(kpts_est[..., -1:])
|
||
|
kpts_homo = torch.cat([kpts_est, kpts_homo], dim=-1)
|
||
|
if self.unproj:
|
||
|
P = torch.cat([self.Rc, self.Tc], dim=-1)
|
||
|
else:
|
||
|
P = torch.bmm(self.K, torch.cat([self.Rc, self.Tc], dim=-1))
|
||
|
if self.reshape_views:
|
||
|
kpts_homo = kpts_homo.reshape(self.K.shape[0], self.K.shape[1], *kpts_homo.shape[1:])
|
||
|
try:
|
||
|
point_cam = torch.einsum(self.einsum, P, kpts_homo)
|
||
|
except:
|
||
|
print('Wrong shape: {}x{} <=== {}'.format(P.shape, kpts_homo.shape, self.einsum))
|
||
|
raise NotImplementedError
|
||
|
img_points = point_cam[..., :2]/point_cam[..., 2:]
|
||
|
return img_points
|
||
|
|
||
|
def forward(self, kpts_est, **kwargs):
|
||
|
img_points = self.project(kpts_est)
|
||
|
loss = self.loss(img_points.squeeze(), self.keypoints.squeeze(), self.conf.squeeze())
|
||
|
return loss
|
||
|
|
||
|
def check(self, kpts_est, min_conf=0.3):
|
||
|
with torch.no_grad():
|
||
|
img_points = self.project(kpts_est)
|
||
|
conf = (self.conf>min_conf)
|
||
|
err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf
|
||
|
if len(err.shape) == 3:
|
||
|
err = err.sum(dim=1)
|
||
|
conf = conf.sum(dim=1)
|
||
|
err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0))
|
||
|
return conf, err
|
||
|
|
||
|
def check_at_start(self, kpts_est, **kwargs):
|
||
|
if len(self.index_est) > 0:
|
||
|
names = [str(i) for i in self.index_est]
|
||
|
elif len(self.ranges_est) > 0:
|
||
|
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
|
||
|
else:
|
||
|
names = [str(i) for i in range(self.conf.shape[-1])]
|
||
|
conf, error = self.check(kpts_est)
|
||
|
valid = conf.sum(dim=0).detach().cpu().numpy()
|
||
|
valid = valid.tolist()
|
||
|
header = ['name', 'count']
|
||
|
contents = [names, valid]
|
||
|
header.append('before(pix)')
|
||
|
contents.append(error.detach().cpu().numpy().tolist())
|
||
|
self.cache_check = (header, contents)
|
||
|
|
||
|
def check_at_end(self, kpts_est, **kwargs):
|
||
|
conf, err_after = self.check(kpts_est)
|
||
|
header, contents = self.cache_check
|
||
|
header.append('after(pix)')
|
||
|
contents.append(err_after.detach().cpu().numpy().tolist())
|
||
|
print_table(header, contents)
|