""" rag service类 带历史记忆 """ from langchain_community.chat_models import ChatTongyi from langchain_community.embeddings import DashScopeEmbeddings from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda from file_hisroty_store import get_history import config_data as config from vector_stores import VectorStoreService class RAGService(object): def __init__(self): self.vector_service = VectorStoreService( embedding=DashScopeEmbeddings(model = config.embedding_model_name) ) self.prompt_template = ChatPromptTemplate.from_messages( [ ("system", "以我提供的已知参考资料为主,简介专业回答用户问题,参考资料: {context}"), ("system", "用户的对话历史记录,如下"), MessagesPlaceholder("history"), ("user", "请回答用户提问: {input}"), ] ) self.chat_model = ChatTongyi(model = config.chat_model_name) self.chain = self.__get_chain() def format_document(self, docs : list[Document]): format_str = "" if not docs: return "无相关参考资料" for doc in docs: format_str += f"文档片段:{doc.page_content} \n元数据 {doc.metadata} \n\n" return format_str def __get_chain(self): """获取最终执行链""" retriever = self.vector_service.get_retriever() def format_for_retriever(value): return value["input"] def format_for_prompt_template(value): return {"input": value["input"]["input"], "context": value["context"], "history": value["input"]["history"]} chain = ( { "input": RunnablePassthrough(), "context": RunnableLambda(format_for_retriever)| retriever | self.format_document } | RunnableLambda(format_for_prompt_template) | self.prompt_template |self.chat_model | StrOutputParser() ) # 历史增强 conversation_chain = RunnableWithMessageHistory( chain, get_history, input_messages_key="input", history_messages_key="history", ) return conversation_chain if __name__ == '__main__': # session_id 配置 session_config = { "configurable": { "session_id": "user_001", } } # input 要求字典 res = RAGService().chain.stream({"input": "春天穿什么颜色的衣服"}, session_config) for chunks in res: print(chunks, end="", flush=True)