EasyMocap/easymocap/neuralbody/renderer/render_base.py
2022-10-25 20:06:04 +08:00

422 lines
18 KiB
Python

import numpy as np
import cv2
import torch.nn as nn
import torch
import time
import json
from ..model.base import augment_z_vals, concat
_time_ = 0
def tic():
global _time_
_time_ = time.time()
def toc(name):
global _time_
print('{:15s}: {:.1f}'.format(name, 1000*(time.time() - _time_)))
_time_ = time.time()
def raw2acc(raw):
alpha = raw[..., -1]
weights = alpha * torch.cumprod(
torch.cat(
[torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10],
-1), -1)[:, :-1]
acc_map = torch.sum(weights, -1)
return acc_map
def raw2outputs(outputs, z_vals, rays_d, bkgd=None):
"""Transforms model's predictions to semantically meaningful values.
Args:
acc: [num_rays, num_samples along ray, 1]. Prediction from model.
feature: [num_rays, num_samples along ray, N]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
feat_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
if 'occupancy' in outputs.keys():
alpha = outputs['occupancy'][..., 0]
elif 'density' in outputs.keys():
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat(
[dists,
torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)],
-1) # [N_rays, N_samples]
dists = dists * torch.norm(rays_d, dim=-1)
noise = 0.
# alpha = raw2alpha(raw[..., -1] + noise, dists) # [N_rays, N_samples]
alpha = 1 - torch.exp(-dists*torch.relu(outputs['density'][..., 0] + noise)) # (N_rays, N_samples_)
else:
raise NotImplementedError
weights = alpha * torch.cumprod(
torch.cat(
[torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10],
-1), -1)[:, :-1]
acc_map = torch.sum(weights, -1)
# ATTN: here depth must /||ray_d||
depth_map = torch.sum(weights * z_vals, -1)/(1e-10 + acc_map)/torch.norm(rays_d, dim=-1).squeeze()
results = {
'acc_map': acc_map,
'depth_map': depth_map,
}
for key, val in outputs.items():
if key == 'occupancy':
continue
results[key+'_map'] = torch.sum(weights[..., None] * val, -2) # [N_rays, 3]
if bkgd is not None:
results['rgb_map'] = results['rgb_map'] + bkgd[0] * (1 - acc_map[..., None])
return results
class BaseRenderer(nn.Module):
def __init__(self, net, chunk, white_bkgd, use_occupancy, N_samples, split,
render_layer=False,
return_raw=False, return_extra=False, use_canonical=False):
super().__init__()
self.net = net
self.chunk = chunk
self.white_bkgd = white_bkgd
self.use_occupancy = use_occupancy
self.N_samples = 64
self.split = split
self.return_extra = return_extra
self.use_canonical = use_canonical
self.render_layer = render_layer
if use_canonical:
self.net.use_canonical = use_canonical
def forward_any(self, net, data, meta, bkgd):
# give network and data, return the corresponding output
raw, z_val_ = [], []
ray_o = data['ray_o'][0].unsqueeze(1)
ray_d = data['ray_d'][0].unsqueeze(1)
# Sample depth points
z_steps = torch.linspace(0, 1, self.N_samples, device=ray_o.device).reshape(1, -1)
for bn in range(0, ray_o.shape[0], self.chunk):
start, end = bn, bn + self.chunk
# first sample points
near, far = [data[key][0, start:end][:, None] for key in ['near', 'far']]
if False:
# z_vals: (nrays, N_samples)
z_vals = near * (1-z_steps) + far * z_steps
z_vals = z_vals.unsqueeze(2)
if self.split == 'train':
z_vals = augment_z_vals(z_vals)
pts = ray_o[mask] + ray_d[mask] * z_vals
viewdir = viewdirs[mask].expand(-1, pts.shape[1], -1)
raw_output = model.calculate_density_color(pts, viewdir)
else:
z_vals, pts, raw_output = net.calculate_density_color_from_ray(
ray_o[start:end], ray_d[start:end], near, far, self.split)
# directly render
if bkgd.shape[1] != 1:
bkgd_ = bkgd[:, start:end]
else:
bkgd_ = bkgd
results = raw2outputs(
raw_output, z_vals[..., 0], ray_d[start:end], bkgd_)
raw.append(results)
return raw
def compose(self, retlist, mask=None, bkgd=None):
res = {}
for key in retlist[0].keys():
val = torch.cat([r[key] for r in retlist])
if mask is not None and val.shape[0] != mask.shape[0]:
val_ = torch.zeros((mask.shape[0], *val.shape[1:]), device=val.device, dtype=val.dtype)
if key == 'rgb_map': # consider the background
if bkgd is None:
import ipdb; ipdb.set_trace()
elif bkgd is not None and bkgd.shape[0] > 1:
val_[~mask] = bkgd[~mask]
else:
val_[~mask] = bkgd[0]
val_[mask] = val
val = val_.unsqueeze(0)
else:
val = val.unsqueeze(0)
res[key] = val
return res
def forward_single(self, batch, bkgd):
keys = [d[0] for d in batch['meta']['keys']]
assert len(keys) == 1, 'Only support one key'
key = keys[0]
model = self.net.model(key)
model.clear_cache()
data = model.before(batch, key)
# get the background
bkgd_ = bkgd
if bkgd is not None and bkgd.shape[0] > 1:
bkgd_ = bkgd[data['mask'][0]][None] # (1, nValid, 3)
retlist = self.forward_any(model, data, batch['meta'], bkgd_)
res = self.compose(retlist, data['mask'][0], bkgd)
res['keys'] = keys
return res
def batch_forward(self, batch, viewdir, start, end, bkgd):
ray_o = batch['ray_o'][0, start:end, None]
ray_d = batch['ray_d'][0, start:end, None]
viewdirs = batch['viewdirs'][0, start:end, None].expand(-1, 1, -1)
keys_all = self.net.keys.copy()
object_keys = [d[0] for d in batch['meta']['object_keys']]
if len(object_keys) > len(keys_all) or True:
mapkeys = {}
operation = {}
keys_all = object_keys
for key in object_keys:
mapkeys[key] = key.split('_@')[0]
if '_@' in key:
params = json.loads(key.split('_@')[1].replace("'", '"'))
operation[key] = params
else:
mapkeys = {key:key for key in keys_all}
# keys_all.sort(key=lambda x:0 if x=='back' else int(x.split('_@')[0].replace('human_', ''))+1 if x.startswith('human') else 9999)
# print('render keys: ', keys_all)
ret_all = []
dimGroups = [0]
for key in object_keys:
if '@' in key:
model = self.net.model(mapkeys[key])
model.current = key
else:
model = self.net.model(key)
# 这里手动设置一下key,因为在非share模式下,不会自动覆盖
model.current = key
mask = batch[key + '_mask'][0]
start_ = mask[:start].sum()
end_ = mask[:end].sum()
near, far = [batch[key+'_'+nearfar][0, start_:end_][:, None] for nearfar in ['near', 'far']]
mask = mask[start:end]
if mask.sum() < 1:
# print('Skip {} [{}, {}]'.format(key, start, end))
continue
if False:
if key in self.net.N_samples.keys():
N_samples = self.net.N_samples[key]
else:
N_samples = self.net.N_samples['default']
dimGroups.append(dimGroups[-1]+N_samples)
z_steps = torch.linspace(0, 1, N_samples, device=ray_d.device).reshape(1, -1)
# z_vals: (nrays, N_samples)
z_vals = near * (1-z_steps) + far * z_steps
z_vals = z_vals.unsqueeze(2)
if self.split == 'train':
z_vals = augment_z_vals(z_vals)
pts = ray_o[mask] + ray_d[mask] * z_vals
viewdir = viewdirs[mask].expand(-1, pts.shape[1], -1)
raw_output = model.calculate_density_color(pts, viewdir)
else:
z_vals, pts, raw_output = model.calculate_density_color_from_ray(
ray_o[mask], ray_d[mask], near, far, self.split)
dimGroups.append(dimGroups[-1]+z_vals.shape[-2])
if not self.use_occupancy:
# set the density of last points to zero
raw_output['density'][:, -1] = 0.
if '_@' in key:
if 'scale_occ' in operation[key].keys():
raw_output['occupancy'] *= operation[key]['scale_occ']
# TODO: remove bounds
# if key.startswith('human') or key.startswith('ball'):
# notInBound = pts[..., -1] < 0.02
# raw_output['occupancy'][notInBound] = 0.
raw_output['z_vals'] = z_vals[..., 0]
# add instance
instance_ = torch.zeros((*pts.shape[:-1], len(keys_all)),
dtype=pts.dtype, device=pts.device)
instance_[..., keys_all.index(key)] = 1.
raw_output['instance'] = instance_
raw_padding = {}
for key_out, val in raw_output.items():
if len(val.shape) == 1: # for traj
raw_padding[key_out] = val
continue
padding = torch.zeros([mask.shape[0], *val.shape[1:]], dtype=val.dtype, device=val.device)
padding[mask] = val
raw_padding[key_out] = padding
ret_all.append(raw_padding)
# toc(key)
# if key.startswith('back') and (self.radius_max is not None or self.ranges is not None):
# if self.radius_max is not None:
# notInBound = torch.norm(wpts, dim=-1) > self.radius_max
# elif self.ranges is not None:
# bound_l = torch.FloatTensor(self.ranges[0]).to(wpts.device).reshape(1, 1, 3)
# bound_u = torch.FloatTensor(self.ranges[1]).to(wpts.device).reshape(1, 1, 3)
# self.bound_l = bound_l
# self.bound_u = bound_u
# notInBound = ((wpts < bound_l)|(wpts>bound_u)).any(dim=-1)
# mask_valid = mask & (~notInBound.all(dim=-1))
# # 注意:这里只在背景的时候正确,如果不是背景,数据量不一样的
# wpts_valid = wpts[mask_valid]
# # print('[back] forward {} points of {}'.format(mask_valid.sum(), mask.sum()))
# z_val_valid = z_val[mask_valid]
# viewdir_valid = viewdir_[mask_valid]
# viewdir_valid = viewdir_valid.expand(-1, wpts.shape[1], -1)
# raw_ = torch.zeros([wpts.shape[0], wpts.shape[1], 4], device=wpts.device, dtype=wpts.dtype)
# if mask_valid.sum() > 0:
# raw_valid = model.calculate_density_color(wpts_valid, viewdir_valid)
# raw_[mask_valid] = raw_valid
# else:
# viewdir_ = viewdir_.expand(-1, wpts.shape[1], -1)
# raw_ = model.calculate_density_color(wpts, viewdir_)
# if key+'_scale_occ' in batch.keys():
# raw_[..., -1] *= batch[key+'_scale_occ']
# if self.remove_fog:
# acc_ = raw2acc(raw_)
# raw_[acc_<self.thres_fog*batch[key+'_scale_occ']] = 0.
# elif self.remove_fog and key != 'back' and key.startswith('human'):
# acc_ = raw2acc(raw_)
# raw_[acc_<self.thres_fog] = 0.
# import ipdb; ipdb.set_trace()
# if self.render_layer:
# ret_layer = self.render_func(raw_now, zval_now, ray_d)
# for ret_name, val in ret_layer.items():
# ret_all[ret_name+'_'+key] = val
if len(ret_all) == 0:
# 补全0
occupancy = torch.zeros([ray_d.shape[0], 1, 1], device=ray_d.device)
color = torch.zeros([ray_d.shape[0], 1, 3], device=ray_d.device)
instance = torch.zeros([ray_d.shape[0], 1, len(object_keys)], device=ray_d.device)
z_vals_blank = torch.zeros([ray_d.shape[0], 1], device=ray_d.device)
blank_output = {'occupancy': occupancy, 'rgb': color, 'instance': instance,
'raw_alpha': occupancy}
blank_output['raw_rgb'] = blank_output['rgb']
ret = raw2outputs(blank_output, z_vals_blank, ray_d, bkgd)
return ret
raw_concat = concat(ret_all, dim=1, unsqueeze=False)
z_vals = raw_concat.pop('z_vals')
z_vals_sorted, indices = torch.sort(z_vals, dim=-1)
# toc('sort')
ind_0 = torch.zeros_like(indices, device=indices.device)
ind_0 = ind_0 + torch.arange(0, indices.shape[0], device=indices.device).reshape(-1, 1)
raw_sorted = {}
for key, val in raw_concat.items():
val_sorted = val[ind_0, indices]
raw_sorted[key] = val_sorted
ret = raw2outputs(raw_sorted, z_vals_sorted, ray_d, bkgd)
if self.render_layer:
for ikey, key in enumerate(object_keys):
raw_key = {k:v[:, dimGroups[ikey]:dimGroups[ikey+1]] for k,v in raw_concat.items()}
layer = raw2outputs(raw_key, z_vals[:, dimGroups[ikey]:dimGroups[ikey+1]], ray_d, bkgd)
for k in ['acc_map', 'rgb_map']:
ret[key+'_'+k] = layer[k]
# toc('render')
return ret
def forward_multi(self, batch, bkgd):
keys = [d[0] for d in batch['meta']['keys']]
# prepare each model
res_cache = {}
for key in self.net.keys:
model = self.net.model(key)
model.clear_cache()
for key in keys:
if '@' in key:
key0 = key.split('_@')[0]
model = self.net.model(key0)
model.current = key
else:
model = self.net.model(key)
model.before(batch, key)
if key in model.cache.keys():
res_cache[key+'_cache'] = model.cache[key]
viewdir = batch['viewdirs'][0].unsqueeze(1)
retlist = []
for bn in range(0, viewdir.shape[0], self.chunk):
start, end = bn, min(bn + self.chunk, viewdir.shape[0])
ret = self.batch_forward(batch, viewdir, start, end, bkgd)
if ret is not None:
retlist.append(ret)
res = self.compose(retlist)
# add cache
res.update(res_cache)
res['keys'] = keys
return res
def forward(self, batch):
keys = [d[0] for d in batch['meta']['keys']]
rand_bkgd = None
device = batch['rgb'].device
if self.split == 'train':
rand_bkgd = torch.rand(3, device=device).reshape(1, 1, 3)
else:
if self.white_bkgd:
rand_bkgd = torch.ones(3, device=device).reshape(1, 1, 3)
if len(keys) == 1:
results = self.forward_single(batch, rand_bkgd)
else:
results = self.forward_multi(batch, rand_bkgd)
# set the random background in target
if self.split == 'train':
idx = torch.nonzero(batch['rgb'][0, :, 0] < 0)
if rand_bkgd is not None:
batch['rgb'][0, idx] = rand_bkgd
else:
batch['rgb'][0, idx] = 0.
return results
class RendererWithBkgd(BaseRenderer):
def forward(self, batch):
keys = [d[0] for d in batch['meta']['keys']]
device = batch['rgb'].device
nview = batch['meta']['nview'][0]
if 'background' not in batch.keys():
bkgd = torch.zeros((1, 1, 3), device=device)
if self.white_bkgd:
bkgd += 1.
else:
bkgd = batch['background'][0]
if len(keys) == 1:
results = self.forward_single(batch, bkgd)
else:
results = self.forward_multi(batch, bkgd)
return results
class BackgroundRenderer(BaseRenderer):
def forward(self, batch):
keys = [d[0] for d in batch['meta']['keys']]
coord = batch['coord'][0]
background = self.net.model('background')
background.before(batch)
bkgd = background(coord, batch['meta'])
if len(keys) == 1:
results = self.forward_single(batch, bkgd)
else:
results = self.forward_multi(batch, bkgd)
return results
class MirrorDemoRenderer(BackgroundRenderer):
def forward(self, _batch):
background = self.net.model('background').background
background_init = self.net.model('background').background_init
background = background.detach().cpu().numpy().copy()
accmap = np.zeros_like(background[:, :, 0])
for mirror_key in ['left', 'right']:
batch = _batch[mirror_key]
batch['meta']['index'] = batch['meta']['index']//2
H, W = int(batch['meta']['H'][0]), int(batch['meta']['W'][0])
keys = [d[0] for d in batch['meta']['keys']]
coord = batch['coord'][0].cpu().numpy()
if mirror_key == 'right':
coord[:, 1] = W - 1 - coord[:, 1]
bkgd = torch.zeros((1, 1, 3), device=batch['coord'].device)
results = self.forward_single(batch, bkgd)
rgb_map = results['rgb_map'][0].detach().cpu().numpy()
acc_map = np.clip(results['acc_map'][0].detach().cpu().numpy(), 0., 1.).reshape(-1, 1)
accmap[coord[:, 0], coord[:, 1]] = acc_map[..., 0]
background[coord[:, 0], coord[:, 1]] = rgb_map * acc_map + background[coord[:, 0], coord[:, 1]] * (1-acc_map)
return {
'rgb_map': background,
'acc_map': accmap,
'meta': batch['meta']}