English
This commit is contained in:
parent
1b1744b999
commit
10d1999316
|
|
@ -1,12 +1,12 @@
|
||||||
import re
|
import re
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
# import os
|
||||||
from matplotlib.lines import Line2D
|
from matplotlib.lines import Line2D
|
||||||
from matplotlib.gridspec import GridSpec
|
from matplotlib.gridspec import GridSpec
|
||||||
from matplotlib.font_manager import FontProperties
|
# from matplotlib.font_manager import FontProperties
|
||||||
|
|
||||||
# 全局字典,包含每个指标的基准线数据和颜色
|
# Global dictionary containing baseline data and colors for each metric
|
||||||
MAE_baselines = {
|
MAE_baselines = {
|
||||||
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
|
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
|
||||||
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
|
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
|
||||||
|
|
@ -29,6 +29,7 @@ baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_base
|
||||||
|
|
||||||
|
|
||||||
def extract_avg_loss(log_file_path):
|
def extract_avg_loss(log_file_path):
|
||||||
|
# Regex to extract test loss from log files
|
||||||
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
||||||
client_loss_data = {}
|
client_loss_data = {}
|
||||||
|
|
||||||
|
|
@ -60,7 +61,7 @@ def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines):
|
||||||
rounds = [r[0] for r in losses]
|
rounds = [r[0] for r in losses]
|
||||||
avg_losses = [r[1] for r in losses]
|
avg_losses = [r[1] for r in losses]
|
||||||
|
|
||||||
# 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮
|
# Limit to the first 20 rounds for dataset D7 and metric MAPE
|
||||||
if dataset == '7' and metric == 'MAPE':
|
if dataset == '7' and metric == 'MAPE':
|
||||||
rounds = [r for r in rounds if r <= 60]
|
rounds = [r for r in rounds if r <= 60]
|
||||||
avg_losses = avg_losses[:len(rounds)]
|
avg_losses = avg_losses[:len(rounds)]
|
||||||
|
|
@ -89,9 +90,9 @@ def main(legend_ncol=8):
|
||||||
|
|
||||||
subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)']
|
subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)']
|
||||||
|
|
||||||
# 设置字体为 Times New Roman
|
# Set font to Times New Roman
|
||||||
# times_new_roman_font = FontProperties(fname='./times.ttf')
|
# times_new_roman_font = FontProperties(fname='./times.ttf')
|
||||||
# 遍历指标和数据集,并绘制子图
|
# Iterate through metrics and datasets, and plot subfigures
|
||||||
for i, metric in enumerate(metrics):
|
for i, metric in enumerate(metrics):
|
||||||
for j, dataset in enumerate(datasets):
|
for j, dataset in enumerate(datasets):
|
||||||
ax_idx = i * 4 + j
|
ax_idx = i * 4 + j
|
||||||
|
|
@ -103,14 +104,14 @@ def main(legend_ncol=8):
|
||||||
ax.set_ylabel(metric)
|
ax.set_ylabel(metric)
|
||||||
ax.set_xlabel('Round')
|
ax.set_xlabel('Round')
|
||||||
|
|
||||||
# 在底部居中添加标题
|
# Add title at the bottom center
|
||||||
title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}'
|
title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}'
|
||||||
# ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
# ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
||||||
# fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font)
|
# fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font)
|
||||||
ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
||||||
fontsize=10, fontweight='bold', va='center', ha='center')
|
fontsize=10, fontweight='bold', va='center', ha='center')
|
||||||
|
|
||||||
# 设置统一的图例在顶部
|
# Set unified legend at the top
|
||||||
handles = [
|
handles = [
|
||||||
Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg')
|
Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg')
|
||||||
]
|
]
|
||||||
|
|
@ -119,12 +120,12 @@ def main(legend_ncol=8):
|
||||||
client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)]
|
client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)]
|
||||||
client_labels = [f'Client #{i + 1}' for i in range(10)]
|
client_labels = [f'Client #{i + 1}' for i in range(10)]
|
||||||
|
|
||||||
# 创建图例并加黑框背景
|
# Create legend with black border
|
||||||
legend = fig.legend(handles=handles + client_handles,
|
legend = fig.legend(handles=handles + client_handles,
|
||||||
labels=[h.get_label() for h in handles] + client_labels,
|
labels=[h.get_label() for h in handles] + client_labels,
|
||||||
loc='upper center', bbox_to_anchor=(0.5, 0.98),
|
loc='upper center', bbox_to_anchor=(0.5, 0.98),
|
||||||
ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black')
|
ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black')
|
||||||
legend.get_frame().set_linewidth(1.5) # 黑框宽度
|
legend.get_frame().set_linewidth(1.5) # Black border width
|
||||||
|
|
||||||
plt.savefig('baseline.jpg', bbox_inches='tight')
|
plt.savefig('baseline.jpg', bbox_inches='tight')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ def get_dataloader(dataset, config, split='train'):
|
||||||
from federatedscope.mf.dataloader import MFDataLoader
|
from federatedscope.mf.dataloader import MFDataLoader
|
||||||
loader_cls = MFDataLoader
|
loader_cls = MFDataLoader
|
||||||
elif config.dataloader.type == 'trafficflow':
|
elif config.dataloader.type == 'trafficflow':
|
||||||
# 待定
|
# This if is not strictly necessary but helps avoid potential bugs
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
loader_cls = DataLoader
|
loader_cls = DataLoader
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -80,12 +80,12 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||||||
"and its subclasses, or " \
|
"and its subclasses, or " \
|
||||||
"`GeneralTorchTrainer` and its subclasses"
|
"`GeneralTorchTrainer` and its subclasses"
|
||||||
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
||||||
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
|
# Copy attributes from base_trainer one by one, skipping non-copyable objects
|
||||||
for key, value in base_trainer.__dict__.items():
|
for key, value in base_trainer.__dict__.items():
|
||||||
try:
|
try:
|
||||||
self.__dict__[key] = copy.deepcopy(value)
|
self.__dict__[key] = copy.deepcopy(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
self.__dict__[key] = value # 如果不能 deepcopy,使用浅拷贝
|
self.__dict__[key] = value # If unable to deepcopy, use shallow copy
|
||||||
|
|
||||||
assert models_interact_mode in ["sequential", "parallel"], \
|
assert models_interact_mode in ["sequential", "parallel"], \
|
||||||
f"Invalid models_interact_mode, should be `sequential` or " \
|
f"Invalid models_interact_mode, should be `sequential` or " \
|
||||||
|
|
|
||||||
|
|
@ -135,7 +135,7 @@ def load_traffic_data(config, client_cfgs):
|
||||||
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
||||||
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
||||||
|
|
||||||
# 客户端分割数据集
|
# Client-side dataset splitting
|
||||||
node_num = config.data.num_nodes
|
node_num = config.data.num_nodes
|
||||||
client_num = config.federate.client_num
|
client_num = config.federate.client_num
|
||||||
per_samples = node_num // client_num
|
per_samples = node_num // client_num
|
||||||
|
|
@ -143,7 +143,7 @@ def load_traffic_data(config, client_cfgs):
|
||||||
input_dim, output_dim = config.model.input_dim, config.model.output_dim
|
input_dim, output_dim = config.model.input_dim, config.model.output_dim
|
||||||
for i in range(client_num):
|
for i in range(client_num):
|
||||||
if cur_index + per_samples <= node_num:
|
if cur_index + per_samples <= node_num:
|
||||||
# 正常截取
|
# Normal slicing
|
||||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||||
|
|
@ -152,7 +152,7 @@ def load_traffic_data(config, client_cfgs):
|
||||||
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim]
|
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||||
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim]
|
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||||
else:
|
else:
|
||||||
# 不足一个per_samples,补0列
|
# If there are not enough nodes to fill per_samples, pad with zero columns
|
||||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class DGCN(nn.Module):
|
||||||
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||||||
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||||||
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||||
# 初始化参数
|
# Initialize parameters
|
||||||
nn.init.xavier_uniform_(self.weights_pool)
|
nn.init.xavier_uniform_(self.weights_pool)
|
||||||
nn.init.xavier_uniform_(self.weights)
|
nn.init.xavier_uniform_(self.weights)
|
||||||
nn.init.zeros_(self.bias_pool)
|
nn.init.zeros_(self.bias_pool)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class FedDGCN(nn.Module):
|
||||||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
||||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
||||||
# 初始化参数
|
# Initialize parameters
|
||||||
nn.init.xavier_uniform_(self.node_embeddings1)
|
nn.init.xavier_uniform_(self.node_embeddings1)
|
||||||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ class TrafficSplitter(BaseSplitter):
|
||||||
|
|
||||||
def __call__(self, dataset, *args, **kwargs):
|
def __call__(self, dataset, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
后面考虑子图标记划分
|
TODO:subgraph partition
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: ndarray(timestep, num_node, channel)
|
dataset: ndarray(timestep, num_node, channel)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue