173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
import numpy as np
|
|
import cv2
|
|
from ..dataset.config import CONFIG
|
|
from ..config import load_object
|
|
from ..mytools.debug_utils import log, mywarn, myerror
|
|
import torch
|
|
from tqdm import tqdm, trange
|
|
|
|
def svd_rot(src, tgt, reflection=False, debug=False):
|
|
# optimum rotation matrix of Y
|
|
A = np.matmul(src.transpose(0, 2, 1), tgt)
|
|
U, s, Vt = np.linalg.svd(A, full_matrices=False)
|
|
V = Vt.transpose(0, 2, 1)
|
|
T = np.matmul(V, U.transpose(0, 2, 1))
|
|
# does the current solution use a reflection?
|
|
have_reflection = np.linalg.det(T) < 0
|
|
|
|
# if that's not what was specified, force another reflection
|
|
V[have_reflection, :, -1] *= -1
|
|
s[have_reflection, -1] *= -1
|
|
T = np.matmul(V, U.transpose(0, 2, 1))
|
|
if debug:
|
|
err = np.linalg.norm(tgt - src @ T.T, axis=1)
|
|
print('[svd] ', err)
|
|
return T
|
|
|
|
def batch_invRodrigues(rot):
|
|
res = []
|
|
for r in rot:
|
|
v = cv2.Rodrigues(r)[0]
|
|
res.append(v)
|
|
res = np.stack(res)
|
|
return res[:, :, 0]
|
|
|
|
class BaseInit:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __call__(self, body_model, body_params, infos):\
|
|
return body_params
|
|
|
|
class Remove(BaseInit):
|
|
def __init__(self, key, index) -> None:
|
|
super().__init__()
|
|
self.key = key
|
|
self.index = index
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
infos[self.key][..., self.index, :] = 0
|
|
return super().__call__(body_model, body_params, infos)
|
|
|
|
class CheckKeypoints:
|
|
def __init__(self, type) -> None:
|
|
# this class is used to check if the provided keypoints3d
|
|
self.type = type
|
|
self.body_config = CONFIG[type]
|
|
self.hand_config = CONFIG['hand']
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
for key in ['keypoints3d', 'handl3d', 'handr3d']:
|
|
if key not in infos.keys(): continue
|
|
keypoints = infos[key]
|
|
conf = keypoints[..., -1]
|
|
keypoints[conf<0.1] = 0
|
|
if key == 'keypoints3d':
|
|
continue
|
|
import ipdb;ipdb.set_trace()
|
|
# limb_length = np.linalg.norm(keypoints[:, , :3], axis=2)
|
|
return body_params
|
|
|
|
class InitRT:
|
|
def __init__(self, torso) -> None:
|
|
self.torso = torso
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
keypoints3d = infos['keypoints3d'].detach().cpu().numpy()
|
|
temp_joints = body_model.keypoints(body_params, return_tensor=False)
|
|
|
|
torso = keypoints3d[..., self.torso, :3].copy()
|
|
torso_temp = temp_joints[..., self.torso, :3].copy()
|
|
# here use the first id of torso as the rotation center
|
|
root, root_temp = torso[..., :1, :], torso_temp[..., :1, :]
|
|
torso = torso - root
|
|
torso_temp = torso_temp - root_temp
|
|
conf = (keypoints3d[..., self.torso, 3] > 0.).all(axis=-1)
|
|
if not conf.all():
|
|
myerror("The torso in frames {} is not valid, please check the 3d keypoints".format(np.where(~conf)))
|
|
if len(torso.shape) == 3:
|
|
R = svd_rot(torso_temp, torso)
|
|
R_flat = R
|
|
T = np.matmul(- root_temp, R.transpose(0, 2, 1)) + root
|
|
else:
|
|
R_flat = svd_rot(torso_temp.reshape(-1, *torso_temp.shape[-2:]), torso.reshape(-1, *torso.shape[-2:]))
|
|
R = R_flat.reshape(*torso.shape[:2], 3, 3)
|
|
T = np.matmul(- root_temp, R.swapaxes(-1, -2)) + root
|
|
for nf in np.where(~conf)[0]:
|
|
# copy previous frames
|
|
mywarn('copy {} from {}'.format(nf, nf-1))
|
|
R[nf] = R[nf-1]
|
|
T[nf] = T[nf-1]
|
|
body_params['Th'] = T[..., 0, :]
|
|
rvec = batch_invRodrigues(R_flat)
|
|
if len(torso.shape) > 3:
|
|
rvec = rvec.reshape(*torso.shape[:2], 3)
|
|
body_params['Rh'] = rvec
|
|
return body_params
|
|
|
|
def __str__(self) -> str:
|
|
return "[Initialize] svd with torso: {}".format(self.torso)
|
|
|
|
class TriangulatorWrapper:
|
|
def __init__(self, module, args):
|
|
self.triangulator = load_object(module, args)
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
infos['RT'] = torch.cat([infos['Rc'], infos['Tc']], dim=-1)
|
|
data = {
|
|
'RT': infos['RT'].numpy(),
|
|
}
|
|
for key in self.triangulator.keys:
|
|
if key not in infos.keys():
|
|
continue
|
|
data[key] = infos[key].numpy()
|
|
data[key+'_unproj'] = infos[key+'_unproj'].numpy()
|
|
data[key+'_distort'] = infos[key+'_distort'].numpy()
|
|
results = self.triangulator(data)[0]
|
|
for key, val in results.items():
|
|
if key == 'id': continue
|
|
infos[key] = torch.Tensor(val[None].astype(np.float32))
|
|
body_params = body_model.init_params(nFrames=1, add_scale=True)
|
|
return body_params
|
|
|
|
class CheckRT:
|
|
def __init__(self, T_thres, window):
|
|
self.T_thres = T_thres
|
|
self.window = window
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
Th = body_params['Th']
|
|
if len(Th.shape) == 3:
|
|
for nper in range(Th.shape[1]):
|
|
for nf in trange(1, Th.shape[0], desc='Check Th of {}'.format(nper)):
|
|
if nf > self.window:
|
|
tpre = Th[nf-self.window:nf, nper]
|
|
else:
|
|
tpre = Th[:nf, nper]
|
|
tpre = tpre.mean(axis=0)
|
|
tnow = Th[nf , nper]
|
|
dist = np.linalg.norm(tnow - tpre)
|
|
if dist > self.T_thres:
|
|
mywarn('[Check Th] distance in frame {} = {} larger than {}'.format(nf, dist, self.T_thres))
|
|
Th[nf, nper] = tpre
|
|
body_params['Th'] = Th
|
|
return body_params
|
|
|
|
class Scale:
|
|
def __init__(self, keys):
|
|
self.keys = keys
|
|
|
|
def __call__(self, body_model, body_params, infos):
|
|
scale = body_params.pop('scale')[0, 0]
|
|
if scale < 1.1 and scale > 0.9:
|
|
return body_params
|
|
print('scale = ', scale)
|
|
for key in self.keys:
|
|
if key not in infos.keys():
|
|
continue
|
|
infos[key] /= scale
|
|
infos['Tc'] /= scale
|
|
infos['RT'][..., -1] *= scale
|
|
infos['scale'] = scale
|
|
return body_params
|