🚧 create the new stype of fitting

This commit is contained in:
Qing Shuai 2022-08-21 16:07:06 +08:00
parent 5bc4b113ba
commit a0127f712a
12 changed files with 4026 additions and 0 deletions

View File

@ -0,0 +1,308 @@
# 这个脚本用于通用的多阶段的优化
import numpy as np
import torch
from ..annotator.file_utils import read_json
from ..mytools import Timer
from .lossbase import print_table
from ..config.baseconfig import load_object
from ..bodymodel.base import Params
from torch.utils.data import DataLoader
from tqdm import tqdm
def dict_of_numpy_to_tensor(body_model, body_params, *args, **kwargs):
device = body_model.device
body_params = {key:torch.Tensor(val).to(device) for key, val in body_params.items()}
return body_params
class AddExtra:
def __init__(self, vals) -> None:
self.vals = vals
def __call__(self, body_model, body_params, *args, **kwargs):
shapes = body_params['poses'].shape[:-1]
for key in self.vals:
if key in body_params.keys():
continue
if key.startswith('R_') or key.startswith('T_'):
val = np.zeros((*shapes, 3), dtype=np.float32)
body_params[key] = val
return body_params
def dict_of_tensor_to_numpy(body_params):
body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()}
return body_params
def grad_require(params, flag=False):
if isinstance(params, list):
for par in params:
par.requires_grad = flag
elif isinstance(params, dict):
for key, par in params.items():
par.requires_grad = flag
def rel_change(prev_val, curr_val):
return (prev_val - curr_val) / max([1e-5, abs(prev_val), abs(curr_val)])
def make_optimizer(opt_params, optim_type='lbfgs', max_iter=20,
lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0, **kwargs):
if isinstance(opt_params, dict):
# LBFGS 不支持参数字典
opt_params = list(opt_params.values())
if optim_type == 'lbfgs':
from ..pyfitting.lbfgs import LBFGS
optimizer = LBFGS(opt_params, line_search_fn='strong_wolfe', max_iter=max_iter, **kwargs)
elif optim_type == 'adam':
optimizer = torch.optim.Adam(opt_params, lr=lr, betas=betas, weight_decay=weight_decay)
else:
raise NotImplementedError
return optimizer
def make_lossfuncs(stage, infos, device, irepeat, verbose=False):
loss_funcs, weights = {}, {}
for key, val in stage.loss.items():
loss_args = dict(val.args)
if 'infos' in val.keys():
for k in val.infos:
loss_args[k] = infos[k]
module = load_object(val.module, loss_args)
module.to(device)
if 'weights' in val.keys():
weights[key] = val.weights[irepeat]
else:
weights[key] = val.weight
if weights[key] < 0:
weights.pop(key)
else:
loss_funcs[key] = module
if verbose or True:
print('Loss functions: ')
for key, func in loss_funcs.items():
print(' - {:15s}: {}, {}'.format(key, weights[key], func))
return loss_funcs, weights
def make_before_after(before_after, body_model, body_params, infos):
modules = []
for key, val in before_after.items():
args = dict(val.args)
if 'body_model' in args.keys():
args['body_model'] = body_model
try:
module = load_object(val.module, args)
except:
print('[Fitting] Failed to load module {}'.format(key))
raise NotImplementedError
module.infos = infos
modules.append(module)
return modules
def process(start_or_end, body_model, body_params, infos):
for key, val in start_or_end.items():
if isinstance(val, dict):
module = load_object(val.module, val.args)
else:
if key == 'convert' and val == 'numpy_to_tensor':
module = dict_of_numpy_to_tensor
if key == 'add':
module = AddExtra(val)
body_params = module(body_model, body_params, infos)
return body_params
def plot_meshes(img, meshes, K, R, T):
import cv2
mesh_camera = []
for mesh in meshes:
vertices = mesh['vertices'] @ R.T + T.T
v2d = vertices @ K.T
v2d[:, :2] = v2d[:, :2] / v2d[:, 2:3]
lw=1
col=(0,0,255)
for (x, y, d) in v2d[::10]:
cv2.circle(img, (int(x+0.5), int(y+0.5)), lw*2, col, -1)
return img
class MultiStage:
def __init__(self, batch_size, optimizer, monitor, initialize, stages) -> None:
self.batch_size = batch_size
self.optimizer_args = optimizer
self.monitor = monitor
self.initialize = initialize
self.stages = stages
def make_closure(self, body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module):
def closure(debug=False, ret_kpts=False):
# 0. Prepare body parameters => new_params
optimizer.zero_grad()
new_params = body_params.copy()
for module in before_after_module:
new_params = module.before(new_params)
# 1. Compute keypoints => kpts_est
poses_full = body_model.extend_poses(**new_params)
kpts_est = body_model(return_verts=False, return_tensor=True, **new_params)
if ret_kpts:
return kpts_est
verts_est = None
# 2. Compute loss => loss_dict
loss_dict = {}
for key, loss_func in loss_funcs.items():
if key.startswith('v'):
if verts_est is None:
verts_est = body_model(return_verts=True, return_tensor=True, **new_params)
loss_dict[key] = loss_func(verts_est=verts_est, **new_params, **infos)
elif key.startswith('pf-'):
loss_dict[key] = loss_func(poses_full=poses_full, **new_params, **infos)
else:
loss_dict[key] = loss_func(kpts_est=kpts_est, **new_params, **infos)
loss = sum([loss_dict[key]*weights[key]
for key in loss_dict.keys()])
if debug:
return loss_dict
loss.backward()
return loss
return closure
def optimizer_step(self, optimizer, closure, weights):
prev_loss = None
for iter_ in range(self.monitor.maxiters):
with torch.no_grad():
loss_dict = closure(debug=True)
if self.monitor.printloss or (self.monitor.verbose and iter_ == 0):
print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()]))
loss = optimizer.step(closure)
# check the loss
if torch.isnan(loss).sum() > 0:
print('[optimize] NaN loss value, stopping!')
break
if torch.isinf(loss).sum() > 0:
print('[optimize] Infinite loss value, stopping!')
break
# check the delta
if iter_ > 0 and prev_loss is not None:
loss_rel_change = rel_change(prev_loss, loss.item())
if loss_rel_change <= self.monitor.ftol:
if self.monitor.printloss or self.monitor.verbose:
print('{:-6d}: '.format(iter_) + ' '.join([key + ' %f'%(loss_dict[key].item()*weights[key]) for key in loss_dict.keys()]))
break
# log
if self.monitor.vis2d:
pass
if self.monitor.vis3d:
pass
prev_loss = loss.item()
return True
def fit_stage(self, body_model, body_params, infos, stage, irepeat):
# 单独拟合一个stage, 返回body_params
optimizer_args = stage.get('optimizer', self.optimizer_args)
dtype, device = body_model.dtype, body_model.device
body_params = process(stage.get('at_start', {'convert': 'numpy_to_tensor'}), body_model, body_params, infos)
opt_params = {}
if 'optimize' in stage.keys():
optimize_names = stage.optimize
else:
optimize_names = stage.optimizes[irepeat]
for key in optimize_names:
if key in infos.keys(): # 优化的参数
infos[key] = infos[key].to(device)
opt_params[key] = infos[key]
elif key in body_params.keys():
opt_params[key] = body_params[key]
else:
raise ValueError('{} is not in infos or body_params'.format(key))
if self.monitor.verbose:
print('[optimize] optimizing {}'.format(optimize_names))
for key, val in opt_params.items():
infos['init_'+key] = val.clone().detach().cpu()
# initialize keypoints
with torch.no_grad():
kpts_est = body_model.keypoints(body_params)
infos['init_kpts_est'] = kpts_est.clone().detach().cpu()
before_after_module = make_before_after(stage.get('before_after', {}), body_model, body_params, infos)
for module in before_after_module:
# Input to this module is tensor
body_params = module.start(body_params)
grad_require(opt_params, True)
optimizer = make_optimizer(opt_params, **optimizer_args)
loss_funcs, weights = make_lossfuncs(stage, infos, device, irepeat, self.monitor.verbose)
closure = self.make_closure(body_model, body_params, infos, loss_funcs, weights, optimizer, before_after_module)
if self.monitor.check:
new_params = body_params.copy()
for module in before_after_module:
new_params = module.before(new_params)
kpts_est = body_model.keypoints(new_params)
for key, loss in loss_funcs.items():
loss.check_at_start(kpts_est=kpts_est, **new_params)
self.optimizer_step(optimizer, closure, weights)
grad_require(opt_params, False)
if self.monitor.check:
new_params = body_params.copy()
for module in before_after_module:
new_params = module.before(new_params)
kpts_est = body_model.keypoints(new_params)
for key, loss in loss_funcs.items():
loss.check_at_end(kpts_est=kpts_est, **new_params)
for module in before_after_module:
# Input to this module is tensor
body_params = module.final(body_params)
body_params = dict_of_tensor_to_numpy(body_params)
for key, val in opt_params.items():
if key in infos.keys():
infos[key] = val.detach().cpu()
return body_params
def fit(self, body_model, dataset):
batch_size = len(dataset) if self.batch_size == -1 else self.batch_size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
if len(dataloader) > 1:
dataloader = tqdm(dataloader, desc='optimizing')
for data in dataloader:
data = dataset.reshape_data(data)
infos = data.copy()
init_params = body_model.init_params(nFrames=infos['nFrames'], nPerson=infos.get('nPerson', 1))
# first initialize the model
for name, init_func in self.initialize.items():
if 'loss' in init_func.keys():
# fitting to initialize
init_params = self.fit_stage(body_model, init_params, infos, init_func, 0)
else:
# use initialize module
init_module = load_object(init_func.module, init_func.args)
init_params = init_module(body_model, init_params, infos)
# if there are multiple initialization params
# then fit each of them
if not isinstance(init_params, list):
init_params = [init_params]
results = []
for init_param in init_params:
# check the repeat params
body_params = init_param
for stage_name, stage in self.stages.items():
for irepeat in range(stage.get('repeat', 1)):
with Timer('optimize {}'.format(stage_name), not self.monitor.timer):
body_params = self.fit_stage(body_model, body_params, infos, stage, irepeat)
results.append(body_params)
# select the best results
if len(results) > 1:
# check the result
loss = load_object(self.check.module, self.check.args, **{key:infos[key] for key in self.check.infos})
metrics = [loss(body_model.keypoints(body_params, return_tensor=True).cpu()).item() for body_params in results]
best_idx = np.argmin(metrics)
else:
best_idx = 0
if 'sync_offset' in body_params.keys():
offset = body_params.pop('sync_offset')
dataset.write_offset(offset)
body_params = Params(**results[best_idx])
if data['nFrames'] != body_params['poses'].shape[0]:
for key in body_params.keys():
if body_params[key].shape[0] == 1:continue
body_params[key] = body_params[key].reshape(data['nFrames'], -1, *body_params[key].shape[1:])
print(key, body_params[key].shape)
if 'K' in infos.keys():
camera = Params(K=infos['K'].numpy(), R=infos['Rc'].numpy(), T=infos['Tc'].numpy())
if 'mirror' in infos.keys():
camera['mirror'] = infos['mirror'].numpy()[None]
dataset.write(body_model, body_params, data, camera)
else:
# write data without camera
dataset.write(body_model, body_params, data)

