RAG_proj/rag.py

84 lines
2.8 KiB
Python

"""
rag service类
带历史记忆
"""
from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
from file_hisroty_store import get_history
import config_data as config
from vector_stores import VectorStoreService
class RAGService(object):
def __init__(self):
self.vector_service = VectorStoreService(
embedding=DashScopeEmbeddings(model = config.embedding_model_name)
)
self.prompt_template = ChatPromptTemplate.from_messages(
[
("system", "以我提供的已知参考资料为主,简介专业回答用户问题,参考资料: {context}"),
("system", "用户的对话历史记录,如下"),
MessagesPlaceholder("history"),
("user", "请回答用户提问: {input}"),
]
)
self.chat_model = ChatTongyi(model = config.chat_model_name)
self.chain = self.__get_chain()
def format_document(self, docs : list[Document]):
format_str = ""
if not docs:
return "无相关参考资料"
for doc in docs:
format_str += f"文档片段:{doc.page_content} \n元数据 {doc.metadata} \n\n"
return format_str
def __get_chain(self):
"""获取最终执行链"""
retriever = self.vector_service.get_retriever()
def format_for_retriever(value):
return value["input"]
def format_for_prompt_template(value):
return {"input": value["input"]["input"],
"context": value["context"],
"history": value["input"]["history"]}
chain = (
{
"input": RunnablePassthrough(),
"context": RunnableLambda(format_for_retriever)| retriever | self.format_document
}
| RunnableLambda(format_for_prompt_template)
| self.prompt_template |self.chat_model | StrOutputParser()
)
# 历史增强
conversation_chain = RunnableWithMessageHistory(
chain,
get_history,
input_messages_key="input",
history_messages_key="history",
)
return conversation_chain
if __name__ == '__main__':
# session_id 配置
session_config = {
"configurable": {
"session_id": "user_001",
}
}
# input 要求字典
res = RAGService().chain.stream({"input": "春天穿什么颜色的衣服"}, session_config)
for chunks in res:
print(chunks, end="", flush=True)