diff --git a/exp/MAPE/drawResult.py b/exp/MAPE/drawResult.py deleted file mode 100644 index f336bad..0000000 --- a/exp/MAPE/drawResult.py +++ /dev/null @@ -1,85 +0,0 @@ -import re -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import cm - -# 全局字典,包含模型在不同数据集上的效果值和颜色 -MAPE_baselines = { - 'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'), - 'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'), - 'AGCRN': ([15.98, 19.83, 22.37, 15.95], 'darkorange'), - 'STGODE': ([16.50, 20.84, 22.59, 16.81], 'purple') -} - -def extract_avg_loss(log_file_path): - avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)") - client_loss_data = {} - - with open(log_file_path, 'r') as f: - log_content = f.read() - matches = avg_loss_pattern.findall(log_content) - - for match in matches: - if int(match[1]) <= 20: - client_id = int(match[0]) - round_id = int(match[1]) - avg_loss = float(match[2]) - else: - continue - - if client_id not in client_loss_data: - client_loss_data[client_id] = [] - client_loss_data[client_id].append((round_id, avg_loss)) - - return client_loss_data - - -def plot_avg_loss(client_loss_data, dataset): - plt.figure(figsize=(6, 4)) - all_clients_avg_losses = [] - handles = [] - labels = [] - - sorted_clients = sorted(client_loss_data.keys()) - colors = plt.get_cmap('tab20').colors # 使用 tab20 调色板 - - for idx, client_id in enumerate(sorted_clients): - losses = client_loss_data[client_id] - rounds = [r[0] for r in losses] - avg_losses = [r[1] for r in losses] - color = colors[idx % len(colors)] # 为每个 client 分配不同颜色 - line, = plt.plot(rounds, avg_losses, label=f'Client #{client_id}', color=color) - handles.append(line) - labels.append(f'Client #{client_id}') - all_clients_avg_losses.append(avg_losses) - - mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0) - mean_line, = plt.plot(rounds, mean_avg_loss_per_round, label='Mean Test Loss', color='red', linewidth=3) - handles.append(mean_line) - labels.append('Mean Avg Loss') - - # 添加模型基准线 - dataset_index = ['3', '4', '7', '8'].index(dataset) - for model_name, (baseline_values, color) in model_baselines.items(): - baseline_value = baseline_values[dataset_index] - line = plt.axhline(y=baseline_value, color=color, linestyle='--', label=model_name) - handles.append(line) - labels.append(f'{model_name}') - - plt.xlabel('Round') - plt.ylabel('Test Loss') - plt.title(f'Client Test Loss over Rounds in PeMSD{dataset}') - plt.grid(True) - - plt.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(1, 0.5)) - plt.savefig(f'D{dataset}_MAPE.png', bbox_inches='tight') - plt.show() - - -if __name__ == '__main__': - datasets = ['4', '7'] - - for dataset in datasets: - log_file_path = f'./D{dataset}_MAPE/exp_print.log' - client_loss_data = extract_avg_loss(log_file_path) - plot_avg_loss(client_loss_data, dataset) diff --git a/exp/RMSE/D3_RMSE.png b/exp/RMSE/D3_RMSE.png deleted file mode 100644 index 051e38d..0000000 Binary files a/exp/RMSE/D3_RMSE.png and /dev/null differ diff --git a/exp/RMSE/D4_RMSE.png b/exp/RMSE/D4_RMSE.png deleted file mode 100644 index 6755d1c..0000000 Binary files a/exp/RMSE/D4_RMSE.png and /dev/null differ diff --git a/exp/RMSE/D7_RMSE.png b/exp/RMSE/D7_RMSE.png deleted file mode 100644 index b148ef0..0000000 Binary files a/exp/RMSE/D7_RMSE.png and /dev/null differ diff --git a/exp/RMSE/D8_RMSE.png b/exp/RMSE/D8_RMSE.png deleted file mode 100644 index 75f35f5..0000000 Binary files a/exp/RMSE/D8_RMSE.png and /dev/null differ diff --git a/exp/RMSE/drawResult.py b/exp/RMSE/drawResult.py deleted file mode 100644 index a3e9070..0000000 --- a/exp/RMSE/drawResult.py +++ /dev/null @@ -1,85 +0,0 @@ -import re -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import cm - -# 全局字典,包含模型在不同数据集上的效果值和颜色 -RMSE_baselines = { - 'STG-NCDE': ([27.09, 31.09, 33.84, 24.81], 'darkblue'), - 'DCRNN': ([30.31, 33.44, 38.61, 26.36], 'darkgreen'), - 'AGCRN': ([28.25, 32.26, 36.55, 25.22], 'darkorange'), - 'STGODE': ([27.84, 32.82, 37.54, 25.97], 'purple') -} - -def extract_avg_loss(log_file_path): - avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)") - client_loss_data = {} - - with open(log_file_path, 'r') as f: - log_content = f.read() - matches = avg_loss_pattern.findall(log_content) - - for match in matches: - if int(match[1]) <= 59: - client_id = int(match[0]) - round_id = int(match[1]) - avg_loss = float(match[2]) - else: - continue - - if client_id not in client_loss_data: - client_loss_data[client_id] = [] - client_loss_data[client_id].append((round_id, avg_loss)) - - return client_loss_data - - -def plot_avg_loss(client_loss_data, dataset): - plt.figure(figsize=(6, 4)) - all_clients_avg_losses = [] - handles = [] - labels = [] - - sorted_clients = sorted(client_loss_data.keys()) - colors = plt.get_cmap('tab20').colors # 使用 tab20 调色板 - - for idx, client_id in enumerate(sorted_clients): - losses = client_loss_data[client_id] - rounds = [r[0] for r in losses] - avg_losses = [r[1] for r in losses] - color = colors[idx % len(colors)] # 为每个 client 分配不同颜色 - line, = plt.plot(rounds, avg_losses, label=f'Client #{client_id}', color=color) - handles.append(line) - labels.append(f'Client #{client_id}') - all_clients_avg_losses.append(avg_losses) - - mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0) - mean_line, = plt.plot(rounds, mean_avg_loss_per_round, label='Mean Test Loss', color='red', linewidth=3) - handles.append(mean_line) - labels.append('Mean Avg Loss') - - # 添加模型基准线 - dataset_index = ['3', '4', '7', '8'].index(dataset) - for model_name, (baseline_values, color) in model_baselines.items(): - baseline_value = baseline_values[dataset_index] - line = plt.axhline(y=baseline_value, color=color, linestyle='--', label=model_name) - handles.append(line) - labels.append(f'{model_name}') - - plt.xlabel('Round') - plt.ylabel('Test Loss') - plt.title(f'Client Test Loss over Rounds in PeMSD{dataset}') - plt.grid(True) - - plt.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(1, 0.5)) - plt.savefig(f'D{dataset}_RMSE.png', bbox_inches='tight') - plt.show() - - -if __name__ == '__main__': - datasets = ['4', '3', '8', '7'] - - for dataset in datasets: - log_file_path = f'./D{dataset}/exp_print.log' - client_loss_data = extract_avg_loss(log_file_path) - plot_avg_loss(client_loss_data, dataset)