fix bugs
This commit is contained in:
parent
f35bfbf75d
commit
0a9ac1a025
|
|
@ -95,5 +95,7 @@ git push origin dev
|
|||
|
||||
目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN
|
||||
|
||||
STGCN在载入图时会有未知warning
|
||||
|
||||
以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
data:
|
||||
num_nodes: 13
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
normalizer: std
|
||||
column_wise: false
|
||||
default_graph: true
|
||||
add_time_in_day: true
|
||||
add_day_in_week: true
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
batch_size: 64
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 50
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step:
|
||||
- 5
|
||||
- 20
|
||||
- 40
|
||||
- 70
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
grad_norm: false
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 10000
|
||||
plot: false
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 5
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 1026
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 1000
|
||||
plot: true
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 12
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.0
|
||||
train:
|
||||
batch_size: 12
|
||||
early_stop: true
|
||||
early_stop_patience: 10
|
||||
epochs: 300
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.1
|
||||
lr_decay_step:
|
||||
- '5'
|
||||
- '20'
|
||||
- '40'
|
||||
- '70'
|
||||
lr_init: 0.0005625
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 12
|
||||
weight_decay: 0
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 5
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 228
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 1000
|
||||
plot: true
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 8
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.0
|
||||
train:
|
||||
batch_size: 64
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 300
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.1
|
||||
lr_decay_step:
|
||||
- '5'
|
||||
- '20'
|
||||
- '40'
|
||||
- '70'
|
||||
lr_init: 0.003
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 12
|
||||
weight_decay: 0
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 883
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 3000
|
||||
plot: false
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 12
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.0
|
||||
train:
|
||||
batch_size: 16
|
||||
early_stop: true
|
||||
early_stop_patience: 10
|
||||
epochs: 200
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step:
|
||||
- '5'
|
||||
- '20'
|
||||
- '40'
|
||||
- '70'
|
||||
lr_init: 0.00075
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 10
|
||||
weight_decay: 0
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 2000
|
||||
plot: false
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 5
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
train:
|
||||
batch_size: 64
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 300
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 12
|
||||
weight_decay: 0
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
data:
|
||||
num_nodes: 307
|
||||
num_nodes: 358
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
|
|
@ -8,20 +8,25 @@ data:
|
|||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
add_time_in_day: False
|
||||
add_day_in_week: False
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
input_window: 12
|
||||
output_window: 12
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
natt_hops: 4
|
||||
nfc: 256
|
||||
max_up_len: 80
|
||||
feature_dim: 1
|
||||
use_day: False
|
||||
use_week: False
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
data:
|
||||
num_nodes: 358
|
||||
num_nodes: 883
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
|
|
@ -8,26 +8,31 @@ data:
|
|||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
add_time_in_day: False
|
||||
add_day_in_week: False
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 12
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
input_window: 12
|
||||
output_window: 12
|
||||
embed_dim: 10
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
natt_hops: 4
|
||||
nfc: 256
|
||||
max_up_len: 80
|
||||
feature_dim: 1
|
||||
use_day: False
|
||||
use_week: False
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 50
|
||||
batch_size: 32
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
|
|
@ -44,5 +49,5 @@ test:
|
|||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 10000
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
data:
|
||||
num_nodes: 358
|
||||
num_nodes: 170
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
|
|
@ -8,26 +8,31 @@ data:
|
|||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
add_time_in_day: False
|
||||
add_day_in_week: False
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 12
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
input_window: 12
|
||||
output_window: 12
|
||||
embed_dim: 10
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
natt_hops: 4
|
||||
nfc: 256
|
||||
max_up_len: 80
|
||||
feature_dim: 1
|
||||
use_day: False
|
||||
use_week: False
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 50
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
|
|
@ -44,5 +49,5 @@ test:
|
|||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 10000
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
data:
|
||||
num_nodes: 307
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
embed_dim: 64
|
||||
skip_dim: 256
|
||||
lape_dim: 8
|
||||
geo_num_heads: 4
|
||||
sem_num_heads: 2
|
||||
t_num_heads: 2
|
||||
mlp_ratio: 4
|
||||
qkv_bias: True
|
||||
drop: 0.
|
||||
attn_drop: 0.
|
||||
drop_path: 0.3
|
||||
s_attn_size: 3
|
||||
t_attn_size: 3
|
||||
enc_depth: 6
|
||||
type_ln: pre
|
||||
type_short_path: hop
|
||||
input_dim: 3
|
||||
output_dim: 1
|
||||
input_window: 12
|
||||
output_window: 12
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
world_size: 1
|
||||
huber_delta: 1
|
||||
quan_delta: 0.25
|
||||
far_mask_delta: 5
|
||||
dtw_delta: 5
|
||||
use_curriculum_learning: True
|
||||
step_size: 2500
|
||||
max_epoch: 200
|
||||
task_level: 0
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 883
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 3000
|
||||
plot: false
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 12
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.0
|
||||
train:
|
||||
batch_size: 16
|
||||
early_stop: true
|
||||
early_stop_patience: 10
|
||||
epochs: 200
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step:
|
||||
- '5'
|
||||
- '20'
|
||||
- '40'
|
||||
- '70'
|
||||
lr_init: 0.00075
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 10
|
||||
weight_decay: 0
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 2000
|
||||
plot: false
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 5
|
||||
input_dim: 1
|
||||
num_layers: 1
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
train:
|
||||
batch_size: 64
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 300
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
lr_init: 0.003
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
seed: 12
|
||||
weight_decay: 0
|
||||
|
|
@ -128,7 +128,7 @@ def load_st_dataset(dataset, sample):
|
|||
if len(data.shape) == 2:
|
||||
data = np.expand_dims(data, axis=-1)
|
||||
|
||||
print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data))
|
||||
print('加载 %s 数据集中... ' % dataset)
|
||||
return data[::sample]
|
||||
|
||||
def split_data_by_days(data, val_days, test_days, interval=30):
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def load_st_dataset(dataset, sample):
|
|||
if len(data.shape) == 2:
|
||||
data = np.expand_dims(data, axis=-1)
|
||||
|
||||
print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data))
|
||||
print('加载 %s 数据集中... ' % dataset)
|
||||
return data[::sample]
|
||||
|
||||
def split_data_by_days(data, val_days, test_days, interval=30):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import model.STGNCDE.controldiffeq
|
||||
from lib.normalization import normalize_dataset
|
||||
import numpy as np
|
||||
import gc
|
||||
|
|
@ -6,6 +5,8 @@ import os
|
|||
import torch
|
||||
import h5py
|
||||
|
||||
from model.STGNCDE import controldiffeq
|
||||
|
||||
|
||||
def get_dataloader(args, normalizer='std', single=True):
|
||||
data = load_st_dataset(args['type']) # 加载数据
|
||||
|
|
@ -129,22 +130,22 @@ def load_st_dataset(dataset):
|
|||
# output B, N, D
|
||||
match dataset:
|
||||
case 'PEMSD3':
|
||||
data_path = os.path.join('./data/PeMS03/PEMS03.npz')
|
||||
data_path = os.path.join('./data/PEMS03/PEMS03.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
case 'PEMSD4':
|
||||
data_path = os.path.join('./data/PeMS04/PEMS04.npz')
|
||||
data_path = os.path.join('./data/PEMS04/PEMS04.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
case 'PEMSD7':
|
||||
data_path = os.path.join('./data/PeMS07/PEMS07.npz')
|
||||
data_path = os.path.join('./data/PEMS07/PEMS07.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
case 'PEMSD8':
|
||||
data_path = os.path.join('./data/PeMS08/PeMS08.npz')
|
||||
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
case 'PEMSD7(L)':
|
||||
data_path = os.path.join('./data/PeMS07(L)/PEMS07L.npz')
|
||||
data_path = os.path.join('./data/PEMS07(L)/PEMS07L.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
case 'PEMSD7(M)':
|
||||
data_path = os.path.join('./data/PeMS07(M)/V_228.csv')
|
||||
data_path = os.path.join('./data/PEMS07(M)/V_228.csv')
|
||||
data = np.genfromtxt(data_path, delimiter=',') # Read CSV directly with numpy
|
||||
case 'METR-LA':
|
||||
data_path = os.path.join('./data/METR-LA/METR.h5')
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def normalize_dataset(data, normalizer, column_wise=False):
|
|||
maximum = data.max()
|
||||
scaler = MinMax01Scaler(minimum, maximum)
|
||||
# data = scaler.transform(data)
|
||||
print('Normalize the dataset by MinMax01 Normalization')
|
||||
# print('Normalize the dataset by MinMax01 Normalization')
|
||||
elif normalizer == 'max11':
|
||||
if column_wise:
|
||||
minimum = data.min(axis=0, keepdims=True)
|
||||
|
|
@ -130,7 +130,7 @@ def normalize_dataset(data, normalizer, column_wise=False):
|
|||
maximum = data.max()
|
||||
scaler = MinMax11Scaler(minimum, maximum)
|
||||
# data = scaler.transform(data)
|
||||
print('Normalize the dataset by MinMax11 Normalization')
|
||||
# print('Normalize the dataset by MinMax11 Normalization')
|
||||
elif normalizer == 'std':
|
||||
if column_wise:
|
||||
mean = data.mean(axis=0, keepdims=True)
|
||||
|
|
@ -140,15 +140,15 @@ def normalize_dataset(data, normalizer, column_wise=False):
|
|||
std = data.std()
|
||||
scaler = StandardScaler(mean, std)
|
||||
# data = scaler.transform(data)
|
||||
print('Normalize the dataset by Standard Normalization')
|
||||
# print('Normalize the dataset by Standard Normalization')
|
||||
elif normalizer == 'None':
|
||||
scaler = NScaler()
|
||||
# data = scaler.transform(data)
|
||||
print('Does not normalize the dataset')
|
||||
# print('Does not normalize the dataset')
|
||||
elif normalizer == 'cmax':
|
||||
scaler = ColumnMinMaxScaler(data.min(axis=0), data.max(axis=0))
|
||||
# data = scaler.transform(data)
|
||||
print('Normalize the dataset by Column Min-Max Normalization')
|
||||
# print('Normalize the dataset by Column Min-Max Normalization')
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalizer type: {normalizer}")
|
||||
return scaler
|
||||
|
|
|
|||
|
|
@ -9,20 +9,20 @@ def get_adj(args):
|
|||
dataset_path = './data'
|
||||
match args['num_nodes']:
|
||||
case 358:
|
||||
dataset_name = 'PeMS03'
|
||||
dataset_name = 'PEMS03'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'], id_filename=id)
|
||||
case 307:
|
||||
dataset_name = 'PeMS04'
|
||||
dataset_name = 'PEMS04'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
|
||||
case 883:
|
||||
dataset_name = 'PeMS07'
|
||||
dataset_name = 'PEMS07'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
|
||||
case 170:
|
||||
dataset_name = 'PeMS08'
|
||||
dataset_name = 'PEMS08'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
|
||||
|
||||
|
|
|
|||
10
run.py
10
run.py
|
|
@ -68,10 +68,6 @@ def main():
|
|||
with open(destination_path, 'w') as f:
|
||||
f.write(config_content)
|
||||
|
||||
# 拷贝配置文件到日志文件夹
|
||||
# destination_path = os.path.join(args['train']['log_dir'], config_filename)
|
||||
# shutil.copyfile(config_path, destination_path)
|
||||
|
||||
# Start training or testing
|
||||
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler, extra_data)
|
||||
|
|
@ -83,12 +79,8 @@ def main():
|
|||
model.load_state_dict(torch.load(
|
||||
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth",
|
||||
map_location=args['device'], weights_only=True))
|
||||
print(f"Loaded saved model on {args['device']}")
|
||||
# start_time = time.time()
|
||||
# print(f"Loaded saved model on {args['device']}")
|
||||
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
|
||||
# end_time = time.time()
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"执行时间:{elapsed_time:.4f} 秒")
|
||||
case _:
|
||||
raise ValueError(f"Unsupported mode: {args['mode']}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue