363 lines
12 KiB
Python
363 lines
12 KiB
Python
'''
|
|
@ Date: 2021-09-05 20:24:24
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-09-05 21:34:16
|
|
@ FilePath: /EasyMocap/easymocap/neuralbody/renderer/render_loss.py
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ...config import load_object
|
|
|
|
class LossRGB(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
self.norm = norm
|
|
|
|
def forward(self, inp, out):
|
|
diff = inp['rgb'] - out['rgb_map']
|
|
if self.norm == 'l2':
|
|
loss = torch.mean(diff ** 2)
|
|
elif self.norm == 'l1':
|
|
loss = torch.mean(torch.abs(diff))
|
|
return loss
|
|
|
|
class LossDepth(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
self.norm = norm
|
|
self.depth_max = 15.
|
|
|
|
def forward(self, inp, out):
|
|
loss_sum = 0
|
|
for key in out['keys']:
|
|
depth_gt = inp[key+'_depth']
|
|
depth_est = out['depth_map']
|
|
valid = depth_gt > 0.
|
|
depth_diff = (depth_gt[valid] - depth_est[valid])/self.depth_max
|
|
loss = torch.sum((depth_diff**2)/(1e-5 + valid.sum()))
|
|
loss_sum += loss
|
|
return loss_sum
|
|
|
|
class AnyReg(nn.Module):
|
|
def __init__(self, key, norm) -> None:
|
|
super().__init__()
|
|
self.key = key
|
|
self.norm = norm
|
|
|
|
def forward(self, inp, out):
|
|
if self.key not in out.keys():
|
|
return torch.tensor(0.).to(out['rgb_map'].device)
|
|
diff = out[self.key]
|
|
if self.norm == 'l2':
|
|
loss = torch.mean(diff ** 2)
|
|
elif self.norm == 'norm':
|
|
loss = torch.norm(diff)
|
|
else:
|
|
raise NotImplementedError
|
|
return loss
|
|
|
|
class LossMask(nn.Module):
|
|
def __init__(self, norm='l1', key='human_0'):
|
|
super().__init__()
|
|
self.norm = norm
|
|
self.key = key
|
|
|
|
def forward(self, inp, out):
|
|
pred = out['acc_map']
|
|
gt = inp['{}_coord_mask'.format(self.key)]
|
|
if self.norm == 'l1':
|
|
loss_fore = torch.mean(torch.abs(1 - pred[gt]))
|
|
loss_back = torch.mean(torch.abs(pred[~gt]))
|
|
loss = loss_fore + loss_back
|
|
elif self.norm == 'bce':
|
|
target = gt.float()
|
|
loss_fore = F.binary_cross_entropy(pred[gt].clip(1e-5, 1.0 - 1e-5), target[gt])
|
|
loss_back = F.binary_cross_entropy(pred[~gt].clip(1e-5, 1.0 - 1e-5), target[~gt])
|
|
loss = loss_fore + loss_back
|
|
return loss
|
|
|
|
class LossStepWrapper(nn.Module):
|
|
def __init__(self, weights, module, args):
|
|
super().__init__()
|
|
self.loss = load_object(module, args)
|
|
self.weights = weights
|
|
|
|
def forward(self, inp, out):
|
|
step = inp['step']
|
|
weight = 0.
|
|
for (start, end, weight) in self.weights:
|
|
if step >= start and (end == -1 or step < end):
|
|
break
|
|
if weight == 0.:
|
|
loss = torch.tensor(0.).to(out['rgb_map'].device)
|
|
else:
|
|
loss = weight * self.loss(inp, out)
|
|
return loss
|
|
|
|
class LossSemantic(nn.Module):
|
|
def __init__(self, norm, start, end) -> None:
|
|
super().__init__()
|
|
self.norm = norm
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, inp, out):
|
|
semantic = out['semantic_map']
|
|
background = 1 - torch.sum(semantic, dim=-1, keepdim=True)
|
|
semantic = torch.cat([background, semantic], dim=-1)
|
|
loss = self.loss(semantic[0], inp['human_0_semantic'][0])
|
|
return loss
|
|
|
|
class LossAcc(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
# TODO:暂时只考虑一个人的情况
|
|
diff = inp['human_0_acc'] - out['fore_acc_map']
|
|
loss = torch.mean(diff**2)
|
|
return loss
|
|
|
|
class LossCollision(nn.Module):
|
|
def __init__(self, norm, start) -> None:
|
|
super().__init__()
|
|
self.start = start
|
|
|
|
def forward(self, model, batch, output):
|
|
if len(batch['meta']['collision']) == 0 or batch['step'] < self.start:
|
|
loss = torch.tensor(0.).to(output['rgb_map'].device)
|
|
return loss
|
|
key0s, key1s, ptss = batch['meta']['collision'][0]
|
|
loss_all = []
|
|
for key0, key1, pts in zip(key0s, key1s, ptss):
|
|
pts = pts[None].to(output['rgb_map'].device)
|
|
# occ: (nPoints, 1)
|
|
occ0 = model.model(key0).calculate_density(pts)[0]
|
|
occ1 = model.model(key1).calculate_density(pts)[0]
|
|
occ01 = (occ0 + occ1 + 1e-5)
|
|
occ0_ = torch.clamp(occ0/occ01, min=1e-5)
|
|
occ1_ = torch.clamp(occ1/occ01, min=1e-5)
|
|
loss = -occ01 * (occ0_ * torch.log(occ0_) + occ1_ * torch.log(occ1_))
|
|
loss_all.append(loss)
|
|
loss_all = torch.cat(loss_all, dim=0)
|
|
loss = loss_all.mean()
|
|
return loss
|
|
|
|
class LossNormal(nn.Module):
|
|
def __init__(self, norm, perturb_surface_pt=0.01) -> None:
|
|
super().__init__()
|
|
self.perturb_surface_pt = perturb_surface_pt
|
|
|
|
@staticmethod
|
|
def get_sampling_points(bounds):
|
|
sh = bounds.shape
|
|
min_xyz = bounds[:, 0]
|
|
max_xyz = bounds[:, 1]
|
|
N_samples = 1024 * 32
|
|
x_vals = torch.rand([sh[0], N_samples], device=bounds.device)
|
|
y_vals = torch.rand([sh[0], N_samples], device=bounds.device)
|
|
z_vals = torch.rand([sh[0], N_samples], device=bounds.device)
|
|
vals = torch.stack([x_vals, y_vals, z_vals], dim=2)
|
|
vals = vals.to(bounds.device)
|
|
pts = (max_xyz - min_xyz)[:, None] * vals + min_xyz[:, None]
|
|
return pts
|
|
|
|
def forward(self, model, batch, output):
|
|
# TODO:暂时只考虑一个人的情况
|
|
key = 'human_0'
|
|
model = model.model('human_0')
|
|
# (1, 2, 3)
|
|
bounds = batch[key+'_bounds']
|
|
if False:
|
|
pts = self.get_sampling_points(bounds)
|
|
else:
|
|
pts = batch[key+'_pts'].reshape(1, -1, 3)
|
|
# 采样一些点
|
|
N_sample = 1024*32
|
|
idx = torch.randint(0, pts.shape[1], (N_sample,))
|
|
pts = pts[:, idx]
|
|
pts_neighbor = pts \
|
|
+ (torch.rand(pts.shape, device=pts.device) - 0.5) * 2. \
|
|
* self.perturb_surface_pt
|
|
_, gradients = model.gradient(pts)
|
|
_, gradients_nei = model.gradient(pts_neighbor)
|
|
loss = F.mse_loss(F.normalize(gradients, dim=-1), F.normalize(gradients_nei, dim=-1))
|
|
return loss
|
|
|
|
class LossOcc(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
loss = 0
|
|
for key in out.keys():
|
|
if not key.endswith('occ'):
|
|
continue
|
|
diff = inp[key] - out[key]
|
|
loss += torch.mean(diff ** 2)
|
|
return loss
|
|
|
|
class SmoothT(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, model, batch, output):
|
|
value = model.models['basketball'].tvec
|
|
nframe = batch['meta']['nframe'].item()
|
|
loss0, loss1 = 0, 0
|
|
cnt = 0
|
|
# 直接优化所有帧的话,会出现全都坍缩到一个点上去
|
|
if nframe > 0:
|
|
cnt += 1
|
|
loss0 = torch.mean((value[nframe] - value[nframe-1].detach())**2)
|
|
if nframe < value.shape[0]-1:
|
|
cnt += 1
|
|
loss1 = torch.mean((value[nframe] - value[nframe+1].detach())**2)
|
|
loss = (loss0 + loss1)/cnt
|
|
return loss
|
|
|
|
class LossDensity(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
flag = inp['flag'][0]
|
|
inp['density']
|
|
inpd = inp['density'][0]
|
|
outd = out['density'][0]
|
|
weight = inpd.sum()/inpd.shape[0]
|
|
diff0 = torch.mean(inpd[flag] - outd[flag]) ** 2
|
|
diff1 = torch.mean(inpd[~flag] - outd[~flag]) ** 2
|
|
loss = diff0 + diff1
|
|
return loss
|
|
|
|
class LossGround(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
pts = inp['back_pts'][0]
|
|
mask = pts[..., 2] < 1e-5 # under the ground
|
|
occ_dens = out['occ_back'][mask]
|
|
loss = (1. - occ_dens).mean()
|
|
return loss
|
|
|
|
class LossEntropy(nn.Module):
|
|
def __init__(self, norm, start) -> None:
|
|
super().__init__()
|
|
self.start = start
|
|
|
|
def forward(self, inp, out):
|
|
occ = out['occ_object']
|
|
if inp['step'] < self.start:
|
|
loss = torch.tensor(0.).to(occ.device)
|
|
return loss
|
|
entropy = -occ * torch.log(torch.clamp(occ, min=1e-5))
|
|
loss = entropy.mean()
|
|
return loss
|
|
|
|
class LossEntropyInstance(nn.Module):
|
|
def __init__(self, norm, start, end) -> None:
|
|
super().__init__()
|
|
self.norm = norm
|
|
self.start = start
|
|
self.end = end
|
|
|
|
def forward(self, inp, out):
|
|
instance = out['instance_map'][0]
|
|
loss = torch.tensor(0.).to(instance.device)
|
|
if inp['step'] < self.start or inp['step'] > self.end:
|
|
loss = torch.tensor(0.).to(instance.device)
|
|
return loss
|
|
for ikey, key in enumerate(out['keys']):
|
|
if key+'_label' in inp.keys():
|
|
label = inp[key+'_label'][0]
|
|
msk = (inp[key+'_mask'] & (label > 0))[0]
|
|
if msk.sum() > 0:
|
|
loss_ = label[msk] * (1 - instance[msk, ikey])
|
|
loss += loss_.sum()/label[msk].sum()
|
|
return loss
|
|
|
|
class LossACC(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
diff = 1. - out['acc_map']
|
|
loss = torch.mean(diff ** 2)
|
|
return loss
|
|
|
|
class LossSparseEntropy(nn.Module):
|
|
def __init__(self, norm, start, end) -> None:
|
|
super().__init__()
|
|
self.start = start
|
|
self.end = end
|
|
|
|
def forward(self, inp, out):
|
|
instance = out['instance_map']
|
|
if inp['step'] < self.start or inp['step'] > self.end:
|
|
loss = torch.tensor(0.).to(instance.device)
|
|
return loss
|
|
entropy = -instance * torch.log(torch.clamp(instance, min=1e-5))
|
|
return entropy.sum(dim=-1).mean()
|
|
|
|
class LossSemantic1(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
self.loss = nn.CrossEntropyLoss(reduction='none')
|
|
|
|
def forward(self, inp, out):
|
|
label_origin = inp['label'][0] # (N, 2)
|
|
label_valid = label_origin[:, 0] != -1
|
|
semantic_map = out['feat_map'][0][label_valid] # (N, nFeat)
|
|
if semantic_map.shape[0] == 0:
|
|
return torch.tensor(0.).to(semantic_map.device)
|
|
weight = label_origin[:, 0][label_valid]
|
|
label = label_origin[:, 1][label_valid].long()
|
|
index0 = torch.arange(0, semantic_map.shape[0])
|
|
est = semantic_map[index0, label]
|
|
loss = torch.where(est>0.5,
|
|
-torch.log(torch.clamp(est, min=1e-5)),
|
|
1. - est) * weight
|
|
# import ipdb;ipdb.set_trace()
|
|
# loss = - (torch.log(torch.clamp(semantic_map[index0, label], min=1e-5)) * weight).mean()
|
|
loss = loss.mean() / weight.sum()
|
|
return loss
|
|
|
|
class LossLayer(nn.Module):
|
|
def __init__(self, norm) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp, out):
|
|
# weights = {
|
|
# 0: 0.1,
|
|
# 1000: 0.05,
|
|
# 2000: 0.01,
|
|
# 3000: 0.
|
|
# }
|
|
weights = {
|
|
0: 0.1,
|
|
5000: 0.05,
|
|
10000: 0.01,
|
|
15000: 0.
|
|
}
|
|
weight = 0
|
|
for key, val in weights.items():
|
|
if inp['step'] > key:
|
|
weight = val
|
|
if weight == 0.:
|
|
loss = torch.tensor(0.).to(out['rgb_map'].device)
|
|
return loss
|
|
loss = 0.
|
|
cnt = 0.
|
|
for key in out['keys']:
|
|
if key + '_label' not in inp.keys():continue
|
|
label = inp[key+'_label']
|
|
acc = out[key+'_acc_map']
|
|
loss_ = ((label>0)*(1-acc)).sum()
|
|
loss += loss_
|
|
cnt += (label>0).sum()
|
|
loss = weight * loss.sum() / cnt
|
|
# print('step: {}, valid {}, weight={:.2f}, loss = {:.4f}'.format(inp['step'], cnt, weight, loss.item()))
|
|
return loss |