"""rag总结服务:用户提问,搜索参考资料,提问+参考资料提供模型,让模型总结回复""" from langchain_core.output_parsers import StrOutputParser from rag.vector_store import VectorStoreSerivce from utils.prompt_loader import load_rag_prompts from langchain_core.prompts import PromptTemplate from model.factory import chat_model from langchain_community.docstore.document import Document from typing import List from utils.logger_handler import logger class RagSummarizeService: def __init__(self): self.vector_store = VectorStoreSerivce() self.retriever = self.vector_store.get_retriever() self.prompt_template = load_rag_prompts() self.prompt_text = PromptTemplate.from_template(self.prompt_template) self.model = chat_model self.chain = self._init_chain() def _init_chain(self): chain = self.prompt_text | self.model | StrOutputParser() return chain def retriever_docs(self, query: str) -> List[Document]: return self.retriever.invoke(query) def rag_summarize(self, query: str) -> str: context_docs = self.retriever_docs(query) context = "" counter = 0 logger.info(f"[rag_summarize]: 召回了{len(context_docs)}条参考资料") for doc in context_docs: counter += 1 context += f"[参考资料{counter}]: 参考资料:{doc.page_content} | 参考源: {doc.metadata} \n" return self.chain.invoke( { "input": query, "context": context } ) if __name__ == '__main__': rag = RagSummarizeService() response = rag.rag_summarize("查询小户型适合那些扫地机器人") print(response)