agent_proj/model/factory.py

28 lines
897 B
Python

from abc import ABC, abstractmethod
from typing import Optional
from utils.config_handler import rag_conf
from langchain_core.embeddings import Embeddings
from langchain_community.chat_models.tongyi import BaseChatModel
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
class BaseModelFactory(ABC):
@abstractmethod
def generator(self) -> Optional[Embeddings | BaseChatModel]:
pass
class ChatModelFactory(BaseModelFactory):
def generator(self) -> Optional[BaseChatModel]:
return ChatTongyi(model=rag_conf["chat_model_name"])
class EmbeddingsFactory(BaseModelFactory):
def generator(self) -> Optional[Embeddings]:
return DashScopeEmbeddings(model=rag_conf["embeddings_model_name"])
chat_model = ChatModelFactory().generator()
embedding_model = EmbeddingsFactory().generator()