agent_proj/rag/rag_service.py

48 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""rag总结服务用户提问搜索参考资料提问+参考资料提供模型,让模型总结回复"""
from langchain_core.output_parsers import StrOutputParser
from rag.vector_store import VectorStoreSerivce
from utils.prompt_loader import load_rag_prompts
from langchain_core.prompts import PromptTemplate
from model.factory import chat_model
from langchain_community.docstore.document import Document
from typing import List
class RagSummarizeService:
def __init__(self):
self.vector_store = VectorStoreSerivce()
self.retriever = self.vector_store.get_retriever()
self.prompt_template = load_rag_prompts()
self.prompt_text = PromptTemplate.from_template(self.prompt_template)
self.model = chat_model
self.chain = self._init_chain()
def _init_chain(self):
chain = self.prompt_text | self.model | StrOutputParser()
return chain
def retriever_docs(self, query: str) -> List[Document]:
return self.retriever.invoke(query)
def rag_summarize(self, query: str) -> str:
context_docs = self.retriever_docs(query)
context = ""
counter = 0
for doc in context_docs:
counter += 1
context += f"[参考资料{counter}]: 参考资料:{doc.page_content} | 参考源: {doc.metadata} \n"
return self.chain.invoke(
{
"input": query,
"context": context
}
)
if __name__ == '__main__':
rag = RagSummarizeService()
response = rag.rag_summarize("查询小户型适合那些扫地机器人")
print(response)