model subgraph
This commit is contained in:
parent
e95c13f0fc
commit
5fdab2b668
|
|
@ -0,0 +1,68 @@
|
||||||
|
import numpy as np
|
||||||
|
from federatedscope.register import register_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
# LR Scheduler
|
||||||
|
class LR_Scheduler(object):
|
||||||
|
def __init__(self,
|
||||||
|
optimizer,
|
||||||
|
warmup_epochs,
|
||||||
|
warmup_lr,
|
||||||
|
num_epochs,
|
||||||
|
base_lr,
|
||||||
|
final_lr,
|
||||||
|
iter_per_epoch,
|
||||||
|
constant_predictor_lr=False):
|
||||||
|
self.base_lr = base_lr
|
||||||
|
self.constant_predictor_lr = constant_predictor_lr
|
||||||
|
warmup_iter = iter_per_epoch * warmup_epochs
|
||||||
|
warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
|
||||||
|
decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
|
||||||
|
cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * (
|
||||||
|
1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter))
|
||||||
|
|
||||||
|
self.lr_schedule = np.concatenate(
|
||||||
|
(warmup_lr_schedule, cosine_lr_schedule))
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.iter = 0
|
||||||
|
self.current_lr = 0
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
for param_group in self.optimizer.param_groups:
|
||||||
|
|
||||||
|
if self.constant_predictor_lr and param_group[
|
||||||
|
'name'] == 'predictor':
|
||||||
|
param_group['lr'] = self.base_lr
|
||||||
|
else:
|
||||||
|
lr = param_group['lr'] = self.lr_schedule[self.iter]
|
||||||
|
|
||||||
|
self.iter += 1
|
||||||
|
self.current_lr = lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return self.current_lr
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(optimizer, type):
|
||||||
|
try:
|
||||||
|
import torch.optim as optim
|
||||||
|
except ImportError:
|
||||||
|
optim = None
|
||||||
|
scheduler = None
|
||||||
|
|
||||||
|
if type == 'cos_lr_scheduler':
|
||||||
|
if optim is not None:
|
||||||
|
lr_lambda = [lambda epoch: epoch // 30]
|
||||||
|
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
||||||
|
warmup_epochs=0,
|
||||||
|
warmup_lr=0,
|
||||||
|
num_epochs=50,
|
||||||
|
base_lr=30,
|
||||||
|
final_lr=0,
|
||||||
|
iter_per_epoch=int(50000 /
|
||||||
|
512))
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
register_scheduler('cos_lr_scheduler', get_scheduler)
|
||||||
|
|
@ -2,7 +2,7 @@ from torch.nn import ModuleList
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||||
|
import time
|
||||||
|
|
||||||
class DGCRM(nn.Module):
|
class DGCRM(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||||
|
|
@ -25,8 +25,7 @@ class DGCRM(nn.Module):
|
||||||
state = init_state[i]
|
state = init_state[i]
|
||||||
inner_states = []
|
inner_states = []
|
||||||
for t in range(seq_length):
|
for t in range(seq_length):
|
||||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
|
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||||
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
|
||||||
inner_states.append(state)
|
inner_states.append(state)
|
||||||
output_hidden.append(state)
|
output_hidden.append(state)
|
||||||
current_inputs = torch.stack(inner_states, dim=1)
|
current_inputs = torch.stack(inner_states, dim=1)
|
||||||
|
|
@ -36,8 +35,7 @@ class DGCRM(nn.Module):
|
||||||
init_states = []
|
init_states = []
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
||||||
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
# Build you torch or tf model class here
|
# Build you torch or tf model class here
|
||||||
class FedDGCN(nn.Module):
|
class FedDGCN(nn.Module):
|
||||||
|
|
@ -52,7 +50,7 @@ class FedDGCN(nn.Module):
|
||||||
self.num_layers = args.num_layers
|
self.num_layers = args.num_layers
|
||||||
self.use_D = args.use_day
|
self.use_D = args.use_day
|
||||||
self.use_W = args.use_week
|
self.use_W = args.use_week
|
||||||
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
|
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
|
||||||
self.dropout2 = nn.Dropout(p=args.dropout)
|
self.dropout2 = nn.Dropout(p=args.dropout)
|
||||||
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||||
|
|
@ -75,16 +73,16 @@ class FedDGCN(nn.Module):
|
||||||
def forward(self, source):
|
def forward(self, source):
|
||||||
node_embedding1 = self.node_embeddings1
|
node_embedding1 = self.node_embeddings1
|
||||||
if self.use_D:
|
if self.use_D:
|
||||||
t_i_d_data = source[..., 1]
|
t_i_d_data = source[..., 1]
|
||||||
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
|
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
|
||||||
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
|
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
|
||||||
|
|
||||||
if self.use_W:
|
if self.use_W:
|
||||||
d_i_w_data = source[..., 2]
|
d_i_w_data = source[..., 2]
|
||||||
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
|
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
|
||||||
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
|
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
|
||||||
|
|
||||||
node_embeddings = [node_embedding1, self.node_embeddings1]
|
node_embeddings=[node_embedding1,self.node_embeddings1]
|
||||||
|
|
||||||
source = source[..., 0].unsqueeze(-1)
|
source = source[..., 0].unsqueeze(-1)
|
||||||
|
|
||||||
|
|
@ -133,6 +131,7 @@ class FederatedFedDGCN(nn.Module):
|
||||||
subgraph_outputs = []
|
subgraph_outputs = []
|
||||||
|
|
||||||
# Iterate through the subgraph models
|
# Iterate through the subgraph models
|
||||||
|
# Parallel computation has not been realized yet, so it may slower than normal.
|
||||||
for i in range(self.subgraph_num):
|
for i in range(self.subgraph_num):
|
||||||
# Extract the subgraph-specific data
|
# Extract the subgraph-specific data
|
||||||
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
||||||
|
|
@ -143,28 +142,28 @@ class FederatedFedDGCN(nn.Module):
|
||||||
|
|
||||||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
self.local_aggregate()
|
||||||
# self.update_main_model()
|
|
||||||
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def update_main_model(self):
|
def local_aggregate(self):
|
||||||
"""
|
"""
|
||||||
更新 main_model 的参数为 model_list 中所有模型参数的平均值。
|
Update the parameters of each model in model_list to the average of all models' parameters.
|
||||||
"""
|
"""
|
||||||
# 遍历 main_model 的参数
|
with torch.no_grad(): # Ensure no gradients are calculated during the update
|
||||||
with torch.no_grad(): # 确保更新时不会计算梯度
|
# Iterate over each model in model_list
|
||||||
for name, main_param in self.main_model.named_parameters():
|
for i, model in enumerate(self.model_list):
|
||||||
# 初始化平均值的容器
|
# Iterate over each model's parameters
|
||||||
avg_param = torch.zeros_like(main_param)
|
for name, param in model.named_parameters():
|
||||||
|
# Initialize a container for the average value
|
||||||
|
avg_param = torch.zeros_like(param)
|
||||||
|
|
||||||
# 遍历 model_list 中的所有模型
|
# Accumulate the corresponding parameters from all other models
|
||||||
for model in self.model_list:
|
for other_model in self.model_list:
|
||||||
# 加上当前模型的对应参数
|
avg_param += other_model.state_dict()[name]
|
||||||
avg_param += model.state_dict()[name]
|
|
||||||
|
|
||||||
# 计算平均值
|
# Calculate the average
|
||||||
avg_param /= len(self.model_list)
|
avg_param /= len(self.model_list)
|
||||||
|
|
||||||
|
# Update the current model's parameter
|
||||||
|
param.data.copy_(avg_param)
|
||||||
|
|
||||||
# 更新 main_model 的参数
|
|
||||||
main_param.copy_(avg_param)
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ model:
|
||||||
cheb_order: 2
|
cheb_order: 2
|
||||||
use_day: True
|
use_day: True
|
||||||
use_week: True
|
use_week: True
|
||||||
|
use_minigraph: False
|
||||||
|
minigraph_size: 10
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
local_update_steps: 1
|
local_update_steps: 1
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ model:
|
||||||
cheb_order: 2
|
cheb_order: 2
|
||||||
use_day: True
|
use_day: True
|
||||||
use_week: True
|
use_week: True
|
||||||
use_minigraph: True
|
use_minigraph: False
|
||||||
minigraph_size: 10
|
minigraph_size: 10
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ model:
|
||||||
cheb_order: 2
|
cheb_order: 2
|
||||||
use_day: True
|
use_day: True
|
||||||
use_week: True
|
use_week: True
|
||||||
|
use_minigraph: False
|
||||||
|
minigraph_size: 10
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
local_update_steps: 1
|
local_update_steps: 1
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ model:
|
||||||
cheb_order: 2
|
cheb_order: 2
|
||||||
use_day: True
|
use_day: True
|
||||||
use_week: True
|
use_week: True
|
||||||
use_minigraph: True
|
use_minigraph: False
|
||||||
minigraph_size: 5
|
minigraph_size: 10
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
local_update_steps: 1
|
local_update_steps: 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue