diff --git a/aiola/clients/stt/client.py b/aiola/clients/stt/client.py index 19acbce..8c1de0d 100644 --- a/aiola/clients/stt/client.py +++ b/aiola/clients/stt/client.py @@ -17,7 +17,7 @@ AiolaValidationError, ) from ...http_client import create_async_authenticated_client, create_authenticated_client -from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse +from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse, VadConfig from .stream_client import AsyncStreamConnection, StreamConnection if TYPE_CHECKING: @@ -54,6 +54,7 @@ def _build_query_and_headers( time_zone: str | None, keywords: dict[str, str] | None, tasks_config: TasksConfig | None, + vad_config: VadConfig | None, access_token: str, ) -> tuple[dict[str, str], dict[str, str]]: """Build query parameters and headers for streaming requests.""" @@ -73,6 +74,8 @@ def _build_query_and_headers( query["keywords"] = json.dumps(keywords) if tasks_config is not None: query["tasks_config"] = json.dumps(tasks_config) + if vad_config is not None: + query["vad_config"] = json.dumps(vad_config) headers = { "Authorization": f"Bearer {access_token}", @@ -88,6 +91,7 @@ def _validate_stream_params( time_zone: str | None, keywords: dict[str, str] | None, tasks_config: TasksConfig | None, + vad_config: VadConfig | None, ) -> None: """Validate streaming parameters.""" if flow_id is not None and not isinstance(flow_id, str): @@ -100,8 +104,10 @@ def _validate_stream_params( raise AiolaValidationError("time_zone must be a string") if keywords is not None and not isinstance(keywords, dict): raise AiolaValidationError("keywords must be a dictionary") - if tasks_config is not None and not isinstance(tasks_config, dict): - raise AiolaValidationError("tasks_config must be a dictionary") + if tasks_config is not None and not isinstance(tasks_config, dict | TasksConfig): + raise AiolaValidationError("tasks_config must be a dictionary or a TasksConfig object") + if vad_config is not None and not isinstance(vad_config, dict | VadConfig): + raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object") class SttClient(_BaseStt): @@ -119,6 +125,7 @@ def stream( time_zone: str | None = None, keywords: dict[str, str] | None = None, tasks_config: TasksConfig | None = None, + vad_config: VadConfig | None = None, ) -> StreamConnection: """Create a streaming connection for real-time transcription. @@ -135,7 +142,9 @@ def stream( StreamConnection: A connection object for real-time streaming. """ try: - self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config) + self._validate_stream_params( + workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config + ) # Resolve workflow_id with proper precedence resolved_workflow_id = self._resolve_workflow_id(workflow_id) @@ -149,7 +158,7 @@ def stream( # Build query parameters and headers query, headers = self._build_query_and_headers( - workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token + workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token ) url = self._build_url(query) @@ -168,6 +177,7 @@ def transcribe_file( *, language: str | None = None, keywords: dict[str, str] | None = None, + vad_config: VadConfig | None = None, ) -> TranscriptionResponse: """Transcribe an audio file and return the transcription result.""" @@ -180,12 +190,16 @@ def transcribe_file( if keywords is not None and not isinstance(keywords, dict): raise AiolaValidationError("keywords must be a dictionary") + if vad_config is not None and not isinstance(vad_config, dict | VadConfig): + raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object") + try: # Prepare the form data files = {"file": file} data = { "language": language or "en", "keywords": json.dumps(keywords or {}), + "vad_config": json.dumps(vad_config or {}), } # Create authenticated HTTP client and make request @@ -229,6 +243,7 @@ async def stream( time_zone: str | None = None, keywords: dict[str, str] | None = None, tasks_config: TasksConfig | None = None, + vad_config: VadConfig | None = None, ) -> AsyncStreamConnection: """Create an async streaming connection for real-time transcription. @@ -245,7 +260,9 @@ async def stream( AsyncStreamConnection: A connection object for real-time async streaming. """ try: - self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config) + self._validate_stream_params( + workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config + ) # Resolve workflow_id with proper precedence resolved_workflow_id = self._resolve_workflow_id(workflow_id) @@ -259,7 +276,7 @@ async def stream( # Build query parameters and headers query, headers = self._build_query_and_headers( - workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token + workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token ) url = self._build_url(query) @@ -278,6 +295,7 @@ async def transcribe_file( *, language: str | None = None, keywords: dict[str, str] | None = None, + vad_config: VadConfig | None = None, ) -> TranscriptionResponse: """Transcribe an audio file and return the transcription result.""" @@ -290,12 +308,16 @@ async def transcribe_file( if keywords is not None and not isinstance(keywords, dict): raise AiolaValidationError("keywords must be a dictionary") + if vad_config is not None and not isinstance(vad_config, dict | VadConfig): + raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object") + try: # Prepare the form data files = {"file": file} data = { "language": language or "en", "keywords": json.dumps(keywords or {}), + "vad_config": json.dumps(vad_config or {}), } # Create authenticated HTTP client and make request diff --git a/aiola/types.py b/aiola/types.py index 0978fc3..35377df 100644 --- a/aiola/types.py +++ b/aiola/types.py @@ -140,6 +140,14 @@ class TasksConfig: TRANSLATION: TranslationPayload | None = None +@dataclass +class VadConfig: + threshold: float | None = None + min_speech_ms: float | None = None + min_silence_ms: float | None = None + max_segment_ms: float | None = None + + FileContent = Union[IO[bytes], bytes, str] File = Union[ # file (or bytes) diff --git a/tests/unit/stt/test_stt_client.py b/tests/unit/stt/test_stt_client.py index feca5de..f98747a 100644 --- a/tests/unit/stt/test_stt_client.py +++ b/tests/unit/stt/test_stt_client.py @@ -6,7 +6,7 @@ import httpx from aiola import AiolaClient, AsyncAiolaClient, AiolaError -from aiola.types import TasksConfig, LiveEvents, TranscriptionResponse +from aiola.types import TasksConfig, LiveEvents, TranscriptionResponse, VadConfig from aiola.clients.stt.client import StreamConnection, AsyncStreamConnection from tests._helpers import ( @@ -102,6 +102,11 @@ def test_stt_transcribe_file_makes_expected_http_request(dummy_stt_http): parsed_keywords = json.loads(keywords_json) assert parsed_keywords == {} + # Verify empty VAD config JSON is sent + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + def test_stt_transcribe_file_with_keywords(dummy_stt_http): """``SttClient.transcribe_file`` properly serializes keywords.""" @@ -135,6 +140,11 @@ def test_stt_transcribe_file_with_keywords(dummy_stt_http): assert parsed_keywords["hello"] == "greeting" assert parsed_keywords["world"] == "place" + # Verify empty VAD config JSON is sent + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + def test_stt_transcribe_file_with_default_parameters(dummy_stt_http): """``SttClient.transcribe_file`` uses default values when parameters are not provided.""" @@ -248,6 +258,102 @@ def test_stt_transcribe_file_with_complex_keywords(dummy_stt_http): assert parsed_keywords["special!@#$%"] == "symbols" +def test_stt_transcribe_file_with_vad_config_dict(dummy_stt_http): + """``SttClient.transcribe_file`` properly serializes vad_config as dictionary.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + # Create a mock audio file + audio_file = BytesIO(b"fake audio data") + + vad_config = { + "min_speech_ms": 100, + "min_silence_ms": 200, + "max_segment_ms": 1500, + "threshold": 0.5, + } + + result = client.stt.transcribe_file( + file=audio_file, + language="en", + vad_config=vad_config, + ) + + # Check the response + assert result.transcript == "Hello, this is a test transcription." + + # Check the HTTP request + assert len(dummy_stt_http.post_calls) == 1 + recorded = dummy_stt_http.post_calls[0] + assert recorded["path"] == "/api/speech-to-text/file" + assert recorded["files"]["file"] == audio_file + assert recorded["data"]["language"] == "en" + + # Verify VAD config was properly serialized + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config["min_speech_ms"] == 100 + assert parsed_vad_config["min_silence_ms"] == 200 + assert parsed_vad_config["max_segment_ms"] == 1500 + assert parsed_vad_config["threshold"] == 0.5 + + +def test_stt_transcribe_file_with_empty_vad_config(dummy_stt_http): + """``SttClient.transcribe_file`` handles empty vad_config dict properly.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + # Create a mock audio file + audio_file = BytesIO(b"fake audio data") + + result = client.stt.transcribe_file( + file=audio_file, + language="fr", + vad_config={}, + ) + + # Check the response + assert result.transcript == "Hello, this is a test transcription." + + # Check the HTTP request + assert len(dummy_stt_http.post_calls) == 1 + recorded = dummy_stt_http.post_calls[0] + assert recorded["data"]["language"] == "fr" + + # Verify empty VAD config JSON + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + + +def test_stt_transcribe_file_with_default_vad_config(dummy_stt_http): + """``SttClient.transcribe_file`` uses default empty dict when vad_config is not provided.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + # Create a mock audio file + audio_file = BytesIO(b"fake audio data") + + # Call without vad_config parameter + result = client.stt.transcribe_file(file=audio_file) + + # Check the response + assert isinstance(result, TranscriptionResponse) + assert result.transcript == "Hello, this is a test transcription." + + # Check the HTTP request uses defaults + assert len(dummy_stt_http.post_calls) == 1 + recorded = dummy_stt_http.post_calls[0] + assert recorded["path"] == "/api/speech-to-text/file" + assert recorded["files"]["file"] == audio_file + assert recorded["data"]["language"] == "en" # Default language + + # Verify empty VAD config JSON is sent + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + + def test_stt_stream_with_tasks_config(patch_dummy_socket): """``SttClient.stream`` properly serializes tasks_config as JSON.""" @@ -415,6 +521,103 @@ def test_stt_stream_with_all_tasks_config(patch_dummy_socket): assert "PII_REDACTION" in parsed_tasks_config +def test_stt_stream_with_vad_config(patch_dummy_socket): + """``SttClient.stream`` properly serializes vad_config as JSON.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + vad_config = { + "min_speech_ms": 100, + "min_silence_ms": 200, + "max_segment_ms": 1500, + "threshold": 0.5, + } + + connection = client.stt.stream( + workflow_id="flow-123", + vad_config=vad_config + ) + + assert isinstance(connection, StreamConnection) + assert connection.connected is False + + connection.connect() + assert connection.connected is True + + # Access the underlying socket to validate connection parameters + sio = connection._sio + assert isinstance(sio, DummySocketClient) + + # Validate vad_config is properly serialized in query parameters + kwargs = sio.connect_kwargs + url = kwargs["url"] + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + + # Extract and parse the vad_config from the URL + vad_config_json = query["vad_config"][0] + parsed_vad_config = json.loads(vad_config_json) + + # Verify the vad_config was properly serialized + assert parsed_vad_config["min_speech_ms"] == 100 + assert parsed_vad_config["min_silence_ms"] == 200 + assert parsed_vad_config["max_segment_ms"] == 1500 + assert parsed_vad_config["threshold"] == 0.5 + + +def test_stt_stream_with_empty_vad_config(patch_dummy_socket): + """``SttClient.stream`` handles empty vad_config properly.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + connection = client.stt.stream(workflow_id="flow-123", vad_config={}) + + assert isinstance(connection, StreamConnection) + assert connection.connected is False + + connection.connect() + assert connection.connected is True + + # Access the underlying socket to validate connection parameters + sio = connection._sio + + # Verify empty vad_config is serialized as empty JSON object + kwargs = sio.connect_kwargs + url = kwargs["url"] + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + + vad_config_json = query["vad_config"][0] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + + +def test_stt_stream_with_no_vad_config(patch_dummy_socket): + """``SttClient.stream`` handles None vad_config properly by not including it in URL.""" + + client = AiolaClient(api_key="secret-key", base_url="https://speech.example") + + connection = client.stt.stream(workflow_id="flow-123", vad_config=None) + + assert isinstance(connection, StreamConnection) + assert connection.connected is False + + connection.connect() + assert connection.connected is True + + # Access the underlying socket to validate connection parameters + sio = connection._sio + + # Verify None vad_config is not included in URL + kwargs = sio.connect_kwargs + url = kwargs["url"] + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + + # vad_config should not be present when None + assert "vad_config" not in query + + def test_stream_connection_wrapper_functionality(patch_dummy_socket): """Test that StreamConnection wrapper provides the expected API.""" @@ -549,6 +752,48 @@ async def test_async_stt_stream_with_tasks_config(patch_dummy_async_socket): assert parsed_tasks_config["SENTIMENT_ANALYSIS"] == {} +@pytest.mark.anyio +async def test_async_stt_stream_with_vad_config(patch_dummy_async_socket): + """Async version properly handles vad_config.""" + + client = AsyncAiolaClient(api_key="tok", base_url="https://speech.example") + + vad_config = { + "min_speech_ms": 100, + "min_silence_ms": 200, + "max_segment_ms": 1500, + "threshold": 0.5, + } + + connection = await client.stt.stream( + workflow_id="f1", + vad_config=vad_config + ) + + assert isinstance(connection, AsyncStreamConnection) + assert connection.connected is False + + await connection.connect() + assert connection.connected is True + + # Access the underlying socket to validate connection parameters + sio = connection._sio + + # Verify vad_config is properly serialized + kwargs = sio.connect_kwargs + url = kwargs["url"] + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + + vad_config_json = query["vad_config"][0] + parsed_vad_config = json.loads(vad_config_json) + + assert parsed_vad_config["min_speech_ms"] == 100 + assert parsed_vad_config["min_silence_ms"] == 200 + assert parsed_vad_config["max_segment_ms"] == 1500 + assert parsed_vad_config["threshold"] == 0.5 + + @pytest.mark.anyio async def test_async_stt_transcribe_file_makes_expected_http_request(dummy_async_stt_http): """``AsyncSttClient.transcribe_file`` should send POST /api/speech-to-text/file with file.""" @@ -585,6 +830,11 @@ async def test_async_stt_transcribe_file_makes_expected_http_request(dummy_async parsed_keywords = json.loads(keywords_json) assert parsed_keywords == {} + # Verify empty VAD config JSON is sent + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + @pytest.mark.anyio async def test_async_stt_transcribe_file_with_keywords(dummy_async_stt_http): @@ -648,6 +898,53 @@ async def test_async_stt_transcribe_file_with_default_parameters(dummy_async_stt parsed_keywords = json.loads(keywords_json) assert parsed_keywords == {} + # Verify empty VAD config JSON is sent + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config == {} + + +@pytest.mark.anyio +async def test_async_stt_transcribe_file_with_vad_config(dummy_async_stt_http): + """``AsyncSttClient.transcribe_file`` properly serializes vad_config.""" + + client = AsyncAiolaClient(api_key="secret-key", base_url="https://speech.example") + + # Create a mock audio file + audio_file = BytesIO(b"fake audio data") + + vad_config = { + "min_speech_ms": 100, + "min_silence_ms": 200, + "max_segment_ms": 1500, + "threshold": 0.5, + } + + result = await client.stt.transcribe_file( + file=audio_file, + language="en", + vad_config=vad_config, + ) + + # Check the response + assert isinstance(result, TranscriptionResponse) + assert result.transcript == "Hello, this is a test transcription." + + # Check the HTTP request + assert len(dummy_async_stt_http.post_calls) == 1 + recorded = dummy_async_stt_http.post_calls[0] + assert recorded["path"] == "/api/speech-to-text/file" + assert recorded["files"]["file"] == audio_file + assert recorded["data"]["language"] == "en" + + # Verify VAD config was properly serialized + vad_config_json = recorded["data"]["vad_config"] + parsed_vad_config = json.loads(vad_config_json) + assert parsed_vad_config["min_speech_ms"] == 100 + assert parsed_vad_config["min_silence_ms"] == 200 + assert parsed_vad_config["max_segment_ms"] == 1500 + assert parsed_vad_config["threshold"] == 0.5 + @pytest.mark.anyio async def test_async_stt_transcribe_file_handles_http_error(monkeypatch): diff --git a/uv.lock b/uv.lock index 06b8b08..e374875 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.11'", @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "aiola" -version = "0.1.3" +version = "0.1.5" source = { editable = "." } dependencies = [ { name = "httpx" },