Skip to content

Commit a0dd13b

Browse files
AIRND-154 Change TTS interface to use voice_id (#44)
* support new tts interface * update api to use voice_id instead of persona * fix error handling ןמ אאד בךןקמא * minor fix in async error handling * fix aikido gotcha
1 parent cc67e6d commit a0dd13b

File tree

9 files changed

+169
-162
lines changed

9 files changed

+169
-162
lines changed

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ def create_file():
230230

231231
audio = client.tts.synthesize(
232232
text='Hello, how can I help you today?',
233-
voice='jess',
234-
language='en'
233+
voice_id='en_us_male'
235234
)
236235

237236
with open('./audio.wav', 'wb') as f:
@@ -263,8 +262,7 @@ def stream_tts():
263262

264263
stream = client.tts.stream(
265264
text='Hello, how can I help you today?',
266-
voice='jess',
267-
language='en'
265+
voice_id='en_us_male'
268266
)
269267

270268
audio_chunks = []
@@ -330,8 +328,7 @@ async def create_audio_file():
330328

331329
audio = client.tts.synthesize(
332330
text='Hello, how can I help you today?',
333-
voice='jess',
334-
language='en'
331+
voice_id='en_us_male'
335332
)
336333

337334
with open('./audio.wav', 'wb') as f:
@@ -365,8 +362,7 @@ async def stream_tts():
365362

366363
stream = client.tts.stream(
367364
text='Hello, how can I help you today?',
368-
voice='jess',
369-
language='en'
365+
voice_id='en_us_male'
370366
)
371367

372368
audio_chunks = []

aiola/clients/tts/client.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@ def __init__(self, options: AiolaClientOptions, auth: AuthClient | AsyncAuthClie
2222
def _make_headers() -> dict[str, str]:
2323
return {"Accept": "audio/*"}
2424

25-
def _validate_tts_params(self, text: str, voice: str, language: str | None) -> None:
25+
def _validate_tts_params(self, text: str, voice_id: str) -> None:
2626
"""Validate TTS parameters."""
2727
if not text or not isinstance(text, str):
2828
raise AiolaValidationError("text must be a non-empty string")
29-
if not voice or not isinstance(voice, str):
30-
raise AiolaValidationError("voice must be a non-empty string")
31-
if language is not None and not isinstance(language, str):
32-
raise AiolaValidationError("language must be a string")
29+
if not voice_id or not isinstance(voice_id, str):
30+
raise AiolaValidationError("voice_id must be a non-empty string")
3331

3432

3533
class TtsClient(BaseTts):
@@ -39,9 +37,9 @@ def __init__(self, options: AiolaClientOptions, auth: AuthClient):
3937
super().__init__(options, auth)
4038
self._auth: AuthClient = auth # Type narrowing
4139

42-
def stream(self, *, text: str, voice: str, language: str | None = None) -> Iterator[bytes]:
40+
def stream(self, *, text: str, voice_id: str) -> Iterator[bytes]:
4341
"""Stream synthesized audio in real-time."""
44-
self._validate_tts_params(text, voice, language)
42+
self._validate_tts_params(text, voice_id)
4543

4644
try:
4745
# Create authenticated HTTP client and make the streaming request
@@ -52,13 +50,17 @@ def stream(self, *, text: str, voice: str, language: str | None = None) -> Itera
5250
"/api/tts/stream",
5351
json={
5452
"text": text,
55-
"voice": voice,
56-
"language": language,
53+
"voice_id": voice_id,
5754
},
5855
headers=self._make_headers(),
5956
) as response,
6057
):
61-
response.raise_for_status()
58+
try:
59+
response.raise_for_status()
60+
except httpx.HTTPStatusError:
61+
response.read()
62+
raise
63+
6264
yield from response.iter_bytes()
6365

6466
except AiolaError:
@@ -75,9 +77,9 @@ def stream(self, *, text: str, voice: str, language: str | None = None) -> Itera
7577
except Exception as exc:
7678
raise AiolaError(f"TTS streaming failed: {str(exc)}") from exc
7779

