model subgraph
This commit is contained in:
parent
e95c13f0fc
commit
5fdab2b668
|
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
from federatedscope.register import register_scheduler
|
||||
|
||||
|
||||
# LR Scheduler
|
||||
class LR_Scheduler(object):
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
warmup_epochs,
|
||||
warmup_lr,
|
||||
num_epochs,
|
||||
base_lr,
|
||||
final_lr,
|
||||
iter_per_epoch,
|
||||
constant_predictor_lr=False):
|
||||
self.base_lr = base_lr
|
||||
self.constant_predictor_lr = constant_predictor_lr
|
||||
warmup_iter = iter_per_epoch * warmup_epochs
|
||||
warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
|
||||
decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
|
||||
cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * (
|
||||
1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter))
|
||||
|
||||
self.lr_schedule = np.concatenate(
|
||||
(warmup_lr_schedule, cosine_lr_schedule))
|
||||
self.optimizer = optimizer
|
||||
self.iter = 0
|
||||
self.current_lr = 0
|
||||
|
||||
def step(self):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
|
||||
if self.constant_predictor_lr and param_group[
|
||||
'name'] == 'predictor':
|
||||
param_group['lr'] = self.base_lr
|
||||
else:
|
||||
lr = param_group['lr'] = self.lr_schedule[self.iter]
|
||||
|
||||
self.iter += 1
|
||||
self.current_lr = lr
|
||||
return lr
|
||||
|
||||
def get_lr(self):
|
||||
return self.current_lr
|
||||
|
||||
|
||||
def get_scheduler(optimizer, type):
|
||||
try:
|
||||
import torch.optim as optim
|
||||
except ImportError:
|
||||
optim = None
|
||||
scheduler = None
|
||||
|
||||
if type == 'cos_lr_scheduler':
|
||||
if optim is not None:
|
||||
lr_lambda = [lambda epoch: epoch // 30]
|
||||
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
||||
warmup_epochs=0,
|
||||
warmup_lr=0,
|
||||
num_epochs=50,
|
||||
base_lr=30,
|
||||
final_lr=0,
|
||||
iter_per_epoch=int(50000 /
|
||||
512))
|
||||
return scheduler
|
||||
|
||||
|
||||
register_scheduler('cos_lr_scheduler', get_scheduler)
|
||||
|
|
@ -2,7 +2,7 @@ from torch.nn import ModuleList
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||
|
||||
import time
|
||||
|
||||
class DGCRM(nn.Module):
|
||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||
|
|
@ -25,8 +25,7 @@ class DGCRM(nn.Module):
|
|||
state = init_state[i]
|
||||
inner_states = []
|
||||
for t in range(seq_length):
|
||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
|
||||
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||
inner_states.append(state)
|
||||
output_hidden.append(state)
|
||||
current_inputs = torch.stack(inner_states, dim=1)
|
||||
|
|
@ -36,8 +35,7 @@ class DGCRM(nn.Module):
|
|||
init_states = []
|
||||
for i in range(self.num_layers):
|
||||
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
||||
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
||||
|
||||
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
|
||||
|
||||
# Build you torch or tf model class here
|
||||
class FedDGCN(nn.Module):
|
||||
|
|
@ -84,7 +82,7 @@ class FedDGCN(nn.Module):
|
|||
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
|
||||
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
|
||||
|
||||
node_embeddings = [node_embedding1, self.node_embeddings1]
|
||||
node_embeddings=[node_embedding1,self.node_embeddings1]
|
||||
|
||||
source = source[..., 0].unsqueeze(-1)
|
||||
|
||||
|
|
@ -133,6 +131,7 @@ class FederatedFedDGCN(nn.Module):
|
|||
subgraph_outputs = []
|
||||
|
||||
# Iterate through the subgraph models
|
||||
# Parallel computation has not been realized yet, so it may slower than normal.
|
||||
for i in range(self.subgraph_num):
|
||||
# Extract the subgraph-specific data
|
||||
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
||||
|
|
@ -143,28 +142,28 @@ class FederatedFedDGCN(nn.Module):
|
|||
|
||||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
|
||||
# self.update_main_model()
|
||||
|
||||
self.local_aggregate()
|
||||
return output_tensor
|
||||
|
||||
def update_main_model(self):
|
||||
def local_aggregate(self):
|
||||
"""
|
||||
更新 main_model 的参数为 model_list 中所有模型参数的平均值。
|
||||
Update the parameters of each model in model_list to the average of all models' parameters.
|
||||
"""
|
||||
# 遍历 main_model 的参数
|
||||
with torch.no_grad(): # 确保更新时不会计算梯度
|
||||
for name, main_param in self.main_model.named_parameters():
|
||||
# 初始化平均值的容器
|
||||
avg_param = torch.zeros_like(main_param)
|
||||
with torch.no_grad(): # Ensure no gradients are calculated during the update
|
||||
# Iterate over each model in model_list
|
||||
for i, model in enumerate(self.model_list):
|
||||
# Iterate over each model's parameters
|
||||
for name, param in model.named_parameters():
|
||||
# Initialize a container for the average value
|
||||
avg_param = torch.zeros_like(param)
|
||||
|
||||
# 遍历 model_list 中的所有模型
|
||||
for model in self.model_list:
|
||||
# 加上当前模型的对应参数
|
||||
avg_param += model.state_dict()[name]
|
||||
# Accumulate the corresponding parameters from all other models
|
||||
for other_model in self.model_list:
|
||||
avg_param += other_model.state_dict()[name]
|
||||
|
||||
# 计算平均值
|
||||
# Calculate the average
|
||||
avg_param /= len(self.model_list)
|
||||
|
||||
# 更新 main_model 的参数
|
||||
main_param.copy_(avg_param)
|
||||
# Update the current model's parameter
|
||||
param.data.copy_(avg_param)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ model:
|
|||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: False
|
||||
minigraph_size: 10
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ model:
|
|||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: True
|
||||
use_minigraph: False
|
||||
minigraph_size: 10
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ model:
|
|||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: False
|
||||
minigraph_size: 10
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ model:
|
|||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: True
|
||||
minigraph_size: 5
|
||||
use_minigraph: False
|
||||
minigraph_size: 10
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue