48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
"""rag总结服务:用户提问,搜索参考资料,提问+参考资料提供模型,让模型总结回复"""
|
||
from langchain_core.output_parsers import StrOutputParser
|
||
|
||
from rag.vector_store import VectorStoreSerivce
|
||
from utils.prompt_loader import load_rag_prompts
|
||
from langchain_core.prompts import PromptTemplate
|
||
from model.factory import chat_model
|
||
from langchain_community.docstore.document import Document
|
||
from typing import List
|
||
|
||
class RagSummarizeService:
|
||
def __init__(self):
|
||
self.vector_store = VectorStoreSerivce()
|
||
self.retriever = self.vector_store.get_retriever()
|
||
self.prompt_template = load_rag_prompts()
|
||
self.prompt_text = PromptTemplate.from_template(self.prompt_template)
|
||
self.model = chat_model
|
||
self.chain = self._init_chain()
|
||
|
||
|
||
def _init_chain(self):
|
||
chain = self.prompt_text | self.model | StrOutputParser()
|
||
return chain
|
||
|
||
def retriever_docs(self, query: str) -> List[Document]:
|
||
return self.retriever.invoke(query)
|
||
|
||
def rag_summarize(self, query: str) -> str:
|
||
context_docs = self.retriever_docs(query)
|
||
context = ""
|
||
counter = 0
|
||
for doc in context_docs:
|
||
counter += 1
|
||
context += f"[参考资料{counter}]: 参考资料:{doc.page_content} | 参考源: {doc.metadata} \n"
|
||
|
||
return self.chain.invoke(
|
||
{
|
||
"input": query,
|
||
"context": context
|
||
}
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
rag = RagSummarizeService()
|
||
response = rag.rag_summarize("查询小户型适合那些扫地机器人")
|
||
print(response)
|