78-
def synthesize(self, *, text: str, voice: str, language: str | None = None) -> Iterator[bytes]:
80+
def synthesize(self, *, text: str, voice_id: str) -> Iterator[bytes]:
7981
"""Synthesize audio and return as iterator of bytes."""
80-
self._validate_tts_params(text, voice, language)
82+
self._validate_tts_params(text, voice_id)
8183

8284
try:
8385
# Create authenticated HTTP client and make the streaming request
@@ -88,13 +90,17 @@ def synthesize(self, *, text: str, voice: str, language: str | None = None) -> I
8890
"/api/tts/synthesize",
8991
json={
9092
"text": text,
91-
"voice": voice,
92-
"language": language,
93+
"voice_id": voice_id,
9394
},
9495
headers=self._make_headers(),
9596
) as response,
9697
):
97-
response.raise_for_status()
98+
try:
99+
response.raise_for_status()
100+
except httpx.HTTPStatusError:
101+
response.read()
102+
raise
103+
98104
yield from response.iter_bytes()
99105

100106
except AiolaError:
@@ -119,9 +125,9 @@ def __init__(self, options: AiolaClientOptions, auth: AsyncAuthClient):
119125
super().__init__(options, auth)
120126
self._auth: AsyncAuthClient = auth # Type narrowing
121127

122-
async def stream(self, *, text: str, voice: str, language: str | None = None) -> AsyncIterator[bytes]:
128+
async def stream(self, *, text: str, voice_id: str) -> AsyncIterator[bytes]:
123129
"""Stream synthesized audio in real-time (async)."""
124-
self._validate_tts_params(text, voice, language)
130+
self._validate_tts_params(text, voice_id)
125131

126132
try:
127133
# Create authenticated HTTP client and make the streaming request
@@ -133,13 +139,17 @@ async def stream(self, *, text: str, voice: str, language: str | None = None) ->
133139
"/api/tts/stream",
134140
json={
135141
"text": text,
136-
"voice": voice,
137-
"language": language,
142+
"voice_id": voice_id,
138143
},
139144
headers=self._make_headers(),
140145
) as response,
141146
):
142-
response.raise_for_status()
147+
try:
148+
response.raise_for_status()
149+
except httpx.HTTPStatusError:
150+
await response.aread()
151+
raise
152+
143153
async for chunk in response.aiter_bytes():
144154
yield chunk
145155

@@ -157,9 +167,9 @@ async def stream(self, *, text: str, voice: str, language: str | None = None) ->
157167
except Exception as exc:
158168
raise AiolaError(f"Async TTS streaming failed: {str(exc)}") from exc
159169

160-
async def synthesize(self, *, text: str, voice: str, language: str | None = None) -> AsyncIterator[bytes]:
170+
async def synthesize(self, *, text: str, voice_id: str) -> AsyncIterator[bytes]:
161171
"""Synthesize audio and return as async iterator of bytes."""
162-
self._validate_tts_params(text, voice, language)
172+
self._validate_tts_params(text, voice_id)
163173

164174
try:
165175
# Create authenticated HTTP client and make the streaming request
@@ -171,13 +181,17 @@ async def synthesize(self, *, text: str, voice: str, language: str | None = None
171181
"/api/tts/synthesize",
172182
json={
173183
"text": text,
174-
"voice": voice,
175-
"language": language,
184+
"voice_id": voice_id,
176185
},
177186
headers=self._make_headers(),
178187
) as response,
179188
):
180-
response.raise_for_status()
189+
try:
190+
response.raise_for_status()
191+
except httpx.HTTPStatusError:
192+
await response.aread()
193+
raise
194+
181195
async for chunk in response.aiter_bytes():
182196
yield chunk
183197

