import numpy as np import cv2 import torch.nn as nn import torch import time import json from ..model.base import augment_z_vals, concat _time_ = 0 def tic(): global _time_ _time_ = time.time() def toc(name): global _time_ print('{:15s}: {:.1f}'.format(name, 1000*(time.time() - _time_))) _time_ = time.time() def raw2acc(raw): alpha = raw[..., -1] weights = alpha * torch.cumprod( torch.cat( [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], -1), -1)[:, :-1] acc_map = torch.sum(weights, -1) return acc_map def raw2outputs(outputs, z_vals, rays_d, bkgd=None): """Transforms model's predictions to semantically meaningful values. Args: acc: [num_rays, num_samples along ray, 1]. Prediction from model. feature: [num_rays, num_samples along ray, N]. Prediction from model. z_vals: [num_rays, num_samples along ray]. Integration time. rays_d: [num_rays, 3]. Direction of each ray. Returns: feat_map: [num_rays, 3]. Estimated RGB color of a ray. disp_map: [num_rays]. Disparity map. Inverse of depth map. acc_map: [num_rays]. Sum of weights along each ray. weights: [num_rays, num_samples]. Weights assigned to each sampled color. depth_map: [num_rays]. Estimated distance to object. """ if 'occupancy' in outputs.keys(): alpha = outputs['occupancy'][..., 0] elif 'density' in outputs.keys(): dists = z_vals[..., 1:] - z_vals[..., :-1] dists = torch.cat( [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], -1) # [N_rays, N_samples] dists = dists * torch.norm(rays_d, dim=-1) noise = 0. # alpha = raw2alpha(raw[..., -1] + noise, dists) # [N_rays, N_samples] alpha = 1 - torch.exp(-dists*torch.relu(outputs['density'][..., 0] + noise)) # (N_rays, N_samples_) else: raise NotImplementedError weights = alpha * torch.cumprod( torch.cat( [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], -1), -1)[:, :-1] acc_map = torch.sum(weights, -1) # ATTN: here depth must /||ray_d|| depth_map = torch.sum(weights * z_vals, -1)/(1e-10 + acc_map)/torch.norm(rays_d, dim=-1).squeeze() results = { 'acc_map': acc_map, 'depth_map': depth_map, } for key, val in outputs.items(): if key == 'occupancy': continue results[key+'_map'] = torch.sum(weights[..., None] * val, -2) # [N_rays, 3] if bkgd is not None: results['rgb_map'] = results['rgb_map'] + bkgd[0] * (1 - acc_map[..., None]) return results class BaseRenderer(nn.Module): def __init__(self, net, chunk, white_bkgd, use_occupancy, N_samples, split, render_layer=False, return_raw=False, return_extra=False, use_canonical=False): super().__init__() self.net = net self.chunk = chunk self.white_bkgd = white_bkgd self.use_occupancy = use_occupancy self.N_samples = 64 self.split = split self.return_extra = return_extra self.use_canonical = use_canonical self.render_layer = render_layer if use_canonical: self.net.use_canonical = use_canonical def forward_any(self, net, data, meta, bkgd): # give network and data, return the corresponding output raw, z_val_ = [], [] ray_o = data['ray_o'][0].unsqueeze(1) ray_d = data['ray_d'][0].unsqueeze(1) # Sample depth points z_steps = torch.linspace(0, 1, self.N_samples, device=ray_o.device).reshape(1, -1) for bn in range(0, ray_o.shape[0], self.chunk): start, end = bn, bn + self.chunk # first sample points near, far = [data[key][0, start:end][:, None] for key in ['near', 'far']] if False: # z_vals: (nrays, N_samples) z_vals = near * (1-z_steps) + far * z_steps z_vals = z_vals.unsqueeze(2) if self.split == 'train': z_vals = augment_z_vals(z_vals) pts = ray_o[mask] + ray_d[mask] * z_vals viewdir = viewdirs[mask].expand(-1, pts.shape[1], -1) raw_output = model.calculate_density_color(pts, viewdir) else: z_vals, pts, raw_output = net.calculate_density_color_from_ray( ray_o[start:end], ray_d[start:end], near, far, self.split) # directly render if bkgd.shape[1] != 1: bkgd_ = bkgd[:, start:end] else: bkgd_ = bkgd results = raw2outputs( raw_output, z_vals[..., 0], ray_d[start:end], bkgd_) raw.append(results) return raw def compose(self, retlist, mask=None, bkgd=None): res = {} for key in retlist[0].keys(): val = torch.cat([r[key] for r in retlist]) if mask is not None and val.shape[0] != mask.shape[0]: val_ = torch.zeros((mask.shape[0], *val.shape[1:]), device=val.device, dtype=val.dtype) if key == 'rgb_map': # consider the background if bkgd is None: import ipdb; ipdb.set_trace() elif bkgd is not None and bkgd.shape[0] > 1: val_[~mask] = bkgd[~mask] else: val_[~mask] = bkgd[0] val_[mask] = val val = val_.unsqueeze(0) else: val = val.unsqueeze(0) res[key] = val return res def forward_single(self, batch, bkgd): keys = [d[0] for d in batch['meta']['keys']] assert len(keys) == 1, 'Only support one key' key = keys[0] model = self.net.model(key) model.clear_cache() data = model.before(batch, key) # get the background bkgd_ = bkgd if bkgd is not None and bkgd.shape[0] > 1: bkgd_ = bkgd[data['mask'][0]][None] # (1, nValid, 3) retlist = self.forward_any(model, data, batch['meta'], bkgd_) res = self.compose(retlist, data['mask'][0], bkgd) res['keys'] = keys return res def batch_forward(self, batch, viewdir, start, end, bkgd): ray_o = batch['ray_o'][0, start:end, None] ray_d = batch['ray_d'][0, start:end, None] viewdirs = batch['viewdirs'][0, start:end, None].expand(-1, 1, -1) keys_all = self.net.keys.copy() object_keys = [d[0] for d in batch['meta']['object_keys']] if len(object_keys) > len(keys_all) or True: mapkeys = {} operation = {} keys_all = object_keys for key in object_keys: mapkeys[key] = key.split('_@')[0] if '_@' in key: params = json.loads(key.split('_@')[1].replace("'", '"')) operation[key] = params else: mapkeys = {key:key for key in keys_all} # keys_all.sort(key=lambda x:0 if x=='back' else int(x.split('_@')[0].replace('human_', ''))+1 if x.startswith('human') else 9999) # print('render keys: ', keys_all) ret_all = [] dimGroups = [0] for key in object_keys: if '@' in key: model = self.net.model(mapkeys[key]) model.current = key else: model = self.net.model(key) # 这里手动设置一下key,因为在非share模式下,不会自动覆盖 model.current = key mask = batch[key + '_mask'][0] start_ = mask[:start].sum() end_ = mask[:end].sum() near, far = [batch[key+'_'+nearfar][0, start_:end_][:, None] for nearfar in ['near', 'far']] mask = mask[start:end] if mask.sum() < 1: # print('Skip {} [{}, {}]'.format(key, start, end)) continue if False: if key in self.net.N_samples.keys(): N_samples = self.net.N_samples[key] else: N_samples = self.net.N_samples['default'] dimGroups.append(dimGroups[-1]+N_samples) z_steps = torch.linspace(0, 1, N_samples, device=ray_d.device).reshape(1, -1) # z_vals: (nrays, N_samples) z_vals = near * (1-z_steps) + far * z_steps z_vals = z_vals.unsqueeze(2) if self.split == 'train': z_vals = augment_z_vals(z_vals) pts = ray_o[mask] + ray_d[mask] * z_vals viewdir = viewdirs[mask].expand(-1, pts.shape[1], -1) raw_output = model.calculate_density_color(pts, viewdir) else: z_vals, pts, raw_output = model.calculate_density_color_from_ray( ray_o[mask], ray_d[mask], near, far, self.split) dimGroups.append(dimGroups[-1]+z_vals.shape[-2]) if not self.use_occupancy: # set the density of last points to zero raw_output['density'][:, -1] = 0. if '_@' in key: if 'scale_occ' in operation[key].keys(): raw_output['occupancy'] *= operation[key]['scale_occ'] # TODO: remove bounds # if key.startswith('human') or key.startswith('ball'): # notInBound = pts[..., -1] < 0.02 # raw_output['occupancy'][notInBound] = 0. raw_output['z_vals'] = z_vals[..., 0] # add instance instance_ = torch.zeros((*pts.shape[:-1], len(keys_all)), dtype=pts.dtype, device=pts.device) instance_[..., keys_all.index(key)] = 1. raw_output['instance'] = instance_ raw_padding = {} for key_out, val in raw_output.items(): if len(val.shape) == 1: # for traj raw_padding[key_out] = val continue padding = torch.zeros([mask.shape[0], *val.shape[1:]], dtype=val.dtype, device=val.device) padding[mask] = val raw_padding[key_out] = padding ret_all.append(raw_padding) # toc(key) # if key.startswith('back') and (self.radius_max is not None or self.ranges is not None): # if self.radius_max is not None: # notInBound = torch.norm(wpts, dim=-1) > self.radius_max # elif self.ranges is not None: # bound_l = torch.FloatTensor(self.ranges[0]).to(wpts.device).reshape(1, 1, 3) # bound_u = torch.FloatTensor(self.ranges[1]).to(wpts.device).reshape(1, 1, 3) # self.bound_l = bound_l # self.bound_u = bound_u # notInBound = ((wpts < bound_l)|(wpts>bound_u)).any(dim=-1) # mask_valid = mask & (~notInBound.all(dim=-1)) # # 注意:这里只在背景的时候正确,如果不是背景,数据量不一样的 # wpts_valid = wpts[mask_valid] # # print('[back] forward {} points of {}'.format(mask_valid.sum(), mask.sum())) # z_val_valid = z_val[mask_valid] # viewdir_valid = viewdir_[mask_valid] # viewdir_valid = viewdir_valid.expand(-1, wpts.shape[1], -1) # raw_ = torch.zeros([wpts.shape[0], wpts.shape[1], 4], device=wpts.device, dtype=wpts.dtype) # if mask_valid.sum() > 0: # raw_valid = model.calculate_density_color(wpts_valid, viewdir_valid) # raw_[mask_valid] = raw_valid # else: # viewdir_ = viewdir_.expand(-1, wpts.shape[1], -1) # raw_ = model.calculate_density_color(wpts, viewdir_) # if key+'_scale_occ' in batch.keys(): # raw_[..., -1] *= batch[key+'_scale_occ'] # if self.remove_fog: # acc_ = raw2acc(raw_) # raw_[acc_