EasyMocap/myeasymocap/operations/optimizer.py

167 lines
6.7 KiB
Python
Raw Normal View History

2023-06-19 16:39:27 +08:00
import torch
import torch.nn as nn
from easymocap.config import Config, load_object
from easymocap.mytools.debug_utils import log
def dict_of_numpy_to_tensor(body_params, device):
params_ = {}
for key, val in body_params.items():
if isinstance(val, dict):
params_[key] = dict_of_numpy_to_tensor(val, device)
else:
params_[key] = torch.Tensor(val).to(device)
return params_
def dict_of_tensor_to_numpy(body_params):
params_ = {}
for key, val in body_params.items():
if isinstance(val, dict):
params_[key] = dict_of_tensor_to_numpy(val)
else:
params_[key] = val.cpu().numpy()
return params_
def make_optimizer(opt_params, optim_type='lbfgs', max_iter=20,
lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0, **kwargs):
if isinstance(opt_params, dict):
# LBFGS 不支持参数字典
opt_params = list(opt_params.values())
if optim_type == 'lbfgs':
# optimizer = torch.optim.LBFGS(
# opt_params, max_iter=max_iter, lr=lr, line_search_fn='strong_wolfe',
# tolerance_grad= 0.0000001, # float32的有效位数是7位
# tolerance_change=0.0000001,
# )
from easymocap.pyfitting.lbfgs import LBFGS
optimizer = LBFGS(opt_params, line_search_fn='strong_wolfe', max_iter=max_iter,
tolerance_grad= 0.0000001, # float32的有效位数是7位
tolerance_change=0.0000001,
**kwargs)
elif optim_type == 'adam':
optimizer = torch.optim.Adam(opt_params, lr=lr, betas=betas, weight_decay=weight_decay)
else:
raise NotImplementedError
return optimizer
def grad_require(params, flag=False):
if isinstance(params, list):
for par in params:
par.requires_grad = flag
elif isinstance(params, dict):
for key, par in params.items():
par.requires_grad = flag
def make_closure(optimizer, model, params, infos, loss, device):
loss_func = {}
for key, val in loss.items():
loss_func[key] = load_object(val['module'], val['args'])
if isinstance(loss_func[key], nn.Module):
loss_func[key].to(device)
def closure(debug=False):
optimizer.zero_grad()
new_params = params.copy()
output = model(new_params)
loss_dict = {}
loss_weight = {key:loss[key].weight for key in loss_func.keys()}
for key, func in loss_func.items():
output_ = {k: output[k] for k in loss[key].key_from_output}
infos_ = {k: infos[k] for k in loss[key].key_from_infos}
loss_now = func(output_, infos_)
if isinstance(loss_now, dict):
for k, _loss in loss_now.items():
loss_dict[key+'_'+k] = _loss
loss_weight[key+'_'+k] = loss_weight[key]
loss_weight.pop(key)
else:
loss_dict[key] = loss_now
loss_sum = sum([loss_dict[key]*loss_weight[key]
for key in loss_dict.keys()])
# for key in loss_dict.keys():
# print(key, loss_dict[key] * loss_weight[key])
# print(loss_sum)
if debug:
return loss_dict, loss_weight
loss_sum.backward()
return loss_sum
return closure
def rel_change(prev_val, curr_val):
return (prev_val - curr_val) / max([1e-5, abs(prev_val), abs(curr_val)])
class Optimizer:
def __init__(self, optimize_keys, optimizer_args, loss) -> None:
self.optimize_keys = optimize_keys
self.optimizer_args = optimizer_args
self.loss = loss
self.used_infos = []
for key, val in loss.items():
self.used_infos.extend(val.key_from_infos)
self.used_infos = list(set(self.used_infos))
def log_loss(self, iter_, closure, print_loss=False):
if iter_ % 10 == 0 or print_loss:
with torch.no_grad():
loss_dict, loss_weight = closure(debug=True)
print('{:-6d}: '.format(iter_) + ' '.join([key + ' %7.4f'%(loss_dict[key].item()*loss_weight[key]) for key in loss_dict.keys()]))
def optimizer_step(self, optimizer, closure):
prev_loss = None
self.log_loss(0, closure, True)
for iter_ in range(1, 1000):
loss = optimizer.step(closure)
# check the loss
if torch.isnan(loss).sum() > 0:
print('[optimize] NaN loss value, stopping!')
break
if torch.isinf(loss).sum() > 0:
print('[optimize] Infinite loss value, stopping!')
break
# check the delta
if iter_ > 0 and prev_loss is not None:
loss_rel_change = rel_change(prev_loss, loss.item())
if loss_rel_change <= 0.0000001:
break
self.log_loss(iter_, closure)
prev_loss = loss.item()
self.log_loss(iter_, closure, True)
return True
def __call__(self, params, model, **infos):
"""
待优化变量一定要在params中但params中不一定会被优化
infos中的变量不一定会被优化
"""
# TODO: 应该使用model的device但考虑到model可能是一个函数所以暂时当场计算
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
params = dict_of_numpy_to_tensor(params, device=device)
infos_used = {key: infos[key] for key in self.used_infos if key in infos.keys()}
infos_used = dict_of_numpy_to_tensor(infos_used, device=device)
log('[{}] Optimize {}'.format(self.__class__.__name__, self.optimize_keys))
log('[{}] Loading {}'.format(self.__class__.__name__, self.used_infos))
opt_params = {}
for key in self.optimize_keys:
if key in infos.keys(): # 优化的参数
opt_params[key] = infos_used[key]
elif key in params.keys():
opt_params[key] = params[key]
else:
raise ValueError('{} is not in infos or body_params'.format(key))
for key, val in opt_params.items():
infos_used['init_'+key] = val.clone()
optimizer = make_optimizer(opt_params, **self.optimizer_args)
closure = make_closure(optimizer, model, params, infos_used, self.loss, device)
# 准备开始优化
grad_require(opt_params, True)
self.optimizer_step(optimizer, closure)
grad_require(opt_params, False)
# 直接返回
ret = {
'params': params
}
for key in self.optimize_keys:
if key in infos.keys():
ret[key] = opt_params[key]
ret = dict_of_tensor_to_numpy(ret)
return ret