Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
context_wrapper,
tool_call.call_id,
tool_call=tool_call,
agent=agent,
)
Comment on lines 820 to 824

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate agent in realtime tool invocations too

This adds agent to ToolContext for the standard tool execution path, but realtime tool calls still build ToolContext without an agent (see src/agents/realtime/session.py in _handle_tool_call, where ToolContext(...) is constructed from context/usage only). As a result, tools invoked via RealtimeSession will still see tool_ctx.agent is None, so the new feature does not work in that flow. If the intent is to let tools determine the calling agent, please pass the agent there as well.

Useful? React with 👍 / 👎.

agent_hooks = agent.hooks
if config.trace_include_sensitive_data:
Expand Down
8 changes: 8 additions & 0 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .usage import Usage

if TYPE_CHECKING:
from .agent import Agent
from .items import TResponseInputItem
from .run_context import _ApprovalRecord

Expand Down Expand Up @@ -44,6 +45,9 @@ class ToolContext(RunContextWrapper[TContext]):
tool_call: ResponseFunctionToolCall | None = None
"""The tool call object associated with this invocation."""

agent: Agent[Any] | None = None
"""The agent that is calling this tool, if available."""

def __init__(
self,
context: TContext,
Expand All @@ -53,6 +57,7 @@ def __init__(
tool_arguments: str | object = _MISSING,
tool_call: ResponseFunctionToolCall | None = None,
*,
agent: Agent[Any] | None = None,
turn_input: list[TResponseInputItem] | None = None,
_approvals: dict[str, _ApprovalRecord] | None = None,
tool_input: Any | None = None,
Expand Down Expand Up @@ -80,13 +85,15 @@ def __init__(
else cast(str, tool_call_id)
)
self.tool_call = tool_call
self.agent = agent

@classmethod
def from_agent_context(
cls,
context: RunContextWrapper[TContext],
tool_call_id: str,
tool_call: ResponseFunctionToolCall | None = None,
agent: Agent[Any] | None = None,
) -> ToolContext:
"""
Create a ToolContext from a RunContextWrapper.
Expand All @@ -105,6 +112,7 @@ def from_agent_context(
tool_call_id=tool_call_id,
tool_arguments=tool_args,
tool_call=tool_call,
agent=agent,
**base_values,
)
return tool_context
45 changes: 45 additions & 0 deletions tests/test_tool_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from openai.types.responses import ResponseFunctionToolCall

from agents import Agent
from agents.run_context import RunContextWrapper
from agents.tool_context import ToolContext
from tests.utils.hitl import make_context_wrapper
Expand Down Expand Up @@ -36,3 +37,47 @@ def test_tool_context_from_agent_context_populates_fields() -> None:
assert tool_ctx.tool_name == "test_tool"
assert tool_ctx.tool_call_id == "call-123"
assert tool_ctx.tool_arguments == '{"a": 1}'


def test_tool_context_agent_none_by_default() -> None:
"""Agent field defaults to None for backward compatibility."""
tool_call = ResponseFunctionToolCall(
type="function_call",
name="test_tool",
call_id="call-1",
arguments="{}",
)
ctx = make_context_wrapper()
tool_ctx = ToolContext.from_agent_context(ctx, tool_call_id="call-1", tool_call=tool_call)
assert tool_ctx.agent is None


def test_tool_context_agent_from_agent_context() -> None:
"""Agent is populated when passed to from_agent_context."""
agent = Agent(name="test-agent", instructions="do stuff")
tool_call = ResponseFunctionToolCall(
type="function_call",
name="test_tool",
call_id="call-2",
arguments="{}",
)
ctx = make_context_wrapper()
tool_ctx = ToolContext.from_agent_context(
ctx, tool_call_id="call-2", tool_call=tool_call, agent=agent
)
assert tool_ctx.agent is agent
assert tool_ctx.agent.name == "test-agent"


def test_tool_context_agent_via_constructor() -> None:
"""Agent is accessible when passed directly to the ToolContext constructor."""
agent = Agent(name="direct-agent", instructions="hi")
tool_ctx: ToolContext[dict[str, object]] = ToolContext(
context={},
tool_name="my_tool",
tool_call_id="call-3",
tool_arguments="{}",
agent=agent,
)
assert tool_ctx.agent is agent
assert tool_ctx.agent.name == "direct-agent"