58 lines
2.0 KiB
Python
Executable File
58 lines
2.0 KiB
Python
Executable File
import os
|
|
import json
|
|
import importlib
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
class ModelRegistry:
|
|
def __init__(self):
|
|
self.models = {}
|
|
self.model_configs = {}
|
|
self.model_dir = Path(__file__).parent
|
|
self._load_model_configs()
|
|
|
|
def _load_model_configs(self):
|
|
"""加载所有model_config.json文件"""
|
|
# 直接遍历所有model_config.json文件
|
|
for config_path in self.model_dir.rglob("model_config.json"):
|
|
# 读取配置文件
|
|
with open(config_path, 'r') as f:
|
|
configs = json.load(f)
|
|
|
|
# 处理每个模型配置
|
|
for config in configs:
|
|
model_name = config["name"]
|
|
# 检查模型名冲突
|
|
assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}"
|
|
self.model_configs[model_name] = config
|
|
|
|
def _load_model(self, model_name):
|
|
"""动态加载模型"""
|
|
if model_name not in self.model_configs:
|
|
raise ValueError(f"模型 {model_name} 未注册")
|
|
|
|
config = self.model_configs[model_name]
|
|
module = importlib.import_module(config["module"])
|
|
model_cls = getattr(module, config["entry"])
|
|
self.models[model_name] = model_cls
|
|
|
|
def get_model(self, model_name):
|
|
"""获取模型类或函数"""
|
|
if model_name not in self.models:
|
|
self._load_model(model_name)
|
|
return self.models[model_name]
|
|
|
|
# 初始化模型注册表
|
|
model_registry = ModelRegistry()
|
|
|
|
def model_selector(config):
|
|
model_name = config["basic"]["model"]
|
|
model_config = config["model"]
|
|
|
|
model_cls = model_registry.get_model(model_name)
|
|
model = model_cls(model_config)
|
|
# print(f"\n=== 模型选择结果 ===")
|
|
print(f"选择的模型: {model_name}")
|
|
print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}")
|
|
return model
|