our experiment log
This commit is contained in:
parent
d236ba5eae
commit
069ca21618
|
|
@ -1,85 +0,0 @@
|
||||||
import re
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from matplotlib import cm
|
|
||||||
|
|
||||||
# 全局字典,包含模型在不同数据集上的效果值和颜色
|
|
||||||
MAPE_baselines = {
|
|
||||||
'STG-NCDE': ([15.57, 19.21, 20.53, 15.45], 'darkblue'),
|
|
||||||
'DCRNN': ([17.99, 21.22, 25.22, 16.82], 'darkgreen'),
|
|
||||||
'AGCRN': ([15.98, 19.83, 22.37, 15.95], 'darkorange'),
|
|
||||||
'STGODE': ([16.50, 20.84, 22.59, 16.81], 'purple')
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract_avg_loss(log_file_path):
|
|
||||||
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
|
||||||
client_loss_data = {}
|
|
||||||
|
|
||||||
with open(log_file_path, 'r') as f:
|
|
||||||
log_content = f.read()
|
|
||||||
matches = avg_loss_pattern.findall(log_content)
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
if int(match[1]) <= 20:
|
|
||||||
client_id = int(match[0])
|
|
||||||
round_id = int(match[1])
|
|
||||||
avg_loss = float(match[2])
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if client_id not in client_loss_data:
|
|
||||||
client_loss_data[client_id] = []
|
|
||||||
client_loss_data[client_id].append((round_id, avg_loss))
|
|
||||||
|
|
||||||
return client_loss_data
|
|
||||||
|
|
||||||
|
|
||||||
def plot_avg_loss(client_loss_data, dataset):
|
|
||||||
plt.figure(figsize=(6, 4))
|
|
||||||
all_clients_avg_losses = []
|
|
||||||
handles = []
|
|
||||||
labels = []
|
|
||||||
|
|
||||||
sorted_clients = sorted(client_loss_data.keys())
|
|
||||||
colors = plt.get_cmap('tab20').colors # 使用 tab20 调色板
|
|
||||||
|
|
||||||
for idx, client_id in enumerate(sorted_clients):
|
|
||||||
losses = client_loss_data[client_id]
|
|
||||||
rounds = [r[0] for r in losses]
|
|
||||||
avg_losses = [r[1] for r in losses]
|
|
||||||
color = colors[idx % len(colors)] # 为每个 client 分配不同颜色
|
|
||||||
line, = plt.plot(rounds, avg_losses, label=f'Client #{client_id}', color=color)
|
|
||||||
handles.append(line)
|
|
||||||
labels.append(f'Client #{client_id}')
|
|
||||||
all_clients_avg_losses.append(avg_losses)
|
|
||||||
|
|
||||||
mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0)
|
|
||||||
mean_line, = plt.plot(rounds, mean_avg_loss_per_round, label='Mean Test Loss', color='red', linewidth=3)
|
|
||||||
handles.append(mean_line)
|
|
||||||
labels.append('Mean Avg Loss')
|
|
||||||
|
|
||||||
# 添加模型基准线
|
|
||||||
dataset_index = ['3', '4', '7', '8'].index(dataset)
|
|
||||||
for model_name, (baseline_values, color) in model_baselines.items():
|
|
||||||
baseline_value = baseline_values[dataset_index]
|
|
||||||
line = plt.axhline(y=baseline_value, color=color, linestyle='--', label=model_name)
|
|
||||||
handles.append(line)
|
|
||||||
labels.append(f'{model_name}')
|
|
||||||
|
|
||||||
plt.xlabel('Round')
|
|
||||||
plt.ylabel('Test Loss')
|
|
||||||
plt.title(f'Client Test Loss over Rounds in PeMSD{dataset}')
|
|
||||||
plt.grid(True)
|
|
||||||
|
|
||||||
plt.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(1, 0.5))
|
|
||||||
plt.savefig(f'D{dataset}_MAPE.png', bbox_inches='tight')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
datasets = ['4', '7']
|
|
||||||
|
|
||||||
for dataset in datasets:
|
|
||||||
log_file_path = f'./D{dataset}_MAPE/exp_print.log'
|
|
||||||
client_loss_data = extract_avg_loss(log_file_path)
|
|
||||||
plot_avg_loss(client_loss_data, dataset)
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 85 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 88 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 106 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 103 KiB |
|
|
@ -1,85 +0,0 @@
|
||||||
import re
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from matplotlib import cm
|
|
||||||
|
|
||||||
# 全局字典,包含模型在不同数据集上的效果值和颜色
|
|
||||||
RMSE_baselines = {
|
|
||||||
'STG-NCDE': ([27.09, 31.09, 33.84, 24.81], 'darkblue'),
|
|
||||||
'DCRNN': ([30.31, 33.44, 38.61, 26.36], 'darkgreen'),
|
|
||||||
'AGCRN': ([28.25, 32.26, 36.55, 25.22], 'darkorange'),
|
|
||||||
'STGODE': ([27.84, 32.82, 37.54, 25.97], 'purple')
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract_avg_loss(log_file_path):
|
|
||||||
avg_loss_pattern = re.compile(r"Client #(\d+).*?Round.*?(\d+).*?test_loss': (\d+\.\d+)")
|
|
||||||
client_loss_data = {}
|
|
||||||
|
|
||||||
with open(log_file_path, 'r') as f:
|
|
||||||
log_content = f.read()
|
|
||||||
matches = avg_loss_pattern.findall(log_content)
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
if int(match[1]) <= 59:
|
|
||||||
client_id = int(match[0])
|
|
||||||
round_id = int(match[1])
|
|
||||||
avg_loss = float(match[2])
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if client_id not in client_loss_data:
|
|
||||||
client_loss_data[client_id] = []
|
|
||||||
client_loss_data[client_id].append((round_id, avg_loss))
|
|
||||||
|
|
||||||
return client_loss_data
|
|
||||||
|
|
||||||
|
|
||||||
def plot_avg_loss(client_loss_data, dataset):
|
|
||||||
plt.figure(figsize=(6, 4))
|
|
||||||
all_clients_avg_losses = []
|
|
||||||
handles = []
|
|
||||||
labels = []
|
|
||||||
|
|
||||||
sorted_clients = sorted(client_loss_data.keys())
|
|
||||||
colors = plt.get_cmap('tab20').colors # 使用 tab20 调色板
|
|
||||||
|
|
||||||
for idx, client_id in enumerate(sorted_clients):
|
|
||||||
losses = client_loss_data[client_id]
|
|
||||||
rounds = [r[0] for r in losses]
|
|
||||||
avg_losses = [r[1] for r in losses]
|
|
||||||
color = colors[idx % len(colors)] # 为每个 client 分配不同颜色
|
|
||||||
line, = plt.plot(rounds, avg_losses, label=f'Client #{client_id}', color=color)
|
|
||||||
handles.append(line)
|
|
||||||
labels.append(f'Client #{client_id}')
|
|
||||||
all_clients_avg_losses.append(avg_losses)
|
|
||||||
|
|
||||||
mean_avg_loss_per_round = np.mean(np.array(all_clients_avg_losses), axis=0)
|
|
||||||
mean_line, = plt.plot(rounds, mean_avg_loss_per_round, label='Mean Test Loss', color='red', linewidth=3)
|
|
||||||
handles.append(mean_line)
|
|
||||||
labels.append('Mean Avg Loss')
|
|
||||||
|
|
||||||
# 添加模型基准线
|
|
||||||
dataset_index = ['3', '4', '7', '8'].index(dataset)
|
|
||||||
for model_name, (baseline_values, color) in model_baselines.items():
|
|
||||||
baseline_value = baseline_values[dataset_index]
|
|
||||||
line = plt.axhline(y=baseline_value, color=color, linestyle='--', label=model_name)
|
|
||||||
handles.append(line)
|
|
||||||
labels.append(f'{model_name}')
|
|
||||||
|
|
||||||
plt.xlabel('Round')
|
|
||||||
plt.ylabel('Test Loss')
|
|
||||||
plt.title(f'Client Test Loss over Rounds in PeMSD{dataset}')
|
|
||||||
plt.grid(True)
|
|
||||||
|
|
||||||
plt.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(1, 0.5))
|
|
||||||
plt.savefig(f'D{dataset}_RMSE.png', bbox_inches='tight')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
datasets = ['4', '3', '8', '7']
|
|
||||||
|
|
||||||
for dataset in datasets:
|
|
||||||
log_file_path = f'./D{dataset}/exp_print.log'
|
|
||||||
client_loss_data = extract_avg_loss(log_file_path)
|
|
||||||
plot_avg_loss(client_loss_data, dataset)
|
|
||||||
Loading…
Reference in New Issue