🚧 create the new stype of fitting
This commit is contained in:
parent
5bc4b113ba
commit
a0127f712a
308
easymocap/multistage/base.py
Normal file
308
easymocap/multistage/base.py
Normal file
@ -0,0 +1,308 @@
|
||||
# 这个脚本用于通用的多阶段的优化
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..annotator.file_utils import read_json
|
||||
from ..mytools import Timer
|
||||
from .lossbase import print_table
|
||||
from ..config.baseconfig import load_object
|
||||
from ..bodymodel.base import Params
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
def dict_of_numpy_to_tensor(body_model, body_params, *args, **kwargs):
|
||||
device = body_model.device
|
||||
body_params = {key:torch.Tensor(val).to(device) for key, val in body_params.items()}
|
||||
return body_params
|
||||
|
||||
class AddExtra:
|
||||
def __init__(self, vals) -> None:
|
||||
self.vals = vals
|
||||
|
||||
def __call__(self, body_model, body_params, *args, **kwargs):
|
||||
shapes = body_params['poses'].shape[:-1]
|
||||
for key in self.vals:
|
||||
if key in body_params.keys():
|
||||
continue
|
||||
if key.startswith('R_') or key.startswith('T_'):
|
||||
val = np.zeros((*shapes, 3), dtype=np.float32)
|
||||
body_params[key] = val
|
||||
return body_params
|
||||
|
||||
def dict_of_tensor_to_numpy(body_params):
|
||||
body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()}
|
||||
return body_params
|
||||
|
||||
def grad_require(params, flag=False):
|
||||
if isinstance(params, list):
|
||||
for par in params:
|
||||
par.requires_grad = flag
|
||||
elif isinstance(params, dict):
|
||||
for key, par in params.items():
|
||||
par.requires_grad = flag
|
||||
|
||||
def rel_change(prev_val, curr_val):
|
||||
return (prev_val - curr_val) / max([1e-5, abs(prev_val), abs(curr_val)])
|
||||
|
||||
def make_optimizer(opt_params, optim_type='lbfgs', max_iter=20,
|
||||
lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0, **kwargs):
|
||||
if isinstance(opt_params, dict):
|
||||
# LBFGS 不支持参数字典
|
||||
opt_params = list(opt_params.values())
|
||||
if optim_type == 'lbfgs':
|
||||
from ..pyfitting.lbfgs import LBFGS
|
||||
optimizer = LBFGS(opt_params, line_search_fn='strong_wolfe', max_iter=max_iter, **kwargs)
|
||||
elif optim_type == 'adam':
|
||||
optimizer = torch.optim.Adam(opt_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return optimizer
|
||||
|
||||
def make_lossfuncs(stage, infos, device, irepeat, verbose=False):
|
||||
loss_funcs, weights = {}, {}
|
||||
for key, val in stage.loss.items():
|
||||
loss_args = dict(val.args)
|
||||
if 'infos' in val.keys():
|
||||
for k in val.infos:
|
||||
loss_args[k] = infos[k]
|
||||
module = load_object(val.module, loss_args)
|
||||
module.to(device)
|
||||
if 'weights' in val.keys():
|
||||
weights[key] = val.weights[irepeat]
|
||||
else:
|
||||
weights[key] = val.weight
|
||||
if weights[key] < 0:
|
||||
weights.pop(key)
|
||||
else:
|
||||
loss_funcs[key] = module
|
||||
if verbose or True:
|
||||
print('Loss functions: ')
|
||||
for key, func in loss_funcs.items():
|
||||
print(' - {:15s}: {}, {}'.format(key, weights[key], func))
|
||||
return loss_funcs, weights
|
||||
|
||||
def make_before_after(before_after, body_model, body_params, infos):
|
||||
modules = []
|
||||
for key, val in before_after.items():
|
||||
args = dict(val.args)
|
||||
if 'body_model' in args.keys():
|
||||
args['body_model'] = body_model
|
||||
try:
|
||||
module = load_object(val.module, args)
|
||||
except:
|
||||
print('[Fitting] Failed to load module {}'.format(key))
|
||||
raise NotImplementedError
|
||||
module.infos = infos
|
||||
modules.append(module)
|
||||
return modules
|
||||
|
||||
def process(start_or_end, body_model, body_params, infos):
|
||||
for key, val in start_or_end.items():
|
||||
if isinstance(val, dict):
|
||||
module = load_object(val.module, val.args)
|
||||
else:
|
||||
if key == 'convert' and val == 'numpy_to_tensor':
|
||||
module = dict_of_numpy_to_tensor
|
||||
if key == 'add':
|
||||
module = AddExtra(val)
|
||||
body_params = module(body_model, body_params, infos)
|
||||
return body_params
|
||||
|
||||
def plot_meshes(img, meshes, K, R, T):
|
||||
import cv2
|
||||
mesh_camera = []
|
||||
for mesh in meshes:
|
||||
vertices = mesh['vertices'] @ R.T + T.T
|
||||
v2d = vertices @ K.T
|
||||
v2d[:, :2] = v2d[:, :2] / v2d[:, 2:3]
|
||||
lw=1
|
||||
col=(0,0,255)
|
||||
for (x, y, d) in v2d[::10]:
|
||||
cv2.circle(img, (int(x+0.5), int(y+0.5)), lw*2, col, -1)
|
||||
return img
|
||||
|
||||
class MultiStage:
|
||||
def __init__(self, batch_size, optimizer, monitor, initialize, stages) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.optimizer_args = optimizer
|
||||
self.monitor = monitor
|
||||
self.initialize = initialize
|
||||
self.stages = stages
|
||||
|
||||
def make_closure(self, body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module):
|
||||
def closure(debug=False, ret_kpts=False):
|
||||
# 0. Prepare body parameters => new_params
|
||||
optimizer.zero_grad()
|
||||
new_params = body_params.copy()
|
||||
for module in before_after_module:
|
||||
new_params = module.before(new_params)
|
||||
# 1. Compute keypoints => kpts_est
|
||||
poses_full = body_model.extend_poses(**new_params)
|
||||
kpts_est = body_model(return_verts=False, return_tensor=True, **new_params)
|
||||
if ret_kpts:
|
||||
return kpts_est
|
||||
verts_est = None
|
||||
# 2. Compute loss => loss_dict
|
||||
loss_dict = {}
|
||||
for key, loss_func in loss_funcs.items():
|
||||
if key.startswith('v'):
|
||||
if verts_est is None:
|
||||
verts_est = body_model(return_verts=True, return_tensor=True, **new_params)
|
||||
loss_dict[key] = loss_func(verts_est=verts_est, **new_params, **infos)
|
||||
elif key.startswith('pf-'):
|
||||
loss_dict[key] = loss_func(poses_full=poses_full, **new_params, **infos)
|
||||
else:
|
||||
loss_dict[key] = loss_func(kpts_est=kpts_est, **new_params, **infos)
|
||||
loss = sum([loss_dict[key]*weights[key]
|
||||
for key in loss_dict.keys()])
|
||||
if debug:
|
||||
return loss_dict
|
||||
loss.backward()
|
||||
return loss
|
||||
return closure
|
||||
|
||||
def optimizer_step(self, optimizer, closure, weights):
|
||||
prev_loss = None
|
||||
for iter_ in range(self.monitor.maxiters):
|
||||
with torch.no_grad():
|
||||
loss_dict = closure(debug=True)
|
||||
if self.monitor.printloss or (self.monitor.verbose and iter_ == 0):
|
||||
print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()]))
|
||||
loss = optimizer.step(closure)
|
||||
# check the loss
|
||||
if torch.isnan(loss).sum() > 0:
|
||||
print('[optimize] NaN loss value, stopping!')
|
||||
break
|
||||
if torch.isinf(loss).sum() > 0:
|
||||
print('[optimize] Infinite loss value, stopping!')
|
||||
break
|
||||
# check the delta
|
||||
if iter_ > 0 and prev_loss is not None:
|
||||
loss_rel_change = rel_change(prev_loss, loss.item())
|
||||
if loss_rel_change <= self.monitor.ftol:
|
||||
if self.monitor.printloss or self.monitor.verbose:
|
||||
print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()]))
|
||||
break
|
||||
# log
|
||||
if self.monitor.vis2d:
|
||||
pass
|
||||
if self.monitor.vis3d:
|
||||
pass
|
||||
prev_loss = loss.item()
|
||||
return True
|
||||
|
||||
def fit_stage(self, body_model, body_params, infos, stage, irepeat):
|
||||
# 单独拟合一个stage, 返回body_params
|
||||
optimizer_args = stage.get('optimizer', self.optimizer_args)
|
||||
dtype, device = body_model.dtype, body_model.device
|
||||
body_params = process(stage.get('at_start', {'convert': 'numpy_to_tensor'}), body_model, body_params, infos)
|
||||
opt_params = {}
|
||||
if 'optimize' in stage.keys():
|
||||
optimize_names = stage.optimize
|
||||
else:
|
||||
optimize_names = stage.optimizes[irepeat]
|
||||
for key in optimize_names:
|
||||
if key in infos.keys(): # 优化的参数
|
||||
infos[key] = infos[key].to(device)
|
||||
opt_params[key] = infos[key]
|
||||
elif key in body_params.keys():
|
||||
opt_params[key] = body_params[key]
|
||||
else:
|
||||
raise ValueError('{} is not in infos or body_params'.format(key))
|
||||
if self.monitor.verbose:
|
||||
print('[optimize] optimizing {}'.format(optimize_names))
|
||||
for key, val in opt_params.items():
|
||||
infos['init_'+key] = val.clone().detach().cpu()
|
||||
# initialize keypoints
|
||||
with torch.no_grad():
|
||||
kpts_est = body_model.keypoints(body_params)
|
||||
infos['init_kpts_est'] = kpts_est.clone().detach().cpu()
|
||||
before_after_module = make_before_after(stage.get('before_after', {}), body_model, body_params, infos)
|
||||
for module in before_after_module:
|
||||
# Input to this module is tensor
|
||||
body_params = module.start(body_params)
|
||||
grad_require(opt_params, True)
|
||||
optimizer = make_optimizer(opt_params, **optimizer_args)
|
||||
loss_funcs, weights = make_lossfuncs(stage, infos, device, irepeat, self.monitor.verbose)
|
||||
closure = self.make_closure(body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module)
|
||||
if self.monitor.check:
|
||||
new_params = body_params.copy()
|
||||
for module in before_after_module:
|
||||
new_params = module.before(new_params)
|
||||
kpts_est = body_model.keypoints(new_params)
|
||||
for key, loss in loss_funcs.items():
|
||||
loss.check_at_start(kpts_est=kpts_est, **new_params)
|
||||
self.optimizer_step(optimizer, closure, weights)
|
||||
grad_require(opt_params, False)
|
||||
if self.monitor.check:
|
||||
new_params = body_params.copy()
|
||||
for module in before_after_module:
|
||||
new_params = module.before(new_params)
|
||||
kpts_est = body_model.keypoints(new_params)
|
||||
for key, loss in loss_funcs.items():
|
||||
loss.check_at_end(kpts_est=kpts_est, **new_params)
|
||||
for module in before_after_module:
|
||||
# Input to this module is tensor
|
||||
body_params = module.final(body_params)
|
||||
body_params = dict_of_tensor_to_numpy(body_params)
|
||||
for key, val in opt_params.items():
|
||||
if key in infos.keys():
|
||||
infos[key] = val.detach().cpu()
|
||||
return body_params
|
||||
|
||||
def fit(self, body_model, dataset):
|
||||
batch_size = len(dataset) if self.batch_size == -1 else self.batch_size
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
|
||||
if len(dataloader) > 1:
|
||||
dataloader = tqdm(dataloader, desc='optimizing')
|
||||
for data in dataloader:
|
||||
data = dataset.reshape_data(data)
|
||||
infos = data.copy()
|
||||
init_params = body_model.init_params(nFrames=infos['nFrames'], nPerson=infos.get('nPerson', 1))
|
||||
# first initialize the model
|
||||
for name, init_func in self.initialize.items():
|
||||
if 'loss' in init_func.keys():
|
||||
# fitting to initialize
|
||||
init_params = self.fit_stage(body_model, init_params, infos, init_func, 0)
|
||||
else:
|
||||
# use initialize module
|
||||
init_module = load_object(init_func.module, init_func.args)
|
||||
init_params = init_module(body_model, init_params, infos)
|
||||
# if there are multiple initialization params
|
||||
# then fit each of them
|
||||
if not isinstance(init_params, list):
|
||||
init_params = [init_params]
|
||||
results = []
|
||||
for init_param in init_params:
|
||||
# check the repeat params
|
||||
body_params = init_param
|
||||
for stage_name, stage in self.stages.items():
|
||||
for irepeat in range(stage.get('repeat', 1)):
|
||||
with Timer('optimize {}'.format(stage_name), not self.monitor.timer):
|
||||
body_params = self.fit_stage(body_model, body_params, infos, stage, irepeat)
|
||||
results.append(body_params)
|
||||
# select the best results
|
||||
if len(results) > 1:
|
||||
# check the result
|
||||
loss = load_object(self.check.module, self.check.args, **{key:infos[key] for key in self.check.infos})
|
||||
metrics = [loss(body_model.keypoints(body_params, return_tensor=True).cpu()).item() for body_params in results]
|
||||
best_idx = np.argmin(metrics)
|
||||
else:
|
||||
best_idx = 0
|
||||
if 'sync_offset' in body_params.keys():
|
||||
offset = body_params.pop('sync_offset')
|
||||
dataset.write_offset(offset)
|
||||
body_params = Params(**results[best_idx])
|
||||
if data['nFrames'] != body_params['poses'].shape[0]:
|
||||
for key in body_params.keys():
|
||||
if body_params[key].shape[0] == 1:continue
|
||||
body_params[key] = body_params[key].reshape(data['nFrames'], -1, *body_params[key].shape[1:])
|
||||
print(key, body_params[key].shape)
|
||||
if 'K' in infos.keys():
|
||||
camera = Params(K=infos['K'].numpy(), R=infos['Rc'].numpy(), T=infos['Tc'].numpy())
|
||||
if 'mirror' in infos.keys():
|
||||
camera['mirror'] = infos['mirror'].numpy()[None]
|
||||
dataset.write(body_model, body_params, data, camera)
|
||||
else:
|
||||
# write data without camera
|
||||
dataset.write(body_model, body_params, data)
|
39
easymocap/multistage/base_ops.py
Normal file
39
easymocap/multistage/base_ops.py
Normal file
@ -0,0 +1,39 @@
|
||||
'''
|
||||
@ Date: 2022-08-12 20:34:15
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-08-18 14:47:23
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/base_ops.py
|
||||
'''
|
||||
import torch
|
||||
|
||||
class BeforeAfterBase:
|
||||
def __init__(self, model) -> None:
|
||||
pass
|
||||
|
||||
def start(self, body_params):
|
||||
# operation before the optimization
|
||||
return body_params
|
||||
|
||||
def before(self, body_params):
|
||||
# operation in each optimization step
|
||||
return body_params
|
||||
|
||||
def final(self, body_params):
|
||||
# operation after the optimization
|
||||
return body_params
|
||||
|
||||
class SkipPoses(BeforeAfterBase):
|
||||
def __init__(self, index, nPoses) -> None:
|
||||
self.index = index
|
||||
self.nPoses = nPoses
|
||||
self.copy_index = [i for i in range(nPoses) if i not in index]
|
||||
|
||||
def before(self, body_params):
|
||||
poses = body_params['poses']
|
||||
poses_copy = torch.zeros_like(poses)
|
||||
print(poses.shape)
|
||||
poses_copy[..., self.copy_index] = poses[..., self.copy_index]
|
||||
body_params['poses'] = poses_copy
|
||||
return body_params
|
57
easymocap/multistage/before_after.py
Normal file
57
easymocap/multistage/before_after.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
class Remove:
|
||||
def __init__(self, key, index=[], ranges=[]) -> None:
|
||||
self.key = key
|
||||
self.ranges = ranges
|
||||
self.index = index
|
||||
|
||||
def before(self, body_params):
|
||||
val = body_params[self.key]
|
||||
if self.ranges[0] == 0:
|
||||
val_zeros = torch.zeros_like(val[:, :self.ranges[1]])
|
||||
val = torch.cat([val_zeros, val[:, self.ranges[1]:]], dim=1)
|
||||
body_params[self.key] = val
|
||||
return body_params
|
||||
|
||||
class RemoveHand:
|
||||
def __init__(self, start=60) -> None:
|
||||
pass
|
||||
|
||||
def before(self, body_params):
|
||||
poses = body_params['poses']
|
||||
val_zeros = torch.zeros_like(poses[:, 60:])
|
||||
val = torch.cat([poses[:, :60], val_zeros], dim=1)
|
||||
body_params['poses'] = val
|
||||
return body_params
|
||||
|
||||
class Keep:
|
||||
def __init__(self, key, ranges=[], index=[]) -> None:
|
||||
self.key = key
|
||||
self.ranges = ranges
|
||||
self.index = index
|
||||
|
||||
def before(self, body_params):
|
||||
val = body_params[self.key]
|
||||
val_zeros = val.detach().clone()
|
||||
if len(self.ranges) > 0:
|
||||
val_zeros[..., self.ranges[0]:self.ranges[1]] = val[..., self.ranges[0]:self.ranges[1]]
|
||||
elif len(self.index) > 0:
|
||||
val_zeros[..., self.index] = val[..., self.index]
|
||||
body_params[self.key] = val_zeros
|
||||
return body_params
|
||||
|
||||
def final(self, body_params):
|
||||
return body_params
|
||||
|
||||
class VPoser2Full:
|
||||
def __init__(self, key) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
if not 'Embedding' in body_model.__class__.__name__:
|
||||
return body_params
|
||||
poses = body_params['poses']
|
||||
poses_full = body_model.decode(poses, add_rot=False)
|
||||
body_params['poses'] = poses_full
|
||||
return body_params
|
1771
easymocap/multistage/fitting.py
Normal file
1771
easymocap/multistage/fitting.py
Normal file
File diff suppressed because it is too large
Load Diff
100
easymocap/multistage/init_cnn.py
Normal file
100
easymocap/multistage/init_cnn.py
Normal file
@ -0,0 +1,100 @@
|
||||
'''
|
||||
@ Date: 2022-04-26 17:54:28
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-07-11 22:20:44
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/init_cnn.py
|
||||
'''
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from os.path import join
|
||||
import torch
|
||||
from ..bodymodel.base import Params
|
||||
from ..estimator.wrapper_base import bbox_from_keypoints
|
||||
from ..mytools.writer import write_smpl
|
||||
from ..mytools.reader import read_smpl
|
||||
|
||||
class InitSpin:
|
||||
# initialize the smpl results by spin
|
||||
def __init__(self, mean_params, ckpt_path, share_shape,
|
||||
multi_person=False, compose_mp=False) -> None:
|
||||
from ..estimator.SPIN.spin_api import SPIN
|
||||
import torch
|
||||
self.share_shape = share_shape
|
||||
self.spin_model = SPIN(
|
||||
SMPL_MEAN_PARAMS=mean_params,
|
||||
checkpoint=ckpt_path,
|
||||
device=torch.device('cpu'))
|
||||
self.distortMap = {}
|
||||
self.multi_person = multi_person
|
||||
self.compose_mp = compose_mp
|
||||
|
||||
def undistort(self, image, K, dist, nv):
|
||||
if np.linalg.norm(dist) < 0.01:
|
||||
return image
|
||||
if nv not in self.distortMap.keys():
|
||||
h, w = image.shape[:2]
|
||||
mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, K, (w,h), 5)
|
||||
self.distortMap[nv] = (mapx, mapy)
|
||||
mapx, mapy = self.distortMap[nv]
|
||||
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
|
||||
return image
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
self.spin_model.model.to(body_model.device)
|
||||
self.spin_model.device = body_model.device
|
||||
params_all = []
|
||||
for nf, imgname in enumerate(tqdm(infos['imgname'], desc='Run SPIN')):
|
||||
# 暂时不考虑多视角情况
|
||||
# TODO: 没有考虑多人的情况
|
||||
basename = os.sep.join(imgname.split(os.sep)[-2:]).split('.')[0] + '.json'
|
||||
sub = os.path.dirname(basename)
|
||||
cache_dir = os.path.abspath(join(os.sep.join(imgname.split(os.sep)[:-3]), 'cache_spin'))
|
||||
outname = join(cache_dir, basename)
|
||||
if os.path.exists(outname):
|
||||
params = read_smpl(outname)
|
||||
if self.multi_person:
|
||||
params_all.append(params)
|
||||
else:
|
||||
params_all.append(params[0])
|
||||
continue
|
||||
camera = {key: infos[key][nf].numpy() for key in ['K', 'Rc', 'Tc', 'dist']}
|
||||
camera['R'] = camera['Rc']
|
||||
camera['T'] = camera['Tc']
|
||||
image = cv2.imread(imgname)
|
||||
image = self.undistort(image, camera['K'], camera['dist'], sub)
|
||||
if len(infos['keypoints2d'].shape) == 3:
|
||||
k2d = infos['keypoints2d'][nf][None]
|
||||
else:
|
||||
k2d = infos['keypoints2d'][nf]
|
||||
params_current = []
|
||||
for pid in range(k2d.shape[0]):
|
||||
keypoints = k2d[pid].numpy()
|
||||
bbox = bbox_from_keypoints(keypoints)
|
||||
nValid = (keypoints[:, -1] > 0).sum()
|
||||
if nValid > 4:
|
||||
result = self.spin_model(body_model, image,
|
||||
bbox, keypoints, camera, ret_vertices=False)
|
||||
elif len(params_all) == 0:
|
||||
print('[WARN] not enough joints: {} in first frame'.format(imgname))
|
||||
else:
|
||||
print('[WARN] not enough joints: {}'.format(imgname))
|
||||
result = {'body_params': params_all[-1][pid]}
|
||||
params = result['body_params']
|
||||
params['id'] = pid
|
||||
params_current.append(params)
|
||||
write_smpl(outname, params_current)
|
||||
if self.multi_person:
|
||||
params_all.append(params_current)
|
||||
else:
|
||||
params_all.append(params_current[0])
|
||||
if not self.multi_person:
|
||||
params_all = Params.merge(params_all, share_shape=self.share_shape)
|
||||
params_all = body_model.encode(params_all)
|
||||
elif self.compose_mp:
|
||||
params_all = Params.merge([Params.merge(p_, share_shape=False) for p_ in params_all], share_shape=False, stack=np.stack)
|
||||
params_all['id'] = 0
|
||||
return params_all
|
36
easymocap/multistage/init_pose.py
Normal file
36
easymocap/multistage/init_pose.py
Normal file
@ -0,0 +1,36 @@
|
||||
'''
|
||||
@ Date: 2022-04-02 13:59:50
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-07-13 16:34:21
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/init_pose.py
|
||||
'''
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from os.path import join
|
||||
import torch
|
||||
from ..bodymodel.base import Params
|
||||
from ..estimator.wrapper_base import bbox_from_keypoints
|
||||
from ..mytools.writer import write_smpl
|
||||
from ..mytools.reader import read_smpl
|
||||
|
||||
class SmoothPoses:
|
||||
def __init__(self, window_size) -> None:
|
||||
self.W = window_size
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
poses = body_params['poses']
|
||||
padding_before = poses[:1].copy().repeat(self.W, 0)
|
||||
padding_after = poses[-1:].copy().repeat(self.W, 0)
|
||||
mean = poses.copy()
|
||||
nFrames = mean.shape[0]
|
||||
poses_full = np.vstack([padding_before, poses, padding_after])
|
||||
for w in range(1, self.W+1):
|
||||
mean += poses_full[self.W-w:self.W-w+nFrames]
|
||||
mean += poses_full[self.W+w:self.W+w+nFrames]
|
||||
mean /= 2*self.W + 1
|
||||
body_params['poses'] = mean
|
||||
return body_params
|
172
easymocap/multistage/initialize.py
Normal file
172
easymocap/multistage/initialize.py
Normal file
@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from ..dataset.config import CONFIG
|
||||
from ..config import load_object
|
||||
from ..mytools.debug_utils import log, mywarn, myerror
|
||||
import torch
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
def svd_rot(src, tgt, reflection=False, debug=False):
|
||||
# optimum rotation matrix of Y
|
||||
A = np.matmul(src.transpose(0, 2, 1), tgt)
|
||||
U, s, Vt = np.linalg.svd(A, full_matrices=False)
|
||||
V = Vt.transpose(0, 2, 1)
|
||||
T = np.matmul(V, U.transpose(0, 2, 1))
|
||||
# does the current solution use a reflection?
|
||||
have_reflection = np.linalg.det(T) < 0
|
||||
|
||||
# if that's not what was specified, force another reflection
|
||||
V[have_reflection, :, -1] *= -1
|
||||
s[have_reflection, -1] *= -1
|
||||
T = np.matmul(V, U.transpose(0, 2, 1))
|
||||
if debug:
|
||||
err = np.linalg.norm(tgt - src @ T.T, axis=1)
|
||||
print('[svd] ', err)
|
||||
return T
|
||||
|
||||
def batch_invRodrigues(rot):
|
||||
res = []
|
||||
for r in rot:
|
||||
v = cv2.Rodrigues(r)[0]
|
||||
res.append(v)
|
||||
res = np.stack(res)
|
||||
return res[:, :, 0]
|
||||
|
||||
class BaseInit:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, body_model, body_params, infos):\
|
||||
return body_params
|
||||
|
||||
class Remove(BaseInit):
|
||||
def __init__(self, key, index) -> None:
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.index = index
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
infos[self.key][..., self.index, :] = 0
|
||||
return super().__call__(body_model, body_params, infos)
|
||||
|
||||
class CheckKeypoints:
|
||||
def __init__(self, type) -> None:
|
||||
# this class is used to check if the provided keypoints3d
|
||||
self.type = type
|
||||
self.body_config = CONFIG[type]
|
||||
self.hand_config = CONFIG['hand']
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
for key in ['keypoints3d', 'handl3d', 'handr3d']:
|
||||
if key not in infos.keys(): continue
|
||||
keypoints = infos[key]
|
||||
conf = keypoints[..., -1]
|
||||
keypoints[conf<0.1] = 0
|
||||
if key == 'keypoints3d':
|
||||
continue
|
||||
import ipdb;ipdb.set_trace()
|
||||
# limb_length = np.linalg.norm(keypoints[:, , :3], axis=2)
|
||||
return body_params
|
||||
|
||||
class InitRT:
|
||||
def __init__(self, torso) -> None:
|
||||
self.torso = torso
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
keypoints3d = infos['keypoints3d'].detach().cpu().numpy()
|
||||
temp_joints = body_model.keypoints(body_params, return_tensor=False)
|
||||
|
||||
torso = keypoints3d[..., self.torso, :3].copy()
|
||||
torso_temp = temp_joints[..., self.torso, :3].copy()
|
||||
# here use the first id of torso as the rotation center
|
||||
root, root_temp = torso[..., :1, :], torso_temp[..., :1, :]
|
||||
torso = torso - root
|
||||
torso_temp = torso_temp - root_temp
|
||||
conf = (keypoints3d[..., self.torso, 3] > 0.).all(axis=-1)
|
||||
if not conf.all():
|
||||
myerror("The torso in frames {} is not valid, please check the 3d keypoints".format(np.where(~conf)))
|
||||
if len(torso.shape) == 3:
|
||||
R = svd_rot(torso_temp, torso)
|
||||
R_flat = R
|
||||
T = np.matmul(- root_temp, R.transpose(0, 2, 1)) + root
|
||||
else:
|
||||
R_flat = svd_rot(torso_temp.reshape(-1, *torso_temp.shape[-2:]), torso.reshape(-1, *torso.shape[-2:]))
|
||||
R = R_flat.reshape(*torso.shape[:2], 3, 3)
|
||||
T = np.matmul(- root_temp, R.swapaxes(-1, -2)) + root
|
||||
for nf in np.where(~conf)[0]:
|
||||
# copy previous frames
|
||||
mywarn('copy {} from {}'.format(nf, nf-1))
|
||||
R[nf] = R[nf-1]
|
||||
T[nf] = T[nf-1]
|
||||
body_params['Th'] = T[..., 0, :]
|
||||
rvec = batch_invRodrigues(R_flat)
|
||||
if len(torso.shape) > 3:
|
||||
rvec = rvec.reshape(*torso.shape[:2], 3)
|
||||
body_params['Rh'] = rvec
|
||||
return body_params
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "[Initialize] svd with torso: {}".format(self.torso)
|
||||
|
||||
class TriangulatorWrapper:
|
||||
def __init__(self, module, args):
|
||||
self.triangulator = load_object(module, args)
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
infos['RT'] = torch.cat([infos['Rc'], infos['Tc']], dim=-1)
|
||||
data = {
|
||||
'RT': infos['RT'].numpy(),
|
||||
}
|
||||
for key in self.triangulator.keys:
|
||||
if key not in infos.keys():
|
||||
continue
|
||||
data[key] = infos[key].numpy()
|
||||
data[key+'_unproj'] = infos[key+'_unproj'].numpy()
|
||||
data[key+'_distort'] = infos[key+'_distort'].numpy()
|
||||
results = self.triangulator(data)[0]
|
||||
for key, val in results.items():
|
||||
if key == 'id': continue
|
||||
infos[key] = torch.Tensor(val[None].astype(np.float32))
|
||||
body_params = body_model.init_params(nFrames=1, add_scale=True)
|
||||
return body_params
|
||||
|
||||
class CheckRT:
|
||||
def __init__(self, T_thres, window):
|
||||
self.T_thres = T_thres
|
||||
self.window = window
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
Th = body_params['Th']
|
||||
if len(Th.shape) == 3:
|
||||
for nper in range(Th.shape[1]):
|
||||
for nf in trange(1, Th.shape[0], desc='Check Th of {}'.format(nper)):
|
||||
if nf > self.window:
|
||||
tpre = Th[nf-self.window:nf, nper]
|
||||
else:
|
||||
tpre = Th[:nf, nper]
|
||||
tpre = tpre.mean(axis=0)
|
||||
tnow = Th[nf , nper]
|
||||
dist = np.linalg.norm(tnow - tpre)
|
||||
if dist > self.T_thres:
|
||||
mywarn('[Check Th] distance in frame {} = {} larger than {}'.format(nf, dist, self.T_thres))
|
||||
Th[nf, nper] = tpre
|
||||
body_params['Th'] = Th
|
||||
return body_params
|
||||
|
||||
class Scale:
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
scale = body_params.pop('scale')[0, 0]
|
||||
if scale < 1.1 and scale > 0.9:
|
||||
return body_params
|
||||
print('scale = ', scale)
|
||||
for key in self.keys:
|
||||
if key not in infos.keys():
|
||||
continue
|
||||
infos[key] /= scale
|
||||
infos['Tc'] /= scale
|
||||
infos['RT'][..., -1] *= scale
|
||||
infos['scale'] = scale
|
||||
return body_params
|
589
easymocap/multistage/lossbase.py
Normal file
589
easymocap/multistage/lossbase.py
Normal file
@ -0,0 +1,589 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from ..bodymodel.lbs import batch_rodrigues
|
||||
|
||||
class LossBase(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return '# lack of comment'
|
||||
|
||||
def check_at_start(self, **kwargs):
|
||||
pass
|
||||
|
||||
def check_at_end(self, **kwargs):
|
||||
pass
|
||||
|
||||
class GMoF(nn.Module):
|
||||
def __init__(self, rho=1):
|
||||
super(GMoF, self).__init__()
|
||||
self.rho2 = rho * rho
|
||||
|
||||
def extra_repr(self):
|
||||
return 'rho = {}'.format(self.rho)
|
||||
|
||||
def forward(self, est, gt=None, conf=None):
|
||||
if gt is not None:
|
||||
square_diff = torch.sum((est - gt)**2, dim=-1)
|
||||
else:
|
||||
square_diff = torch.sum(est**2, dim=-1)
|
||||
diff = torch.div(square_diff, square_diff + self.rho2)
|
||||
if conf is not None:
|
||||
res = torch.sum(diff * conf)/(1e-5 + conf.sum())
|
||||
else:
|
||||
res = diff.sum()/diff.numel()
|
||||
return res
|
||||
|
||||
def make_loss(norm, norm_info, reduce='sum'):
|
||||
reduce = torch.sum if reduce=='sum' else torch.mean
|
||||
if norm == 'l2':
|
||||
def loss(est, gt=None, conf=None):
|
||||
if gt is not None:
|
||||
square_diff = reduce((est - gt)**2, dim=-1)
|
||||
else:
|
||||
square_diff = reduce(est**2, dim=-1)
|
||||
if conf is not None:
|
||||
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
|
||||
else:
|
||||
res = square_diff.sum()/square_diff.numel()
|
||||
return res
|
||||
elif norm == 'gm':
|
||||
loss = GMoF(norm_info)
|
||||
return loss
|
||||
|
||||
def select(value, ranges, index, dim):
|
||||
if len(ranges) > 0:
|
||||
if ranges[1] == -1:
|
||||
value = value[..., ranges[0]:]
|
||||
else:
|
||||
value = value[..., ranges[0]:ranges[1]]
|
||||
return value
|
||||
if len(index) > 0:
|
||||
if dim == -1:
|
||||
value = value[..., index]
|
||||
elif dim == -2:
|
||||
value = value[..., index, :]
|
||||
return value
|
||||
return value
|
||||
|
||||
def print_table(header, contents):
|
||||
from tabulate import tabulate
|
||||
length = len(contents[0])
|
||||
tables = [[] for _ in range(length)]
|
||||
mean = ['Mean']
|
||||
for icnt, content in enumerate(contents):
|
||||
for i in range(length):
|
||||
if isinstance(content[i], float):
|
||||
tables[i].append('{:6.2f}'.format(content[i]))
|
||||
else:
|
||||
tables[i].append('{}'.format(content[i]))
|
||||
if icnt > 0:
|
||||
mean.append('{:6.2f}'.format(sum(content)/length))
|
||||
tables.append(mean)
|
||||
print(tabulate(tables, header, tablefmt='fancy_grid'))
|
||||
|
||||
class AnyReg(LossBase):
|
||||
def __init__(self, key, norm, dim=-1, reduce='sum', norm_info={}, ranges=[], index=[], **kwargs):
|
||||
super().__init__()
|
||||
self.ranges = ranges
|
||||
self.index = index
|
||||
self.key = key
|
||||
self.dim = dim
|
||||
if 'init_' + key in kwargs.keys():
|
||||
init = kwargs['init_'+key]
|
||||
self.register_buffer('init', torch.Tensor(init))
|
||||
else:
|
||||
self.init = None
|
||||
self.norm_name = norm
|
||||
self.loss = make_loss(norm, norm_info, reduce=reduce)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
"""
|
||||
value: (nFrames, ..., nDims)
|
||||
"""
|
||||
value = kwargs[self.key]
|
||||
if self.init is not None:
|
||||
value = value - self.init
|
||||
value = select(value, self.ranges, self.index, self.dim)
|
||||
return self.loss(value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return 'Loss for {}'.format(self.key, self.norm_name)
|
||||
|
||||
class RegPrior(AnyReg):
|
||||
def __init__(self, **cfg):
|
||||
super().__init__(**cfg)
|
||||
self.init = None # disable init
|
||||
infos = {
|
||||
(2, 0): '-exp',
|
||||
(2, 1): 'l2',
|
||||
(2, 2): 'l2',
|
||||
(3, 0): '-exp', # knee
|
||||
(3, 1): 'L2',
|
||||
(3, 2): 'L2',
|
||||
(4, 0): '-exp', # knee
|
||||
(4, 1): 'L2',
|
||||
(4, 2): 'L2',
|
||||
(5, 0): '-exp',
|
||||
(5, 1): 'l2',
|
||||
(5, 2): 'l2',
|
||||
(6, 1): 'L2',
|
||||
(6, 2): 'L2',
|
||||
(7, 1): 'L2',
|
||||
(7, 2): 'L2',
|
||||
(8, 0): '-exp',
|
||||
(8, 1): 'l2',
|
||||
(8, 2): 'l2',
|
||||
(9, 0): 'L2',
|
||||
(9, 1): 'L2',
|
||||
(9, 2): 'L2',
|
||||
(10, 0): 'L2',
|
||||
(10, 1): 'L2',
|
||||
(10, 2): 'L2',
|
||||
(12, 0): 'l2', # 肩关节前面
|
||||
(13, 0): 'l2',
|
||||
(17, 1): 'exp',
|
||||
(17, 2): 'L2',
|
||||
(18, 1): '-exp',
|
||||
(18, 2): 'L2',
|
||||
}
|
||||
self.l2dims = []
|
||||
self.L2dims = []
|
||||
self.expdims = []
|
||||
self.nexpdims = []
|
||||
for (nj, ndim), norm in infos.items():
|
||||
dim = nj*3 + ndim
|
||||
if norm == 'l2':
|
||||
self.l2dims.append(dim)
|
||||
elif norm == 'L2':
|
||||
self.L2dims.append(dim)
|
||||
elif norm == '-exp':
|
||||
self.nexpdims.append(dim)
|
||||
elif norm == 'exp':
|
||||
self.expdims.append(dim)
|
||||
|
||||
def forward(self, poses, **kwargs):
|
||||
"""
|
||||
poses: (..., nDims)
|
||||
"""
|
||||
alll2loss = torch.mean(poses**2)
|
||||
l2loss = torch.sum(poses[:, self.l2dims]**2)/len(self.l2dims)/poses.shape[0]
|
||||
L2loss = torch.sum(poses[:, self.L2dims]**2)/len(self.L2dims)/poses.shape[0]
|
||||
exploss = torch.sum(torch.exp(poses[:, self.expdims]))/poses.shape[0]
|
||||
nexploss = torch.sum(torch.exp(-poses[:, self.nexpdims]))/poses.shape[0]
|
||||
loss = 0.1*l2loss + L2loss + 0.0005*(exploss + nexploss)/(len(self.expdims) + len(self.nexpdims)) + 0.01*alll2loss
|
||||
return loss
|
||||
|
||||
class VPoserPrior(AnyReg):
|
||||
def __init__(self, **cfg):
|
||||
super().__init__(**cfg)
|
||||
vposer_ckpt = 'data/bodymodels/vposer_v02'
|
||||
from human_body_prior.tools.model_loader import load_model
|
||||
from human_body_prior.models.vposer_model import VPoser
|
||||
vposer, _ = load_model(vposer_ckpt,
|
||||
model_code=VPoser,
|
||||
remove_words_in_model_weights='vp_model.',
|
||||
disable_grad=True)
|
||||
vposer.eval()
|
||||
self.vposer = vposer
|
||||
self.init = None # disable init
|
||||
|
||||
def forward(self, poses, **kwargs):
|
||||
"""
|
||||
poses: (..., nDims)
|
||||
"""
|
||||
nDims = 63
|
||||
poses_body = poses[..., :nDims].reshape(-1, nDims)
|
||||
latent = self.vposer.encode(poses_body)
|
||||
if True:
|
||||
ret = self.vposer.decode(latent.sample())['pose_body'].reshape(poses.shape[0], nDims)
|
||||
return super().forward(poses=poses_body-ret)
|
||||
else:
|
||||
return super().forward(poses=latent.mean)
|
||||
|
||||
class AnySmooth(LossBase):
|
||||
def __init__(self, key, weight, norm, norm_info={}, ranges=[], index=[], dim=-1, order=1):
|
||||
super().__init__()
|
||||
self.ranges = ranges
|
||||
self.index = index
|
||||
self.dim = dim
|
||||
self.weight = weight
|
||||
self.loss = make_loss(norm, norm_info)
|
||||
self.norm_name = norm
|
||||
self.key = key
|
||||
self.order = order
|
||||
|
||||
def forward(self, **kwargs):
|
||||
loss = 0
|
||||
value = kwargs[self.key]
|
||||
value = select(value, self.ranges, self.index, self.dim)
|
||||
if value.shape[0] <= len(self.weight):
|
||||
return torch.FloatTensor([0.]).to(value.device)
|
||||
for width, weight in enumerate(self.weight, start=1):
|
||||
vel = value[width:] - value[:-width]
|
||||
if self.order == 2:
|
||||
vel = vel[1:] - vel[:-1]
|
||||
loss += weight * self.loss(vel)
|
||||
return loss
|
||||
|
||||
def check(self, value):
|
||||
vel = value[1:] - value[:-1]
|
||||
if len(vel.shape) > 2:
|
||||
vel = torch.norm(vel, dim=-1)
|
||||
else:
|
||||
vel = torch.abs(vel)
|
||||
vel = vel.detach().cpu().numpy()
|
||||
return vel
|
||||
|
||||
def get_check_name(self, value):
|
||||
name = [str(i) for i in range(value.shape[1])]
|
||||
return name
|
||||
|
||||
def check_at_start(self, **kwargs):
|
||||
value = kwargs[self.key]
|
||||
if value.shape[0] < len(self.weight):
|
||||
return 0
|
||||
header = ['Smooth '+self.key, 'mean(before)', 'max(before)', 'frame(before)']
|
||||
name = self.get_check_name(value)
|
||||
vel = self.check(value)
|
||||
contents = [name, vel.mean(axis=0).tolist(), vel.max(axis=0).tolist(), vel.argmax(axis=0).tolist()]
|
||||
self.cache_check = (header, contents)
|
||||
return super().check_at_start(**kwargs)
|
||||
|
||||
def check_at_end(self, **kwargs):
|
||||
value = kwargs[self.key]
|
||||
if value.shape[0] < len(self.weight):
|
||||
return 0
|
||||
err_after = self.check(kwargs[self.key])
|
||||
header, contents = self.cache_check
|
||||
header.extend(['mean(after)', 'max(after)', 'frame(after)'])
|
||||
contents.extend([err_after.mean(axis=0).tolist(), err_after.max(axis=0).tolist(), err_after.argmax(axis=0).tolist()])
|
||||
print_table(header, contents)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "smooth in {} frames, range={}, norm={}".format(self.weight, self.ranges, self.norm_name)
|
||||
|
||||
class SmoothRot(AnySmooth):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
from ..bodymodel.lbs import batch_rodrigues
|
||||
self.rodrigues = batch_rodrigues
|
||||
|
||||
def convert_Rh_to_R(self, Rh):
|
||||
shape = Rh.shape[1]
|
||||
ret = []
|
||||
for i in range(shape//3):
|
||||
Rot = self.rodrigues(Rh[:, 3*i:3*(i+1)])
|
||||
ret.append(Rot)
|
||||
ret = torch.cat(ret, dim=1)
|
||||
return ret
|
||||
|
||||
def forward(self, **kwargs):
|
||||
Rh = kwargs[self.key]
|
||||
if Rh.shape[-1] != 3:
|
||||
loss = 0
|
||||
for i in range(Rh.shape[-1]//3):
|
||||
Rh_sub = Rh[..., 3*i:3*i+3]
|
||||
Rot = self.convert_Rh_to_R(Rh_sub).view(*Rh_sub.shape[:-1], 3, 3)
|
||||
loss += super().forward(**{self.key: Rot})
|
||||
return loss
|
||||
else:
|
||||
Rh_flat = Rh.view(-1, 3)
|
||||
Rot = self.convert_Rh_to_R(Rh_flat).view(*Rh.shape[:-1], 3, 3)
|
||||
return super().forward(**{self.key: Rot})
|
||||
|
||||
def get_check_name(self, value):
|
||||
name = ['angle']
|
||||
return name
|
||||
|
||||
def check(self, value):
|
||||
import cv2
|
||||
# TODO: here just use first rotation
|
||||
if len(value.shape) == 3:
|
||||
value = value[:, 0]
|
||||
Rot = self.convert_Rh_to_R(value.detach())[:, :3]
|
||||
vel = torch.matmul(Rot[:-1], Rot.transpose(1,2)[1:]).cpu().numpy()
|
||||
vels = []
|
||||
for i in range(vel.shape[0]):
|
||||
angle = np.linalg.norm(cv2.Rodrigues(vel[i])[0])
|
||||
vels.append(angle)
|
||||
vel = np.array(vels).reshape(-1, 1)
|
||||
return vel
|
||||
|
||||
class BaseKeypoints(LossBase):
|
||||
@staticmethod
|
||||
def select(keypoints, index, ranges):
|
||||
if len(index) > 0:
|
||||
keypoints = keypoints[..., index, :]
|
||||
elif len(ranges) > 0:
|
||||
if ranges[1] == -1:
|
||||
keypoints = keypoints[..., ranges[0]:, :]
|
||||
else:
|
||||
keypoints = keypoints[..., ranges[0]:ranges[1], :]
|
||||
return keypoints
|
||||
|
||||
def set_gt(self, index_gt, ranges_gt):
|
||||
keypoints = self.select(self.keypoints_np, index_gt, ranges_gt)
|
||||
keypoints = torch.Tensor(keypoints)
|
||||
self.register_buffer('keypoints', keypoints[..., :-1])
|
||||
self.register_buffer('conf', keypoints[..., -1])
|
||||
|
||||
def __str__(self):
|
||||
return "keypoints: {}".format(self.keypoints.shape)
|
||||
|
||||
def __init__(self, keypoints, norm='l2', norm_info={},
|
||||
index_gt=[], ranges_gt=[],
|
||||
index_est=[], ranges_est=[]) -> None:
|
||||
super().__init__()
|
||||
# prepare ground-truth
|
||||
self.keypoints_np = keypoints
|
||||
self.set_gt(index_gt, ranges_gt)
|
||||
#
|
||||
self.index_est = index_est
|
||||
self.ranges_est = ranges_est
|
||||
|
||||
self.norm = norm
|
||||
self.loss = make_loss(norm, norm_info)
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||||
return self.loss(est, self.keypoints, self.conf)
|
||||
|
||||
def check(self, kpts_est, min_conf=0.3, **kwargs):
|
||||
est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||||
conf = (self.conf>min_conf).float()
|
||||
norm = torch.norm(est-self.keypoints, dim=-1) * conf
|
||||
mean_joints = norm.sum(dim=0)/(1e-5 + conf.sum(dim=0)) * 1000
|
||||
return conf, mean_joints
|
||||
|
||||
def check_at_start(self, kpts_est, **kwargs):
|
||||
if len(self.index_est) > 0:
|
||||
names = [str(i) for i in self.index_est]
|
||||
elif len(self.ranges_est) > 0:
|
||||
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
|
||||
else:
|
||||
names = [str(i) for i in range(self.conf.shape[-1])]
|
||||
conf, error = self.check(kpts_est, **kwargs)
|
||||
valid = conf.sum(dim=0).detach().cpu().numpy().tolist()
|
||||
header = ['name', 'count']
|
||||
contents = [names, valid]
|
||||
header.append('before')
|
||||
contents.append(error.detach().cpu().numpy().tolist())
|
||||
self.cache_check = (header, contents)
|
||||
|
||||
def check_at_end(self, kpts_est, **kwargs):
|
||||
conf, err_after = self.check(kpts_est, **kwargs)
|
||||
header, contents = self.cache_check
|
||||
header.append('after')
|
||||
contents.append(err_after.detach().cpu().numpy().tolist())
|
||||
print_table(header, contents)
|
||||
|
||||
class Keypoints3D(BaseKeypoints):
|
||||
def __init__(self, keypoints3d, **kwargs) -> None:
|
||||
super().__init__(keypoints3d, **kwargs)
|
||||
|
||||
class AnyKeypoints3D(Keypoints3D):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
key = kwargs.pop('key')
|
||||
keypoints3d = kwargs.pop(key)
|
||||
super().__init__(keypoints3d, **kwargs)
|
||||
self.key = key
|
||||
|
||||
class AnyKeypoints3DWithRT(Keypoints3D):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
key = kwargs.pop('key')
|
||||
keypoints3d = kwargs.pop(key)
|
||||
super().__init__(keypoints3d, **kwargs)
|
||||
self.key = key
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
R = batch_rodrigues(kwargs['R_'+self.key])
|
||||
T = kwargs['T_'+self.key]
|
||||
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
|
||||
return super().forward(RXT)
|
||||
|
||||
def check(self, kpts_est, min_conf=0.3, **kwargs):
|
||||
R = batch_rodrigues(kwargs['R_'+self.key])
|
||||
T = kwargs['T_'+self.key]
|
||||
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
|
||||
kpts_est = RXT
|
||||
return super().check(kpts_est, min_conf)
|
||||
|
||||
class Handl3D(BaseKeypoints):
|
||||
def __init__(self, handl3d, **kwargs) -> None:
|
||||
handl3d = handl3d.clone()
|
||||
handl3d[..., :3] = handl3d[..., :3] - handl3d[:, :1, :3]
|
||||
super().__init__(handl3d, **kwargs)
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
est = kpts_est[:, 25:46]
|
||||
est = est - est[:, :1].detach()
|
||||
return super().forward(est, **kwargs)
|
||||
|
||||
def check(self, kpts_est, **kwargs):
|
||||
est = kpts_est[:, 25:46]
|
||||
est = est - est[:, :1].detach()
|
||||
return super().check(est, **kwargs)
|
||||
|
||||
|
||||
class LimbLength(BaseKeypoints):
|
||||
def __init__(self, kintree, key='keypoints3d', **kwargs):
|
||||
self.kintree = np.array(kintree)
|
||||
if key == 'bodyhand':
|
||||
keypoints3d = np.hstack([kwargs.pop('keypoints3d'), kwargs.pop('handl3d'), kwargs.pop('handr3d')])
|
||||
else:
|
||||
keypoints3d = kwargs.pop(key)
|
||||
super().__init__(keypoints3d, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree]))
|
||||
|
||||
def set_gt(self, index_gt, ranges_gt):
|
||||
keypoints3d = self.keypoints_np
|
||||
kintree = self.kintree
|
||||
# limb_length: nFrames, nLimbs, 1
|
||||
limb_length = np.linalg.norm(keypoints3d[..., kintree[:, 1], :3] - keypoints3d[..., kintree[:, 0], :3], axis=-1, keepdims=True)
|
||||
# conf: nFrames, nLimbs, 1
|
||||
limb_conf = np.minimum(keypoints3d[..., kintree[:, 1], -1], keypoints3d[..., kintree[:, 0], -1])
|
||||
limb_length = torch.Tensor(limb_length)
|
||||
limb_conf = torch.Tensor(limb_conf)
|
||||
self.register_buffer('length', limb_length)
|
||||
self.register_buffer('conf', limb_conf)
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
src = kpts_est[..., self.kintree[:, 0], :]
|
||||
dst = kpts_est[..., self.kintree[:, 1], :]
|
||||
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
|
||||
return self.loss(length_est, self.length, self.conf)
|
||||
|
||||
def check_at_start(self, kpts_est, **kwargs):
|
||||
names = [str(i) for i in self.kintree]
|
||||
conf = (self.conf>0)
|
||||
valid = conf.sum(dim=0).detach().cpu().numpy()
|
||||
if len(valid.shape) == 2:
|
||||
valid = valid.mean(axis=0)
|
||||
header = ['name', 'count']
|
||||
contents = [names, valid.tolist()]
|
||||
error, length = self.check(kpts_est)
|
||||
header.append('before')
|
||||
contents.append(error.detach().cpu().numpy().tolist())
|
||||
header.append('length')
|
||||
length = (self.length[..., 0] * self.conf).sum(dim=0)/self.conf.sum(dim=0)
|
||||
contents.append(length.detach().cpu().numpy().tolist())
|
||||
self.cache_check = (header, contents)
|
||||
|
||||
def check_at_end(self, kpts_est, **kwargs):
|
||||
err_after, length = self.check(kpts_est)
|
||||
header, contents = self.cache_check
|
||||
header.append('after')
|
||||
contents.append(err_after.detach().cpu().numpy().tolist())
|
||||
header.append('length_est')
|
||||
contents.append(length[:,:,0].mean(dim=0).detach().cpu().numpy().tolist())
|
||||
print_table(header, contents)
|
||||
|
||||
def check(self, kpts_est, **kwargs):
|
||||
src = kpts_est[..., self.kintree[:, 0], :]
|
||||
dst = kpts_est[..., self.kintree[:, 1], :]
|
||||
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
|
||||
conf = (self.conf>0).float()
|
||||
norm = torch.abs(length_est-self.length)[..., 0] * conf
|
||||
mean_joints = norm.sum(dim=0)/conf.sum(dim=0) * 1000
|
||||
if len(mean_joints.shape) == 2:
|
||||
mean_joints = mean_joints.mean(dim=0)
|
||||
length_est = length_est.mean(dim=0)
|
||||
return mean_joints, length_est
|
||||
|
||||
class LimbLengthHand(LimbLength):
|
||||
def __init__(self, handl3d, handr3d, **kwargs):
|
||||
kintree = kwargs.pop('kintree')
|
||||
keypoints3d = torch.cat([handl3d, handr3d], dim=0)
|
||||
super().__init__(kintree, keypoints3d, **kwargs)
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
|
||||
return super().forward(kpts_est, **kwargs)
|
||||
|
||||
def check(self, kpts_est, **kwargs):
|
||||
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
|
||||
return super().check(kpts_est, **kwargs)
|
||||
|
||||
class Keypoints2D(BaseKeypoints):
|
||||
def __init__(self, keypoints2d, K, Rc, Tc, einsum='fab,fnb->fna',
|
||||
unproj=True, reshape_views=False, **kwargs) -> None:
|
||||
# convert to camera coordinate
|
||||
invKtrans = torch.inverse(K).transpose(-1, -2)
|
||||
if unproj:
|
||||
homo = torch.ones_like(keypoints2d[..., :1])
|
||||
homo = torch.cat([keypoints2d[..., :2], homo], dim=-1)
|
||||
if len(invKtrans.shape) < len(homo.shape):
|
||||
invKtrans = invKtrans.unsqueeze(-3)
|
||||
homo = torch.matmul(homo, invKtrans)
|
||||
keypoints2d = torch.cat([homo[..., :2], keypoints2d[..., 2:]], dim=-1)
|
||||
# keypoints2d: (nFrames, nViews, ..., nJoints, 3)
|
||||
super().__init__(keypoints2d, **kwargs)
|
||||
self.register_buffer('K', K)
|
||||
self.register_buffer('invKtrans', invKtrans)
|
||||
self.register_buffer('Rc', Rc)
|
||||
self.register_buffer('Tc', Tc)
|
||||
self.unproj = unproj
|
||||
self.einsum = einsum
|
||||
self.reshape_views = reshape_views
|
||||
|
||||
def project(self, kpts_est):
|
||||
kpts_est = self.select(kpts_est, self.index_est, self.ranges_est)
|
||||
kpts_homo = torch.ones_like(kpts_est[..., -1:])
|
||||
kpts_homo = torch.cat([kpts_est, kpts_homo], dim=-1)
|
||||
if self.unproj:
|
||||
P = torch.cat([self.Rc, self.Tc], dim=-1)
|
||||
else:
|
||||
P = torch.bmm(self.K, torch.cat([self.Rc, self.Tc], dim=-1))
|
||||
if self.reshape_views:
|
||||
kpts_homo = kpts_homo.reshape(self.K.shape[0], self.K.shape[1], *kpts_homo.shape[1:])
|
||||
try:
|
||||
point_cam = torch.einsum(self.einsum, P, kpts_homo)
|
||||
except:
|
||||
print('Wrong shape: {}x{} <=== {}'.format(P.shape, kpts_homo.shape, self.einsum))
|
||||
raise NotImplementedError
|
||||
img_points = point_cam[..., :2]/point_cam[..., 2:]
|
||||
return img_points
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
img_points = self.project(kpts_est)
|
||||
loss = self.loss(img_points.squeeze(), self.keypoints.squeeze(), self.conf.squeeze())
|
||||
return loss
|
||||
|
||||
def check(self, kpts_est, min_conf=0.3):
|
||||
with torch.no_grad():
|
||||
img_points = self.project(kpts_est)
|
||||
conf = (self.conf>min_conf)
|
||||
err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf
|
||||
if len(err.shape) == 3:
|
||||
err = err.sum(dim=1)
|
||||
conf = conf.sum(dim=1)
|
||||
err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0))
|
||||
return conf, err
|
||||
|
||||
def check_at_start(self, kpts_est, **kwargs):
|
||||
if len(self.index_est) > 0:
|
||||
names = [str(i) for i in self.index_est]
|
||||
elif len(self.ranges_est) > 0:
|
||||
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
|
||||
else:
|
||||
names = [str(i) for i in range(self.conf.shape[-1])]
|
||||
conf, error = self.check(kpts_est)
|
||||
valid = conf.sum(dim=0).detach().cpu().numpy()
|
||||
valid = valid.tolist()
|
||||
header = ['name', 'count']
|
||||
contents = [names, valid]
|
||||
header.append('before(pix)')
|
||||
contents.append(error.detach().cpu().numpy().tolist())
|
||||
self.cache_check = (header, contents)
|
||||
|
||||
def check_at_end(self, kpts_est, **kwargs):
|
||||
conf, err_after = self.check(kpts_est)
|
||||
header, contents = self.cache_check
|
||||
header.append('after(pix)')
|
||||
contents.append(err_after.detach().cpu().numpy().tolist())
|
||||
print_table(header, contents)
|
261
easymocap/multistage/mirror.py
Normal file
261
easymocap/multistage/mirror.py
Normal file
@ -0,0 +1,261 @@
|
||||
'''
|
||||
@ Date: 2022-07-12 11:55:47
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-07-14 17:57:48
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/mirror.py
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
from ..dataset.mirror import flipPoint2D, flipSMPLPoses, flipSMPLParams
|
||||
from ..estimator.wrapper_base import bbox_from_keypoints
|
||||
from .lossbase import Keypoints2D
|
||||
|
||||
def calc_vanishpoint(keypoints2d):
|
||||
'''
|
||||
keypoints2d: (2, N, 3)
|
||||
'''
|
||||
# weight: (N, 1)
|
||||
weight = keypoints2d[:, :, 2:].mean(axis=0)
|
||||
conf = weight.mean()
|
||||
A = np.hstack([
|
||||
keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2],
|
||||
-(keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1])
|
||||
])
|
||||
b = -keypoints2d[0, :, 0:1]*(keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2]) \
|
||||
+ keypoints2d[0, :, 1:2] * (keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1])
|
||||
b = -b
|
||||
A = A * weight
|
||||
b = b * weight
|
||||
avgInsec = np.linalg.inv(A.T @ A) @ (A.T @ b)
|
||||
result = np.zeros(3)
|
||||
result[0] = avgInsec[0, 0]
|
||||
result[1] = avgInsec[1, 0]
|
||||
result[2] = 1
|
||||
return result
|
||||
|
||||
def calc_mirror_transform(m_):
|
||||
""" From mirror vector to mirror matrix
|
||||
Args:
|
||||
m (bn, 4): (a, b, c, d)
|
||||
Returns:
|
||||
M: (bn, 3, 4)
|
||||
"""
|
||||
norm = torch.norm(m_[:, :3], dim=1, keepdim=True)
|
||||
m = m_[:, :3] / norm
|
||||
d = m_[:, 3]
|
||||
coeff_mat = torch.zeros((m.shape[0], 3, 4), device=m.device)
|
||||
coeff_mat[:, 0, 0] = 1 - 2*m[:, 0]**2
|
||||
coeff_mat[:, 0, 1] = -2*m[:, 0]*m[:, 1]
|
||||
coeff_mat[:, 0, 2] = -2*m[:, 0]*m[:, 2]
|
||||
coeff_mat[:, 0, 3] = -2*m[:, 0]*d
|
||||
coeff_mat[:, 1, 0] = -2*m[:, 1]*m[:, 0]
|
||||
coeff_mat[:, 1, 1] = 1-2*m[:, 1]**2
|
||||
coeff_mat[:, 1, 2] = -2*m[:, 1]*m[:, 2]
|
||||
coeff_mat[:, 1, 3] = -2*m[:, 1]*d
|
||||
coeff_mat[:, 2, 0] = -2*m[:, 2]*m[:, 0]
|
||||
coeff_mat[:, 2, 1] = -2*m[:, 2]*m[:, 1]
|
||||
coeff_mat[:, 2, 2] = 1-2*m[:, 2]**2
|
||||
coeff_mat[:, 2, 3] = -2*m[:, 2]*d
|
||||
return coeff_mat
|
||||
|
||||
class InitNormal:
|
||||
def __init__(self, static) -> None:
|
||||
self.static = static
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
if 'normal' in infos.keys():
|
||||
print('>>> Reading normal: {}'.format(infos['normal']))
|
||||
return body_params
|
||||
kpts = infos['keypoints2d']
|
||||
kpts0 = kpts[:, 0]
|
||||
kpts1 = flipPoint2D(kpts[:, 1])
|
||||
vanish_line = torch.stack([kpts0.reshape(-1, 3), kpts1.reshape(-1, 3)], dim=1)
|
||||
MIN_THRES = 0.5
|
||||
conf = (vanish_line[:, 0, -1] > MIN_THRES) & (vanish_line[:, 1, -1] > MIN_THRES)
|
||||
vanish_line = vanish_line[conf]
|
||||
vline0 = vanish_line.numpy().transpose(1, 0, 2)
|
||||
vpoint0 = calc_vanishpoint(vline0).reshape(1, 3)
|
||||
# 计算点到线的距离进行检查
|
||||
# two points line: (x1, y1), (x2, y2) ==> (y-y1)/(x-x1) = (y2-y1)/(x2-x1)
|
||||
# A = y2 - y1
|
||||
# B = x1 - x2
|
||||
# C = x2y1 - x1y2
|
||||
# d = abs(ax + by + c)/sqrt(a^2+b^2)
|
||||
A_v0 = kpts0[:, :, 1] - vpoint0[0, 1]
|
||||
B_v0 = vpoint0[0, 0] - kpts0[:, :, 0]
|
||||
C_v0 = kpts0[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts0[:, :, 1]
|
||||
distance01 = np.abs(A_v0 * kpts1[:, :, 0] + B_v0 * kpts1[:, :, 1] + C_v0)/np.sqrt(A_v0*A_v0 + B_v0*B_v0)
|
||||
A_v1 = kpts1[:, :, 1] - vpoint0[0, 1]
|
||||
B_v1 = vpoint0[0, 0] - kpts1[:, :, 0]
|
||||
C_v1 = kpts1[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts1[:, :, 1]
|
||||
distance10 = np.abs(A_v1 * kpts0[:, :, 0] + B_v1 * kpts0[:, :, 1] + C_v1)/np.sqrt(A_v1*A_v1 + B_v1*B_v1)
|
||||
DIST_THRES = 0.05
|
||||
for nf in range(kpts.shape[0]):
|
||||
# 计算scale
|
||||
bbox0 = bbox_from_keypoints(kpts0[nf].cpu().numpy())
|
||||
bbox1 = bbox_from_keypoints(kpts1[nf].cpu().numpy())
|
||||
bbox_size0 = max(bbox0[2]-bbox0[0], bbox0[3]-bbox0[1])
|
||||
bbox_size1 = max(bbox1[2]-bbox1[0], bbox1[3]-bbox1[1])
|
||||
valid = (kpts0[nf, :, 2] > 0.3) & (kpts1[nf, :, 2] > 0.3)
|
||||
dist01_ = valid*distance01[nf] / bbox_size1
|
||||
dist10_ = valid*distance10[nf] / bbox_size0
|
||||
# 对于距离异常的点,阈值设定为0.1
|
||||
# 抑制掉置信度低的视角的点
|
||||
not_valid0 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] < kpts1[nf][:, -1]))[0]
|
||||
not_valid1 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] > kpts1[nf][:, -1]))[0]
|
||||
kpts0[nf, not_valid0] = 0.
|
||||
kpts1[nf, not_valid1] = 0.
|
||||
if len(not_valid0) > 0:
|
||||
print('[mirror] filter {} person 0: {}'.format(nf, not_valid0))
|
||||
if len(not_valid1) > 0:
|
||||
print('[mirror] filter {} person 1: {}'.format(nf, not_valid1))
|
||||
kpts1_ = flipPoint2D(kpts1)
|
||||
infos['keypoints2d'] = torch.stack([kpts0, kpts1_], dim=1)
|
||||
infos['vanish_point0'] = torch.Tensor(vpoint0)
|
||||
K = infos['K'][0]
|
||||
normal = np.linalg.inv(K) @ vpoint0.T
|
||||
normal = normal.T/np.linalg.norm(normal)
|
||||
print('>>> Calculating normal from keypoints: {}'.format(normal[0]))
|
||||
infos['normal'] = torch.Tensor(normal)
|
||||
mirror = torch.zeros((1, 4))
|
||||
# 计算镜子平面到相机的距离
|
||||
Th = body_params['Th']
|
||||
center = Th.mean(axis=1)
|
||||
# 相机原点到两个人中心的连线在normal上的投影
|
||||
dist = (center * normal).sum(axis=-1).mean()
|
||||
print('>>> Calculating distance from Th: {}'.format(dist))
|
||||
mirror[0, 3] = - dist # initial guess
|
||||
mirror[:, :3] = infos['normal']
|
||||
infos['mirror'] = mirror
|
||||
return body_params
|
||||
|
||||
class RemoveP1:
|
||||
def __init__(self, static) -> None:
|
||||
self.static = static
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
for key in body_params.keys():
|
||||
if key == 'id': continue
|
||||
body_params[key] = body_params[key][:, 0]
|
||||
return body_params
|
||||
|
||||
class Mirror:
|
||||
def __init__(self, key) -> None:
|
||||
self.key = key
|
||||
|
||||
def before(self, body_params):
|
||||
poses = body_params['poses'][:, 0]
|
||||
# append root
|
||||
poses = torch.cat([torch.zeros_like(poses[..., :3]), poses], dim=-1)
|
||||
poses_mirror = flipSMPLPoses(poses)
|
||||
poses = torch.cat([poses[:, None, 3:], poses_mirror[:, None, 3:]], dim=1)
|
||||
body_params['poses'] = poses
|
||||
return body_params
|
||||
|
||||
def after(self,):
|
||||
pass
|
||||
|
||||
def final(self, body_params):
|
||||
return self.before(body_params)
|
||||
|
||||
class Keypoints2DMirror(Keypoints2D):
|
||||
def __init__(self, mirror, opt_normal, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not mirror.requires_grad:
|
||||
self.register_buffer('mirror', mirror)
|
||||
else:
|
||||
self.mirror = mirror
|
||||
self.opt_normal = opt_normal
|
||||
k2dall = kwargs['keypoints2d']
|
||||
size_all = []
|
||||
for nf in range(k2dall.shape[0]):
|
||||
for nper in range(2):
|
||||
kpts = k2dall[nf, nper]
|
||||
bbox = bbox_from_keypoints(kpts.cpu().numpy())
|
||||
bbox_size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
size_all.append(bbox_size)
|
||||
size_all = np.array(size_all).reshape(-1, 2)
|
||||
scale = (size_all[:, 0] / size_all[:, 1]).mean()
|
||||
print('[loss] mean scale = {} from {} frames, use this to balance the two person'.format(scale, size_all.shape[0]))
|
||||
# ATTN: here we use v^2 to suppress the outlier detections
|
||||
self.conf = self.conf * self.conf
|
||||
self.conf[:, 1] *= scale*scale
|
||||
|
||||
def check(self, kpts_est, min_conf=0.3):
|
||||
with torch.no_grad():
|
||||
M = calc_mirror_transform(self.mirror)
|
||||
homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device)
|
||||
kpts_homo = torch.cat([kpts_est, homo], dim=-1)
|
||||
kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2))
|
||||
kpts = torch.stack([kpts_est, kpts_mirror], dim=1)
|
||||
img_points = self.project(kpts)
|
||||
conf = (self.conf>min_conf)
|
||||
err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf
|
||||
if len(err.shape) == 3:
|
||||
err = err.sum(dim=1)
|
||||
conf = conf.sum(dim=1)
|
||||
err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0))
|
||||
return conf, err
|
||||
|
||||
def forward(self, kpts_est, **kwargs):
|
||||
if self.opt_normal:
|
||||
M = calc_mirror_transform(self.mirror)
|
||||
else:
|
||||
mirror = torch.cat([self.mirror[:, :3].detach(), self.mirror[:, 3:]], dim=1)
|
||||
M = calc_mirror_transform(mirror)
|
||||
homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device)
|
||||
kpts_homo = torch.cat([kpts_est, homo], dim=-1)
|
||||
kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2))
|
||||
kpts = torch.stack([kpts_est, kpts_mirror], dim=1)
|
||||
return super().forward(kpts_est=kpts, **kwargs)
|
||||
|
||||
class MirrorPoses:
|
||||
def __init__(self, ref) -> None:
|
||||
self.ref = ref
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
# shapes: (nFrames, 2, nShapes)
|
||||
shapes = body_params['shapes'].mean(axis=0).mean(axis=0).reshape(1, 1, -1)
|
||||
poses = body_params['poses'][:, 0]
|
||||
# append root
|
||||
poses = np.concatenate([np.zeros([poses.shape[0], 3]), poses], axis=1)
|
||||
poses_mirror = flipSMPLPoses(poses)
|
||||
poses = np.concatenate([poses[:, None, 3:], poses_mirror[:, None, 3:]], axis=1)
|
||||
body_params['poses'] = poses
|
||||
body_params['shapes'] = shapes
|
||||
return body_params
|
||||
|
||||
class MirrorParams:
|
||||
def __init__(self, key) -> None:
|
||||
self.key = key
|
||||
|
||||
def start(self, body_params):
|
||||
if len(body_params['poses'].shape) == 2:
|
||||
return body_params
|
||||
for key in body_params.keys():
|
||||
if key == 'id': continue
|
||||
body_params[key] = body_params[key][:, 0]
|
||||
return body_params
|
||||
|
||||
def before(self, body_params):
|
||||
return body_params
|
||||
|
||||
def after(self,):
|
||||
pass
|
||||
|
||||
def final(self, body_params):
|
||||
device = body_params['poses'].device
|
||||
body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()}
|
||||
body_params['poses'] = np.hstack((np.zeros_like(body_params['poses'][:, :3]), body_params['poses']))
|
||||
params_mirror = flipSMPLParams(body_params, self.infos['mirror'].cpu().numpy())
|
||||
params = {}
|
||||
for key in params_mirror.keys():
|
||||
if key == 'shapes':
|
||||
params[key] = body_params[key][:, None]
|
||||
else:
|
||||
params[key] = np.concatenate([body_params[key][:, None], params_mirror[key][:, None]], axis=-2)
|
||||
params['poses'] = params['poses'][..., 3:]
|
||||
params = {key:torch.Tensor(val).to(device) for key, val in params.items()}
|
||||
return params
|
79
easymocap/multistage/synchronization.py
Normal file
79
easymocap/multistage/synchronization.py
Normal file
@ -0,0 +1,79 @@
|
||||
'''
|
||||
@ Date: 2022-03-11 12:13:01
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-08-11 21:52:00
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/synchronization.py
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class AddTime:
|
||||
def __init__(self, gt) -> None:
|
||||
self.gt = gt
|
||||
|
||||
def __call__(self, body_model, body_params, infos):
|
||||
nViews = infos['keypoints2d'].shape[1]
|
||||
offset = np.zeros((nViews,), dtype=np.float32)
|
||||
body_params['sync_offset'] = offset
|
||||
return body_params
|
||||
|
||||
class Interpolate:
|
||||
def __init__(self, actfn) -> None:
|
||||
# self.act_fn = lambda x: 2*torch.nn.functional.softsign(x)
|
||||
self.act_fn = lambda x: 2*torch.tanh(x)
|
||||
self.use0asref = False
|
||||
|
||||
def get_offset(self, time_offset):
|
||||
if self.use0asref:
|
||||
off = self.act_fn(torch.cat([torch.zeros(1, device=time_offset.device), time_offset[1:]]))
|
||||
else:
|
||||
off = self.act_fn(time_offset)
|
||||
return off
|
||||
|
||||
def start(self, body_params):
|
||||
return body_params
|
||||
|
||||
def before(self, body_params):
|
||||
off = self.get_offset(body_params['sync_offset'])
|
||||
nViews = off.shape[0]
|
||||
if len(body_params['poses'].shape) == 2:
|
||||
off = off[None, :, None]
|
||||
else:
|
||||
off = off[None, :, None, None]
|
||||
for key in body_params.keys():
|
||||
if key in ['sync_offset', 'shapes']:
|
||||
continue
|
||||
# TODO: Rh有正周期旋转的时候会有问题
|
||||
val = body_params[key]
|
||||
if key == 'Rh':
|
||||
pass
|
||||
if key in ['Th', 'poses']:
|
||||
velocity = torch.cat([val[1:2] - val[0:1], val[1:] - val[:-1]], dim=0)
|
||||
valnew = val[:, None] + off * velocity[:, None]
|
||||
# vel = velocity.detach().cpu().numpy()
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.plot(vel)
|
||||
# plt.show()
|
||||
# import ipdb;ipdb.set_trace()
|
||||
else:
|
||||
if len(val.shape) == 2:
|
||||
valnew = val[:, None].repeat(1, nViews, 1)
|
||||
elif len(val.shape) == 3:
|
||||
valnew = val[:, None].repeat(1, nViews, 1, 1)
|
||||
else:
|
||||
print('[warn] Unknown {} shape {}'.format(key, valnew.shape))
|
||||
import ipdb; ipdb.set_trace()
|
||||
valnew = valnew.reshape(-1, *val.shape[1:])
|
||||
body_params[key] = valnew
|
||||
return body_params
|
||||
|
||||
def after(self,):
|
||||
pass
|
||||
|
||||
def final(self, body_params):
|
||||
off = self.get_offset(body_params['sync_offset'])
|
||||
body_params = self.before(body_params)
|
||||
body_params['sync_offset'] = off
|
||||
return body_params
|
517
easymocap/multistage/torchgeometry.py
Normal file
517
easymocap/multistage/torchgeometry.py
Normal file
@ -0,0 +1,517 @@
|
||||
"""
|
||||
useful functions to perform conversion between rotation in different format(quaternion, rotation_matrix, euler_angle, axis_angle)
|
||||
quaternion representation: (w,x,y,z)
|
||||
code reference: torchgeometry, kornia, https://github.com/MandyMo/pytorch_HMR.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Conversions between different rotation representations, quaternion,rotation matrix,euler and axis angle.
|
||||
|
||||
def rot6d_to_rotation_matrix(rot6d):
|
||||
"""
|
||||
Convert 6D rotation representation to 3x3 rotation matrix.
|
||||
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
||||
Args:
|
||||
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
|
||||
Returns:
|
||||
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
|
||||
"""
|
||||
x = rot6d.view(-1, 3, 2)
|
||||
a1 = x[:, :, 0]
|
||||
a2 = x[:, :, 1]
|
||||
b1 = F.normalize(a1)
|
||||
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
||||
b3 = torch.cross(b1, b2)
|
||||
return torch.stack((b1, b2, b3), dim=-1)
|
||||
|
||||
|
||||
def rotation_matrix_to_rot6d(rotation_matrix):
|
||||
"""
|
||||
Convert 3x3 rotation matrix to 6D rotation representation.
|
||||
Args:
|
||||
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
|
||||
Returns:
|
||||
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
|
||||
"""
|
||||
v1 = rotation_matrix[:, :, 0:1]
|
||||
v2 = rotation_matrix[:, :, 1:2]
|
||||
rot6d = torch.cat([v1, v2], dim=-1).reshape(v1.shape[0], 6)
|
||||
return rot6d
|
||||
|
||||
|
||||
def quaternion_to_rotation_matrix(quaternion):
|
||||
"""
|
||||
Convert quaternion coefficients to rotation matrix.
|
||||
Args:
|
||||
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
||||
Returns:
|
||||
rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3)
|
||||
"""
|
||||
|
||||
norm_quaternion = quaternion
|
||||
norm_quaternion = norm_quaternion / \
|
||||
norm_quaternion.norm(p=2, dim=1, keepdim=True)
|
||||
w, x, y, z = norm_quaternion[:, 0], norm_quaternion[:,
|
||||
1], norm_quaternion[:, 2], norm_quaternion[:, 3]
|
||||
|
||||
batch_size = quaternion.size(0)
|
||||
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
||||
wx, wy, wz = w*x, w*y, w*z
|
||||
xy, xz, yz = x*y, x*z, y*z
|
||||
|
||||
rotation_matrix = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
||||
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
||||
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(batch_size, 3, 3)
|
||||
return rotation_matrix
|
||||
|
||||
|
||||
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
||||
"""
|
||||
Convert rotation matrix to corresponding quaternion
|
||||
Args:
|
||||
rotation_matrix: torch tensor of shape (batch_size, 3, 3)
|
||||
Returns:
|
||||
quaternion: torch tensor of shape(batch_size, 4) in (w, x, y, z) representation.
|
||||
"""
|
||||
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
||||
|
||||
mask_d2 = rmat_t[:, 2, 2] < eps
|
||||
|
||||
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
||||
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
||||
|
||||
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
||||
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
||||
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
||||
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
|
||||
t0_rep = t0.repeat(4, 1).t()
|
||||
|
||||
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
||||
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
||||
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
||||
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
|
||||
t1_rep = t1.repeat(4, 1).t()
|
||||
|
||||
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
||||
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
||||
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
||||
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
|
||||
t2_rep = t2.repeat(4, 1).t()
|
||||
|
||||
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
||||
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
||||
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
||||
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
|
||||
t3_rep = t3.repeat(4, 1).t()
|
||||
|
||||
mask_c0 = mask_d2 * mask_d0_d1
|
||||
mask_c1 = mask_d2 * (~ mask_d0_d1)
|
||||
mask_c2 = (~ mask_d2) * mask_d0_nd1
|
||||
mask_c3 = (~ mask_d2) * (~ mask_d0_nd1)
|
||||
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
||||
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
||||
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
||||
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
||||
|
||||
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
||||
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
|
||||
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
|
||||
q *= 0.5
|
||||
return q
|
||||
|
||||
|
||||
def quaternion_to_euler(quaternion, order, epsilon=0):
|
||||
"""
|
||||
Convert quaternion to euler angles.
|
||||
Args:
|
||||
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
||||
order: euler angle representation order, 'zyx' etc.
|
||||
epsilon:
|
||||
Returns:
|
||||
euler: torch tensor of shape (batch_size, 3) in order.
|
||||
"""
|
||||
assert quaternion.shape[-1] == 4
|
||||
original_shape = list(quaternion.shape)
|
||||
original_shape[-1] = 3
|
||||
q = quaternion.contiguous().view(-1, 4)
|
||||
q0 = q[:, 0]
|
||||
q1 = q[:, 1]
|
||||
q2 = q[:, 2]
|
||||
q3 = q[:, 3]
|
||||
|
||||
if order == 'xyz':
|
||||
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q2 * q2))
|
||||
y = torch.asin(torch.clamp(
|
||||
2 * (q1 * q3 + q0 * q2), -1+epsilon, 1-epsilon))
|
||||
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q2 * q2 + q3 * q3))
|
||||
elif order == 'yzx':
|
||||
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q3 * q3))
|
||||
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q2 * q2 + q3 * q3))
|
||||
z = torch.asin(torch.clamp(
|
||||
2 * (q1 * q2 + q0 * q3), -1+epsilon, 1-epsilon))
|
||||
elif order == 'zxy':
|
||||
x = torch.asin(torch.clamp(
|
||||
2 * (q0 * q1 + q2 * q3), -1+epsilon, 1-epsilon))
|
||||
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q1 * q1 + q2 * q2))
|
||||
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q1 * q1 + q3 * q3))
|
||||
elif order == 'xzy':
|
||||
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q3 * q3))
|
||||
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2*(q2 * q2 + q3 * q3))
|
||||
z = torch.asin(torch.clamp(
|
||||
2 * (q0 * q3 - q1 * q2), -1+epsilon, 1-epsilon))
|
||||
elif order == 'yxz':
|
||||
x = torch.asin(torch.clamp(
|
||||
2 * (q0 * q1 - q2 * q3), -1+epsilon, 1-epsilon))
|
||||
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2*(q1 * q1 + q2 * q2))
|
||||
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2*(q1 * q1 + q3 * q3))
|
||||
elif order == 'zyx':
|
||||
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q2 * q2))
|
||||
y = torch.asin(torch.clamp(
|
||||
2 * (q0 * q2 - q1 * q3), -1+epsilon, 1-epsilon))
|
||||
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2*(q2 * q2 + q3 * q3))
|
||||
else:
|
||||
raise Exception('unsupported euler order!')
|
||||
|
||||
return torch.stack((x, y, z), dim=1).view(original_shape)
|
||||
|
||||
|
||||
def euler_to_quaternion(euler, order):
|
||||
"""
|
||||
Convert euler angles to quaternion.
|
||||
Args:
|
||||
euler: torch tensor of shape (batch_size, 3) in order.
|
||||
order:
|
||||
Returns:
|
||||
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
||||
"""
|
||||
assert euler.shape[-1] == 3
|
||||
original_shape = list(euler.shape)
|
||||
original_shape[-1] = 4
|
||||
e = euler.reshape(-1, 3)
|
||||
|
||||
x = e[:, 0]
|
||||
y = e[:, 1]
|
||||
z = e[:, 2]
|
||||
|
||||
rx = torch.stack((torch.cos(x/2), torch.sin(x/2),
|
||||
torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
||||
ry = torch.stack((torch.cos(y/2), torch.zeros_like(y),
|
||||
torch.sin(y/2), torch.zeros_like(y)), dim=1)
|
||||
rz = torch.stack((torch.cos(z/2), torch.zeros_like(z),
|
||||
torch.zeros_like(z), torch.sin(z/2)), dim=1)
|
||||
|
||||
result = None
|
||||
for coord in order:
|
||||
if coord == 'x':
|
||||
r = rx
|
||||
elif coord == 'y':
|
||||
r = ry
|
||||
elif coord == 'z':
|
||||
r = rz
|
||||
else:
|
||||
raise Exception('unsupported euler order!')
|
||||
if result is None:
|
||||
result = r
|
||||
else:
|
||||
result = quaternion_mul(result, r)
|
||||
|
||||
# Reverse antipodal representation to have a non-negative "w"
|
||||
if order in ['xyz', 'yzx', 'zxy']:
|
||||
result *= -1
|
||||
|
||||
return result.reshape(original_shape)
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternion):
|
||||
"""
|
||||
Convert quaternion to axis angle.
|
||||
based on: https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
|
||||
Args:
|
||||
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
||||
Returns:
|
||||
axis_angle: torch tensor of shape (batch_size, 3)
|
||||
"""
|
||||
epsilon = 1.e-8
|
||||
if not torch.is_tensor(quaternion):
|
||||
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
||||
type(quaternion)))
|
||||
|
||||
if not quaternion.shape[-1] == 4:
|
||||
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
|
||||
.format(quaternion.shape))
|
||||
# unpack input and compute conversion
|
||||
q1: torch.Tensor = quaternion[..., 1]
|
||||
q2: torch.Tensor = quaternion[..., 2]
|
||||
q3: torch.Tensor = quaternion[..., 3]
|
||||
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
||||
|
||||
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta+epsilon)
|
||||
cos_theta: torch.Tensor = quaternion[..., 0]
|
||||
two_theta: torch.Tensor = 2.0 * torch.where(
|
||||
cos_theta < 0.0,
|
||||
torch.atan2(-sin_theta, -cos_theta),
|
||||
torch.atan2(sin_theta, cos_theta))
|
||||
|
||||
k_pos: torch.Tensor = two_theta / sin_theta
|
||||
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
||||
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
||||
|
||||
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
||||
angle_axis[..., 0] += q1 * k
|
||||
angle_axis[..., 1] += q2 * k
|
||||
angle_axis[..., 2] += q3 * k
|
||||
return angle_axis
|
||||
|
||||
|
||||
def axis_angle_to_quaternion(axis_angle):
|
||||
"""
|
||||
Convert axis angle to quaternion.
|
||||
Args:
|
||||
axis_angle: torch tensor of shape (batch_size, 3)
|
||||
Returns:
|
||||
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
||||
"""
|
||||
rotation_matrix = axis_angle_to_rotation_matrix(axis_angle)
|
||||
return rotation_matrix_to_quaternion(rotation_matrix)
|
||||
|
||||
|
||||
def axis_angle_to_rotation_matrix(axis_angle):
|
||||
"""
|
||||
Convert axis-angle representation to rotation matrix.
|
||||
Args:
|
||||
axis_angle: torch tensor of shape (batch_size, 3).
|
||||
Returns:
|
||||
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
|
||||
"""
|
||||
|
||||
l1_norm = torch.norm(axis_angle+1e-8, p=2, dim=1)
|
||||
angle = torch.unsqueeze(l1_norm, dim=-1)
|
||||
normalized = torch.div(axis_angle, angle)
|
||||
angle = angle * 0.5
|
||||
v_cos = torch.cos(angle)
|
||||
v_sin = torch.sin(angle)
|
||||
quaternion = torch.cat([v_cos, v_sin*normalized], dim=1)
|
||||
return quaternion_to_rotation_matrix(quaternion)
|
||||
|
||||
|
||||
def rotation_matrix_to_axis_angle(rotation_matrix):
|
||||
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
||||
return quaternion_to_axis_angle(quaternion)
|
||||
|
||||
|
||||
def rotation_matrix_to_euler(rotation_matrix, order):
|
||||
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
||||
return quaternion_to_euler(quaternion, order)
|
||||
|
||||
|
||||
def euler_to_rotation_matrix(euler, order):
|
||||
quaternion = euler_to_quaternion(euler, order)
|
||||
return quaternion_to_rotation_matrix(quaternion)
|
||||
|
||||
|
||||
def axis_angle_to_euler(axis_angle, order):
|
||||
quaternion = axis_angle_to_quaternion(axis_angle)
|
||||
return quaternion_to_euler(quaternion, order)
|
||||
|
||||
|
||||
def euler_to_axis_angle(euler, order):
|
||||
quaternion = euler_to_quaternion(euler, order)
|
||||
return quaternion_to_axis_angle(quaternion)
|
||||
|
||||
# rotation operations
|
||||
|
||||
|
||||
def quaternion_mul(q, r):
|
||||
"""
|
||||
Multiply quaternion(s) q with quaternion(s) r.
|
||||
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
||||
Returns q*r as a tensor of shape (*, 4).
|
||||
"""
|
||||
assert q.shape[-1] == 4
|
||||
assert r.shape[-1] == 4
|
||||
|
||||
original_shape = q.shape
|
||||
|
||||
# Compute outer product
|
||||
terms = torch.bmm(r.contiguous().view(-1, 4, 1),
|
||||
q.contiguous().view(-1, 1, 4))
|
||||
|
||||
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
||||
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
||||
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
||||
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
||||
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
||||
|
||||
|
||||
def rotate_vec_by_quaternion(v, q):
|
||||
"""
|
||||
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
||||
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
||||
where * denotes any number of dimensions.
|
||||
Returns a tensor of shape (*, 3).
|
||||
"""
|
||||
assert q.shape[-1] == 4
|
||||
assert v.shape[-1] == 3
|
||||
assert q.shape[:-1] == v.shape[:-1]
|
||||
|
||||
original_shape = list(v.shape)
|
||||
q = q.contiguous().view(-1, 4)
|
||||
v = v.view(-1, 3)
|
||||
|
||||
qvec = q[:, 1:]
|
||||
uv = torch.cross(qvec, v, dim=1)
|
||||
uuv = torch.cross(qvec, uv, dim=1)
|
||||
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
||||
|
||||
|
||||
def quaternion_fix(quaternion):
|
||||
"""
|
||||
Enforce quaternion continuity across the time dimension by selecting
|
||||
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
||||
between two consecutive frames.
|
||||
Args:
|
||||
quaternion: torch tensor of shape (batch_size, 4)
|
||||
Returns:
|
||||
quaternion: torch tensor of shape (batch_size, 4)
|
||||
"""
|
||||
quaternion_fixed = quaternion.clone()
|
||||
dot_products = torch.sum(quaternion[1:]*quaternion[:-1],dim=-1)
|
||||
mask = dot_products < 0
|
||||
mask = (torch.cumsum(mask, dim=0) % 2).bool()
|
||||
quaternion_fixed[1:][mask] *= -1
|
||||
return quaternion_fixed
|
||||
|
||||
|
||||
def quaternion_inverse(quaternion):
|
||||
q_conjugate = quaternion.clone()
|
||||
q_conjugate[::, 1:] * -1
|
||||
q_norm = quaternion[::, 1:].norm(dim=-1) + quaternion[::, 0]**2
|
||||
return q_conjugate/q_norm.unsqueeze(-1)
|
||||
|
||||
|
||||
def quaternion_lerp(q1, q2, t):
|
||||
q = (1-t)*q1 + t*q2
|
||||
q = q/q.norm(dim=-1).unsqueeze(-1)
|
||||
return q
|
||||
|
||||
def geodesic_dist(q1,q2):
|
||||
"""
|
||||
@q1: torch tensor of shape (frame, joints, 4) quaternion
|
||||
@q2: same as q1
|
||||
@output: torch tensor of shape (frame, joints)
|
||||
"""
|
||||
q1_conjugate = q1.clone()
|
||||
q1_conjugate[:,:,1:] *= -1
|
||||
q1_norm = q1[:,:,1:].norm(dim=-1) + q1[:,:,0]**2
|
||||
q1_inverse = q1_conjugate/q1_norm.unsqueeze(dim=-1)
|
||||
q_between = quaternion_mul(q1_inverse,q2)
|
||||
geodesic_dist = quaternion_to_axis_angle(q_between).norm(dim=-1)
|
||||
return geodesic_dist
|
||||
|
||||
def get_extrinsic(translation, rotation):
|
||||
batch_size = translation.shape[0]
|
||||
pose = torch.zeros((batch_size, 4, 4))
|
||||
pose[:,:3, :3] = rotation
|
||||
pose[:,:3, 3] = translation
|
||||
pose[:,3, 3] = 1
|
||||
extrinsic = torch.inverse(pose)
|
||||
return extrinsic[:,:3, 3], extrinsic[:,:3, :3]
|
||||
|
||||
def euler_fix_old(euler):
|
||||
frame_num = euler.shape[0]
|
||||
joint_num = euler.shape[1]
|
||||
for l in range(3):
|
||||
for j in range(joint_num):
|
||||
overall_add = 0.
|
||||
for i in range(1,frame_num):
|
||||
add1 = overall_add
|
||||
add2 = overall_add + 2*np.pi
|
||||
add3 = overall_add - 2*np.pi
|
||||
previous = euler[i-1,j,l]
|
||||
value1 = euler[i,j,l] + add1
|
||||
value2 = euler[i,j,l] + add2
|
||||
value3 = euler[i,j,l] + add3
|
||||
e1 = torch.abs(value1 - previous)
|
||||
e2 = torch.abs(value2 - previous)
|
||||
e3 = torch.abs(value3 - previous)
|
||||
if (e1 <= e2) and (e1 <= e3):
|
||||
euler[i,j,l] = value1
|
||||
overall_add = add1
|
||||
if (e2 <= e1) and (e2 <= e3):
|
||||
euler[i, j, l] = value2
|
||||
overall_add = add2
|
||||
if (e3 <= e1) and (e3 <= e2):
|
||||
euler[i, j, l] = value3
|
||||
overall_add = add3
|
||||
return euler
|
||||
|
||||
def euler_fix(euler,rotation_order='zyx'):
|
||||
frame_num = euler.shape[0]
|
||||
joint_num = euler.shape[1]
|
||||
euler_new = euler.clone()
|
||||
for j in range(joint_num):
|
||||
euler_new[:,j] = euler_filter(euler[:,j],rotation_order)
|
||||
return euler_new
|
||||
|
||||
'''
|
||||
euler filter from https://github.com/wesen/blender-euler-filter/blob/master/euler_filter.py.
|
||||
'''
|
||||
def euler_distance(e1, e2):
|
||||
return abs(e1[0] - e2[0]) + abs(e1[1] - e2[1]) + abs(e1[2] - e2[2])
|
||||
|
||||
|
||||
def euler_axis_index(axis):
|
||||
if axis == 'x':
|
||||
return 0
|
||||
if axis == 'y':
|
||||
return 1
|
||||
if axis == 'z':
|
||||
return 2
|
||||
return None
|
||||
|
||||
def flip_euler(euler, rotation_mode):
|
||||
ret = euler.clone()
|
||||
inner_axis = rotation_mode[0]
|
||||
outer_axis = rotation_mode[2]
|
||||
middle_axis = rotation_mode[1]
|
||||
|
||||
ret[euler_axis_index(inner_axis)] += np.pi
|
||||
ret[euler_axis_index(outer_axis)] += np.pi
|
||||
ret[euler_axis_index(middle_axis)] *= -1
|
||||
ret[euler_axis_index(middle_axis)] += np.pi
|
||||
return ret
|
||||
|
||||
def naive_flip_diff(a1, a2):
|
||||
while abs(a1 - a2) >= np.pi+1e-5:
|
||||
if a1 < a2:
|
||||
a2 -= 2 * np.pi
|
||||
else:
|
||||
a2 += 2 * np.pi
|
||||
|
||||
return a2
|
||||
|
||||
def euler_filter(euler,rotation_order):
|
||||
frame_num = euler.shape[0]
|
||||
if frame_num <= 1:
|
||||
return euler
|
||||
euler_fix = euler.clone()
|
||||
prev = euler[0]
|
||||
for i in range(1,frame_num):
|
||||
e = euler[i]
|
||||
for d in range(3):
|
||||
e[d] = naive_flip_diff(prev[d],e[d])
|
||||
fe = flip_euler(e,rotation_order)
|
||||
for d in range(3):
|
||||
fe[d] = naive_flip_diff(prev[d],fe[d])
|
||||
|
||||
de = euler_distance(prev,e)
|
||||
dfe = euler_distance(prev,fe)
|
||||
if dfe < de:
|
||||
e = fe
|
||||
prev = e
|
||||
euler_fix[i] = e
|
||||
return euler_fix
|
97
easymocap/multistage/totalfitting.py
Normal file
97
easymocap/multistage/totalfitting.py
Normal file
@ -0,0 +1,97 @@
|
||||
'''
|
||||
@ Date: 2022-07-28 14:39:23
|
||||
@ Author: Qing Shuai
|
||||
@ Mail: s_q@zju.edu.cn
|
||||
@ LastEditors: Qing Shuai
|
||||
@ LastEditTime: 2022-08-12 21:42:12
|
||||
@ FilePath: /EasyMocapPublic/easymocap/multistage/totalfitting.py
|
||||
'''
|
||||
import torch
|
||||
|
||||
from ..bodymodel.lbs import batch_rodrigues
|
||||
from .torchgeometry import rotation_matrix_to_axis_angle, rotation_matrix_to_quaternion, quaternion_to_rotation_matrix, quaternion_to_axis_angle
|
||||
import numpy as np
|
||||
from .base_ops import BeforeAfterBase
|
||||
|
||||
def compute_twist_rotation(rotation_matrix, twist_axis):
|
||||
'''
|
||||
Compute the twist component of given rotation and twist axis
|
||||
https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis
|
||||
Parameters
|
||||
----------
|
||||
rotation_matrix : Tensor (B, 3, 3,)
|
||||
The rotation to convert
|
||||
twist_axis : Tensor (B, 3,)
|
||||
The twist axis
|
||||
Returns
|
||||
-------
|
||||
Tensor (B, 3, 3)
|
||||
The twist rotation
|
||||
'''
|
||||
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
||||
|
||||
twist_axis = twist_axis / (torch.norm(twist_axis, dim=1, keepdim=True) + 1e-9)
|
||||
|
||||
projection = torch.einsum('bi,bi->b', twist_axis, quaternion[:, 1:]).unsqueeze(-1) * twist_axis
|
||||
|
||||
twist_quaternion = torch.cat([quaternion[:, 0:1], projection], dim=1)
|
||||
twist_quaternion = twist_quaternion / (torch.norm(twist_quaternion, dim=1, keepdim=True) + 1e-9)
|
||||
|
||||
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
|
||||
|
||||
twist_aa = quaternion_to_axis_angle(twist_quaternion)
|
||||
|
||||
twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
|
||||
|
||||
return twist_rotation, twist_angle
|
||||
|
||||
class ClearTwist(BeforeAfterBase):
|
||||
def start(self, body_params):
|
||||
idx_elbow = [18-1, 19-1]
|
||||
for idx in idx_elbow:
|
||||
# x
|
||||
body_params['poses'][:, 3*idx] = 0.
|
||||
# z
|
||||
body_params['poses'][:, 3*idx+2] = 0.
|
||||
idx_wrist = [20-1, 21-1]
|
||||
for idx in idx_wrist:
|
||||
body_params['poses'][:, 3*idx:3*idx+3] = 0.
|
||||
return body_params
|
||||
|
||||
class SolveTwist(BeforeAfterBase):
|
||||
def __init__(self, body_model=None) -> None:
|
||||
self.body_model = body_model
|
||||
|
||||
def final(self, body_params):
|
||||
T_joints, T_vertices = self.body_model.transform(body_params)
|
||||
# This transform don't consider RT
|
||||
R = batch_rodrigues(body_params['Rh'])
|
||||
template = self.body_model.keypoints({'shapes': body_params['shapes'],
|
||||
'poses': torch.zeros_like(body_params['poses'])},
|
||||
only_shape=True, return_smpl_joints=True)
|
||||
config = {
|
||||
'left': {
|
||||
'index_smpl': 20,
|
||||
'index_elbow_smpl': 18,
|
||||
'R_global': 'R_handl3d',
|
||||
'axis': torch.Tensor([[1., 0., 0.]]).to(device=T_joints.device),
|
||||
},
|
||||
'right': {
|
||||
'index_smpl': 21,
|
||||
'index_elbow_smpl': 19,
|
||||
'R_global': 'R_handr3d',
|
||||
'axis': torch.Tensor([[-1., 0., 0.]]).to(device=T_joints.device),
|
||||
}
|
||||
}
|
||||
for key in ['left', 'right']:
|
||||
cfg = config[key]
|
||||
R_wrist_add = batch_rodrigues(body_params[cfg['R_global']])
|
||||
idx_elbow = cfg['index_elbow_smpl']
|
||||
idx_wrist = cfg['index_smpl']
|
||||
pred_parent_elbow = R @ T_joints[..., idx_elbow, :3, :3]
|
||||
pred_parent_wrist = R @ T_joints[..., idx_wrist, :3, :3]
|
||||
pred_global_wrist = torch.bmm(R_wrist_add, pred_parent_wrist)
|
||||
pred_local_wrist = torch.bmm(pred_parent_wrist.transpose(-1, -2), pred_global_wrist)
|
||||
axis = rotation_matrix_to_axis_angle(pred_local_wrist)
|
||||
body_params['poses'][..., 3*idx_wrist-3:3*idx_wrist] = axis
|
||||
return body_params
|
Loading…
Reference in New Issue
Block a user