99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
|
'''
|
||
|
@ Date: 2021-07-20 12:32:29
|
||
|
@ Author: Qing Shuai
|
||
|
@ LastEditors: Qing Shuai
|
||
|
@ LastEditTime: 2021-09-05 20:19:11
|
||
|
@ FilePath: /EasyMocap/easymocap/neuralbody/trainer/dataloader.py
|
||
|
'''
|
||
|
from easymocap.config.baseconfig import load_object
|
||
|
import torch
|
||
|
|
||
|
def make_data_sampler(cfg, dataset, shuffle, is_distributed, is_train):
|
||
|
if not is_train and cfg.test.sampler == 'FrameSampler':
|
||
|
from .samplers import FrameSampler
|
||
|
sampler = FrameSampler(dataset)
|
||
|
return sampler
|
||
|
if is_distributed:
|
||
|
from .samplers import DistributedSampler
|
||
|
return DistributedSampler(dataset, shuffle=shuffle)
|
||
|
if shuffle:
|
||
|
sampler = torch.utils.data.sampler.RandomSampler(dataset)
|
||
|
else:
|
||
|
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
|
||
|
return sampler
|
||
|
|
||
|
def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter,
|
||
|
is_train):
|
||
|
if is_train:
|
||
|
batch_sampler = cfg.train.batch_sampler
|
||
|
else:
|
||
|
batch_sampler = cfg.test.batch_sampler
|
||
|
|
||
|
if batch_sampler == 'default':
|
||
|
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
||
|
sampler, batch_size, drop_last)
|
||
|
elif batch_sampler == 'image_size':
|
||
|
raise NotImplementedError
|
||
|
|
||
|
if max_iter != -1:
|
||
|
from .samplers import IterationBasedBatchSampler
|
||
|
batch_sampler = IterationBasedBatchSampler(
|
||
|
batch_sampler, max_iter)
|
||
|
return batch_sampler
|
||
|
|
||
|
|
||
|
def worker_init_fn(worker_id):
|
||
|
import numpy as np
|
||
|
import time
|
||
|
# np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16))))
|
||
|
|
||
|
|
||
|
def make_collator(cfg, is_train):
|
||
|
_collators = {
|
||
|
}
|
||
|
from torch.utils.data.dataloader import default_collate
|
||
|
collator = cfg.train.collator if is_train else cfg.test.collator
|
||
|
if collator in _collators:
|
||
|
return _collators[collator]
|
||
|
else:
|
||
|
return default_collate
|
||
|
|
||
|
def Dataloader(cfg, split='train', is_train=True, start=0):
|
||
|
is_distributed = cfg.distributed
|
||
|
if split == 'train' and is_train:
|
||
|
batch_size = cfg.train.batch_size
|
||
|
max_iter = cfg.train.ep_iter
|
||
|
# shuffle = True
|
||
|
shuffle = cfg.train.shuffle
|
||
|
drop_last = False
|
||
|
else:
|
||
|
batch_size = cfg.test.batch_size
|
||
|
shuffle = True if is_distributed else False
|
||
|
drop_last = False
|
||
|
max_iter = -1
|
||
|
if split == 'train' and is_train:
|
||
|
dataset = load_object(cfg.data_train_module, cfg.data_train_args)
|
||
|
elif split == 'train' and not is_train:
|
||
|
cfg.data_train_args.split = 'test'
|
||
|
dataset = load_object(cfg.data_train_module, cfg.data_train_args)
|
||
|
elif split in ['test', 'val']:
|
||
|
dataset = load_object(cfg.data_val_module, cfg.data_val_args)
|
||
|
elif split == 'demo':
|
||
|
dataset = load_object(cfg.data_demo_module, cfg.data_demo_args)
|
||
|
elif split == 'mesh':
|
||
|
dataset = load_object(cfg.data_mesh_module, cfg.data_mesh_args)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
is_train = (split == 'train') and is_train
|
||
|
sampler = make_data_sampler(cfg, dataset, shuffle, is_distributed, is_train)
|
||
|
batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size,
|
||
|
drop_last, max_iter, is_train)
|
||
|
num_workers = cfg.train.num_workers if is_train else cfg.test.num_workers
|
||
|
collator = make_collator(cfg, is_train)
|
||
|
data_loader = torch.utils.data.DataLoader(dataset,
|
||
|
batch_sampler=batch_sampler,
|
||
|
num_workers=num_workers,
|
||
|
collate_fn=collator,
|
||
|
worker_init_fn=worker_init_fn)
|
||
|
|
||
|
return data_loader
|