import re import matplotlib.pyplot as plt import numpy as np import os from matplotlib.lines import Line2D from matplotlib.gridspec import GridSpec from matplotlib.font_manager import FontProperties # 全局字典,包含每个指标的基准线数据和颜色 MAE_baselines = { 'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'), 'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'), 'AGCRN': ([15.98, 19.83, 22.37, 15.95], 'darkorange'), 'STGODE': ([16.50, 20.84, 22.59, 16.81], 'purple') } RMSE_baselines = { 'STG-NCDE': ([27.09, 31.09, 33.84, 24.81], 'darkblue'), 'DCRNN': ([30.31, 33.44, 38.61, 26.36], 'darkgreen'), 'AGCRN': ([28.25, 32.26, 36.55, 25.22], 'darkorange'), 'STGODE': ([27.84, 32.82, 37.54, 25.97], 'purple') } MAPE_baselines = { 'STG-NCDE': ([15.06, 12.76, 8.80, 9.92], 'darkblue'), 'DCRNN': ([18.34, 14.17, 11.82, 10.92], 'darkgreen'), 'AGCRN': ([15.23, 12.97, 9.12, 10.09], 'darkorange'), 'STGODE': ([16.69, 13.77, 10.14, 10.62], 'purple') } baseline_dict = {'MAE': MAE_baselines, 'RMSE': RMSE_baselines, 'MAPE': MAPE_baselines} def extract_avg_loss(log_file_path): avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)") client_loss_data = {} with open(log_file_path, 'r') as f: log_content = f.read() matches = avg_loss_pattern.findall(log_content) for match in matches: if int(match[1]) <= 60: client_id = int(match[0]) round_id = int(match[1]) avg_loss = float(match[2]) if client_id not in client_loss_data: client_loss_data[client_id] = [] client_loss_data[client_id].append((round_id, avg_loss)) return client_loss_data def plot_avg_loss(client_loss_data, dataset, metric, ax, baselines): all_clients_avg_losses = [] sorted_clients = sorted(client_loss_data.keys()) colors = plt.get_cmap('tab20').colors for idx, client_id in enumerate(sorted_clients): losses = client_loss_data[client_id] rounds = [r[0] for r in losses] avg_losses = [r[1] for r in losses] # 当数据集为 D7 且指标为 MAPE 时,限制最多绘制前 20 轮 if dataset == '7' and metric == 'MAPE': rounds = [r for r in rounds if r <= 60] avg_losses = avg_losses[:len(rounds)] color = colors[idx % len(colors)] ax.plot(rounds, avg_losses, color=color) all_clients_avg_losses.append(avg_losses) mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0) ax.plot(rounds, mean_avg_loss_per_round, color='red', linewidth=3) dataset_index = ['3', '4', '7', '8'].index(dataset) for model_name, (baseline_values, color) in baselines.items(): baseline_value = baseline_values[dataset_index] ax.axhline(y=baseline_value, color=color, linestyle='--') ax.grid(True) def main(legend_ncol=8): metrics = ['MAE', 'RMSE', 'MAPE'] datasets = ['3', '4', '7', '8'] fig = plt.figure(figsize=(16, 8), constrained_layout=False) gs = GridSpec(3, 4, figure=fig, wspace=0.3, hspace=0.6) subfig_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)', '(j)', '(k)', '(l)'] # 设置字体为 Times New Roman # times_new_roman_font = FontProperties(fname='./times.ttf') # 遍历指标和数据集,并绘制子图 for i, metric in enumerate(metrics): for j, dataset in enumerate(datasets): ax_idx = i * 4 + j ax = fig.add_subplot(gs[i, j]) log_file_path = f'./{metric}/D{dataset}/exp_print.log' client_loss_data = extract_avg_loss(log_file_path) plot_avg_loss(client_loss_data, dataset, metric, ax, baseline_dict[metric]) ax.set_ylabel(metric) ax.set_xlabel('Round') # 在底部居中添加标题 title = f'{subfig_labels[ax_idx]} {metric} on PEMSD{dataset}' # ax.text(0.5, -0.4, title, transform=ax.transAxes, # fontsize=10, fontweight='bold', va='center', ha='center', fontproperties=times_new_roman_font) ax.text(0.5, -0.4, title, transform=ax.transAxes, fontsize=10, fontweight='bold', va='center', ha='center') # 设置统一的图例在顶部 handles = [ Line2D([0], [0], color='red', linewidth=3, label='FedDGCN-Avg') ] for model_name, (_, color) in RMSE_baselines.items(): handles.append(Line2D([0], [0], color=color, linestyle='--', label=model_name)) client_handles = [Line2D([0], [0], color=plt.get_cmap('tab20')(i), lw=2) for i in range(10)] client_labels = [f'Client #{i + 1}' for i in range(10)] # 创建图例并加黑框背景 legend = fig.legend(handles=handles + client_handles, labels=[h.get_label() for h in handles] + client_labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=legend_ncol, fontsize='small', frameon=True, edgecolor='black') legend.get_frame().set_linewidth(1.5) # 黑框宽度 plt.savefig('baseline.jpg', bbox_inches='tight') plt.show() if __name__ == '__main__': main()