This commit is contained in:
HengZhang 2024-11-26 19:35:53 +08:00
parent 1b1744b999
commit 10d1999316
7 changed files with 20 additions and 19 deletions

View File

@ -1,12 +1,12 @@
import re
import matplotlib.pyplot as plt
import numpy as np
import os
# import os
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.font_manager import FontProperties
# from matplotlib.font_manager import FontProperties
# 全局字典,包含每个指标的基准线数据和颜色
# Global dictionary containing baseline data and colors for each metric
MAE_baselines = {
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
@ -29,6 +29,7 @@ baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_base
def extract_avg_loss(log_file_path):
# Regex to extract test loss from log files
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
client_loss_data = {}
@ -60,7 +61,7 @@ def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines):
rounds = [r[0] for r in losses]
avg_losses = [r[1] for r in losses]
# 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮
# Limit to the first 20 rounds for dataset D7 and metric MAPE
if dataset == '7' and metric == 'MAPE':
rounds = [r for r in rounds if r <= 60]
avg_losses = avg_losses[:len(rounds)]
@ -89,9 +90,9 @@ def main(legend_ncol=8):
subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)']
# 设置字体为 Times New Roman
# Set font to Times New Roman
# times_new_roman_font = FontProperties(fname='./times.ttf')
# 遍历指标和数据集,并绘制子图
# Iterate through metrics and datasets, and plot subfigures
for i, metric in enumerate(metrics):
for j, dataset in enumerate(datasets):
ax_idx = i * 4 + j
@ -103,14 +104,14 @@ def main(legend_ncol=8):
ax.set_ylabel(metric)
ax.set_xlabel('Round')
# 在底部居中添加标题
# Add title at the bottom center
title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}'
# ax.text(0.5, -0.4, title, transform=ax.transAxes,
# fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font)
ax.text(0.5, -0.4, title, transform=ax.transAxes,
fontsize=10, fontweight='bold', va='center', ha='center')
# 设置统一的图例在顶部
# Set unified legend at the top
handles = [
Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg')
]
@ -119,12 +120,12 @@ def main(legend_ncol=8):
client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)]
client_labels = [f'Client #{i + 1}' for i in range(10)]
# 创建图例并加黑框背景
# Create legend with black border
legend = fig.legend(handles=handles + client_handles,
labels=[h.get_label() for h in handles] + client_labels,
loc='upper center', bbox_to_anchor=(0.5, 0.98),
ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black')
legend.get_frame().set_linewidth(1.5) # 黑框宽度
legend.get_frame().set_linewidth(1.5) # Black border width
plt.savefig('baseline.jpg', bbox_inches='tight')
plt.show()

View File

@ -63,7 +63,7 @@ def get_dataloader(dataset, config, split='train'):
from federatedscope.mf.dataloader import MFDataLoader
loader_cls = MFDataLoader
elif config.dataloader.type == 'trafficflow':
# 待定
# This if is not strictly necessary but helps avoid potential bugs
from torch.utils.data import DataLoader
loader_cls = DataLoader
else:

View File

@ -80,12 +80,12 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
"and its subclasses, or " \
"`GeneralTorchTrainer` and its subclasses"
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
# Copy attributes from base_trainer one by one, skipping non-copyable objects
for key, value in base_trainer.__dict__.items():
try:
self.__dict__[key] = copy.deepcopy(value)
except TypeError:
self.__dict__[key] = value # 如果不能 deepcopy使用浅拷贝
self.__dict__[key] = value # If unable to deepcopy, use shallow copy
assert models_interact_mode in ["sequential", "parallel"], \
f"Invalid models_interact_mode, should be `sequential` or " \

View File

@ -135,7 +135,7 @@ def load_traffic_data(config, client_cfgs):
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
# 客户端分割数据集
# Client-side dataset splitting
node_num = config.data.num_nodes
client_num = config.federate.client_num
per_samples = node_num // client_num
@ -143,7 +143,7 @@ def load_traffic_data(config, client_cfgs):
input_dim, output_dim = config.model.input_dim, config.model.output_dim
for i in range(client_num):
if cur_index + per_samples <= node_num:
# 正常截取
# Normal slicing
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
@ -152,7 +152,7 @@ def load_traffic_data(config, client_cfgs):
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim]
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim]
else:
# 不足一个per_samples补0列
# If there are not enough nodes to fill per_samples, pad with zero columns
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]

View File

@ -13,7 +13,7 @@ class DGCN(nn.Module):
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# 初始化参数
# Initialize parameters
nn.init.xavier_uniform_(self.weights_pool)
nn.init.xavier_uniform_(self.weights)
nn.init.zeros_(self.bias_pool)

View File

@ -54,7 +54,7 @@ class FedDGCN(nn.Module):
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
# 初始化参数
# Initialize parameters
nn.init.xavier_uniform_(self.node_embeddings1)
nn.init.xavier_uniform_(self.T_i_D_emb)
nn.init.xavier_uniform_(self.D_i_W_emb)

View File

@ -9,7 +9,7 @@ class TrafficSplitter(BaseSplitter):
def __call__(self, dataset, *args, **kwargs):
"""
后面考虑子图标记划分
TODO:subgraph partition
Args:
dataset: ndarray(timestep, num_node, channel)