diff --git a/README.md b/README.md index a502bba..5e54f9b 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,16 @@ python prepare_pems_bay.py 根据BasicTS仓库配置BasicTS环境,亦或是使用 -`pip install -r requirement.txt` +`pip install -r requirements.txt` 我是直接使用现有的BasicTS环境,因此没有做过测试 +创建log文件夹到项目仓库 + +```bash +mkdir log +``` + 开跑 ```python python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12 diff --git a/models/repst.py b/models/repst.py index adf4908..a915e9c 100644 --- a/models/repst.py +++ b/models/repst.py @@ -7,8 +7,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers import GPT2Model, GPT2Config from einops import rearrange -from .reprogramming import * -from .normalizer import * +from models.reprogramming import * +from models.normalizer import * class repst(nn.Module): diff --git a/requirements.txt b/requirements.txt index f1a59c5..0834779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -torch==2.0.1 -accelerate==0.26.1 -einops==0.6.0 -matplotlib==3.8.2 -numpy==1.24.4 -pandas==2.1.4 -scikit-learn==1.3.2 -scipy==1.11.4 -tqdm==4.66.1 -transformers==4.36.2 +torch +accelerate +einops +matplotlib +numpy +pandas +scikit-learn +scipy +tqdm +transformers diff --git a/utils/former_tools.py b/utils/former_tools.py index 34f4415..fa2e8b7 100644 --- a/utils/former_tools.py +++ b/utils/former_tools.py @@ -47,7 +47,7 @@ class EarlyStopping: self.counter = 0 self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf + self.val_loss_min = np.inf self.delta = delta def __call__(self, val_loss, model, path): diff --git a/utils/tools.py b/utils/tools.py index a911d8d..0a2c592 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -47,7 +47,7 @@ class EarlyStopping: self.counter = 0 self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf + self.val_loss_min = np.inf self.delta = delta def __call__(self, val_loss, model, path):