EasyMocap/apps/neuralbody/train_pl.py
2022-10-25 20:06:04 +08:00

233 lines
9.3 KiB
Python

# Training code based on PyTorch-Lightning
import os
from os.path import join
from easymocap.mytools.debug_utils import myerror
import torch
from easymocap.config import load_object, Config
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
# https://github.com/Project-MONAI/MONAI/issues/701
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
class plwrapper(pl.LightningModule):
def __init__(self, cfg, mode='train'):
super().__init__()
# load model
self.cfg = cfg
self.network = load_object(cfg.network_module, cfg.network_args)
trainer_args = dict(cfg.trainer_args)
trainer_args['net'] = self.network
self.train_renderer = load_object(cfg.trainer_module, trainer_args)
if mode == 'train' or mode == 'trainvis':
self.train_dataset = load_object(cfg.data_train_module, cfg.data_train_args)
# self.val_dataset = load_object(cfg.data_val_module, cfg.data_val_args)
else:
if mode + '_renderer_module' in cfg.keys():
module, args = cfg[mode+'_renderer_module'], cfg[mode+'_renderer_args']
else:
module, args = cfg.renderer_module, cfg.renderer_args
self.test_renderer = load_object(module, args, net=self.network)
if mode + '_visualizer_module' in cfg.keys():
module, args = cfg[mode+'_visualizer_module'], cfg[mode+'_visualizer_args']
else:
module, args = cfg.visualizer_module, cfg.visualizer_args
self.visualizer = load_object(module, args)
def forward(self, batch):
# in lightning, forward defines the prediction/inference actions
self.network.train()
batch['step'] = self.trainer.global_step
batch['meta']['step'] = self.trainer.global_step
output = self.test_renderer(batch)
self.visualizer(output, batch)
return 0
def training_step(self, batch, batch_idx):
batch['step'] = self.trainer.global_step
batch['meta']['step'] = self.trainer.global_step
# training_step defines the train loop. It is independent of forward
output, loss, loss_stats, image_stats = self.train_renderer(batch)
for key, val in loss_stats.items():
self.log(key, val)
return loss
def train_dataloader(self):
from easymocap.neuralbody.trainer.dataloader import make_data_sampler, make_batch_data_sampler, make_collator, worker_init_fn
shuffle = True
is_distributed = len(cfg.gpus) > 1
is_train = True
sampler = make_data_sampler(self.cfg, self.train_dataset, shuffle, is_distributed, is_train)
batch_size = self.cfg.train.batch_size
drop_last = False
max_iter = cfg.train.ep_iter
self.batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size,
drop_last, max_iter, is_train)
num_workers = cfg.train.num_workers
collator = make_collator(cfg, is_train)
data_loader = torch.utils.data.DataLoader(self.train_dataset,
batch_sampler=self.batch_sampler,
num_workers=num_workers,
collate_fn=collator,
worker_init_fn=worker_init_fn)
return data_loader
def configure_optimizers(self):
from easymocap.neuralbody.trainer.optimizer import Optimizer
from easymocap.neuralbody.trainer.lr_sheduler import Scheduler, set_lr_scheduler
optimizer = Optimizer(self.network, cfg.optimizer)
scheduler = Scheduler(cfg.scheduler, optimizer)
return [optimizer], [scheduler]
def on_train_epoch_end(self):
if len(cfg.gpus) > 1:
self.batch_sampler.sampler.set_epoch(self.current_epoch)
def train(cfg):
model = plwrapper(cfg)
if cfg.resume and os.path.exists(join(cfg.trained_model_dir, 'last.ckpt')):
resume_from_checkpoint = join(cfg.trained_model_dir, 'last.ckpt')
else:
resume_from_checkpoint = None
if os.path.exists(cfg.recorder_args.log_dir):
# os.removedirs(cfg.recorder_args.log_dir)
pass
os.makedirs(cfg.recorder_args.log_dir, exist_ok=True)
print(cfg, file=open(join(cfg.recorder_args.log_dir, 'exp.yml'), 'w'))
logger = TensorBoardLogger(save_dir=cfg.recorder_args.log_dir, name=cfg.exp)
ckpt_callback = pl.callbacks.ModelCheckpoint(
verbose=True,
dirpath=cfg.trained_model_dir,
every_n_epochs=5 if not args.debug else 1,
save_last=True,
save_top_k=-1,
monitor='loss',
filename="{epoch}")
# Log true learning rate, serves as LR-Scheduler callback
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
extra_args = {
# 'num_nodes': len(cfg.gpus),
'accelerator': 'gpu',
}
if len(cfg.gpus) > 0:
extra_args['strategy'] = 'ddp'
extra_args['replace_sampler_ddp'] = False
trainer = pl.Trainer(
gpus=len(cfg.gpus),
logger=logger,
resume_from_checkpoint=resume_from_checkpoint,
callbacks=[ckpt_callback, lr_monitor],
max_epochs=cfg.train.epoch,
# profiler='simple',
**extra_args
)
trainer.fit(model)
def load_ckpt(model, ckpt_path, model_name='network'):
print('Load from {}'.format(ckpt_path))
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
epoch = checkpoint['epoch']
if 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
checkpoint_ = {}
for k, v in checkpoint.items():
if not k.startswith(model_name):
continue
k = k[len(model_name)+1:]
for prefix in []:
if k.startswith(prefix):
break
else:
checkpoint_[k] = v
model.load_state_dict(checkpoint_, strict=False)
return epoch
def test(cfg):
from glob import glob
from os.path import join
from tqdm import tqdm
model = plwrapper(cfg, mode=cfg.split)
ckptpath = join(cfg.trained_model_dir, 'last.ckpt')
if os.path.exists(ckptpath):
epoch = load_ckpt(model.network, ckptpath)
else:
myerror('{} not exists'.format(ckptpath))
epoch = -1
model.step = epoch * 1000
if cfg['output'] == 'none':
vis_out_dir = join('neuralbody', cfg.exp, cfg.split + '_{}'.format(epoch))
else:
vis_out_dir = join('neuralbody', cfg.exp, cfg.output)
model.visualizer.data_dir = vis_out_dir
model.visualizer.subs = cfg.data_val_args.subs
if cfg.split == 'test' or cfg.split == 'eval':
dataset = load_object(cfg.data_val_module, cfg.data_val_args)
elif cfg.split in ['demo', 'canonical', 'novelposes']:
dataset = load_object(cfg['data_{}_module'.format(cfg.split)], cfg['data_{}_args'.format(cfg.split)])
elif cfg.split == 'trainvis':
dataset = model.train_dataset
dataset.sample_args.nrays *= 16
ranges = cfg.get('visranges', [0, -1, 1])
if ranges[1] == -1:
ranges[1] = len(dataset)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=1, num_workers=cfg.test.num_workers)
extra_args = {
'accelerator': 'gpu',
}
if len(cfg.gpus) > 1:
extra_args['strategy'] = 'ddp'
trainer = pl.Trainer(
gpus=len(cfg.gpus),
max_epochs=cfg.train.epoch,
**extra_args
)
preds = trainer.predict(model, dataloader)
def parse(args, cfg):
from os.path import join
cfg.recorder_args.local_rank = cfg.local_rank
if not args.slurm:
os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus])
assert cfg.exp != "", "Please set the experiement name"
cfg.trained_model_dir = join('neuralbody', cfg.exp, 'model')
os.makedirs(cfg.trained_model_dir, exist_ok=True)
cfg.recorder_args.log_dir = join('neuralbody', cfg.exp, 'record')
os.makedirs(cfg.recorder_args.log_dir, exist_ok=True)
exp = 'vis'
if 'keyframe' in cfg.data_val_args:
exp += '_{}'.format(cfg.data_val_args.keyframe)
if 'pid' in cfg.data_val_args:
exp += '_{}'.format(cfg.data_val_args.pid)
if 'pids' in cfg.data_val_args:
exp += '_{}'.format('+'.join(list(map(str, cfg.data_val_args.pids))))
if cfg.split == 'eval':
if 'camnf' not in cfg.visualizer_args.format:
cfg.visualizer_args.format = 'camnf'
cfg.visualizer_args.concat = 'none'
cfg.visualizer_args['keys'] = list(cfg.visualizer_args['keys']) + ['rgb', 'instance_map'] + ['raw_depth']
assert len(cfg.data_val_args.subs) > 0, cfg.data_val_args.subs
cfg.visualizer_args['subs'] = cfg.data_val_args.subs
if __name__ == "__main__":
usage = '''This is the training script for Neuralbody'''
args, cfg = Config.load_args(usage=usage)
parse(args, cfg)
if cfg.fix_random:
seed_everything(666)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if cfg.split == 'train':
train(cfg)
elif cfg.split in ['test', 'demo', 'eval', 'trainvis', 'canonical', 'novelposes']:
test(cfg)