Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions aiola/clients/stt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AiolaValidationError,
)
from ...http_client import create_async_authenticated_client, create_authenticated_client
from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse
from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse, VadConfig
from .stream_client import AsyncStreamConnection, StreamConnection

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,6 +54,7 @@ def _build_query_and_headers(
time_zone: str | None,
keywords: dict[str, str] | None,
tasks_config: TasksConfig | None,
vad_config: VadConfig | None,
access_token: str,
) -> tuple[dict[str, str], dict[str, str]]:
"""Build query parameters and headers for streaming requests."""
Expand All @@ -73,6 +74,8 @@ def _build_query_and_headers(
query["keywords"] = json.dumps(keywords)
if tasks_config is not None:
query["tasks_config"] = json.dumps(tasks_config)
if vad_config is not None:
query["vad_config"] = json.dumps(vad_config)

headers = {
"Authorization": f"Bearer {access_token}",
Expand All @@ -88,6 +91,7 @@ def _validate_stream_params(
time_zone: str | None,
keywords: dict[str, str] | None,
tasks_config: TasksConfig | None,
vad_config: VadConfig | None,
) -> None:
"""Validate streaming parameters."""
if flow_id is not None and not isinstance(flow_id, str):
Expand All @@ -100,8 +104,10 @@ def _validate_stream_params(
raise AiolaValidationError("time_zone must be a string")
if keywords is not None and not isinstance(keywords, dict):
raise AiolaValidationError("keywords must be a dictionary")
if tasks_config is not None and not isinstance(tasks_config, dict):
raise AiolaValidationError("tasks_config must be a dictionary")
if tasks_config is not None and not isinstance(tasks_config, dict | TasksConfig):
raise AiolaValidationError("tasks_config must be a dictionary or a TasksConfig object")
if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")


class SttClient(_BaseStt):
Expand All @@ -119,6 +125,7 @@ def stream(
time_zone: str | None = None,
keywords: dict[str, str] | None = None,
tasks_config: TasksConfig | None = None,
vad_config: VadConfig | None = None,
) -> StreamConnection:
"""Create a streaming connection for real-time transcription.

Expand All @@ -135,7 +142,9 @@ def stream(
StreamConnection: A connection object for real-time streaming.
"""
try:
self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config)
self._validate_stream_params(
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config
)

# Resolve workflow_id with proper precedence
resolved_workflow_id = self._resolve_workflow_id(workflow_id)
Expand All @@ -149,7 +158,7 @@ def stream(

# Build query parameters and headers
query, headers = self._build_query_and_headers(
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token
)

url = self._build_url(query)
Expand All @@ -168,6 +177,7 @@ def transcribe_file(
*,
language: str | None = None,
keywords: dict[str, str] | None = None,
vad_config: VadConfig | None = None,
) -> TranscriptionResponse:
"""Transcribe an audio file and return the transcription result."""

Expand All @@ -180,12 +190,16 @@ def transcribe_file(
if keywords is not None and not isinstance(keywords, dict):
raise AiolaValidationError("keywords must be a dictionary")

if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")

try:
# Prepare the form data
files = {"file": file}
data = {
"language": language or "en",
"keywords": json.dumps(keywords or {}),
"vad_config": json.dumps(vad_config or {}),
}

# Create authenticated HTTP client and make request
Expand Down Expand Up @@ -229,6 +243,7 @@ async def stream(
time_zone: str | None = None,
keywords: dict[str, str] | None = None,
tasks_config: TasksConfig | None = None,
vad_config: VadConfig | None = None,
) -> AsyncStreamConnection:
"""Create an async streaming connection for real-time transcription.

Expand All @@ -245,7 +260,9 @@ async def stream(
AsyncStreamConnection: A connection object for real-time async streaming.
"""
try:
self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config)
self._validate_stream_params(
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config
)

# Resolve workflow_id with proper precedence
resolved_workflow_id = self._resolve_workflow_id(workflow_id)
Expand All @@ -259,7 +276,7 @@ async def stream(

# Build query parameters and headers
query, headers = self._build_query_and_headers(
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token
)

url = self._build_url(query)
Expand All @@ -278,6 +295,7 @@ async def transcribe_file(
*,
language: str | None = None,
keywords: dict[str, str] | None = None,
vad_config: VadConfig | None = None,
) -> TranscriptionResponse:
"""Transcribe an audio file and return the transcription result."""

Expand All @@ -290,12 +308,16 @@ async def transcribe_file(
if keywords is not None and not isinstance(keywords, dict):
raise AiolaValidationError("keywords must be a dictionary")

if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")

try:
# Prepare the form data
files = {"file": file}
data = {
"language": language or "en",
"keywords": json.dumps(keywords or {}),
"vad_config": json.dumps(vad_config or {}),
}

# Create authenticated HTTP client and make request
Expand Down
8 changes: 8 additions & 0 deletions aiola/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ class TasksConfig:
TRANSLATION: TranslationPayload | None = None


@dataclass
class VadConfig:
threshold: float | None = None
min_speech_ms: float | None = None
min_silence_ms: float | None = None
max_segment_ms: float | None = None


FileContent = Union[IO[bytes], bytes, str]
File = Union[
# file (or bytes)
Expand Down
Loading