274 lines
10 KiB
Python
274 lines
10 KiB
Python
'''
|
|
@ Date: 2021-09-03 17:12:29
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-09-05 19:58:54
|
|
@ FilePath: /EasyMocap/easymocap/neuralbody/model/nerf.py
|
|
'''
|
|
from .base import Base
|
|
from .embedder import get_embedder
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class Nerf(Base):
|
|
def __init__(self, D, W, skips, init_bias=0.693,
|
|
D_rgb=1, W_rgb=256,
|
|
xyz_res=10, dim_pts=3,
|
|
view_res=4, dim_dir=3,# embed
|
|
ch_pts_extra=0, ch_dir_extra=0, # extra channels
|
|
latent={}, # latent code
|
|
pts_to_density=True, pts_to_rgb=False, latent_to_density=False,
|
|
use_viewdirs=True, use_occupancy=True, # option
|
|
act_fn='expsoftplus', linear_func='Linear',
|
|
relu_fn='relu',
|
|
sample_args=None,
|
|
embed_pts='none', embed_dir='none',
|
|
density_bias=True,
|
|
) -> None:
|
|
super().__init__(sample_args=sample_args)
|
|
# set the embed
|
|
self.embed_pts_name = embed_pts
|
|
self.embed_dir_name = embed_dir
|
|
if embed_pts == 'hash':
|
|
from .hashnerf import get_embedder
|
|
else:
|
|
from .embedder import get_embedder
|
|
self.embed_pts, ch_pts = get_embedder(xyz_res, dim_pts)
|
|
if embed_dir == 'hash':
|
|
from .hashnerf import get_embedder
|
|
else:
|
|
from .embedder import get_embedder
|
|
self.embed_dir, ch_dir = get_embedder(view_res, dim_dir)
|
|
self.latent_keys = list(latent.keys())
|
|
if len(self.latent_keys) > 0:
|
|
latent_dim = sum([latent[key] for key in self.latent_keys])
|
|
else:
|
|
latent_dim = 0
|
|
self.ch_pts = ch_pts
|
|
self.ch_dir = ch_dir
|
|
# set the input channels
|
|
ch_pts_inp = ch_pts_extra
|
|
if pts_to_density:
|
|
ch_pts_inp += ch_pts
|
|
if latent_to_density:
|
|
ch_pts_inp += latent_dim
|
|
self.pts_to_density = pts_to_density
|
|
self.latent_to_density = latent_to_density
|
|
self.pts_to_rgb = pts_to_rgb
|
|
self.skips = skips
|
|
self.use_viewdirs = use_viewdirs
|
|
self.use_occupancy = use_occupancy
|
|
if linear_func == 'Linear':
|
|
linear_func = lambda input_w, output_w, bias=True: nn.Linear(input_w, output_w, bias=bias)
|
|
cat_dim = -1
|
|
elif linear_func == 'Conv1d':
|
|
linear_func = lambda input_w, output_w, bias=True: nn.Conv1d(input_w, output_w, 1, bias=bias)
|
|
cat_dim = 1
|
|
else:
|
|
raise NotImplementedError
|
|
if relu_fn == 'relu':
|
|
self.relu = F.relu
|
|
elif relu_fn == 'relu6':
|
|
self.relu = F.relu6
|
|
elif relu_fn == 'LeakyRelu':
|
|
self.relu = nn.LeakyReLU(0.1)
|
|
self.cat_dim = cat_dim
|
|
|
|
self.pts_linears = nn.ModuleList(
|
|
[linear_func(ch_pts_inp, W, density_bias)] +
|
|
[linear_func(W, W, density_bias) if i not in self.skips else
|
|
linear_func(W + ch_pts_inp, W, density_bias) for i in range(D - 1)
|
|
])
|
|
|
|
self.alpha_linear = linear_func(W, 1)
|
|
# following neuralbody structure
|
|
# 1. net feature: (256) => feature_fc => (256, 256)
|
|
self.feature_linear = linear_func(W, W)
|
|
# 2. concat other feature (256+other)
|
|
# 3. latent_fc: (feature+other) => 384 => 256 (feature)
|
|
# 4. concat viewdir, light_pts (256 + viewdir /+ pts)
|
|
# 5. view_fc: (346, 128)
|
|
# 6. rgb_fc: (128, 3)
|
|
|
|
self.latent_linear = linear_func(W+latent_dim, W)
|
|
if self.use_viewdirs and not self.pts_to_rgb:
|
|
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
|
|
view_dim = ch_dir + W
|
|
### Implementation according to the paper
|
|
# self.views_linears = nn.ModuleList(
|
|
# [linear_func(input_ch_views + W, W//2)] + [linear_func(W//2, W//2) for i in range(D//2)])
|
|
elif self.use_viewdirs and self.pts_to_rgb:
|
|
view_dim = ch_dir + ch_pts + W
|
|
else:
|
|
view_dim = W
|
|
if D_rgb == 1:
|
|
self.views_linears = nn.ModuleList(
|
|
[linear_func(view_dim, W // 2)])
|
|
else:
|
|
self.views_linears = nn.ModuleList(
|
|
[linear_func(view_dim, W_rgb)] +
|
|
[linear_func(W_rgb, W_rgb) for _ in range(D_rgb - 2)] +
|
|
[linear_func(W_rgb, W // 2)]
|
|
)
|
|
self.rgb_linear = linear_func(W // 2, 3)
|
|
|
|
if self.use_occupancy:
|
|
# initial value for (e^x - 1) / e^x to give 0.5
|
|
self.alpha_linear.bias.data.fill_(init_bias)
|
|
if act_fn == 'exprelu':
|
|
self.act_alpha = lambda x:1 - torch.exp(-torch.relu(x))
|
|
elif act_fn == 'expsoftplus':
|
|
self.act_alpha = lambda x:1 - torch.exp(-F.softplus(x-1))
|
|
elif act_fn == 'sigmoid':
|
|
self.act_alpha = torch.sigmoid
|
|
|
|
def calculate_density_color(self, wpts, viewdir, latents={}, **kwargs):
|
|
# Linear mode
|
|
# wpts: (..., 3) => (..., 63)
|
|
# return: (..., 1)
|
|
if self.embed_pts_name == 'hash':
|
|
self.embed_pts.bound = self.datas_cache['bounds']
|
|
# prepare latents
|
|
latent_embeding = []
|
|
for key in self.latent_keys:
|
|
inp = latents[key]
|
|
if self.cat_dim == 1: # with Conv1d
|
|
inp = inp[..., None].expand(*inp.shape, wpts.shape[-1])
|
|
else:
|
|
inp = inp[None].expand(*wpts.shape[:-1], inp.shape[-1])
|
|
latent_embeding.append(inp)
|
|
if len(latent_embeding) > 0:
|
|
latent_embeding = torch.cat(latent_embeding, dim=self.cat_dim)
|
|
wpts = self.embed_pts(wpts)
|
|
extra_density = kwargs.get('extra_density', None)
|
|
inp = []
|
|
if self.pts_to_density:
|
|
inp.append(wpts)
|
|
if extra_density is not None:
|
|
inp.append(extra_density)
|
|
if self.latent_to_density:
|
|
inp.append(latent_embeding)
|
|
inp = torch.cat(inp, dim=self.cat_dim)
|
|
|
|
h = inp
|
|
for i, l in enumerate(self.pts_linears):
|
|
h = self.pts_linears[i](h)
|
|
h = self.relu(h)
|
|
if i in self.skips:
|
|
h = torch.cat([inp, h], self.cat_dim)
|
|
alpha = self.alpha_linear(h)
|
|
raw_alpha = alpha
|
|
if self.use_occupancy:
|
|
alpha = self.act_alpha(alpha)
|
|
# rgb part:
|
|
feature = self.feature_linear(h)
|
|
# latent:
|
|
if len(self.latent_keys) > 0:
|
|
features = [feature, latent_embeding]
|
|
features = torch.cat(features, dim=self.cat_dim)
|
|
feature = self.latent_linear(features)
|
|
# append viewdir
|
|
if self.use_viewdirs:
|
|
input_views = self.embed_dir(viewdir)
|
|
if self.pts_to_rgb:
|
|
feature = torch.cat([feature, input_views, wpts], self.cat_dim)
|
|
else:
|
|
feature = torch.cat([feature, input_views], self.cat_dim)
|
|
for i, l in enumerate(self.views_linears):
|
|
feature = self.views_linears[i](feature)
|
|
feature = self.relu(feature)
|
|
rgb = self.rgb_linear(feature)
|
|
rgb_01 = torch.sigmoid(rgb)
|
|
outputs = {
|
|
'occupancy': alpha,
|
|
'raw_alpha': raw_alpha,
|
|
'rgb': rgb_01,
|
|
'raw_rgb': rgb
|
|
}
|
|
return outputs
|
|
|
|
class MultiLinear(nn.Module):
|
|
def __init__(self, D, W, input_ch, output_ch, skips,
|
|
init_bias=0.693,
|
|
act_fn='none', linear_func='Linear',
|
|
**kwargs) -> None:
|
|
super().__init__()
|
|
self.D = D
|
|
self.W = W
|
|
self.input_ch = input_ch
|
|
self.output_ch = output_ch
|
|
self.skips = skips
|
|
if linear_func == 'Linear':
|
|
linear_func = lambda input_w, output_w: nn.Linear(input_w, output_w)
|
|
cat_dim = -1
|
|
elif linear_func == 'Conv1d':
|
|
linear_func = lambda input_w, output_w: nn.Conv1d(input_w, output_w, 1)
|
|
cat_dim = 1
|
|
else:
|
|
raise NotImplementedError
|
|
self.cat_dim = cat_dim
|
|
if D > 0:
|
|
self.linears = nn.ModuleList(
|
|
[linear_func(input_ch, W)] +
|
|
[linear_func(W, W) if i not in self.skips else
|
|
linear_func(W + input_ch, W) for i in range(D - 1)])
|
|
self.output_linear = linear_func(W, output_ch)
|
|
else:
|
|
self.linears = []
|
|
self.output_linear = linear_func(input_ch, output_ch)
|
|
|
|
if act_fn == 'softplus':
|
|
act_fn = torch.nn.functional.softplus
|
|
elif act_fn == 'none':
|
|
act_fn = None
|
|
elif act_fn == 'tanh':
|
|
act_fn = torch.tanh
|
|
self.act_fn = act_fn
|
|
# initialize to 0.5
|
|
self.output_linear.bias.data.fill_(init_bias)
|
|
|
|
def forward(self, inputs):
|
|
h = inputs
|
|
for i, l in enumerate(self.linears):
|
|
h = self.linears[i](h)
|
|
h = F.relu(h)
|
|
if i in self.skips:
|
|
h = torch.cat([inputs, h], self.cat_dim)
|
|
output = self.output_linear(h)
|
|
if self.act_fn is not None:
|
|
output = self.act_fn(output)
|
|
return output
|
|
|
|
class EmbedMLP(nn.Module):
|
|
def __init__(self, input_ch, output_ch, multi_res, W, D, bounds) -> None:
|
|
super().__init__()
|
|
self.embed, ch_time = get_embedder(multi_res, input_ch)
|
|
self.bounds = bounds
|
|
self.linear = MultiLinear(
|
|
input_ch=ch_time,
|
|
output_ch=output_ch, init_bias=0, act_fn='none',
|
|
D=D, W=W, skips=[])
|
|
|
|
def forward(self, time):
|
|
embed = self.embed(time.reshape(1, -1).float()/self.bounds)
|
|
output = self.linear(embed)
|
|
return output
|
|
|
|
if __name__ == "__main__":
|
|
cfg = {
|
|
'D': 8,
|
|
'W': 256,
|
|
'dim_pts': 3,
|
|
'dim_dir': 3,
|
|
'xyz_res': 10,
|
|
'view_res': 4,
|
|
'skips': [4],
|
|
'use_viewdirs': True,
|
|
'use_occupancy': True,
|
|
'linear_func': 'Linear',
|
|
}
|
|
network = Nerf(**cfg)
|
|
pts = torch.rand((1, 4, 4, 64, 3))
|
|
alpha = network.calculate_density(pts)
|
|
print(alpha.shape) |