English
This commit is contained in:
parent
1b1744b999
commit
10d1999316
|
|
@ -1,12 +1,12 @@
|
|||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
# import os
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.gridspec import GridSpec
|
||||
from matplotlib.font_manager import FontProperties
|
||||
# from matplotlib.font_manager import FontProperties
|
||||
|
||||
# 全局字典,包含每个指标的基准线数据和颜色
|
||||
# Global dictionary containing baseline data and colors for each metric
|
||||
MAE_baselines = {
|
||||
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
|
||||
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
|
||||
|
|
@ -29,6 +29,7 @@ baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_base
|
|||
|
||||
|
||||
def extract_avg_loss(log_file_path):
|
||||
# Regex to extract test loss from log files
|
||||
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
||||
client_loss_data = {}
|
||||
|
||||
|
|
@ -60,7 +61,7 @@ def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines):
|
|||
rounds = [r[0] for r in losses]
|
||||
avg_losses = [r[1] for r in losses]
|
||||
|
||||
# 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮
|
||||
# Limit to the first 20 rounds for dataset D7 and metric MAPE
|
||||
if dataset == '7' and metric == 'MAPE':
|
||||
rounds = [r for r in rounds if r <= 60]
|
||||
avg_losses = avg_losses[:len(rounds)]
|
||||
|
|
@ -89,9 +90,9 @@ def main(legend_ncol=8):
|
|||
|
||||
subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)']
|
||||
|
||||
# 设置字体为 Times New Roman
|
||||
# Set font to Times New Roman
|
||||
# times_new_roman_font = FontProperties(fname='./times.ttf')
|
||||
# 遍历指标和数据集,并绘制子图
|
||||
# Iterate through metrics and datasets, and plot subfigures
|
||||
for i, metric in enumerate(metrics):
|
||||
for j, dataset in enumerate(datasets):
|
||||
ax_idx = i * 4 + j
|
||||
|
|
@ -103,14 +104,14 @@ def main(legend_ncol=8):
|
|||
ax.set_ylabel(metric)
|
||||
ax.set_xlabel('Round')
|
||||
|
||||
# 在底部居中添加标题
|
||||
# Add title at the bottom center
|
||||
title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}'
|
||||
# ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
||||
# fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font)
|
||||
ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
||||
fontsize=10, fontweight='bold', va='center', ha='center')
|
||||
|
||||
# 设置统一的图例在顶部
|
||||
# Set unified legend at the top
|
||||
handles = [
|
||||
Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg')
|
||||
]
|
||||
|
|
@ -119,12 +120,12 @@ def main(legend_ncol=8):
|
|||
client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)]
|
||||
client_labels = [f'Client #{i + 1}' for i in range(10)]
|
||||
|
||||
# 创建图例并加黑框背景
|
||||
# Create legend with black border
|
||||
legend = fig.legend(handles=handles + client_handles,
|
||||
labels=[h.get_label() for h in handles] + client_labels,
|
||||
loc='upper center', bbox_to_anchor=(0.5, 0.98),
|
||||
ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black')
|
||||
legend.get_frame().set_linewidth(1.5) # 黑框宽度
|
||||
legend.get_frame().set_linewidth(1.5) # Black border width
|
||||
|
||||
plt.savefig('baseline.jpg', bbox_inches='tight')
|
||||
plt.show()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def get_dataloader(dataset, config, split='train'):
|
|||
from federatedscope.mf.dataloader import MFDataLoader
|
||||
loader_cls = MFDataLoader
|
||||
elif config.dataloader.type == 'trafficflow':
|
||||
# 待定
|
||||
# This if is not strictly necessary but helps avoid potential bugs
|
||||
from torch.utils.data import DataLoader
|
||||
loader_cls = DataLoader
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -80,12 +80,12 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
|||
"and its subclasses, or " \
|
||||
"`GeneralTorchTrainer` and its subclasses"
|
||||
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
||||
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
|
||||
# Copy attributes from base_trainer one by one, skipping non-copyable objects
|
||||
for key, value in base_trainer.__dict__.items():
|
||||
try:
|
||||
self.__dict__[key] = copy.deepcopy(value)
|
||||
except TypeError:
|
||||
self.__dict__[key] = value # 如果不能 deepcopy,使用浅拷贝
|
||||
self.__dict__[key] = value # If unable to deepcopy, use shallow copy
|
||||
|
||||
assert models_interact_mode in ["sequential", "parallel"], \
|
||||
f"Invalid models_interact_mode, should be `sequential` or " \
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def load_traffic_data(config, client_cfgs):
|
|||
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
||||
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
||||
|
||||
# 客户端分割数据集
|
||||
# Client-side dataset splitting
|
||||
node_num = config.data.num_nodes
|
||||
client_num = config.federate.client_num
|
||||
per_samples = node_num // client_num
|
||||
|
|
@ -143,7 +143,7 @@ def load_traffic_data(config, client_cfgs):
|
|||
input_dim, output_dim = config.model.input_dim, config.model.output_dim
|
||||
for i in range(client_num):
|
||||
if cur_index + per_samples <= node_num:
|
||||
# 正常截取
|
||||
# Normal slicing
|
||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||
|
|
@ -152,7 +152,7 @@ def load_traffic_data(config, client_cfgs):
|
|||
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||
else:
|
||||
# 不足一个per_samples,补0列
|
||||
# If there are not enough nodes to fill per_samples, pad with zero columns
|
||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class DGCN(nn.Module):
|
|||
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||||
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||||
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||
# 初始化参数
|
||||
# Initialize parameters
|
||||
nn.init.xavier_uniform_(self.weights_pool)
|
||||
nn.init.xavier_uniform_(self.weights)
|
||||
nn.init.zeros_(self.bias_pool)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class FedDGCN(nn.Module):
|
|||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
||||
# 初始化参数
|
||||
# Initialize parameters
|
||||
nn.init.xavier_uniform_(self.node_embeddings1)
|
||||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class TrafficSplitter(BaseSplitter):
|
|||
|
||||
def __call__(self, dataset, *args, **kwargs):
|
||||
"""
|
||||
后面考虑子图标记划分
|
||||
TODO:subgraph partition
|
||||
|
||||
Args:
|
||||
dataset: ndarray(timestep, num_node, channel)
|
||||
|
|
|
|||
Loading…
Reference in New Issue