48 lines
1.9 KiB
Python
48 lines
1.9 KiB
Python
'''
|
|
@ Date: 2021-09-05 20:24:16
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-09-05 21:25:08
|
|
@ FilePath: /EasyMocap/easymocap/neuralbody/renderer/render_wrapper.py
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
from ...config import load_object
|
|
|
|
class RenderWrapper(nn.Module):
|
|
def __init__(self, net, renderer_module, renderer_args, loss, loss_reg={}):
|
|
super().__init__()
|
|
renderer_args = dict(renderer_args)
|
|
renderer_args['net'] = net
|
|
self.renderer = load_object(renderer_module, renderer_args)
|
|
self.weights = {key:val['weight'] for key, val in loss.items()}
|
|
self.weights.update({key:val['weight'] for key, val in loss_reg.items()})
|
|
loss = {key:load_object(val.module, val.args) for key, val in loss.items()}
|
|
loss_reg = {key:load_object(val.module, val.args) for key, val in loss_reg.items()}
|
|
self.loss = nn.ModuleDict(loss)
|
|
self.loss_reg = nn.ModuleDict(loss_reg)
|
|
|
|
def forward(self, batch):
|
|
ret = self.renderer(batch)
|
|
loss = 0
|
|
scalar_stats = {}
|
|
for key, func in self.loss.items():
|
|
val = func(batch, ret)
|
|
scalar_stats[key] = val
|
|
loss += self.weights[key] * val
|
|
for key, func in self.loss_reg.items():
|
|
val = func(self.renderer.net, batch, ret)
|
|
scalar_stats[key] = val
|
|
loss += self.weights[key] * val
|
|
for key in ['rgb_map', 'acc_map', 'occ_object', 'occ_back', 'human_0_occ']:
|
|
if key not in ret.keys():
|
|
continue
|
|
scalar_stats[key] = ret[key].mean()
|
|
for key in ['rgb']:
|
|
if key not in ret.keys():
|
|
continue
|
|
scalar_stats['mean_'+key] = batch[key].mean()
|
|
scalar_stats.update({'loss': loss})
|
|
image_stats = {}
|
|
|
|
return ret, loss, scalar_stats, image_stats |