TrafficWheel/run_tests.sh

95 lines
2.5 KiB
Bash
Executable File

#!/bin/bash
# 设置默认模型名和数据集列表
MODEL_NAME="STAEFormer"
DATASETS=(
"METR-LA"
"PEMS-BAY"
"NYCBike-InFlow"
"NYCBike-OutFlow"
"AirQuality"
"SolarEnergy"
)
# 初始化统计变量
success_count=0
failure_count=0
missing_count=0
total_count=0
success_datasets=()
failure_datasets=()
missing_datasets=()
# 检查是否有参数传入来覆盖默认值
if [ $# -gt 0 ]; then
MODEL_NAME=$1
# 如果传入了更多参数,使用它们作为数据集列表
if [ $# -gt 1 ]; then
DATASETS=(${@:2})
fi
fi
echo "使用模型: $MODEL_NAME"
echo "数据集列表: ${DATASETS[*]}"
echo "开始测试..."
echo ""
# 循环测试每个数据集
for dataset in "${DATASETS[@]}"; do
total_count=$((total_count + 1))
# 构建配置文件路径
CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml"
echo "测试数据集: $dataset"
echo "使用配置文件: $CONFIG_PATH"
# 检查配置文件是否存在
if [ ! -f "$CONFIG_PATH" ]; then
echo "错误: 配置文件 $CONFIG_PATH 不存在!"
missing_count=$((missing_count + 1))
missing_datasets+=("$dataset")
echo "----------------------------------------"
continue
fi
# 执行测试命令,同时捕获输出并显示在控制台上
echo "执行: python run.py --config $CONFIG_PATH"
output=$(python run.py --config "$CONFIG_PATH" 2>&1 | tee /dev/tty)
# 如果没有找到明确的标记,回退到检查退出码
if [ $? -eq 0 ]; then
echo "数据集 $dataset 测试成功! (基于退出码)"
success_count=$((success_count + 1))
success_datasets+=("$dataset")
else
echo "数据集 $dataset 测试失败! (基于退出码)"
failure_count=$((failure_count + 1))
failure_datasets+=("$dataset")
fi
echo "----------------------------------------"
done
# 输出总结
echo "======================================="
echo "测试总结"
echo "======================================="
echo "总数据集数量: $total_count"
echo "成功数量: $success_count"
echo "失败数量: $failure_count"
echo "缺失配置文件数量: $missing_count"
if [ ${#success_datasets[@]} -gt 0 ]; then
echo "成功的数据集: ${success_datasets[*]}"
fi
if [ ${#failure_datasets[@]} -gt 0 ]; then
echo "失败的数据集: ${failure_datasets[*]}"
fi
if [ ${#missing_datasets[@]} -gt 0 ]; then
echo "缺失配置的数据集: ${missing_datasets[*]}"
fi
echo "======================================="
echo "所有测试完成!"