RAG_proj/app_qa.py

56 lines
1.4 KiB
Python

import streamlit as st
import time
from rag import RAGService
import config_data as config
# 标题
st.title("智能客服")
st.divider()
if "message" not in st.session_state:
st.session_state["message"] = list()
msg_list = st.session_state["message"]
msg_list.append({"role": "assistant", "content": "你好有什么可以帮您"})
else:
msg_list = st.session_state["message"]
def add_history(role, content):
msg_list.append({"role": role, "content": content})
def show_history():
for msg in msg_list:
st.chat_message(msg["role"]).write(msg["content"])
show_history()
if "RAGService" not in st.session_state:
st.session_state["RAGService"] = RAGService()
rag_service = st.session_state["RAGService"]
else:
rag_service = st.session_state["RAGService"]
def get_answer(prompt : str):
res = rag_service.chain.stream({"input": prompt}, config.session_config)
return res
prompt = st.chat_input()
if prompt:
st.chat_message("user").write(prompt)
add_history("user", prompt)
with st.spinner("AI thinking..."):
time.sleep(0.5)
ans = get_answer(prompt)
ai_list = []
def capture(generator, cache_list):
for chunk in generator:
cache_list.append(chunk)
yield chunk
st.chat_message("assistant").write_stream(capture(ans, ai_list))
text = "".join(ai_list)
add_history("assistant", text)