135 lines
4.1 KiB
Python
135 lines
4.1 KiB
Python
'''
|
|
@ Date: 2022-03-17 19:23:59
|
|
@ Author: Qing Shuai
|
|
@ Mail: s_q@zju.edu.cn
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2022-07-15 12:15:46
|
|
@ FilePath: /EasyMocapPublic/easymocap/bodymodel/base.py
|
|
'''
|
|
import numpy as np
|
|
import torch
|
|
|
|
from ..mytools.file_utils import myarray2string
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.name = 'custom'
|
|
|
|
def forward(self):
|
|
pass
|
|
|
|
def vertices(self, params, **kwargs):
|
|
return self.forward(return_verts=True, **kwargs, **params)
|
|
|
|
def keypoints(self, params, **kwargs):
|
|
return self.forward(return_verts=False, **kwargs, **params)
|
|
|
|
def transform(self, params, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
class ComposedModel(torch.nn.Module):
|
|
def __init__(self, config_dict):
|
|
# 叠加多个模型的配置
|
|
for name, config in config_dict.items():
|
|
pass
|
|
|
|
class Params(dict):
|
|
@classmethod
|
|
def merge(self, params_list, share_shape=True, stack=np.vstack):
|
|
output = {}
|
|
for key in params_list[0].keys():
|
|
if key == 'id':continue
|
|
output[key] = stack([v[key] for v in params_list])
|
|
if share_shape:
|
|
output['shapes'] = output['shapes'].mean(axis=0, keepdims=True)
|
|
return output
|
|
|
|
def __len__(self):
|
|
return len(self['poses'])
|
|
|
|
def __getattr__(self, name):
|
|
if name in self:
|
|
return self[name]
|
|
else:
|
|
raise AttributeError(name)
|
|
|
|
def __getitem__(self, index):
|
|
if not isinstance(index, int):
|
|
return super().__getitem__(index)
|
|
if 'shapes' not in self.keys():
|
|
# arbitray data
|
|
ret = {}
|
|
for key, val in self.items():
|
|
if index >= 1 and val.shape[0] == 1:
|
|
ret[key] = val[0]
|
|
else:
|
|
ret[key] = val[index]
|
|
return Params(**ret)
|
|
ret = {'id': 0}
|
|
poses = self.poses
|
|
shapes = self.shapes
|
|
while len(shapes.shape) < len(poses.shape):
|
|
shapes = shapes[None]
|
|
if poses.shape[0] == shapes.shape[0]:
|
|
if index >= 1 and shapes.shape[0] == 1:
|
|
ret['shapes'] = shapes[0]
|
|
else:
|
|
ret['shapes'] = shapes[index]
|
|
elif shapes.shape[0] == 1:
|
|
ret['shapes'] = shapes[0]
|
|
else:
|
|
import ipdb; ipdb.set_trace()
|
|
if index >= 1 and poses.shape[0] == 1:
|
|
ret['poses'] = poses[0]
|
|
else:
|
|
ret['poses'] = poses[index]
|
|
for key, val in self.items():
|
|
if key == 'id':
|
|
ret[key] = self[key]
|
|
continue
|
|
if key in ret.keys():continue
|
|
if index >= 1 and val.shape[0] == 1:
|
|
ret[key] = val[0]
|
|
else:
|
|
ret[key] = val[index]
|
|
for key, val in ret.items():
|
|
if key == 'id': continue
|
|
if len(val.shape) == 1:
|
|
ret[key] = val[None]
|
|
return Params(**ret)
|
|
|
|
def to_multiperson(self, pids):
|
|
results = []
|
|
for i, pid in enumerate(pids):
|
|
param = self[i]
|
|
# TODO: this class just implement getattr
|
|
# param.id = pid # is wrong
|
|
param['id'] = pid
|
|
results.append(param)
|
|
return results
|
|
|
|
def __str__(self) -> str:
|
|
ret = ''
|
|
lastkey = list(self.keys())[-1]
|
|
for key, val in self.items():
|
|
if isinstance(val, np.ndarray):
|
|
ret += '"{}": '.format(key) + myarray2string(val, indent=0)
|
|
else:
|
|
ret += '"{}": '.format(key) + str(val)
|
|
if key != lastkey:
|
|
ret += ',\n'
|
|
return ret
|
|
|
|
def shape(self):
|
|
ret = ''
|
|
lastkey = list(self.keys())[-1]
|
|
for key, val in self.items():
|
|
if isinstance(val, np.ndarray):
|
|
ret += '"{}": {}'.format(key, val.shape)
|
|
else:
|
|
ret += '"{}": '.format(key) + str(val)
|
|
if key != lastkey:
|
|
ret += ',\n'
|
|
print(ret)
|
|
return ret |