94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
import numpy as np
|
|
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
|
from typing import Dict, Union
|
|
|
|
|
|
def calculate_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
|
"""
|
|
计算评估指标
|
|
|
|
Args:
|
|
y_true: 真实值
|
|
y_pred: 预测值
|
|
|
|
Returns:
|
|
包含各种指标的字典
|
|
"""
|
|
# 确保输入是numpy数组
|
|
y_true = np.array(y_true)
|
|
y_pred = np.array(y_pred)
|
|
|
|
# 计算各种指标
|
|
mse = mean_squared_error(y_true, y_pred)
|
|
rmse = np.sqrt(mse)
|
|
mae = mean_absolute_error(y_true, y_pred)
|
|
|
|
# 计算MAPE
|
|
mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-8))) * 100
|
|
|
|
# 计算R²
|
|
ss_res = np.sum((y_true - y_pred) ** 2)
|
|
ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
|
|
r2 = 1 - (ss_res / (ss_tot + 1e-8))
|
|
|
|
# 计算SMAPE
|
|
smape = 2.0 * np.mean(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred) + 1e-8)) * 100
|
|
|
|
metrics = {
|
|
'MSE': mse,
|
|
'RMSE': rmse,
|
|
'MAE': mae,
|
|
'MAPE': mape,
|
|
'R2': r2,
|
|
'SMAPE': smape
|
|
}
|
|
|
|
return metrics
|
|
|
|
|
|
def calculate_rolling_metrics(y_true: np.ndarray, y_pred: np.ndarray,
|
|
window: int = 10) -> Dict[str, np.ndarray]:
|
|
"""
|
|
计算滚动评估指标
|
|
|
|
Args:
|
|
y_true: 真实值
|
|
y_pred: 预测值
|
|
window: 滚动窗口大小
|
|
|
|
Returns:
|
|
包含滚动指标的字典
|
|
"""
|
|
y_true = np.array(y_true)
|
|
y_pred = np.array(y_pred)
|
|
|
|
n = len(y_true)
|
|
if n < window:
|
|
return {}
|
|
|
|
rolling_mse = []
|
|
rolling_mae = []
|
|
rolling_mape = []
|
|
|
|
for i in range(window, n + 1):
|
|
start_idx = i - window
|
|
end_idx = i
|
|
|
|
true_window = y_true[start_idx:end_idx]
|
|
pred_window = y_pred[start_idx:end_idx]
|
|
|
|
# 计算窗口内的指标
|
|
mse = mean_squared_error(true_window, pred_window)
|
|
mae = mean_absolute_error(true_window, pred_window)
|
|
mape = np.mean(np.abs((true_window - pred_window) / (true_window + 1e-8))) * 100
|
|
|
|
rolling_mse.append(mse)
|
|
rolling_mae.append(mae)
|
|
rolling_mape.append(mape)
|
|
|
|
return {
|
|
'rolling_MSE': np.array(rolling_mse),
|
|
'rolling_MAE': np.array(rolling_mae),
|
|
'rolling_MAPE': np.array(rolling_mape)
|
|
}
|