EasyMocap/easymocap/bodymodel/base.py

135 lines
4.1 KiB
Python
Raw Normal View History

2022-08-21 16:02:11 +08:00
'''
@ Date: 2022-03-17 19:23:59
@ Author: Qing Shuai
@ Mail: s_q@zju.edu.cn
@ LastEditors: Qing Shuai
@ LastEditTime: 2022-07-15 12:15:46
@ FilePath: /EasyMocapPublic/easymocap/bodymodel/base.py
'''
import numpy as np
import torch
from ..mytools.file_utils import myarray2string
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.name = 'custom'
def forward(self):
pass
def vertices(self, params, **kwargs):
return self.forward(return_verts=True, **kwargs, **params)
def keypoints(self, params, **kwargs):
return self.forward(return_verts=False, **kwargs, **params)
def transform(self, params, **kwargs):
raise NotImplementedError
class ComposedModel(torch.nn.Module):
def __init__(self, config_dict):
# 叠加多个模型的配置
for name, config in config_dict.items():
pass
class Params(dict):
@classmethod
def merge(self, params_list, share_shape=True, stack=np.vstack):
output = {}
for key in params_list[0].keys():
if key == 'id':continue
output[key] = stack([v[key] for v in params_list])
if share_shape:
output['shapes'] = output['shapes'].mean(axis=0, keepdims=True)
return output
def __len__(self):
return len(self['poses'])
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __getitem__(self, index):
if not isinstance(index, int):
return super().__getitem__(index)
if 'shapes' not in self.keys():
# arbitray data
ret = {}
for key, val in self.items():
if index >= 1 and val.shape[0] == 1:
ret[key] = val[0]
else:
ret[key] = val[index]
return Params(**ret)
ret = {'id': 0}
poses = self.poses
shapes = self.shapes
while len(shapes.shape) < len(poses.shape):
shapes = shapes[None]
if poses.shape[0] == shapes.shape[0]:
if index >= 1 and shapes.shape[0] == 1:
ret['shapes'] = shapes[0]
else:
ret['shapes'] = shapes[index]
elif shapes.shape[0] == 1:
ret['shapes'] = shapes[0]
else:
import ipdb; ipdb.set_trace()
if index >= 1 and poses.shape[0] == 1:
ret['poses'] = poses[0]
else:
ret['poses'] = poses[index]
for key, val in self.items():
if key == 'id':
ret[key] = self[key]
continue
if key in ret.keys():continue
if index >= 1 and val.shape[0] == 1:
ret[key] = val[0]
else:
ret[key] = val[index]
for key, val in ret.items():
if key == 'id': continue
if len(val.shape) == 1:
ret[key] = val[None]
return Params(**ret)
def to_multiperson(self, pids):
results = []
for i, pid in enumerate(pids):
param = self[i]
# TODO: this class just implement getattr
# param.id = pid # is wrong
param['id'] = pid
results.append(param)
return results
def __str__(self) -> str:
ret = ''
lastkey = list(self.keys())[-1]
for key, val in self.items():
if isinstance(val, np.ndarray):
ret += '"{}": '.format(key) + myarray2string(val, indent=0)
else:
ret += '"{}": '.format(key) + str(val)
if key != lastkey:
ret += ',\n'
return ret
def shape(self):
ret = ''
lastkey = list(self.keys())[-1]
for key, val in self.items():
if isinstance(val, np.ndarray):
ret += '"{}": {}'.format(key, val.shape)
else:
ret += '"{}": '.format(key) + str(val)
if key != lastkey:
ret += ',\n'
print(ret)
return ret