EasyMocap/easymocap/dataset/filter.py

111 lines
3.9 KiB
Python
Raw Normal View History

2021-06-25 15:54:54 +08:00
import numpy as np
class BaseCrit:
def __init__(self, log, **kwargs) -> None:
self.log = log
def __call__(self, keypoints, bbox, **kwargs) -> bool:
return True
def __str__(self) -> str:
return "default filter"
class CritMinJoints(BaseCrit):
def __init__(self, min_joints, log, **kwargs):
super().__init__(log)
self.min_joints = min_joints
def __call__(self, keypoints, **kwargs):
return (keypoints[:, 2] > 0.).sum() > self.min_joints
def __str__(self) -> str:
return "remove the detections less than {} joints".format(self.min_joints)
class CritWithTorso(BaseCrit):
def __init__(self, torso_idx, min_conf, log, **kwargs) -> None:
super().__init__(log)
self.idx = torso_idx
self.min_conf = min_conf
def __call__(self, keypoints, bbox, **kwargs) -> bool:
return (keypoints[self.idx, 2] > self.min_conf).all()
def __str__(self) -> str:
return "remove the human without torso {}".format(self.idx)
class CritNoBorder(BaseCrit):
def __init__(self, rate, height, width, log) -> None:
super().__init__(log)
self.height = height
self.width = width
self.border = rate * max(self.height, self.width)
self.leftidx = [3, 4, 10, 11]
self.rightidx = [6, 7, 13, 14]
def __call__(self, keypoints, bbox, **kwargs) -> bool:
l, t, r, b, c = bbox[:5]
if t < self.border: # 跳过上面部分被截掉的
pass
if l < self.border or r > self.width - self.border:
if self.log:print('[Crit2d]: {}'.format(' '.join(['%8.3f'%(i) for i in bbox])))
if self.log:print('[Error] Left or right')
dist = np.linalg.norm(keypoints[self.leftidx, :2] - keypoints[self.rightidx, :2], axis=1)
bbox_size = b - t
dist = dist/bbox_size
if dist.min() < 1e-2:
return False
else:
return True
if b > self.height:
if self.log:print('[Error] bottom')
return True
def __str__(self) -> str:
return "remove the human in the border"
class ComposedFilter:
def __init__(self, filters, min_conf) -> None:
self.filters = filters
self.min_conf = min_conf
def __call__(self, keypoints, **kwargs) -> bool:
conf = keypoints[:, 2]
conf[conf<self.min_conf] = 0
valid = conf>self.min_conf
center = keypoints[valid, :2].mean(axis=0, keepdims=True)
keypoints[conf<self.min_conf, :2] = center
for filt in self.filters:
if not filt(keypoints=keypoints, **kwargs):
return False
return True
def nms(self, annots):
# This function do nothing
if len(annots) < 2:
return annots
keypoints = np.stack([annot['keypoints'] for annot in annots])
bbox = np.stack([annot['bbox'] for annot in annots])
bbox_size = np.max(np.abs(bbox[:, [1, 3]] - bbox[:, [0, 2]]), axis=1)
bbox_size = np.maximum(bbox_size[:, None], bbox_size[None, :])
dist = np.linalg.norm(keypoints[:, None, :, :2] - keypoints[None, :, :, :2], axis=-1)
conf = (keypoints[:, None, :, 2] > 0) * (keypoints[None, :, :, 2] > 0)
dist = (dist * conf).sum(axis=2)/conf.sum(axis=2)/bbox_size
return annots
def __str__(self) -> str:
indent = ' ' * 4
res = indent + 'Composed Filters: \n'
for filt in self.filters:
res_ = indent + indent + '{:15s}'.format(filt.__class__.__name__) + ': ' + str(filt) + '\n'
res += res_
return res
def make_filter(param):
filters = []
for key, val in param.filter.items():
filters.append(globals()[key](log=param.log, width=param.width, height=param.height, **val))
comp = ComposedFilter(filters, param.min_conf)
print(comp)
return comp