model subgraph

This commit is contained in:
HengZhang 2024-11-28 11:46:32 +08:00
parent e95c13f0fc
commit 5fdab2b668
6 changed files with 101 additions and 30 deletions

View File

@ -0,0 +1,68 @@
import numpy as np
from federatedscope.register import register_scheduler
# LR Scheduler
class LR_Scheduler(object):
def __init__(self,
optimizer,
warmup_epochs,
warmup_lr,
num_epochs,
base_lr,
final_lr,
iter_per_epoch,
constant_predictor_lr=False):
self.base_lr = base_lr
self.constant_predictor_lr = constant_predictor_lr
warmup_iter = iter_per_epoch * warmup_epochs
warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * (
1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter))
self.lr_schedule = np.concatenate(
(warmup_lr_schedule, cosine_lr_schedule))
self.optimizer = optimizer
self.iter = 0
self.current_lr = 0
def step(self):
for param_group in self.optimizer.param_groups:
if self.constant_predictor_lr and param_group[
'name'] == 'predictor':
param_group['lr'] = self.base_lr
else:
lr = param_group['lr'] = self.lr_schedule[self.iter]
self.iter += 1
self.current_lr = lr
return lr
def get_lr(self):
return self.current_lr
def get_scheduler(optimizer, type):
try:
import torch.optim as optim
except ImportError:
optim = None
scheduler = None
if type == 'cos_lr_scheduler':
if optim is not None:
lr_lambda = [lambda epoch: epoch // 30]
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
warmup_epochs=0,
warmup_lr=0,
num_epochs=50,
base_lr=30,
final_lr=0,
iter_per_epoch=int(50000 /
512))
return scheduler
register_scheduler('cos_lr_scheduler', get_scheduler)

View File

@ -2,7 +2,7 @@ from torch.nn import ModuleList
import torch import torch
import torch.nn as nn import torch.nn as nn
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
import time
class DGCRM(nn.Module): class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
@ -25,8 +25,7 @@ class DGCRM(nn.Module):
state = init_state[i] state = init_state[i]
inner_states = [] inner_states = []
for t in range(seq_length): for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state) inner_states.append(state)
output_hidden.append(state) output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1) current_inputs = torch.stack(inner_states, dim=1)
@ -36,8 +35,7 @@ class DGCRM(nn.Module):
init_states = [] init_states = []
for i in range(self.num_layers): for i in range(self.num_layers):
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
# Build you torch or tf model class here # Build you torch or tf model class here
class FedDGCN(nn.Module): class FedDGCN(nn.Module):
@ -52,7 +50,7 @@ class FedDGCN(nn.Module):
self.num_layers = args.num_layers self.num_layers = args.num_layers
self.use_D = args.use_day self.use_D = args.use_day
self.use_W = args.use_week self.use_W = args.use_week
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
self.dropout2 = nn.Dropout(p=args.dropout) self.dropout2 = nn.Dropout(p=args.dropout)
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
@ -75,16 +73,16 @@ class FedDGCN(nn.Module):
def forward(self, source): def forward(self, source):
node_embedding1 = self.node_embeddings1 node_embedding1 = self.node_embeddings1
if self.use_D: if self.use_D:
t_i_d_data = source[..., 1] t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
if self.use_W: if self.use_W:
d_i_w_data = source[..., 2] d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
node_embeddings = [node_embedding1, self.node_embeddings1] node_embeddings=[node_embedding1,self.node_embeddings1]
source = source[..., 0].unsqueeze(-1) source = source[..., 0].unsqueeze(-1)
@ -133,6 +131,7 @@ class FederatedFedDGCN(nn.Module):
subgraph_outputs = [] subgraph_outputs = []
# Iterate through the subgraph models # Iterate through the subgraph models
# Parallel computation has not been realized yet, so it may slower than normal.
for i in range(self.subgraph_num): for i in range(self.subgraph_num):
# Extract the subgraph-specific data # Extract the subgraph-specific data
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims) subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
@ -143,28 +142,28 @@ class FederatedFedDGCN(nn.Module):
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
self.local_aggregate()
# self.update_main_model()
return output_tensor return output_tensor
def update_main_model(self): def local_aggregate(self):
""" """
更新 main_model 的参数为 model_list 中所有模型参数的平均值 Update the parameters of each model in model_list to the average of all models' parameters.
""" """
# 遍历 main_model 的参数 with torch.no_grad(): # Ensure no gradients are calculated during the update
with torch.no_grad(): # 确保更新时不会计算梯度 # Iterate over each model in model_list
for name, main_param in self.main_model.named_parameters(): for i, model in enumerate(self.model_list):
# 初始化平均值的容器 # Iterate over each model's parameters
avg_param = torch.zeros_like(main_param) for name, param in model.named_parameters():
# Initialize a container for the average value
avg_param = torch.zeros_like(param)
# 遍历 model_list 中的所有模型 # Accumulate the corresponding parameters from all other models
for model in self.model_list: for other_model in self.model_list:
# 加上当前模型的对应参数 avg_param += other_model.state_dict()[name]
avg_param += model.state_dict()[name]
# 计算平均值 # Calculate the average
avg_param /= len(self.model_list) avg_param /= len(self.model_list)
# Update the current model's parameter
param.data.copy_(avg_param)
# 更新 main_model 的参数
main_param.copy_(avg_param)

View File

@ -42,6 +42,8 @@ model:
cheb_order: 2 cheb_order: 2
use_day: True use_day: True
use_week: True use_week: True
use_minigraph: False
minigraph_size: 10
train: train:
batch_or_epoch: 'epoch' batch_or_epoch: 'epoch'
local_update_steps: 1 local_update_steps: 1

View File

@ -44,7 +44,7 @@ model:
cheb_order: 2 cheb_order: 2
use_day: True use_day: True
use_week: True use_week: True
use_minigraph: True use_minigraph: False
minigraph_size: 10 minigraph_size: 10
train: train:
batch_or_epoch: 'epoch' batch_or_epoch: 'epoch'

View File

@ -42,6 +42,8 @@ model:
cheb_order: 2 cheb_order: 2
use_day: True use_day: True
use_week: True use_week: True
use_minigraph: False
minigraph_size: 10
train: train:
batch_or_epoch: 'epoch' batch_or_epoch: 'epoch'
local_update_steps: 1 local_update_steps: 1

View File

@ -42,8 +42,8 @@ model:
cheb_order: 2 cheb_order: 2
use_day: True use_day: True
use_week: True use_week: True
use_minigraph: True use_minigraph: False
minigraph_size: 5 minigraph_size: 10
train: train:
batch_or_epoch: 'epoch' batch_or_epoch: 'epoch'
local_update_steps: 1 local_update_steps: 1