134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
import os
|
||
import re
|
||
|
||
# 配置路径
|
||
CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config"
|
||
LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json"
|
||
|
||
# 遍历所有yaml文件
|
||
def find_all_yaml_files(directory):
|
||
yaml_files = []
|
||
for root, dirs, files in os.walk(directory):
|
||
for file in files:
|
||
if file.endswith(".yaml") and not file.startswith("BJTaxi"):
|
||
yaml_files.append(os.path.join(root, file))
|
||
return yaml_files
|
||
|
||
# 生成launch配置字符串
|
||
def generate_launch_config_string(yaml_files):
|
||
config_strings = []
|
||
|
||
for file_path in yaml_files:
|
||
# 提取模型名和数据集名
|
||
relative_path = os.path.relpath(file_path, CONFIG_DIR)
|
||
model_name = relative_path.split(os.sep)[0]
|
||
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
|
||
|
||
# 处理v2版本
|
||
if "v2_" in dataset_name:
|
||
model_display_name = f"{model_name}_v2"
|
||
dataset_display_name = dataset_name.replace("v2_", "")
|
||
else:
|
||
model_display_name = model_name
|
||
dataset_display_name = dataset_name
|
||
|
||
# 生成配置字符串
|
||
config_string = f'''
|
||
{{
|
||
"name": "{model_display_name}: {dataset_display_name}",
|
||
"type": "debugpy",
|
||
"request": "launch",
|
||
"program": "run.py",
|
||
"console": "integratedTerminal",
|
||
"args": "--config ./config/{model_name}/{os.path.basename(file_path)}"
|
||
}}'''
|
||
|
||
config_strings.append(config_string)
|
||
|
||
return ",".join(config_strings)
|
||
|
||
# 读取现有的launch.json文件,提取配置名称
|
||
def get_existing_config_names():
|
||
with open(LAUNCH_FILE, 'r') as f:
|
||
content = f.read()
|
||
|
||
# 提取所有配置名称
|
||
name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"')
|
||
matches = name_pattern.findall(content)
|
||
|
||
return set(matches)
|
||
|
||
# 生成新的配置,过滤掉已存在的
|
||
def generate_new_configs(yaml_files, existing_names):
|
||
new_configs = []
|
||
|
||
for file_path in yaml_files:
|
||
# 提取模型名和数据集名
|
||
relative_path = os.path.relpath(file_path, CONFIG_DIR)
|
||
model_name = relative_path.split(os.sep)[0]
|
||
dataset_name = os.path.splitext(os.path.basename(file_path))[0]
|
||
|
||
# 处理v2版本
|
||
if "v2_" in dataset_name:
|
||
model_display_name = f"{model_name}_v2"
|
||
dataset_display_name = dataset_name.replace("v2_", "")
|
||
else:
|
||
model_display_name = model_name
|
||
dataset_display_name = dataset_name
|
||
|
||
# 生成配置名称
|
||
config_name = f"{model_display_name}: {dataset_display_name}"
|
||
|
||
# 如果配置不存在,则添加
|
||
if config_name not in existing_names:
|
||
new_configs.append(file_path)
|
||
|
||
return new_configs
|
||
|
||
# 更新launch.json文件
|
||
def update_launch_json(new_configs_string):
|
||
with open(LAUNCH_FILE, 'r') as f:
|
||
content = f.read()
|
||
|
||
# 找到configurations数组的结束位置
|
||
configs_end_match = re.search(r'\s*\]\s*\}', content)
|
||
if not configs_end_match:
|
||
return False
|
||
|
||
# 插入新的配置
|
||
insert_pos = configs_end_match.start()
|
||
new_content = content[:insert_pos] + new_configs_string + content[insert_pos:]
|
||
|
||
# 保存文件
|
||
with open(LAUNCH_FILE, 'w') as f:
|
||
f.write(new_content)
|
||
|
||
return True
|
||
|
||
# 主函数
|
||
def main():
|
||
# 查找所有yaml文件
|
||
yaml_files = find_all_yaml_files(CONFIG_DIR)
|
||
|
||
# 获取现有配置名称
|
||
existing_names = get_existing_config_names()
|
||
|
||
# 生成新的配置,过滤掉已存在的
|
||
new_config_files = generate_new_configs(yaml_files, existing_names)
|
||
|
||
if not new_config_files:
|
||
print("No new configurations to add")
|
||
return
|
||
|
||
# 生成新的配置字符串
|
||
new_configs_string = generate_launch_config_string(new_config_files)
|
||
|
||
# 更新launch.json文件
|
||
if update_launch_json(new_configs_string):
|
||
print(f"Added {len(new_config_files)} new launch configurations")
|
||
print(f"Total configurations: {len(existing_names) + len(new_config_files)}")
|
||
else:
|
||
print("Failed to update launch.json")
|
||
|
||
if __name__ == "__main__":
|
||
main() |