TrafficWheel/model/model_selector.py

58 lines
2.0 KiB
Python
Executable File

import os
import json
import importlib
import sys
from pathlib import Path
class ModelRegistry:
def __init__(self):
self.models = {}
self.model_configs = {}
self.model_dir = Path(__file__).parent
self._load_model_configs()
def _load_model_configs(self):
"""加载所有model_config.json文件"""
# 直接遍历所有model_config.json文件
for config_path in self.model_dir.rglob("model_config.json"):
# 读取配置文件
with open(config_path, 'r') as f:
configs = json.load(f)
# 处理每个模型配置
for config in configs:
model_name = config["name"]
# 检查模型名冲突
assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}"
self.model_configs[model_name] = config
def _load_model(self, model_name):
"""动态加载模型"""
if model_name not in self.model_configs:
raise ValueError(f"模型 {model_name} 未注册")
config = self.model_configs[model_name]
module = importlib.import_module(config["module"])
model_cls = getattr(module, config["entry"])
self.models[model_name] = model_cls
def get_model(self, model_name):
"""获取模型类或函数"""
if model_name not in self.models:
self._load_model(model_name)
return self.models[model_name]
# 初始化模型注册表
model_registry = ModelRegistry()
def model_selector(config):
model_name = config["basic"]["model"]
model_config = config["model"]
model_cls = model_registry.get_model(model_name)
model = model_cls(model_config)
# print(f"\n=== 模型选择结果 ===")
print(f"选择的模型: {model_name}")
print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}")
return model