TrafficWheel/dataloader/loader_selector.py

14 lines
673 B
Python
Executable File

from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader
def get_dataloader(config, normalizer, single):
match config['model']['type']:
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
case 'EXP': return EXP_loader(config['data'], normalizer, single)
case _: return normal_loader(config['data'], normalizer, single)