40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
from langchain.agents import create_agent
|
||
from model.factory import chat_model
|
||
from utils.prompt_loader import load_system_prompts
|
||
from agent.tools.agent_tools import (rag_summarize, get_weather, get_user_id,
|
||
get_user_location, get_current_month,
|
||
fill_context_for_report, fetch_external_data)
|
||
from agent.tools.middleware import monitor_tool, log_before_model, repoet_prompt_switch
|
||
|
||
|
||
class ReactAgent:
|
||
def __init__(self):
|
||
self.agent = create_agent(
|
||
model=chat_model,
|
||
system_prompt=load_system_prompts(),
|
||
tools=[rag_summarize, get_weather, get_user_location, get_user_id,
|
||
get_current_month, fetch_external_data, fill_context_for_report],
|
||
middleware=[monitor_tool, log_before_model, repoet_prompt_switch],
|
||
)
|
||
|
||
def excute_stream(self, query: str):
|
||
input_dict = {
|
||
"messages": [
|
||
{"role": "user", "content": query}
|
||
]
|
||
}
|
||
|
||
# 上下文runtime信息,做提示词切换标记
|
||
response = self.agent.stream(input_dict, stream_mode="values", context={"report": False})
|
||
for chunk in response:
|
||
latest_message = chunk["messages"][-1]
|
||
if latest_message.content:
|
||
yield latest_message.content.strip() + "\n"
|
||
|
||
|
||
if __name__ == '__main__':
|
||
agent = ReactAgent();
|
||
stream = agent.excute_stream("扫地机器人在我所在地的气温下如何保养")
|
||
for chunk in stream:
|
||
print(chunk, end="", flush=True)
|