From 055fef5046e4f39eefd391a50be7ad34d7068a6e Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 09:33:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=EF=BC=8Cnp.Inf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 8 +++++++- models/repst.py | 4 ++-- requirements.txt | 20 ++++++++++---------- utils/former_tools.py | 2 +- utils/tools.py | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index a502bba..5e54f9b 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,16 @@ python prepare_pems_bay.py 根据BasicTS仓库配置BasicTS环境,亦或是使用 -`pip install -r requirement.txt` +`pip install -r requirements.txt` 我是直接使用现有的BasicTS环境,因此没有做过测试 +创建log文件夹到项目仓库 + +```bash +mkdir log +``` + 开跑 ```python python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12 diff --git a/models/repst.py b/models/repst.py index adf4908..a915e9c 100644 --- a/models/repst.py +++ b/models/repst.py @@ -7,8 +7,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers import GPT2Model, GPT2Config from einops import rearrange -from .reprogramming import * -from .normalizer import * +from models.reprogramming import * +from models.normalizer import * class repst(nn.Module): diff --git a/requirements.txt b/requirements.txt index f1a59c5..0834779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -torch==2.0.1 -accelerate==0.26.1 -einops==0.6.0 -matplotlib==3.8.2 -numpy==1.24.4 -pandas==2.1.4 -scikit-learn==1.3.2 -scipy==1.11.4 -tqdm==4.66.1 -transformers==4.36.2 +torch +accelerate +einops +matplotlib +numpy +pandas +scikit-learn +scipy +tqdm +transformers diff --git a/utils/former_tools.py b/utils/former_tools.py index 34f4415..fa2e8b7 100644 --- a/utils/former_tools.py +++ b/utils/former_tools.py @@ -47,7 +47,7 @@ class EarlyStopping: self.counter = 0 self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf + self.val_loss_min = np.inf self.delta = delta def __call__(self, val_loss, model, path): diff --git a/utils/tools.py b/utils/tools.py index a911d8d..0a2c592 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -47,7 +47,7 @@ class EarlyStopping: self.counter = 0 self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf + self.val_loss_min = np.inf self.delta = delta def __call__(self, val_loss, model, path):