TrafficWheel/dataloader/loader_selector.py

26 lines
1.1 KiB
Python
Executable File

from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
from dataloader.TSloader import get_dataloader as TS_loader
def get_dataloader(config, normalizer, single):
TS_model = ["iTransformer", "HI"]
model_name = config["basic"]["model"]
if model_name in TS_model:
return TS_loader(config, normalizer, single)
else :
match model_name:
case "STGNCDE":
return cde_loader(config, normalizer, single)
case "STGNRDE":
return nrde_loader(config, normalizer, single)
case "DCRNN":
return DCRNN_loader(config, normalizer, single)
case "EXP":
return EXP_loader(config, normalizer, single)
case _:
return normal_loader(config, normalizer, single)