View File

@ -0,0 +1,39 @@
'''
@ Date: 2022-08-12 20:34:15
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-08-18 14:47:23
@ FilePath: /EasyMocapPublic/easymocap/multistage/base_ops.py
'''
import torch
class BeforeAfterBase:
def __init__(self, model) -> None:
pass
def start(self, body_params):
# operation before the optimization
return body_params
def before(self, body_params):
# operation in each optimization step
return body_params
def final(self, body_params):
# operation after the optimization
return body_params
class SkipPoses(BeforeAfterBase):
def __init__(self, index, nPoses) -> None:
self.index = index
self.nPoses = nPoses
self.copy_index = [i for i in range(nPoses) if i not in index]
def before(self, body_params):
poses = body_params['poses']
poses_copy = torch.zeros_like(poses)
print(poses.shape)
poses_copy[..., self.copy_index] = poses[..., self.copy_index]
body_params['poses'] = poses_copy
return body_params

View File

@ -0,0 +1,57 @@
import torch
class Remove:
def __init__(self, key, index=[], ranges=[]) -> None:
self.key = key
self.ranges = ranges
self.index = index
def before(self, body_params):
val = body_params[self.key]
if self.ranges[0] == 0:
val_zeros = torch.zeros_like(val[:, :self.ranges[1]])
val = torch.cat([val_zeros, val[:, self.ranges[1]:]], dim=1)
body_params[self.key] = val
return body_params
class RemoveHand:
def __init__(self, start=60) -> None:
pass
def before(self, body_params):
poses = body_params['poses']
val_zeros = torch.zeros_like(poses[:, 60:])
val = torch.cat([poses[:, :60], val_zeros], dim=1)
body_params['poses'] = val
return body_params
class Keep:
def __init__(self, key, ranges=[], index=[]) -> None:
self.key = key
self.ranges = ranges
self.index = index
def before(self, body_params):
val = body_params[self.key]
val_zeros = val.detach().clone()
if len(self.ranges) > 0:
val_zeros[..., self.ranges[0]:self.ranges[1]] = val[..., self.ranges[0]:self.ranges[1]]
elif len(self.index) > 0:
val_zeros[..., self.index] = val[..., self.index]
body_params[self.key] = val_zeros
return body_params
def final(self, body_params):
return body_params
class VPoser2Full:
def __init__(self, key) -> None:
pass
def __call__(self, body_model, body_params, infos):
if not 'Embedding' in body_model.__class__.__name__:
return body_params
poses = body_params['poses']
poses_full = body_model.decode(poses, add_rot=False)
body_params['poses'] = poses_full
return body_params

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,100 @@
'''
@ Date: 2022-04-26 17:54:28
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-07-11 22:20:44
@ FilePath: /EasyMocapPublic/easymocap/multistage/init_cnn.py
'''
import os
import numpy as np
import cv2
from tqdm import tqdm
from os.path import join
import torch
from ..bodymodel.base import Params
from ..estimator.wrapper_base import bbox_from_keypoints
from ..mytools.writer import write_smpl
from ..mytools.reader import read_smpl
class InitSpin:
# initialize the smpl results by spin
def __init__(self, mean_params, ckpt_path, share_shape,
multi_person=False, compose_mp=False) -> None:
from ..estimator.SPIN.spin_api import SPIN
import torch
self.share_shape = share_shape
self.spin_model = SPIN(
SMPL_MEAN_PARAMS=mean_params,
checkpoint=ckpt_path,
device=torch.device('cpu'))
self.distortMap = {}
self.multi_person = multi_person
self.compose_mp = compose_mp
def undistort(self, image, K, dist, nv):
if np.linalg.norm(dist) < 0.01:
return image
if nv not in self.distortMap.keys():
h, w = image.shape[:2]
mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, K, (w,h), 5)
self.distortMap[nv] = (mapx, mapy)
mapx, mapy = self.distortMap[nv]
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
return image
def __call__(self, body_model, body_params, infos):
self.spin_model.model.to(body_model.device)
self.spin_model.device = body_model.device
params_all = []
for nf, imgname in enumerate(tqdm(infos['imgname'], desc='Run SPIN')):
# 暂时不考虑多视角情况
# TODO: 没有考虑多人的情况
basename = os.sep.join(imgname.split(os.sep)[-2:]).split('.')[0] + '.json'
sub = os.path.dirname(basename)
cache_dir = os.path.abspath(join(os.sep.join(imgname.split(os.sep)[:-3]), 'cache_spin'))
outname = join(cache_dir, basename)
if os.path.exists(outname):
params = read_smpl(outname)
if self.multi_person:
params_all.append(params)
else:
params_all.append(params[0])
continue
camera = {key: infos[key][nf].numpy() for key in ['K', 'Rc', 'Tc', 'dist']}
camera['R'] = camera['Rc']
camera['T'] = camera['Tc']
image = cv2.imread(imgname)
image = self.undistort(image, camera['K'], camera['dist'], sub)
if len(infos['keypoints2d'].shape) == 3:
k2d = infos['keypoints2d'][nf][None]
else:
k2d = infos['keypoints2d'][nf]
params_current = []
for pid in range(k2d.shape[0]):
keypoints = k2d[pid].numpy()
bbox = bbox_from_keypoints(keypoints)
nValid = (keypoints[:, -1] > 0).sum()
if nValid > 4:
result = self.spin_model(body_model, image,
bbox, keypoints, camera, ret_vertices=False)
elif len(params_all) == 0:
print('[WARN] not enough joints: {} in first frame'.format(imgname))
else:
print('[WARN] not enough joints: {}'.format(imgname))
result = {'body_params': params_all[-1][pid]}
params = result['body_params']
params['id'] = pid
params_current.append(params)
write_smpl(outname, params_current)
if self.multi_person:
params_all.append(params_current)
else:
params_all.append(params_current[0])
if not self.multi_person:
params_all = Params.merge(params_all, share_shape=self.share_shape)
params_all = body_model.encode(params_all)
elif self.compose_mp:
params_all = Params.merge([Params.merge(p_, share_shape=False) for p_ in params_all], share_shape=False, stack=np.stack)
params_all['id'] = 0
return params_all

View File

@ -0,0 +1,36 @@
'''
@ Date: 2022-04-02 13:59:50
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-07-13 16:34:21
@ FilePath: /EasyMocapPublic/easymocap/multistage/init_pose.py
'''
import os
import numpy as np
import cv2
from tqdm import tqdm
from os.path import join
import torch
from ..bodymodel.base import Params
from ..estimator.wrapper_base import bbox_from_keypoints
from ..mytools.writer import write_smpl
from ..mytools.reader import read_smpl
class SmoothPoses:
def __init__(self, window_size) -> None:
self.W = window_size
def __call__(self, body_model, body_params, infos):
poses = body_params['poses']
padding_before = poses[:1].copy().repeat(self.W, 0)
padding_after = poses[-1:].copy().repeat(self.W, 0)
mean = poses.copy()
nFrames = mean.shape[0]
poses_full = np.vstack([padding_before, poses, padding_after])
for w in range(1, self.W+1):
mean += poses_full[self.W-w:self.W-w+nFrames]
mean += poses_full[self.W+w:self.W+w+nFrames]
mean /= 2*self.W + 1
body_params['poses'] = mean
return body_params

View File