aiola/errors.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ def __init__(
1616
self,
1717
message: str,
1818
*,
19+
reason: str | None = None,
1920
status: int | None = None,
2021
code: str | None = None,
2122
details: Any | None = None,
2223
) -> None:
2324
super().__init__(message)
2425
self.message: str = message # Keep an explicit attribute – ``Exception`` drops it under ``__str__``
26+
self.reason: str | None = reason
2527
self.status: int | None = status
2628
self.code: str | None = code
2729
self.details: Any | None = details
@@ -38,27 +40,29 @@ def from_response(cls, response: httpx.Response) -> AiolaError:
3840
"""
3941

4042
message: str = f"Request failed with status {response.status_code}"
43+
reason: str | None = None
4144
code: str | None = None
4245
details: Any | None = None
4346

4447
try:
4548
payload = response.json()
4649
if isinstance(payload, dict):
47-
err_payload = payload.get("error", payload)
48-
if isinstance(err_payload, dict):
49-
message = err_payload.get("message", message)
50-
code = err_payload.get("code")
51-
details = err_payload.get("details", err_payload)
50+
reason = payload.get("message")
51+
code = payload.get("code")
52+
details = payload.get("details", payload)
5253
except ValueError:
5354
# Not JSON – try plain text
54-
text = response.text
55-
if text:
56-
message = text
55+
reason = response.text
5756

58-
return cls(message, status=response.status_code, code=code, details=details)
57+
return cls(message, reason=reason, status=response.status_code, code=code, details=details)
5958

6059
def __str__(self) -> str:
61-
return self.message
60+
parts = [self.message]
61+
62+
if self.reason is not None:
63+
parts.append(f"Reason: {self.reason}")
64+
65+
return " | ".join(parts)
6266

6367

6468
class AiolaConnectionError(AiolaError):

examples/tts/README.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def synthesize_to_file():
2121
# Step 3: Synthesize audio to a file
2222
audio_stream = client.tts.synthesize(
2323
text="Hello, how can I help you today?",
24-
voice="jess",
25-
language="en"
24+
voice_id="en_us_male"
2625
)
2726

2827
# Step 4: Save to file
@@ -67,8 +66,7 @@ def main():
6766
def synthesize_to_file():
6867
audio_stream = client.tts.synthesize(
6968
text="Hello, how can I help you today?",
70-
voice="jess",
71-
language="en"
69+
voice_id="en_us_male"
7270
)
7371

7472
# Save to file
@@ -83,8 +81,7 @@ def main():
8381
def stream_tts():
8482
stream = client.tts.stream(
8583
text="Hello, this is a streaming example of text-to-speech synthesis.",
86-
voice="jess",
87-
language="en"
84+
voice_id="en_us_male"
8885
)
8986

9087
# Collect audio chunks
@@ -122,7 +119,7 @@ async def async_tts_example():
122119
result = await AsyncAiolaClient.grant_token(api_key=os.getenv("AIOLA_API_KEY"))
123120
client = AsyncAiolaClient(access_token=result.access_token)
124121

125-
response = await client.tts.synthesize(text="Hello world", voice="jess", language="en")
122+
response = await client.tts.synthesize(text="Hello world", voice_id="en_us_male")
126123

127124
async for chunk in response:
128125
# Process audio chunk

examples/tts/async_tts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ async def create_audio_file():
1717
# Step 3: Generate audio
1818
audio = client.tts.synthesize(
1919
text='Hello, how can I help you today?',
20-
voice='jess',
21-
language='en'
20+
voice_id='en_us_male'
2221
)
2322

2423
file_path = os.path.join(os.path.dirname(__file__), "async_audio.wav")
@@ -48,8 +47,7 @@ async def stream_tts():
4847
# Step 3: Stream audio
4948
stream = client.tts.stream(
5049
text='Hello, how can I help you today?',
51-
voice='jess',
52-
language='en'
50+
voice_id='en_us_male'
5351
)
5452

5553
audio_chunks = []

examples/tts/tts_file.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ def create_file():
1616
# Step 3: Generate audio
1717
audio = client.tts.synthesize(
1818
text='Hello, how can I help you today?',
19-
voice='jess',
20-
language='en'
19+
voice_id='en_us_male'
2120
)
2221

2322
output_path = os.path.join(os.path.dirname(__file__), "output_audio.wav")

examples/tts/tts_stream.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ def stream_tts():
1616
# Step 3: Stream audio
1717
stream = client.tts.stream(
1818
text='Hello, how can I help you today?',
19-
voice='jess',
20-
language='en'
19+
voice_id='en_us_male'
2120
)
2221

2322
audio_chunks = []

tests/unit/tts/test_tts_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ def test_tts_stream_makes_expected_http_request(dummy_http):
1414
"""``TtsClient.stream`` should send POST /synthesize/stream and yield audio chunks."""
1515

1616
client = AiolaClient(api_key="k", base_url="https://tts.example")
17-
chunks = list(client.tts.stream(text="Hello", voice="voiceA"))
17+
chunks = list(client.tts.stream(text="Hello", voice_id="en_us_male"))
1818

1919
assert chunks == [b"chunk1", b"chunk2"]
2020

2121
recorded = dummy_http.stream_calls.pop()
2222
assert recorded["method"] == "POST"
2323
assert recorded["path"] == "/api/tts/stream"
24-
assert recorded["json"] == {"text": "Hello", "voice": "voiceA", "language": None}
24+
assert recorded["json"] == {"text": "Hello", "voice_id": "en_us_male"}
2525

2626

2727
def test_tts_synthesize_makes_expected_http_request(dummy_http):
2828
"""``TtsClient.synthesize`` must hit POST /synthesize (non-stream variant)."""
2929

3030
client = AiolaClient(api_key="k")
31-
list(client.tts.synthesize(text="Hi", voice="B")) # exhaust generator
31+
list(client.tts.synthesize(text="Hi", voice_id="de_female")) # exhaust generator
3232

3333
recorded = dummy_http.stream_calls.pop()
3434
assert recorded["path"] == "/api/tts/synthesize"
@@ -44,7 +44,7 @@ async def test_async_tts_stream(dummy_async_http):
4444
"""``AsyncTtsClient.stream`` should work similarly using awaitables."""
4545

4646
client = AsyncAiolaClient(api_key="k")
47-
chunks = [c async for c in client.tts.stream(text="Async", voice="v")] # exhaust
47+
chunks = [c async for c in client.tts.stream(text="Async", voice_id="en_uk_female")] # exhaust
4848

4949
assert chunks == [b"chunk1", b"chunk2"]
5050

@@ -57,7 +57,7 @@ async def test_async_tts_synthesize(dummy_async_http):
5757
"""``AsyncTtsClient.synthesize`` POSTs to /synthesize endpoint."""
5858

5959
client = AsyncAiolaClient(api_key="k")
60-
_ = [c async for c in client.tts.synthesize(text="Async", voice="v")]
60+
_ = [c async for c in client.tts.synthesize(text="Async", voice_id="de_female")]
6161

6262
recorded = dummy_async_http.stream_calls.pop()
6363
assert recorded["path"] == "/api/tts/synthesize"
@@ -95,7 +95,7 @@ def mock_create_authenticated_client(*args, **kwargs):
9595

9696
# Now wrapped in AiolaError instead of raw RuntimeError
9797
with pytest.raises(AiolaError, match="TTS streaming failed"):
98-
list(client.tts.stream(text="x", voice="v"))
98+
list(client.tts.stream(text="x", voice_id="en_us_male"))
9999

100100

101101
@pytest.mark.anyio
@@ -129,5 +129,5 @@ async def mock_create_async_authenticated_client(*args, **kwargs):
129129

130130
# Now wrapped in AiolaError instead of raw RuntimeError
131131
with pytest.raises(AiolaError, match="Async TTS streaming failed"):
132-
async for _ in client.tts.stream(text="fail", voice="v"):
132+
async for _ in client.tts.stream(text="fail", voice_id="en_us_male"):
133133
pass

0 commit comments

Comments
 (0)