测试所有配置是否正常运行
This commit is contained in:
parent
d4ee8e309e
commit
77a3210475
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,134 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
# 配置路径
|
||||
CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config"
|
||||
LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json"
|
||||
|
||||
# 遍历所有yaml文件
|
||||
def find_all_yaml_files(directory):
|
||||
yaml_files = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith(".yaml") and not file.startswith("BJTaxi"):
|
||||
yaml_files.append(os.path.join(root, file))
|
||||
return yaml_files
|
||||
|
||||
# 生成launch配置字符串
|
||||
def generate_launch_config_string(yaml_files):
|
||||
config_strings = []
|
||||
|
||||
for file_path in yaml_files:
|
||||
# 提取模型名和数据集名
|
||||
relative_path = os.path.relpath(file_path, CONFIG_DIR)
|
||||
model_name = relative_path.split(os.sep)[0]
|
||||
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# 处理v2版本
|
||||
if "v2_" in dataset_name:
|
||||
model_display_name = f"{model_name}_v2"
|
||||
dataset_display_name = dataset_name.replace("v2_", "")
|
||||
else:
|
||||
model_display_name = model_name
|
||||
dataset_display_name = dataset_name
|
||||
|
||||
# 生成配置字符串
|
||||
config_string = f'''
|
||||
{{
|
||||
"name": "{model_display_name}: {dataset_display_name}",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/{model_name}/{os.path.basename(file_path)}"
|
||||
}}'''
|
||||
|
||||
config_strings.append(config_string)
|
||||
|
||||
return ",".join(config_strings)
|
||||
|
||||
# 读取现有的launch.json文件,提取配置名称
|
||||
def get_existing_config_names():
|
||||
with open(LAUNCH_FILE, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 提取所有配置名称
|
||||
name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"')
|
||||
matches = name_pattern.findall(content)
|
||||
|
||||
return set(matches)
|
||||
|
||||
# 生成新的配置,过滤掉已存在的
|
||||
def generate_new_configs(yaml_files, existing_names):
|
||||
new_configs = []
|
||||
|
||||
for file_path in yaml_files:
|
||||
# 提取模型名和数据集名
|
||||
relative_path = os.path.relpath(file_path, CONFIG_DIR)
|
||||
model_name = relative_path.split(os.sep)[0]
|
||||
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# 处理v2版本
|
||||
if "v2_" in dataset_name:
|
||||
model_display_name = f"{model_name}_v2"
|
||||
dataset_display_name = dataset_name.replace("v2_", "")
|
||||
else:
|
||||
model_display_name = model_name
|
||||
dataset_display_name = dataset_name
|
||||
|
||||
# 生成配置名称
|
||||
config_name = f"{model_display_name}: {dataset_display_name}"
|
||||
|
||||
# 如果配置不存在,则添加
|
||||
if config_name not in existing_names:
|
||||
new_configs.append(file_path)
|
||||
|
||||
return new_configs
|
||||
|
||||
# 更新launch.json文件
|
||||
def update_launch_json(new_configs_string):
|
||||
with open(LAUNCH_FILE, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 找到configurations数组的结束位置
|
||||
configs_end_match = re.search(r'\s*\]\s*\}', content)
|
||||
if not configs_end_match:
|
||||
return False
|
||||
|
||||
# 插入新的配置
|
||||
insert_pos = configs_end_match.start()
|
||||
new_content = content[:insert_pos] + new_configs_string + content[insert_pos:]
|
||||
|
||||
# 保存文件
|
||||
with open(LAUNCH_FILE, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
return True
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
# 查找所有yaml文件
|
||||
yaml_files = find_all_yaml_files(CONFIG_DIR)
|
||||
|
||||
# 获取现有配置名称
|
||||
existing_names = get_existing_config_names()
|
||||
|
||||
# 生成新的配置,过滤掉已存在的
|
||||
new_config_files = generate_new_configs(yaml_files, existing_names)
|
||||
|
||||
if not new_config_files:
|
||||
print("No new configurations to add")
|
||||
return
|
||||
|
||||
# 生成新的配置字符串
|
||||
new_configs_string = generate_launch_config_string(new_config_files)
|
||||
|
||||
# 更新launch.json文件
|
||||
if update_launch_json(new_configs_string):
|
||||
print(f"Added {len(new_config_files)} new launch configurations")
|
||||
print(f"Total configurations: {len(existing_names) + len(new_config_files)}")
|
||||
else:
|
||||
print("Failed to update launch.json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
import os
|
||||
import subprocess
|
||||
import yaml
|
||||
import time
|
||||
|
||||
# 配置路径
|
||||
CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config"
|
||||
RUN_SCRIPT = "/user/czzhangheng/code/TrafficWheel/run.py"
|
||||
RESULTS_FILE = "/user/czzhangheng/code/TrafficWheel/test_results.txt"
|
||||
|
||||
# 记录测试结果的字典
|
||||
results = {
|
||||
"passed": [],
|
||||
"failed": [],
|
||||
"error": []
|
||||
}
|
||||
|
||||
# 遍历所有yaml文件
|
||||
def find_all_yaml_files(directory):
|
||||
yaml_files = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith(".yaml") and not file.startswith("BJTaxi"):
|
||||
yaml_files.append(os.path.join(root, file))
|
||||
return yaml_files
|
||||
|
||||
# 测试单个yaml文件
|
||||
def test_yaml_file(yaml_path):
|
||||
print(f"\n=== Testing {yaml_path} ===")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(yaml_path):
|
||||
print(f"File not found: {yaml_path}")
|
||||
return "error", f"File not found: {yaml_path}"
|
||||
|
||||
# 运行测试命令
|
||||
command = ["python", RUN_SCRIPT, "--config", yaml_path]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600 # 10分钟超时
|
||||
)
|
||||
|
||||
# 分析结果
|
||||
if result.returncode == 0:
|
||||
if "Test passed" in result.stdout:
|
||||
print(f"✓ PASSED: {yaml_path}")
|
||||
return "passed", result.stdout.strip()
|
||||
else:
|
||||
print(f"✗ FAILED: {yaml_path}")
|
||||
return "failed", result.stdout.strip() + "\n" + result.stderr.strip()
|
||||
else:
|
||||
print(f"✗ ERROR: {yaml_path}")
|
||||
return "error", result.stdout.strip() + "\n" + result.stderr.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"✗ TIMEOUT: {yaml_path}")
|
||||
return "error", "Timeout after 10 minutes"
|
||||
except Exception as e:
|
||||
print(f"✗ EXCEPTION: {yaml_path}")
|
||||
return "error", str(e)
|
||||
|
||||
# 生成测试报告
|
||||
def generate_report(results):
|
||||
total = len(results["passed"]) + len(results["failed"]) + len(results["error"])
|
||||
|
||||
report = f"""# 测试报告
|
||||
|
||||
## 测试概述
|
||||
- 测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}
|
||||
- 总测试文件数: {total}
|
||||
- 通过: {len(results['passed'])}
|
||||
- 失败: {len(results['failed'])}
|
||||
- 错误: {len(results['error'])}
|
||||
|
||||
## 通过的配置文件
|
||||
"""
|
||||
|
||||
for file_path, output in results["passed"]:
|
||||
report += f"- ✅ {file_path}\n"
|
||||
|
||||
report += "\n## 失败的配置文件\n"
|
||||
for file_path, output in results["failed"]:
|
||||
report += f"- ❌ {file_path}\n"
|
||||
|
||||
report += "\n## 出错的配置文件\n"
|
||||
for file_path, output in results["error"]:
|
||||
report += f"- ⚠️ {file_path}\n"
|
||||
|
||||
report += "\n## 详细输出\n"
|
||||
|
||||
for status, files in results.items():
|
||||
report += f"\n### {status.upper()}\n\n"
|
||||
for file_path, output in files:
|
||||
report += f"#### {file_path}\n\n```\n{output}\n```\n\n"
|
||||
|
||||
return report
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
# 找到所有符合条件的yaml文件
|
||||
yaml_files = find_all_yaml_files(CONFIG_DIR)
|
||||
print(f"Found {len(yaml_files)} yaml files to test")
|
||||
|
||||
# 测试每个文件
|
||||
for yaml_file in yaml_files:
|
||||
status, output = test_yaml_file(yaml_file)
|
||||
results[status].append((yaml_file, output))
|
||||
|
||||
# 生成并保存报告
|
||||
report = generate_report(results)
|
||||
with open(RESULTS_FILE, "w") as f:
|
||||
f.write(report)
|
||||
|
||||
print(f"\n=== Test Results ===")
|
||||
print(f"Total: {len(yaml_files)}")
|
||||
print(f"Passed: {len(results['passed'])}")
|
||||
print(f"Failed: {len(results['failed'])}")
|
||||
print(f"Error: {len(results['error'])}")
|
||||
print(f"Report saved to: {RESULTS_FILE}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -103,6 +103,16 @@ class Trainer:
|
|||
output = self.model(data, labels=label.clone()).to(self.device)
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
# 反归一化
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
|
|
|||
|
|
@ -108,6 +108,16 @@ class Trainer:
|
|||
label = target[..., : self.args["output_dim"]]
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
# 反归一化
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
|
|
|||
|
|
@ -107,6 +107,16 @@ class Trainer:
|
|||
# 计算原始loss
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
# 反归一化
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
|
|
|||
|
|
@ -137,11 +137,31 @@ class Trainer:
|
|||
|
||||
# 总loss
|
||||
loss = loss1 + 10 * tkloss + 1 * scl
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 普通训练模式
|
||||
output, out_, _ = self.model(data)
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
# 反归一化
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
|
|
|||
|
|
@ -110,6 +110,16 @@ class Trainer:
|
|||
# 计算原始loss
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 检查output和label的shape是否一致
|
||||
if output.shape == label.shape:
|
||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
# 反归一化
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
|
|
|||
Loading…
Reference in New Issue