52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from langchain.agents.middleware import wrap_tool_call, before_model, dynamic_prompt, ModelRequest
|
|
from langchain.tools.tool_node import ToolCallRequest
|
|
from typing import Callable
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.types import Command
|
|
from utils.logger_handler import logger
|
|
from langchain.agents import AgentState
|
|
from langgraph.runtime import Runtime
|
|
from utils.prompt_loader import load_system_prompts, load_report_prompts
|
|
|
|
@wrap_tool_call
|
|
def monitor_tool(
|
|
# 函数
|
|
request: ToolCallRequest,
|
|
# 入参
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
) -> ToolMessage | Command:
|
|
logger.info(f"[tool monitor]执行工具:{request.tool_call['name']}")
|
|
logger.info(f"[tool monitor]传入参数:{request.tool_call['args']}")
|
|
|
|
try:
|
|
result = handler(request)
|
|
logger.info(f"[tool minitor]工具{request.tool_call['name']}调用成功")
|
|
|
|
if request.tool_call['name'] == 'fill_context_for_report':
|
|
request.runtime.context["report"] = True
|
|
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"工具{request.tool_call['name']}调用失败,原因: {str(e)}")
|
|
raise e
|
|
|
|
|
|
@before_model
|
|
def log_before_model(
|
|
state: AgentState, # 状态记录
|
|
runtine: Runtime, # 执行过程 上下文信息
|
|
):
|
|
logger.info(f"[log_before_model]即将调用模型,带有{len(state["messages"])}条消息")
|
|
logger.debug(f"[log_before_model]{type(state["messages"][-1]).__name__} | "
|
|
f"消息内容:{state["messages"][-1].content.strip()}")
|
|
return None
|
|
|
|
|
|
@dynamic_prompt # 每一次提示词生成前,调用
|
|
def repoet_prompt_switch(request: ModelRequest):
|
|
is_report = request.runtime.context.get("report", False)
|
|
if is_report:
|
|
return load_report_prompts()
|
|
|
|
return load_system_prompts()
|