TrafficWheel/dataloader/loader_selector.py

20 lines
800 B
Python
Executable File

from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
from dataloader.TSloader import get_dataloader as TS_loader
from dataloader.Informer_loader import get_dataloader as Informer_loader
def get_dataloader(config, normalizer, single):
loader_map = {
"STGNCDE": cde_loader,
"STGNRDE": nrde_loader,
"DCRNN": DCRNN_loader,
"EXP": EXP_loader,
}
return loader_map.get(config["basic"]["model"], normal_loader)(
config, normalizer, single
)