fix dataloader
This commit is contained in:
parent
0b41f04d3c
commit
9cfe3f01dc
|
|
@ -205,7 +205,7 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
|
|||
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
|
||||
|
||||
# Initialize output tensor with dummy values
|
||||
output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype)
|
||||
output = np.full((timestep, horizon, graph_num, graph_size, dim), dummy_value, dtype=tensor.dtype)
|
||||
|
||||
# Fill in the real data
|
||||
for i in range(graph_num):
|
||||
|
|
@ -213,8 +213,8 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
|
|||
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
|
||||
slice_size = end_idx - start_idx
|
||||
|
||||
# Assign the data to the corresponding mini-graph (adjusted indexing)
|
||||
output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :]
|
||||
# Assign the data to the corresponding mini-graph
|
||||
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from torch.nn import ModuleList
|
||||
|
||||
from federatedscope.register import register_model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||
|
|
@ -71,7 +69,7 @@ class FedDGCN(nn.Module):
|
|||
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
|
||||
def forward(self, source, i=2):
|
||||
def forward(self, source):
|
||||
node_embedding1 = self.node_embeddings1
|
||||
if self.use_D:
|
||||
t_i_d_data = source[..., 1]
|
||||
|
|
@ -152,4 +150,27 @@ class FederatedFedDGCN(nn.Module):
|
|||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
|
||||
self.update_main_model()
|
||||
|
||||
return output_tensor
|
||||
|
||||
def update_main_model(self):
|
||||
"""
|
||||
更新 main_model 的参数为 model_list 中所有模型参数的平均值。
|
||||
"""
|
||||
# 遍历 main_model 的参数
|
||||
with torch.no_grad(): # 确保更新时不会计算梯度
|
||||
for name, main_param in self.main_model.named_parameters():
|
||||
# 初始化平均值的容器
|
||||
avg_param = torch.zeros_like(main_param)
|
||||
|
||||
# 遍历 model_list 中的所有模型
|
||||
for model in self.model_list:
|
||||
# 加上当前模型的对应参数
|
||||
avg_param += model.state_dict()[name]
|
||||
|
||||
# 计算平均值
|
||||
avg_param /= len(self.model_list)
|
||||
|
||||
# 更新 main_model 的参数
|
||||
main_param.copy_(avg_param)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ model:
|
|||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: True
|
||||
minigraph_size: 3
|
||||
minigraph_size: 10
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue