From a0127f712a105c13205e6e81817c3b37f5afa290 Mon Sep 17 00:00:00 2001 From: Qing Shuai Date: Sun, 21 Aug 2022 16:07:06 +0800 Subject: [PATCH] :construction: create the new stype of fitting --- easymocap/multistage/base.py | 308 ++++ easymocap/multistage/base_ops.py | 39 + easymocap/multistage/before_after.py | 57 + easymocap/multistage/fitting.py | 1771 +++++++++++++++++++++++ easymocap/multistage/init_cnn.py | 100 ++ easymocap/multistage/init_pose.py | 36 + easymocap/multistage/initialize.py | 172 +++ easymocap/multistage/lossbase.py | 589 ++++++++ easymocap/multistage/mirror.py | 261 ++++ easymocap/multistage/synchronization.py | 79 + easymocap/multistage/torchgeometry.py | 517 +++++++ easymocap/multistage/totalfitting.py | 97 ++ 12 files changed, 4026 insertions(+) create mode 100644 easymocap/multistage/base.py create mode 100644 easymocap/multistage/base_ops.py create mode 100644 easymocap/multistage/before_after.py create mode 100644 easymocap/multistage/fitting.py create mode 100644 easymocap/multistage/init_cnn.py create mode 100644 easymocap/multistage/init_pose.py create mode 100644 easymocap/multistage/initialize.py create mode 100644 easymocap/multistage/lossbase.py create mode 100644 easymocap/multistage/mirror.py create mode 100644 easymocap/multistage/synchronization.py create mode 100644 easymocap/multistage/torchgeometry.py create mode 100644 easymocap/multistage/totalfitting.py diff --git a/easymocap/multistage/base.py b/easymocap/multistage/base.py new file mode 100644 index 0000000..a54628d --- /dev/null +++ b/easymocap/multistage/base.py @@ -0,0 +1,308 @@ +# 这个脚本用于通用的多阶段的优化 +import numpy as np +import torch + +from ..annotator.file_utils import read_json +from ..mytools import Timer +from .lossbase import print_table +from ..config.baseconfig import load_object +from ..bodymodel.base import Params +from torch.utils.data import DataLoader +from tqdm import tqdm + +def dict_of_numpy_to_tensor(body_model, body_params, *args, **kwargs): + device = body_model.device + body_params = {key:torch.Tensor(val).to(device) for key, val in body_params.items()} + return body_params + +class AddExtra: + def __init__(self, vals) -> None: + self.vals = vals + + def __call__(self, body_model, body_params, *args, **kwargs): + shapes = body_params['poses'].shape[:-1] + for key in self.vals: + if key in body_params.keys(): + continue + if key.startswith('R_') or key.startswith('T_'): + val = np.zeros((*shapes, 3), dtype=np.float32) + body_params[key] = val + return body_params + +def dict_of_tensor_to_numpy(body_params): + body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()} + return body_params + +def grad_require(params, flag=False): + if isinstance(params, list): + for par in params: + par.requires_grad = flag + elif isinstance(params, dict): + for key, par in params.items(): + par.requires_grad = flag + +def rel_change(prev_val, curr_val): + return (prev_val - curr_val) / max([1e-5, abs(prev_val), abs(curr_val)]) + +def make_optimizer(opt_params, optim_type='lbfgs', max_iter=20, + lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0, **kwargs): + if isinstance(opt_params, dict): + # LBFGS 不支持参数字典 + opt_params = list(opt_params.values()) + if optim_type == 'lbfgs': + from ..pyfitting.lbfgs import LBFGS + optimizer = LBFGS(opt_params, line_search_fn='strong_wolfe', max_iter=max_iter, **kwargs) + elif optim_type == 'adam': + optimizer = torch.optim.Adam(opt_params, lr=lr, betas=betas, weight_decay=weight_decay) + else: + raise NotImplementedError + return optimizer + +def make_lossfuncs(stage, infos, device, irepeat, verbose=False): + loss_funcs, weights = {}, {} + for key, val in stage.loss.items(): + loss_args = dict(val.args) + if 'infos' in val.keys(): + for k in val.infos: + loss_args[k] = infos[k] + module = load_object(val.module, loss_args) + module.to(device) + if 'weights' in val.keys(): + weights[key] = val.weights[irepeat] + else: + weights[key] = val.weight + if weights[key] < 0: + weights.pop(key) + else: + loss_funcs[key] = module + if verbose or True: + print('Loss functions: ') + for key, func in loss_funcs.items(): + print(' - {:15s}: {}, {}'.format(key, weights[key], func)) + return loss_funcs, weights + +def make_before_after(before_after, body_model, body_params, infos): + modules = [] + for key, val in before_after.items(): + args = dict(val.args) + if 'body_model' in args.keys(): + args['body_model'] = body_model + try: + module = load_object(val.module, args) + except: + print('[Fitting] Failed to load module {}'.format(key)) + raise NotImplementedError + module.infos = infos + modules.append(module) + return modules + +def process(start_or_end, body_model, body_params, infos): + for key, val in start_or_end.items(): + if isinstance(val, dict): + module = load_object(val.module, val.args) + else: + if key == 'convert' and val == 'numpy_to_tensor': + module = dict_of_numpy_to_tensor + if key == 'add': + module = AddExtra(val) + body_params = module(body_model, body_params, infos) + return body_params + +def plot_meshes(img, meshes, K, R, T): + import cv2 + mesh_camera = [] + for mesh in meshes: + vertices = mesh['vertices'] @ R.T + T.T + v2d = vertices @ K.T + v2d[:, :2] = v2d[:, :2] / v2d[:, 2:3] + lw=1 + col=(0,0,255) + for (x, y, d) in v2d[::10]: + cv2.circle(img, (int(x+0.5), int(y+0.5)), lw*2, col, -1) + return img + +class MultiStage: + def __init__(self, batch_size, optimizer, monitor, initialize, stages) -> None: + self.batch_size = batch_size + self.optimizer_args = optimizer + self.monitor = monitor + self.initialize = initialize + self.stages = stages + + def make_closure(self, body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module): + def closure(debug=False, ret_kpts=False): + # 0. Prepare body parameters => new_params + optimizer.zero_grad() + new_params = body_params.copy() + for module in before_after_module: + new_params = module.before(new_params) + # 1. Compute keypoints => kpts_est + poses_full = body_model.extend_poses(**new_params) + kpts_est = body_model(return_verts=False, return_tensor=True, **new_params) + if ret_kpts: + return kpts_est + verts_est = None + # 2. Compute loss => loss_dict + loss_dict = {} + for key, loss_func in loss_funcs.items(): + if key.startswith('v'): + if verts_est is None: + verts_est = body_model(return_verts=True, return_tensor=True, **new_params) + loss_dict[key] = loss_func(verts_est=verts_est, **new_params, **infos) + elif key.startswith('pf-'): + loss_dict[key] = loss_func(poses_full=poses_full, **new_params, **infos) + else: + loss_dict[key] = loss_func(kpts_est=kpts_est, **new_params, **infos) + loss = sum([loss_dict[key]*weights[key] + for key in loss_dict.keys()]) + if debug: + return loss_dict + loss.backward() + return loss + return closure + + def optimizer_step(self, optimizer, closure, weights): + prev_loss = None + for iter_ in range(self.monitor.maxiters): + with torch.no_grad(): + loss_dict = closure(debug=True) + if self.monitor.printloss or (self.monitor.verbose and iter_ == 0): + print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()])) + loss = optimizer.step(closure) + # check the loss + if torch.isnan(loss).sum() > 0: + print('[optimize] NaN loss value, stopping!') + break + if torch.isinf(loss).sum() > 0: + print('[optimize] Infinite loss value, stopping!') + break + # check the delta + if iter_ > 0 and prev_loss is not None: + loss_rel_change = rel_change(prev_loss, loss.item()) + if loss_rel_change <= self.monitor.ftol: + if self.monitor.printloss or self.monitor.verbose: + print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()])) + break + # log + if self.monitor.vis2d: + pass + if self.monitor.vis3d: + pass + prev_loss = loss.item() + return True + + def fit_stage(self, body_model, body_params, infos, stage, irepeat): + # 单独拟合一个stage, 返回body_params + optimizer_args = stage.get('optimizer', self.optimizer_args) + dtype, device = body_model.dtype, body_model.device + body_params = process(stage.get('at_start', {'convert': 'numpy_to_tensor'}), body_model, body_params, infos) + opt_params = {} + if 'optimize' in stage.keys(): + optimize_names = stage.optimize + else: + optimize_names = stage.optimizes[irepeat] + for key in optimize_names: + if key in infos.keys(): # 优化的参数 + infos[key] = infos[key].to(device) + opt_params[key] = infos[key] + elif key in body_params.keys(): + opt_params[key] = body_params[key] + else: + raise ValueError('{} is not in infos or body_params'.format(key)) + if self.monitor.verbose: + print('[optimize] optimizing {}'.format(optimize_names)) + for key, val in opt_params.items(): + infos['init_'+key] = val.clone().detach().cpu() + # initialize keypoints + with torch.no_grad(): + kpts_est = body_model.keypoints(body_params) + infos['init_kpts_est'] = kpts_est.clone().detach().cpu() + before_after_module = make_before_after(stage.get('before_after', {}), body_model, body_params, infos) + for module in before_after_module: + # Input to this module is tensor + body_params = module.start(body_params) + grad_require(opt_params, True) + optimizer = make_optimizer(opt_params, **optimizer_args) + loss_funcs, weights = make_lossfuncs(stage, infos, device, irepeat, self.monitor.verbose) + closure = self.make_closure(body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module) + if self.monitor.check: + new_params = body_params.copy() + for module in before_after_module: + new_params = module.before(new_params) + kpts_est = body_model.keypoints(new_params) + for key, loss in loss_funcs.items(): + loss.check_at_start(kpts_est=kpts_est, **new_params) + self.optimizer_step(optimizer, closure, weights) + grad_require(opt_params, False) + if self.monitor.check: + new_params = body_params.copy() + for module in before_after_module: + new_params = module.before(new_params) + kpts_est = body_model.keypoints(new_params) + for key, loss in loss_funcs.items(): + loss.check_at_end(kpts_est=kpts_est, **new_params) + for module in before_after_module: + # Input to this module is tensor + body_params = module.final(body_params) + body_params = dict_of_tensor_to_numpy(body_params) + for key, val in opt_params.items(): + if key in infos.keys(): + infos[key] = val.detach().cpu() + return body_params + + def fit(self, body_model, dataset): + batch_size = len(dataset) if self.batch_size == -1 else self.batch_size + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False) + if len(dataloader) > 1: + dataloader = tqdm(dataloader, desc='optimizing') + for data in dataloader: + data = dataset.reshape_data(data) + infos = data.copy() + init_params = body_model.init_params(nFrames=infos['nFrames'], nPerson=infos.get('nPerson', 1)) + # first initialize the model + for name, init_func in self.initialize.items(): + if 'loss' in init_func.keys(): + # fitting to initialize + init_params = self.fit_stage(body_model, init_params, infos, init_func, 0) + else: + # use initialize module + init_module = load_object(init_func.module, init_func.args) + init_params = init_module(body_model, init_params, infos) + # if there are multiple initialization params + # then fit each of them + if not isinstance(init_params, list): + init_params = [init_params] + results = [] + for init_param in init_params: + # check the repeat params + body_params = init_param + for stage_name, stage in self.stages.items(): + for irepeat in range(stage.get('repeat', 1)): + with Timer('optimize {}'.format(stage_name), not self.monitor.timer): + body_params = self.fit_stage(body_model, body_params, infos, stage, irepeat) + results.append(body_params) + # select the best results + if len(results) > 1: + # check the result + loss = load_object(self.check.module, self.check.args, **{key:infos[key] for key in self.check.infos}) + metrics = [loss(body_model.keypoints(body_params, return_tensor=True).cpu()).item() for body_params in results] + best_idx = np.argmin(metrics) + else: + best_idx = 0 + if 'sync_offset' in body_params.keys(): + offset = body_params.pop('sync_offset') + dataset.write_offset(offset) + body_params = Params(**results[best_idx]) + if data['nFrames'] != body_params['poses'].shape[0]: + for key in body_params.keys(): + if body_params[key].shape[0] == 1:continue + body_params[key] = body_params[key].reshape(data['nFrames'], -1, *body_params[key].shape[1:]) + print(key, body_params[key].shape) + if 'K' in infos.keys(): + camera = Params(K=infos['K'].numpy(), R=infos['Rc'].numpy(), T=infos['Tc'].numpy()) + if 'mirror' in infos.keys(): + camera['mirror'] = infos['mirror'].numpy()[None] + dataset.write(body_model, body_params, data, camera) + else: + # write data without camera + dataset.write(body_model, body_params, data) \ No newline at end of file diff --git a/easymocap/multistage/base_ops.py b/easymocap/multistage/base_ops.py new file mode 100644 index 0000000..f49b6db --- /dev/null +++ b/easymocap/multistage/base_ops.py @@ -0,0 +1,39 @@ +''' + @ Date: 2022-08-12 20:34:15 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-08-18 14:47:23 + @ FilePath: /EasyMocapPublic/easymocap/multistage/base_ops.py +''' +import torch + +class BeforeAfterBase: + def __init__(self, model) -> None: + pass + + def start(self, body_params): + # operation before the optimization + return body_params + + def before(self, body_params): + # operation in each optimization step + return body_params + + def final(self, body_params): + # operation after the optimization + return body_params + +class SkipPoses(BeforeAfterBase): + def __init__(self, index, nPoses) -> None: + self.index = index + self.nPoses = nPoses + self.copy_index = [i for i in range(nPoses) if i not in index] + + def before(self, body_params): + poses = body_params['poses'] + poses_copy = torch.zeros_like(poses) + print(poses.shape) + poses_copy[..., self.copy_index] = poses[..., self.copy_index] + body_params['poses'] = poses_copy + return body_params \ No newline at end of file diff --git a/easymocap/multistage/before_after.py b/easymocap/multistage/before_after.py new file mode 100644 index 0000000..0a12ccc --- /dev/null +++ b/easymocap/multistage/before_after.py @@ -0,0 +1,57 @@ +import torch + +class Remove: + def __init__(self, key, index=[], ranges=[]) -> None: + self.key = key + self.ranges = ranges + self.index = index + + def before(self, body_params): + val = body_params[self.key] + if self.ranges[0] == 0: + val_zeros = torch.zeros_like(val[:, :self.ranges[1]]) + val = torch.cat([val_zeros, val[:, self.ranges[1]:]], dim=1) + body_params[self.key] = val + return body_params + +class RemoveHand: + def __init__(self, start=60) -> None: + pass + + def before(self, body_params): + poses = body_params['poses'] + val_zeros = torch.zeros_like(poses[:, 60:]) + val = torch.cat([poses[:, :60], val_zeros], dim=1) + body_params['poses'] = val + return body_params + +class Keep: + def __init__(self, key, ranges=[], index=[]) -> None: + self.key = key + self.ranges = ranges + self.index = index + + def before(self, body_params): + val = body_params[self.key] + val_zeros = val.detach().clone() + if len(self.ranges) > 0: + val_zeros[..., self.ranges[0]:self.ranges[1]] = val[..., self.ranges[0]:self.ranges[1]] + elif len(self.index) > 0: + val_zeros[..., self.index] = val[..., self.index] + body_params[self.key] = val_zeros + return body_params + + def final(self, body_params): + return body_params + +class VPoser2Full: + def __init__(self, key) -> None: + pass + + def __call__(self, body_model, body_params, infos): + if not 'Embedding' in body_model.__class__.__name__: + return body_params + poses = body_params['poses'] + poses_full = body_model.decode(poses, add_rot=False) + body_params['poses'] = poses_full + return body_params \ No newline at end of file diff --git a/easymocap/multistage/fitting.py b/easymocap/multistage/fitting.py new file mode 100644 index 0000000..c4082c8 --- /dev/null +++ b/easymocap/multistage/fitting.py @@ -0,0 +1,1771 @@ +''' + @ Date: 2022-03-22 16:11:44 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-07-25 11:51:50 + @ FilePath: /EasyMocapPublic/easymocap/multistage/fitting.py +''' +# This function provides a realtime fitting interface +from collections import namedtuple +from time import time, sleep +import numpy as np +import cv2 +import torch +import copy + +from ..config.baseconfig import load_object_from_cmd +from ..mytools.debug_utils import log, mywarn +from ..mytools import Timer +from ..config import Config +from ..mytools.triangulator import iterative_triangulate +from ..bodymodel.base import Params +from .torchgeometry import axis_angle_to_euler, euler_to_axis_angle + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + +from scipy.spatial.transform import Rotation +def aa2euler(aa): + aa = np.array(aa) + R = cv2.Rodrigues(aa)[0] + # res = Rotation.from_dcm(R).as_euler('XYZ', degrees=True) + res = Rotation.from_matrix(R).as_euler('XYZ', degrees=False) + return np.round(res, 2).tolist() + +def rotmat2euler(rot): + res = Rotation.from_matrix(rot).as_euler('XYZ', degrees=True) + return res + +def euler2rotmat(euler): + res = Rotation.from_euler('XYZ', euler, degrees=True) + return res.as_matrix() + +def batch_rodrigues_jacobi(rvec): + shape = rvec.shape + rvec = rvec.view(-1, 3) + device = rvec.device + dSkew = torch.zeros(3, 9, device=device) + dSkew[0, 5] = -1 + dSkew[1, 6] = -1 + dSkew[2, 1] = -1 + dSkew[0, 7] = 1 + dSkew[1, 2] = 1 + dSkew[2, 3] = 1 + dSkew = dSkew[None] + theta = torch.norm(rvec, dim=-1, keepdim=True) + 1e-5 + c = torch.cos(theta) + s = torch.sin(theta) + c1 = 1 - c + itheta = 1 / theta + r = rvec / theta + zeros = torch.zeros_like(r[:, :1]) + rx, ry, rz = torch.split(r, 1, dim=1) + rrt = torch.matmul(r[:, :, None], r[:, None, :]) + skew = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((r.shape[0], 3, 3)) + I = torch.eye(3, device=rvec.device, dtype=rvec.dtype)[None] + rot_mat = I + s[:, None] * skew + c1[:, None] * torch.bmm(skew, skew) + + drrt = torch.stack([ + rx + rx, ry, rz, ry, zeros, zeros, rz, zeros, zeros, + zeros, rx, zeros, rx, ry + ry, rz, zeros, rz, zeros, + zeros, zeros, rx, zeros, zeros, ry, rx, ry, rz + rz + ], dim=-1).view((r.shape[0], 3, 9)) + jacobi = torch.zeros((r.shape[0], 3, 9), device=rvec.device, dtype=rvec.dtype) + for i in range(3): + ri = r[:, i:i+1] + a0 = -s * ri + a1 = (s - 2*c1*itheta)*ri + a2 = c1 * itheta + a3 = (c-s*itheta)*ri + a4 = s * itheta + jaco = a0[:, None] * I + a1[:, None] * rrt + a2[:, None] * drrt[:, i].view(-1, 3, 3) + a3[:, None] * skew + a4[:, None] * dSkew[:, i].view(-1, 3, 3) + jacobi[:, i] = jaco.view(-1, 9) + rot_mat = rot_mat.view(*shape[:-1], 3, 3) + jacobi = jacobi.view(*shape[:-1], 3, 9) + return rot_mat, jacobi + +def getJacobianOfRT(rvec, tvec, joints): + # joints: (bn, nJ, 3) + dtype, device = rvec.dtype, rvec.device + bn, nJoints = joints.shape[:2] + # jacobiToRvec: (bn, 3, 9) // tested by OpenCV and PyTorch + Rot, jacobiToRvec = batch_rodrigues_jacobi(rvec) + I3 = torch.eye(3, dtype=dtype, device=device)[None] + # jacobiJ_R: (bn, nJ, 3, 3+3+3) => (bn, nJ, 3, 9) + # // flat by column: + # // x, 0, 0 | y, 0, 0 | z, 0, 0 + # // 0, x, 0 | 0, y, 0 | 0, z, 0 + # // 0, 0, x | 0, 0, y | 0, 0, z + jacobi_J_R = torch.zeros((bn, nJoints, 3, 9), dtype=dtype, device=device) + jacobi_J_R[:, :, 0, :3] = joints + jacobi_J_R[:, :, 1, 3:6] = joints + jacobi_J_R[:, :, 2, 6:9] = joints + # jacobi_J_rvec: (bn, nJ, 3, 3) + jacobi_J_rvec = torch.matmul(jacobi_J_R, jacobiToRvec[:, None].transpose(-1, -2)) + # if True: # 测试自动梯度 + # def test_func(rvec): + # Rot = batch_rodrigues(rvec[None])[0] + # joints_new = joints[0] @ Rot.t() + # return joints_new + # jac_J_rvec = torch.autograd.functional.jacobian(test_func, rvec[0]) + # my_j = jacobi_joints_RT[0, ..., :3] + # jacobi_J_tvec: (bn, nJx3, 3) + jacobi_J_tvec = I3[None].expand(bn, nJoints, -1, -1) + jacobi_J_rt = torch.cat([jacobi_J_rvec, jacobi_J_tvec], dim=-1) + return Rot, jacobiToRvec, jacobi_J_rt + +class Model: + rootIdx = 0 + parents = [] + + +INDEX_HALF = [0,1,2,3,4,5,6,7,15,16,17,18] + +class LowPassFilter: + def __init__(self): + self.prev_raw_value = None + self.prev_filtered_value = None + + def process(self, value, alpha): + if self.prev_raw_value is None: + s = value + else: + s = alpha * value + (1.0 - alpha) * self.prev_filtered_value + self.prev_raw_value = value + self.prev_filtered_value = s + return s + +class OneEuroFilter: + def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30): + self.freq = freq + self.mincutoff = mincutoff + self.beta = beta + self.dcutoff = dcutoff + self.x_filter = LowPassFilter() + self.dx_filter = LowPassFilter() + + def compute_alpha(self, cutoff): + te = 1.0 / self.freq + tau = 1.0 / (2 * np.pi * cutoff) + return 1.0 / (1.0 + tau / te) + + def process(self, x): + prev_x = self.x_filter.prev_raw_value + dx = 0.0 if prev_x is None else (x - prev_x) * self.freq + edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff)) + cutoff = self.mincutoff + self.beta * np.abs(edx) + return self.x_filter.process(x, self.compute_alpha(cutoff)) + +class BaseBody: + def __init__(self, cfg_triangulator, cfg_model, cfg) -> None: + self.triangulator = load_object_from_cmd(cfg_triangulator, []) + self.body_model = load_object_from_cmd(cfg_model, ['args.use_pose_blending', False, 'args.device', 'cpu']) + self.cfg = cfg + self.register_from_lbs(self.body_model) + + def register_from_lbs(self, body_model): + kintree_shape = np.array(self.cfg.shape.kintree) + self.nJoints = body_model.J_regressor.shape[0] + self.k_shapeBlend = body_model.j_shapedirs[self.nJoints:] + self.j_shapeBlend = body_model.j_shapedirs[:self.nJoints] + self.jacobian_limb_shapes = self.k_shapeBlend[kintree_shape[:, 1]] - self.k_shapeBlend[kintree_shape[:, 0]] + self.k_template = body_model.j_v_template[self.nJoints:] + self.j_template = body_model.j_v_template[:self.nJoints] + self.k_weights = body_model.j_weights[self.nJoints:] + self.j_weights = body_model.j_weights[:self.nJoints] + parents = body_model.parents[1:].cpu().numpy() + child = np.arange(1, parents.shape[0]+1, dtype=np.int64) + self.kintree = np.stack([parents, child], axis=1) + self.parents = np.zeros(parents.shape[0]+1, dtype=np.int) - 1 + self.parents[self.kintree[:, 1]] = self.kintree[:, 0] + self.rootIdx = 0 + self.time = time() + +def rotation_matrix_from_3x3(A): + U, s, Vt = np.linalg.svd(A, full_matrices=False) + V = Vt.T + T = np.dot(V, U.T) + # does the current solution use a reflection? + have_reflection = np.linalg.det(T) < 0 + + # if that's not what was specified, force another reflection + if have_reflection: + V[:,-1] *= -1 + s[-1] *= -1 + T = np.dot(V, U.T) + return T + +def svd_rot(src, tgt, reflection=False, debug=True): + # optimum rotation matrix of Y + A = np.dot(src.T, tgt) + U, s, Vt = np.linalg.svd(A, full_matrices=False) + V = Vt.T + T = np.dot(V, U.T) + # does the current solution use a reflection? + have_reflection = np.linalg.det(T) < 0 + + # if that's not what was specified, force another reflection + if reflection != have_reflection: + V[:,-1] *= -1 + s[-1] *= -1 + T = np.dot(V, U.T) + if debug: + err = np.linalg.norm(tgt - src @ T.T, axis=1) + print('[svd] ', err) + return T + +def normalize(vector): + return vector/np.linalg.norm(vector) + +def rad_from_2vec(vec1, vec2): + return np.arccos((normalize(vec1)*normalize(vec2)).sum()) + +def smoothing_factor(t_e, cutoff): + r = 2 * 3.14 * cutoff * t_e + return r / (r + 1) + +def exponential_smoothing(a, x, x_prev): + return a * x + (1 - a) * x_prev + +FilterResult = namedtuple('FilterResult', ['x', 'dx', 'v', 't']) + +class MyFilter: + def __init__(self, key, filled, min_cutoff=1.0, d_cutoff=1.0, + beta=0.1) -> None: + self.key = key + self.fill_result = filled + self.min_cutoff = min_cutoff + self.d_cutoff = d_cutoff + self.beta = beta + self.init = False + self.records = [] + self.result = None + self.counter = 0 + self.data = [] + self.conf = [] + self.smooth_record = [] + + def fill(self, value, conf): + filled = conf < 0.1 + if filled.sum() >= 1: + value[filled] = self.fill_result[0][filled] + if self.key == 'Rh': + value = rotation_matrix_from_3x3(value.reshape(3, 3)) + value = cv2.Rodrigues(value)[0].reshape(3,) + if (value < 0).all(): + value = -value + return value[None] + + def __call__(self, value, conf): + self.counter += 1 + x, v = value[0], conf[0] + if self.key == 'Rh': + x = cv2.Rodrigues(x)[0].reshape(-1) + v = np.zeros((9,)) + v[0] + t = np.zeros_like(x) + t[v>0.1] = self.counter + self.data.append(x) + self.conf.append(v) + if len(self.smooth_record) == 0: + if self.key == 'Rh': + start = x + else: + # start = self.fill_result[0] + start = x + smoothed = FilterResult(start, np.zeros_like(x), np.zeros_like(x), t) + self.smooth_record.append(smoothed) + if len(self.data) < 3: + return self.fill(x, v) + data = np.array(self.data) + conf = np.array(self.conf) + smoothed = self.smooth_record[-1] + # 预计的速度 + dx_new = x - smoothed.x + # 滤波器可见,当前可见 + flag_vis = (smoothed.v > 0.1) & (v > 0.1) + # - 速度异常的,移除掉,认为当前帧不可见 + flag_outlier = (np.abs(smoothed.dx) > 0.05) & (np.abs(dx_new - smoothed.dx)/(1e-5 + smoothed.dx) > 2.) + if self.key != 'Rh': + v[flag_vis&flag_outlier] = 0. + # 滤波器不可见,当前可见,速度打折,认为是新增的帧 + flag_new = (smoothed.v < 0.1)&(v>0.1) + dx_new[flag_new] /= 3 + # 滤波器可见,当前不可见,速度使用滤波器的速度 + flag_unvis = (v<0.1) & (conf[-2] < 0.1) + dx_new[flag_unvis] = smoothed.dx[flag_unvis] + # 滤波器不可见,当前也不可见,速度清0 + dx_new[(v<0.1)&(smoothed.v<0.1)] = 0. + # 实际估计出来的速度,这里要去掉不可见的地方 + # 混合的权重使用 0.7, 0.3默认全部使用新的一半 + weight_dx = np.zeros_like(dx_new) + 0.7 + dx_smoothed = smoothed.dx*(1-weight_dx) + dx_new*weight_dx + smoothed_value = smoothed.x + dx_smoothed + v_new = smoothed.v.copy() + v_new = v_new * (1-weight_dx) + v*weight_dx + t_new = smoothed.t.copy() + t_new[v>0.1] = t[v>0.1] + smooth_new = FilterResult(smoothed_value, dx_smoothed, v_new, t_new) + self.smooth_record.append(smooth_new) + if self.counter == 1000: + if self.key == 'poses': + import matplotlib.pyplot as plt + xrange = np.arange(0, data.shape[0]) + smoothed = np.array([d.x for d in self.smooth_record]) + for nj in range(data.shape[1]): + valid = conf[:, nj] > 0. + plt.scatter(xrange[valid], data[valid, nj]) + # yhat = savgol_filter(data[:, nj], data.shape[0], 3) + # plt.plot(yhat) + plt.plot(smoothed) + plt.show() + import ipdb;ipdb.set_trace() + # return self.fill(x, v) + return self.fill(smooth_new.x, smooth_new.v) + + def __call__0(self, value, conf): + self.counter += 1 + x, v = value[0], conf[0] + if self.key == 'Rh': + x = cv2.Rodrigues(x)[0].reshape(-1) + v = np.zeros((9,)) + v[0] + if self.result is None: + result = FilterResult(x, np.zeros_like(x), v, (v>0)*self.counter) + self.result = result + self.records.append(result) + return self.fill(result.x, result.v) + # return self.fill(x, v) + # 维护一个前一帧的,去除outlier + prev = self.result + t = prev.t.copy() + t[v>0.] = self.counter + dx = x - prev.x # 这里直接使用与之前的结果的差了,避免多帧不可见,然后速度过大 + MAX_DX = 1. + WINDOW = 31 + not_valid = ((np.abs(dx) > MAX_DX) & (prev.v > 0.1))|\ + (t-prev.t > WINDOW) + v[not_valid] = 0. + x_all = np.stack([r.x for r in self.records[-WINDOW:]]) + v_all = np.stack([r.v for r in self.records[-WINDOW:]]) + dx_all = np.stack([r.dx for r in self.records[-WINDOW:]]) + v_sum = v_all.sum(axis=0) + dx_mean = (dx_all * v_all).sum(axis=0)/(1e-5 + v_all.sum(axis=0)) + # if (x_all.shape[0] > 30) & (self.counter % 40 == 0): + if True: + x_mean = (x_all * v_all).sum(axis=0)/(1e-5 + v_all.sum(axis=0)) + x_pred = x_mean + dx_pred = np.zeros_like(x_pred) + elif x_all.shape[0] >= 5: + # 进行smooth + axt = np.zeros((2, x_all.shape[1])) + xrange = np.arange(x_all.shape[0]).reshape(-1, 1) + A0 = np.hstack([xrange, np.ones((x_all.shape[0], 1))]) + for nj in range(x_all.shape[1]): + conf = v_all[:, nj:nj+1] + if (conf>0.).sum() < 3: + continue + A = conf * A0 + b = conf * (x_all[:, nj:nj+1]) + est = np.linalg.inv(A.T @ A) @ A.T @ b + axt[:, nj] = est[:, 0] + x_all_smoothed = xrange * axt[0:1] + axt[1:] + x_pred = x_all_smoothed[x_all.shape[0]//2] + dx_pred = axt[0] + else: + x_pred = x_all[x_all.shape[0]//2] + dx_pred = dx_mean + if x_all.shape[0] == 1: + current = FilterResult(x, dx, v, t) + self.records.append(current) + self.result = current + else: + # dx_hat = (dx * v + dx_mean * v_mean)/(v+v_mean+1e-5) + # x_pred = x_mean + dx_hat + # current = FilterResult(x_pred, dx_hat, v, t) + current = FilterResult(x, dx, v, t) + self.records.append(current) + # 使用平均速度模型 + self.result = FilterResult(x_pred, dx_pred, v_sum, t) + return self.fill(self.result.x, self.result.v) + + def __call__2(self, value, conf): + self.counter += 1 + x, v = value[0], conf[0] + if self.result is None: + result = FilterResult(x, np.zeros_like(x), v, (v>0)*self.counter) + self.result = result + return self.fill(result.x, result.v) + prev = self.result + t = prev.t.copy() + t[v>0.] = self.counter + # update t + # dx = (x - prev.x)/(np.maximum(t-prev.t, 1)) + dx = x - prev.x # 这里直接使用与之前的结果的差了,避免多帧不可见,然后速度过大 + dx_ = dx.copy() + # 判断dx的大小 + large_dx = np.abs(dx) > 0.5 + if large_dx.sum() > 0: + v[large_dx] = 0. + t[large_dx] = prev.t[large_dx] + dx[large_dx] = 0. + missing_index = ((prev.v > 0.1) & (v < 0.1)) | (t - prev.t > 10) + if missing_index.sum() > 0: + print('missing', missing_index) + new_index = (prev.v < 0.1) & (v > 0.1) + if new_index.sum() > 0: + print('new', new_index) + dx[new_index] = 0. + weight_dx = v/(1e-5+ 3*prev.v + 1*v) + weight_x = v/(1e-5+ 3*prev.v + 1*v) + # 移除速度过大的点 + dx_hat = exponential_smoothing(weight_dx, dx, prev.dx) + x_pred = prev.x + dx_hat + x_hat = exponential_smoothing(weight_x, x, x_pred) + dx_real = x_hat - prev.x + # consider the unvisible v + print_val = { + 't_pre': prev.t, + 'x_inp': x, + 'x_pre': prev.x, + 'x_new': x_hat, + 'dx_inp': dx_, + 'dx_pre': prev.dx, + 'dx_new': dx_hat, + 'dx_real': dx_real, + 'v': v, + 'v_pre': prev.v, + 'w_vel': weight_dx, + 'w_x': weight_x + } + for key in print_val.keys(): + print('{:7s}'.format(key), end=' ') + print('') + for i in range(x.shape[0]): + for key in print_val.keys(): + print('{:7.2f}'.format(print_val[key][i]), end=' ') + print('') + v[missing_index] = prev.v[missing_index] / 1.2 # 衰减系数 + result = FilterResult(x_hat, dx_hat, v, t) + self.result = result + return self.fill(result.x, result.v) + if len(self.records) < 10: + self.records.append([self.counter, value, conf]) + return self.fill(value[0], conf[0]) + if self.x is None: + time = np.vstack([x[0] for x in self.records]) + value_pre = np.vstack([x[1] for x in self.records]) + conf_pre = np.vstack([x[2] for x in self.records]) + conf_sum = conf_pre.sum(axis=0) + value_mean = (value_pre * conf_pre).sum(axis=0)/(conf_sum + 1e-5) + self.x = value_mean + self.x_conf = conf_sum + t_prev = np.zeros_like(self.x, dtype=np.int) - 1 + t_prev[conf_sum>0] = self.counter + self.t_prev = t_prev + # 零速度初始化 + self.d_x = np.zeros_like(self.x) + return self.fill(self.x, self.x_conf) + # 假设每帧都传进来的吧 + return self.fill(self.x, self.x_conf) + x_est, v_est, conf_est = self.x.copy(), self.d_x.copy(), self.x_conf.copy() + value = value[0] + conf = conf[0] + d_x = value - self.x + t_current = np.zeros_like(self.x, dtype=int) - 1 + t_current[conf>0.] = self.counter + t_est = t_current - self.t_prev + # 前一帧有观测,当前帧有观测,两帧之差在10帧以内。正常更新 + flag_vv = (t_current > 0) & (self.t_prev > 0) & \ + (t_current - self.t_prev < 10) + # 前一帧无观测;当前帧有观测的;判断为新增的 + flag_iv = (self.t_prev < 0) & (t_current > 0) + weight_vel = smoothing_factor(t_est, self.d_cutoff) + # 将观测的速度权重置0 + weight_vel[flag_vv] = 0. + vel_hat = exponential_smoothing(weight_vel, d_x, self.d_x) + cutoff = self.min_cutoff + self.beta * np.abs(vel_hat) + weight_value = smoothing_factor(t_est, cutoff) + # 将观测的数值权重置0 + weight_value[flag_vv] = 0. + weight_value[flag_iv] = 1. # 当前帧可见的,之前的帧不可见的,直接选择当前帧 + vel_hat[flag_iv] = 0. + x_hat = exponential_smoothing(weight_value, value, self.x) + flag_vi = (self.t_prev > 0) & (~flag_vv) + flag_v = flag_vv | flag_vi | flag_iv + # 前一帧有观测;当前帧无观测的;判断为丢失的 + x_est[flag_v] = x_hat[flag_v] + v_est[flag_v] = vel_hat[flag_v] + conf_est[flag_v] = (self.x_conf + conf)[flag_v]/2 + self.t_prev[flag_v] = self.counter + self.x = x_est + self.d_x = v_est + self.x_conf = conf_est + return self.fill(x_est, conf_est) + +class IKBody(BaseBody): + def __init__(self, key, cfg_triangulator, cfg_model, cfg, debug) -> None: + super().__init__(cfg_triangulator, cfg_model, cfg) + self.key = key + self.frame_index = 0 + self.frame_latest = 0 + self.init = False + self.records = [] + self.results = [] + self.blank_result = self.make_blank() + self.fill_result = self.make_fill() + self.results_newest = self.blank_result + if True: + self.lefthand = ManoFitterCPPCache('LEFT') + self.righthand = ManoFitterCPPCache('RIGHT') + self.up_vector = 'z' + self.filter = {} + for key in ['Rh', 'poses', 'handl', 'handr']: + self.filter[key] = MyFilter(key, self.fill_result[key]) + + def make_blank(self): + raise NotImplementedError + + def make_fill(self): + raise NotImplementedError + + def smooth_results(self, params=None): + results = {'id': 0, 'type': 'smplh_half'} + for key in ['Rh', 'poses', 'handl', 'handr']: + value = self.filter[key](params[key], params[key+'_conf']) + results[key] = value + for key in ['shapes', 'Th']: + results[key] = self.blank_result[key] + return results + + def smooth_results_old(self, params=None): + if params is not None: + self.results.append(params) + if len(self.results) < 10: + return params + else: + if len(self.results) < 10: + return self.fill_result + else: + params = self.fill_result + results = {'id': 0} + if False: + for key in ['Rh', 'poses', 'handl', 'handr']: + if not self.filter[key].init: + import ipdb;ipdb.set_trace() + else: + value = self.filter[key](self.results[-1][key]) + if True: + for key in ['Rh', 'poses', 'handl', 'handr']: + find = False + for WINDOW in [10, 20, 40]: + if WINDOW > len(self.results): + break + records = self.results[-WINDOW:] + value = np.vstack([r[key] for r in records]) + conf = np.vstack([r[key+'_conf'] for r in records]) + valid = conf[..., 0] > 0 + if valid.sum() < WINDOW // 3: + import ipdb;ipdb.set_trace() + else: + value, conf = value[valid], conf[valid] + mean_value = value.mean(axis=0) + std_value = value.std(axis=0) + valid2 = (np.abs(value - mean_value) < WINDOW//3 * std_value).any(axis=-1) + if valid2.sum() < WINDOW // 4: + continue + find = True + value, conf = value[valid2], conf[valid2] + conf_sum = conf.sum(axis=0) + mean = (value*conf).sum(axis=0)/(conf_sum + 1e-5) + # 计算latest + break + if key in ['poses', 'handl', 'handr']: + conf_sum_p = conf.sum(axis=0) + mean_previous = (value*conf).sum(axis=0)/(conf_sum_p + 1e-5) + mean[conf_sum<0.01] = mean_previous[conf_sum<0.01] + conf_sum[conf_sum<0.01] = conf_sum_p[conf_sum<0.01] + # 使用fill的填值 + mean[conf_sum<0.01] = self.fill_result[key][0][conf_sum<0.01] + break + if find: + results[key] = mean[None] + else: + results[key] = self.fill_result[key] + if False: # 均值滤波 + for key in ['Rh', 'poses', 'handl', 'handr']: + if key not in params.keys(): + continue + if key not in self.cfg.SMOOTH_SIZE.keys(): + results[key] = params[key] + records = self.results[-self.cfg.SMOOTH_SIZE[key]:] + value = np.vstack([r[key] for r in records]) + conf = np.vstack([r[key+'_conf'] for r in records]) + conf_sum = conf.sum(axis=0) + mean = (value*conf).sum(axis=0)/(conf_sum + 1e-5) + # 计算latest + if key in ['poses', 'handl', 'handr']: + records = self.results[-5*self.cfg.SMOOTH_SIZE[key]:] + value = np.vstack([r[key] for r in records]) + conf = np.vstack([r[key+'_conf'] for r in records]) + conf_sum_p = conf.sum(axis=0) + mean_previous = (value*conf).sum(axis=0)/(conf_sum_p + 1e-5) + mean[conf_sum<0.01] = mean_previous[conf_sum<0.01] + conf_sum[conf_sum<0.01] = conf_sum_p[conf_sum<0.01] + # 使用fill的填值 + mean[conf_sum<0.01] = self.fill_result[key][0][conf_sum<0.01] + results[key] = mean[None] + results['Th'] = self.blank_result['Th'] + results['shapes'] = self.blank_result['shapes'] + return results + + def get_keypoints3d(self, records, key=None): + if key is None: + return np.stack([r[self.key] for r in records]) + else: + return np.stack([r[key] for r in records]) + + def check_keypoints(self, keypoints3d): + flag = (keypoints3d[..., -1]>self.cfg.MIN_THRES).sum() > 5 + if len(self.records) > 1: + pre = self.records[-1] + k_pre = self.get_keypoints3d([pre]) + dist = np.linalg.norm(keypoints3d[..., :3] - k_pre[..., :3], axis=-1) + conf = np.sqrt(keypoints3d[..., 3] * k_pre[..., 3]) + dist_mean = (dist * conf).sum()/conf.sum() + flag = flag and dist_mean < 0.1 + return flag + + def __call__(self, data): + self.frame_index += 1 + k3d = self.triangulator(data)[0] + keypoints3d = self.get_keypoints3d([k3d]) + flag = self.check_keypoints(keypoints3d) + if not flag: + mywarn('Missing keypoints {} [{}->{}]'.format(keypoints3d[..., -1].sum(), self.frame_latest, self.frame_index)) + # 1. 初始化过了,但是超出帧数了,清零 + # 2. 没有初始化过,超出了,清零 + if (self.frame_index - self.frame_latest > 10 and self.init) or not self.init: + mywarn('Missing keypoints, resetting...') + self.init = False + self.records = [] + self.results = [] + return [self.fill_result] + else: + return [self.smooth_results()] + elif not self.init: # 暂时还没有初始化,先等待 + if len(self.records) < 10: + self.records.append(k3d) + return [self.fill_result] + self.records.append(k3d) + flag, params = self.fitting(keypoints3d, self.results_newest) + if not flag: + return [self.fill_result] + self.frame_latest = self.frame_index + # smooth results + results = self.smooth_results(params) + self.results_newest = results + k3d['type'] = 'body25' + return [results, k3d] + +class HalfBodyIK(IKBody): + def get_keypoints3d(self, records): + THRES_WRIST = 0.2 + keypoints3d = super().get_keypoints3d(records) + keypoints3d = keypoints3d[:, INDEX_HALF] + handl = super().get_keypoints3d(records, key='handl3d') + handr = super().get_keypoints3d(records, key='handr3d') + dist_ll = np.linalg.norm(keypoints3d[:, 7, :3] - handl[:, 0, :3], axis=-1) + dist_rr = np.linalg.norm(keypoints3d[:, 4, :3] - handr[:, 0, :3], axis=-1) + log('Dist left = {}, right = {}'.format(dist_ll, dist_rr)) + handl[dist_ll>THRES_WRIST] = 0. + handr[dist_rr>THRES_WRIST] = 0. + keypoints3d = np.hstack([keypoints3d, handl, handr]) + conf = keypoints3d[..., 3:] + keypoints3d = np.hstack([(keypoints3d[..., :3] * conf).sum(axis=0)/(1e-5 + conf.sum(axis=0)), conf.min(axis=0)]) + keypoints3d = keypoints3d[None] + # if (keypoints3d.shape[0] == 10): + return keypoints3d + + def _ik_shoulder(self, keypoints3d, params): + SHOULDER_IDX = [2, 5] + shoulder = keypoints3d[SHOULDER_IDX[1], :3] - keypoints3d[SHOULDER_IDX[0], :3] + if self.up_vector == 'x': + shoulder[..., 0] = 0. + up_vector = np.array([1., 0., 0.], dtype=np.float32) + elif self.up_vector == 'z': + shoulder[..., 2] = 0. + up_vector = np.array([0., 0., 1.], dtype=np.float32) + shoulder = shoulder/np.linalg.norm(shoulder, keepdims=True) + # 限定一下角度范围 + theta = -np.rad2deg(np.arctan2(shoulder[1], shoulder[2])) + if (theta < 30 or theta > 150) and False: + return False, params + front = np.cross(shoulder, up_vector) + front = front/np.linalg.norm(front, keepdims=True) + R = np.stack([shoulder, up_vector, front]).T + Rh = cv2.Rodrigues(R)[0].reshape(1, 3) + log('Shoulder:{}'.format(Rh)) + params['R'] = R + params['Rh'] = Rh + params['Rh_conf'] = np.zeros((1, 3)) + keypoints3d[SHOULDER_IDX, 3].min() + return True, params + + def _ik_head(self, keypoints3d, params): + HEAD_IDX = [0, 8, 9, 10, 11] + HEAD_ROT_IDX = 0 + est_points = keypoints3d[HEAD_IDX, :3] + valid = (keypoints3d[HEAD_IDX[0], 3] > self.cfg.MIN_THRES) and (keypoints3d[HEAD_IDX[1:], 3]>self.cfg.MIN_THRES).sum()>=2 + if not valid: + params['poses_conf'][:, 3*HEAD_ROT_IDX:3*(HEAD_ROT_IDX+1)] = 0. + return params + params['poses_conf'][:, 3*HEAD_ROT_IDX:3*(HEAD_ROT_IDX+1)] = keypoints3d[HEAD_IDX, 3].sum() + + gt_points = self.k_template[HEAD_IDX].numpy() + gt_points = gt_points - gt_points[:1] + est_points = est_points - est_points[:1] + # gt_points = gt_points / np.linalg.norm(gt_points, axis=-1, keepdims=True) + # est_points = est_points / np.linalg.norm(est_points, axis=-1, keepdims=True) + + if True: + R_global = svd_rot(gt_points, est_points) + R_local = params['R'].T @ R_global + elif False: + est_points_inv = est_points @ params['R'].T.T + R_local = svd_rot(gt_points, est_points_inv) + else: + gt_points = gt_points @ params['R'].T + R_local = svd_rot(gt_points, est_points) + euler = rotmat2euler(R_local) + euler[0] = euler[0] - 25 + # log('euler before filter: {}'.format(euler)) + euler[0] = max(min(euler[0], 30), -30) + euler[1] = max(min(euler[1], 60), -60) + euler[2] = max(min(euler[2], 45), -45) + # log('euler after filter: {}'.format(euler)) + R_local = euler2rotmat(euler) + R_head = cv2.Rodrigues(R_local)[0].reshape(1, 3) + params['poses'][:, 3*HEAD_ROT_IDX:3*(HEAD_ROT_IDX+1)] = R_head + return params + + @staticmethod + def _rad_from_twovec(keypoints3d, start, mid, end, MIN_THRES): + start = keypoints3d[start] + mid = keypoints3d[mid] + end = keypoints3d[end] + if isinstance(end, list): + # dst is a list + if (end[:, 3] > MIN_THRES).sum() < 2: + return 0, 0. + end = np.sum(end * end[:, 3:], axis=0)/(end[:, 3:].sum()) + # use its mean to represent the points + conf = [start[3], mid[3], end[3]] + valid = (min(conf) > MIN_THRES).all() + if not valid: + return 0, 0. + conf = sum(conf) + dir_src = normalize(mid[:3] - start[:3]) + dir_dst = normalize(end[:3] - mid[:3]) + rad = rad_from_2vec(dir_src, dir_dst) + return conf, rad + + def _ik_elbow(self, keypoints3d, params): + for name, info in self.cfg.LEAF.items(): + conf, rad = self._rad_from_twovec(keypoints3d, info.start, info.mid, info.end, self.cfg.MIN_THRES) + if conf <= 0.: continue + params['poses_conf'][:, 3*info['index']:3*(info['index']+1)] = conf + rad = np.clip(rad, *info['ranges']) + rot = rad*np.array(info['axis']).reshape(1, 3) + params['poses'][:, 3*info['index']:3*(info['index']+1)] = rot + return params + + def _ik_arm(self, keypoints3d, params): + # forward一遍获得关键点 + # 这里需要确保求解的关节点没有父节点了 + template = self.body_model.keypoints({'poses': params['poses'], 'shapes': params['shapes']}, return_tensor=False)[0] + for name, info in self.cfg.NODE.items(): + idx = info['children'] + conf = keypoints3d[info['children'], 3] + if not (conf >self.cfg.MIN_THRES).all(): + continue + params['poses_conf'][:, 3*info['index']:3*(info['index']+1)] = conf.sum() + est_points = keypoints3d[idx, :3] + gt_points = template[idx] + est_points = est_points - est_points[:1] + gt_points = gt_points - gt_points[:1] + R_children = svd_rot(gt_points, est_points) + R_local = params['R'].T @ R_children + euler = rotmat2euler(R_local) + # log('euler {} before filter: {}'.format(name, euler)) + # euler[0] = max(min(euler[0], 90), -90) + # euler[1] = max(min(euler[1], 90), -90) + # euler[2] = max(min(euler[2], 90), -90) + # log('euler {} after filter: {}'.format(name, euler)) + R_local = euler2rotmat(euler) + params['poses'][:, 3*info['index']:3*(info['index']+1)] = cv2.Rodrigues(R_local)[0].reshape(1, 3) + return params + + def _ik_palm(self, keypoints3d, params): + # template = self.body_model.keypoints({'poses': params['poses'], 'shapes': params['shapes']}, return_tensor=False)[0] + T_joints, _ = self.body_model.transform({'poses': params['poses'], 'shapes': params['shapes']}, return_vertices=False) + T_joints = T_joints[0].cpu().numpy() + for name, info in self.cfg.PALM.items(): + # 计算手掌的朝向 + est_points = keypoints3d[:, :3] + est_conf = keypoints3d[:, 3] + if est_conf[info.children].min() < self.cfg.MIN_THRES: + continue + # 计算朝向 + dir0 = normalize(est_points[info.children[1]] - est_points[info.children[0]]) + dir1 = normalize(est_points[info.children[-1]] - est_points[info.children[0]]) + normal = normalize(np.cross(dir0, dir1)) + dir_parent = normalize(est_points[info.parent[1]] - est_points[info.parent[0]]) + # 计算夹角 + rad = np.arccos((normal * dir_parent).sum()) - np.pi/2 + rad = np.clip(rad, *info['ranges']) + rot = rad*np.array(info['axis']).reshape(1, 3) + params['poses'][:, 3*info['index']:3*(info['index']+1)] = rot + # 考虑手肘的朝向;这个时候还差一个绕手肘的朝向的方向的旋转;这个旋转是在手肘扭曲之前的 + # 先计算出这个朝向;再转化 + R_parents = params['R'] @ T_joints[info.index, :3, :3] + normal_canonical = R_parents.T @ normal.reshape(3, 1) + normal_canonical[0, 0] = 0 + normal_canonical = normalize(normal_canonical) + # 在canonical下的投影 + # normal_T = np.array([0., -1., 0.]) + # trick: 旋转角度的正弦值等于在z轴上的坐标 + rad = np.arcsin(normal_canonical[2, 0]) + rot_x = np.array([-rad, 0., 0.]) + R_x = cv2.Rodrigues(rot_x)[0] + R_elbow = cv2.Rodrigues(params['poses'][:, 3*info.index_elbow:3*info.index_elbow+3])[0] + R_elbow = R_elbow @ R_x + params['poses'][:, 3*info.index_elbow:3*info.index_elbow+3] = cv2.Rodrigues(R_elbow)[0].reshape(1, 3) + return params + + def _ik_hand(self, template, keypoints3d, poses, conf, is_left): + # 计算手的每一段的置信度 + poses_full = np.zeros((1, 45)) + conf_full = np.zeros((1, 45)) + y_axis = np.array([0., 1., 0.]) + log('_ik for left: {}'.format(is_left)) + for name, info in self.cfg.HAND.LEAF.items(): + conf, rad = self._rad_from_twovec(keypoints3d, *info.ranges, self.cfg.MIN_THRES) + if conf <= 0.: + log('- skip: {}'.format(name)) + continue + # trick: 手的朝向是反的 + rad = - rad + if info.axis == 'auto': + template_dir = template[info.ranges[2]] - template[info.ranges[1]] + # y轴方向设成0 + template_dir[1] = 0. + template_dir = normalize(template_dir) + # 计算旋转轴,在与z轴的cross方向上 + rot_vec = normalize(np.cross(template_dir, y_axis)).reshape(1, 3) + elif info.axis == 'root': + template_dir0 = template[info.ranges[1]] - template[info.ranges[0]] + template_dir1 = template[info.ranges[2]] - template[info.ranges[1]] + template_dir0 = normalize(template_dir0) + template_dir1 = normalize(template_dir1) + costheta0 = (template_dir0 *template_dir1).sum() + # 计算当前的夹角 + est_dir0 = keypoints3d[info.ranges[1], :3] - keypoints3d[info.ranges[0], :3] + est_dir1 = keypoints3d[info.ranges[2], :3] - keypoints3d[info.ranges[1], :3] + est_dir0 = normalize(est_dir0) + est_dir1 = normalize(est_dir1) + costheta1 = (est_dir0 * est_dir1).sum() + # trick: 手的旋转角度都是相反的 + rad = - np.arccos(np.clip(costheta1/costheta0, 0., 1.)) + rot_vec = normalize(np.cross(template_dir1, y_axis)).reshape(1, 3) + log('- get: {}: {:.1f}, {}'.format(name, np.rad2deg(rad), rot_vec)) + poses_full[:, 3*info.index:3*info.index+3] = rot_vec * rad + conf_full[:, 3*info.index:3*info.index+3] = conf + # 求解 + usePCA = False + if usePCA: + ncomp = 24 + lamb = 0.05 + if is_left: + A_full = self.body_model.components_full_l[:ncomp].T + mean_full = self.body_model.mean_full_l + else: + A_full = self.body_model.components_full_r[:ncomp].T + mean_full = self.body_model.mean_full_r + valid = conf_full[0] > 0. + A = A_full[valid, :] + res = (poses_full[:, valid] - mean_full[:, valid]).T + x = np.linalg.inv(A.T @ A + lamb * np.eye(ncomp)) @ A.T @ res + poses_full = (A_full @ x).reshape(1, -1) + mean_full + conf_full = np.zeros_like(poses_full) + valid.sum() + return poses_full, conf_full + + def make_blank(self): + params = self.body_model.init_params(1, ret_tensor=False) + params['id'] = 0 + params['type'] = 'smplh_half' + params['Th'][0, 0] = 1. + params['Th'][0, 1] = -1 + params['Th'][0, 2] = 1. + return params + + def make_fill(self): + params = self.body_model.init_params(1, ret_tensor=False) + params['id'] = 0 + params['type'] = 'smplh_half' + params['Rh'][0, 2] = -np.pi/2 + params['handl'] = self.body_model.mean_full_l + params['handr'] = self.body_model.mean_full_r + params['poses'][0, 3*4+2] = -np.pi/4 + params['poses'][0, 3*5+2] = np.pi/4 + return params + + def fitting(self, keypoints3d, results_pre): + # keypoints3d: (nFrames, nJoints, 4) + # 根据肩膀计算身体朝向 + if len(keypoints3d.shape) == 3: + keypoints3d = keypoints3d[0] + params = self.body_model.init_params(1, ret_tensor=False) + params['poses_conf'] = np.zeros_like(params['poses']) + params['handl_conf'] = np.zeros_like(params['handl']) + params['handr_conf'] = np.zeros_like(params['handr']) + params['Rh_conf'] = 0. + params['id'] = 0 + flag, params = self._ik_shoulder(keypoints3d, params) + if (params['Rh_conf'] <= 0.01).all(): + return False, params + params = self._ik_head(keypoints3d, params) + params = self._ik_elbow(keypoints3d, params) + params = self._ik_arm(keypoints3d, params) + params = self._ik_palm(keypoints3d, params) + if False: + params['handl'], params['handl_conf'] = self._ik_hand(self.k_template[12:12+21].numpy(), keypoints3d[12:12+21], params['handl'], params['handl_conf'], is_left=True) + params['handr'], params['handr_conf'] = self._ik_hand(self.k_template[12+21:12+21+21].numpy(), keypoints3d[12+21:12+21+21], params['handr'], params['handr_conf'], is_left=False) + else: + params_l = self.lefthand(keypoints3d[12:12+21])[0] + params_r = self.righthand(keypoints3d[12+21:12+21+21])[0] + # log('[{:06d}] {}'.format(self.frame_index, params_l['poses'][0])) + # log('[{:06d}] {}'.format(self.frame_index, params_r['poses'][0])) + ncomp = params_l['poses'].shape[1] + A_full = self.body_model.components_full_l[:ncomp].T + mean_full = self.body_model.mean_full_l + poses_full = (A_full @ params_l['poses'].T).T + mean_full + params['handl'] = poses_full + A_full = self.body_model.components_full_r[:ncomp].T + mean_full = self.body_model.mean_full_r + poses_full = (A_full @ params_r['poses'].T).T + mean_full + params['handr'] = poses_full + params['handl_conf'] = np.ones((1, 45)) + params['handr_conf'] = np.ones((1, 45)) + return True, params + +class BaseFitter(BaseBody): + def __init__(self, cfg_triangulator, cfg_model, + INIT_SIZE, WINDOW_SIZE, FITTING_SIZE, SMOOTH_SIZE, + init_dict, fix_dict, + cfg) -> None: + super().__init__(cfg_triangulator, cfg_model, cfg) + self.records = [] + self.results = [] + self.INIT_SIZE = INIT_SIZE + self.WINDOW_SIZE = WINDOW_SIZE + self.FITTING_SIZE = FITTING_SIZE + self.SMOOTH_SIZE = SMOOTH_SIZE + self.time = 0 + self.frame_latest = 0 + self.frame_index = 0 + self.init = False + self.init_dict = init_dict + self.fix_dict = fix_dict + self.identity_cache = {} + + def get_keypoints3d(self, records): + raise NotImplementedError + + def get_init_params(self, nFrames): + params = self.body_model.init_params(nFrames, ret_tensor=True) + for key, val in self.init_dict.items(): + if key == 'Rh': + import cv2 + R = cv2.Rodrigues(params['Rh'][0].cpu().numpy())[0] + for vec in self.init_dict['Rh']: + Rrel = cv2.Rodrigues(np.deg2rad(np.array(vec)))[0] + R = Rrel @ R + Rh = cv2.Rodrigues(R)[0] + params['Rh'] = torch.Tensor(Rh).reshape(-1, 3).repeat(nFrames, 1) + else: + params[key] = torch.Tensor(val).repeat(nFrames, 1) + params['id'] = 0 + return params + + def add_any_reg(self, val, val0, JTJ, JTr, w): + # loss: (val - val0) + nVals = val.shape[0] + if nVals not in self.identity_cache.keys(): + self.identity_cache[nVals] = torch.eye(nVals, device=val.device, dtype=val.dtype) + identity = self.identity_cache[nVals] + JTJ += w * identity + JTr += -w*(val - val0).view(-1, 1) + return 0 + + def log(self, name, step, delta, res, keys_range=None): + toc = (time() - self.time)*1000 + norm_d = torch.norm(delta).item() + norm_f = torch.norm(res).item() + text = '[{}:{:6.2f}]: step = {:3d}, ||delta|| = {:.4f}, ||res|| = {:.4f}'.format(name, toc, step, norm_d, norm_f) + print(text) + self.time = time() + return norm_d, norm_f + + def fitShape(self, keypoints3d, params, weight, option): + kintree = np.array(self.cfg.shape.kintree) + nShapes = params['shapes'].shape[-1] + # shapes: (1, 10) + shapes = params['shapes'].T + shapes0 = shapes.clone() + device, dtype = shapes.device, shapes.dtype + lengths3d_est = torch.norm(keypoints3d[:, kintree[:, 1], :3] - keypoints3d[:, kintree[:, 0], :3], dim=-1) + conf = torch.sqrt(keypoints3d[:, kintree[:, 1], 3:] * keypoints3d[:, kintree[:, 0], 3:]) + conf = conf.repeat(1, 1, 3).reshape(-1, 1) + nFrames = keypoints3d.shape[0] + # jacobian: (nFrames, nLimbs, 3, nShapes) + jacob_limb_shapes = self.jacobian_limb_shapes[None].repeat(nFrames, 1, 1, 1) + jacob_limb_shapes = jacob_limb_shapes.reshape(-1, nShapes) + # 注意:这里乘到雅克比的应该是 sqrt(conf),这里把两个合并了 + JTJ_limb_shapes = jacob_limb_shapes.t() @ (jacob_limb_shapes * conf) + lossnorm = 0 + self.time = time() + for iter_ in range(option.max_iter): + # perform shape blending + shape_offset = self.k_shapeBlend @ shapes + keyShaped = self.k_template + shape_offset[..., 0] + JTJ = JTJ_limb_shapes + JTr = torch.zeros((nShapes, 1), device=device, dtype=dtype) + # 并行添加所有的骨架 + dir = keyShaped[kintree[:, 1]] - keyShaped[kintree[:, 0]] + dir_normalized = dir / torch.norm(dir, dim=-1, keepdim=True) + # res: (nFrames, nLimbs, 3) + res = lengths3d_est[..., None] * dir_normalized[None] - dir[None] + res = conf * res.reshape(-1, 1) + JTr += jacob_limb_shapes.t() @ res + self.add_any_reg(shapes, shapes0, JTJ, JTr, w=weight.init_shape) + delta = torch.linalg.solve(JTJ, JTr) + shapes += delta + norm_d, norm_f = self.log('shape', iter_, delta, res) + if torch.norm(delta) < option.gtol: + break + if iter_ > 0 and abs(norm_f - lossnorm)/norm_f < option.ftol: + break + lossnorm = norm_f + shapes = shapes.t() + params['shapes'] = shapes + return params, keyShaped + + def fitRT(self, keypoints3d, params, weight, option, kpts_index=None): + keys_optimized = ['Rh', 'Th'] + if kpts_index is not None: + keypoints3d = keypoints3d[:, kpts_index] + init_dict = { + 'Rh': params['Rh'].clone(), + 'Th': params['Th'].clone(), + } + init_dict['Rot'] = batch_rodrigues(init_dict['Rh']) + params_dict = { + 'Rh': params['Rh'], + 'Th': params['Th'], + } + keys_range = {} + for ikey, key in enumerate(keys_optimized): + if ikey == 0: + keys_range[key] = [0, init_dict[key].shape[-1]] + else: + start = keys_range[keys_optimized[ikey-1]][1] + keys_range[key] = [start, start+init_dict[key].shape[-1]] + NUM_FRAME = keys_range[keys_optimized[-1]][1] + kpts = self.body_model.keypoints({'poses': params['poses'], 'shapes': params['shapes']}) + bn = keypoints3d.shape[0] + conf = keypoints3d[..., -1:].repeat(1, 1, 3).reshape(bn, -1) + dtype, device = kpts.dtype, kpts.device + w_joints = 1./keypoints3d.shape[-2] * weight.joints + self.time = time() + for iter_ in range(option.max_iter): + Rh, Th = params_dict['Rh'], params_dict['Th'] + rot, jacobi_R_rvec, jacobi_joints_RT = getJacobianOfRT(Rh, Th, kpts) + kpts_rt = torch.matmul(kpts, rot.transpose(-1, -2)) + Th[:, None] + # // loss: J_obs - (R @ jest + T) => -dR/drvec - dT/dtvec - Rx(djest/dtheta) + jacobi_keypoints = -jacobi_joints_RT + if kpts_index is not None: + jacobi_keypoints = jacobi_keypoints[:, kpts_index] + kpts_rt = kpts_rt[:, kpts_index] + jacobi_keypoints_flat = jacobi_keypoints.view(bn, -1, jacobi_keypoints.shape[-1]) + JTJ_keypoints = jacobi_keypoints_flat.transpose(-1, -2) @ (jacobi_keypoints_flat * conf[..., None]) + res = conf[..., None] * ((keypoints3d[..., :3] - kpts_rt).view(bn, -1, 1)) + JTr_keypoints = jacobi_keypoints_flat.transpose(-1, -2) @ res + # + JTJ = torch.eye(bn*NUM_FRAME, device=device, dtype=dtype) * 1e-4 + JTr = torch.zeros((bn*NUM_FRAME, 1), device=device, dtype=dtype) + # accumulate loss + for nf in range(bn): + JTJ[nf*NUM_FRAME:(nf+1)*NUM_FRAME, nf*NUM_FRAME:(nf+1)*NUM_FRAME] += w_joints * JTJ_keypoints[nf] + # add regularization for each parameter + for nf in range(bn): + for key in keys_optimized: + start, end = nf*NUM_FRAME + keys_range[key][0], nf*NUM_FRAME + keys_range[key][1] + if key == 'Rh': + # 增加初始化的loss + res_init = rot[nf] - init_dict['Rot'][nf] + JTJ[start:end, start:end] += weight['init_'+key] * jacobi_R_rvec[nf] @ jacobi_R_rvec[nf].T + JTr[start:end] += -weight.init_Rh * jacobi_R_rvec[nf] @ res_init.reshape(-1, 1) + else: + res_init = Th[nf] - init_dict['Th'][nf] + JTJ[start:end, start:end] += weight['init_'+key] * torch.eye(3) + JTr[start:end] += -weight['init_'+key] * res_init.reshape(-1, 1) + JTr += - w_joints * JTr_keypoints.reshape(-1, 1) + # solve + delta = torch.linalg.solve(JTJ, JTr) + norm_d, norm_f = self.log('pose', iter_, delta, res) + if torch.norm(delta) < option.gtol: + break + if iter_ > 0 and abs(norm_f - lossnorm)/norm_f < option.ftol: + break + delta = delta.view(bn, NUM_FRAME) + lossnorm = norm_f + for key, _range in keys_range.items(): + if key not in params_dict.keys():continue + params_dict[key] += delta[:, _range[0]:_range[1]] + norm_key = torch.norm(delta[:, _range[0]:_range[1]]) + params.update(params_dict) + return params + + @staticmethod + def localTransform(J_shaped, poses, rootIdx, kintree): + bn = poses.shape[0] + nThetas = poses.shape[1]//3 + localTrans = torch.eye(4, device=poses.device)[None, None].repeat(bn, nThetas, 1, 1) + poses_flat = poses.reshape(-1, 3) + rot_flat = batch_rodrigues(poses_flat) + rot = rot_flat.view(bn, nThetas, 3, 3) + localTrans[:, :, :3, :3] = rot + # set the root + localTrans[:, rootIdx, :3, 3] = J_shaped[rootIdx].view(1, 3) + # relTrans: (nKintree, 3) + relTrans = J_shaped[kintree[:, 1]] - J_shaped[kintree[:, 0]] + localTrans[:, kintree[:, 1], :3, 3] = relTrans[None] + return localTrans + + @staticmethod + def globalTransform(localTrans, rootIdx, kintree): + # localTrans: (bn, nJoints, 4, 4) + globalTrans = localTrans.clone() + # set the root + for (parent, child) in kintree: + globalTrans[:, child] = globalTrans[:, parent] @ localTrans[:, child] + return globalTrans + + def jacobi_GlobalTrans_theta(self, poses, j_shaped, rootIdx, kintree, + device, dtype): + parents = self.parents + start = time() + tic = lambda x: print('-> [{:20s}]: {:.3f}ms'.format(x, 1000*(time()-start))) + localTrans = self.localTransform(j_shaped, poses, rootIdx, kintree) + # tic('local trans') + globalTrans = self.globalTransform(localTrans, rootIdx, kintree) + # tic('global trans') + # 计算localTransformation + poses_flat = poses.view(poses.shape[0], -1, 3) + # jacobi_R_rvec: (bn, nJ, 3, 9) + Rot, jacobi_R_rvec = batch_rodrigues_jacobi(poses_flat) + + bn, nJoints = localTrans.shape[:2] + dGlobalTrans_template = torch.zeros((bn, nJoints, 4, 4), device=device, dtype=dtype) + # compute global transformation + # results: global transformation to each theta: (bn, nJ, 4, 4, nTheta) + jacobi_GlobalTrans_theta = torch.zeros((bn, nJoints, 4, 4, nJoints*3), device=device, dtype=dtype) + # tic('rodrigues') + for djIdx in range(1, nJoints): + if djIdx in self.cfg.IGNORE_JOINTS: continue + # // 第djIdx个轴角的第dAxis个维度 + for dAxis in range(3): + if dAxis in self.cfg.IGNORE_AXIS.get(str(djIdx), []): continue + # if(model->frozenJoint[3*djIdx+dAxis])continue; + # // 从上至下堆叠起来 + dGlobalTrans = dGlobalTrans_template.clone() + # // 将local的映射过来 + dGlobalTrans[:, djIdx, :3, :3] = jacobi_R_rvec[:, djIdx, dAxis].view(bn, 3, 3) + if djIdx != rootIdx: + # // d(R0 @ R1)/dt1 = R0 @ dR1/dt1, 这里的R0是上一级所有的累积,因此使用全局的 + parent = parents[djIdx] + dGlobalTrans[:, djIdx] = globalTrans[:, parent] @ dGlobalTrans[:, djIdx] + valid = np.zeros(nJoints, dtype=np.bool) + valid[djIdx] = True + # tic('current {}'.format(djIdx)) + # // 遍历骨架树: 将对当前theta的导数传递下去 + for (src, dst) in kintree: + # // 当前处理的关节为子节点的不用考虑 + if dst == djIdx: continue + # if dst in self.cfg.IGNORE_JOINTS:continue + valid[dst] = valid[src] + if valid[src]: + # // 如果父节点是有效的: d(R0 @ R1)/dt0 = dR0/dt0 @ R1,这里的R1是当前的局部的,因此使用local的 + dGlobalTrans[:, dst] = dGlobalTrans[:, src] @ localTrans[:, dst] + # tic('forward {}'.format(djIdx)) + jacobi_GlobalTrans_theta[..., 3*djIdx+dAxis] = dGlobalTrans + # tic('jacobia') + return globalTrans, jacobi_GlobalTrans_theta + + def fitPose(self, keypoints3d, params, weight, option, kpts_index=None): + # preprocess input data + if kpts_index is not None: + keypoints3d = keypoints3d[:, kpts_index] + bn = keypoints3d.shape[0] + conf = keypoints3d[..., -1:].repeat(1, 1, 3).reshape(bn, -1) + if (conf > 0.3).sum() < 4: + print('skip') + return params + w_joints = 1./keypoints3d.shape[-2] * weight.joints + # pre-calculate the shape + Rh, Th, poses = params['Rh'], params['Th'], params['poses'] + init_dict = { + 'Rh': Rh.clone(), + 'Th': Th.clone(), + 'poses': poses.clone() + } + init_dict['Rot'] = batch_rodrigues(init_dict['Rh']) + zero_dict = {key:torch.zeros_like(val) for key, val in init_dict.items()} + keys_optimized = ['Rh', 'Th', 'poses'] + keys_range = {} + for ikey, key in enumerate(keys_optimized): + if ikey == 0: + keys_range[key] = [0, init_dict[key].shape[-1]] + else: + start = keys_range[keys_optimized[ikey-1]][1] + keys_range[key] = [start, start+init_dict[key].shape[-1]] + NUM_FRAME = keys_range[keys_optimized[-1]][1] + # calculate J_shaped + shapes_t = params['shapes'].t() + shape_offset = self.j_shapeBlend @ shapes_t + # jshaped: (nJoints, 3) + j_shaped = self.j_template + shape_offset[..., 0] + shape_offset = self.k_shapeBlend @ shapes_t + # kshaped: (nJoints, 3) + k_shaped = self.k_template + shape_offset[..., 0] + # optimize parameters + nJoints = j_shaped.shape[0] + dtype, device = poses.dtype, poses.device + lossnorm = 0 + self.time = time() + for iter_ in range(option.max_iter): + # forward the model + # 0. poses => full poses + def tic(name): + print('[{:20}] {:.3f}ms'.format(name, 1000*(time()-self.time))) + if 'handl' in params.keys(): + poses_full = self.body_model.extend_poses(poses, params['handl'], params['handr']) + jacobi_posesful_poses = self.body_model.jacobian_posesfull_poses_ + else: + poses_full = self.body_model.extend_poses(poses) + jacobi_posesful_poses = self.body_model.jacobian_posesfull_poses(poses, poses_full) + # tic('jacobian poses full') + # 1. poses => local transformation => global transformation(bn, nJ, 4, 4) + globalTrans, jacobi_GlobalTrans_theta = self.jacobi_GlobalTrans_theta(poses_full, j_shaped, self.rootIdx, self.kintree, device, dtype) + # tic('global transform') + # 2. global transformation => relative transformation => final transformation + relTrans = globalTrans.clone() + relTrans[..., :3, 3:] -= torch.matmul(globalTrans[..., :3, :3], j_shaped[None, :, :, None]) + + relTrans_weight = torch.einsum('kj,fjab->fkab', self.k_weights, relTrans) + jacobi_relTrans_theta = jacobi_GlobalTrans_theta.clone() + # // consider topRight: T - RT0: add -dRT0/dt + # rot: (bn, nJoints, 3, 3, nThetas) @ (bn, nJoints, 1, 3, 1) => (bn, nJoints, 3, nThetas) + rot = jacobi_GlobalTrans_theta[..., :3, :3, :] + j0 = j_shaped.reshape(1, nJoints, 1, 3, 1).expand(bn, -1, -1, -1, -1) + jacobi_relTrans_theta[..., :3, 3, :] -= torch.sum(rot*j0, dim=-2) + jacobi_blendtrans_theta = torch.einsum('kj,fjabt->fkabt', self.k_weights, jacobi_relTrans_theta) + kpts = torch.einsum('fkab,kb->fka', relTrans_weight[..., :3, :3], k_shaped) + relTrans_weight[..., :3, 3] + # d(RJ0 + J1)/dtheta = d(R)/dtheta @ J0 + dJ1/dtheta + rot = jacobi_blendtrans_theta[..., :3, :3, :] + k0 = k_shaped.reshape(1, k_shaped.shape[0], 1, 3, 1).expand(bn, -1, -1, -1, -1) + # jacobi_keypoints_theta: (bn, nKeypoints, 3, nThetas) + jacobi_keypoints_theta = torch.sum(rot*k0, dim=-2) + jacobi_blendtrans_theta[..., :3, 3, :] + # tic('keypoints') + # // compute the jacobian of R T + # // loss: J_obs - (R @ jest + T) + rot, jacobi_R_rvec, jacobi_joints_RT = getJacobianOfRT(Rh, Th, kpts) + kpts_rt = torch.matmul(kpts, rot.transpose(-1, -2)) + Th[:, None] + rot_nk = rot[:, None].expand(-1, k_shaped.shape[0], -1, -1) + jacobi_keypoints_theta = torch.matmul(rot_nk, jacobi_keypoints_theta) + # compute jacobian of poses + shape_jacobi = jacobi_keypoints_theta.shape[:-1] + NUM_THETAS = jacobi_posesful_poses.shape[0] + jacobi_keypoints_poses = (jacobi_keypoints_theta[..., :NUM_THETAS].view(-1, NUM_THETAS) @ jacobi_posesful_poses).reshape(*shape_jacobi, -1) + # // loss: J_obs - (R @ jest + T) => -dR/drvec - dT/dtvec - Rx(djest/dtheta) + jacobi_keypoints = torch.cat([-jacobi_joints_RT, -jacobi_keypoints_poses], dim=-1) + if kpts_index is not None: + jacobi_keypoints = jacobi_keypoints[:, kpts_index] + kpts_rt = kpts_rt[:, kpts_index] + jacobi_keypoints_flat = jacobi_keypoints.view(bn, -1, jacobi_keypoints.shape[-1]) + # tic('jacobian keypoints') + JTJ_keypoints = jacobi_keypoints_flat.transpose(-1, -2) @ (jacobi_keypoints_flat * conf[..., None]) + res = conf[..., None] * ((keypoints3d[..., :3] - kpts_rt).view(bn, -1, 1)) + JTr_keypoints = jacobi_keypoints_flat.transpose(-1, -2) @ res + cache_dict = { + 'Th': Th, + 'Rh': Rh, + 'poses': poses, + } + # 计算loss + # JTJ = torch.zeros((bn*NUM_FRAME, bn*NUM_FRAME), device=device, dtype=dtype) + JTJ = torch.eye(bn*NUM_FRAME, device=device, dtype=dtype) * 1e-4 + JTr = torch.zeros((bn*NUM_FRAME, 1), device=device, dtype=dtype) + # add regularization for each parameter + for nf in range(bn): + for key in keys_optimized: + start, end = nf*NUM_FRAME + keys_range[key][0], nf*NUM_FRAME + keys_range[key][1] + JTJ[start:end, start:end] += weight['reg_{}'.format(key)] * torch.eye(end-start) + JTr[start:end] += -weight['reg_{}'.format(key)] * cache_dict[key][nf].view(-1, 1) + # add init for Rh + if key == 'Rh': + # 增加初始化的loss + res_init = rot[nf] - init_dict['Rot'][nf] + JTJ[start:end, start:end] += weight['init_'+key] * jacobi_R_rvec[nf] @ jacobi_R_rvec[nf].T + JTr[start:end] += -weight.init_Rh * jacobi_R_rvec[nf] @ res_init.reshape(-1, 1) + else: + res_init = cache_dict[key][nf] - init_dict[key][nf] + JTJ[start:end, start:end] += weight['init_'+key] * torch.eye(end-start) + JTr[start:end] += -weight['init_'+key] * res_init.reshape(-1, 1) + # add keypoints loss + for nf in range(bn): + JTJ[nf*NUM_FRAME:(nf+1)*NUM_FRAME, nf*NUM_FRAME:(nf+1)*NUM_FRAME] += w_joints * JTJ_keypoints[nf] + JTr += - w_joints * JTr_keypoints.reshape(-1, 1) + # tic('add loss') + delta = torch.linalg.solve(JTJ, JTr) + # tic('solve') + norm_d, norm_f = self.log('pose', iter_, delta, res, keys_range) + if torch.norm(delta) < option.gtol: + break + if iter_ > 0 and abs(norm_f - lossnorm)/norm_f < option.ftol: + break + delta = delta.view(bn, NUM_FRAME) + lossnorm = norm_f + for key, _range in keys_range.items(): + if key not in cache_dict.keys():continue + cache_dict[key] += delta[:, _range[0]:_range[1]] + res = { + 'id': params['id'], + 'poses': poses, + 'shapes': params['shapes'], + 'Rh': Rh, + 'Th': Th, + } + for key, val in params.items(): + if key not in res.keys(): + res[key] = val + return res + + def try_to_init(self, records): + if self.init: + return copy.deepcopy(self.params_newest) + mywarn('>> Initialize') + keypoints3d = self.get_keypoints3d(records) + params, keypoints_template = self.fitShape(keypoints3d, self.get_init_params(keypoints3d.shape[0]), self.cfg.shape.weight, self.cfg.shape.option) + Rot = batch_rodrigues(params['Rh']) + keypoints_template = torch.matmul(keypoints_template, Rot[0].t()) + if self.cfg.initRT.mean_T: + conf = keypoints3d[..., 3:] + T = ((keypoints3d[..., :3] - keypoints_template[None])*conf).sum(dim=-2)/(conf.sum(dim=-2)) + params['Th'] = T + params = self.fitRT(keypoints3d, params, self.cfg.initRT.weight, self.cfg.initRT.option, + kpts_index=self.cfg.TORSO_INDEX) + params = self.fitPose(keypoints3d, params, self.cfg.init_pose.weight, self.cfg.init_pose.option, + kpts_index=self.cfg.TORSO_INDEX) + params = self.fitPose(keypoints3d, params, self.cfg.init_pose.weight, self.cfg.init_pose.option, + kpts_index=self.cfg.BODY_INDEX) + mywarn('>> Initialize Rh = {}, Th = {}'.format(params['Rh'][0], params['Th'][0])) + params = Params(params)[-1] + self.init = True + return params + + def fitting(self, params_init, records): + keypoints3d = self.get_keypoints3d(records[-self.WINDOW_SIZE:]) + params = params_init + params = self.fitRT(keypoints3d[-self.FITTING_SIZE:], params, self.cfg.RT.weight, self.cfg.RT.option) + params = self.fitPose(keypoints3d[-self.FITTING_SIZE:], params, self.cfg.pose.weight, self.cfg.pose.option) + return params + + def filter_result(self, result): + poses = result['poses'].reshape(-1, 3) + # TODO: 这里的xyz是scipy中的XYZ + euler = axis_angle_to_euler(poses, order='xyz') + # euler[euler>np.pi] = 0. + poses = euler_to_axis_angle(euler, order='xyz') + result['euler'] = euler + return result + + def smooth_results(self): + results_ = {} + for key in self.results[0].keys(): + if key == 'id': continue + if key not in self.SMOOTH_SIZE.keys():continue + results_[key] = np.vstack([r[key] for r in self.results[-self.SMOOTH_SIZE[key]:]]) + results_[key] = np.mean(results_[key], axis=0, keepdims=True) + results_['id'] = 0 + for key, val in self.fix_dict.items(): + results_[key] = np.array(val) + # results_['Th'][:] = 0. + return [results_] + + def check_keypoints(self, keypoints3d): + flag = (keypoints3d[..., -1]>0.3).sum() > 5 + if len(self.records) > 1: + pre = self.records[-1] + k_pre = self.get_keypoints3d([pre]) + dist = torch.norm(keypoints3d[..., :3] - k_pre[..., :3], dim=-1) + conf = torch.sqrt(keypoints3d[..., 3] * k_pre[..., 3]) + dist_mean = (dist * conf).sum()/conf.sum() + flag = flag and dist_mean < 0.1 + return flag + + def __call__(self, data): + self.frame_index += 1 + k3d = self.triangulator(data)[0] + keypoints3d = self.get_keypoints3d([k3d]) + flag = self.check_keypoints(keypoints3d) + if not flag: + mywarn('Missing keypoints {} [{}->{}]'.format(keypoints3d[..., -1].sum(), self.frame_latest, self.frame_index)) + if self.frame_index - self.frame_latest > 10 and self.init: + mywarn('Missing keypoints, resetting...') + self.init = False + self.records = [] + self.results = [] + return -1 + self.records.append(k3d) + if len(self.records) < self.INIT_SIZE: + return -1 + with Timer('fitting', True): + params = self.try_to_init(self.records) + params = self.fitting(params, self.records) + params = self.filter_result(params) + self.results.append(params) + self.params_newest = params + self.frame_latest = self.frame_index + return self.smooth_results() + +class FitterCPPCache: + def __init__(self) -> None: + self.init = False + self.records = [] + self.frame_index = 0 + self.frame_latest = 0 + + def try_to_init(self, keypoints3d): + if self.init: + return copy.deepcopy(self.params_newest) + mywarn('>> Initialize') + params = self.handmodel.init_params(1) + if not (keypoints3d[..., -1] > 0.).all(): # 只有一半的点可见 + return params + params = self.handmodel.fit3DShape(keypoints3d, params) + params = self.handmodel.init3DRT(keypoints3d[-1:], params) + params = self.handmodel.fit3DPose(keypoints3d[-1:], params) + mywarn('>> Initialize Rh = {}, Th = {}'.format(params['Rh'][0], params['Th'][0])) + self.init = True + return params + + def check_keypoints(self, keypoints3d): + flag = (keypoints3d[..., -1]>0.3).sum() > 5 + if len(self.records) > 1: + k_pre = self.records[-1][None] + dist = np.linalg.norm(keypoints3d[..., :3] - k_pre[..., :3], axis=-1) + conf = np.sqrt(keypoints3d[..., 3] * k_pre[..., 3]) + dist_mean = (dist * conf).sum()/(1e-5+conf.sum()) + print(dist_mean) + flag = flag and dist_mean < 0.1 + return flag + + def smooth_results(self, params=None): + if params is None: + params = self.params_newest.copy() + params['poses'] = params['poses'][:, 3:] + params['id'] = 0 + return [params] + + def __call__(self, keypoints3d): + self.frame_index += 1 + flag = self.check_keypoints(keypoints3d) + if not flag: + mywarn('Missing keypoints {} [{}->{}]'.format(keypoints3d[..., -1].sum(), self.frame_latest, self.frame_index)) + if self.frame_index - self.frame_latest > 10: + mywarn('Missing keypoints, resetting...') + self.init = False + self.records = [] + return self.smooth_results(self.handmodel.init_params(1)) + self.records.append(keypoints3d) + with Timer('fitting'): + params = self.try_to_init(keypoints3d[None]) + params = self.handmodel.fit3DPose(keypoints3d[None], params) + self.params_newest = params + self.frame_latest = self.frame_index + return self.smooth_results() +class ManoFitterCPPCache(FitterCPPCache): + def __init__(self, name='LEFT') -> None: + super().__init__() + self.handmodel = load_object_from_cmd('config/model/mano_full_cpp.yml', [ + 'args.model_path', 'data/bodymodels/manov1.2/MANO_{}.pkl'.format(name), + 'args.regressor_path', 'data/smplx/J_regressor_mano_{}.txt'.format(name), + ]) + +class SMPLFitterCPPCache(FitterCPPCache): + def __init__(self, name='half') -> None: + super().__init__() + self.handmodel = load_object_from_cmd('config/model/mano_full_cpp.yml', [ + 'args.model_path', 'data/bodymodels/manov1.2/MANO_{}.pkl'.format(name), + 'args.regressor_path', 'data/smplx/J_regressor_mano_{}.txt'.format(name), + ]) + +class ManoFitterCPP: + def __init__(self, cfg_triangulator, key) -> None: + self.handmodel = load_object_from_cmd('config/model/mano_full_cpp.yml', []) + self.triangulator = load_object_from_cmd(cfg_triangulator, []) + self.time = 0 + self.frame_latest = 0 + self.frame_index = 0 + self.key = 'handl3d' + self.init = False + self.params_newest = None + self.records, self.results = [], [] + self.INIT_SIZE = 10 + + def get_keypoints3d(self, records, key=None): + if key is None: + return np.stack([r[self.key] for r in records]) + else: + return np.stack([r[key] for r in records]) + + def try_to_init(self, records): + if self.init: + return copy.deepcopy(self.params_newest) + mywarn('>> Initialize') + keypoints3d = self.get_keypoints3d(records) + params = self.handmodel.init_params(1) + params = self.handmodel.fit3DShape(keypoints3d, params) + params = self.handmodel.init3DRT(keypoints3d[-1:], params) + params = self.handmodel.fit3DPose(keypoints3d[-1:], params) + mywarn('>> Initialize Rh = {}, Th = {}'.format(params['Rh'][0], params['Th'][0])) + self.init = True + return params + + def smooth_results(self): + params = self.params_newest.copy() + params['poses'] = params['poses'][:, 3:] + params['id'] = 0 + return [params] + + def __call__(self, data): + self.frame_index += 1 + k3d = self.triangulator(data)[0] + keypoints3d = self.get_keypoints3d([k3d]) + # flag = self.check_keypoints(keypoints3d) + flag = True + if not flag: + mywarn('Missing keypoints {} [{}->{}]'.format(keypoints3d[..., -1].sum(), self.frame_latest, self.frame_index)) + if self.frame_index - self.frame_latest > 10 and self.init: + mywarn('Missing keypoints, resetting...') + self.init = False + self.records = [] + self.results = [] + return -1 + self.records.append(k3d) + if len(self.records) < self.INIT_SIZE: + return -1 + with Timer('fitting'): + params = self.try_to_init(self.records) + k3d = self.get_keypoints3d(self.records[-1:]) + params = self.handmodel.fit3DPose(k3d, params) + # params['poses'] = torch.Tensor(params['poses'][:, 3:]) + # params['shapes'] = torch.Tensor(params['shapes']) + # params['Rh'] = torch.Tensor(params['Rh']) + # params['Th'] = torch.Tensor(params['Th']) + # params = self.filter_result(params) + self.results.append(params) + self.params_newest = params + self.frame_latest = self.frame_index + return self.smooth_results() + +class BodyFitter(BaseFitter): + def __init__(self, key, **kwargs): + super().__init__(**kwargs) + self.key = key + + def get_keypoints3d(self, records, key=None): + if key is None: + return torch.Tensor(np.stack([r[self.key] for r in records])) + else: + return torch.Tensor(np.stack([r[key] for r in records])) + +class ManoFitter(BodyFitter): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + +class HalfFitter(BodyFitter): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.INDEX_HALF = [0,1,2,3,4,5,6,7,15,16,17,18] + + def get_keypoints3d(self, records): + THRES_WRIST = 0.05 + keypoints3d = super().get_keypoints3d(records) + keypoints3d = keypoints3d[:, self.INDEX_HALF] + handl = super().get_keypoints3d(records, key='handl3d') + handr = super().get_keypoints3d(records, key='handr3d') + dist_ll = torch.norm(keypoints3d[:, 7, :3] - handl[:, 0, :3], dim=-1) + dist_rr = torch.norm(keypoints3d[:, 4, :3] - handr[:, 0, :3], dim=-1) + handl[dist_ll>THRES_WRIST] = 0. + handr[dist_rr>THRES_WRIST] = 0. + keypoints3d = np.hstack([keypoints3d, handl, handr]) + conf = keypoints3d[..., 3:] + keypoints3d = np.hstack([(keypoints3d[..., :3] * conf).sum(axis=0)/(1e-5 + conf.sum(axis=0)), conf.min(axis=0)]) + keypoints3d = keypoints3d[None] + # if (keypoints3d.shape[0] == 10): + return torch.Tensor(keypoints3d) + + def filter_result(self, result): + result = super().filter_result(result) + # 限定一下关节旋转 + # 手肘 + # result['poses'][:, 5*3+1] = np.clip(result['poses'][:, 5*3+1], -2.5, 0.1) + # result['poses'][:, 6*3+1] = np.clip(result['poses'][:, 5*3+1], -0.1, 2.5) + # 手腕 + return result + +class HalfHandFitter(HalfFitter): + def __init__(self, cfg_handl, cfg_handr, **kwargs) -> None: + super().__init__(**kwargs) + self.handl = load_object_from_cmd(cfg_handl, []) + self.handr = load_object_from_cmd(cfg_handr, []) + + def get_init_params(self, nFrames): + params = super().get_init_params(nFrames) + params_ = self.handl.get_init_params(nFrames) + params['shapes_handl'] = params_['shapes'] + params['shapes_handr'] = params_['shapes'].clone() + params['Rh_handl'] = torch.zeros((nFrames, 3)) + params['Rh_handr'] = torch.zeros((nFrames, 3)) + params['Th_handl'] = torch.zeros((nFrames, 3)) + params['Th_handr'] = torch.zeros((nFrames, 3)) + return params + + def fitPose(self, keypoints3d, params, weight, option): + keypoints = { + 'handl': keypoints3d[:, -21-21:-21, :], + 'handr': keypoints3d[:, -21:, :] + } + for key in ['handl', 'handr']: + kpts = keypoints[key] + params_ = { + 'id': 0, + 'Rh': params['Rh_'+key], + 'Th': params['Th_'+key], + 'shapes': params['shapes_'+key], + 'poses': params[key], + } + if key == 'handl': + params_ = self.handl.fitPose(kpts, params_, self.handl.cfg.pose.weight, self.handl.cfg.pose.option) + else: + params_ = self.handr.fitPose(kpts, params_, self.handr.cfg.pose.weight, self.handr.cfg.pose.option) + params['Rh_'+key] = params_['Rh'] + params['Th_'+key] = params_['Th'] + params['shapes_'+key] = params_['shapes'] + params[key] = params_['poses'] + return super().fitPose(keypoints3d, params, weight, option, + kpts_index=[0,1,2,3,4,5,6,7,8,9,10,11, + 12, 17, 21, 25, 29, + 24, 29, 33, 37, 41]) + + def try_to_init(self, records): + if self.init: + return self.params_newest + params = super().try_to_init(records) + self.handl.init = False + self.handr.init = False + key = 'handl' + params_ = self.handl.try_to_init(records) + params['handl'] = params_['poses'] + params['Rh_'+key] = params_['Rh'] + params['Th_'+key] = params_['Th'] + params['shapes_'+key] = params_['shapes'] + key = 'handr' + params_ = self.handr.try_to_init(records) + params[key] = params_['poses'] + params['Rh_'+key] = params_['Rh'] + params['Th_'+key] = params_['Th'] + params['shapes_'+key] = params_['shapes'] + return params + +if __name__ == '__main__': + from glob import glob + from os.path import join + from tqdm import tqdm + from ..mytools.file_utils import read_json + from ..config.baseconfig import load_object_from_cmd + + data= '/nas/datasets/EasyMocap/302' + # data = '/home/qing/Dataset/handl' + mode = 'half' + # data = '/home/qing/DGPU/home/shuaiqing/zju-mocap-mp/female-jump' + data = '/home/qing/Dataset/desktop/0402/test3' + k3dnames = sorted(glob(join(data, 'output-keypoints3d', 'keypoints3d', '*.json'))) + if mode == 'handl': + fitter = load_object_from_cmd('config/recon/fit_manol.yml', []) + elif mode == 'half': + fitter = load_object_from_cmd('config/recon/fit_half.yml', []) + elif mode == 'smpl': + fitter = load_object_from_cmd('config/recon/fit_smpl.yml', []) + from easymocap.socket.base_client import BaseSocketClient + client = BaseSocketClient('0.0.0.0', 9999) + + for k3dname in tqdm(k3dnames): + k3ds = read_json(k3dname) + if mode == 'handl': + k3ds = np.array(k3ds[0]['handl3d']) + data = {fitter.key3d: k3ds} + elif mode == 'half': + data = { + 'keypoints3d': np.array(k3ds[0]['keypoints3d']), + 'handl3d': np.array(k3ds[0]['handl3d']), + 'handr3d': np.array(k3ds[0]['handr3d']) + } + elif mode == 'smpl': + k3ds = np.array(k3ds[0]['keypoints3d']) + data = {fitter.key3d: k3ds} + results = fitter(data) + if results != -1: + client.send_smpl(results) diff --git a/easymocap/multistage/init_cnn.py b/easymocap/multistage/init_cnn.py new file mode 100644 index 0000000..a4e2ce0 --- /dev/null +++ b/easymocap/multistage/init_cnn.py @@ -0,0 +1,100 @@ +''' + @ Date: 2022-04-26 17:54:28 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-07-11 22:20:44 + @ FilePath: /EasyMocapPublic/easymocap/multistage/init_cnn.py +''' +import os +import numpy as np +import cv2 +from tqdm import tqdm +from os.path import join +import torch +from ..bodymodel.base import Params +from ..estimator.wrapper_base import bbox_from_keypoints +from ..mytools.writer import write_smpl +from ..mytools.reader import read_smpl + +class InitSpin: + # initialize the smpl results by spin + def __init__(self, mean_params, ckpt_path, share_shape, + multi_person=False, compose_mp=False) -> None: + from ..estimator.SPIN.spin_api import SPIN + import torch + self.share_shape = share_shape + self.spin_model = SPIN( + SMPL_MEAN_PARAMS=mean_params, + checkpoint=ckpt_path, + device=torch.device('cpu')) + self.distortMap = {} + self.multi_person = multi_person + self.compose_mp = compose_mp + + def undistort(self, image, K, dist, nv): + if np.linalg.norm(dist) < 0.01: + return image + if nv not in self.distortMap.keys(): + h, w = image.shape[:2] + mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, K, (w,h), 5) + self.distortMap[nv] = (mapx, mapy) + mapx, mapy = self.distortMap[nv] + image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) + return image + + def __call__(self, body_model, body_params, infos): + self.spin_model.model.to(body_model.device) + self.spin_model.device = body_model.device + params_all = [] + for nf, imgname in enumerate(tqdm(infos['imgname'], desc='Run SPIN')): + # 暂时不考虑多视角情况 + # TODO: 没有考虑多人的情况 + basename = os.sep.join(imgname.split(os.sep)[-2:]).split('.')[0] + '.json' + sub = os.path.dirname(basename) + cache_dir = os.path.abspath(join(os.sep.join(imgname.split(os.sep)[:-3]), 'cache_spin')) + outname = join(cache_dir, basename) + if os.path.exists(outname): + params = read_smpl(outname) + if self.multi_person: + params_all.append(params) + else: + params_all.append(params[0]) + continue + camera = {key: infos[key][nf].numpy() for key in ['K', 'Rc', 'Tc', 'dist']} + camera['R'] = camera['Rc'] + camera['T'] = camera['Tc'] + image = cv2.imread(imgname) + image = self.undistort(image, camera['K'], camera['dist'], sub) + if len(infos['keypoints2d'].shape) == 3: + k2d = infos['keypoints2d'][nf][None] + else: + k2d = infos['keypoints2d'][nf] + params_current = [] + for pid in range(k2d.shape[0]): + keypoints = k2d[pid].numpy() + bbox = bbox_from_keypoints(keypoints) + nValid = (keypoints[:, -1] > 0).sum() + if nValid > 4: + result = self.spin_model(body_model, image, + bbox, keypoints, camera, ret_vertices=False) + elif len(params_all) == 0: + print('[WARN] not enough joints: {} in first frame'.format(imgname)) + else: + print('[WARN] not enough joints: {}'.format(imgname)) + result = {'body_params': params_all[-1][pid]} + params = result['body_params'] + params['id'] = pid + params_current.append(params) + write_smpl(outname, params_current) + if self.multi_person: + params_all.append(params_current) + else: + params_all.append(params_current[0]) + if not self.multi_person: + params_all = Params.merge(params_all, share_shape=self.share_shape) + params_all = body_model.encode(params_all) + elif self.compose_mp: + params_all = Params.merge([Params.merge(p_, share_shape=False) for p_ in params_all], share_shape=False, stack=np.stack) + params_all['id'] = 0 + return params_all \ No newline at end of file diff --git a/easymocap/multistage/init_pose.py b/easymocap/multistage/init_pose.py new file mode 100644 index 0000000..c87bed2 --- /dev/null +++ b/easymocap/multistage/init_pose.py @@ -0,0 +1,36 @@ +''' + @ Date: 2022-04-02 13:59:50 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-07-13 16:34:21 + @ FilePath: /EasyMocapPublic/easymocap/multistage/init_pose.py +''' +import os +import numpy as np +import cv2 +from tqdm import tqdm +from os.path import join +import torch +from ..bodymodel.base import Params +from ..estimator.wrapper_base import bbox_from_keypoints +from ..mytools.writer import write_smpl +from ..mytools.reader import read_smpl + +class SmoothPoses: + def __init__(self, window_size) -> None: + self.W = window_size + + def __call__(self, body_model, body_params, infos): + poses = body_params['poses'] + padding_before = poses[:1].copy().repeat(self.W, 0) + padding_after = poses[-1:].copy().repeat(self.W, 0) + mean = poses.copy() + nFrames = mean.shape[0] + poses_full = np.vstack([padding_before, poses, padding_after]) + for w in range(1, self.W+1): + mean += poses_full[self.W-w:self.W-w+nFrames] + mean += poses_full[self.W+w:self.W+w+nFrames] + mean /= 2*self.W + 1 + body_params['poses'] = mean + return body_params \ No newline at end of file diff --git a/easymocap/multistage/initialize.py b/easymocap/multistage/initialize.py new file mode 100644 index 0000000..c472c08 --- /dev/null +++ b/easymocap/multistage/initialize.py @@ -0,0 +1,172 @@ +import numpy as np +import cv2 +from ..dataset.config import CONFIG +from ..config import load_object +from ..mytools.debug_utils import log, mywarn, myerror +import torch +from tqdm import tqdm, trange + +def svd_rot(src, tgt, reflection=False, debug=False): + # optimum rotation matrix of Y + A = np.matmul(src.transpose(0, 2, 1), tgt) + U, s, Vt = np.linalg.svd(A, full_matrices=False) + V = Vt.transpose(0, 2, 1) + T = np.matmul(V, U.transpose(0, 2, 1)) + # does the current solution use a reflection? + have_reflection = np.linalg.det(T) < 0 + + # if that's not what was specified, force another reflection + V[have_reflection, :, -1] *= -1 + s[have_reflection, -1] *= -1 + T = np.matmul(V, U.transpose(0, 2, 1)) + if debug: + err = np.linalg.norm(tgt - src @ T.T, axis=1) + print('[svd] ', err) + return T + +def batch_invRodrigues(rot): + res = [] + for r in rot: + v = cv2.Rodrigues(r)[0] + res.append(v) + res = np.stack(res) + return res[:, :, 0] + +class BaseInit: + def __init__(self) -> None: + pass + + def __call__(self, body_model, body_params, infos):\ + return body_params + +class Remove(BaseInit): + def __init__(self, key, index) -> None: + super().__init__() + self.key = key + self.index = index + + def __call__(self, body_model, body_params, infos): + infos[self.key][..., self.index, :] = 0 + return super().__call__(body_model, body_params, infos) + +class CheckKeypoints: + def __init__(self, type) -> None: + # this class is used to check if the provided keypoints3d + self.type = type + self.body_config = CONFIG[type] + self.hand_config = CONFIG['hand'] + + def __call__(self, body_model, body_params, infos): + for key in ['keypoints3d', 'handl3d', 'handr3d']: + if key not in infos.keys(): continue + keypoints = infos[key] + conf = keypoints[..., -1] + keypoints[conf<0.1] = 0 + if key == 'keypoints3d': + continue + import ipdb;ipdb.set_trace() + # limb_length = np.linalg.norm(keypoints[:, , :3], axis=2) + return body_params + +class InitRT: + def __init__(self, torso) -> None: + self.torso = torso + + def __call__(self, body_model, body_params, infos): + keypoints3d = infos['keypoints3d'].detach().cpu().numpy() + temp_joints = body_model.keypoints(body_params, return_tensor=False) + + torso = keypoints3d[..., self.torso, :3].copy() + torso_temp = temp_joints[..., self.torso, :3].copy() + # here use the first id of torso as the rotation center + root, root_temp = torso[..., :1, :], torso_temp[..., :1, :] + torso = torso - root + torso_temp = torso_temp - root_temp + conf = (keypoints3d[..., self.torso, 3] > 0.).all(axis=-1) + if not conf.all(): + myerror("The torso in frames {} is not valid, please check the 3d keypoints".format(np.where(~conf))) + if len(torso.shape) == 3: + R = svd_rot(torso_temp, torso) + R_flat = R + T = np.matmul(- root_temp, R.transpose(0, 2, 1)) + root + else: + R_flat = svd_rot(torso_temp.reshape(-1, *torso_temp.shape[-2:]), torso.reshape(-1, *torso.shape[-2:])) + R = R_flat.reshape(*torso.shape[:2], 3, 3) + T = np.matmul(- root_temp, R.swapaxes(-1, -2)) + root + for nf in np.where(~conf)[0]: + # copy previous frames + mywarn('copy {} from {}'.format(nf, nf-1)) + R[nf] = R[nf-1] + T[nf] = T[nf-1] + body_params['Th'] = T[..., 0, :] + rvec = batch_invRodrigues(R_flat) + if len(torso.shape) > 3: + rvec = rvec.reshape(*torso.shape[:2], 3) + body_params['Rh'] = rvec + return body_params + + def __str__(self) -> str: + return "[Initialize] svd with torso: {}".format(self.torso) + +class TriangulatorWrapper: + def __init__(self, module, args): + self.triangulator = load_object(module, args) + + def __call__(self, body_model, body_params, infos): + infos['RT'] = torch.cat([infos['Rc'], infos['Tc']], dim=-1) + data = { + 'RT': infos['RT'].numpy(), + } + for key in self.triangulator.keys: + if key not in infos.keys(): + continue + data[key] = infos[key].numpy() + data[key+'_unproj'] = infos[key+'_unproj'].numpy() + data[key+'_distort'] = infos[key+'_distort'].numpy() + results = self.triangulator(data)[0] + for key, val in results.items(): + if key == 'id': continue + infos[key] = torch.Tensor(val[None].astype(np.float32)) + body_params = body_model.init_params(nFrames=1, add_scale=True) + return body_params + +class CheckRT: + def __init__(self, T_thres, window): + self.T_thres = T_thres + self.window = window + + def __call__(self, body_model, body_params, infos): + Th = body_params['Th'] + if len(Th.shape) == 3: + for nper in range(Th.shape[1]): + for nf in trange(1, Th.shape[0], desc='Check Th of {}'.format(nper)): + if nf > self.window: + tpre = Th[nf-self.window:nf, nper] + else: + tpre = Th[:nf, nper] + tpre = tpre.mean(axis=0) + tnow = Th[nf , nper] + dist = np.linalg.norm(tnow - tpre) + if dist > self.T_thres: + mywarn('[Check Th] distance in frame {} = {} larger than {}'.format(nf, dist, self.T_thres)) + Th[nf, nper] = tpre + body_params['Th'] = Th + return body_params + +class Scale: + def __init__(self, keys): + self.keys = keys + + def __call__(self, body_model, body_params, infos): + scale = body_params.pop('scale')[0, 0] + if scale < 1.1 and scale > 0.9: + return body_params + print('scale = ', scale) + for key in self.keys: + if key not in infos.keys(): + continue + infos[key] /= scale + infos['Tc'] /= scale + infos['RT'][..., -1] *= scale + infos['scale'] = scale + return body_params diff --git a/easymocap/multistage/lossbase.py b/easymocap/multistage/lossbase.py new file mode 100644 index 0000000..c728711 --- /dev/null +++ b/easymocap/multistage/lossbase.py @@ -0,0 +1,589 @@ +import numpy as np +import torch.nn as nn +import torch +from ..bodymodel.lbs import batch_rodrigues + +class LossBase(nn.Module): + def __init__(self): + super().__init__() + + def __str__(self) -> str: + return '# lack of comment' + + def check_at_start(self, **kwargs): + pass + + def check_at_end(self, **kwargs): + pass + +class GMoF(nn.Module): + def __init__(self, rho=1): + super(GMoF, self).__init__() + self.rho2 = rho * rho + + def extra_repr(self): + return 'rho = {}'.format(self.rho) + + def forward(self, est, gt=None, conf=None): + if gt is not None: + square_diff = torch.sum((est - gt)**2, dim=-1) + else: + square_diff = torch.sum(est**2, dim=-1) + diff = torch.div(square_diff, square_diff + self.rho2) + if conf is not None: + res = torch.sum(diff * conf)/(1e-5 + conf.sum()) + else: + res = diff.sum()/diff.numel() + return res + +def make_loss(norm, norm_info, reduce='sum'): + reduce = torch.sum if reduce=='sum' else torch.mean + if norm == 'l2': + def loss(est, gt=None, conf=None): + if gt is not None: + square_diff = reduce((est - gt)**2, dim=-1) + else: + square_diff = reduce(est**2, dim=-1) + if conf is not None: + res = torch.sum(square_diff * conf)/(1e-5 + conf.sum()) + else: + res = square_diff.sum()/square_diff.numel() + return res + elif norm == 'gm': + loss = GMoF(norm_info) + return loss + +def select(value, ranges, index, dim): + if len(ranges) > 0: + if ranges[1] == -1: + value = value[..., ranges[0]:] + else: + value = value[..., ranges[0]:ranges[1]] + return value + if len(index) > 0: + if dim == -1: + value = value[..., index] + elif dim == -2: + value = value[..., index, :] + return value + return value + +def print_table(header, contents): + from tabulate import tabulate + length = len(contents[0]) + tables = [[] for _ in range(length)] + mean = ['Mean'] + for icnt, content in enumerate(contents): + for i in range(length): + if isinstance(content[i], float): + tables[i].append('{:6.2f}'.format(content[i])) + else: + tables[i].append('{}'.format(content[i])) + if icnt > 0: + mean.append('{:6.2f}'.format(sum(content)/length)) + tables.append(mean) + print(tabulate(tables, header, tablefmt='fancy_grid')) + +class AnyReg(LossBase): + def __init__(self, key, norm, dim=-1, reduce='sum', norm_info={}, ranges=[], index=[], **kwargs): + super().__init__() + self.ranges = ranges + self.index = index + self.key = key + self.dim = dim + if 'init_' + key in kwargs.keys(): + init = kwargs['init_'+key] + self.register_buffer('init', torch.Tensor(init)) + else: + self.init = None + self.norm_name = norm + self.loss = make_loss(norm, norm_info, reduce=reduce) + + def forward(self, **kwargs): + """ + value: (nFrames, ..., nDims) + """ + value = kwargs[self.key] + if self.init is not None: + value = value - self.init + value = select(value, self.ranges, self.index, self.dim) + return self.loss(value) + + def __str__(self) -> str: + return 'Loss for {}'.format(self.key, self.norm_name) + +class RegPrior(AnyReg): + def __init__(self, **cfg): + super().__init__(**cfg) + self.init = None # disable init + infos = { + (2, 0): '-exp', + (2, 1): 'l2', + (2, 2): 'l2', + (3, 0): '-exp', # knee + (3, 1): 'L2', + (3, 2): 'L2', + (4, 0): '-exp', # knee + (4, 1): 'L2', + (4, 2): 'L2', + (5, 0): '-exp', + (5, 1): 'l2', + (5, 2): 'l2', + (6, 1): 'L2', + (6, 2): 'L2', + (7, 1): 'L2', + (7, 2): 'L2', + (8, 0): '-exp', + (8, 1): 'l2', + (8, 2): 'l2', + (9, 0): 'L2', + (9, 1): 'L2', + (9, 2): 'L2', + (10, 0): 'L2', + (10, 1): 'L2', + (10, 2): 'L2', + (12, 0): 'l2', # 肩关节前面 + (13, 0): 'l2', + (17, 1): 'exp', + (17, 2): 'L2', + (18, 1): '-exp', + (18, 2): 'L2', + } + self.l2dims = [] + self.L2dims = [] + self.expdims = [] + self.nexpdims = [] + for (nj, ndim), norm in infos.items(): + dim = nj*3 + ndim + if norm == 'l2': + self.l2dims.append(dim) + elif norm == 'L2': + self.L2dims.append(dim) + elif norm == '-exp': + self.nexpdims.append(dim) + elif norm == 'exp': + self.expdims.append(dim) + + def forward(self, poses, **kwargs): + """ + poses: (..., nDims) + """ + alll2loss = torch.mean(poses**2) + l2loss = torch.sum(poses[:, self.l2dims]**2)/len(self.l2dims)/poses.shape[0] + L2loss = torch.sum(poses[:, self.L2dims]**2)/len(self.L2dims)/poses.shape[0] + exploss = torch.sum(torch.exp(poses[:, self.expdims]))/poses.shape[0] + nexploss = torch.sum(torch.exp(-poses[:, self.nexpdims]))/poses.shape[0] + loss = 0.1*l2loss + L2loss + 0.0005*(exploss + nexploss)/(len(self.expdims) + len(self.nexpdims)) + 0.01*alll2loss + return loss + +class VPoserPrior(AnyReg): + def __init__(self, **cfg): + super().__init__(**cfg) + vposer_ckpt = 'data/bodymodels/vposer_v02' + from human_body_prior.tools.model_loader import load_model + from human_body_prior.models.vposer_model import VPoser + vposer, _ = load_model(vposer_ckpt, + model_code=VPoser, + remove_words_in_model_weights='vp_model.', + disable_grad=True) + vposer.eval() + self.vposer = vposer + self.init = None # disable init + + def forward(self, poses, **kwargs): + """ + poses: (..., nDims) + """ + nDims = 63 + poses_body = poses[..., :nDims].reshape(-1, nDims) + latent = self.vposer.encode(poses_body) + if True: + ret = self.vposer.decode(latent.sample())['pose_body'].reshape(poses.shape[0], nDims) + return super().forward(poses=poses_body-ret) + else: + return super().forward(poses=latent.mean) + +class AnySmooth(LossBase): + def __init__(self, key, weight, norm, norm_info={}, ranges=[], index=[], dim=-1, order=1): + super().__init__() + self.ranges = ranges + self.index = index + self.dim = dim + self.weight = weight + self.loss = make_loss(norm, norm_info) + self.norm_name = norm + self.key = key + self.order = order + + def forward(self, **kwargs): + loss = 0 + value = kwargs[self.key] + value = select(value, self.ranges, self.index, self.dim) + if value.shape[0] <= len(self.weight): + return torch.FloatTensor([0.]).to(value.device) + for width, weight in enumerate(self.weight, start=1): + vel = value[width:] - value[:-width] + if self.order == 2: + vel = vel[1:] - vel[:-1] + loss += weight * self.loss(vel) + return loss + + def check(self, value): + vel = value[1:] - value[:-1] + if len(vel.shape) > 2: + vel = torch.norm(vel, dim=-1) + else: + vel = torch.abs(vel) + vel = vel.detach().cpu().numpy() + return vel + + def get_check_name(self, value): + name = [str(i) for i in range(value.shape[1])] + return name + + def check_at_start(self, **kwargs): + value = kwargs[self.key] + if value.shape[0] < len(self.weight): + return 0 + header = ['Smooth '+self.key, 'mean(before)', 'max(before)', 'frame(before)'] + name = self.get_check_name(value) + vel = self.check(value) + contents = [name, vel.mean(axis=0).tolist(), vel.max(axis=0).tolist(), vel.argmax(axis=0).tolist()] + self.cache_check = (header, contents) + return super().check_at_start(**kwargs) + + def check_at_end(self, **kwargs): + value = kwargs[self.key] + if value.shape[0] < len(self.weight): + return 0 + err_after = self.check(kwargs[self.key]) + header, contents = self.cache_check + header.extend(['mean(after)', 'max(after)', 'frame(after)']) + contents.extend([err_after.mean(axis=0).tolist(), err_after.max(axis=0).tolist(), err_after.argmax(axis=0).tolist()]) + print_table(header, contents) + + def __str__(self) -> str: + return "smooth in {} frames, range={}, norm={}".format(self.weight, self.ranges, self.norm_name) + +class SmoothRot(AnySmooth): + def __init__(self, **kwargs): + super().__init__(**kwargs) + from ..bodymodel.lbs import batch_rodrigues + self.rodrigues = batch_rodrigues + + def convert_Rh_to_R(self, Rh): + shape = Rh.shape[1] + ret = [] + for i in range(shape//3): + Rot = self.rodrigues(Rh[:, 3*i:3*(i+1)]) + ret.append(Rot) + ret = torch.cat(ret, dim=1) + return ret + + def forward(self, **kwargs): + Rh = kwargs[self.key] + if Rh.shape[-1] != 3: + loss = 0 + for i in range(Rh.shape[-1]//3): + Rh_sub = Rh[..., 3*i:3*i+3] + Rot = self.convert_Rh_to_R(Rh_sub).view(*Rh_sub.shape[:-1], 3, 3) + loss += super().forward(**{self.key: Rot}) + return loss + else: + Rh_flat = Rh.view(-1, 3) + Rot = self.convert_Rh_to_R(Rh_flat).view(*Rh.shape[:-1], 3, 3) + return super().forward(**{self.key: Rot}) + + def get_check_name(self, value): + name = ['angle'] + return name + + def check(self, value): + import cv2 + # TODO: here just use first rotation + if len(value.shape) == 3: + value = value[:, 0] + Rot = self.convert_Rh_to_R(value.detach())[:, :3] + vel = torch.matmul(Rot[:-1], Rot.transpose(1,2)[1:]).cpu().numpy() + vels = [] + for i in range(vel.shape[0]): + angle = np.linalg.norm(cv2.Rodrigues(vel[i])[0]) + vels.append(angle) + vel = np.array(vels).reshape(-1, 1) + return vel + +class BaseKeypoints(LossBase): + @staticmethod + def select(keypoints, index, ranges): + if len(index) > 0: + keypoints = keypoints[..., index, :] + elif len(ranges) > 0: + if ranges[1] == -1: + keypoints = keypoints[..., ranges[0]:, :] + else: + keypoints = keypoints[..., ranges[0]:ranges[1], :] + return keypoints + + def set_gt(self, index_gt, ranges_gt): + keypoints = self.select(self.keypoints_np, index_gt, ranges_gt) + keypoints = torch.Tensor(keypoints) + self.register_buffer('keypoints', keypoints[..., :-1]) + self.register_buffer('conf', keypoints[..., -1]) + + def __str__(self): + return "keypoints: {}".format(self.keypoints.shape) + + def __init__(self, keypoints, norm='l2', norm_info={}, + index_gt=[], ranges_gt=[], + index_est=[], ranges_est=[]) -> None: + super().__init__() + # prepare ground-truth + self.keypoints_np = keypoints + self.set_gt(index_gt, ranges_gt) + # + self.index_est = index_est + self.ranges_est = ranges_est + + self.norm = norm + self.loss = make_loss(norm, norm_info) + + def forward(self, kpts_est, **kwargs): + est = self.select(kpts_est, self.index_est, self.ranges_est) + return self.loss(est, self.keypoints, self.conf) + + def check(self, kpts_est, min_conf=0.3, **kwargs): + est = self.select(kpts_est, self.index_est, self.ranges_est) + conf = (self.conf>min_conf).float() + norm = torch.norm(est-self.keypoints, dim=-1) * conf + mean_joints = norm.sum(dim=0)/(1e-5 + conf.sum(dim=0)) * 1000 + return conf, mean_joints + + def check_at_start(self, kpts_est, **kwargs): + if len(self.index_est) > 0: + names = [str(i) for i in self.index_est] + elif len(self.ranges_est) > 0: + names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])] + else: + names = [str(i) for i in range(self.conf.shape[-1])] + conf, error = self.check(kpts_est, **kwargs) + valid = conf.sum(dim=0).detach().cpu().numpy().tolist() + header = ['name', 'count'] + contents = [names, valid] + header.append('before') + contents.append(error.detach().cpu().numpy().tolist()) + self.cache_check = (header, contents) + + def check_at_end(self, kpts_est, **kwargs): + conf, err_after = self.check(kpts_est, **kwargs) + header, contents = self.cache_check + header.append('after') + contents.append(err_after.detach().cpu().numpy().tolist()) + print_table(header, contents) + +class Keypoints3D(BaseKeypoints): + def __init__(self, keypoints3d, **kwargs) -> None: + super().__init__(keypoints3d, **kwargs) + +class AnyKeypoints3D(Keypoints3D): + def __init__(self, **kwargs) -> None: + key = kwargs.pop('key') + keypoints3d = kwargs.pop(key) + super().__init__(keypoints3d, **kwargs) + self.key = key + +class AnyKeypoints3DWithRT(Keypoints3D): + def __init__(self, **kwargs) -> None: + key = kwargs.pop('key') + keypoints3d = kwargs.pop(key) + super().__init__(keypoints3d, **kwargs) + self.key = key + + def forward(self, kpts_est, **kwargs): + R = batch_rodrigues(kwargs['R_'+self.key]) + T = kwargs['T_'+self.key] + RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :] + return super().forward(RXT) + + def check(self, kpts_est, min_conf=0.3, **kwargs): + R = batch_rodrigues(kwargs['R_'+self.key]) + T = kwargs['T_'+self.key] + RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :] + kpts_est = RXT + return super().check(kpts_est, min_conf) + +class Handl3D(BaseKeypoints): + def __init__(self, handl3d, **kwargs) -> None: + handl3d = handl3d.clone() + handl3d[..., :3] = handl3d[..., :3] - handl3d[:, :1, :3] + super().__init__(handl3d, **kwargs) + + def forward(self, kpts_est, **kwargs): + est = kpts_est[:, 25:46] + est = est - est[:, :1].detach() + return super().forward(est, **kwargs) + + def check(self, kpts_est, **kwargs): + est = kpts_est[:, 25:46] + est = est - est[:, :1].detach() + return super().check(est, **kwargs) + + +class LimbLength(BaseKeypoints): + def __init__(self, kintree, key='keypoints3d', **kwargs): + self.kintree = np.array(kintree) + if key == 'bodyhand': + keypoints3d = np.hstack([kwargs.pop('keypoints3d'), kwargs.pop('handl3d'), kwargs.pop('handr3d')]) + else: + keypoints3d = kwargs.pop(key) + super().__init__(keypoints3d, **kwargs) + + def __str__(self): + return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree])) + + def set_gt(self, index_gt, ranges_gt): + keypoints3d = self.keypoints_np + kintree = self.kintree + # limb_length: nFrames, nLimbs, 1 + limb_length = np.linalg.norm(keypoints3d[..., kintree[:, 1], :3] - keypoints3d[..., kintree[:, 0], :3], axis=-1, keepdims=True) + # conf: nFrames, nLimbs, 1 + limb_conf = np.minimum(keypoints3d[..., kintree[:, 1], -1], keypoints3d[..., kintree[:, 0], -1]) + limb_length = torch.Tensor(limb_length) + limb_conf = torch.Tensor(limb_conf) + self.register_buffer('length', limb_length) + self.register_buffer('conf', limb_conf) + + def forward(self, kpts_est, **kwargs): + src = kpts_est[..., self.kintree[:, 0], :] + dst = kpts_est[..., self.kintree[:, 1], :] + length_est = torch.norm(dst - src, dim=-1, keepdim=True) + return self.loss(length_est, self.length, self.conf) + + def check_at_start(self, kpts_est, **kwargs): + names = [str(i) for i in self.kintree] + conf = (self.conf>0) + valid = conf.sum(dim=0).detach().cpu().numpy() + if len(valid.shape) == 2: + valid = valid.mean(axis=0) + header = ['name', 'count'] + contents = [names, valid.tolist()] + error, length = self.check(kpts_est) + header.append('before') + contents.append(error.detach().cpu().numpy().tolist()) + header.append('length') + length = (self.length[..., 0] * self.conf).sum(dim=0)/self.conf.sum(dim=0) + contents.append(length.detach().cpu().numpy().tolist()) + self.cache_check = (header, contents) + + def check_at_end(self, kpts_est, **kwargs): + err_after, length = self.check(kpts_est) + header, contents = self.cache_check + header.append('after') + contents.append(err_after.detach().cpu().numpy().tolist()) + header.append('length_est') + contents.append(length[:,:,0].mean(dim=0).detach().cpu().numpy().tolist()) + print_table(header, contents) + + def check(self, kpts_est, **kwargs): + src = kpts_est[..., self.kintree[:, 0], :] + dst = kpts_est[..., self.kintree[:, 1], :] + length_est = torch.norm(dst - src, dim=-1, keepdim=True) + conf = (self.conf>0).float() + norm = torch.abs(length_est-self.length)[..., 0] * conf + mean_joints = norm.sum(dim=0)/conf.sum(dim=0) * 1000 + if len(mean_joints.shape) == 2: + mean_joints = mean_joints.mean(dim=0) + length_est = length_est.mean(dim=0) + return mean_joints, length_est + +class LimbLengthHand(LimbLength): + def __init__(self, handl3d, handr3d, **kwargs): + kintree = kwargs.pop('kintree') + keypoints3d = torch.cat([handl3d, handr3d], dim=0) + super().__init__(kintree, keypoints3d, **kwargs) + + def forward(self, kpts_est, **kwargs): + kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0) + return super().forward(kpts_est, **kwargs) + + def check(self, kpts_est, **kwargs): + kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0) + return super().check(kpts_est, **kwargs) + +class Keypoints2D(BaseKeypoints): + def __init__(self, keypoints2d, K, Rc, Tc, einsum='fab,fnb->fna', + unproj=True, reshape_views=False, **kwargs) -> None: + # convert to camera coordinate + invKtrans = torch.inverse(K).transpose(-1, -2) + if unproj: + homo = torch.ones_like(keypoints2d[..., :1]) + homo = torch.cat([keypoints2d[..., :2], homo], dim=-1) + if len(invKtrans.shape) < len(homo.shape): + invKtrans = invKtrans.unsqueeze(-3) + homo = torch.matmul(homo, invKtrans) + keypoints2d = torch.cat([homo[..., :2], keypoints2d[..., 2:]], dim=-1) + # keypoints2d: (nFrames, nViews, ..., nJoints, 3) + super().__init__(keypoints2d, **kwargs) + self.register_buffer('K', K) + self.register_buffer('invKtrans', invKtrans) + self.register_buffer('Rc', Rc) + self.register_buffer('Tc', Tc) + self.unproj = unproj + self.einsum = einsum + self.reshape_views = reshape_views + + def project(self, kpts_est): + kpts_est = self.select(kpts_est, self.index_est, self.ranges_est) + kpts_homo = torch.ones_like(kpts_est[..., -1:]) + kpts_homo = torch.cat([kpts_est, kpts_homo], dim=-1) + if self.unproj: + P = torch.cat([self.Rc, self.Tc], dim=-1) + else: + P = torch.bmm(self.K, torch.cat([self.Rc, self.Tc], dim=-1)) + if self.reshape_views: + kpts_homo = kpts_homo.reshape(self.K.shape[0], self.K.shape[1], *kpts_homo.shape[1:]) + try: + point_cam = torch.einsum(self.einsum, P, kpts_homo) + except: + print('Wrong shape: {}x{} <=== {}'.format(P.shape, kpts_homo.shape, self.einsum)) + raise NotImplementedError + img_points = point_cam[..., :2]/point_cam[..., 2:] + return img_points + + def forward(self, kpts_est, **kwargs): + img_points = self.project(kpts_est) + loss = self.loss(img_points.squeeze(), self.keypoints.squeeze(), self.conf.squeeze()) + return loss + + def check(self, kpts_est, min_conf=0.3): + with torch.no_grad(): + img_points = self.project(kpts_est) + conf = (self.conf>min_conf) + err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf + if len(err.shape) == 3: + err = err.sum(dim=1) + conf = conf.sum(dim=1) + err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0)) + return conf, err + + def check_at_start(self, kpts_est, **kwargs): + if len(self.index_est) > 0: + names = [str(i) for i in self.index_est] + elif len(self.ranges_est) > 0: + names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])] + else: + names = [str(i) for i in range(self.conf.shape[-1])] + conf, error = self.check(kpts_est) + valid = conf.sum(dim=0).detach().cpu().numpy() + valid = valid.tolist() + header = ['name', 'count'] + contents = [names, valid] + header.append('before(pix)') + contents.append(error.detach().cpu().numpy().tolist()) + self.cache_check = (header, contents) + + def check_at_end(self, kpts_est, **kwargs): + conf, err_after = self.check(kpts_est) + header, contents = self.cache_check + header.append('after(pix)') + contents.append(err_after.detach().cpu().numpy().tolist()) + print_table(header, contents) \ No newline at end of file diff --git a/easymocap/multistage/mirror.py b/easymocap/multistage/mirror.py new file mode 100644 index 0000000..04439c5 --- /dev/null +++ b/easymocap/multistage/mirror.py @@ -0,0 +1,261 @@ +''' + @ Date: 2022-07-12 11:55:47 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-07-14 17:57:48 + @ FilePath: /EasyMocapPublic/easymocap/multistage/mirror.py +''' +import numpy as np +import torch +from ..dataset.mirror import flipPoint2D, flipSMPLPoses, flipSMPLParams +from ..estimator.wrapper_base import bbox_from_keypoints +from .lossbase import Keypoints2D + +def calc_vanishpoint(keypoints2d): + ''' + keypoints2d: (2, N, 3) + ''' + # weight: (N, 1) + weight = keypoints2d[:, :, 2:].mean(axis=0) + conf = weight.mean() + A = np.hstack([ + keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2], + -(keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1]) + ]) + b = -keypoints2d[0, :, 0:1]*(keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2]) \ + + keypoints2d[0, :, 1:2] * (keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1]) + b = -b + A = A * weight + b = b * weight + avgInsec = np.linalg.inv(A.T @ A) @ (A.T @ b) + result = np.zeros(3) + result[0] = avgInsec[0, 0] + result[1] = avgInsec[1, 0] + result[2] = 1 + return result + +def calc_mirror_transform(m_): + """ From mirror vector to mirror matrix + Args: + m (bn, 4): (a, b, c, d) + Returns: + M: (bn, 3, 4) + """ + norm = torch.norm(m_[:, :3], dim=1, keepdim=True) + m = m_[:, :3] / norm + d = m_[:, 3] + coeff_mat = torch.zeros((m.shape[0], 3, 4), device=m.device) + coeff_mat[:, 0, 0] = 1 - 2*m[:, 0]**2 + coeff_mat[:, 0, 1] = -2*m[:, 0]*m[:, 1] + coeff_mat[:, 0, 2] = -2*m[:, 0]*m[:, 2] + coeff_mat[:, 0, 3] = -2*m[:, 0]*d + coeff_mat[:, 1, 0] = -2*m[:, 1]*m[:, 0] + coeff_mat[:, 1, 1] = 1-2*m[:, 1]**2 + coeff_mat[:, 1, 2] = -2*m[:, 1]*m[:, 2] + coeff_mat[:, 1, 3] = -2*m[:, 1]*d + coeff_mat[:, 2, 0] = -2*m[:, 2]*m[:, 0] + coeff_mat[:, 2, 1] = -2*m[:, 2]*m[:, 1] + coeff_mat[:, 2, 2] = 1-2*m[:, 2]**2 + coeff_mat[:, 2, 3] = -2*m[:, 2]*d + return coeff_mat + +class InitNormal: + def __init__(self, static) -> None: + self.static = static + + def __call__(self, body_model, body_params, infos): + if 'normal' in infos.keys(): + print('>>> Reading normal: {}'.format(infos['normal'])) + return body_params + kpts = infos['keypoints2d'] + kpts0 = kpts[:, 0] + kpts1 = flipPoint2D(kpts[:, 1]) + vanish_line = torch.stack([kpts0.reshape(-1, 3), kpts1.reshape(-1, 3)], dim=1) + MIN_THRES = 0.5 + conf = (vanish_line[:, 0, -1] > MIN_THRES) & (vanish_line[:, 1, -1] > MIN_THRES) + vanish_line = vanish_line[conf] + vline0 = vanish_line.numpy().transpose(1, 0, 2) + vpoint0 = calc_vanishpoint(vline0).reshape(1, 3) + # 计算点到线的距离进行检查 + # two points line: (x1, y1), (x2, y2) ==> (y-y1)/(x-x1) = (y2-y1)/(x2-x1) + # A = y2 - y1 + # B = x1 - x2 + # C = x2y1 - x1y2 + # d = abs(ax + by + c)/sqrt(a^2+b^2) + A_v0 = kpts0[:, :, 1] - vpoint0[0, 1] + B_v0 = vpoint0[0, 0] - kpts0[:, :, 0] + C_v0 = kpts0[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts0[:, :, 1] + distance01 = np.abs(A_v0 * kpts1[:, :, 0] + B_v0 * kpts1[:, :, 1] + C_v0)/np.sqrt(A_v0*A_v0 + B_v0*B_v0) + A_v1 = kpts1[:, :, 1] - vpoint0[0, 1] + B_v1 = vpoint0[0, 0] - kpts1[:, :, 0] + C_v1 = kpts1[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts1[:, :, 1] + distance10 = np.abs(A_v1 * kpts0[:, :, 0] + B_v1 * kpts0[:, :, 1] + C_v1)/np.sqrt(A_v1*A_v1 + B_v1*B_v1) + DIST_THRES = 0.05 + for nf in range(kpts.shape[0]): + # 计算scale + bbox0 = bbox_from_keypoints(kpts0[nf].cpu().numpy()) + bbox1 = bbox_from_keypoints(kpts1[nf].cpu().numpy()) + bbox_size0 = max(bbox0[2]-bbox0[0], bbox0[3]-bbox0[1]) + bbox_size1 = max(bbox1[2]-bbox1[0], bbox1[3]-bbox1[1]) + valid = (kpts0[nf, :, 2] > 0.3) & (kpts1[nf, :, 2] > 0.3) + dist01_ = valid*distance01[nf] / bbox_size1 + dist10_ = valid*distance10[nf] / bbox_size0 + # 对于距离异常的点,阈值设定为0.1 + # 抑制掉置信度低的视角的点 + not_valid0 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] < kpts1[nf][:, -1]))[0] + not_valid1 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] > kpts1[nf][:, -1]))[0] + kpts0[nf, not_valid0] = 0. + kpts1[nf, not_valid1] = 0. + if len(not_valid0) > 0: + print('[mirror] filter {} person 0: {}'.format(nf, not_valid0)) + if len(not_valid1) > 0: + print('[mirror] filter {} person 1: {}'.format(nf, not_valid1)) + kpts1_ = flipPoint2D(kpts1) + infos['keypoints2d'] = torch.stack([kpts0, kpts1_], dim=1) + infos['vanish_point0'] = torch.Tensor(vpoint0) + K = infos['K'][0] + normal = np.linalg.inv(K) @ vpoint0.T + normal = normal.T/np.linalg.norm(normal) + print('>>> Calculating normal from keypoints: {}'.format(normal[0])) + infos['normal'] = torch.Tensor(normal) + mirror = torch.zeros((1, 4)) + # 计算镜子平面到相机的距离 + Th = body_params['Th'] + center = Th.mean(axis=1) + # 相机原点到两个人中心的连线在normal上的投影 + dist = (center * normal).sum(axis=-1).mean() + print('>>> Calculating distance from Th: {}'.format(dist)) + mirror[0, 3] = - dist # initial guess + mirror[:, :3] = infos['normal'] + infos['mirror'] = mirror + return body_params + +class RemoveP1: + def __init__(self, static) -> None: + self.static = static + + def __call__(self, body_model, body_params, infos): + for key in body_params.keys(): + if key == 'id': continue + body_params[key] = body_params[key][:, 0] + return body_params + +class Mirror: + def __init__(self, key) -> None: + self.key = key + + def before(self, body_params): + poses = body_params['poses'][:, 0] + # append root + poses = torch.cat([torch.zeros_like(poses[..., :3]), poses], dim=-1) + poses_mirror = flipSMPLPoses(poses) + poses = torch.cat([poses[:, None, 3:], poses_mirror[:, None, 3:]], dim=1) + body_params['poses'] = poses + return body_params + + def after(self,): + pass + + def final(self, body_params): + return self.before(body_params) + +class Keypoints2DMirror(Keypoints2D): + def __init__(self, mirror, opt_normal, **kwargs): + super().__init__(**kwargs) + if not mirror.requires_grad: + self.register_buffer('mirror', mirror) + else: + self.mirror = mirror + self.opt_normal = opt_normal + k2dall = kwargs['keypoints2d'] + size_all = [] + for nf in range(k2dall.shape[0]): + for nper in range(2): + kpts = k2dall[nf, nper] + bbox = bbox_from_keypoints(kpts.cpu().numpy()) + bbox_size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size_all.append(bbox_size) + size_all = np.array(size_all).reshape(-1, 2) + scale = (size_all[:, 0] / size_all[:, 1]).mean() + print('[loss] mean scale = {} from {} frames, use this to balance the two person'.format(scale, size_all.shape[0])) + # ATTN: here we use v^2 to suppress the outlier detections + self.conf = self.conf * self.conf + self.conf[:, 1] *= scale*scale + + def check(self, kpts_est, min_conf=0.3): + with torch.no_grad(): + M = calc_mirror_transform(self.mirror) + homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device) + kpts_homo = torch.cat([kpts_est, homo], dim=-1) + kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2)) + kpts = torch.stack([kpts_est, kpts_mirror], dim=1) + img_points = self.project(kpts) + conf = (self.conf>min_conf) + err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf + if len(err.shape) == 3: + err = err.sum(dim=1) + conf = conf.sum(dim=1) + err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0)) + return conf, err + + def forward(self, kpts_est, **kwargs): + if self.opt_normal: + M = calc_mirror_transform(self.mirror) + else: + mirror = torch.cat([self.mirror[:, :3].detach(), self.mirror[:, 3:]], dim=1) + M = calc_mirror_transform(mirror) + homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device) + kpts_homo = torch.cat([kpts_est, homo], dim=-1) + kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2)) + kpts = torch.stack([kpts_est, kpts_mirror], dim=1) + return super().forward(kpts_est=kpts, **kwargs) + +class MirrorPoses: + def __init__(self, ref) -> None: + self.ref = ref + + def __call__(self, body_model, body_params, infos): + # shapes: (nFrames, 2, nShapes) + shapes = body_params['shapes'].mean(axis=0).mean(axis=0).reshape(1, 1, -1) + poses = body_params['poses'][:, 0] + # append root + poses = np.concatenate([np.zeros([poses.shape[0], 3]), poses], axis=1) + poses_mirror = flipSMPLPoses(poses) + poses = np.concatenate([poses[:, None, 3:], poses_mirror[:, None, 3:]], axis=1) + body_params['poses'] = poses + body_params['shapes'] = shapes + return body_params + +class MirrorParams: + def __init__(self, key) -> None: + self.key = key + + def start(self, body_params): + if len(body_params['poses'].shape) == 2: + return body_params + for key in body_params.keys(): + if key == 'id': continue + body_params[key] = body_params[key][:, 0] + return body_params + + def before(self, body_params): + return body_params + + def after(self,): + pass + + def final(self, body_params): + device = body_params['poses'].device + body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()} + body_params['poses'] = np.hstack((np.zeros_like(body_params['poses'][:, :3]), body_params['poses'])) + params_mirror = flipSMPLParams(body_params, self.infos['mirror'].cpu().numpy()) + params = {} + for key in params_mirror.keys(): + if key == 'shapes': + params[key] = body_params[key][:, None] + else: + params[key] = np.concatenate([body_params[key][:, None], params_mirror[key][:, None]], axis=-2) + params['poses'] = params['poses'][..., 3:] + params = {key:torch.Tensor(val).to(device) for key, val in params.items()} + return params \ No newline at end of file diff --git a/easymocap/multistage/synchronization.py b/easymocap/multistage/synchronization.py new file mode 100644 index 0000000..3ccf2b6 --- /dev/null +++ b/easymocap/multistage/synchronization.py @@ -0,0 +1,79 @@ +''' + @ Date: 2022-03-11 12:13:01 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-08-11 21:52:00 + @ FilePath: /EasyMocapPublic/easymocap/multistage/synchronization.py +''' +import numpy as np +import torch + +class AddTime: + def __init__(self, gt) -> None: + self.gt = gt + + def __call__(self, body_model, body_params, infos): + nViews = infos['keypoints2d'].shape[1] + offset = np.zeros((nViews,), dtype=np.float32) + body_params['sync_offset'] = offset + return body_params + +class Interpolate: + def __init__(self, actfn) -> None: + # self.act_fn = lambda x: 2*torch.nn.functional.softsign(x) + self.act_fn = lambda x: 2*torch.tanh(x) + self.use0asref = False + + def get_offset(self, time_offset): + if self.use0asref: + off = self.act_fn(torch.cat([torch.zeros(1, device=time_offset.device), time_offset[1:]])) + else: + off = self.act_fn(time_offset) + return off + + def start(self, body_params): + return body_params + + def before(self, body_params): + off = self.get_offset(body_params['sync_offset']) + nViews = off.shape[0] + if len(body_params['poses'].shape) == 2: + off = off[None, :, None] + else: + off = off[None, :, None, None] + for key in body_params.keys(): + if key in ['sync_offset', 'shapes']: + continue + # TODO: Rh有正周期旋转的时候会有问题 + val = body_params[key] + if key == 'Rh': + pass + if key in ['Th', 'poses']: + velocity = torch.cat([val[1:2] - val[0:1], val[1:] - val[:-1]], dim=0) + valnew = val[:, None] + off * velocity[:, None] + # vel = velocity.detach().cpu().numpy() + # import matplotlib.pyplot as plt + # plt.plot(vel) + # plt.show() + # import ipdb;ipdb.set_trace() + else: + if len(val.shape) == 2: + valnew = val[:, None].repeat(1, nViews, 1) + elif len(val.shape) == 3: + valnew = val[:, None].repeat(1, nViews, 1, 1) + else: + print('[warn] Unknown {} shape {}'.format(key, valnew.shape)) + import ipdb; ipdb.set_trace() + valnew = valnew.reshape(-1, *val.shape[1:]) + body_params[key] = valnew + return body_params + + def after(self,): + pass + + def final(self, body_params): + off = self.get_offset(body_params['sync_offset']) + body_params = self.before(body_params) + body_params['sync_offset'] = off + return body_params \ No newline at end of file diff --git a/easymocap/multistage/torchgeometry.py b/easymocap/multistage/torchgeometry.py new file mode 100644 index 0000000..888282b --- /dev/null +++ b/easymocap/multistage/torchgeometry.py @@ -0,0 +1,517 @@ +""" +useful functions to perform conversion between rotation in different format(quaternion, rotation_matrix, euler_angle, axis_angle) +quaternion representation: (w,x,y,z) +code reference: torchgeometry, kornia, https://github.com/MandyMo/pytorch_HMR. +""" + +import torch +from torch.nn import functional as F +import numpy as np + + +# Conversions between different rotation representations, quaternion,rotation matrix,euler and axis angle. + +def rot6d_to_rotation_matrix(rot6d): + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations. + Returns: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices. + """ + x = rot6d.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def rotation_matrix_to_rot6d(rotation_matrix): + """ + Convert 3x3 rotation matrix to 6D rotation representation. + Args: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices. + Returns: + rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations. + """ + v1 = rotation_matrix[:, :, 0:1] + v2 = rotation_matrix[:, :, 1:2] + rot6d = torch.cat([v1, v2], dim=-1).reshape(v1.shape[0], 6) + return rot6d + + +def quaternion_to_rotation_matrix(quaternion): + """ + Convert quaternion coefficients to rotation matrix. + Args: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + Returns: + rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3) + """ + + norm_quaternion = quaternion + norm_quaternion = norm_quaternion / \ + norm_quaternion.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quaternion[:, 0], norm_quaternion[:, + 1], norm_quaternion[:, 2], norm_quaternion[:, 3] + + batch_size = quaternion.size(0) + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotation_matrix = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(batch_size, 3, 3) + return rotation_matrix + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + Convert rotation matrix to corresponding quaternion + Args: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) + Returns: + quaternion: torch tensor of shape(batch_size, 4) in (w, x, y, z) representation. + """ + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * (~ mask_d0_d1) + mask_c2 = (~ mask_d2) * mask_d0_nd1 + mask_c3 = (~ mask_d2) * (~ mask_d0_nd1) + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def quaternion_to_euler(quaternion, order, epsilon=0): + """ + Convert quaternion to euler angles. + Args: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + order: euler angle representation order, 'zyx' etc. + epsilon: + Returns: + euler: torch tensor of shape (batch_size, 3) in order. + """ + assert quaternion.shape[-1] == 4 + original_shape = list(quaternion.shape) + original_shape[-1] = 3 + q = quaternion.contiguous().view(-1, 4) + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp( + 2 * (q1 * q3 + q0 * q2), -1+epsilon, 1-epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp( + 2 * (q1 * q2 + q0 * q3), -1+epsilon, 1-epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp( + 2 * (q0 * q1 + q2 * q3), -1+epsilon, 1-epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2*(q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp( + 2 * (q0 * q3 - q1 * q2), -1+epsilon, 1-epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp( + 2 * (q0 * q1 - q2 * q3), -1+epsilon, 1-epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2*(q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2*(q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp( + 2 * (q0 * q2 - q1 * q3), -1+epsilon, 1-epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2*(q2 * q2 + q3 * q3)) + else: + raise Exception('unsupported euler order!') + + return torch.stack((x, y, z), dim=1).view(original_shape) + + +def euler_to_quaternion(euler, order): + """ + Convert euler angles to quaternion. + Args: + euler: torch tensor of shape (batch_size, 3) in order. + order: + Returns: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + """ + assert euler.shape[-1] == 3 + original_shape = list(euler.shape) + original_shape[-1] = 4 + e = euler.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x/2), torch.sin(x/2), + torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y/2), torch.zeros_like(y), + torch.sin(y/2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z/2), torch.zeros_like(z), + torch.zeros_like(z), torch.sin(z/2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise Exception('unsupported euler order!') + if result is None: + result = r + else: + result = quaternion_mul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_axis_angle(quaternion): + """ + Convert quaternion to axis angle. + based on: https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138 + Args: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + Returns: + axis_angle: torch tensor of shape (batch_size, 3) + """ + epsilon = 1.e-8 + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" + .format(quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta+epsilon) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert axis angle to quaternion. + Args: + axis_angle: torch tensor of shape (batch_size, 3) + Returns: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + """ + rotation_matrix = axis_angle_to_rotation_matrix(axis_angle) + return rotation_matrix_to_quaternion(rotation_matrix) + + +def axis_angle_to_rotation_matrix(axis_angle): + """ + Convert axis-angle representation to rotation matrix. + Args: + axis_angle: torch tensor of shape (batch_size, 3). + Returns: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices. + """ + + l1_norm = torch.norm(axis_angle+1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1_norm, dim=-1) + normalized = torch.div(axis_angle, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quaternion = torch.cat([v_cos, v_sin*normalized], dim=1) + return quaternion_to_rotation_matrix(quaternion) + + +def rotation_matrix_to_axis_angle(rotation_matrix): + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + return quaternion_to_axis_angle(quaternion) + + +def rotation_matrix_to_euler(rotation_matrix, order): + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + return quaternion_to_euler(quaternion, order) + + +def euler_to_rotation_matrix(euler, order): + quaternion = euler_to_quaternion(euler, order) + return quaternion_to_rotation_matrix(quaternion) + + +def axis_angle_to_euler(axis_angle, order): + quaternion = axis_angle_to_quaternion(axis_angle) + return quaternion_to_euler(quaternion, order) + + +def euler_to_axis_angle(euler, order): + quaternion = euler_to_quaternion(euler, order) + return quaternion_to_axis_angle(quaternion) + +# rotation operations + + +def quaternion_mul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.contiguous().view(-1, 4, 1), + q.contiguous().view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def rotate_vec_by_quaternion(v, q): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + q = q.contiguous().view(-1, 4) + v = v.view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def quaternion_fix(quaternion): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + Args: + quaternion: torch tensor of shape (batch_size, 4) + Returns: + quaternion: torch tensor of shape (batch_size, 4) + """ + quaternion_fixed = quaternion.clone() + dot_products = torch.sum(quaternion[1:]*quaternion[:-1],dim=-1) + mask = dot_products < 0 + mask = (torch.cumsum(mask, dim=0) % 2).bool() + quaternion_fixed[1:][mask] *= -1 + return quaternion_fixed + + +def quaternion_inverse(quaternion): + q_conjugate = quaternion.clone() + q_conjugate[::, 1:] * -1 + q_norm = quaternion[::, 1:].norm(dim=-1) + quaternion[::, 0]**2 + return q_conjugate/q_norm.unsqueeze(-1) + + +def quaternion_lerp(q1, q2, t): + q = (1-t)*q1 + t*q2 + q = q/q.norm(dim=-1).unsqueeze(-1) + return q + +def geodesic_dist(q1,q2): + """ + @q1: torch tensor of shape (frame, joints, 4) quaternion + @q2: same as q1 + @output: torch tensor of shape (frame, joints) + """ + q1_conjugate = q1.clone() + q1_conjugate[:,:,1:] *= -1 + q1_norm = q1[:,:,1:].norm(dim=-1) + q1[:,:,0]**2 + q1_inverse = q1_conjugate/q1_norm.unsqueeze(dim=-1) + q_between = quaternion_mul(q1_inverse,q2) + geodesic_dist = quaternion_to_axis_angle(q_between).norm(dim=-1) + return geodesic_dist + +def get_extrinsic(translation, rotation): + batch_size = translation.shape[0] + pose = torch.zeros((batch_size, 4, 4)) + pose[:,:3, :3] = rotation + pose[:,:3, 3] = translation + pose[:,3, 3] = 1 + extrinsic = torch.inverse(pose) + return extrinsic[:,:3, 3], extrinsic[:,:3, :3] + +def euler_fix_old(euler): + frame_num = euler.shape[0] + joint_num = euler.shape[1] + for l in range(3): + for j in range(joint_num): + overall_add = 0. + for i in range(1,frame_num): + add1 = overall_add + add2 = overall_add + 2*np.pi + add3 = overall_add - 2*np.pi + previous = euler[i-1,j,l] + value1 = euler[i,j,l] + add1 + value2 = euler[i,j,l] + add2 + value3 = euler[i,j,l] + add3 + e1 = torch.abs(value1 - previous) + e2 = torch.abs(value2 - previous) + e3 = torch.abs(value3 - previous) + if (e1 <= e2) and (e1 <= e3): + euler[i,j,l] = value1 + overall_add = add1 + if (e2 <= e1) and (e2 <= e3): + euler[i, j, l] = value2 + overall_add = add2 + if (e3 <= e1) and (e3 <= e2): + euler[i, j, l] = value3 + overall_add = add3 + return euler + +def euler_fix(euler,rotation_order='zyx'): + frame_num = euler.shape[0] + joint_num = euler.shape[1] + euler_new = euler.clone() + for j in range(joint_num): + euler_new[:,j] = euler_filter(euler[:,j],rotation_order) + return euler_new + +''' +euler filter from https://github.com/wesen/blender-euler-filter/blob/master/euler_filter.py. +''' +def euler_distance(e1, e2): + return abs(e1[0] - e2[0]) + abs(e1[1] - e2[1]) + abs(e1[2] - e2[2]) + + +def euler_axis_index(axis): + if axis == 'x': + return 0 + if axis == 'y': + return 1 + if axis == 'z': + return 2 + return None + +def flip_euler(euler, rotation_mode): + ret = euler.clone() + inner_axis = rotation_mode[0] + outer_axis = rotation_mode[2] + middle_axis = rotation_mode[1] + + ret[euler_axis_index(inner_axis)] += np.pi + ret[euler_axis_index(outer_axis)] += np.pi + ret[euler_axis_index(middle_axis)] *= -1 + ret[euler_axis_index(middle_axis)] += np.pi + return ret + +def naive_flip_diff(a1, a2): + while abs(a1 - a2) >= np.pi+1e-5: + if a1 < a2: + a2 -= 2 * np.pi + else: + a2 += 2 * np.pi + + return a2 + +def euler_filter(euler,rotation_order): + frame_num = euler.shape[0] + if frame_num <= 1: + return euler + euler_fix = euler.clone() + prev = euler[0] + for i in range(1,frame_num): + e = euler[i] + for d in range(3): + e[d] = naive_flip_diff(prev[d],e[d]) + fe = flip_euler(e,rotation_order) + for d in range(3): + fe[d] = naive_flip_diff(prev[d],fe[d]) + + de = euler_distance(prev,e) + dfe = euler_distance(prev,fe) + if dfe < de: + e = fe + prev = e + euler_fix[i] = e + return euler_fix \ No newline at end of file diff --git a/easymocap/multistage/totalfitting.py b/easymocap/multistage/totalfitting.py new file mode 100644 index 0000000..b66a953 --- /dev/null +++ b/easymocap/multistage/totalfitting.py @@ -0,0 +1,97 @@ +''' + @ Date: 2022-07-28 14:39:23 + @ Author: Qing Shuai + @ Mail: s_q@zju.edu.cn + @ LastEditors: Qing Shuai + @ LastEditTime: 2022-08-12 21:42:12 + @ FilePath: /EasyMocapPublic/easymocap/multistage/totalfitting.py +''' +import torch + +from ..bodymodel.lbs import batch_rodrigues +from .torchgeometry import rotation_matrix_to_axis_angle, rotation_matrix_to_quaternion, quaternion_to_rotation_matrix, quaternion_to_axis_angle +import numpy as np +from .base_ops import BeforeAfterBase + +def compute_twist_rotation(rotation_matrix, twist_axis): + ''' + Compute the twist component of given rotation and twist axis + https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis + Parameters + ---------- + rotation_matrix : Tensor (B, 3, 3,) + The rotation to convert + twist_axis : Tensor (B, 3,) + The twist axis + Returns + ------- + Tensor (B, 3, 3) + The twist rotation + ''' + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + + twist_axis = twist_axis / (torch.norm(twist_axis, dim=1, keepdim=True) + 1e-9) + + projection = torch.einsum('bi,bi->b', twist_axis, quaternion[:, 1:]).unsqueeze(-1) * twist_axis + + twist_quaternion = torch.cat([quaternion[:, 0:1], projection], dim=1) + twist_quaternion = twist_quaternion / (torch.norm(twist_quaternion, dim=1, keepdim=True) + 1e-9) + + twist_rotation = quaternion_to_rotation_matrix(twist_quaternion) + + twist_aa = quaternion_to_axis_angle(twist_quaternion) + + twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True) + + return twist_rotation, twist_angle + +class ClearTwist(BeforeAfterBase): + def start(self, body_params): + idx_elbow = [18-1, 19-1] + for idx in idx_elbow: + # x + body_params['poses'][:, 3*idx] = 0. + # z + body_params['poses'][:, 3*idx+2] = 0. + idx_wrist = [20-1, 21-1] + for idx in idx_wrist: + body_params['poses'][:, 3*idx:3*idx+3] = 0. + return body_params + +class SolveTwist(BeforeAfterBase): + def __init__(self, body_model=None) -> None: + self.body_model = body_model + + def final(self, body_params): + T_joints, T_vertices = self.body_model.transform(body_params) + # This transform don't consider RT + R = batch_rodrigues(body_params['Rh']) + template = self.body_model.keypoints({'shapes': body_params['shapes'], + 'poses': torch.zeros_like(body_params['poses'])}, + only_shape=True, return_smpl_joints=True) + config = { + 'left': { + 'index_smpl': 20, + 'index_elbow_smpl': 18, + 'R_global': 'R_handl3d', + 'axis': torch.Tensor([[1., 0., 0.]]).to(device=T_joints.device), + }, + 'right': { + 'index_smpl': 21, + 'index_elbow_smpl': 19, + 'R_global': 'R_handr3d', + 'axis': torch.Tensor([[-1., 0., 0.]]).to(device=T_joints.device), + } + } + for key in ['left', 'right']: + cfg = config[key] + R_wrist_add = batch_rodrigues(body_params[cfg['R_global']]) + idx_elbow = cfg['index_elbow_smpl'] + idx_wrist = cfg['index_smpl'] + pred_parent_elbow = R @ T_joints[..., idx_elbow, :3, :3] + pred_parent_wrist = R @ T_joints[..., idx_wrist, :3, :3] + pred_global_wrist = torch.bmm(R_wrist_add, pred_parent_wrist) + pred_local_wrist = torch.bmm(pred_parent_wrist.transpose(-1, -2), pred_global_wrist) + axis = rotation_matrix_to_axis_angle(pred_local_wrist) + body_params['poses'][..., 3*idx_wrist-3:3*idx_wrist] = axis + return body_params \ No newline at end of file