from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader from dataloader.TSloader import get_dataloader as TS_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI"] model_name = config["basic"]["model"] if model_name in TS_model: return TS_loader(config, normalizer, single) else : match model_name: case "STGNCDE": return cde_loader(config, normalizer, single) case "STGNRDE": return nrde_loader(config, normalizer, single) case "DCRNN": return DCRNN_loader(config, normalizer, single) case "EXP": return EXP_loader(config, normalizer, single) case _: return normal_loader(config, normalizer, single)