EasyMocap/easymocap/dataset/base.py

609 lines
25 KiB
Python
Raw Normal View History

2021-04-14 15:22:51 +08:00
'''
@ Date: 2021-01-13 16:53:55
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-04-13 15:59:35
@ FilePath: /EasyMocap/easymocap/dataset/base.py
'''
import os
import json
from os.path import join
from glob import glob
import cv2
import os, sys
import numpy as np
from ..mytools.camera_utils import read_camera, get_fundamental_matrix, Undistort
from ..mytools import FileWriter, read_annot, getFileList
from ..mytools.reader import read_keypoints3d, read_json
# from ..mytools.writer import FileWriter
# from ..mytools.camera_utils import read_camera, undistort, write_camera, get_fundamental_matrix
# from ..mytools.vis_base import merge, plot_bbox, plot_keypoints
# from ..mytools.file_utils import read_json, save_json, read_annot, read_smpl, write_smpl, get_bbox_from_pose
# from ..mytools.file_utils import merge_params, select_nf, getFileList
def crop_image(img, annot, vis_2d=False, config={}, crop_square=True):
for det in annot:
bbox = det['bbox']
l, t, r, b = det['bbox'][:4]
if crop_square:
if b - t > r - l:
diff = (b - t) - (r - l)
l -= diff//2
r += diff//2
else:
diff = (r - l) - (b - t)
t -= diff//2
b += diff//2
l = max(0, int(l+0.5))
t = max(0, int(t+0.5))
r = min(img.shape[1], int(r+0.5))
b = min(img.shape[0], int(b+0.5))
det['bbox'][:4] = [l, t, r, b]
if vis_2d:
crop_img = img.copy()
plot_keypoints(crop_img, det['keypoints'], pid=det['id'],
config=config, use_limb_color=True, lw=2)
else:
crop_img = img
crop_img = crop_img[t:b, l:r, :]
if crop_square:
crop_img = cv2.resize(crop_img, (256, 256))
else:
crop_img = cv2.resize(crop_img, (128, 256))
det['crop'] = crop_img
det['img'] = img
return 0
class ImageFolder:
"""Dataset for image folders"""
def __init__(self, root, subs=[], out=None, image_root='images', annot_root='annots',
kpts_type='body15', config={}, no_img=False) -> None:
self.root = root
self.image_root = join(root, image_root)
self.annot_root = join(root, annot_root)
self.kpts_type = kpts_type
self.no_img = no_img
if len(subs) == 0:
self.imagelist = getFileList(self.image_root, '.jpg')
self.annotlist = getFileList(self.annot_root, '.json')
else:
self.imagelist, self.annotlist = [], []
for sub in subs:
images = sorted([join(sub, i) for i in os.listdir(join(self.image_root, sub))])
self.imagelist.extend(images)
annots = sorted([join(sub, i) for i in os.listdir(join(self.annot_root, sub))])
self.annotlist.extend(annots)
# output
assert out is not None
self.out = out
self.writer = FileWriter(self.out, config=config)
self.gtK, self.gtRT = False, False
def load_gt_cameras(self):
cameras = load_cameras(self.root)
gtCameras = []
for i, name in enumerate(self.annotlist):
cam = os.path.dirname(name)
gtcams = {key:cameras[cam][key].copy() for key in ['K', 'R', 'T', 'dist']}
gtCameras.append(gtcams)
self.gtCameras = gtCameras
def __len__(self) -> int:
return len(self.imagelist)
def __getitem__(self, index: int):
imgname = join(self.image_root, self.imagelist[index])
annname = join(self.annot_root, self.annotlist[index])
assert os.path.exists(imgname) and os.path.exists(annname), (imgname, annname)
assert os.path.basename(imgname).split('.')[0] == os.path.basename(annname).split('.')[0], '{}, {}'.format(imgname, annname)
if not self.no_img:
img = cv2.imread(imgname)
else:
img = None
annot = read_annot(annname, self.kpts_type)
return img, annot
def camera(self, index=0, annname=None):
if annname is None:
annname = join(self.annot_root, self.annotlist[index])
data = read_json(annname)
if 'K' not in data.keys():
height, width = data['height'], data['width']
# focal = 1.2*max(height, width) # as colmap
focal = 1.2*min(height, width) # as colmap
K = np.array([focal, 0., width/2, 0., focal, height/2, 0. ,0., 1.]).reshape(3, 3)
else:
K = np.array(data['K']).reshape(3, 3)
camera = {'K':K ,'R': np.eye(3), 'T': np.zeros((3, 1)), 'dist': np.zeros((1, 5))}
if self.gtK:
camera['K'] = self.gtCameras[index]['K']
if self.gtRT:
camera['R'] = self.gtCameras[index]['R']
camera['T'] = self.gtCameras[index]['T']
# camera['T'][2, 0] = 5. # guess to 5 meters
camera['RT'] = np.hstack((camera['R'], camera['T']))
camera['P'] = camera['K'] @ np.hstack((camera['R'], camera['T']))
return camera
def basename(self, nf):
return self.annotlist[nf].replace('.json', '')
def write_keypoints3d(self, results, nf):
outname = join(self.out, 'keypoints3d', '{}.json'.format(self.basename(nf)))
self.writer.write_keypoints3d(results, outname)
def write_smpl(self, results, nf):
outname = join(self.out, 'smpl', '{}.json'.format(self.basename(nf)))
self.writer.write_smpl(results, outname)
def vis_smpl(self, render_data, image, camera, nf):
outname = join(self.out, 'smpl', '{}.jpg'.format(self.basename(nf)))
images = [image]
for key in camera.keys():
camera[key] = camera[key][None, :, :]
self.writer.vis_smpl(render_data, images, camera, outname, add_back=True)
class VideoFolder(ImageFolder):
"一段视频的图片的文件夹"
def __init__(self, root, name, out=None,
image_root='images', annot_root='annots',
kpts_type='body15', config={}, no_img=False) -> None:
self.root = root
self.image_root = join(root, image_root, name)
self.annot_root = join(root, annot_root, name)
self.name = name
self.kpts_type = kpts_type
self.no_img = no_img
self.imagelist = sorted(os.listdir(self.image_root))
self.annotlist = sorted(os.listdir(self.annot_root))
self.ret_crop = False
def load_annot_all(self, path):
# 这个不使用personID只是单纯的罗列一下
assert os.path.exists(path), '{} not exists!'.format(path)
results = []
annnames = sorted(glob(join(path, '*.json')))
for annname in annnames:
datas = read_annot(annname, self.kpts_type)
if self.ret_crop:
# TODO:修改imgname
basename = os.path.basename(annname)
imgname = annname\
.replace('annots-cpn', 'images')\
.replace('annots', 'images')\
.replace('.json', '.jpg')
assert os.path.exists(imgname), imgname
img = cv2.imread(imgname)
crop_image(img, datas)
results.append(datas)
return results
def load_annot(self, path, pids=[]):
# 这个根据人的ID预先存一下
assert os.path.exists(path), '{} not exists!'.format(path)
results = {}
annnames = sorted(glob(join(path, '*.json')))
for annname in annnames:
nf = int(os.path.basename(annname).replace('.json', ''))
datas = read_annot(annname, self.kpts_type)
for data in datas:
pid = data['id']
if len(pids) > 0 and pid not in pids:
continue
# 注意 这里没有考虑从哪开始的
if pid not in results.keys():
results[pid] = {'bboxes': [], 'keypoints2d': []}
results[pid]['bboxes'].append(data['bbox'])
results[pid]['keypoints2d'].append(data['keypoints'])
for pid, val in results.items():
for key in val.keys():
val[key] = np.stack(val[key])
return results
def load_smpl(self, path, pids=[]):
""" load SMPL parameters from files
Args:
path (str): root path of smpl
pids (list, optional): used person ids. Defaults to [], loading all person.
"""
assert os.path.exists(path), '{} not exists!'.format(path)
results = {}
smplnames = sorted(glob(join(path, '*.json')))
for smplname in smplnames:
nf = int(os.path.basename(smplname).replace('.json', ''))
datas = read_smpl(smplname)
for data in datas:
pid = data['id']
if len(pids) > 0 and pid not in pids:
continue
# 注意 这里没有考虑从哪开始的
if pid not in results.keys():
results[pid] = {'body_params': [], 'frames': []}
results[pid]['body_params'].append(data)
results[pid]['frames'].append(nf)
for pid, val in results.items():
val['body_params'] = merge_params(val['body_params'])
return results
class _VideoBase:
"""Dataset for single sequence data
"""
def __init__(self, image_root, annot_root, out=None, config={}, kpts_type='body15', no_img=False) -> None:
self.image_root = image_root
self.annot_root = annot_root
self.kpts_type = kpts_type
self.no_img = no_img
self.config = config
assert out is not None
self.out = out
self.writer = FileWriter(self.out, config=config)
imgnames = sorted(os.listdir(self.image_root))
self.imagelist = imgnames
self.annotlist = sorted(os.listdir(self.annot_root))
self.nFrames = len(self.imagelist)
self.undis = False
self.read_camera()
def read_camera(self):
# 读入相机参数
annname = join(self.annot_root, self.annotlist[0])
data = read_json(annname)
if 'K' not in data.keys():
height, width = data['height'], data['width']
focal = 1.2*max(height, width)
K = np.array([focal, 0., width/2, 0., focal, height/2, 0. ,0., 1.]).reshape(3, 3)
else:
K = np.array(data['K']).reshape(3, 3)
self.camera = {'K':K ,'R': np.eye(3), 'T': np.zeros((3, 1))}
def __getitem__(self, index: int):
imgname = join(self.image_root, self.imagelist[index])
annname = join(self.annot_root, self.annotlist[index])
assert os.path.exists(imgname) and os.path.exists(annname)
assert os.path.basename(imgname).split('.')[0] == os.path.basename(annname).split('.')[0], '{}, {}'.format(imgname, annname)
if not self.no_img:
img = cv2.imread(imgname)
else:
img = None
annot = read_annot(annname, self.kpts_type)
return img, annot
def __len__(self) -> int:
return self.nFrames
def write_smpl(self, peopleDict, nf):
results = []
for pid, people in peopleDict.items():
result = {'id': pid}
result.update(people.body_params)
results.append(result)
self.writer.write_smpl(results, nf)
def vis_detections(self, image, detections, nf, to_img=True):
return self.writer.vis_detections([image], [detections], nf,
key='keypoints', to_img=to_img, vis_id=True)
def vis_repro(self, peopleDict, image, annots, nf):
# 可视化重投影的关键点与输入的关键点
detections = []
for pid, data in peopleDict.items():
keypoints3d = (data.keypoints3d @ self.camera['R'].T + self.camera['T'].T) @ self.camera['K'].T
keypoints3d[:, :2] /= keypoints3d[:, 2:]
keypoints3d = np.hstack([keypoints3d, data.keypoints3d[:, -1:]])
det = {
'id': pid,
'repro': keypoints3d
}
detections.append(det)
return self.writer.vis_detections([image], [detections], nf, key='repro',
to_img=True, vis_id=False)
def vis_smpl(self, peopleDict, faces, image, nf, sub_vis=[],
mode='smpl', extra_data=[], add_back=True,
axis=np.array([1., 0., 0.]), degree=0., fix_center=None):
# 为了统一接口,旋转视角的在此处实现,只在单视角的数据中使用
# 通过修改相机参数实现
# 相机参数的修正可以通过计算点的中心来获得
# render the smpl to each view
render_data = {}
for pid, data in peopleDict.items():
render_data[pid] = {
'vertices': data.vertices, 'faces': faces,
'vid': pid, 'name': 'human_{}_{}'.format(nf, pid)}
for iid, extra in enumerate(extra_data):
render_data[10000+iid] = {
'vertices': extra['vertices'],
'faces': extra['faces'],
'colors': extra['colors'],
'name': extra['name']
}
camera = {}
for key in self.camera.keys():
camera[key] = self.camera[key][None, :, :]
# render another view point
if np.abs(degree) > 1e-3:
vertices_all = np.vstack([data.vertices for data in peopleDict.values()])
if fix_center is None:
center = np.mean(vertices_all, axis=0, keepdims=True)
new_center = center.copy()
new_center[:, 0:2] = 0
else:
center = fix_center.copy()
new_center = fix_center.copy()
new_center[:, 2] *= 1.5
direc = np.array(axis)
rot, _ = cv2.Rodrigues(direc*degree/90*np.pi/2)
# If we rorate the data, it is like:
# V = Rnew @ (V0 - center) + new_center
# = Rnew @ V0 - Rnew @ center + new_center
# combine with the camera
# VV = Rc(Rnew @ V0 - Rnew @ center + new_center) + Tc
# = Rc@Rnew @ V0 + Rc @ (new_center - Rnew@center) + Tc
blank = np.zeros_like(image, dtype=np.uint8) + 255
images = [image, blank]
Rnew = camera['R'][0] @ rot
Tnew = camera['R'][0] @ (new_center.T - rot @ center.T) + camera['T'][0]
camera['K'] = np.vstack([camera['K'], camera['K']])
camera['R'] = np.vstack([camera['R'], Rnew[None, :, :]])
camera['T'] = np.vstack([camera['T'], Tnew[None, :, :]])
else:
images = [image]
self.writer.vis_smpl(render_data, nf, images, camera, mode, add_back=add_back)
def load_cameras(path):
# 读入相机参数
intri_name = join(path, 'intri.yml')
extri_name = join(path, 'extri.yml')
if os.path.exists(intri_name) and os.path.exists(extri_name):
cameras = read_camera(intri_name, extri_name)
cams = cameras.pop('basenames')
else:
print('\n\n!!!there is no camera parameters, maybe bug: \n', intri_name, extri_name, '\n')
cameras = None
return cameras
class MVBase:
""" Dataset for multiview data
"""
def __init__(self, root, cams=[], out=None, config={},
image_root='images', annot_root='annots',
kpts_type='body15',
undis=True, no_img=False) -> None:
self.root = root
self.image_root = join(root, image_root)
self.annot_root = join(root, annot_root)
self.kpts_type = kpts_type
self.undis = undis
self.no_img = no_img
# use when debug
self.ret_crop = False
self.config = config
# results path
# the results store keypoints3d
self.skel_path = None
self.out = out
self.writer = FileWriter(self.out, config=config)
self.cams = cams
self.imagelist = {}
self.annotlist = {}
for cam in cams: #TODO: 增加start,end
# ATTN: when image name's frame number is not continuous,
imgnames = sorted(os.listdir(join(self.image_root, cam)))
self.imagelist[cam] = imgnames
if os.path.exists(self.annot_root):
self.annotlist[cam] = sorted(os.listdir(join(self.annot_root, cam)))
self.has2d = True
else:
self.has2d = False
nFrames = min([len(val) for key, val in self.imagelist.items()])
self.nFrames = nFrames
self.nViews = len(cams)
self.read_camera(self.root)
def read_camera(self, path):
# 读入相机参数
intri_name = join(path, 'intri.yml')
extri_name = join(path, 'extri.yml')
if os.path.exists(intri_name) and os.path.exists(extri_name):
self.cameras = read_camera(intri_name, extri_name)
self.cameras.pop('basenames')
# 注意:这里的相机参数一定要用定义的,不然只用一部分相机的时候会出错
cams = self.cams
self.cameras_for_affinity = [[cam['invK'], cam['R'], cam['T']] for cam in [self.cameras[name] for name in cams]]
self.Pall = np.stack([self.cameras[cam]['P'] for cam in cams])
self.Fall = get_fundamental_matrix(self.cameras, cams)
else:
print('\n!!!\n!!!there is no camera parameters, maybe bug: \n', intri_name, extri_name, '\n')
self.cameras = None
def undistort(self, images):
if self.cameras is not None and len(images) > 0:
images_ = []
for nv in range(self.nViews):
mtx = self.cameras[self.cams[nv]]['K']
dist = self.cameras[self.cams[nv]]['dist']
if images[nv] is not None:
frame = cv2.undistort(images[nv], mtx, dist, None)
else:
frame = None
images_.append(frame)
else:
images_ = images
return images_
def undis_det(self, lDetections):
for nv in range(len(lDetections)):
camera = self.cameras[self.cams[nv]]
for det in lDetections[nv]:
det['bbox'] = Undistort.bbox(det['bbox'], K=camera['K'], dist=camera['dist'])
keypoints = det['keypoints']
det['keypoints'] = Undistort.points(keypoints=keypoints, K=camera['K'], dist=camera['dist'])
return lDetections
def select_person(self, annots_all, index, pid):
annots = {'bbox': [], 'keypoints': []}
for nv, cam in enumerate(self.cams):
data = [d for d in annots_all[nv] if d['id'] == pid]
if len(data) == 1:
data = data[0]
bbox = data['bbox']
keypoints = data['keypoints']
else:
if self.verbose:print('not found pid {} in frame {}, view {}'.format(self.pid, index, nv))
keypoints = np.zeros((self.config['nJoints'], 3))
bbox = np.array([0, 0, 100., 100., 0.])
annots['bbox'].append(bbox)
annots['keypoints'].append(keypoints)
for key in ['bbox', 'keypoints']:
annots[key] = np.stack(annots[key])
return annots
def __getitem__(self, index: int):
images, annots = [], []
for cam in self.cams:
imgname = join(self.image_root, cam, self.imagelist[cam][index])
assert os.path.exists(imgname), imgname
if self.has2d:
annname = join(self.annot_root, cam, self.annotlist[cam][index])
assert os.path.exists(annname), annname
assert self.imagelist[cam][index].split('.')[0] == self.annotlist[cam][index].split('.')[0]
annot = read_annot(annname, self.kpts_type)
else:
annot = []
if not self.no_img:
img = cv2.imread(imgname)
images.append(img)
else:
img = None
images.append(None)
# TODO:这里直接取了0
if self.ret_crop:
crop_image(img, annot, True, self.config)
annots.append(annot)
if self.undis:
images = self.undistort(images)
annots = self.undis_det(annots)
return images, annots
def __len__(self) -> int:
return self.nFrames
def vis_detections(self, images, lDetections, nf, mode='detec', to_img=True, sub_vis=[]):
outname = join(self.out, mode, '{:06d}.jpg'.format(nf))
if len(sub_vis) != 0:
valid_idx = [self.cams.index(i) for i in sub_vis]
images = [images[i] for i in valid_idx]
lDetections = [lDetections[i] for i in valid_idx]
return self.writer.vis_keypoints2d_mv(images, lDetections, outname=outname, vis_id=False)
def vis_match(self, images, lDetections, nf, to_img=True, sub_vis=[]):
if len(sub_vis) != 0:
valid_idx = [self.cams.index(i) for i in sub_vis]
images = [images[i] for i in valid_idx]
lDetections = [lDetections[i] for i in valid_idx]
return self.writer.vis_detections(images, lDetections, nf,
key='match', to_img=to_img, vis_id=True)
def basename(self, nf):
return '{:06d}'.format(nf)
def write_keypoints3d(self, results, nf):
outname = join(self.out, 'keypoints3d', self.basename(nf)+'.json')
self.writer.write_keypoints3d(results, outname)
def write_smpl(self, results, nf):
outname = join(self.out, 'smpl', self.basename(nf)+'.json')
self.writer.write_smpl(results, outname)
def vis_smpl(self, peopleDict, faces, images, nf, sub_vis=[],
mode='smpl', extra_data=[], extra_mesh=[],
add_back=True, camera_scale=1, cameras=None):
# render the smpl to each view
render_data = {}
for pid, data in peopleDict.items():
render_data[pid] = {
'vertices': data.vertices, 'faces': faces,
'vid': pid, 'name': 'human_{}_{}'.format(nf, pid)}
for iid, extra in enumerate(extra_data):
render_data[10000+iid] = {
'vertices': extra['vertices'],
'faces': extra['faces'],
'name': extra['name']
}
if 'colors' in extra.keys():
render_data[10000+iid]['colors'] = extra['colors']
elif 'vid' in extra.keys():
render_data[10000+iid]['vid'] = extra['vid']
if len(sub_vis) == 0:
sub_vis = self.cams
images = [images[self.cams.index(cam)] for cam in sub_vis]
if cameras is None:
cameras = {'K': [], 'R':[], 'T':[]}
for key in cameras.keys():
cameras[key] = [self.cameras[cam][key] for cam in sub_vis]
for key in cameras.keys():
cameras[key] = np.stack([self.cameras[cam][key] for cam in sub_vis])
# 根据camera_back参数控制相机向后退的距离
# 相机的光心的位置: -R.T @ T
if False:
R = cameras['R']
T = cameras['T']
cam_center = np.einsum('bij,bjk->bik', -R.transpose(0, 2, 1), T)
# 相机的朝向: R @ [0, 0, 1]
zdir = np.array([0., 0., 1.]).reshape(-1, 3, 1)
direction = np.einsum('bij,bjk->bik', R, zdir)
cam_center = cam_center - direction * 1
# 更新过后的相机的T: - R @ C
Tnew = - np.einsum('bij,bjk->bik', R, cam_center)
cameras['T'] = Tnew
else:
cameras['K'][:, 0, 0] /= camera_scale
cameras['K'][:, 1, 1] /= camera_scale
return self.writer.vis_smpl(render_data, nf, images, cameras, mode, add_back=add_back, extra_mesh=extra_mesh)
def read_skeleton(self, start, end):
keypoints3ds = []
for nf in range(start, end):
skelname = join(self.out, 'keypoints3d', '{:06d}.json'.format(nf))
skeletons = read_keypoints3d(skelname)
skeleton = [i for i in skeletons if i['id'] == self.pid]
assert len(skeleton) == 1, 'There must be only 1 keypoints3d, id = {} in {}'.format(self.pid, skelname)
keypoints3ds.append(skeleton[0]['keypoints3d'])
keypoints3ds = np.stack(keypoints3ds)
return keypoints3ds
def read_skel(self, nf, path=None, mode='none'):
if path is None:
path = self.skel_path
assert path is not None, 'please set the skeleton path'
if mode == 'a4d':
outname = join(path, '{}.txt'.format(nf))
assert os.path.exists(outname), outname
skels = readReasultsTxt(outname)
elif mode == 'none':
outname = join(path, '{:06d}.json'.format(nf))
assert os.path.exists(outname), outname
skels = readResultsJson(outname)
else:
import ipdb; ipdb.set_trace()
return skels
def read_smpl(self, nf, path=None):
if path is None:
path = self.skel_path
assert path is not None, 'please set the skeleton path'
outname = join(path, '{:06d}.json'.format(nf))
assert os.path.exists(outname), outname
datas = read_json(outname)
outputs = []
for data in datas:
for key in ['Rh', 'Th', 'poses', 'shapes']:
data[key] = np.array(data[key])
outputs.append(data)
return outputs