diff --git a/.idea/Project-I.iml b/.idea/Project-I.iml index 91f2557..c2a517e 100644 --- a/.idea/Project-I.iml +++ b/.idea/Project-I.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 2f14dc7..d5c3d09 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index e3f6bd5..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 196422d..ede70ce 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,10 +5,10 @@ - - - - + + + + - + + + + - @@ -189,6 +222,8 @@ + + @@ -207,8 +242,9 @@ - + + \ No newline at end of file diff --git a/README.md b/README.md index c843f1b..a6ee965 100644 --- a/README.md +++ b/README.md @@ -2,23 +2,34 @@ Secret Projct +``` mkdir -p models/gpt2 +``` + +## Prepare Env. +``` +pip install -r requirement.txt +``` ## Download dataset +``` python utils/download.py +``` ## Download gpt weight -mkdir -p models/gpt2 +`mkdir -p models/gpt2` Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main +```bash wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json - wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json - +``` Use pytorch >= 2.6 to load model. ## Run -Run: `python.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml +``` +`python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml +``` \ No newline at end of file diff --git a/models/STGODE_LLM_GPT2/STGODE_LLM_GPT2.py b/models/STGODE_LLM_GPT2/STGODE_LLM_GPT2.py index af65de7..c4a3bbe 100644 --- a/models/STGODE_LLM_GPT2/STGODE_LLM_GPT2.py +++ b/models/STGODE_LLM_GPT2/STGODE_LLM_GPT2.py @@ -66,11 +66,11 @@ class GPT2BackboneHF(nn.Module): super().__init__() from transformers import GPT2Model if local_dir is not None and len(local_dir) > 0: - self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True) + self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False) else: if model_name is None: model_name = 'gpt2' - self.model = GPT2Model.from_pretrained(model_name) + self.model = GPT2Model.from_pretrained(model_name, use_cache=False) if gradient_checkpointing: self.model.gradient_checkpointing_enable() self.hidden_size = self.model.config.hidden_size diff --git a/models/model_selector.py b/models/model_selector.py index cb07152..4a03c16 100644 --- a/models/model_selector.py +++ b/models/model_selector.py @@ -1,12 +1,16 @@ from models.STDEN.stden_model import STDENModel from models.STGODE.STGODE import ODEGCN +from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2 + def model_selector(config): model_name = config['basic']['model'] model = None match model_name: - case 'STDEN': + case 'STDEN': model = STDENModel(config) case 'STGODE': model = ODEGCN(config) - return model \ No newline at end of file + case 'STGODE-LLM-GPT2': + model = ODEGCN_LLM_GPT2(config) + return model