TrafficWheel/generate_launch_configs.py

134 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
# 配置路径
CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config"
LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json"
# 遍历所有yaml文件
def find_all_yaml_files(directory):
yaml_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".yaml") and not file.startswith("BJTaxi"):
yaml_files.append(os.path.join(root, file))
return yaml_files
# 生成launch配置字符串
def generate_launch_config_string(yaml_files):
config_strings = []
for file_path in yaml_files:
# 提取模型名和数据集名
relative_path = os.path.relpath(file_path, CONFIG_DIR)
model_name = relative_path.split(os.sep)[0]
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
# 处理v2版本
if "v2_" in dataset_name:
model_display_name = f"{model_name}_v2"
dataset_display_name = dataset_name.replace("v2_", "")
else:
model_display_name = model_name
dataset_display_name = dataset_name
# 生成配置字符串
config_string = f'''
{{
"name": "{model_display_name}: {dataset_display_name}",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/{model_name}/{os.path.basename(file_path)}"
}}'''
config_strings.append(config_string)
return ",".join(config_strings)
# 读取现有的launch.json文件提取配置名称
def get_existing_config_names():
with open(LAUNCH_FILE, 'r') as f:
content = f.read()
# 提取所有配置名称
name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"')
matches = name_pattern.findall(content)
return set(matches)
# 生成新的配置,过滤掉已存在的
def generate_new_configs(yaml_files, existing_names):
new_configs = []
for file_path in yaml_files:
# 提取模型名和数据集名
relative_path = os.path.relpath(file_path, CONFIG_DIR)
model_name = relative_path.split(os.sep)[0]
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
# 处理v2版本
if "v2_" in dataset_name:
model_display_name = f"{model_name}_v2"
dataset_display_name = dataset_name.replace("v2_", "")
else:
model_display_name = model_name
dataset_display_name = dataset_name
# 生成配置名称
config_name = f"{model_display_name}: {dataset_display_name}"
# 如果配置不存在,则添加
if config_name not in existing_names:
new_configs.append(file_path)
return new_configs
# 更新launch.json文件
def update_launch_json(new_configs_string):
with open(LAUNCH_FILE, 'r') as f:
content = f.read()
# 找到configurations数组的结束位置
configs_end_match = re.search(r'\s*\]\s*\}', content)
if not configs_end_match:
return False
# 插入新的配置
insert_pos = configs_end_match.start()
new_content = content[:insert_pos] + new_configs_string + content[insert_pos:]
# 保存文件
with open(LAUNCH_FILE, 'w') as f:
f.write(new_content)
return True
# 主函数
def main():
# 查找所有yaml文件
yaml_files = find_all_yaml_files(CONFIG_DIR)
# 获取现有配置名称
existing_names = get_existing_config_names()
# 生成新的配置,过滤掉已存在的
new_config_files = generate_new_configs(yaml_files, existing_names)
if not new_config_files:
print("No new configurations to add")
return
# 生成新的配置字符串
new_configs_string = generate_launch_config_string(new_config_files)
# 更新launch.json文件
if update_launch_json(new_configs_string):
print(f"Added {len(new_config_files)} new launch configurations")
print(f"Total configurations: {len(existing_names) + len(new_config_files)}")
else:
print("Failed to update launch.json")
if __name__ == "__main__":
main()