From 9cfe3f01dc96f641d251c70f1d4f549e06690300 Mon Sep 17 00:00:00 2001 From: HengZhang Date: Thu, 28 Nov 2024 10:43:14 +0800 Subject: [PATCH] fix dataloader --- .../dataloader/traffic_dataloader_v2.py | 6 ++--- federatedscope/trafficflow/model/FedDGCNv2.py | 27 ++++++++++++++++--- scripts/trafficflow_exp_scripts/D4.yaml | 2 +- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py index 98c5f56..f8d5229 100644 --- a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py +++ b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py @@ -205,7 +205,7 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0): graph_num = (node_num + graph_size - 1) // graph_size # Round up division # Initialize output tensor with dummy values - output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype) + output = np.full((timestep, horizon, graph_num, graph_size, dim), dummy_value, dtype=tensor.dtype) # Fill in the real data for i in range(graph_num): @@ -213,8 +213,8 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0): end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number slice_size = end_idx - start_idx - # Assign the data to the corresponding mini-graph (adjusted indexing) - output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :] + # Assign the data to the corresponding mini-graph + output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :] return output diff --git a/federatedscope/trafficflow/model/FedDGCNv2.py b/federatedscope/trafficflow/model/FedDGCNv2.py index ddd2f92..f565131 100644 --- a/federatedscope/trafficflow/model/FedDGCNv2.py +++ b/federatedscope/trafficflow/model/FedDGCNv2.py @@ -1,6 +1,4 @@ from torch.nn import ModuleList - -from federatedscope.register import register_model import torch import torch.nn as nn from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell @@ -71,7 +69,7 @@ class FedDGCN(nn.Module): self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) - def forward(self, source, i=2): + def forward(self, source): node_embedding1 = self.node_embeddings1 if self.use_D: t_i_d_data = source[..., 1] @@ -152,4 +150,27 @@ class FederatedFedDGCN(nn.Module): # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) + self.update_main_model() + return output_tensor + + def update_main_model(self): + """ + 更新 main_model 的参数为 model_list 中所有模型参数的平均值。 + """ + # 遍历 main_model 的参数 + with torch.no_grad(): # 确保更新时不会计算梯度 + for name, main_param in self.main_model.named_parameters(): + # 初始化平均值的容器 + avg_param = torch.zeros_like(main_param) + + # 遍历 model_list 中的所有模型 + for model in self.model_list: + # 加上当前模型的对应参数 + avg_param += model.state_dict()[name] + + # 计算平均值 + avg_param /= len(self.model_list) + + # 更新 main_model 的参数 + main_param.copy_(avg_param) diff --git a/scripts/trafficflow_exp_scripts/D4.yaml b/scripts/trafficflow_exp_scripts/D4.yaml index 73d9f70..1dc3aa6 100644 --- a/scripts/trafficflow_exp_scripts/D4.yaml +++ b/scripts/trafficflow_exp_scripts/D4.yaml @@ -45,7 +45,7 @@ model: use_day: True use_week: True use_minigraph: True - minigraph_size: 3 + minigraph_size: 10 train: batch_or_epoch: 'epoch' local_update_steps: 1