From 5fdab2b66895426171808cf2ab1f2cf30bca3578 Mon Sep 17 00:00:00 2001 From: HengZhang Date: Thu, 28 Nov 2024 11:46:32 +0800 Subject: [PATCH] model subgraph --- .../cl/lr_scheduler/LR_Scheduler.py | 68 +++++++++++++++++++ federatedscope/trafficflow/model/FedDGCNv2.py | 53 +++++++-------- scripts/trafficflow_exp_scripts/D3.yaml | 2 + scripts/trafficflow_exp_scripts/D4.yaml | 2 +- scripts/trafficflow_exp_scripts/D7.yaml | 2 + scripts/trafficflow_exp_scripts/D8.yaml | 4 +- 6 files changed, 101 insertions(+), 30 deletions(-) create mode 100644 federatedscope/cl/lr_scheduler/LR_Scheduler.py diff --git a/federatedscope/cl/lr_scheduler/LR_Scheduler.py b/federatedscope/cl/lr_scheduler/LR_Scheduler.py new file mode 100644 index 0000000..90b82a5 --- /dev/null +++ b/federatedscope/cl/lr_scheduler/LR_Scheduler.py @@ -0,0 +1,68 @@ +import numpy as np +from federatedscope.register import register_scheduler + + +# LR Scheduler +class LR_Scheduler(object): + def __init__(self, + optimizer, + warmup_epochs, + warmup_lr, + num_epochs, + base_lr, + final_lr, + iter_per_epoch, + constant_predictor_lr=False): + self.base_lr = base_lr + self.constant_predictor_lr = constant_predictor_lr + warmup_iter = iter_per_epoch * warmup_epochs + warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) + decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) + cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * ( + 1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter)) + + self.lr_schedule = np.concatenate( + (warmup_lr_schedule, cosine_lr_schedule)) + self.optimizer = optimizer + self.iter = 0 + self.current_lr = 0 + + def step(self): + for param_group in self.optimizer.param_groups: + + if self.constant_predictor_lr and param_group[ + 'name'] == 'predictor': + param_group['lr'] = self.base_lr + else: + lr = param_group['lr'] = self.lr_schedule[self.iter] + + self.iter += 1 + self.current_lr = lr + return lr + + def get_lr(self): + return self.current_lr + + +def get_scheduler(optimizer, type): + try: + import torch.optim as optim + except ImportError: + optim = None + scheduler = None + + if type == 'cos_lr_scheduler': + if optim is not None: + lr_lambda = [lambda epoch: epoch // 30] + scheduler = optim.lr_scheduler.LambdaLR(optimizer, + warmup_epochs=0, + warmup_lr=0, + num_epochs=50, + base_lr=30, + final_lr=0, + iter_per_epoch=int(50000 / + 512)) + return scheduler + + +register_scheduler('cos_lr_scheduler', get_scheduler) diff --git a/federatedscope/trafficflow/model/FedDGCNv2.py b/federatedscope/trafficflow/model/FedDGCNv2.py index 1b92e5d..542cb9a 100644 --- a/federatedscope/trafficflow/model/FedDGCNv2.py +++ b/federatedscope/trafficflow/model/FedDGCNv2.py @@ -2,7 +2,7 @@ from torch.nn import ModuleList import torch import torch.nn as nn from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell - +import time class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): @@ -25,8 +25,7 @@ class DGCRM(nn.Module): state = init_state[i] inner_states = [] for t in range(seq_length): - state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, - [node_embeddings[0][:, t, :, :], node_embeddings[1]]) + state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) inner_states.append(state) output_hidden.append(state) current_inputs = torch.stack(inner_states, dim=1) @@ -36,8 +35,7 @@ class DGCRM(nn.Module): init_states = [] for i in range(self.num_layers): init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) - return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) - + return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) # Build you torch or tf model class here class FedDGCN(nn.Module): @@ -52,7 +50,7 @@ class FedDGCN(nn.Module): self.num_layers = args.num_layers self.use_D = args.use_day self.use_W = args.use_week - self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 + self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 self.dropout2 = nn.Dropout(p=args.dropout) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) @@ -75,16 +73,16 @@ class FedDGCN(nn.Module): def forward(self, source): node_embedding1 = self.node_embeddings1 if self.use_D: - t_i_d_data = source[..., 1] + t_i_d_data = source[..., 1] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) if self.use_W: - d_i_w_data = source[..., 2] + d_i_w_data = source[..., 2] D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) - node_embeddings = [node_embedding1, self.node_embeddings1] + node_embeddings=[node_embedding1,self.node_embeddings1] source = source[..., 0].unsqueeze(-1) @@ -133,6 +131,7 @@ class FederatedFedDGCN(nn.Module): subgraph_outputs = [] # Iterate through the subgraph models + # Parallel computation has not been realized yet, so it may slower than normal. for i in range(self.subgraph_num): # Extract the subgraph-specific data subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims) @@ -143,28 +142,28 @@ class FederatedFedDGCN(nn.Module): # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) - - # self.update_main_model() - + self.local_aggregate() return output_tensor - def update_main_model(self): + def local_aggregate(self): """ - 更新 main_model 的参数为 model_list 中所有模型参数的平均值。 + Update the parameters of each model in model_list to the average of all models' parameters. """ - # 遍历 main_model 的参数 - with torch.no_grad(): # 确保更新时不会计算梯度 - for name, main_param in self.main_model.named_parameters(): - # 初始化平均值的容器 - avg_param = torch.zeros_like(main_param) + with torch.no_grad(): # Ensure no gradients are calculated during the update + # Iterate over each model in model_list + for i, model in enumerate(self.model_list): + # Iterate over each model's parameters + for name, param in model.named_parameters(): + # Initialize a container for the average value + avg_param = torch.zeros_like(param) - # 遍历 model_list 中的所有模型 - for model in self.model_list: - # 加上当前模型的对应参数 - avg_param += model.state_dict()[name] + # Accumulate the corresponding parameters from all other models + for other_model in self.model_list: + avg_param += other_model.state_dict()[name] - # 计算平均值 - avg_param /= len(self.model_list) + # Calculate the average + avg_param /= len(self.model_list) + + # Update the current model's parameter + param.data.copy_(avg_param) - # 更新 main_model 的参数 - main_param.copy_(avg_param) diff --git a/scripts/trafficflow_exp_scripts/D3.yaml b/scripts/trafficflow_exp_scripts/D3.yaml index 2f19f56..1f6d0d8 100644 --- a/scripts/trafficflow_exp_scripts/D3.yaml +++ b/scripts/trafficflow_exp_scripts/D3.yaml @@ -42,6 +42,8 @@ model: cheb_order: 2 use_day: True use_week: True + use_minigraph: False + minigraph_size: 10 train: batch_or_epoch: 'epoch' local_update_steps: 1 diff --git a/scripts/trafficflow_exp_scripts/D4.yaml b/scripts/trafficflow_exp_scripts/D4.yaml index 1dc3aa6..97d85b7 100644 --- a/scripts/trafficflow_exp_scripts/D4.yaml +++ b/scripts/trafficflow_exp_scripts/D4.yaml @@ -44,7 +44,7 @@ model: cheb_order: 2 use_day: True use_week: True - use_minigraph: True + use_minigraph: False minigraph_size: 10 train: batch_or_epoch: 'epoch' diff --git a/scripts/trafficflow_exp_scripts/D7.yaml b/scripts/trafficflow_exp_scripts/D7.yaml index 518d52d..7dd6070 100644 --- a/scripts/trafficflow_exp_scripts/D7.yaml +++ b/scripts/trafficflow_exp_scripts/D7.yaml @@ -42,6 +42,8 @@ model: cheb_order: 2 use_day: True use_week: True + use_minigraph: False + minigraph_size: 10 train: batch_or_epoch: 'epoch' local_update_steps: 1 diff --git a/scripts/trafficflow_exp_scripts/D8.yaml b/scripts/trafficflow_exp_scripts/D8.yaml index 97614aa..498b71e 100644 --- a/scripts/trafficflow_exp_scripts/D8.yaml +++ b/scripts/trafficflow_exp_scripts/D8.yaml @@ -42,8 +42,8 @@ model: cheb_order: 2 use_day: True use_week: True - use_minigraph: True - minigraph_size: 5 + use_minigraph: False + minigraph_size: 10 train: batch_or_epoch: 'epoch' local_update_steps: 1