@ -0,0 +1,172 @@
import numpy as np
import cv2
from ..dataset.config import CONFIG
from ..config import load_object
from ..mytools.debug_utils import log, mywarn, myerror
import torch
from tqdm import tqdm, trange
def svd_rot(src, tgt, reflection=False, debug=False):
# optimum rotation matrix of Y
A = np.matmul(src.transpose(0, 2, 1), tgt)
U, s, Vt = np.linalg.svd(A, full_matrices=False)
V = Vt.transpose(0, 2, 1)
T = np.matmul(V, U.transpose(0, 2, 1))
# does the current solution use a reflection?
have_reflection = np.linalg.det(T) < 0
# if that's not what was specified, force another reflection
V[have_reflection, :, -1] *= -1
s[have_reflection, -1] *= -1
T = np.matmul(V, U.transpose(0, 2, 1))
if debug:
err = np.linalg.norm(tgt - src @ T.T, axis=1)
print('[svd] ', err)
return T
def batch_invRodrigues(rot):
res = []
for r in rot:
v = cv2.Rodrigues(r)[0]
res.append(v)
res = np.stack(res)
return res[:, :, 0]
class BaseInit:
def __init__(self) -> None:
pass
def __call__(self, body_model, body_params, infos):\
return body_params
class Remove(BaseInit):
def __init__(self, key, index) -> None:
super().__init__()
self.key = key
self.index = index
def __call__(self, body_model, body_params, infos):
infos[self.key][..., self.index, :] = 0
return super().__call__(body_model, body_params, infos)
class CheckKeypoints:
def __init__(self, type) -> None:
# this class is used to check if the provided keypoints3d
self.type = type
self.body_config = CONFIG[type]
self.hand_config = CONFIG['hand']
def __call__(self, body_model, body_params, infos):
for key in ['keypoints3d', 'handl3d', 'handr3d']:
if key not in infos.keys(): continue
keypoints = infos[key]
conf = keypoints[..., -1]
keypoints[conf<0.1] = 0
if key == 'keypoints3d':
continue
import ipdb;ipdb.set_trace()
# limb_length = np.linalg.norm(keypoints[:, , :3], axis=2)
return body_params
class InitRT:
def __init__(self, torso) -> None:
self.torso = torso
def __call__(self, body_model, body_params, infos):
keypoints3d = infos['keypoints3d'].detach().cpu().numpy()
temp_joints = body_model.keypoints(body_params, return_tensor=False)
torso = keypoints3d[..., self.torso, :3].copy()
torso_temp = temp_joints[..., self.torso, :3].copy()
# here use the first id of torso as the rotation center
root, root_temp = torso[..., :1, :], torso_temp[..., :1, :]
torso = torso - root
torso_temp = torso_temp - root_temp
conf = (keypoints3d[..., self.torso, 3] > 0.).all(axis=-1)
if not conf.all():
myerror("The torso in frames {} is not valid, please check the 3d keypoints".format(np.where(~conf)))
if len(torso.shape) == 3:
R = svd_rot(torso_temp, torso)
R_flat = R
T = np.matmul(- root_temp, R.transpose(0, 2, 1)) + root
else:
R_flat = svd_rot(torso_temp.reshape(-1, *torso_temp.shape[-2:]), torso.reshape(-1, *torso.shape[-2:]))
R = R_flat.reshape(*torso.shape[:2], 3, 3)
T = np.matmul(- root_temp, R.swapaxes(-1, -2)) + root
for nf in np.where(~conf)[0]:
# copy previous frames
mywarn('copy {} from {}'.format(nf, nf-1))
R[nf] = R[nf-1]
T[nf] = T[nf-1]
body_params['Th'] = T[..., 0, :]
rvec = batch_invRodrigues(R_flat)
if len(torso.shape) > 3:
rvec = rvec.reshape(*torso.shape[:2], 3)
body_params['Rh'] = rvec
return body_params
def __str__(self) -> str:
return "[Initialize] svd with torso: {}".format(self.torso)
class TriangulatorWrapper:
def __init__(self, module, args):
self.triangulator = load_object(module, args)
def __call__(self, body_model, body_params, infos):
infos['RT'] = torch.cat([infos['Rc'], infos['Tc']], dim=-1)
data = {
'RT': infos['RT'].numpy(),
}
for key in self.triangulator.keys:
if key not in infos.keys():
continue
data[key] = infos[key].numpy()
data[key+'_unproj'] = infos[key+'_unproj'].numpy()
data[key+'_distort'] = infos[key+'_distort'].numpy()
results = self.triangulator(data)[0]
for key, val in results.items():
if key == 'id': continue
infos[key] = torch.Tensor(val[None].astype(np.float32))
body_params = body_model.init_params(nFrames=1, add_scale=True)
return body_params
class CheckRT:
def __init__(self, T_thres, window):
self.T_thres = T_thres
self.window = window
def __call__(self, body_model, body_params, infos):
Th = body_params['Th']
if len(Th.shape) == 3:
for nper in range(Th.shape[1]):
for nf in trange(1, Th.shape[0], desc='Check Th of {}'.format(nper)):
if nf > self.window:
tpre = Th[nf-self.window:nf, nper]
else:
tpre = Th[:nf, nper]
tpre = tpre.mean(axis=0)
tnow = Th[nf , nper]
dist = np.linalg.norm(tnow - tpre)
if dist > self.T_thres:
mywarn('[Check Th] distance in frame {} = {} larger than {}'.format(nf, dist, self.T_thres))
Th[nf, nper] = tpre
body_params['Th'] = Th
return body_params
class Scale:
def __init__(self, keys):
self.keys = keys
def __call__(self, body_model, body_params, infos):
scale = body_params.pop('scale')[0, 0]
if scale < 1.1 and scale > 0.9:
return body_params
print('scale = ', scale)
for key in self.keys:
if key not in infos.keys():
continue
infos[key] /= scale
infos['Tc'] /= scale
infos['RT'][..., -1] *= scale
infos['scale'] = scale
return body_params

View File

@ -0,0 +1,589 @@
import numpy as np
import torch.nn as nn
import torch
from ..bodymodel.lbs import batch_rodrigues
class LossBase(nn.Module):
def __init__(self):
super().__init__()
def __str__(self) -> str:
return '# lack of comment'
def check_at_start(self, **kwargs):
pass
def check_at_end(self, **kwargs):
pass
class GMoF(nn.Module):
def __init__(self, rho=1):
super(GMoF, self).__init__()
self.rho2 = rho * rho
def extra_repr(self):
return 'rho = {}'.format(self.rho)
def forward(self, est, gt=None, conf=None):
if gt is not None:
square_diff = torch.sum((est - gt)**2, dim=-1)
else:
square_diff = torch.sum(est**2, dim=-1)
diff = torch.div(square_diff, square_diff + self.rho2)
if conf is not None:
res = torch.sum(diff * conf)/(1e-5 + conf.sum())
else:
res = diff.sum()/diff.numel()
return res
def make_loss(norm, norm_info, reduce='sum'):
reduce = torch.sum if reduce=='sum' else torch.mean
if norm == 'l2':
def loss(est, gt=None, conf=None):
if gt is not None:
square_diff = reduce((est - gt)**2, dim=-1)
else:
square_diff = reduce(est**2, dim=-1)
if conf is not None:
res = torch.sum(square_diff * conf)/(1e-5 + conf.sum())
else:
res = square_diff.sum()/square_diff.numel()
return res
elif norm == 'gm':
loss = GMoF(norm_info)
return loss
def select(value, ranges, index, dim):
if len(ranges) > 0:
if ranges[1] == -1:
value = value[..., ranges[0]:]
else:
value = value[..., ranges[0]:ranges[1]]
return value
if len(index) > 0:
if dim == -1:
value = value[..., index]
elif dim == -2:
value = value[..., index, :]
return value
return value
def print_table(header, contents):
from tabulate import tabulate
length = len(contents[0])
tables = [[] for _ in range(length)]
mean = ['Mean']
for icnt, content in enumerate(contents):
for i in range(length):
if isinstance(content[i], float):
tables[i].append('{:6.2f}'.format(content[i]))
else:
tables[i].append('{}'.format(content[i]))
if icnt > 0:
mean.append('{:6.2f}'.format(sum(content)/length))
tables.append(mean)
print(tabulate(tables, header, tablefmt='fancy_grid'))
class AnyReg(LossBase):
def __init__(self, key, norm, dim=-1, reduce='sum', norm_info={}, ranges=[], index=[], **kwargs):
super().__init__()
self.ranges = ranges
self.index = index
self.key = key
self.dim = dim
if 'init_' + key in kwargs.keys():
init = kwargs['init_'+key]
self.register_buffer('init', torch.Tensor(init))
else:
self.init = None
self.norm_name = norm
self.loss = make_loss(norm, norm_info, reduce=reduce)
def forward(self, **kwargs):
"""
value: (nFrames, ..., nDims)
"""
value = kwargs[self.key]
if self.init is not None:
value = value - self.init
value = select(value, self.ranges, self.index, self.dim)
return self.loss(value)
def __str__(self) -> str:
return 'Loss for {}'.format(self.key, self.norm_name)
class RegPrior(AnyReg):
def __init__(self, **cfg):
super().__init__(**cfg)
self.init = None # disable init
infos = {
(2, 0): '-exp',
(2, 1): 'l2',
(2, 2): 'l2',
(3, 0): '-exp', # knee
(3, 1): 'L2',
(3, 2): 'L2',
(4, 0): '-exp', # knee
(4, 1): 'L2',
(4, 2): 'L2',
(5, 0): '-exp',
(5, 1): 'l2',
(5, 2): 'l2',
(6, 1): 'L2',
(6, 2): 'L2',
(7, 1): 'L2',
(7, 2): 'L2',
(8, 0): '-exp',
(8, 1): 'l2',
(8, 2): 'l2',
(9, 0): 'L2',
(9, 1): 'L2',
(9, 2): 'L2',
(10, 0): 'L2',
(10, 1): 'L2',
(10, 2): 'L2',
(12, 0): 'l2', # 肩关节前面
(13, 0): 'l2',
(17, 1): 'exp',
(17, 2): 'L2',
(18, 1): '-exp',
(18, 2): 'L2',
}
self.l2dims = []
self.L2dims = []
self.expdims = []
self.nexpdims = []
for (nj, ndim), norm in infos.items():
dim = nj*3 + ndim
if norm == 'l2':
self.l2dims.append(dim)
elif norm == 'L2':
self.L2dims.append(dim)
elif norm == '-exp':
self.nexpdims.append(dim)
elif norm == 'exp':
self.expdims.append(dim)
def forward(self, poses, **kwargs):
"""
poses: (..., nDims)
"""
alll2loss = torch.mean(poses**2)
l2loss = torch.sum(poses[:, self.l2dims]**2)/len(self.l2dims)/poses.shape[0]
L2loss = torch.sum(poses[:, self.L2dims]**2)/len(self.L2dims)/poses.shape[0]
exploss = torch.sum(torch.exp(poses[:, self.expdims]))/poses.shape[0]
nexploss = torch.sum(torch.exp(-poses[:, self.nexpdims]))/poses.shape[0]
loss = 0.1*l2loss + L2loss + 0.0005*(exploss + nexploss)/(len(self.expdims) + len(self.nexpdims)) + 0.01*alll2loss
return loss
class VPoserPrior(AnyReg):
def __init__(self, **cfg):
super().__init__(**cfg)
vposer_ckpt = 'data/bodymodels/vposer_v02'
from human_body_prior.tools.model_loader import load_model
from human_body_prior.models.vposer_model import VPoser
vposer, _ = load_model(vposer_ckpt,
model_code=VPoser,
remove_words_in_model_weights='vp_model.',
disable_grad=True)
vposer.eval()
self.vposer = vposer
self.init = None # disable init
def forward(self, poses, **kwargs):
"""
poses: (..., nDims)
"""
nDims = 63
poses_body = poses[..., :nDims].reshape(-1, nDims)
latent = self.vposer.encode(poses_body)
if True:
ret = self.vposer.decode(latent.sample())['pose_body'].reshape(poses.shape[0], nDims)
return super().forward(poses=poses_body-ret)
else:
return super().forward(poses=latent.mean)
class AnySmooth(LossBase):
def __init__(self, key, weight, norm, norm_info={}, ranges=[], index=[], dim=-1, order=1):
super().__init__()
self.ranges = ranges
self.index = index
self.dim = dim
self.weight = weight
self.loss = make_loss(norm, norm_info)
self.norm_name = norm
self.key = key
self.order = order
def forward(self, **kwargs):
loss = 0
value = kwargs[self.key]
value = select(value, self.ranges, self.index, self.dim)
if value.shape[0] <= len(self.weight):
return torch.FloatTensor([0.]).to(value.device)
for width, weight in enumerate(self.weight, start=1):
vel = value[width:] - value[:-width]
if self.order == 2:
vel = vel[1:] - vel[:-1]
loss += weight * self.loss(vel)
return loss
def check(self, value):
vel = value[1:] - value[:-1]
if len(vel.shape) > 2:
vel = torch.norm(vel, dim=-1)
else:
vel = torch.abs(vel)
vel = vel.detach().cpu().numpy()
return vel
def get_check_name(self, value):
name = [str(i) for i in range(value.shape[1])]
return name
def check_at_start(self, **kwargs):
value = kwargs[self.key]
if value.shape[0] < len(self.weight):
return 0
header = ['Smooth '+self.key, 'mean(before)', 'max(before)', 'frame(before)']
name = self.get_check_name(value)
vel = self.check(value)
contents = [name, vel.mean(axis=0).tolist(), vel.max(axis=0).tolist(), vel.argmax(axis=0).tolist()]
self.cache_check = (header, contents)
return super().check_at_start(**kwargs)
def check_at_end(self, **kwargs):
value = kwargs[self.key]
if value.shape[0] < len(self.weight):
return 0
err_after = self.check(kwargs[self.key])
header, contents = self.cache_check
header.extend(['mean(after)', 'max(after)', 'frame(after)'])
contents.extend([err_after.mean(axis=0).tolist(), err_after.max(axis=0).tolist(), err_after.argmax(axis=0).tolist()])
print_table(header, contents)
def __str__(self) -> str:
return "smooth in {} frames, range={}, norm={}".format(self.weight, self.ranges, self.norm_name)
class SmoothRot(AnySmooth):
def __init__(self, **kwargs):
super().__init__(**kwargs)
from ..bodymodel.lbs import batch_rodrigues
self.rodrigues = batch_rodrigues
def convert_Rh_to_R(self, Rh):
shape = Rh.shape[1]
ret = []
for i in range(shape//3):
Rot = self.rodrigues(Rh[:, 3*i:3*(i+1)])
ret.append(Rot)
ret = torch.cat(ret, dim=1)
return ret
def forward(self, **kwargs):
Rh = kwargs[self.key]
if Rh.shape[-1] != 3:
loss = 0
for i in range(Rh.shape[-1]//3):
Rh_sub = Rh[..., 3*i:3*i+3]
Rot = self.convert_Rh_to_R(Rh_sub).view(*Rh_sub.shape[:-1], 3, 3)
loss += super().forward(**{self.key: Rot})
return loss
else:
Rh_flat = Rh.view(-1, 3)
Rot = self.convert_Rh_to_R(Rh_flat).view(*Rh.shape[:-1], 3, 3)
return super().forward(**{self.key: Rot})
def get_check_name(self, value):
name = ['angle']
return name
def check(self, value):
import cv2
# TODO: here just use first rotation
if len(value.shape) == 3:
value = value[:, 0]
Rot = self.convert_Rh_to_R(value.detach())[:, :3]
vel = torch.matmul(Rot[:-1], Rot.transpose(1,2)[1:]).cpu().numpy()
vels = []
for i in range(vel.shape[0]):
angle = np.linalg.norm(cv2.Rodrigues(vel[i])[0])
vels.append(angle)
vel = np.array(vels).reshape(-1, 1)
return vel
class BaseKeypoints(LossBase):
@staticmethod
def select(keypoints, index, ranges):
if len(index) > 0:
keypoints = keypoints[..., index, :]
elif len(ranges) > 0:
if ranges[1] == -1:
keypoints = keypoints[..., ranges[0]:, :]
else:
keypoints = keypoints[..., ranges[0]:ranges[1], :]
return keypoints
def set_gt(self, index_gt, ranges_gt):
keypoints = self.select(self.keypoints_np, index_gt, ranges_gt)
keypoints = torch.Tensor(keypoints)
self.register_buffer('keypoints', keypoints[..., :-1])
self.register_buffer('conf', keypoints[..., -1])
def __str__(self):
return "keypoints: {}".format(self.keypoints.shape)
def __init__(self, keypoints, norm='l2', norm_info={},
index_gt=[], ranges_gt=[],
index_est=[], ranges_est=[]) -> None:
super().__init__()
# prepare ground-truth
self.keypoints_np = keypoints
self.set_gt(index_gt, ranges_gt)
#
self.index_est = index_est
self.ranges_est = ranges_est
self.norm = norm
self.loss = make_loss(norm, norm_info)
def forward(self, kpts_est, **kwargs):
est = self.select(kpts_est, self.index_est, self.ranges_est)
return self.loss(est, self.keypoints, self.conf)
def check(self, kpts_est, min_conf=0.3, **kwargs):
est = self.select(kpts_est, self.index_est, self.ranges_est)
conf = (self.conf>min_conf).float()
norm = torch.norm(est-self.keypoints, dim=-1) * conf
mean_joints = norm.sum(dim=0)/(1e-5 + conf.sum(dim=0)) * 1000
return conf, mean_joints
def check_at_start(self, kpts_est, **kwargs):
if len(self.index_est) > 0:
names = [str(i) for i in self.index_est]
elif len(self.ranges_est) > 0:
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
else:
names = [str(i) for i in range(self.conf.shape[-1])]
conf, error = self.check(kpts_est, **kwargs)
valid = conf.sum(dim=0).detach().cpu().numpy().tolist()
header = ['name', 'count']
contents = [names, valid]
header.append('before')
contents.append(error.detach().cpu().numpy().tolist())
self.cache_check = (header, contents)
def check_at_end(self, kpts_est, **kwargs):
conf, err_after = self.check(kpts_est, **kwargs)
header, contents = self.cache_check
header.append('after')
contents.append(err_after.detach().cpu().numpy().tolist())
print_table(header, contents)
class Keypoints3D(BaseKeypoints):
def __init__(self, keypoints3d, **kwargs) -> None:
super().__init__(keypoints3d, **kwargs)
class AnyKeypoints3D(Keypoints3D):
def __init__(self, **kwargs) -> None:
key = kwargs.pop('key')
keypoints3d = kwargs.pop(key)
super().__init__(keypoints3d, **kwargs)
self.key = key
class AnyKeypoints3DWithRT(Keypoints3D):
def __init__(self, **kwargs) -> None:
key = kwargs.pop('key')
keypoints3d = kwargs.pop(key)
super().__init__(keypoints3d, **kwargs)
self.key = key
def forward(self, kpts_est, **kwargs):
R = batch_rodrigues(kwargs['R_'+self.key])
T = kwargs['T_'+self.key]
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
return super().forward(RXT)
def check(self, kpts_est, min_conf=0.3, **kwargs):
R = batch_rodrigues(kwargs['R_'+self.key])
T = kwargs['T_'+self.key]
RXT = torch.matmul(kpts_est, R.transpose(-1, -2)) + T[..., None, :]
kpts_est = RXT
return super().check(kpts_est, min_conf)
class Handl3D(BaseKeypoints):
def __init__(self, handl3d, **kwargs) -> None:
handl3d = handl3d.clone()
handl3d[..., :3] = handl3d[..., :3] - handl3d[:, :1, :3]
super().__init__(handl3d, **kwargs)
def forward(self, kpts_est, **kwargs):
est = kpts_est[:, 25:46]
est = est - est[:, :1].detach()
return super().forward(est, **kwargs)
def check(self, kpts_est, **kwargs):
est = kpts_est[:, 25:46]
est = est - est[:, :1].detach()
return super().check(est, **kwargs)
class LimbLength(BaseKeypoints):
def __init__(self, kintree, key='keypoints3d', **kwargs):
self.kintree = np.array(kintree)
if key == 'bodyhand':
keypoints3d = np.hstack([kwargs.pop('keypoints3d'), kwargs.pop('handl3d'), kwargs.pop('handr3d')])
else:
keypoints3d = kwargs.pop(key)
super().__init__(keypoints3d, **kwargs)
def __str__(self):
return "Limb of: {}".format(','.join(['[{},{}]'.format(i,j) for (i,j) in self.kintree]))
def set_gt(self, index_gt, ranges_gt):
keypoints3d = self.keypoints_np
kintree = self.kintree
# limb_length: nFrames, nLimbs, 1
limb_length = np.linalg.norm(keypoints3d[..., kintree[:, 1], :3] - keypoints3d[..., kintree[:, 0], :3], axis=-1, keepdims=True)
# conf: nFrames, nLimbs, 1
limb_conf = np.minimum(keypoints3d[..., kintree[:, 1], -1], keypoints3d[..., kintree[:, 0], -1])
limb_length = torch.Tensor(limb_length)
limb_conf = torch.Tensor(limb_conf)
self.register_buffer('length', limb_length)
self.register_buffer('conf', limb_conf)
def forward(self, kpts_est, **kwargs):
src = kpts_est[..., self.kintree[:, 0], :]
dst = kpts_est[..., self.kintree[:, 1], :]
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
return self.loss(length_est, self.length, self.conf)
def check_at_start(self, kpts_est, **kwargs):
names = [str(i) for i in self.kintree]
conf = (self.conf>0)
valid = conf.sum(dim=0).detach().cpu().numpy()
if len(valid.shape) == 2:
valid = valid.mean(axis=0)
header = ['name', 'count']
contents = [names, valid.tolist()]
error, length = self.check(kpts_est)
header.append('before')
contents.append(error.detach().cpu().numpy().tolist())
header.append('length')
length = (self.length[..., 0] * self.conf).sum(dim=0)/self.conf.sum(dim=0)
contents.append(length.detach().cpu().numpy().tolist())
self.cache_check = (header, contents)
def check_at_end(self, kpts_est, **kwargs):
err_after, length = self.check(kpts_est)
header, contents = self.cache_check
header.append('after')
contents.append(err_after.detach().cpu().numpy().tolist())
header.append('length_est')
contents.append(length[:,:,0].mean(dim=0).detach().cpu().numpy().tolist())
print_table(header, contents)
def check(self, kpts_est, **kwargs):
src = kpts_est[..., self.kintree[:, 0], :]
dst = kpts_est[..., self.kintree[:, 1], :]
length_est = torch.norm(dst - src, dim=-1, keepdim=True)
conf = (self.conf>0).float()
norm = torch.abs(length_est-self.length)[..., 0] * conf
mean_joints = norm.sum(dim=0)/conf.sum(dim=0) * 1000
if len(mean_joints.shape) == 2:
mean_joints = mean_joints.mean(dim=0)
length_est = length_est.mean(dim=0)
return mean_joints, length_est
class LimbLengthHand(LimbLength):
def __init__(self, handl3d, handr3d, **kwargs):
kintree = kwargs.pop('kintree')
keypoints3d = torch.cat([handl3d, handr3d], dim=0)
super().__init__(kintree, keypoints3d, **kwargs)
def forward(self, kpts_est, **kwargs):
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
return super().forward(kpts_est, **kwargs)
def check(self, kpts_est, **kwargs):
kpts_est = torch.cat([kpts_est[:, :21], kpts_est[:, 21:]], dim=0)
return super().check(kpts_est, **kwargs)
class Keypoints2D(BaseKeypoints):
def __init__(self, keypoints2d, K, Rc, Tc, einsum='fab,fnb->fna',
unproj=True, reshape_views=False, **kwargs) -> None:
# convert to camera coordinate
invKtrans = torch.inverse(K).transpose(-1, -2)
if unproj:
homo = torch.ones_like(keypoints2d[..., :1])
homo = torch.cat([keypoints2d[..., :2], homo], dim=-1)
if len(invKtrans.shape) < len(homo.shape):
invKtrans = invKtrans.unsqueeze(-3)
homo = torch.matmul(homo, invKtrans)
keypoints2d = torch.cat([homo[..., :2], keypoints2d[..., 2:]], dim=-1)
# keypoints2d: (nFrames, nViews, ..., nJoints, 3)
super().__init__(keypoints2d, **kwargs)
self.register_buffer('K', K)
self.register_buffer('invKtrans', invKtrans)
self.register_buffer('Rc', Rc)
self.register_buffer('Tc', Tc)
self.unproj = unproj
self.einsum = einsum
self.reshape_views = reshape_views
def project(self, kpts_est):
kpts_est = self.select(kpts_est, self.index_est, self.ranges_est)
kpts_homo = torch.ones_like(kpts_est[..., -1:])
kpts_homo = torch.cat([kpts_est, kpts_homo], dim=-1)
if self.unproj:
P = torch.cat([self.Rc, self.Tc], dim=-1)
else:
P = torch.bmm(self.K, torch.cat([self.Rc, self.Tc], dim=-1))
if self.reshape_views:
kpts_homo = kpts_homo.reshape(self.K.shape[0], self.K.shape[1], *kpts_homo.shape[1:])
try:
point_cam = torch.einsum(self.einsum, P, kpts_homo)
except:
print('Wrong shape: {}x{} <=== {}'.format(P.shape, kpts_homo.shape, self.einsum))
raise NotImplementedError
img_points = point_cam[..., :2]/point_cam[..., 2:]
return img_points
def forward(self, kpts_est, **kwargs):
img_points = self.project(kpts_est)
loss = self.loss(img_points.squeeze(), self.keypoints.squeeze(), self.conf.squeeze())
return loss
def check(self, kpts_est, min_conf=0.3):
with torch.no_grad():
img_points = self.project(kpts_est)
conf = (self.conf>min_conf)
err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf
if len(err.shape) == 3:
err = err.sum(dim=1)
conf = conf.sum(dim=1)
err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0))
return conf, err
def check_at_start(self, kpts_est, **kwargs):
if len(self.index_est) > 0:
names = [str(i) for i in self.index_est]
elif len(self.ranges_est) > 0:
names = [str(i) for i in range(self.ranges_est[0], self.ranges_est[1])]
else:
names = [str(i) for i in range(self.conf.shape[-1])]
conf, error = self.check(kpts_est)
valid = conf.sum(dim=0).detach().cpu().numpy()
valid = valid.tolist()
header = ['name', 'count']
contents = [names, valid]
header.append('before(pix)')
contents.append(error.detach().cpu().numpy().tolist())
self.cache_check = (header, contents)
def check_at_end(self, kpts_est, **kwargs):
conf, err_after = self.check(kpts_est)
header, contents = self.cache_check
header.append('after(pix)')
contents.append(err_after.detach().cpu().numpy().tolist())
print_table(header, contents)

View File

@ -0,0 +1,261 @@
'''
@ Date: 2022-07-12 11:55:47
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-07-14 17:57:48
@ FilePath: /EasyMocapPublic/easymocap/multistage/mirror.py
'''
import numpy as np
import torch
from ..dataset.mirror import flipPoint2D, flipSMPLPoses, flipSMPLParams
from ..estimator.wrapper_base import bbox_from_keypoints
from .lossbase import Keypoints2D
def calc_vanishpoint(keypoints2d):
'''
keypoints2d: (2, N, 3)
'''
# weight: (N, 1)
weight = keypoints2d[:, :, 2:].mean(axis=0)
conf = weight.mean()
A = np.hstack([
keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2],
-(keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1])
])
b = -keypoints2d[0, :, 0:1]*(keypoints2d[1, :, 1:2] - keypoints2d[0, :, 1:2]) \
+ keypoints2d[0, :, 1:2] * (keypoints2d[1, :, 0:1] - keypoints2d[0, :, 0:1])
b = -b
A = A * weight
b = b * weight
avgInsec = np.linalg.inv(A.T @ A) @ (A.T @ b)
result = np.zeros(3)
result[0] = avgInsec[0, 0]
result[1] = avgInsec[1, 0]
result[2] = 1
return result
def calc_mirror_transform(m_):
""" From mirror vector to mirror matrix
Args:
m (bn, 4): (a, b, c, d)
Returns:
M: (bn, 3, 4)
"""
norm = torch.norm(m_[:, :3], dim=1, keepdim=True)
m = m_[:, :3] / norm
d = m_[:, 3]
coeff_mat = torch.zeros((m.shape[0], 3, 4), device=m.device)
coeff_mat[:, 0, 0] = 1 - 2*m[:, 0]**2
coeff_mat[:, 0, 1] = -2*m[:, 0]*m[:, 1]
coeff_mat[:, 0, 2] = -2*m[:, 0]*m[:, 2]
coeff_mat[:, 0, 3] = -2*m[:, 0]*d
coeff_mat[:, 1, 0] = -2*m[:, 1]*m[:, 0]
coeff_mat[:, 1, 1] = 1-2*m[:, 1]**2
coeff_mat[:, 1, 2] = -2*m[:, 1]*m[:, 2]
coeff_mat[:, 1, 3] = -2*m[:, 1]*d
coeff_mat[:, 2, 0] = -2*m[:, 2]*m[:, 0]
coeff_mat[:, 2, 1] = -2*m[:, 2]*m[:, 1]
coeff_mat[:, 2, 2] = 1-2*m[:, 2]**2
coeff_mat[:, 2, 3] = -2*m[:, 2]*d
return coeff_mat
class InitNormal:
def __init__(self, static) -> None:
self.static = static
def __call__(self, body_model, body_params, infos):
if 'normal' in infos.keys():
print('>>> Reading normal: {}'.format(infos['normal']))
return body_params
kpts = infos['keypoints2d']
kpts0 = kpts[:, 0]
kpts1 = flipPoint2D(kpts[:, 1])
vanish_line = torch.stack([kpts0.reshape(-1, 3), kpts1.reshape(-1, 3)], dim=1)
MIN_THRES = 0.5
conf = (vanish_line[:, 0, -1] > MIN_THRES) & (vanish_line[:, 1, -1] > MIN_THRES)
vanish_line = vanish_line[conf]
vline0 = vanish_line.numpy().transpose(1, 0, 2)
vpoint0 = calc_vanishpoint(vline0).reshape(1, 3)
# 计算点到线的距离进行检查
# two points line: (x1, y1), (x2, y2) ==> (y-y1)/(x-x1) = (y2-y1)/(x2-x1)
# A = y2 - y1
# B = x1 - x2
# C = x2y1 - x1y2
# d = abs(ax + by + c)/sqrt(a^2+b^2)
A_v0 = kpts0[:, :, 1] - vpoint0[0, 1]
B_v0 = vpoint0[0, 0] - kpts0[:, :, 0]
C_v0 = kpts0[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts0[:, :, 1]
distance01 = np.abs(A_v0 * kpts1[:, :, 0] + B_v0 * kpts1[:, :, 1] + C_v0)/np.sqrt(A_v0*A_v0 + B_v0*B_v0)
A_v1 = kpts1[:, :, 1] - vpoint0[0, 1]
B_v1 = vpoint0[0, 0] - kpts1[:, :, 0]
C_v1 = kpts1[:, :, 0]*vpoint0[0, 1] - vpoint0[0, 0]*kpts1[:, :, 1]
distance10 = np.abs(A_v1 * kpts0[:, :, 0] + B_v1 * kpts0[:, :, 1] + C_v1)/np.sqrt(A_v1*A_v1 + B_v1*B_v1)
DIST_THRES = 0.05
for nf in range(kpts.shape[0]):
# 计算scale
bbox0 = bbox_from_keypoints(kpts0[nf].cpu().numpy())
bbox1 = bbox_from_keypoints(kpts1[nf].cpu().numpy())
bbox_size0 = max(bbox0[2]-bbox0[0], bbox0[3]-bbox0[1])
bbox_size1 = max(bbox1[2]-bbox1[0], bbox1[3]-bbox1[1])
valid = (kpts0[nf, :, 2] > 0.3) & (kpts1[nf, :, 2] > 0.3)
dist01_ = valid*distance01[nf] / bbox_size1
dist10_ = valid*distance10[nf] / bbox_size0
# 对于距离异常的点阈值设定为0.1
# 抑制掉置信度低的视角的点
not_valid0 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] < kpts1[nf][:, -1]))[0]
not_valid1 = np.where((dist01_ + dist10_ > DIST_THRES*2) & (kpts0[nf][:, -1] > kpts1[nf][:, -1]))[0]
kpts0[nf, not_valid0] = 0.
kpts1[nf, not_valid1] = 0.
if len(not_valid0) > 0:
print('[mirror] filter {} person 0: {}'.format(nf, not_valid0))
if len(not_valid1) > 0:
print('[mirror] filter {} person 1: {}'.format(nf, not_valid1))
kpts1_ = flipPoint2D(kpts1)
infos['keypoints2d'] = torch.stack([kpts0, kpts1_], dim=1)
infos['vanish_point0'] = torch.Tensor(vpoint0)
K = infos['K'][0]
normal = np.linalg.inv(K) @ vpoint0.T
normal = normal.T/np.linalg.norm(normal)
print('>>> Calculating normal from keypoints: {}'.format(normal[0]))
infos['normal'] = torch.Tensor(normal)
mirror = torch.zeros((1, 4))
# 计算镜子平面到相机的距离
Th = body_params['Th']
center = Th.mean(axis=1)
# 相机原点到两个人中心的连线在normal上的投影
dist = (center * normal).sum(axis=-1).mean()
print('>>> Calculating distance from Th: {}'.format(dist))
mirror[0, 3] = - dist # initial guess
mirror[:, :3] = infos['normal']
infos['mirror'] = mirror
return body_params
class RemoveP1:
def __init__(self, static) -> None:
self.static = static
def __call__(self, body_model, body_params, infos):
for key in body_params.keys():
if key == 'id': continue
body_params[key] = body_params[key][:, 0]
return body_params
class Mirror:
def __init__(self, key) -> None:
self.key = key
def before(self, body_params):
poses = body_params['poses'][:, 0]
# append root
poses = torch.cat([torch.zeros_like(poses[..., :3]), poses], dim=-1)
poses_mirror = flipSMPLPoses(poses)
poses = torch.cat([poses[:, None, 3:], poses_mirror[:, None, 3:]], dim=1)
body_params['poses'] = poses
return body_params
def after(self,):
pass
def final(self, body_params):
return self.before(body_params)
class Keypoints2DMirror(Keypoints2D):
def __init__(self, mirror, opt_normal, **kwargs):
super().__init__(**kwargs)
if not mirror.requires_grad:
self.register_buffer('mirror', mirror)
else:
self.mirror = mirror
self.opt_normal = opt_normal
k2dall = kwargs['keypoints2d']
size_all = []
for nf in range(k2dall.shape[0]):
for nper in range(2):
kpts = k2dall[nf, nper]
bbox = bbox_from_keypoints(kpts.cpu().numpy())
bbox_size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size_all.append(bbox_size)
size_all = np.array(size_all).reshape(-1, 2)
scale = (size_all[:, 0] / size_all[:, 1]).mean()
print('[loss] mean scale = {} from {} frames, use this to balance the two person'.format(scale, size_all.shape[0]))
# ATTN: here we use v^2 to suppress the outlier detections
self.conf = self.conf * self.conf
self.conf[:, 1] *= scale*scale
def check(self, kpts_est, min_conf=0.3):
with torch.no_grad():
M = calc_mirror_transform(self.mirror)
homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device)
kpts_homo = torch.cat([kpts_est, homo], dim=-1)
kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2))
kpts = torch.stack([kpts_est, kpts_mirror], dim=1)
img_points = self.project(kpts)
conf = (self.conf>min_conf)
err = self.K[..., 0:1, 0].mean() * torch.norm(img_points - self.keypoints, dim=-1) * conf
if len(err.shape) == 3:
err = err.sum(dim=1)
conf = conf.sum(dim=1)
err = err.sum(dim=0)/(1e-5 + conf.sum(dim=0))
return conf, err
def forward(self, kpts_est, **kwargs):
if self.opt_normal:
M = calc_mirror_transform(self.mirror)
else:
mirror = torch.cat([self.mirror[:, :3].detach(), self.mirror[:, 3:]], dim=1)
M = calc_mirror_transform(mirror)
homo = torch.ones((*kpts_est.shape[:-1], 1), device=kpts_est.device)
kpts_homo = torch.cat([kpts_est, homo], dim=-1)
kpts_mirror = flipPoint2D(torch.matmul(M, kpts_homo.transpose(1, 2)).transpose(1, 2))
kpts = torch.stack([kpts_est, kpts_mirror], dim=1)
return super().forward(kpts_est=kpts, **kwargs)
class MirrorPoses:
def __init__(self, ref) -> None:
self.ref = ref
def __call__(self, body_model, body_params, infos):
# shapes: (nFrames, 2, nShapes)
shapes = body_params['shapes'].mean(axis=0).mean(axis=0).reshape(1, 1, -1)
poses = body_params['poses'][:, 0]
# append root
poses = np.concatenate([np.zeros([poses.shape[0], 3]), poses], axis=1)
poses_mirror = flipSMPLPoses(poses)
poses = np.concatenate([poses[:, None, 3:], poses_mirror[:, None, 3:]], axis=1)
body_params['poses'] = poses
body_params['shapes'] = shapes
return body_params
class MirrorParams:
def __init__(self, key) -> None:
self.key = key
def start(self, body_params):
if len(body_params['poses'].shape) == 2:
return body_params
for key in body_params.keys():
if key == 'id': continue
body_params[key] = body_params[key][:, 0]
return body_params
def before(self, body_params):
return body_params
def after(self,):
pass
def final(self, body_params):
device = body_params['poses'].device
body_params = {key:val.detach().cpu().numpy() for key, val in body_params.items()}
body_params['poses'] = np.hstack((np.zeros_like(body_params['poses'][:, :3]), body_params['poses']))
params_mirror = flipSMPLParams(body_params, self.infos['mirror'].cpu().numpy())
params = {}
for key in params_mirror.keys():
if key == 'shapes':
params[key] = body_params[key][:, None]
else:
params[key] = np.concatenate([body_params[key][:, None], params_mirror[key][:, None]], axis=-2)
params['poses'] = params['poses'][..., 3:]
params = {key:torch.Tensor(val).to(device) for key, val in params.items()}
return params

