84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
"""
|
|
rag service类
|
|
带历史记忆
|
|
"""
|
|
from langchain_community.chat_models import ChatTongyi
|
|
from langchain_community.embeddings import DashScopeEmbeddings
|
|
from langchain_core.documents import Document
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
|
|
|
|
from file_hisroty_store import get_history
|
|
import config_data as config
|
|
from vector_stores import VectorStoreService
|
|
|
|
class RAGService(object):
|
|
def __init__(self):
|
|
self.vector_service = VectorStoreService(
|
|
embedding=DashScopeEmbeddings(model = config.embedding_model_name)
|
|
)
|
|
self.prompt_template = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", "以我提供的已知参考资料为主,简介专业回答用户问题,参考资料: {context}"),
|
|
("system", "用户的对话历史记录,如下"),
|
|
MessagesPlaceholder("history"),
|
|
("user", "请回答用户提问: {input}"),
|
|
]
|
|
)
|
|
self.chat_model = ChatTongyi(model = config.chat_model_name)
|
|
self.chain = self.__get_chain()
|
|
|
|
def format_document(self, docs : list[Document]):
|
|
format_str = ""
|
|
if not docs:
|
|
return "无相关参考资料"
|
|
for doc in docs:
|
|
format_str += f"文档片段:{doc.page_content} \n元数据 {doc.metadata} \n\n"
|
|
return format_str
|
|
|
|
def __get_chain(self):
|
|
"""获取最终执行链"""
|
|
retriever = self.vector_service.get_retriever()
|
|
|
|
def format_for_retriever(value):
|
|
return value["input"]
|
|
|
|
def format_for_prompt_template(value):
|
|
return {"input": value["input"]["input"],
|
|
"context": value["context"],
|
|
"history": value["input"]["history"]}
|
|
|
|
chain = (
|
|
{
|
|
"input": RunnablePassthrough(),
|
|
"context": RunnableLambda(format_for_retriever)| retriever | self.format_document
|
|
}
|
|
| RunnableLambda(format_for_prompt_template)
|
|
| self.prompt_template |self.chat_model | StrOutputParser()
|
|
)
|
|
|
|
# 历史增强
|
|
conversation_chain = RunnableWithMessageHistory(
|
|
chain,
|
|
get_history,
|
|
input_messages_key="input",
|
|
history_messages_key="history",
|
|
)
|
|
|
|
return conversation_chain
|
|
|
|
if __name__ == '__main__':
|
|
# session_id 配置
|
|
session_config = {
|
|
"configurable": {
|
|
"session_id": "user_001",
|
|
}
|
|
}
|
|
# input 要求字典
|
|
res = RAGService().chain.stream({"input": "春天穿什么颜色的衣服"}, session_config)
|
|
for chunks in res:
|
|
print(chunks, end="", flush=True)
|
|
|
|
|