修复bug,np.Inf
This commit is contained in:
parent
8372a7580c
commit
055fef5046
|
|
@ -21,10 +21,16 @@ python prepare_pems_bay.py
|
|||
|
||||
根据BasicTS仓库配置BasicTS环境,亦或是使用
|
||||
|
||||
`pip install -r requirement.txt`
|
||||
`pip install -r requirements.txt`
|
||||
|
||||
我是直接使用现有的BasicTS环境,因此没有做过测试
|
||||
|
||||
创建log文件夹到项目仓库
|
||||
|
||||
```bash
|
||||
mkdir log
|
||||
```
|
||||
|
||||
开跑
|
||||
```python
|
||||
python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|||
from transformers import GPT2Model, GPT2Config
|
||||
|
||||
from einops import rearrange
|
||||
from .reprogramming import *
|
||||
from .normalizer import *
|
||||
from models.reprogramming import *
|
||||
from models.normalizer import *
|
||||
|
||||
class repst(nn.Module):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
torch==2.0.1
|
||||
accelerate==0.26.1
|
||||
einops==0.6.0
|
||||
matplotlib==3.8.2
|
||||
numpy==1.24.4
|
||||
pandas==2.1.4
|
||||
scikit-learn==1.3.2
|
||||
scipy==1.11.4
|
||||
tqdm==4.66.1
|
||||
transformers==4.36.2
|
||||
torch
|
||||
accelerate
|
||||
einops
|
||||
matplotlib
|
||||
numpy
|
||||
pandas
|
||||
scikit-learn
|
||||
scipy
|
||||
tqdm
|
||||
transformers
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class EarlyStopping:
|
|||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.Inf
|
||||
self.val_loss_min = np.inf
|
||||
self.delta = delta
|
||||
|
||||
def __call__(self, val_loss, model, path):
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class EarlyStopping:
|
|||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.Inf
|
||||
self.val_loss_min = np.inf
|
||||
self.delta = delta
|
||||
|
||||
def __call__(self, val_loss, model, path):
|
||||
|
|
|
|||
Loading…
Reference in New Issue