134 lines
5.1 KiB
Python
134 lines
5.1 KiB
Python
import re
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import os
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.gridspec import GridSpec
|
|
from matplotlib.font_manager import FontProperties
|
|
|
|
# 全局字典,包含每个指标的基准线数据和颜色
|
|
MAE_baselines = {
|
|
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
|
|
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
|
|
'AGCRN': ([15.98, 19.83, 22.37, 15.95], 'darkorange'),
|
|
'STGODE': ([16.50, 20.84, 22.59, 16.81], 'purple')
|
|
}
|
|
RMSE_baselines = {
|
|
'STG-NCDE': ([27.09, 31.09, 33.84, 24.81], 'darkblue'),
|
|
'DCRNN': ([30.31, 33.44, 38.61, 26.36], 'darkgreen'),
|
|
'AGCRN': ([28.25, 32.26, 36.55, 25.22], 'darkorange'),
|
|
'STGODE': ([27.84, 32.82, 37.54, 25.97], 'purple')
|
|
}
|
|
MAPE_baselines = {
|
|
'STG-NCDE': ([15.06, 12.76, 8.80, 9.92], 'darkblue'),
|
|
'DCRNN': ([18.34, 14.17, 11.82, 10.92], 'darkgreen'),
|
|
'AGCRN': ([15.23, 12.97, 9.12, 10.09], 'darkorange'),
|
|
'STGODE': ([16.69, 13.77, 10.14, 10.62], 'purple')
|
|
}
|
|
baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_baselines}
|
|
|
|
|
|
def extract_avg_loss(log_file_path):
|
|
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
|
client_loss_data = {}
|
|
|
|
with open(log_file_path, 'r') as f:
|
|
log_content = f.read()
|
|
matches = avg_loss_pattern.findall(log_content)
|
|
|
|
for match in matches:
|
|
if int(match[1]) <= 60:
|
|
client_id = int(match[0])
|
|
round_id = int(match[1])
|
|
avg_loss = float(match[2])
|
|
|
|
if client_id not in client_loss_data:
|
|
client_loss_data[client_id] = []
|
|
client_loss_data[client_id].append((round_id, avg_loss))
|
|
|
|
return client_loss_data
|
|
|
|
|
|
def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines):
|
|
all_clients_avg_losses = []
|
|
|
|
sorted_clients = sorted(client_loss_data.keys())
|
|
colors = plt.get_cmap('tab20').colors
|
|
|
|
for idx, client_id in enumerate(sorted_clients):
|
|
losses = client_loss_data[client_id]
|
|
rounds = [r[0] for r in losses]
|
|
avg_losses = [r[1] for r in losses]
|
|
|
|
# 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮
|
|
if dataset == '7' and metric == 'MAPE':
|
|
rounds = [r for r in rounds if r <= 60]
|
|
avg_losses = avg_losses[:len(rounds)]
|
|
|
|
color = colors[idx % len(colors)]
|
|
ax.plot(rounds, avg_losses, color=color)
|
|
|
|
all_clients_avg_losses.append(avg_losses)
|
|
|
|
mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0)
|
|
ax.plot(rounds, mean_avg_loss_per_round, color='red', linewidth=3)
|
|
|
|
dataset_index = ['3', '4', '7', '8'].index(dataset)
|
|
for model_name, (baseline_values, color) in baselines.items():
|
|
baseline_value = baseline_values[dataset_index]
|
|
ax.axhline(y=baseline_value, color=color, linestyle='--')
|
|
|
|
ax.grid(True)
|
|
|
|
def main(legend_ncol=8):
|
|
metrics = ['MAE', 'RMSE', 'MAPE']
|
|
datasets = ['3', '4', '7', '8']
|
|
|
|
fig = plt.figure(figsize=(16, 8), constrained_layout=False)
|
|
gs = GridSpec(3, 4, figure=fig, wspace=0.3, hspace=0.6)
|
|
|
|
subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)']
|
|
|
|
# 设置字体为 Times New Roman
|
|
# times_new_roman_font = FontProperties(fname='./times.ttf')
|
|
# 遍历指标和数据集,并绘制子图
|
|
for i, metric in enumerate(metrics):
|
|
for j, dataset in enumerate(datasets):
|
|
ax_idx = i * 4 + j
|
|
ax = fig.add_subplot(gs[i, j])
|
|
log_file_path = f'./{metric}/D{dataset}/exp_print.log'
|
|
client_loss_data = extract_avg_loss(log_file_path)
|
|
plot_avg_loss(client_loss_data, dataset, metric, ax, baseline_dict[metric])
|
|
|
|
ax.set_ylabel(metric)
|
|
ax.set_xlabel('Round')
|
|
|
|
# 在底部居中添加标题
|
|
title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}'
|
|
# ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
|
# fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font)
|
|
ax.text(0.5, -0.4, title, transform=ax.transAxes,
|
|
fontsize=10, fontweight='bold', va='center', ha='center')
|
|
|
|
# 设置统一的图例在顶部
|
|
handles = [
|
|
Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg')
|
|
]
|
|
for model_name, (_, color) in RMSE_baselines.items():
|
|
handles.append(Line2D([0], [0], color=color, linestyle='--', label=model_name))
|
|
client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)]
|
|
client_labels = [f'Client #{i + 1}' for i in range(10)]
|
|
|
|
# 创建图例并加黑框背景
|
|
legend = fig.legend(handles=handles + client_handles,
|
|
labels=[h.get_label() for h in handles] + client_labels,
|
|
loc='upper center', bbox_to_anchor=(0.5, 0.98),
|
|
ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black')
|
|
legend.get_frame().set_linewidth(1.5) # 黑框宽度
|
|
|
|
plt.savefig('baseline.jpg', bbox_inches='tight')
|
|
plt.show()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|