修复bug,np.Inf
This commit is contained in:
parent
8372a7580c
commit
055fef5046
|
|
@ -21,10 +21,16 @@ python prepare_pems_bay.py
|
||||||
|
|
||||||
根据BasicTS仓库配置BasicTS环境,亦或是使用
|
根据BasicTS仓库配置BasicTS环境,亦或是使用
|
||||||
|
|
||||||
`pip install -r requirement.txt`
|
`pip install -r requirements.txt`
|
||||||
|
|
||||||
我是直接使用现有的BasicTS环境,因此没有做过测试
|
我是直接使用现有的BasicTS环境,因此没有做过测试
|
||||||
|
|
||||||
|
创建log文件夹到项目仓库
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir log
|
||||||
|
```
|
||||||
|
|
||||||
开跑
|
开跑
|
||||||
```python
|
```python
|
||||||
python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12
|
python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from transformers import GPT2Model, GPT2Config
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .reprogramming import *
|
from models.reprogramming import *
|
||||||
from .normalizer import *
|
from models.normalizer import *
|
||||||
|
|
||||||
class repst(nn.Module):
|
class repst(nn.Module):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
torch==2.0.1
|
torch
|
||||||
accelerate==0.26.1
|
accelerate
|
||||||
einops==0.6.0
|
einops
|
||||||
matplotlib==3.8.2
|
matplotlib
|
||||||
numpy==1.24.4
|
numpy
|
||||||
pandas==2.1.4
|
pandas
|
||||||
scikit-learn==1.3.2
|
scikit-learn
|
||||||
scipy==1.11.4
|
scipy
|
||||||
tqdm==4.66.1
|
tqdm
|
||||||
transformers==4.36.2
|
transformers
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class EarlyStopping:
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
self.best_score = None
|
self.best_score = None
|
||||||
self.early_stop = False
|
self.early_stop = False
|
||||||
self.val_loss_min = np.Inf
|
self.val_loss_min = np.inf
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
|
|
||||||
def __call__(self, val_loss, model, path):
|
def __call__(self, val_loss, model, path):
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class EarlyStopping:
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
self.best_score = None
|
self.best_score = None
|
||||||
self.early_stop = False
|
self.early_stop = False
|
||||||
self.val_loss_min = np.Inf
|
self.val_loss_min = np.inf
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
|
|
||||||
def __call__(self, val_loss, model, path):
|
def __call__(self, val_loss, model, path):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue