diff --git a/model/model_selector.py b/model/model_selector.py index 0117619..091a465 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -14,6 +14,7 @@ from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.STMLP.STMLP import STMLP +from model.STIDGCN.STIDGCN import STIDGCN def model_selector(model): match model['type']: @@ -33,4 +34,5 @@ def model_selector(model): case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) case 'STMLP': return STMLP(model) + case 'STIDGCN': return STIDGCN(model)