model subgraph

This commit is contained in:
HengZhang 2024-11-28 11:46:32 +08:00
parent e95c13f0fc
commit 5fdab2b668
6 changed files with 101 additions and 30 deletions

View File

@ -0,0 +1,68 @@
import numpy as np
from federatedscope.register import register_scheduler
# LR Scheduler
class LR_Scheduler(object):
def __init__(self,
optimizer,
warmup_epochs,
warmup_lr,
num_epochs,
base_lr,
final_lr,
iter_per_epoch,
constant_predictor_lr=False):
self.base_lr = base_lr
self.constant_predictor_lr = constant_predictor_lr
warmup_iter = iter_per_epoch * warmup_epochs
warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * (
1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter))
self.lr_schedule = np.concatenate(
(warmup_lr_schedule, cosine_lr_schedule))
self.optimizer = optimizer
self.iter = 0
self.current_lr = 0
def step(self):
for param_group in self.optimizer.param_groups:
if self.constant_predictor_lr and param_group[
'name'] == 'predictor':
param_group['lr'] = self.base_lr
else:
lr = param_group['lr'] = self.lr_schedule[self.iter]
self.iter += 1
self.current_lr = lr
return lr
def get_lr(self):
return self.current_lr
def get_scheduler(optimizer, type):
try:
import torch.optim as optim
except ImportError:
optim = None
scheduler = None
if type == 'cos_lr_scheduler':
if optim is not None:
lr_lambda = [lambda epoch: epoch // 30]
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
warmup_epochs=0,
warmup_lr=0,
num_epochs=50,
base_lr=30,
final_lr=0,
iter_per_epoch=int(50000 /
512))
return scheduler
register_scheduler('cos_lr_scheduler', get_scheduler)

View File

@ -2,7 +2,7 @@ from torch.nn import ModuleList
import torch
import torch.nn as nn
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
import time
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
@ -25,8 +25,7 @@ class DGCRM(nn.Module):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
@ -36,8 +35,7 @@ class DGCRM(nn.Module):
init_states = []
for i in range(self.num_layers):
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
# Build you torch or tf model class here
class FedDGCN(nn.Module):
@ -84,7 +82,7 @@ class FedDGCN(nn.Module):
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
node_embeddings = [node_embedding1, self.node_embeddings1]
node_embeddings=[node_embedding1,self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
@ -133,6 +131,7 @@ class FederatedFedDGCN(nn.Module):
subgraph_outputs = []
# Iterate through the subgraph models
# Parallel computation has not been realized yet, so it may slower than normal.
for i in range(self.subgraph_num):
# Extract the subgraph-specific data
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
@ -143,28 +142,28 @@ class FederatedFedDGCN(nn.Module):
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
# self.update_main_model()
self.local_aggregate()
return output_tensor
def update_main_model(self):
def local_aggregate(self):
"""
更新 main_model 的参数为 model_list 中所有模型参数的平均值
Update the parameters of each model in model_list to the average of all models' parameters.
"""
# 遍历 main_model 的参数
with torch.no_grad(): # 确保更新时不会计算梯度
for name, main_param in self.main_model.named_parameters():
# 初始化平均值的容器
avg_param = torch.zeros_like(main_param)
with torch.no_grad(): # Ensure no gradients are calculated during the update
# Iterate over each model in model_list
for i, model in enumerate(self.model_list):
# Iterate over each model's parameters
for name, param in model.named_parameters():
# Initialize a container for the average value
avg_param = torch.zeros_like(param)
# 遍历 model_list 中的所有模型
for model in self.model_list:
# 加上当前模型的对应参数
avg_param += model.state_dict()[name]
# Accumulate the corresponding parameters from all other models
for other_model in self.model_list:
avg_param += other_model.state_dict()[name]
# 计算平均值
# Calculate the average
avg_param /= len(self.model_list)
# 更新 main_model 的参数
main_param.copy_(avg_param)
# Update the current model's parameter
param.data.copy_(avg_param)

View File

@ -42,6 +42,8 @@ model:
cheb_order: 2
use_day: True
use_week: True
use_minigraph: False
minigraph_size: 10
train:
batch_or_epoch: 'epoch'
local_update_steps: 1

View File

@ -44,7 +44,7 @@ model:
cheb_order: 2
use_day: True
use_week: True
use_minigraph: True
use_minigraph: False
minigraph_size: 10
train:
batch_or_epoch: 'epoch'

View File

@ -42,6 +42,8 @@ model:
cheb_order: 2
use_day: True
use_week: True
use_minigraph: False
minigraph_size: 10
train:
batch_or_epoch: 'epoch'
local_update_steps: 1

View File

@ -42,8 +42,8 @@ model:
cheb_order: 2
use_day: True
use_week: True
use_minigraph: True
minigraph_size: 5
use_minigraph: False
minigraph_size: 10
train:
batch_or_epoch: 'epoch'
local_update_steps: 1