189 lines
7.0 KiB
Python
189 lines
7.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .nerf import Nerf, EmbedMLP, MultiLinear
|
|
from os.path import join
|
|
from ...mytools.file_utils import read_json
|
|
import numpy as np
|
|
|
|
def create_dynamic_embedding(mode, embed):
|
|
if mode == 'dense':
|
|
embedding = nn.Embedding(embed.shape[0], embed.shape[1])
|
|
elif mode == 'mlp':
|
|
if 'D' not in embed.keys():
|
|
embedding = EmbedMLP(
|
|
input_ch=1,
|
|
multi_res=32,
|
|
W=128,
|
|
D=2,
|
|
bounds=embed.shape[0],
|
|
output_ch=embed.shape[1])
|
|
else:
|
|
embedding = EmbedMLP(
|
|
input_ch=1,
|
|
multi_res=32,
|
|
W=embed.W,
|
|
D=embed.D,
|
|
bounds=embed.shape[0],
|
|
output_ch=embed.shape[1])
|
|
else:
|
|
raise NotImplementedError
|
|
return embedding
|
|
|
|
class NeRFT(Nerf):
|
|
def __init__(self, embed, nerf):
|
|
nerf['latent'] = {'time': embed.shape[1]}
|
|
super().__init__(**nerf)
|
|
self.mode = embed.mode
|
|
self.embedding = create_dynamic_embedding(self.mode, embed)
|
|
self.cache = {}
|
|
|
|
def clear_cache(self):
|
|
self.cache = {}
|
|
|
|
def before(self, batch, name):
|
|
data = super().before(batch, name)
|
|
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
|
|
if 'frame' in name:
|
|
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
|
|
self.cache['embed'] = self.embedding(nf)
|
|
return data
|
|
|
|
def calculate_density_color(self, wpts, viewdir, **kwargs):
|
|
latents = {'time': self.cache['embed']}
|
|
return super().calculate_density_color(wpts, viewdir, latents, **kwargs)
|
|
|
|
class NeRFGroundShadow(Nerf):
|
|
def __init__(self, embed, shadow, nerf):
|
|
super().__init__(**nerf)
|
|
self.shadow = MultiLinear(
|
|
D=shadow.D,
|
|
W=shadow.W,
|
|
input_ch=self.ch_pts + embed.shape[1],
|
|
output_ch=1, # 输出一维阴影
|
|
init_bias=5,
|
|
act_fn='none',
|
|
skips=[]
|
|
)
|
|
nerf['latent'] = {'time': embed.shape[1]}
|
|
self.mode = embed.mode
|
|
self.embedding = create_dynamic_embedding(self.mode, embed)
|
|
self.cache = {}
|
|
|
|
def clear_cache(self):
|
|
self.cache = {}
|
|
|
|
def before(self, batch, name):
|
|
data = super().before(batch, name)
|
|
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
|
|
if 'frame' in name:
|
|
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
|
|
self.cache['embed'] = self.embedding(nf)
|
|
return data
|
|
|
|
def calculate_density_color(self, wpts, viewdir, **kwargs):
|
|
latents = self.cache['embed'][None]
|
|
raw_output = super().calculate_density_color(wpts, viewdir, **kwargs)
|
|
pts_embed = self.embed_pts(wpts)
|
|
latents = latents.expand(pts_embed.shape[0], pts_embed.shape[1], -1)
|
|
shadow = self.shadow(torch.cat([pts_embed, latents], dim=-1))
|
|
shadow = torch.sigmoid(shadow)
|
|
raw_output['rgb'] = shadow * raw_output['rgb']
|
|
return raw_output
|
|
|
|
class NeRFT_pretrain(Nerf):
|
|
def __init__(self, nerf, embed_time, pretrain, dcolor):
|
|
super().__init__(**nerf)
|
|
state_dict = torch.load(pretrain, map_location='cpu')['state_dict']
|
|
state_dict_new = {}
|
|
for key, val in state_dict.items():
|
|
if key.startswith('train_renderer.'): continue
|
|
state_dict_new[key.replace('network.', '')] = val
|
|
self.load_state_dict(state_dict_new)
|
|
for p in self.parameters():
|
|
p.requires_grad = False
|
|
self.mode = embed_time.mode
|
|
self.embedding = create_dynamic_embedding(self.mode, embed_time)
|
|
# create dynamic color layers:
|
|
# input: embeding, pts, viewdirs => delta_color
|
|
self.delta_color = MultiLinear(
|
|
input_ch=self.ch_pts+self.ch_dir+embed_time.shape[1],
|
|
output_ch=3,
|
|
init_bias=0.,
|
|
act_fn='none',
|
|
**dcolor
|
|
)
|
|
|
|
def before(self, batch, name):
|
|
data = super().before(batch, name)
|
|
nf, nv = batch['meta']['time'][0], batch['meta']['nview'][0]
|
|
if 'frame' in name:
|
|
nf = nf + batch[name+'_frame'] - batch['meta']['nframe']
|
|
self.cache['embed'] = self.embedding(nf)
|
|
return data
|
|
|
|
def calculate_density_color(self, wpts, viewdir, **kwargs):
|
|
raw_output = super().calculate_density_color(wpts, viewdir, **kwargs)
|
|
wpts = self.embed_pts(wpts)
|
|
input_views = self.embed_dir(viewdir)
|
|
embed = self.cache['embed'][None].expand(wpts.shape[0], wpts.shape[1], -1)
|
|
inputs = torch.cat([wpts, input_views, embed], dim=-1)
|
|
delta_color = self.delta_color(inputs) * 0.1 # avoid too much delta
|
|
color = torch.sigmoid(delta_color + raw_output['raw_rgb'])
|
|
raw_output['rgb'] = color
|
|
return raw_output
|
|
|
|
class DynamicColorNerf(NeRFT):
|
|
def __init__(self, pid, traj, embed, nerf, opt_traj_step, share_view):
|
|
super().__init__(embed, nerf)
|
|
trajs = []
|
|
for nf in range(*traj.ranges):
|
|
annname = join(traj.path, '{:06d}.json'.format(nf))
|
|
annots = read_json(annname)
|
|
annots = [a for a in annots if a['id'] == pid][0]
|
|
center = annots['keypoints3d'][0][:3]
|
|
trajs.append(center)
|
|
if share_view:
|
|
traj.nViews = 1
|
|
trajs = np.array(trajs, dtype=np.float32)[None].repeat(traj.nViews, 0)
|
|
trajs = torch.Tensor(trajs)
|
|
self.register_buffer('init_t', trajs.clone())
|
|
self.traj = nn.Parameter(trajs)
|
|
self.opt_traj_step = opt_traj_step
|
|
self.share_view = share_view
|
|
|
|
def before(self, batch, name):
|
|
data = super().before(batch, name)
|
|
if False:
|
|
results = []
|
|
for nf in range(200):
|
|
t = torch.tensor([nf/200*300], dtype=torch.float32).to(data['rgb'].device)
|
|
embed = self.embedding(t)
|
|
results.append(embed.detach().cpu().numpy())
|
|
import numpy as np
|
|
results = np.vstack(results)
|
|
import matplotlib.pyplot as plt
|
|
plt.imshow(results)
|
|
plt.show()
|
|
import ipdb;ipdb.set_trace()
|
|
nf, nv = batch['meta']['nframe'][0], batch['meta']['nview'][0]
|
|
if batch['meta']['sub'][0].startswith('novel'):
|
|
nv = 0 # use view 0 for novel view
|
|
if batch['meta']['split'][0] != 'train':
|
|
nv = 0
|
|
if self.share_view:
|
|
nv = 0
|
|
self.cache['T'] = self.traj[nv, nf]
|
|
self.cache['init'] = self.init_t[nv, nf]
|
|
if batch['step'] < self.opt_traj_step:
|
|
self.cache['T'] = self.cache['T'].detach()
|
|
reg_t = self.cache['T'] - self.cache['init']
|
|
self.cache['reg_t'] = reg_t
|
|
|
|
def calculate_density_color(self, wpts, viewdir, **kwargs):
|
|
# unpose to canonical space
|
|
wpts = wpts - self.cache['T'][None, None]
|
|
output = super().calculate_density_color(wpts, viewdir, **kwargs)
|
|
return output
|
|
|
|
if __name__ == '__main__':
|
|
embedding = MultiResEmbedding() |