216 lines
8.6 KiB
Python
216 lines
8.6 KiB
Python
'''
|
|
@ Date: 2021-01-13 16:53:55
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-01-14 19:55:58
|
|
@ FilePath: /EasyMocapRelease/code/dataset/base.py
|
|
'''
|
|
import os
|
|
import json
|
|
from os.path import join
|
|
from torch.utils.data.dataset import Dataset
|
|
import cv2
|
|
import os, sys
|
|
import numpy as np
|
|
code_path = join(os.path.dirname(__file__), '..')
|
|
sys.path.append(code_path)
|
|
|
|
from mytools.camera_utils import read_camera, undistort, write_camera
|
|
from mytools.vis_base import merge, plot_bbox, plot_keypoints
|
|
|
|
def read_json(path):
|
|
with open(path) as f:
|
|
data = json.load(f)
|
|
return data
|
|
|
|
def save_json(file, data):
|
|
if not os.path.exists(os.path.dirname(file)):
|
|
os.makedirs(os.path.dirname(file))
|
|
with open(file, 'w') as f:
|
|
json.dump(data, f, indent=4)
|
|
|
|
|
|
def read_annot(annotname, add_hand_face=False):
|
|
data = read_json(annotname)['annots']
|
|
for i in range(len(data)):
|
|
data[i]['id'] = data[i].pop('personID')
|
|
for key in ['bbox', 'keypoints', 'handl2d', 'handr2d', 'face2d']:
|
|
if key not in data[i].keys():continue
|
|
data[i][key] = np.array(data[i][key])
|
|
return data
|
|
|
|
def get_bbox_from_pose(pose_2d, img, rate = 0.1):
|
|
# this function returns bounding box from the 2D pose
|
|
validIdx = pose_2d[:, 2] > 0
|
|
if validIdx.sum() == 0:
|
|
return [0, 0, 100, 100, 0]
|
|
y_min = int(min(pose_2d[validIdx, 1]))
|
|
y_max = int(max(pose_2d[validIdx, 1]))
|
|
x_min = int(min(pose_2d[validIdx, 0]))
|
|
x_max = int(max(pose_2d[validIdx, 0]))
|
|
dx = (x_max - x_min)*rate
|
|
dy = (y_max - y_min)*rate
|
|
# 后面加上类别这些
|
|
bbox = [x_min-dx, y_min-dy, x_max+dx, y_max+dy, 1]
|
|
correct_bbox(img, bbox)
|
|
return bbox
|
|
|
|
def correct_bbox(img, bbox):
|
|
# this function corrects the bbox, which is out of image
|
|
w = img.shape[0]
|
|
h = img.shape[1]
|
|
if bbox[2] <= 0 or bbox[0] >= h or bbox[1] >= w or bbox[3] <= 0:
|
|
bbox[4] = 0
|
|
return bbox
|
|
class FileWriter:
|
|
def __init__(self, output_path, config=None, basenames=[], cfg=None) -> None:
|
|
self.out = output_path
|
|
keys = ['keypoints3d', 'smpl', 'repro', 'keypoints']
|
|
output_dict = {key:join(self.out, key) for key in keys}
|
|
for key, p in output_dict.items():
|
|
os.makedirs(p, exist_ok=True)
|
|
self.output_dict = output_dict
|
|
|
|
self.basenames = basenames
|
|
if cfg is not None:
|
|
print(cfg, file=open(join(output_path, 'exp.yml'), 'w'))
|
|
self.save_origin = False
|
|
self.config = config
|
|
|
|
def write_keypoints3d(self, results, nf):
|
|
savename = join(self.output_dict['keypoints3d'], '{:06d}.json'.format(nf))
|
|
save_json(savename, results)
|
|
|
|
def vis_detections(self, images, lDetections, nf, key='keypoints', to_img=True, vis_id=True):
|
|
images_vis = []
|
|
for nv, image in enumerate(images):
|
|
img = image.copy()
|
|
for det in lDetections[nv]:
|
|
keypoints = det[key]
|
|
bbox = det.pop('bbox', get_bbox_from_pose(keypoints, img))
|
|
# bbox = det['bbox']
|
|
plot_bbox(img, bbox, pid=det['id'], vis_id=vis_id)
|
|
plot_keypoints(img, keypoints, pid=det['id'], config=self.config, use_limb_color=False, lw=2)
|
|
images_vis.append(img)
|
|
image_vis = merge(images_vis, resize=not self.save_origin)
|
|
if to_img:
|
|
savename = join(self.output_dict[key], '{:06d}.jpg'.format(nf))
|
|
cv2.imwrite(savename, image_vis)
|
|
return image_vis
|
|
|
|
def write_smpl(self, results, nf):
|
|
format_out = {'float_kind':lambda x: "%.3f" % x}
|
|
filename = join(self.output_dict['smpl'], '{:06d}.json'.format(nf))
|
|
with open(filename, 'w') as f:
|
|
f.write('[\n')
|
|
for data in results:
|
|
f.write(' {\n')
|
|
output = {}
|
|
output['id'] = data['id']
|
|
output['Rh'] = np.array2string(data['Rh'], max_line_width=1000, separator=', ', formatter=format_out)
|
|
output['Th'] = np.array2string(data['Th'], max_line_width=1000, separator=', ', formatter=format_out)
|
|
output['poses'] = np.array2string(data['poses'], max_line_width=1000, separator=', ', formatter=format_out)
|
|
output['shapes'] = np.array2string(data['shapes'], max_line_width=1000, separator=', ', formatter=format_out)
|
|
for key in ['id', 'Rh', 'Th', 'poses', 'shapes']:
|
|
f.write(' \"{}\": {},\n'.format(key, output[key]))
|
|
f.write(' },\n')
|
|
f.write(']\n')
|
|
|
|
def vis_smpl(self, render_data, nf, images, cameras):
|
|
from visualize.renderer import Renderer
|
|
render = Renderer(height=1024, width=1024, faces=None)
|
|
render_results = render.render(render_data, cameras, images)
|
|
image_vis = merge(render_results, resize=not self.save_origin)
|
|
savename = join(self.output_dict['smpl'], '{:06d}.jpg'.format(nf))
|
|
cv2.imwrite(savename, image_vis)
|
|
|
|
class MVBase(Dataset):
|
|
""" Dataset for multiview data
|
|
"""
|
|
def __init__(self, root, cams=[], out=None, config={},
|
|
image_root='images', annot_root='annots',
|
|
add_hand_face=True,
|
|
undis=True, no_img=False) -> None:
|
|
self.root = root
|
|
self.image_root = join(root, image_root)
|
|
self.annot_root = join(root, annot_root)
|
|
self.add_hand_face = add_hand_face
|
|
self.undis = undis
|
|
self.no_img = no_img
|
|
self.config = config
|
|
|
|
if out is None:
|
|
out = join(root, 'output')
|
|
self.out = out
|
|
self.writer = FileWriter(self.out, config=config)
|
|
|
|
if len(cams) == 0:
|
|
cams = sorted([i for i in os.listdir(self.image_root) if os.path.isdir(join(self.image_root, i))])
|
|
self.cams = cams
|
|
self.imagelist = {}
|
|
self.annotlist = {}
|
|
for cam in cams: #TODO: 增加start,end
|
|
imgnames = sorted(os.listdir(join(self.image_root, cam)))
|
|
self.imagelist[cam] = imgnames
|
|
self.annotlist[cam] = sorted(os.listdir(join(self.annot_root, cam)))
|
|
nFrames = min([len(val) for key, val in self.imagelist.items()])
|
|
self.nFrames = nFrames
|
|
self.nViews = len(cams)
|
|
self.read_camera()
|
|
|
|
def read_camera(self):
|
|
path = self.root
|
|
# 读入相机参数
|
|
intri_name = join(path, 'intri.yml')
|
|
extri_name = join(path, 'extri.yml')
|
|
if os.path.exists(intri_name) and os.path.exists(extri_name):
|
|
self.cameras = read_camera(intri_name, extri_name, self.cams)
|
|
self.cameras.pop('basenames')
|
|
self.cameras_for_affinity = [[cam['invK'], cam['R'], cam['T']] for cam in [self.cameras[name] for name in self.cams]]
|
|
self.Pall = [self.cameras[cam]['P'] for cam in self.cams]
|
|
else:
|
|
print('!!!there is no camera parameters, maybe bug', intri_name, extri_name)
|
|
self.cameras = None
|
|
|
|
def undistort(self, images):
|
|
if self.cameras is not None and len(images) > 0:
|
|
images_ = []
|
|
for nv in range(self.nViews):
|
|
mtx = self.cameras[self.cams[nv]]['K']
|
|
dist = self.cameras[self.cams[nv]]['dist']
|
|
frame = cv2.undistort(images[nv], mtx, dist, None)
|
|
images_.append(frame)
|
|
else:
|
|
images_ = images
|
|
return images_
|
|
|
|
def undis_det(self, lDetections):
|
|
for nv in range(len(lDetections)):
|
|
camera = self.cameras[self.cams[nv]]
|
|
for det in lDetections[nv]:
|
|
det['bbox'] = undistort(camera, bbox=det['bbox'])
|
|
keypoints = det['keypoints']
|
|
det['keypoints'] = undistort(camera, keypoints=keypoints[None, :, :])[1][0]
|
|
return lDetections
|
|
|
|
def __getitem__(self, index: int):
|
|
images, annots = [], []
|
|
for cam in self.cams:
|
|
imgname = join(self.image_root, cam, self.imagelist[cam][index])
|
|
annname = join(self.annot_root, cam, self.annotlist[cam][index])
|
|
assert os.path.exists(imgname), imgname
|
|
assert os.path.exists(annname), annname
|
|
assert self.imagelist[cam][index].split('.')[0] == self.annotlist[cam][index].split('.')[0]
|
|
if not self.no_img:
|
|
img = cv2.imread(imgname)
|
|
images.append(img)
|
|
# TODO:这里直接取了0
|
|
annot = read_annot(annname, self.add_hand_face)
|
|
annots.append(annot)
|
|
if self.undis:
|
|
images = self.undistort(images)
|
|
annots = self.undis_det(annots)
|
|
return images, annots
|
|
|
|
def __len__(self) -> int:
|
|
return self.nFrames |