EasyMocap/easymocap/neuralbody/trainer/net_loader.py
2022-10-25 20:06:04 +08:00

106 lines
3.2 KiB
Python

'''
@ Date: 2021-09-05 20:12:41
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-09-05 20:12:42
@ FilePath: /EasyMocap/easymocap/neuralbody/trainer/net_load.py
'''
import os
from termcolor import colored
import torch
def load_model(net,
optim,
scheduler,
recorder,
model_dir,
resume=True,
epoch=-1):
if not resume:
os.system('rm -rf {}'.format(model_dir))
if not os.path.exists(model_dir):
return 0
pths = [
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
if pth != 'latest.pth'
]
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
return 0
if epoch == -1:
if 'latest.pth' in os.listdir(model_dir):
pth = 'latest'
else:
pth = max(pths)
else:
pth = epoch
print('load model: {}'.format(os.path.join(model_dir,
'{}.pth'.format(pth))))
pretrained_model = torch.load(
os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu')
net.load_state_dict(pretrained_model['net'])
optim.load_state_dict(pretrained_model['optim'])
scheduler.load_state_dict(pretrained_model['scheduler'])
recorder.load_state_dict(pretrained_model['recorder'])
return pretrained_model['epoch'] + 1
def save_model(net, optim, scheduler, recorder, model_dir, epoch, last=False):
os.system('mkdir -p {}'.format(model_dir))
model = {
'net': net.state_dict(),
'optim': optim.state_dict(),
'scheduler': scheduler.state_dict(),
'recorder': recorder.state_dict(),
'epoch': epoch
}
if epoch > 20 and (epoch+1) % 10 != 0 and not last:
return 0
if last:
torch.save(model, os.path.join(model_dir, 'latest.pth'))
else:
torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch)))
return 0
# remove previous pretrained model if the number of models is too big
pths = [
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
if pth != 'latest.pth'
]
if len(pths) <= 20:
return
os.system('rm {}'.format(
os.path.join(model_dir, '{}.pth'.format(min(pths)))))
def load_network(net, model_dir, resume=True, epoch=-1, strict=True):
if not resume:
return 0
if not os.path.exists(model_dir):
print(colored('pretrained model does not exist', 'red'))
return 0
if os.path.isdir(model_dir):
pths = [
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
if pth != 'latest.pth'
]
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
return 0
if epoch == -1:
if 'latest.pth' in os.listdir(model_dir):
pth = 'latest'
else:
pth = max(pths)
else:
pth = max(epoch, -1)
model_path = os.path.join(model_dir, '{}.pth'.format(pth))
else:
model_path = model_dir
print('load model: {}'.format(model_path))
pretrained_model = torch.load(model_path)
net.load_state_dict(pretrained_model['net'], strict=strict)
return pretrained_model['epoch'] + 1