fix dataloader

This commit is contained in:
HengZhang 2024-11-28 10:43:14 +08:00
parent 0b41f04d3c
commit 9cfe3f01dc
3 changed files with 28 additions and 7 deletions

View File

@ -205,7 +205,7 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
# Initialize output tensor with dummy values
output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype)
output = np.full((timestep, horizon, graph_num, graph_size, dim), dummy_value, dtype=tensor.dtype)
# Fill in the real data
for i in range(graph_num):
@ -213,8 +213,8 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
slice_size = end_idx - start_idx
# Assign the data to the corresponding mini-graph (adjusted indexing)
output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :]
# Assign the data to the corresponding mini-graph
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
return output

View File

@ -1,6 +1,4 @@
from torch.nn import ModuleList
from federatedscope.register import register_model
import torch
import torch.nn as nn
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
@ -71,7 +69,7 @@ class FedDGCN(nn.Module):
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, i=2):
def forward(self, source):
node_embedding1 = self.node_embeddings1
if self.use_D:
t_i_d_data = source[..., 1]
@ -152,4 +150,27 @@ class FederatedFedDGCN(nn.Module):
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
self.update_main_model()
return output_tensor
def update_main_model(self):
"""
更新 main_model 的参数为 model_list 中所有模型参数的平均值
"""
# 遍历 main_model 的参数
with torch.no_grad(): # 确保更新时不会计算梯度
for name, main_param in self.main_model.named_parameters():
# 初始化平均值的容器
avg_param = torch.zeros_like(main_param)
# 遍历 model_list 中的所有模型
for model in self.model_list:
# 加上当前模型的对应参数
avg_param += model.state_dict()[name]
# 计算平均值
avg_param /= len(self.model_list)
# 更新 main_model 的参数
main_param.copy_(avg_param)

View File

@ -45,7 +45,7 @@ model:
use_day: True
use_week: True
use_minigraph: True
minigraph_size: 3
minigraph_size: 10
train:
batch_or_epoch: 'epoch'
local_update_steps: 1