import os import json import importlib import sys from pathlib import Path class ModelRegistry: def __init__(self): self.models = {} self.model_configs = {} self.model_dir = Path(__file__).parent self._load_model_configs() def _load_model_configs(self): """加载所有model_config.json文件""" # 直接遍历所有model_config.json文件 for config_path in self.model_dir.rglob("model_config.json"): # 读取配置文件 with open(config_path, 'r') as f: configs = json.load(f) # 处理每个模型配置 for config in configs: model_name = config["name"] # 检查模型名冲突 assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}" self.model_configs[model_name] = config def _load_model(self, model_name): """动态加载模型""" if model_name not in self.model_configs: raise ValueError(f"模型 {model_name} 未注册") config = self.model_configs[model_name] module = importlib.import_module(config["module"]) model_cls = getattr(module, config["entry"]) self.models[model_name] = model_cls def get_model(self, model_name): """获取模型类或函数""" if model_name not in self.models: self._load_model(model_name) return self.models[model_name] # 初始化模型注册表 model_registry = ModelRegistry() def model_selector(config): model_name = config["basic"]["model"] model_config = config["model"] model_cls = model_registry.get_model(model_name) model = model_cls(model_config) # print(f"\n=== 模型选择结果 ===") print(f"选择的模型: {model_name}") print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}") return model