修复bug,np.Inf

This commit is contained in:
czzhangheng 2025-11-09 09:33:22 +08:00
parent 8372a7580c
commit 055fef5046
5 changed files with 21 additions and 15 deletions

View File

@ -21,10 +21,16 @@ python prepare_pems_bay.py
根据BasicTS仓库配置BasicTS环境亦或是使用
`pip install -r requirement.txt`
`pip install -r requirements.txt`
我是直接使用现有的BasicTS环境因此没有做过测试
创建log文件夹到项目仓库
```bash
mkdir log
```
开跑
```python
python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12

View File

@ -7,8 +7,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from transformers import GPT2Model, GPT2Config
from einops import rearrange
from .reprogramming import *
from .normalizer import *
from models.reprogramming import *
from models.normalizer import *
class repst(nn.Module):

View File

@ -1,11 +1,11 @@
torch==2.0.1
accelerate==0.26.1
einops==0.6.0
matplotlib==3.8.2
numpy==1.24.4
pandas==2.1.4
scikit-learn==1.3.2
scipy==1.11.4
tqdm==4.66.1
transformers==4.36.2
torch
accelerate
einops
matplotlib
numpy
pandas
scikit-learn
scipy
tqdm
transformers

View File

@ -47,7 +47,7 @@ class EarlyStopping:
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.val_loss_min = np.inf
self.delta = delta
def __call__(self, val_loss, model, path):

View File

@ -47,7 +47,7 @@ class EarlyStopping:
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.val_loss_min = np.inf
self.delta = delta
def __call__(self, val_loss, model, path):