EasyMocap/easymocap/neuralbody/model/nerfplus.py
2022-10-25 20:06:04 +08:00

189 lines
7.0 KiB
Python

import torch
import torch.nn as nn
from .nerf import Nerf, EmbedMLP, MultiLinear
from os.path import join
from ...mytools.file_utils import read_json
import numpy as np
def create_dynamic_embedding(mode, embed):
if mode == 'dense':
embedding = nn.Embedding(embed.shape[0], embed.shape[1])
elif mode == 'mlp':
if 'D' not in embed.keys():
embedding = EmbedMLP(
input_ch=1,
multi_res=32,
W=128,
D=2,
bounds=embed.shape[0],
output_ch=embed.shape[1])
else:
embedding = EmbedMLP(
input_ch=1,
multi_res=32,
W=embed.W,
D=embed.D,
bounds=embed.shape[0],
output_ch=embed.shape[1])
else:
raise NotImplementedError
return embedding
class NeRFT(Nerf):
def __init__(self, embed, nerf):
nerf['latent'] = {'time': embed.shape[1]}
super().__init__(**nerf)
self.mode = embed.mode
self.embedding = create_dynamic_embedding(self.mode, embed)
self.cache = {}
def clear_cache(self):
self.cache = {}
def before(self, batch, name):
data = super().before(batch, name)
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
if 'frame' in name:
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
self.cache['embed'] = self.embedding(nf)
return data
def calculate_density_color(self, wpts, viewdir, **kwargs):
latents = {'time': self.cache['embed']}
return super().calculate_density_color(wpts, viewdir, latents, **kwargs)
class NeRFGroundShadow(Nerf):
def __init__(self, embed, shadow, nerf):
super().__init__(**nerf)
self.shadow = MultiLinear(
D=shadow.D,
W=shadow.W,
input_ch=self.ch_pts + embed.shape[1],
output_ch=1, # 输出一维阴影
init_bias=5,
act_fn='none',
skips=[]
)
nerf['latent'] = {'time': embed.shape[1]}
self.mode = embed.mode
self.embedding = create_dynamic_embedding(self.mode, embed)
self.cache = {}
def clear_cache(self):
self.cache = {}
def before(self, batch, name):
data = super().before(batch, name)
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
if 'frame' in name:
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
self.cache['embed'] = self.embedding(nf)
return data
def calculate_density_color(self, wpts, viewdir, **kwargs):
latents = self.cache['embed'][None]
raw_output = super().calculate_density_color(wpts, viewdir, **kwargs)
pts_embed = self.embed_pts(wpts)
latents = latents.expand(pts_embed.shape[0], pts_embed.shape[1], -1)
shadow = self.shadow(torch.cat([pts_embed, latents], dim=-1))
shadow = torch.sigmoid(shadow)
raw_output['rgb'] = shadow * raw_output['rgb']
return raw_output
class NeRFT_pretrain(Nerf):
def __init__(self, nerf, embed_time, pretrain, dcolor):
super().__init__(**nerf)
state_dict = torch.load(pretrain, map_location='cpu')['state_dict']
state_dict_new = {}
for key, val in state_dict.items():
if key.startswith('train_renderer.'): continue
state_dict_new[key.replace('network.', '')] = val
self.load_state_dict(state_dict_new)
for p in self.parameters():
p.requires_grad = False
self.mode = embed_time.mode
self.embedding = create_dynamic_embedding(self.mode, embed_time)
# create dynamic color layers:
# input: embeding, pts, viewdirs => delta_color
self.delta_color = MultiLinear(
input_ch=self.ch_pts+self.ch_dir+embed_time.shape[1],
output_ch=3,
init_bias=0.,
act_fn='none',
**dcolor
)
def before(self, batch, name):
data = super().before(batch, name)
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
if 'frame' in name:
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
self.cache['embed'] = self.embedding(nf)
return data
def calculate_density_color(self, wpts, viewdir, **kwargs):
raw_output = super().calculate_density_color(wpts, viewdir, **kwargs)
wpts = self.embed_pts(wpts)
input_views = self.embed_dir(viewdir)
embed = self.cache['embed'][None].expand(wpts.shape[0], wpts.shape[1], -1)
inputs = torch.cat([wpts, input_views, embed], dim=-1)
delta_color = self.delta_color(inputs) * 0.1 # avoid too much delta
color = torch.sigmoid(delta_color + raw_output['raw_rgb'])
raw_output['rgb'] = color
return raw_output
class DynamicColorNerf(NeRFT):
def __init__(self, pid, traj, embed, nerf, opt_traj_step, share_view):
super().__init__(embed, nerf)
trajs = []
for nf in range(*traj.ranges):
annname = join(traj.path, '{:06d}.json'.format(nf))
annots = read_json(annname)
annots = [a for a in annots if a['id'] == pid][0]
center = annots['keypoints3d'][0][:3]
trajs.append(center)
if share_view:
traj.nViews = 1
trajs = np.array(trajs, dtype=np.float32)[None].repeat(traj.nViews, 0)
trajs = torch.Tensor(trajs)
self.register_buffer('init_t', trajs.clone())
self.traj = nn.Parameter(trajs)
self.opt_traj_step = opt_traj_step
self.share_view = share_view
def before(self, batch, name):
data = super().before(batch, name)
if False:
results = []
for nf in range(200):
t = torch.tensor([nf/200*300], dtype=torch.float32).to(data['rgb'].device)
embed = self.embedding(t)
results.append(embed.detach().cpu().numpy())
import numpy as np
results = np.vstack(results)
import matplotlib.pyplot as plt
plt.imshow(results)
plt.show()
import ipdb;ipdb.set_trace()
nf, nv = batch['meta']['nframe'][0], batch['meta']['nview'][0]
if batch['meta']['sub'][0].startswith('novel'):
nv = 0 # use view 0 for novel view
if batch['meta']['split'][0] != 'train':
nv = 0
if self.share_view:
nv = 0
self.cache['T'] = self.traj[nv, nf]
self.cache['init'] = self.init_t[nv, nf]
if batch['step'] < self.opt_traj_step:
self.cache['T'] = self.cache['T'].detach()
reg_t = self.cache['T'] - self.cache['init']
self.cache['reg_t'] = reg_t
def calculate_density_color(self, wpts, viewdir, **kwargs):
# unpose to canonical space
wpts = wpts - self.cache['T'][None, None]
output = super().calculate_density_color(wpts, viewdir, **kwargs)
return output
if __name__ == '__main__':
embedding = MultiResEmbedding()