fix dataloader
This commit is contained in:
parent
0b41f04d3c
commit
9cfe3f01dc
|
|
@ -205,7 +205,7 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
|
||||||
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
|
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
|
||||||
|
|
||||||
# Initialize output tensor with dummy values
|
# Initialize output tensor with dummy values
|
||||||
output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype)
|
output = np.full((timestep, horizon, graph_num, graph_size, dim), dummy_value, dtype=tensor.dtype)
|
||||||
|
|
||||||
# Fill in the real data
|
# Fill in the real data
|
||||||
for i in range(graph_num):
|
for i in range(graph_num):
|
||||||
|
|
@ -213,8 +213,8 @@ def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
|
||||||
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
|
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
|
||||||
slice_size = end_idx - start_idx
|
slice_size = end_idx - start_idx
|
||||||
|
|
||||||
# Assign the data to the corresponding mini-graph (adjusted indexing)
|
# Assign the data to the corresponding mini-graph
|
||||||
output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :]
|
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from federatedscope.register import register_model
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||||
|
|
@ -71,7 +69,7 @@ class FedDGCN(nn.Module):
|
||||||
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
|
||||||
def forward(self, source, i=2):
|
def forward(self, source):
|
||||||
node_embedding1 = self.node_embeddings1
|
node_embedding1 = self.node_embeddings1
|
||||||
if self.use_D:
|
if self.use_D:
|
||||||
t_i_d_data = source[..., 1]
|
t_i_d_data = source[..., 1]
|
||||||
|
|
@ -152,4 +150,27 @@ class FederatedFedDGCN(nn.Module):
|
||||||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
|
||||||
|
self.update_main_model()
|
||||||
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
def update_main_model(self):
|
||||||
|
"""
|
||||||
|
更新 main_model 的参数为 model_list 中所有模型参数的平均值。
|
||||||
|
"""
|
||||||
|
# 遍历 main_model 的参数
|
||||||
|
with torch.no_grad(): # 确保更新时不会计算梯度
|
||||||
|
for name, main_param in self.main_model.named_parameters():
|
||||||
|
# 初始化平均值的容器
|
||||||
|
avg_param = torch.zeros_like(main_param)
|
||||||
|
|
||||||
|
# 遍历 model_list 中的所有模型
|
||||||
|
for model in self.model_list:
|
||||||
|
# 加上当前模型的对应参数
|
||||||
|
avg_param += model.state_dict()[name]
|
||||||
|
|
||||||
|
# 计算平均值
|
||||||
|
avg_param /= len(self.model_list)
|
||||||
|
|
||||||
|
# 更新 main_model 的参数
|
||||||
|
main_param.copy_(avg_param)
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ model:
|
||||||
use_day: True
|
use_day: True
|
||||||
use_week: True
|
use_week: True
|
||||||
use_minigraph: True
|
use_minigraph: True
|
||||||
minigraph_size: 3
|
minigraph_size: 10
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
local_update_steps: 1
|
local_update_steps: 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue