EasyMocap/easymocap/neuralbody/model/base.py
2022-10-25 20:06:04 +08:00

207 lines
7.6 KiB
Python

'''
@ Date: 2021-09-03 16:56:14
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-09-05 21:17:25
@ FilePath: /EasyMocap/easymocap/neuralbody/model/base.py
'''
import torch
import torch.nn as nn
from torch import searchsorted
def augment_z_vals(z_vals, perturb=1):
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
perturb_rand = perturb * torch.rand(z_vals.shape, device=z_vals.device)
z_vals = lower + (upper - lower) * perturb_rand
return z_vals
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Invert CDF
u = u.contiguous().to(cdf.device)
inds = searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
if False:
import matplotlib as mpl
mpl.use('TkAgg')
import matplotlib.pyplot as plt
pdf0 = pdf[0:1].detach().cpu().numpy()
cdf0 = cdf[0:1].detach().cpu().numpy()
plt.figure(1)
plt.title('pdf')
plt.plot(bins[0].detach().cpu().numpy()[:-1], pdf0.T)
plt.vlines(samples[0].detach().cpu().numpy(), ymin=0, ymax=1, colors='r')
plt.figure(2)
plt.title('cdf')
plt.plot(cdf0.T)
plt.show()
import ipdb;ipdb.set_trace()
return samples
def get_near_far(ray_o, ray_d, bounds):
""" get near and far
Args:
ray_o (np):
ray_d ([type]): [description]
bounds ([type]): [description]
Returns:
near, far, mask_at_box
这里的near是实际物理空间中的深度
"""
norm_d = torch.norm(ray_d, dim=-1, keepdim=True)
viewdir = ray_d/norm_d
viewdir[(viewdir<1e-10)&(viewdir>-1e-10)] = 1e-10
viewdir[(viewdir>-1e-10)&(viewdir<1e-10)] = -1e-10
inv_dir = 1.0/viewdir
tmin = (bounds[:1] - ray_o[:1])*inv_dir
tmax = (bounds[1:2] - ray_o[:1])*inv_dir
# 限定时间是增加的
t1 = torch.minimum(tmin, tmax)
t2 = torch.maximum(tmin, tmax)
near, _ = torch.max(t1, dim=-1)
far, _ = torch.min(t2, dim=-1)
mask_at_box = near < far
# pts = ray_o[mask_at_box] + far[mask_at_box, None]/norm_d[mask_at_box] * ray_d[mask_at_box]
# if (pts.max(0)[0] > bounds[1]).any():
# print(pts.max(0)[0] - bounds[1])
# import ipdb;ipdb.set_trace()
return near, far, mask_at_box
def get_near_far_RTBBox(ray_o, ray_d, bounds, R, T):
# sample the near far in canonical coordinate
ray_o_rt = (ray_o - T) @ R
ray_d_rt = ray_d @ R
near, far, mask_at_box = get_near_far(ray_o_rt, ray_d_rt, bounds)
return near, far, mask_at_box
def concat(retlist, dim=0, unsqueeze=True, mask=None):
res = {}
if len(retlist) == 0:
return res
for key in retlist[0].keys():
val = torch.cat([r[key] for r in retlist], dim=dim)
if mask is not None and val.shape[0] != mask.shape[0]:
val_ = torch.zeros((mask.shape[0], *val.shape[1:]), device=val.device, dtype=val.dtype)
val_[mask] = val
val = val_
if unsqueeze:
val = val.unsqueeze(0)
res[key] = val
return res
class Base(nn.Module):
def __init__(self, sample_args) -> None:
super().__init__()
self.cache = {}
self.N_samples = sample_args.N_samples
self.sample_method = sample_args.method
self.sample_args = sample_args
z_steps = torch.linspace(0, 1, self.N_samples).reshape(1, -1)
self.register_buffer('z_steps', z_steps)
def model(self, key):
self.current = key
return self
def clear_cache(self):
self.cache = {}
def before(self, batch, name):
""" The operation before each step
- neuralbody: encode sparse voxel
- aninerf: encode blending weight
"""
datas = {key.replace(name+'_', ''):val for key,val in batch.items() if key.startswith(name)}
return datas
def calculate_density(self, wpts):
raise NotImplementedError
@staticmethod
def sample(near, far, ray_o, ray_d, z_steps, split):
# This function provides a uniform sample strategy
# Override this function to implement more sampling, e.g. importance_sample, unisurf
z_vals = near * (1 - z_steps) + far * z_steps
z_vals = z_vals.unsqueeze(2)
if split == 'train':
z_vals = augment_z_vals(z_vals)
pts = ray_o + ray_d * z_vals
return pts, z_vals
def sample_pdf(self, near, far, ray_o, ray_d, split):
pts, z_vals = self.sample(near, far, ray_o, ray_d, self.z_steps, split)
# forward
viewdirs = ray_d / torch.norm(ray_d, keepdim=True, dim=-1)
viewdirs_ = viewdirs.expand(-1, pts.shape[1], -1)
raw_output = self.calculate_density_color(pts, viewdirs_)
# resample
alpha = raw_output['occupancy'][..., 0]
weights = alpha * torch.cumprod(
torch.cat(
[torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10],
-1), -1)[:, :-1]
z_vals_mid = .5 * (z_vals[...,1:, 0] + z_vals[...,:-1, 0])
z_samples = sample_pdf(z_vals_mid, weights[:, 1:-1], self.sample_args.N_importance, det=(split!='train'))
z_samples = z_samples.detach()
if split == 'train':
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples[..., None]], -2), -2)
else:
z_vals = z_samples[..., None]
pts = ray_o + ray_d * z_vals
return pts, z_vals
def calculate_density_color(self, wpts, viewdir):
raise NotImplementedError
def calculate_density_color_from_ray(self, ray_o, ray_d, near, far, split):
if self.sample_method == 'uniform':
pts, z_vals = self.sample(near, far, ray_o, ray_d, self.z_steps, split)
elif self.sample_method == 'importance':
pts, z_vals = self.sample_pdf(near, far, ray_o, ray_d, split)
elif self.sample_method == 'raymarching':
import ipdb;ipdb.set_trace()
else:
print('Please check the sample :', self.sample_method)
raise NotImplementedError
viewdirs = ray_d / torch.norm(ray_d, keepdim=True, dim=-1)
viewdirs_ = viewdirs.expand(-1, pts.shape[1], -1)
raw_output = self.calculate_density_color(pts, viewdirs_)
return z_vals, pts, raw_output
def after(self, batch):
pass