From 10d19993164d19b04057d1897763282f59c9eb80 Mon Sep 17 00:00:00 2001 From: HengZhang Date: Tue, 26 Nov 2024 19:35:53 +0800 Subject: [PATCH] English --- exp/global.py | 21 ++++++++++--------- .../core/auxiliaries/dataloader_builder.py | 2 +- .../core/trainers/trainer_multi_model.py | 4 ++-- .../dataloader/traffic_dataloader.py | 6 +++--- federatedscope/trafficflow/model/DGCN.py | 2 +- federatedscope/trafficflow/model/FedDGCN.py | 2 +- .../trafficflow/splitters/trafficSplitter.py | 2 +- 7 files changed, 20 insertions(+), 19 deletions(-) diff --git a/exp/global.py b/exp/global.py index f196796..176f5c4 100644 --- a/exp/global.py +++ b/exp/global.py @@ -1,12 +1,12 @@ import re import matplotlib.pyplot as plt import numpy as np -import os +# import os from matplotlib.lines import Line2D from matplotlib.gridspec import GridSpec -from matplotlib.font_manager import FontProperties +# from matplotlib.font_manager import FontProperties -# 全局字典,包含每个指标的基准线数据和颜色 +# Global dictionary containing baseline data and colors for each metric MAE_baselines = { 'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'), 'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'), @@ -29,6 +29,7 @@ baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_base def extract_avg_loss(log_file_path): + # Regex to extract test loss from log files avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)") client_loss_data = {} @@ -60,7 +61,7 @@ def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines): rounds = [r[0] for r in losses] avg_losses = [r[1] for r in losses] - # 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮 + # Limit to the first 20 rounds for dataset D7 and metric MAPE if dataset == '7' and metric == 'MAPE': rounds = [r for r in rounds if r <= 60] avg_losses = avg_losses[:len(rounds)] @@ -89,9 +90,9 @@ def main(legend_ncol=8): subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)'] - # 设置字体为 Times New Roman + # Set font to Times New Roman # times_new_roman_font = FontProperties(fname='./times.ttf') - # 遍历指标和数据集,并绘制子图 + # Iterate through metrics and datasets, and plot subfigures for i, metric in enumerate(metrics): for j, dataset in enumerate(datasets): ax_idx = i * 4 + j @@ -103,14 +104,14 @@ def main(legend_ncol=8): ax.set_ylabel(metric) ax.set_xlabel('Round') - # 在底部居中添加标题 + # Add title at the bottom center title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}' # ax.text(0.5, -0.4, title, transform=ax.transAxes, # fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font) ax.text(0.5, -0.4, title, transform=ax.transAxes, fontsize=10, fontweight='bold', va='center', ha='center') - # 设置统一的图例在顶部 + # Set unified legend at the top handles = [ Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg') ] @@ -119,12 +120,12 @@ def main(legend_ncol=8): client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)] client_labels = [f'Client #{i + 1}' for i in range(10)] - # 创建图例并加黑框背景 + # Create legend with black border legend = fig.legend(handles=handles + client_handles, labels=[h.get_label() for h in handles] + client_labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black') - legend.get_frame().set_linewidth(1.5) # 黑框宽度 + legend.get_frame().set_linewidth(1.5) # Black border width plt.savefig('baseline.jpg', bbox_inches='tight') plt.show() diff --git a/federatedscope/core/auxiliaries/dataloader_builder.py b/federatedscope/core/auxiliaries/dataloader_builder.py index 5eea675..a6d056b 100644 --- a/federatedscope/core/auxiliaries/dataloader_builder.py +++ b/federatedscope/core/auxiliaries/dataloader_builder.py @@ -63,7 +63,7 @@ def get_dataloader(dataset, config, split='train'): from federatedscope.mf.dataloader import MFDataLoader loader_cls = MFDataLoader elif config.dataloader.type == 'trafficflow': - # 待定 + # This if is not strictly necessary but helps avoid potential bugs from torch.utils.data import DataLoader loader_cls = DataLoader else: diff --git a/federatedscope/core/trainers/trainer_multi_model.py b/federatedscope/core/trainers/trainer_multi_model.py index f86aeb4..7680cb9 100644 --- a/federatedscope/core/trainers/trainer_multi_model.py +++ b/federatedscope/core/trainers/trainer_multi_model.py @@ -80,12 +80,12 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer): "and its subclasses, or " \ "`GeneralTorchTrainer` and its subclasses" # self.__dict__ = copy.deepcopy(base_trainer.__dict__) - # 逐个复制 base_trainer 的属性,跳过不可拷贝的对象 + # Copy attributes from base_trainer one by one, skipping non-copyable objects for key, value in base_trainer.__dict__.items(): try: self.__dict__[key] = copy.deepcopy(value) except TypeError: - self.__dict__[key] = value # 如果不能 deepcopy,使用浅拷贝 + self.__dict__[key] = value # If unable to deepcopy, use shallow copy assert models_interact_mode in ["sequential", "parallel"], \ f"Invalid models_interact_mode, should be `sequential` or " \ diff --git a/federatedscope/trafficflow/dataloader/traffic_dataloader.py b/federatedscope/trafficflow/dataloader/traffic_dataloader.py index bf3823f..6766432 100644 --- a/federatedscope/trafficflow/dataloader/traffic_dataloader.py +++ b/federatedscope/trafficflow/dataloader/traffic_dataloader.py @@ -135,7 +135,7 @@ def load_traffic_data(config, client_cfgs): # y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim]) # y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim]) - # 客户端分割数据集 + # Client-side dataset splitting node_num = config.data.num_nodes client_num = config.federate.client_num per_samples = node_num // client_num @@ -143,7 +143,7 @@ def load_traffic_data(config, client_cfgs): input_dim, output_dim = config.model.input_dim, config.model.output_dim for i in range(client_num): if cur_index + per_samples <= node_num: - # 正常截取 + # Normal slicing sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :] sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :] sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :] @@ -152,7 +152,7 @@ def load_traffic_data(config, client_cfgs): sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim] sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim] else: - # 不足一个per_samples,补0列 + # If there are not enough nodes to fill per_samples, pad with zero columns sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :] sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :] sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :] diff --git a/federatedscope/trafficflow/model/DGCN.py b/federatedscope/trafficflow/model/DGCN.py index 6ec5cf4..4e045b8 100644 --- a/federatedscope/trafficflow/model/DGCN.py +++ b/federatedscope/trafficflow/model/DGCN.py @@ -13,7 +13,7 @@ class DGCN(nn.Module): self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out)) - # 初始化参数 + # Initialize parameters nn.init.xavier_uniform_(self.weights_pool) nn.init.xavier_uniform_(self.weights) nn.init.zeros_(self.bias_pool) diff --git a/federatedscope/trafficflow/model/FedDGCN.py b/federatedscope/trafficflow/model/FedDGCN.py index eccf3af..2f59b1c 100644 --- a/federatedscope/trafficflow/model/FedDGCN.py +++ b/federatedscope/trafficflow/model/FedDGCN.py @@ -54,7 +54,7 @@ class FedDGCN(nn.Module): self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim)) self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim)) - # 初始化参数 + # Initialize parameters nn.init.xavier_uniform_(self.node_embeddings1) nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.D_i_W_emb) diff --git a/federatedscope/trafficflow/splitters/trafficSplitter.py b/federatedscope/trafficflow/splitters/trafficSplitter.py index 95aead0..6b03676 100644 --- a/federatedscope/trafficflow/splitters/trafficSplitter.py +++ b/federatedscope/trafficflow/splitters/trafficSplitter.py @@ -9,7 +9,7 @@ class TrafficSplitter(BaseSplitter): def __call__(self, dataset, *args, **kwargs): """ - 后面考虑子图标记划分 + TODO:subgraph partition Args: dataset: ndarray(timestep, num_node, channel)