Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 5 additions & 59 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ bidi = [
"prompt_toolkit>=3.0.0,<4.0.0",
"pyaudio>=0.2.13,<1.0.0",
"smithy-aws-core>=0.0.1; python_version>='3.12'",
"strands-agents[gemini]",
"websockets>=15.0.0,<16.0.0",
]
bidi-gemini = ["google-genai>=1.32.0,<2.0.0"]
bidi-openai = ["websockets>=15.0.0,<16.0.0"]

all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"]
all = ["strands-agents[a2a,anthropic,bidi,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]

dev = [
"commitizen>=4.4.0,<5.0.0",
Expand Down Expand Up @@ -130,7 +129,7 @@ format-fix = [
]
lint-check = [
"ruff check",
"mypy ./src"
"mypy -p src"
]
lint-fix = [
"ruff check --fix"
Expand Down Expand Up @@ -204,16 +203,10 @@ warn_no_return = true
warn_unreachable = true
follow_untyped_imports = true
ignore_missing_imports = false
exclude = ["src/strands/experimental/bidi"]

[[tool.mypy.overrides]]
module = ["strands.experimental.bidi.*"]
follow_imports = "skip"

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"]

[tool.ruff.lint]
select = [
Expand All @@ -236,16 +229,14 @@ convention = "google"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function"
addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi"

addopts = "--ignore=tests_integ/bidi" # TODO: install portaudio for github workflow

[tool.coverage.run]
branch = true
source = ["src"]
context = "thread"
parallel = true
concurrency = ["thread", "multiprocessing"]
omit = ["src/strands/experimental/bidi/*"]

[tool.coverage.report]
show_missing = true
Expand Down Expand Up @@ -275,48 +266,3 @@ style = [
["text", ""],
["disabled", "fg:#858585 italic"]
]

# =========================
# Bidi development configs
# =========================

[tool.hatch.envs.bidi]
dev-mode = true
features = ["dev", "bidi-all"]
installer = "uv"

[tool.hatch.envs.bidi.scripts]
prepare = [
"hatch run bidi-lint:format-fix",
"hatch run bidi-lint:quality-fix",
"hatch run bidi-lint:type-check",
"hatch run bidi-test:test-cov",
]

[tools.hatch.envs.bidi-lint]
template = "bidi"

[tool.hatch.envs.bidi-lint.scripts]
format-check = "format-fix --check"
format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-fix = "quality-check --fix"
type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py"

[tool.hatch.envs.bidi-test]
template = "bidi"

[tool.hatch.envs.bidi-test.scripts]
test = "pytest {args} tests/strands/experimental/bidi"
test-cov = """
test \
--cov=strands.experimental.bidi \
--cov-config= \
--cov-branch \
--cov-report=term-missing \
--cov-report=xml:build/coverage/bidi-coverage.xml \
--cov-report=html:build/coverage/bidi-html
"""

[[tool.hatch.envs.bidi-test.matrix]]
python = ["3.13", "3.12"]
8 changes: 0 additions & 8 deletions src/strands/experimental/bidi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Bidirectional streaming package."""

import sys

if sys.version_info < (3, 12):
raise ImportError("bidi only supported for >= Python 3.12")

# Main components - Primary user interface
# Re-export standard agent events for tool handling
from ...types._events import (
Expand All @@ -19,7 +14,6 @@

# Model interface (for custom implementations)
from .models.model import BidiModel
from .models.nova_sonic import BidiNovaSonicModel

# Built-in tools
from .tools import stop_conversation
Expand Down Expand Up @@ -48,8 +42,6 @@
"BidiAgent",
# IO channels
"BidiAudioIO",
# Model providers
"BidiNovaSonicModel",
# Built-in tools
"stop_conversation",
# Input Event types
Expand Down
15 changes: 7 additions & 8 deletions src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ...tools import ToolProvider
from .._async import _TaskGroup, stop_all
from ..models.model import BidiModel
from ..models.nova_sonic import BidiNovaSonicModel
from ..types.agent import BidiAgentInput
from ..types.events import (
BidiAudioInputEvent,
Expand Down Expand Up @@ -100,13 +99,13 @@
ValueError: If model configuration is invalid or state is invalid type.
TypeError: If model type is unsupported.
"""
self.model = (
BidiNovaSonicModel()
if not model
else BidiNovaSonicModel(model_id=model)
if isinstance(model, str)
else model
)
if isinstance(model, BidiModel):
self.model = model

Check warning on line 103 in src/strands/experimental/bidi/agent/agent.py

View workflow job for this annotation

GitHub Actions / check-api

BidiAgent.model

Attribute value was changed: `BidiNovaSonicModel() if not model else BidiNovaSonicModel(model_id=model) if isinstance(model, str) else model` -> `model`
else:
from ..models.nova_sonic import BidiNovaSonicModel

self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel()

self.system_prompt = system_prompt
self.messages = messages or []

Expand Down
15 changes: 7 additions & 8 deletions src/strands/experimental/bidi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
"""Bidirectional model interfaces and implementations."""

from typing import Any

from .model import BidiModel, BidiModelTimeoutError
from .nova_sonic import BidiNovaSonicModel

__all__ = [
"BidiModel",
"BidiModelTimeoutError",
"BidiNovaSonicModel",
]


def __getattr__(name: str) -> Any:
"""
Lazy load bidi model implementations only when accessed.

This defers the import of optional dependencies until actually needed:
- BidiGeminiLiveModel requires google-generativeai (lazy loaded)
- BidiOpenAIRealtimeModel requires openai (lazy loaded)
"""Lazy load bidi model implementations only when accessed.

This defers the import of optional dependencies until actually needed.
"""
if name == "BidiGeminiLiveModel":
from .gemini_live import BidiGeminiLiveModel

return BidiGeminiLiveModel
if name == "BidiNovaSonicModel":
from .nova_sonic import BidiNovaSonicModel

return BidiNovaSonicModel
if name == "BidiOpenAIRealtimeModel":
from .openai_realtime import BidiOpenAIRealtimeModel

Expand Down
3 changes: 2 additions & 1 deletion src/strands/experimental/bidi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import logging
from typing import Any, AsyncIterable, Protocol
from typing import Any, AsyncIterable, Protocol, runtime_checkable

from ....types._events import ToolResultEvent
from ....types.content import Messages
Expand All @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__)


@runtime_checkable
class BidiModel(Protocol):
"""Protocol for bidirectional streaming models.
Expand Down
30 changes: 23 additions & 7 deletions src/strands/experimental/bidi/models/nova_sonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,41 @@
- Tool execution with content containers and identifier tracking
- 8-minute connection limits with proper cleanup sequences
- Interruption detection through stopReason events

Note, BidiNovaSonicModel is only supported for Python 3.12+
"""

import asyncio
import sys

if sys.version_info < (3, 12):
raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+")

import asyncio # type: ignore[unreachable]
import base64
import json
import logging
import uuid
from typing import Any, AsyncGenerator, cast

import boto3
from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
from aws_sdk_bedrock_runtime.models import (
from aws_sdk_bedrock_runtime.client import ( # type: ignore[import-not-found]
BedrockRuntimeClient,
InvokeModelWithBidirectionalStreamOperationInput,
)
from aws_sdk_bedrock_runtime.config import ( # type: ignore[import-not-found]
Config,
HTTPAuthSchemeResolver,
SigV4AuthScheme,
)
from aws_sdk_bedrock_runtime.models import ( # type: ignore[import-not-found]
BidirectionalInputPayloadPart,
InvokeModelWithBidirectionalStreamInputChunk,
ModelTimeoutException,
ValidationException,
)
from smithy_aws_core.identity.static import StaticCredentialsResolver
from smithy_core.aio.eventstream import DuplexEventStream
from smithy_core.shapes import ShapeID
from smithy_aws_core.identity.static import StaticCredentialsResolver # type: ignore[import-not-found]
from smithy_core.aio.eventstream import DuplexEventStream # type: ignore[import-not-found]
from smithy_core.shapes import ShapeID # type: ignore[import-not-found]

from ....types._events import ToolResultEvent, ToolUseStreamEvent
from ....types.content import Messages
Expand Down Expand Up @@ -93,6 +107,8 @@ class BidiNovaSonicModel(BidiModel):
Manages Nova Sonic's complex event sequencing, audio format conversion, and
tool execution patterns while providing the standard BidiModel interface.

Note, BidiNovaSonicModel is only supported for Python 3.12+.

Attributes:
_stream: open bedrock stream to nova sonic.
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/strands/experimental/bidi/_async/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def test_task_group__aexit__():
@pytest.mark.asyncio
async def test_task_group__aexit__task_exception():
wait_event = asyncio.Event()

async def wait():
await wait_event.wait()

Expand Down Expand Up @@ -49,12 +50,14 @@ async def wait():
@pytest.mark.asyncio
async def test_task_group__aexit__context_cancelled():
wait_event = asyncio.Event()

async def wait():
await wait_event.wait()

tasks = []

run_event = asyncio.Event()

async def run():
async with _TaskGroup() as task_group:
tasks.append(task_group.create_task(wait()))
Expand Down
20 changes: 12 additions & 8 deletions tests/strands/experimental/bidi/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Unit tests for BidiAgent."""

import asyncio
import sys
import unittest.mock
from uuid import uuid4

import pytest

from strands.experimental.bidi.agent.agent import BidiAgent
from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel
from strands.experimental.bidi.types.events import (
BidiAudioInputEvent,
BidiAudioStreamEvent,
Expand Down Expand Up @@ -125,20 +125,24 @@ def test_bidi_agent_init_with_various_configurations():
assert agent_with_config.system_prompt == system_prompt
assert agent_with_config.agent_id == "test_agent"

# Test with string model ID
model_id = "amazon.nova-sonic-v1:0"
agent_with_string = BidiAgent(model=model_id)

assert isinstance(agent_with_string.model, BidiNovaSonicModel)
assert agent_with_string.model.model_id == model_id

# Test model config access
config = agent.model.config
assert config["audio"]["input_rate"] == 16000
assert config["audio"]["output_rate"] == 24000
assert config["audio"]["channels"] == 1


@pytest.mark.skipif(sys.version_info < (3, 12), reason="BidiNovaSonicModel is only supported for Python 3.12+")
def test_bidi_agent_init_with_model_id():
from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel

model_id = "amazon.nova-sonic-v1:0"
agent = BidiAgent(model=model_id)

assert isinstance(agent.model, BidiNovaSonicModel)
assert agent.model.model_id == model_id


@pytest.mark.asyncio
async def test_bidi_agent_start_stop_lifecycle(agent):
"""Test agent start/stop lifecycle and state management."""
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/experimental/bidi/agent/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from strands import tool
from strands.experimental.bidi import BidiAgent
from strands.experimental.bidi.models import BidiModelTimeoutError
from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError
from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent
from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent

Expand All @@ -21,7 +21,7 @@ async def func():

@pytest.fixture
def agent(time_tool):
return BidiAgent(model=unittest.mock.AsyncMock(), tools=[time_tool])
return BidiAgent(model=unittest.mock.AsyncMock(spec=BidiModel), tools=[time_tool])


@pytest_asyncio.fixture
Expand Down
Loading
Loading