EasyMocap/myeasymocap/backbone/yolo/yolo.py
2023-07-10 22:10:41 +08:00

307 lines
11 KiB
Python

import torch
import numpy as np
import os
import cv2
from os.path import join
import pickle
def check_modelpath(paths):
if isinstance(paths, str):
assert os.path.exists(paths), paths
return paths
elif isinstance(paths, list):
for path in paths:
if os.path.exists(path):
print(f'Found model in {path}')
break
else:
print(f'No model found in {paths}!')
raise FileExistsError
return path
else:
raise NotImplementedError
class BaseYOLOv5:
def __init__(self, ckpt=None, model='yolov5m', name='object2d', multiview=True) -> None:
if ckpt is not None:
ckpt = check_modelpath(ckpt)
self.model = torch.hub.load('ultralytics/yolov5', 'custom', ckpt)
else:
print('[{}] Not given ckpt, use default yolov5'.format(self.__class__.__name__))
self.model = torch.hub.load('ultralytics/yolov5', model)
self.multiview = multiview
self.name = name
def check_cache(self, imgname):
basename = os.path.basename(imgname)
imgext = '.' + basename.split('.')[-1]
nv = imgname.split(os.sep)[-2]
cachename = join(self.output, self.name, nv, basename.replace(imgext, '.npy'))
os.makedirs(os.path.dirname(cachename), exist_ok=True)
if os.path.exists(cachename):
output = np.load(cachename, allow_pickle=True)
return True, output, cachename
else:
return False, None, cachename
def check_image(self, img_or_name):
if isinstance(img_or_name, str):
images = cv2.imread(img_or_name)
else:
images = img_or_name
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
return images
@torch.no_grad()
def detect(self, image, imgname):
flag, cache, cachename = self.check_cache(imgname)
if flag:
return cache
image = self.check_image(imgname)
results = self.model(image) #RGB images[:,:,::-1]
arrays = np.array(results.pandas().xyxy[0])
np.save(cachename, arrays)
return arrays
@staticmethod
def select_class(results, name):
select = []
for i, res in enumerate(results):
classname = res[6]
if classname != name:
continue
box = res[:5]
select.append(box)
return select
def select_bbox(self, select, imgname):
if select.shape[0] == 0:
return select
# Naive: select the best
idx = np.argsort(select[:, -1])[::-1]
return select[idx[0:1]]
def __call__(self, images, imgnames): # 这里好像默认是多视角了,需要继承一下单视角的
squeeze = False
if not isinstance(images, list):
images = [images]
imgnames = [imgnames]
squeeze = True
detects = {'bbox': [[] for _ in range(len(images))]}
for nv in range(len(images)):
res = self.detect(images[nv], imgnames[nv])
select = self.select_class(res, self.name)
if len(select) == 0:
select = np.zeros((0,5), dtype=np.float32)
else:
select = np.stack(select).astype(np.float32)
# TODO: add track here
select = self.select_bbox(select, imgnames[nv])
detects['bbox'][nv] = select
if squeeze:
detects['bbox'] = detects['bbox'][0]
return detects
class YoloWithTrack(BaseYOLOv5):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.track_cache = {}
@staticmethod
def calculate_iou(bbox_pre, bbox_now):
area_now = (bbox_now[:, 2] - bbox_now[:, 0])*(bbox_now[:, 3]-bbox_now[:, 1])
area_pre = (bbox_pre[:, 2] - bbox_pre[:, 0])*(bbox_pre[:, 3]-bbox_pre[:, 1])
# compute IOU
# max of left
xx1 = np.maximum(bbox_now[:, 0], bbox_pre[:, 0])
yy1 = np.maximum(bbox_now[:, 1], bbox_pre[:, 1])
# min of right
xx2 = np.minimum(bbox_now[:, 0+2], bbox_pre[:, 0+2])
yy2 = np.minimum(bbox_now[:, 1+2], bbox_pre[:, 1+2])
# w h
w = np.maximum(0, xx2 - xx1)
h = np.maximum(0, yy2 - yy1)
over = (w*h)/(area_pre+area_now-w*h)
return over
def select_bbox(self, select, imgname):
if select.shape[0] == 0:
return select
sub = os.path.basename(os.path.dirname(imgname))
frame = int(os.path.basename(imgname).split('.')[0])
if sub not in self.track_cache:
# select the best
select = super().select_bbox(select, imgname)
self.track_cache[sub] = {
'frame': [frame],
'bbox': [select]
}
return select
bbox_pre = self.track_cache[sub]['bbox'][-1]
iou = self.calculate_iou(bbox_pre, select)
idx = iou.argmax()
select = select[idx:idx+1]
self.track_cache[sub]['frame'].append(frame)
self.track_cache[sub]['bbox'].append(select)
return select
class MultiPerson(BaseYOLOv5):
def __init__(self, min_length, max_length, **kwargs):
super().__init__(**kwargs)
self.min_length = min_length
self.max_length = max_length
print('[{}] Only keep the bbox in [{}, {}]'.format(self.__class__.__name__, min_length, max_length))
def select_bbox(self, select, imgname):
if select.shape[0] == 0:
return select
# 判断一下面积
area = np.sqrt((select[:, 2] - select[:, 0])*(select[:, 3]-select[:, 1]))
valid = (area > self.min_length) & (area < self.max_length)
return select[valid]
class DetectToPelvis:
def __init__(self, key) -> None:
self.key = key
self.multiview = True
def __call__(self, **kwargs):
key = self.key
val = kwargs[key]
ret = {'pelvis': []}
for nv in range(len(val)):
bbox = val[nv]
center = np.stack([(bbox[:, 0] + bbox[:, 2])/2, (bbox[:, 1] + bbox[:, 3])/2, bbox[:, -1]], axis=-1)
ret['pelvis'].append(center)
return ret
class Yolo_model:
def __init__(self, mode, yolo_ckpt, multiview, repo_or_dir = 'ultralytics/yolov5', source='github') -> None:
yolo_ckpt = check_modelpath(yolo_ckpt)
self.model = torch.hub.load(repo_or_dir, 'custom', yolo_ckpt, source=source)
self.min_detect_thres = 0.3
self.mode = mode # 'fullimg' # 'bboxcrop'
self.output = 'output'
self.name = 'yolo'
self.multiview = multiview
@torch.no_grad()
def det_step(self, img_or_name, imgname, bbox=[]):
basename = os.path.basename(imgname)
if self.multiview:
nv = imgname.split('/')[-2]
cachename = join(self.output, self.name, nv, basename.replace('.jpg', '.pkl'))
else:
cachename = join(self.output, self.name, basename.replace('.jpg', '.pkl'))
os.makedirs(os.path.dirname(cachename), exist_ok=True)
if os.path.exists(cachename):
with open(cachename, 'rb') as f:
output = pickle.load(f)
return output
if isinstance(img_or_name,str):
images = cv2.imread(img_or_name)
else:
images = img_or_name
if self.mode == 'bboxcrop':
bbox[0] = max(0,bbox[0])
bbox[1] = max(0,bbox[1])
crop = images[int(bbox[1]):int(bbox[3]),int(bbox[0]):int(bbox[2]),::-1]
else:
crop = images[:,:,::-1]
# print("[yolo img shape] ",crop.shape)
results = self.model(crop) #RGB images[:,:,::-1]
# breakpoint()
arrays = np.array(results.pandas().xyxy[0])
bboxes = {
'bbox':[],
'bbox_handl':[],
'bbox_handr':[],
'pelvis':[],
'pelvis_l':[],
'pelvis_r':[]
}
for i, res in enumerate(arrays):
classid = res[5]
box = res[:5]
if self.mode == 'bboxcrop':
box[0]+=bbox[0]
box[2]+=bbox[0]
box[1]+=bbox[1]
box[3]+=bbox[1]
if False:
vis = images.copy()
cpimg = crop.copy()
from easymocap.mytools.vis_base import plot_bbox
plot_bbox(vis,box,0)
plot_bbox(cpimg,res[:5],0)
cv2.imshow('vis',vis)
# cv2.waitKey(0)
cv2.imshow('crop',cpimg)
cv2.waitKey(0)
breakpoint()
if box[4] < self.min_detect_thres:
continue
if classid==0:
bboxes['bbox'].append(box)
elif classid==1:
bboxes['bbox_handl'].append(box)
bboxes['pelvis_l'].append([(box[0]+box[2])/2,(box[1]+box[3])/2,box[-1]])
elif classid==2:
bboxes['bbox_handr'].append(box)
bboxes['pelvis_r'].append([(box[0]+box[2])/2,(box[1]+box[3])/2,box[-1]])
if(len(bboxes['bbox_handl'])==0):
# bboxes['bbox_handl'].append(np.zeros((0, 5)))
# bboxes['pelvis_l'].append(np.zeros((0, 3)))
bboxes['bbox_handl'].append(np.zeros((5)))
bboxes['pelvis_l'].append(np.zeros((3)))
if(len(bboxes['bbox_handr'])==0):
# bboxes['bbox_handr'].append(np.zeros((0, 5)))
# bboxes['pelvis_r'].append(np.zeros((0, 3)))
bboxes['bbox_handr'].append(np.zeros((5)))
bboxes['pelvis_r'].append(np.zeros((3)))
if(len(bboxes['bbox'])==0):
bboxes['bbox'].append(np.zeros((5)))
bboxes['bbox'] = np.array(bboxes['bbox'])
if isinstance(imgname,str):
with open(cachename, 'wb') as f:
pickle.dump(bboxes, f)
return bboxes
def __call__(self, images, imgname, bbox=[]):
return self.det_step(images, imgname, bbox)
class Yolo_model_hand_mvmp(Yolo_model):
@torch.no_grad()
def __call__(self, bbox, images, imgnames):
ret = {
'pelvis_l':[],
'pelvis_r':[],
# 'pelvis':[],
'bbox_handl':[],
'bbox_handr':[],
}
for nv in range(len(images)):
img = images[nv]
imgname = imgnames[nv]
if self.mode == 'bboxcrop':
bboxes = {
'bbox':[],
'bbox_handl':[],
'bbox_handr':[],
'pelvis_l':[],
'pelvis_r':[]
}
for pid in range(len(bbox[nv])):
bboxes_ = self.det_step(img, imgname, bbox[nv][pid])
for key in bboxes.keys():
bboxes[key].append(bboxes_[key])
else:
bboxes = self.det_step(img, imgname)
for k in ret.keys():
ret[k].append(np.array(bboxes[k]))
return ret