EasyMocap/easymocap/multistage/initialize.py
2022-08-21 16:07:06 +08:00

173 lines
6.2 KiB
Python

import numpy as np
import cv2
from ..dataset.config import CONFIG
from ..config import load_object
from ..mytools.debug_utils import log, mywarn, myerror
import torch
from tqdm import tqdm, trange
def svd_rot(src, tgt, reflection=False, debug=False):
# optimum rotation matrix of Y
A = np.matmul(src.transpose(0, 2, 1), tgt)
U, s, Vt = np.linalg.svd(A, full_matrices=False)
V = Vt.transpose(0, 2, 1)
T = np.matmul(V, U.transpose(0, 2, 1))
# does the current solution use a reflection?
have_reflection = np.linalg.det(T) < 0
# if that's not what was specified, force another reflection
V[have_reflection, :, -1] *= -1
s[have_reflection, -1] *= -1
T = np.matmul(V, U.transpose(0, 2, 1))
if debug:
err = np.linalg.norm(tgt - src @ T.T, axis=1)
print('[svd] ', err)
return T
def batch_invRodrigues(rot):
res = []
for r in rot:
v = cv2.Rodrigues(r)[0]
res.append(v)
res = np.stack(res)
return res[:, :, 0]
class BaseInit:
def __init__(self) -> None:
pass
def __call__(self, body_model, body_params, infos):\
return body_params
class Remove(BaseInit):
def __init__(self, key, index) -> None:
super().__init__()
self.key = key
self.index = index
def __call__(self, body_model, body_params, infos):
infos[self.key][..., self.index, :] = 0
return super().__call__(body_model, body_params, infos)
class CheckKeypoints:
def __init__(self, type) -> None:
# this class is used to check if the provided keypoints3d
self.type = type
self.body_config = CONFIG[type]
self.hand_config = CONFIG['hand']
def __call__(self, body_model, body_params, infos):
for key in ['keypoints3d', 'handl3d', 'handr3d']:
if key not in infos.keys(): continue
keypoints = infos[key]
conf = keypoints[..., -1]
keypoints[conf<0.1] = 0
if key == 'keypoints3d':
continue
import ipdb;ipdb.set_trace()
# limb_length = np.linalg.norm(keypoints[:, , :3], axis=2)
return body_params
class InitRT:
def __init__(self, torso) -> None:
self.torso = torso
def __call__(self, body_model, body_params, infos):
keypoints3d = infos['keypoints3d'].detach().cpu().numpy()
temp_joints = body_model.keypoints(body_params, return_tensor=False)
torso = keypoints3d[..., self.torso, :3].copy()
torso_temp = temp_joints[..., self.torso, :3].copy()
# here use the first id of torso as the rotation center
root, root_temp = torso[..., :1, :], torso_temp[..., :1, :]
torso = torso - root
torso_temp = torso_temp - root_temp
conf = (keypoints3d[..., self.torso, 3] > 0.).all(axis=-1)
if not conf.all():
myerror("The torso in frames {} is not valid, please check the 3d keypoints".format(np.where(~conf)))
if len(torso.shape) == 3:
R = svd_rot(torso_temp, torso)
R_flat = R
T = np.matmul(- root_temp, R.transpose(0, 2, 1)) + root
else:
R_flat = svd_rot(torso_temp.reshape(-1, *torso_temp.shape[-2:]), torso.reshape(-1, *torso.shape[-2:]))
R = R_flat.reshape(*torso.shape[:2], 3, 3)
T = np.matmul(- root_temp, R.swapaxes(-1, -2)) + root
for nf in np.where(~conf)[0]:
# copy previous frames
mywarn('copy {} from {}'.format(nf, nf-1))
R[nf] = R[nf-1]
T[nf] = T[nf-1]
body_params['Th'] = T[..., 0, :]
rvec = batch_invRodrigues(R_flat)
if len(torso.shape) > 3:
rvec = rvec.reshape(*torso.shape[:2], 3)
body_params['Rh'] = rvec
return body_params
def __str__(self) -> str:
return "[Initialize] svd with torso: {}".format(self.torso)
class TriangulatorWrapper:
def __init__(self, module, args):
self.triangulator = load_object(module, args)
def __call__(self, body_model, body_params, infos):
infos['RT'] = torch.cat([infos['Rc'], infos['Tc']], dim=-1)
data = {
'RT': infos['RT'].numpy(),
}
for key in self.triangulator.keys:
if key not in infos.keys():
continue
data[key] = infos[key].numpy()
data[key+'_unproj'] = infos[key+'_unproj'].numpy()
data[key+'_distort'] = infos[key+'_distort'].numpy()
results = self.triangulator(data)[0]
for key, val in results.items():
if key == 'id': continue
infos[key] = torch.Tensor(val[None].astype(np.float32))
body_params = body_model.init_params(nFrames=1, add_scale=True)
return body_params
class CheckRT:
def __init__(self, T_thres, window):
self.T_thres = T_thres
self.window = window
def __call__(self, body_model, body_params, infos):
Th = body_params['Th']
if len(Th.shape) == 3:
for nper in range(Th.shape[1]):
for nf in trange(1, Th.shape[0], desc='Check Th of {}'.format(nper)):
if nf > self.window:
tpre = Th[nf-self.window:nf, nper]
else:
tpre = Th[:nf, nper]
tpre = tpre.mean(axis=0)
tnow = Th[nf , nper]
dist = np.linalg.norm(tnow - tpre)
if dist > self.T_thres:
mywarn('[Check Th] distance in frame {} = {} larger than {}'.format(nf, dist, self.T_thres))
Th[nf, nper] = tpre
body_params['Th'] = Th
return body_params
class Scale:
def __init__(self, keys):
self.keys = keys
def __call__(self, body_model, body_params, infos):
scale = body_params.pop('scale')[0, 0]
if scale < 1.1 and scale > 0.9:
return body_params
print('scale = ', scale)
for key in self.keys:
if key not in infos.keys():
continue
infos[key] /= scale
infos['Tc'] /= scale
infos['RT'][..., -1] *= scale
infos['scale'] = scale
return body_params