diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index bc370ea61..6d41627ad 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -820,6 +820,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo context_wrapper, tool_call.call_id, tool_call=tool_call, + agent=agent, ) agent_hooks = agent.hooks if config.trace_include_sensitive_data: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index d2d156b79..4e5da6c94 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -9,6 +9,7 @@ from .usage import Usage if TYPE_CHECKING: + from .agent import Agent from .items import TResponseInputItem from .run_context import _ApprovalRecord @@ -44,6 +45,9 @@ class ToolContext(RunContextWrapper[TContext]): tool_call: ResponseFunctionToolCall | None = None """The tool call object associated with this invocation.""" + agent: Agent[Any] | None = None + """The agent that is calling this tool, if available.""" + def __init__( self, context: TContext, @@ -53,6 +57,7 @@ def __init__( tool_arguments: str | object = _MISSING, tool_call: ResponseFunctionToolCall | None = None, *, + agent: Agent[Any] | None = None, turn_input: list[TResponseInputItem] | None = None, _approvals: dict[str, _ApprovalRecord] | None = None, tool_input: Any | None = None, @@ -80,6 +85,7 @@ def __init__( else cast(str, tool_call_id) ) self.tool_call = tool_call + self.agent = agent @classmethod def from_agent_context( @@ -87,6 +93,7 @@ def from_agent_context( context: RunContextWrapper[TContext], tool_call_id: str, tool_call: ResponseFunctionToolCall | None = None, + agent: Agent[Any] | None = None, ) -> ToolContext: """ Create a ToolContext from a RunContextWrapper. @@ -105,6 +112,7 @@ def from_agent_context( tool_call_id=tool_call_id, tool_arguments=tool_args, tool_call=tool_call, + agent=agent, **base_values, ) return tool_context diff --git a/tests/test_tool_context.py b/tests/test_tool_context.py index 4edd79522..32d183768 100644 --- a/tests/test_tool_context.py +++ b/tests/test_tool_context.py @@ -1,6 +1,7 @@ import pytest from openai.types.responses import ResponseFunctionToolCall +from agents import Agent from agents.run_context import RunContextWrapper from agents.tool_context import ToolContext from tests.utils.hitl import make_context_wrapper @@ -36,3 +37,47 @@ def test_tool_context_from_agent_context_populates_fields() -> None: assert tool_ctx.tool_name == "test_tool" assert tool_ctx.tool_call_id == "call-123" assert tool_ctx.tool_arguments == '{"a": 1}' + + +def test_tool_context_agent_none_by_default() -> None: + """Agent field defaults to None for backward compatibility.""" + tool_call = ResponseFunctionToolCall( + type="function_call", + name="test_tool", + call_id="call-1", + arguments="{}", + ) + ctx = make_context_wrapper() + tool_ctx = ToolContext.from_agent_context(ctx, tool_call_id="call-1", tool_call=tool_call) + assert tool_ctx.agent is None + + +def test_tool_context_agent_from_agent_context() -> None: + """Agent is populated when passed to from_agent_context.""" + agent = Agent(name="test-agent", instructions="do stuff") + tool_call = ResponseFunctionToolCall( + type="function_call", + name="test_tool", + call_id="call-2", + arguments="{}", + ) + ctx = make_context_wrapper() + tool_ctx = ToolContext.from_agent_context( + ctx, tool_call_id="call-2", tool_call=tool_call, agent=agent + ) + assert tool_ctx.agent is agent + assert tool_ctx.agent.name == "test-agent" + + +def test_tool_context_agent_via_constructor() -> None: + """Agent is accessible when passed directly to the ToolContext constructor.""" + agent = Agent(name="direct-agent", instructions="hi") + tool_ctx: ToolContext[dict[str, object]] = ToolContext( + context={}, + tool_name="my_tool", + tool_call_id="call-3", + tool_arguments="{}", + agent=agent, + ) + assert tool_ctx.agent is agent + assert tool_ctx.agent.name == "direct-agent"