View File

@ -0,0 +1,79 @@
'''
@ Date: 2022-03-11 12:13:01
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-08-11 21:52:00
@ FilePath: /EasyMocapPublic/easymocap/multistage/synchronization.py
'''
import numpy as np
import torch
class AddTime:
def __init__(self, gt) -> None:
self.gt = gt
def __call__(self, body_model, body_params, infos):
nViews = infos['keypoints2d'].shape[1]
offset = np.zeros((nViews,), dtype=np.float32)
body_params['sync_offset'] = offset
return body_params
class Interpolate:
def __init__(self, actfn) -> None:
# self.act_fn = lambda x: 2*torch.nn.functional.softsign(x)
self.act_fn = lambda x: 2*torch.tanh(x)
self.use0asref = False
def get_offset(self, time_offset):
if self.use0asref:
off = self.act_fn(torch.cat([torch.zeros(1, device=time_offset.device), time_offset[1:]]))
else:
off = self.act_fn(time_offset)
return off
def start(self, body_params):
return body_params
def before(self, body_params):
off = self.get_offset(body_params['sync_offset'])
nViews = off.shape[0]
if len(body_params['poses'].shape) == 2:
off = off[None, :, None]
else:
off = off[None, :, None, None]
for key in body_params.keys():
if key in ['sync_offset', 'shapes']:
continue
# TODO: Rh有正周期旋转的时候会有问题
val = body_params[key]
if key == 'Rh':
pass
if key in ['Th', 'poses']:
velocity = torch.cat([val[1:2] - val[0:1], val[1:] - val[:-1]], dim=0)
valnew = val[:, None] + off * velocity[:, None]
# vel = velocity.detach().cpu().numpy()
# import matplotlib.pyplot as plt
# plt.plot(vel)
# plt.show()
# import ipdb;ipdb.set_trace()
else:
if len(val.shape) == 2:
valnew = val[:, None].repeat(1, nViews, 1)
elif len(val.shape) == 3:
valnew = val[:, None].repeat(1, nViews, 1, 1)
else:
print('[warn] Unknown {} shape {}'.format(key, valnew.shape))
import ipdb; ipdb.set_trace()
valnew = valnew.reshape(-1, *val.shape[1:])
body_params[key] = valnew
return body_params
def after(self,):
pass
def final(self, body_params):
off = self.get_offset(body_params['sync_offset'])
body_params = self.before(body_params)
body_params['sync_offset'] = off
return body_params

