from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader from dataloader.TSloader import get_dataloader as TS_loader from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] if model_name == "Informer": return Informer_loader(config, normalizer, single) elif model_name in TS_model: return TS_loader(config, normalizer, single) else : match model_name: case "STGNCDE": return cde_loader(config, normalizer, single) case "STGNRDE": return nrde_loader(config, normalizer, single) case "DCRNN": return DCRNN_loader(config, normalizer, single) case "EXP": return EXP_loader(config, normalizer, single) case _: return normal_loader(config, normalizer, single)