From 387f64efab6e951f87a997c2aa5842e27ae94de5 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Sep 2025 10:18:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0STGODE=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=EF=BC=9A=E6=B7=BB=E5=8A=A0=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E3=80=81=E4=BC=98=E5=8C=96=E6=A8=A1=E5=9E=8B=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E3=80=81=E6=96=B0=E5=A2=9E=E6=B5=8B=E8=AF=95=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/workspace.xml | 48 +++++++++++++++++++++++++---- configs/STGODE/PEMS08.yaml | 60 +++++++++++++++++++++++++++++++++++++ models/STGODE/STGODE.py | 7 +++-- models/STGODE/adj.py | 44 ++++++++++++++------------- models/model_selector.py | 6 +++- test_semantic.npy | Bin 0 -> 928 bytes test_spatial.npy | Bin 0 -> 928 bytes 7 files changed, 134 insertions(+), 31 deletions(-) create mode 100644 configs/STGODE/PEMS08.yaml create mode 100644 test_semantic.npy create mode 100644 test_spatial.npy diff --git a/.idea/workspace.xml b/.idea/workspace.xml index d789a68..d1a09b9 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,9 +5,14 @@ + + + + + - @@ -57,10 +62,11 @@ - + - + - + @@ -114,7 +120,7 @@ - + + + + + + + + @@ -157,6 +191,7 @@ 1756727620810 + @@ -176,5 +211,6 @@ + \ No newline at end of file diff --git a/configs/STGODE/PEMS08.yaml b/configs/STGODE/PEMS08.yaml new file mode 100644 index 0000000..d9b7f29 --- /dev/null +++ b/configs/STGODE/PEMS08.yaml @@ -0,0 +1,60 @@ +basic: + device: cuda:0 + dataset: PEMS08 + model: STGODE + mode: train + seed: 2025 + +data: + dataset_dir: data/PEMS08 + val_batch_size: 32 + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + num_nodes: 170 + batch_size: 64 + input_dim: 1 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 24 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + history: 12 + horizon: 12 + num_features: 1 + rnn_units: 64 + sigma1: 0.1 + sigma2: 10 + thres1: 0.6 + thres2: 0.5 + + +train: + loss: mae + batch_size: 64 + epochs: 100 + lr_init: 0.003 + mape_thresh: 0.001 + mae_thresh: None + debug: False + output_dim: 1 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + log_step: 3000 + \ No newline at end of file diff --git a/models/STGODE/STGODE.py b/models/STGODE/STGODE.py index cf60e77..f978623 100755 --- a/models/STGODE/STGODE.py +++ b/models/STGODE/STGODE.py @@ -117,7 +117,7 @@ class STGCNBlock(nn.Module): class ODEGCN(nn.Module): """ the overall network framework """ - def __init__(self, args): + def __init__(self, config): """ Args: num_nodes : number of nodes in the graph @@ -129,11 +129,12 @@ class ODEGCN(nn.Module): """ super(ODEGCN, self).__init__() - num_nodes = args['num_nodes'] + args = config['model'] + num_nodes = config['data']['num_nodes'] num_features = args['num_features'] num_timesteps_input = args['history'] num_timesteps_output = args['horizon'] - A_sp_hat, A_se_hat = get_A_hat(args) + A_sp_hat, A_se_hat = get_A_hat(config) # spatial graph self.sp_blocks = nn.ModuleList( diff --git a/models/STGODE/adj.py b/models/STGODE/adj.py index e5da55d..5dca96b 100755 --- a/models/STGODE/adj.py +++ b/models/STGODE/adj.py @@ -17,7 +17,7 @@ files = { } -def get_A_hat(args): +def get_A_hat(config): """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw Args: @@ -31,12 +31,13 @@ def get_A_hat(args): dtw_matrix: array, semantic adjacency matrix sp_matrix: array, spatial adjacency matrix """ - filepath = './data/' - num_node = args['num_nodes'] - file = files[num_node] - filename = file[0][:6] + file_path = config['data']['graph_pkl_filename'] + filename = config['basic']['dataset'] + dataset_path = config['data']['dataset_dir'] + args = config['model'] - data = np.load(filepath + file[0])['data'] + data = np.load(file_path) + data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0) num_node = data.shape[1] mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) @@ -72,7 +73,7 @@ def get_A_hat(args): with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f: id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 # 使用 pandas 读取 CSV 文件,跳过标题行 - df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(id_dict[row[0]]) @@ -82,7 +83,7 @@ def get_A_hat(args): np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) else: # 使用 pandas 读取 CSV 文件,跳过标题行 - df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(row[0]) @@ -98,7 +99,8 @@ def get_A_hat(args): sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) sp_matrix[sp_matrix < args['thres2']] = 0 - return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device']) + return (get_normalized_adj(dtw_matrix).to(config['basic']['device']), + get_normalized_adj(sp_matrix).to(config['basic']['device'])) def get_normalized_adj(A): @@ -115,16 +117,16 @@ def get_normalized_adj(A): return torch.from_numpy(A_reg.astype(np.float32)) -if __name__ == '__main__': - if __name__ == '__main__': - config = { - 'sigma1': 0.1, - 'sigma2': 10, - 'thres1': 0.6, - 'thres2': 0.5, - 'device': 'cuda:0' if torch.cuda.is_available() else 'cpu' - } - for nodes in [358, 170, 883]: - args = {'num_nodes': nodes, **config} - get_A_hat(args) +if __name__ == '__main__': + config = { + 'sigma1': 0.1, + 'sigma2': 10, + 'thres1': 0.6, + 'thres2': 0.5, + 'device': 'cuda:0' if torch.cuda.is_available() else 'cpu' + } + + for nodes in [358, 170, 883]: + args = {'num_nodes': nodes, **config} + get_A_hat(args) diff --git a/models/model_selector.py b/models/model_selector.py index ab80d63..cb07152 100644 --- a/models/model_selector.py +++ b/models/model_selector.py @@ -1,8 +1,12 @@ from models.STDEN.stden_model import STDENModel +from models.STGODE.STGODE import ODEGCN def model_selector(config): model_name = config['basic']['model'] model = None match model_name: - case 'STDEN': model = STDENModel(config) + case 'STDEN': + model = STDENModel(config) + case 'STGODE': + model = ODEGCN(config) return model \ No newline at end of file diff --git a/test_semantic.npy b/test_semantic.npy new file mode 100644 index 0000000000000000000000000000000000000000..e0d1b755d72f3a7665e3bfc2b9764502ede67fa4 GIT binary patch literal 928 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-2099c2AVnwwF+bcE(R#j$}rWi3H@e2>+lmJjd?%p`>VxcbB=zne|BMx z!#5Rf2jgoW_ggiyIvktxYyQi8Jg?>B9R&gP+^O%z>GE z==1ZRkI%6=ESqDicO`(!LA;{qa>!0zhp&63o6aBMa1gkA&q!#sfWt1?I-L!z%nrAw zUj25snAPD!z+tt=9vluZ_rT17nd_tY-=)Y%(7{}DdXR<}k3;9=b8B5iI2{sfo=rW( z{oP)xmdUvznA73bv9^=lyaEoqQJc;@aQhDT7tB2{b71DW?PY1dewNL_BYMt8m;3w< zRh6OM+Z_ZQSUl?%79M(QU)=d}R+b}+gQ8F1)xAkC;o$@G7tB2{b71DKnEBiLY74W& z$}4L+CuIM$pZ&9?>c}EK2eVBx-rxA}&7QB-@{6V_3p|`*;REv*%snu3VCL?T)V=zd zf!%>^VbayM0yYQd?{RMyKW1_8e=_yTJY#-%e8IvQ7Ctb4!Q2Bg2WD=@#7kfAGzd9N zFrF+^9K_~uj63I^qx(a9SiHmH3l`3>@PYXY<{p?iFmt08i7`yR_RYTKoA9Q|T}<%w z1B-W9e8IvQ7Ctb4!Q2Bg2WIZ{E$ZtgFZcydhp_Ymi+5Oj!NM6HK2U$b+ygTQW-b7` CS7fFD literal 0 HcmV?d00001 diff --git a/test_spatial.npy b/test_spatial.npy new file mode 100644 index 0000000000000000000000000000000000000000..747906d5cb8801df205af4f79fde5c8803a4f582 GIT binary patch literal 928 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-2099c2AVnwwF+bcE(Rz#e90rFanVov8y4HDrKXEI9B}T~Gsla|p|neN za?e3=2e(b~v4&Z44(GMJ1Z_UbJCx_DOBBoya+taHt@E6n3=WHurbj={5O9E*12gw> zWr?h9Ag@C~)SO!?g(42gb!Nh|jDCSrhwXh6FBi|}bAY)AW)95U$bE#w zo*hqeN+cajZf%+WbB~xq+Dx^3j;|CQngSPAUx^ohhY!qOF!#XBftmX#BBW%3wX{R} zWi7vq92tjK0himd-1r=nKl3zXNU}Seim0)1xFG=#XIS{a`~`Cl%p91xS_c-F=pB)D zm~nRAmcpsq?N>eJo;LBiltY!3{!@>vVFPM8^=D^I|k-@@MbpL@pEFHqq4=mnc@dXQKSopyF1#=I~9GJNP{c}&? literal 0 HcmV?d00001