56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
import streamlit as st
|
|
import time
|
|
from rag import RAGService
|
|
import config_data as config
|
|
|
|
# 标题
|
|
st.title("智能客服")
|
|
st.divider()
|
|
|
|
if "message" not in st.session_state:
|
|
st.session_state["message"] = list()
|
|
msg_list = st.session_state["message"]
|
|
msg_list.append({"role": "assistant", "content": "你好有什么可以帮您"})
|
|
else:
|
|
msg_list = st.session_state["message"]
|
|
|
|
def add_history(role, content):
|
|
msg_list.append({"role": role, "content": content})
|
|
|
|
def show_history():
|
|
for msg in msg_list:
|
|
st.chat_message(msg["role"]).write(msg["content"])
|
|
|
|
show_history()
|
|
|
|
if "RAGService" not in st.session_state:
|
|
st.session_state["RAGService"] = RAGService()
|
|
rag_service = st.session_state["RAGService"]
|
|
else:
|
|
rag_service = st.session_state["RAGService"]
|
|
|
|
def get_answer(prompt : str):
|
|
res = rag_service.chain.stream({"input": prompt}, config.session_config)
|
|
return res
|
|
|
|
|
|
prompt = st.chat_input()
|
|
|
|
if prompt:
|
|
st.chat_message("user").write(prompt)
|
|
add_history("user", prompt)
|
|
|
|
with st.spinner("AI thinking..."):
|
|
time.sleep(0.5)
|
|
ans = get_answer(prompt)
|
|
|
|
ai_list = []
|
|
def capture(generator, cache_list):
|
|
for chunk in generator:
|
|
cache_list.append(chunk)
|
|
yield chunk
|
|
|
|
st.chat_message("assistant").write_stream(capture(ans, ai_list))
|
|
text = "".join(ai_list)
|
|
add_history("assistant", text)
|