EasyMocap/code/dataset/base.py
2021-01-14 21:32:09 +08:00

216 lines
8.6 KiB
Python

'''
@ Date: 2021-01-13 16:53:55
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-01-14 19:55:58
@ FilePath: /EasyMocapRelease/code/dataset/base.py
'''
import os
import json
from os.path import join
from torch.utils.data.dataset import Dataset
import cv2
import os, sys
import numpy as np
code_path = join(os.path.dirname(__file__), '..')
sys.path.append(code_path)
from mytools.camera_utils import read_camera, undistort, write_camera
from mytools.vis_base import merge, plot_bbox, plot_keypoints
def read_json(path):
with open(path) as f:
data = json.load(f)
return data
def save_json(file, data):
if not os.path.exists(os.path.dirname(file)):
os.makedirs(os.path.dirname(file))
with open(file, 'w') as f:
json.dump(data, f, indent=4)
def read_annot(annotname, add_hand_face=False):
data = read_json(annotname)['annots']
for i in range(len(data)):
data[i]['id'] = data[i].pop('personID')
for key in ['bbox', 'keypoints', 'handl2d', 'handr2d', 'face2d']:
if key not in data[i].keys():continue
data[i][key] = np.array(data[i][key])
return data
def get_bbox_from_pose(pose_2d, img, rate = 0.1):
# this function returns bounding box from the 2D pose
validIdx = pose_2d[:, 2] > 0
if validIdx.sum() == 0:
return [0, 0, 100, 100, 0]
y_min = int(min(pose_2d[validIdx, 1]))
y_max = int(max(pose_2d[validIdx, 1]))
x_min = int(min(pose_2d[validIdx, 0]))
x_max = int(max(pose_2d[validIdx, 0]))
dx = (x_max - x_min)*rate
dy = (y_max - y_min)*rate
# 后面加上类别这些
bbox = [x_min-dx, y_min-dy, x_max+dx, y_max+dy, 1]
correct_bbox(img, bbox)
return bbox
def correct_bbox(img, bbox):
# this function corrects the bbox, which is out of image
w = img.shape[0]
h = img.shape[1]
if bbox[2] <= 0 or bbox[0] >= h or bbox[1] >= w or bbox[3] <= 0:
bbox[4] = 0
return bbox
class FileWriter:
def __init__(self, output_path, config=None, basenames=[], cfg=None) -> None:
self.out = output_path
keys = ['keypoints3d', 'smpl', 'repro', 'keypoints']
output_dict = {key:join(self.out, key) for key in keys}
for key, p in output_dict.items():
os.makedirs(p, exist_ok=True)
self.output_dict = output_dict
self.basenames = basenames
if cfg is not None:
print(cfg, file=open(join(output_path, 'exp.yml'), 'w'))
self.save_origin = False
self.config = config
def write_keypoints3d(self, results, nf):
savename = join(self.output_dict['keypoints3d'], '{:06d}.json'.format(nf))
save_json(savename, results)
def vis_detections(self, images, lDetections, nf, key='keypoints', to_img=True, vis_id=True):
images_vis = []
for nv, image in enumerate(images):
img = image.copy()
for det in lDetections[nv]:
keypoints = det[key]
bbox = det.pop('bbox', get_bbox_from_pose(keypoints, img))
# bbox = det['bbox']
plot_bbox(img, bbox, pid=det['id'], vis_id=vis_id)
plot_keypoints(img, keypoints, pid=det['id'], config=self.config, use_limb_color=False, lw=2)
images_vis.append(img)
image_vis = merge(images_vis, resize=not self.save_origin)
if to_img:
savename = join(self.output_dict[key], '{:06d}.jpg'.format(nf))
cv2.imwrite(savename, image_vis)
return image_vis
def write_smpl(self, results, nf):
format_out = {'float_kind':lambda x: "%.3f" % x}
filename = join(self.output_dict['smpl'], '{:06d}.json'.format(nf))
with open(filename, 'w') as f:
f.write('[\n')
for data in results:
f.write(' {\n')
output = {}
output['id'] = data['id']
output['Rh'] = np.array2string(data['Rh'], max_line_width=1000, separator=', ', formatter=format_out)
output['Th'] = np.array2string(data['Th'], max_line_width=1000, separator=', ', formatter=format_out)
output['poses'] = np.array2string(data['poses'], max_line_width=1000, separator=', ', formatter=format_out)
output['shapes'] = np.array2string(data['shapes'], max_line_width=1000, separator=', ', formatter=format_out)
for key in ['id', 'Rh', 'Th', 'poses', 'shapes']:
f.write(' \"{}\": {},\n'.format(key, output[key]))
f.write(' },\n')
f.write(']\n')
def vis_smpl(self, render_data, nf, images, cameras):
from visualize.renderer import Renderer
render = Renderer(height=1024, width=1024, faces=None)
render_results = render.render(render_data, cameras, images)
image_vis = merge(render_results, resize=not self.save_origin)
savename = join(self.output_dict['smpl'], '{:06d}.jpg'.format(nf))
cv2.imwrite(savename, image_vis)
class MVBase(Dataset):
""" Dataset for multiview data
"""
def __init__(self, root, cams=[], out=None, config={},
image_root='images', annot_root='annots',
add_hand_face=True,
undis=True, no_img=False) -> None:
self.root = root
self.image_root = join(root, image_root)
self.annot_root = join(root, annot_root)
self.add_hand_face = add_hand_face
self.undis = undis
self.no_img = no_img
self.config = config
if out is None:
out = join(root, 'output')
self.out = out
self.writer = FileWriter(self.out, config=config)
if len(cams) == 0:
cams = sorted([i for i in os.listdir(self.image_root) if os.path.isdir(join(self.image_root, i))])
self.cams = cams
self.imagelist = {}
self.annotlist = {}
for cam in cams: #TODO: 增加start,end
imgnames = sorted(os.listdir(join(self.image_root, cam)))
self.imagelist[cam] = imgnames
self.annotlist[cam] = sorted(os.listdir(join(self.annot_root, cam)))
nFrames = min([len(val) for key, val in self.imagelist.items()])
self.nFrames = nFrames
self.nViews = len(cams)
self.read_camera()
def read_camera(self):
path = self.root
# 读入相机参数
intri_name = join(path, 'intri.yml')
extri_name = join(path, 'extri.yml')
if os.path.exists(intri_name) and os.path.exists(extri_name):
self.cameras = read_camera(intri_name, extri_name, self.cams)
self.cameras.pop('basenames')
self.cameras_for_affinity = [[cam['invK'], cam['R'], cam['T']] for cam in [self.cameras[name] for name in self.cams]]
self.Pall = [self.cameras[cam]['P'] for cam in self.cams]
else:
print('!!!there is no camera parameters, maybe bug', intri_name, extri_name)
self.cameras = None
def undistort(self, images):
if self.cameras is not None and len(images) > 0:
images_ = []
for nv in range(self.nViews):
mtx = self.cameras[self.cams[nv]]['K']
dist = self.cameras[self.cams[nv]]['dist']
frame = cv2.undistort(images[nv], mtx, dist, None)
images_.append(frame)
else:
images_ = images
return images_
def undis_det(self, lDetections):
for nv in range(len(lDetections)):
camera = self.cameras[self.cams[nv]]
for det in lDetections[nv]:
det['bbox'] = undistort(camera, bbox=det['bbox'])
keypoints = det['keypoints']
det['keypoints'] = undistort(camera, keypoints=keypoints[None, :, :])[1][0]
return lDetections
def __getitem__(self, index: int):
images, annots = [], []
for cam in self.cams:
imgname = join(self.image_root, cam, self.imagelist[cam][index])
annname = join(self.annot_root, cam, self.annotlist[cam][index])
assert os.path.exists(imgname), imgname
assert os.path.exists(annname), annname
assert self.imagelist[cam][index].split('.')[0] == self.annotlist[cam][index].split('.')[0]
if not self.no_img:
img = cv2.imread(imgname)
images.append(img)
# TODO:这里直接取了0
annot = read_annot(annname, self.add_hand_face)
annots.append(annot)
if self.undis:
images = self.undistort(images)
annots = self.undis_det(annots)
return images, annots
def __len__(self) -> int:
return self.nFrames