106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
'''
|
|
@ Date: 2021-09-05 20:12:41
|
|
@ Author: Qing Shuai
|
|
@ LastEditors: Qing Shuai
|
|
@ LastEditTime: 2021-09-05 20:12:42
|
|
@ FilePath: /EasyMocap/easymocap/neuralbody/trainer/net_load.py
|
|
'''
|
|
import os
|
|
from termcolor import colored
|
|
import torch
|
|
|
|
def load_model(net,
|
|
optim,
|
|
scheduler,
|
|
recorder,
|
|
model_dir,
|
|
resume=True,
|
|
epoch=-1):
|
|
if not resume:
|
|
os.system('rm -rf {}'.format(model_dir))
|
|
|
|
if not os.path.exists(model_dir):
|
|
return 0
|
|
|
|
pths = [
|
|
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
|
|
if pth != 'latest.pth'
|
|
]
|
|
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
|
|
return 0
|
|
if epoch == -1:
|
|
if 'latest.pth' in os.listdir(model_dir):
|
|
pth = 'latest'
|
|
else:
|
|
pth = max(pths)
|
|
else:
|
|
pth = epoch
|
|
print('load model: {}'.format(os.path.join(model_dir,
|
|
'{}.pth'.format(pth))))
|
|
pretrained_model = torch.load(
|
|
os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu')
|
|
net.load_state_dict(pretrained_model['net'])
|
|
optim.load_state_dict(pretrained_model['optim'])
|
|
scheduler.load_state_dict(pretrained_model['scheduler'])
|
|
recorder.load_state_dict(pretrained_model['recorder'])
|
|
return pretrained_model['epoch'] + 1
|
|
|
|
|
|
def save_model(net, optim, scheduler, recorder, model_dir, epoch, last=False):
|
|
os.system('mkdir -p {}'.format(model_dir))
|
|
model = {
|
|
'net': net.state_dict(),
|
|
'optim': optim.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
'recorder': recorder.state_dict(),
|
|
'epoch': epoch
|
|
}
|
|
if epoch > 20 and (epoch+1) % 10 != 0 and not last:
|
|
return 0
|
|
if last:
|
|
torch.save(model, os.path.join(model_dir, 'latest.pth'))
|
|
else:
|
|
torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch)))
|
|
return 0
|
|
# remove previous pretrained model if the number of models is too big
|
|
pths = [
|
|
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
|
|
if pth != 'latest.pth'
|
|
]
|
|
if len(pths) <= 20:
|
|
return
|
|
os.system('rm {}'.format(
|
|
os.path.join(model_dir, '{}.pth'.format(min(pths)))))
|
|
|
|
def load_network(net, model_dir, resume=True, epoch=-1, strict=True):
|
|
if not resume:
|
|
return 0
|
|
|
|
if not os.path.exists(model_dir):
|
|
print(colored('pretrained model does not exist', 'red'))
|
|
return 0
|
|
|
|
if os.path.isdir(model_dir):
|
|
pths = [
|
|
int(pth.split('.')[0]) for pth in os.listdir(model_dir)
|
|
if pth != 'latest.pth'
|
|
]
|
|
if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
|
|
return 0
|
|
if epoch == -1:
|
|
if 'latest.pth' in os.listdir(model_dir):
|
|
pth = 'latest'
|
|
else:
|
|
pth = max(pths)
|
|
else:
|
|
pth = max(epoch, -1)
|
|
model_path = os.path.join(model_dir, '{}.pth'.format(pth))
|
|
else:
|
|
model_path = model_dir
|
|
|
|
print('load model: {}'.format(model_path))
|
|
pretrained_model = torch.load(model_path)
|
|
net.load_state_dict(pretrained_model['net'], strict=strict)
|
|
return pretrained_model['epoch'] + 1
|
|
|