View File

@ -0,0 +1,517 @@
"""
useful functions to perform conversion between rotation in different format(quaternion, rotation_matrix, euler_angle, axis_angle)
quaternion representation: (w,x,y,z)
code reference: torchgeometry, kornia, https://github.com/MandyMo/pytorch_HMR.
"""
import torch
from torch.nn import functional as F
import numpy as np
# Conversions between different rotation representations, quaternion,rotation matrix,euler and axis angle.
def rot6d_to_rotation_matrix(rot6d):
"""
Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Args:
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
Returns:
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
"""
x = rot6d.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def rotation_matrix_to_rot6d(rotation_matrix):
"""
Convert 3x3 rotation matrix to 6D rotation representation.
Args:
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
Returns:
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
"""
v1 = rotation_matrix[:, :, 0:1]
v2 = rotation_matrix[:, :, 1:2]
rot6d = torch.cat([v1, v2], dim=-1).reshape(v1.shape[0], 6)
return rot6d
def quaternion_to_rotation_matrix(quaternion):
"""
Convert quaternion coefficients to rotation matrix.
Args:
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
Returns:
rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3)
"""
norm_quaternion = quaternion
norm_quaternion = norm_quaternion / \
norm_quaternion.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quaternion[:, 0], norm_quaternion[:,
1], norm_quaternion[:, 2], norm_quaternion[:, 3]
batch_size = quaternion.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w*x, w*y, w*z
xy, xz, yz = x*y, x*z, y*z
rotation_matrix = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(batch_size, 3, 3)
return rotation_matrix
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
"""
Convert rotation matrix to corresponding quaternion
Args:
rotation_matrix: torch tensor of shape (batch_size, 3, 3)
Returns:
quaternion: torch tensor of shape(batch_size, 4) in (w, x, y, z) representation.
"""
rmat_t = torch.transpose(rotation_matrix, 1, 2)
mask_d2 = rmat_t[:, 2, 2] < eps
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
t0_rep = t0.repeat(4, 1).t()
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
t1_rep = t1.repeat(4, 1).t()
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
t2_rep = t2.repeat(4, 1).t()
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
t3_rep = t3.repeat(4, 1).t()
mask_c0 = mask_d2 * mask_d0_d1
mask_c1 = mask_d2 * (~ mask_d0_d1)
mask_c2 = (~ mask_d2) * mask_d0_nd1
mask_c3 = (~ mask_d2) * (~ mask_d0_nd1)
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q *= 0.5
return q
def quaternion_to_euler(quaternion, order, epsilon=0):
"""
Convert quaternion to euler angles.
Args:
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
order: euler angle representation order, 'zyx' etc.
epsilon:
Returns:
euler: torch tensor of shape (batch_size, 3) in order.
"""
assert quaternion.shape[-1] == 4
original_shape = list(quaternion.shape)
original_shape[-1] = 3
q = quaternion.contiguous().view(-1, 4)
q0 = q[:, 0]
q1 = q[:, 1]
q2 = q[:, 2]
q3 = q[:, 3]
if order == 'xyz':
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q2 * q2))
y = torch.asin(torch.clamp(
2 * (q1 * q3 + q0 * q2), -1+epsilon, 1-epsilon))
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q2 * q2 + q3 * q3))
elif order == 'yzx':
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2*(q1 * q1 + q3 * q3))
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q2 * q2 + q3 * q3))
z = torch.asin(torch.clamp(
2 * (q1 * q2 + q0 * q3), -1+epsilon, 1-epsilon))
elif order == 'zxy':
x = torch.asin(torch.clamp(
2 * (q0 * q1 + q2 * q3), -1+epsilon, 1-epsilon))
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2*(q1 * q1 + q2 * q2))
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2*(q1 * q1 + q3 * q3))
elif order == 'xzy':
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q3 * q3))
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2*(q2 * q2 + q3 * q3))
z = torch.asin(torch.clamp(
2 * (q0 * q3 - q1 * q2), -1+epsilon, 1-epsilon))
elif order == 'yxz':
x = torch.asin(torch.clamp(
2 * (q0 * q1 - q2 * q3), -1+epsilon, 1-epsilon))
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2*(q1 * q1 + q2 * q2))
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2*(q1 * q1 + q3 * q3))
elif order == 'zyx':
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2*(q1 * q1 + q2 * q2))
y = torch.asin(torch.clamp(
2 * (q0 * q2 - q1 * q3), -1+epsilon, 1-epsilon))
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2*(q2 * q2 + q3 * q3))
else:
raise Exception('unsupported euler order!')
return torch.stack((x, y, z), dim=1).view(original_shape)
def euler_to_quaternion(euler, order):
"""
Convert euler angles to quaternion.
Args:
euler: torch tensor of shape (batch_size, 3) in order.
order:
Returns:
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
"""
assert euler.shape[-1] == 3
original_shape = list(euler.shape)
original_shape[-1] = 4
e = euler.reshape(-1, 3)
x = e[:, 0]
y = e[:, 1]
z = e[:, 2]
rx = torch.stack((torch.cos(x/2), torch.sin(x/2),
torch.zeros_like(x), torch.zeros_like(x)), dim=1)
ry = torch.stack((torch.cos(y/2), torch.zeros_like(y),
torch.sin(y/2), torch.zeros_like(y)), dim=1)
rz = torch.stack((torch.cos(z/2), torch.zeros_like(z),
torch.zeros_like(z), torch.sin(z/2)), dim=1)
result = None
for coord in order:
if coord == 'x':
r = rx
elif coord == 'y':
r = ry
elif coord == 'z':
r = rz
else:
raise Exception('unsupported euler order!')
if result is None:
result = r
else:
result = quaternion_mul(result, r)
# Reverse antipodal representation to have a non-negative "w"
if order in ['xyz', 'yzx', 'zxy']:
result *= -1
return result.reshape(original_shape)
def quaternion_to_axis_angle(quaternion):
"""
Convert quaternion to axis angle.
based on: https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
Args:
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
Returns:
axis_angle: torch tensor of shape (batch_size, 3)
"""
epsilon = 1.e-8
if not torch.is_tensor(quaternion):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(quaternion)))
if not quaternion.shape[-1] == 4:
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
.format(quaternion.shape))
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
q3: torch.Tensor = quaternion[..., 3]
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta+epsilon)
cos_theta: torch.Tensor = quaternion[..., 0]
two_theta: torch.Tensor = 2.0 * torch.where(
cos_theta < 0.0,
torch.atan2(-sin_theta, -cos_theta),
torch.atan2(sin_theta, cos_theta))
k_pos: torch.Tensor = two_theta / sin_theta
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
angle_axis[..., 0] += q1 * k
angle_axis[..., 1] += q2 * k
angle_axis[..., 2] += q3 * k
return angle_axis
def axis_angle_to_quaternion(axis_angle):
"""
Convert axis angle to quaternion.
Args:
axis_angle: torch tensor of shape (batch_size, 3)
Returns:
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
"""
rotation_matrix = axis_angle_to_rotation_matrix(axis_angle)
return rotation_matrix_to_quaternion(rotation_matrix)
def axis_angle_to_rotation_matrix(axis_angle):
"""
Convert axis-angle representation to rotation matrix.
Args:
axis_angle: torch tensor of shape (batch_size, 3).
Returns:
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
"""
l1_norm = torch.norm(axis_angle+1e-8, p=2, dim=1)
angle = torch.unsqueeze(l1_norm, dim=-1)
normalized = torch.div(axis_angle, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quaternion = torch.cat([v_cos, v_sin*normalized], dim=1)
return quaternion_to_rotation_matrix(quaternion)
def rotation_matrix_to_axis_angle(rotation_matrix):
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
return quaternion_to_axis_angle(quaternion)
def rotation_matrix_to_euler(rotation_matrix, order):
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
return quaternion_to_euler(quaternion, order)
def euler_to_rotation_matrix(euler, order):
quaternion = euler_to_quaternion(euler, order)
return quaternion_to_rotation_matrix(quaternion)
def axis_angle_to_euler(axis_angle, order):
quaternion = axis_angle_to_quaternion(axis_angle)
return quaternion_to_euler(quaternion, order)
def euler_to_axis_angle(euler, order):
quaternion = euler_to_quaternion(euler, order)
return quaternion_to_axis_angle(quaternion)
# rotation operations
def quaternion_mul(q, r):
"""
Multiply quaternion(s) q with quaternion(s) r.
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
Returns q*r as a tensor of shape (*, 4).
"""
assert q.shape[-1] == 4
assert r.shape[-1] == 4
original_shape = q.shape
# Compute outer product
terms = torch.bmm(r.contiguous().view(-1, 4, 1),
q.contiguous().view(-1, 1, 4))
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
return torch.stack((w, x, y, z), dim=1).view(original_shape)
def rotate_vec_by_quaternion(v, q):
"""
Rotate vector(s) v about the rotation described by quaternion(s) q.
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
where * denotes any number of dimensions.
Returns a tensor of shape (*, 3).
"""
assert q.shape[-1] == 4
assert v.shape[-1] == 3
assert q.shape[:-1] == v.shape[:-1]
original_shape = list(v.shape)
q = q.contiguous().view(-1, 4)
v = v.view(-1, 3)
qvec = q[:, 1:]
uv = torch.cross(qvec, v, dim=1)
uuv = torch.cross(qvec, uv, dim=1)
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
def quaternion_fix(quaternion):
"""
Enforce quaternion continuity across the time dimension by selecting
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
between two consecutive frames.
Args:
quaternion: torch tensor of shape (batch_size, 4)
Returns:
quaternion: torch tensor of shape (batch_size, 4)
"""
quaternion_fixed = quaternion.clone()
dot_products = torch.sum(quaternion[1:]*quaternion[:-1],dim=-1)
mask = dot_products < 0
mask = (torch.cumsum(mask, dim=0) % 2).bool()
quaternion_fixed[1:][mask] *= -1
return quaternion_fixed
def quaternion_inverse(quaternion):
q_conjugate = quaternion.clone()
q_conjugate[::, 1:] * -1
q_norm = quaternion[::, 1:].norm(dim=-1) + quaternion[::, 0]**2
return q_conjugate/q_norm.unsqueeze(-1)
def quaternion_lerp(q1, q2, t):
q = (1-t)*q1 + t*q2
q = q/q.norm(dim=-1).unsqueeze(-1)
return q
def geodesic_dist(q1,q2):
"""
@q1: torch tensor of shape (frame, joints, 4) quaternion
@q2: same as q1
@output: torch tensor of shape (frame, joints)
"""
q1_conjugate = q1.clone()
q1_conjugate[:,:,1:] *= -1
q1_norm = q1[:,:,1:].norm(dim=-1) + q1[:,:,0]**2
q1_inverse = q1_conjugate/q1_norm.unsqueeze(dim=-1)
q_between = quaternion_mul(q1_inverse,q2)
geodesic_dist = quaternion_to_axis_angle(q_between).norm(dim=-1)
return geodesic_dist
def get_extrinsic(translation, rotation):
batch_size = translation.shape[0]
pose = torch.zeros((batch_size, 4, 4))
pose[:,:3, :3] = rotation
pose[:,:3, 3] = translation
pose[:,3, 3] = 1
extrinsic = torch.inverse(pose)
return extrinsic[:,:3, 3], extrinsic[:,:3, :3]
def euler_fix_old(euler):
frame_num = euler.shape[0]
joint_num = euler.shape[1]
for l in range(3):
for j in range(joint_num):
overall_add = 0.
for i in range(1,frame_num):
add1 = overall_add
add2 = overall_add + 2*np.pi
add3 = overall_add - 2*np.pi
previous = euler[i-1,j,l]
value1 = euler[i,j,l] + add1
value2 = euler[i,j,l] + add2
value3 = euler[i,j,l] + add3
e1 = torch.abs(value1 - previous)
e2 = torch.abs(value2 - previous)
e3 = torch.abs(value3 - previous)
if (e1 <= e2) and (e1 <= e3):
euler[i,j,l] = value1
overall_add = add1
if (e2 <= e1) and (e2 <= e3):
euler[i, j, l] = value2
overall_add = add2
if (e3 <= e1) and (e3 <= e2):
euler[i, j, l] = value3
overall_add = add3
return euler
def euler_fix(euler,rotation_order='zyx'):
frame_num = euler.shape[0]
joint_num = euler.shape[1]
euler_new = euler.clone()
for j in range(joint_num):
euler_new[:,j] = euler_filter(euler[:,j],rotation_order)
return euler_new
'''
euler filter from https://github.com/wesen/blender-euler-filter/blob/master/euler_filter.py.
'''
def euler_distance(e1, e2):
return abs(e1[0] - e2[0]) + abs(e1[1] - e2[1]) + abs(e1[2] - e2[2])
def euler_axis_index(axis):
if axis == 'x':
return 0
if axis == 'y':
return 1
if axis == 'z':
return 2
return None
def flip_euler(euler, rotation_mode):
ret = euler.clone()
inner_axis = rotation_mode[0]
outer_axis = rotation_mode[2]
middle_axis = rotation_mode[1]
ret[euler_axis_index(inner_axis)] += np.pi
ret[euler_axis_index(outer_axis)] += np.pi
ret[euler_axis_index(middle_axis)] *= -1
ret[euler_axis_index(middle_axis)] += np.pi
return ret
def naive_flip_diff(a1, a2):
while abs(a1 - a2) >= np.pi+1e-5:
if a1 < a2:
a2 -= 2 * np.pi
else:
a2 += 2 * np.pi
return a2
def euler_filter(euler,rotation_order):
frame_num = euler.shape[0]
if frame_num <= 1:
return euler
euler_fix = euler.clone()
prev = euler[0]
for i in range(1,frame_num):
e = euler[i]
for d in range(3):
e[d] = naive_flip_diff(prev[d],e[d])
fe = flip_euler(e,rotation_order)
for d in range(3):
fe[d] = naive_flip_diff(prev[d],fe[d])
de = euler_distance(prev,e)
dfe = euler_distance(prev,fe)
if dfe < de:
e = fe
prev = e
euler_fix[i] = e
return euler_fix

View File

@ -0,0 +1,97 @@
'''
@ Date: 2022-07-28 14:39:23
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-08-12 21:42:12
@ FilePath: /EasyMocapPublic/easymocap/multistage/totalfitting.py
'''
import torch
from ..bodymodel.lbs import batch_rodrigues
from .torchgeometry import rotation_matrix_to_axis_angle, rotation_matrix_to_quaternion, quaternion_to_rotation_matrix, quaternion_to_axis_angle
import numpy as np
from .base_ops import BeforeAfterBase
def compute_twist_rotation(rotation_matrix, twist_axis):
'''
Compute the twist component of given rotation and twist axis
https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis
Parameters
----------
rotation_matrix : Tensor (B, 3, 3,)
The rotation to convert
twist_axis : Tensor (B, 3,)
The twist axis
Returns
-------
Tensor (B, 3, 3)
The twist rotation
'''
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
twist_axis = twist_axis / (torch.norm(twist_axis, dim=1, keepdim=True) + 1e-9)
projection = torch.einsum('bi,bi->b', twist_axis, quaternion[:, 1:]).unsqueeze(-1) * twist_axis
twist_quaternion = torch.cat([quaternion[:, 0:1], projection], dim=1)
twist_quaternion = twist_quaternion / (torch.norm(twist_quaternion, dim=1, keepdim=True) + 1e-9)
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
twist_aa = quaternion_to_axis_angle(twist_quaternion)
twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
return twist_rotation, twist_angle
class ClearTwist(BeforeAfterBase):
def start(self, body_params):
idx_elbow = [18-1, 19-1]
for idx in idx_elbow:
# x
body_params['poses'][:, 3*idx] = 0.
# z
body_params['poses'][:, 3*idx+2] = 0.
idx_wrist = [20-1, 21-1]
for idx in idx_wrist:
body_params['poses'][:, 3*idx:3*idx+3] = 0.
return body_params
class SolveTwist(BeforeAfterBase):
def __init__(self, body_model=None) -> None:
self.body_model = body_model
def final(self, body_params):
T_joints, T_vertices = self.body_model.transform(body_params)
# This transform don't consider RT
R = batch_rodrigues(body_params['Rh'])
template = self.body_model.keypoints({'shapes': body_params['shapes'],
'poses': torch.zeros_like(body_params['poses'])},
only_shape=True, return_smpl_joints=True)
config = {
'left': {
'index_smpl': 20,
'index_elbow_smpl': 18,
'R_global': 'R_handl3d',
'axis': torch.Tensor([[1., 0., 0.]]).to(device=T_joints.device),
},
'right': {
'index_smpl': 21,
'index_elbow_smpl': 19,
'R_global': 'R_handr3d',
'axis': torch.Tensor([[-1., 0., 0.]]).to(device=T_joints.device),
}
}
for key in ['left', 'right']:
cfg = config[key]
R_wrist_add = batch_rodrigues(body_params[cfg['R_global']])
idx_elbow = cfg['index_elbow_smpl']
idx_wrist = cfg['index_smpl']
pred_parent_elbow = R @ T_joints[..., idx_elbow, :3, :3]
pred_parent_wrist = R @ T_joints[..., idx_wrist, :3, :3]
pred_global_wrist = torch.bmm(R_wrist_add, pred_parent_wrist)
pred_local_wrist = torch.bmm(pred_parent_wrist.transpose(-1, -2), pred_global_wrist)
axis = rotation_matrix_to_axis_angle(pred_local_wrist)
body_params['poses'][..., 3*idx_wrist-3:3*idx_wrist] = axis
return body_params