diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index d789a68..d1a09b9 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -5,9 +5,14 @@
+
+
+
+
+
@@ -40,7 +45,7 @@
-
+
@@ -57,10 +62,11 @@
-
+
-
+
-
+
@@ -114,7 +120,7 @@
-
+
@@ -139,6 +145,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -157,6 +191,7 @@
1756727620810
+
@@ -176,5 +211,6 @@
+
\ No newline at end of file
diff --git a/configs/STGODE/PEMS08.yaml b/configs/STGODE/PEMS08.yaml
new file mode 100644
index 0000000..d9b7f29
--- /dev/null
+++ b/configs/STGODE/PEMS08.yaml
@@ -0,0 +1,60 @@
+basic:
+ device: cuda:0
+ dataset: PEMS08
+ model: STGODE
+ mode: train
+ seed: 2025
+
+data:
+ dataset_dir: data/PEMS08
+ val_batch_size: 32
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ num_nodes: 170
+ batch_size: 64
+ input_dim: 1
+ lag: 12
+ horizon: 12
+ val_ratio: 0.2
+ test_ratio: 0.2
+ tod: False
+ normalizer: std
+ column_wise: False
+ default_graph: True
+ add_time_in_day: True
+ add_day_in_week: True
+ steps_per_day: 24
+ days_per_week: 7
+
+model:
+ input_dim: 1
+ output_dim: 1
+ history: 12
+ horizon: 12
+ num_features: 1
+ rnn_units: 64
+ sigma1: 0.1
+ sigma2: 10
+ thres1: 0.6
+ thres2: 0.5
+
+
+train:
+ loss: mae
+ batch_size: 64
+ epochs: 100
+ lr_init: 0.003
+ mape_thresh: 0.001
+ mae_thresh: None
+ debug: False
+ output_dim: 1
+ weight_decay: 0
+ lr_decay: False
+ lr_decay_rate: 0.3
+ lr_decay_step: "5,20,40,70"
+ early_stop: True
+ early_stop_patience: 15
+ grad_norm: False
+ max_grad_norm: 5
+ real_value: True
+ log_step: 3000
+
\ No newline at end of file
diff --git a/models/STGODE/STGODE.py b/models/STGODE/STGODE.py
index cf60e77..f978623 100755
--- a/models/STGODE/STGODE.py
+++ b/models/STGODE/STGODE.py
@@ -117,7 +117,7 @@ class STGCNBlock(nn.Module):
class ODEGCN(nn.Module):
""" the overall network framework """
- def __init__(self, args):
+ def __init__(self, config):
"""
Args:
num_nodes : number of nodes in the graph
@@ -129,11 +129,12 @@ class ODEGCN(nn.Module):
"""
super(ODEGCN, self).__init__()
- num_nodes = args['num_nodes']
+ args = config['model']
+ num_nodes = config['data']['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
- A_sp_hat, A_se_hat = get_A_hat(args)
+ A_sp_hat, A_se_hat = get_A_hat(config)
# spatial graph
self.sp_blocks = nn.ModuleList(
diff --git a/models/STGODE/adj.py b/models/STGODE/adj.py
index e5da55d..5dca96b 100755
--- a/models/STGODE/adj.py
+++ b/models/STGODE/adj.py
@@ -17,7 +17,7 @@ files = {
}
-def get_A_hat(args):
+def get_A_hat(config):
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
Args:
@@ -31,12 +31,13 @@ def get_A_hat(args):
dtw_matrix: array, semantic adjacency matrix
sp_matrix: array, spatial adjacency matrix
"""
- filepath = './data/'
- num_node = args['num_nodes']
- file = files[num_node]
- filename = file[0][:6]
+ file_path = config['data']['graph_pkl_filename']
+ filename = config['basic']['dataset']
+ dataset_path = config['data']['dataset_dir']
+ args = config['model']
- data = np.load(filepath + file[0])['data']
+ data = np.load(file_path)
+ data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
num_node = data.shape[1]
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
@@ -72,7 +73,7 @@ def get_A_hat(args):
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
# 使用 pandas 读取 CSV 文件,跳过标题行
- df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
+ df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(id_dict[row[0]])
@@ -82,7 +83,7 @@ def get_A_hat(args):
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
else:
# 使用 pandas 读取 CSV 文件,跳过标题行
- df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
+ df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(row[0])
@@ -98,7 +99,8 @@ def get_A_hat(args):
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
sp_matrix[sp_matrix < args['thres2']] = 0
- return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])
+ return (get_normalized_adj(dtw_matrix).to(config['basic']['device']),
+ get_normalized_adj(sp_matrix).to(config['basic']['device']))
def get_normalized_adj(A):
@@ -115,16 +117,16 @@ def get_normalized_adj(A):
return torch.from_numpy(A_reg.astype(np.float32))
-if __name__ == '__main__':
- if __name__ == '__main__':
- config = {
- 'sigma1': 0.1,
- 'sigma2': 10,
- 'thres1': 0.6,
- 'thres2': 0.5,
- 'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
- }
- for nodes in [358, 170, 883]:
- args = {'num_nodes': nodes, **config}
- get_A_hat(args)
+if __name__ == '__main__':
+ config = {
+ 'sigma1': 0.1,
+ 'sigma2': 10,
+ 'thres1': 0.6,
+ 'thres2': 0.5,
+ 'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ }
+
+ for nodes in [358, 170, 883]:
+ args = {'num_nodes': nodes, **config}
+ get_A_hat(args)
diff --git a/models/model_selector.py b/models/model_selector.py
index ab80d63..cb07152 100644
--- a/models/model_selector.py
+++ b/models/model_selector.py
@@ -1,8 +1,12 @@
from models.STDEN.stden_model import STDENModel
+from models.STGODE.STGODE import ODEGCN
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
- case 'STDEN': model = STDENModel(config)
+ case 'STDEN':
+ model = STDENModel(config)
+ case 'STGODE':
+ model = ODEGCN(config)
return model
\ No newline at end of file
diff --git a/test_semantic.npy b/test_semantic.npy
new file mode 100644
index 0000000..e0d1b75
Binary files /dev/null and b/test_semantic.npy differ
diff --git a/test_spatial.npy b/test_spatial.npy
new file mode 100644
index 0000000..747906d
Binary files /dev/null and b/test_spatial.npy differ