agent_proj/lesson/04_middle_ware.py

71 lines
2.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
agent 前后
model 前后
工具 中
模型 中
"""
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_agent, after_agent, before_model, after_model, wrap_model_call, \
wrap_tool_call
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_core.tools import tool
from langgraph.runtime import Runtime
@tool(description="查询天气, 传入城市名称字符串,返回字符串天气信息")
def get_weather(city: str) -> str:
return f"{city} : 晴天"
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"before agent: info_num: {len(state["messages"])}")
@after_agent
def log_after_agent(state: AgentState, runtime: Runtime) -> None:
print(f"after agent: info_num: {len(state["messages"])}")
@before_model
def log_before_model(state: AgentState, runtime: Runtime) -> None:
print(f"before model: info_num: {len(state["messages"])}")
@after_model
def log_after_model(state: AgentState, runtime: Runtime) -> None:
print(f"after model: info_num: {len(state["messages"])}")
@wrap_model_call
def model_call_hook(request, handler):
print(f"model call: {request}")
return handler(request)
@wrap_tool_call
def model_tool_hook(request, handler):
print(f"model tool: {request.tool_call['name']}")
print(f"args: {request.tool_call['args']}")
return handler(request)
agent = create_agent(
model=ChatTongyi(model="qwen3-max"),
tools=[get_weather],
middleware=[model_call_hook, model_tool_hook, log_before_model,
log_after_model, log_before_agent, log_after_agent],
system_prompt="""你是严格遵循ReAct框架的智能体必须按[思考,行动,观察,再思考]的流程解决问题
每轮仅能思考并调用1个工具禁止单词调用多个工具。并告知我你的思考过程工具调用的原因按思考、行动
、观察三个结构告知我"""
)
res = agent.stream(
{"messages": [{
"role": "user", "content": "查询北京的天气"
}]},
stream_mode="values"
)
for chunk in res:
print(chunk["messages"][-1].content)