Skip to content

Commit 83b74ef

Browse files
authored
Merge pull request #82 from hyscale-lab/feat/tbt-calculation
TBT calculation for non-streaming inferences
2 parents ad34240 + 29ce11a commit 83b74ef

File tree

11 files changed

+158
-104
lines changed

11 files changed

+158
-104
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ PERPLEXITY_AI_API="your-perplexity-ai-api-key"
4646
HYPERBOLIC_API="your-hyperbolic-api-key"
4747
GROQ_API_KEY="your-groq-api-key"
4848
GEMINI_API_KEY="your-gemini-api-key"
49-
AZURE_LLAMA_8B_API="your-azure-llama-8b-api-key"
50-
AZURE_LLAMA_3.1_70B_API="your-azure-llama-70b-api-key"
5149
MISTRAL_LARGE_API="your-mistral-large-api-key"
5250
AWS_BEDROCK_ACCESS_KEY_ID="your-aws-bedrock-access-key-id"
5351
AWS_BEDROCK_SECRET_ACCESS_KEY="your-aws-bedrock-secret-key"
5452
AWS_BEDROCK_REGION="your-aws-bedrock-region"
5553
DYNAMODB_ENDPOINT_URL="your-dynamodb-endpoint-url"
54+
AZURE_AI_ENDPOINT="your-azure-ai-endpoint"
55+
AZURE_AI_API_KEY="your-azure-ai-api-key"
5656
```
5757

5858
## **Usage**

providers/anthropic_provider.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,20 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
6868
timeout=500,
6969
)
7070
elapsed = timer() - start
71+
72+
usage = getattr(response, "usage", None)
73+
total_tokens = (getattr(usage, "output_tokens", 0) or 0) if usage else 0
74+
75+
tbt = elapsed / max(total_tokens, 1)
76+
tps = (total_tokens / elapsed)
77+
7178
self.log_metrics(model, "response_times", elapsed)
79+
self.log_metrics(model, "totaltokens", total_tokens)
80+
self.log_metrics(model, "timebetweentokens", tbt)
81+
self.log_metrics(model, "tps", tps)
7282
# Process and display the response
7383
if verbosity:
84+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
7485
self.display_response(response, elapsed)
7586
return elapsed
7687

providers/aws_provider.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,18 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
7272
model_response = json.loads(response["body"].read())
7373
generated_text = model_response.get("generation", "")
7474

75+
total_tokens = model_response.get("generation_token_count") or 0
76+
77+
tbt = total_time / max(total_tokens, 1)
78+
tps = (total_tokens / total_time)
79+
80+
self.log_metrics(model, "totaltokens", total_tokens)
81+
self.log_metrics(model, "timebetweentokens", tbt)
82+
self.log_metrics(model, "tps", tps)
83+
7584
if verbosity:
7685
print(f"[INFO] Total response time: {total_time:.4f} seconds")
86+
print(f"[INFO] Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
7787
print("[INFO] Generated response:")
7888
print(generated_text)
7989

providers/azure_provider.py

Lines changed: 81 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,93 @@
11
import os
2-
import requests
32
import numpy as np
43
from providers.base_provider import ProviderInterface
54
from time import perf_counter as timer
6-
import re
5+
from azure.ai.inference import ChatCompletionsClient
6+
from azure.ai.inference.models import SystemMessage, UserMessage
7+
from azure.core.credentials import AzureKeyCredential
78

89

910
class Azure(ProviderInterface):
1011
def __init__(self):
1112
"""Initialize AzureProvider with required API information."""
1213
super().__init__()
1314

15+
self.endpoint = os.getenv("AZURE_AI_ENDPOINT")
16+
self.api_key = os.getenv("AZURE_AI_API_KEY")
17+
1418
# Map model names to Azure model IDs
1519
self.model_map = {
1620
# "mistral-7b-instruct-v0.1": "mistral-7b-instruct-v0.1",
17-
"meta-llama-3.1-8b-instruct": "Meta-Llama-3-1-8B-Instruct-fyp",
21+
"meta-llama-3.1-8b-instruct": "Meta-Llama-3.1-8B-Instruct-fyp",
1822
"meta-llama-3.1-70b-instruct": "Meta-Llama-3-1-70B-Instruct-fyp",
1923
"mistral-large": "Mistral-Large-2411-yatcd",
2024
"common-model": "Mistral-Large-2411-yatcd",
2125
}
2226

23-
# Define API keys for each model
24-
self.model_api_keys = {
25-
# "mistral-7b-instruct-v0.1": os.environ.get("MISTRAL_API_KEY"),
26-
"meta-llama-3.1-8b-instruct": os.environ.get("AZURE_LLAMA_8B_API"),
27-
"meta-llama-3.1-70b-instruct": os.environ.get("AZURE_LLAMA_3.1_70B_API"),
28-
"mistral-large": os.environ.get("MISTRAL_LARGE_API"),
29-
"common-model": os.environ.get("MISTRAL_LARGE_API")
30-
}
27+
self._client = None
28+
29+
def _ensure_client(self):
30+
"""
31+
Create the Azure client only when first used.
32+
Raise a clear error if env vars are missing.
33+
"""
34+
if self._client is not None:
35+
return
36+
37+
if not self.api_key or not isinstance(self.api_key, str):
38+
raise RuntimeError(
39+
"Azure provider misconfigured: AZURE_AI_API_KEY is missing or not a string."
40+
)
41+
if not self.endpoint:
42+
raise RuntimeError(
43+
"Azure provider misconfigured: AZURE_AI_ENDPOINT is missing."
44+
)
45+
46+
credential = AzureKeyCredential(self.api_key)
47+
self._client = ChatCompletionsClient(
48+
endpoint=self.endpoint,
49+
credential=credential,
50+
api_version="2024-05-01-preview",
51+
)
3152

3253
def get_model_name(self, model):
3354
"""Retrieve the model name based on the input key."""
3455
return self.model_map.get(model, None)
3556

36-
def get_model_api_key(self, model):
37-
"""Retrieve the API key for a specific model."""
38-
api_key = self.model_api_keys.get(model)
39-
if not api_key:
40-
raise ValueError(
41-
f"No API key found for model '{model}'. Ensure it is set in environment variables."
42-
)
43-
return api_key
44-
4557
def perform_inference(self, model, prompt, max_output=100, verbosity=True):
4658
"""Performs non-streaming inference request to Azure."""
4759
try:
60+
self._ensure_client()
61+
client = self._client
4862
model_id = self.get_model_name(model)
49-
api_key = self.get_model_api_key(model)
5063
if model_id is None:
5164
print(f"Model {model} not available.")
5265
return None
5366
start_time = timer()
54-
endpoint = f"https://{model_id}.eastus.models.ai.azure.com/chat/completions"
55-
response = requests.post(
56-
f"{endpoint}",
57-
headers={
58-
"Authorization": f"Bearer {api_key}",
59-
"Content-Type": "application/json",
60-
},
61-
json={
62-
"messages": [
63-
{"role": "system", "content": self.system_prompt},
64-
{"role": "user", "content": prompt},
65-
],
66-
"max_tokens": max_output,
67-
},
68-
timeout=500,
67+
response = client.complete(
68+
messages=[
69+
SystemMessage(content=self.system_prompt),
70+
UserMessage(content=prompt),
71+
],
72+
max_tokens=max_output,
73+
model=model_id
6974
)
7075
elapsed = timer() - start_time
71-
if response.status_code != 200:
72-
print(f"Error: {response.status_code} - {response.text}")
73-
return None
7476

75-
# Parse and display response
76-
inference = response.json()
77+
usage = response.get("usage")
78+
total_tokens = usage.get("completion_tokens") or 0
79+
tbt = elapsed / max(total_tokens, 1)
80+
tps = (total_tokens / elapsed)
81+
7782
self.log_metrics(model, "response_times", elapsed)
83+
self.log_metrics(model, "totaltokens", total_tokens)
84+
self.log_metrics(model, "timebetweentokens", tbt)
85+
self.log_metrics(model, "tps", tps)
86+
7887
if verbosity:
79-
print(f"Response: {inference['choices'][0]['message']['content']}")
80-
return inference
88+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
89+
print(f"Response: {response['choices'][0]['message']['content']}")
90+
return response
8191

8292
except Exception as e:
8393
print(f"[ERROR] Inference failed for model '{model}': {e}")
@@ -87,72 +97,46 @@ def perform_inference_streaming(
8797
self, model, prompt, max_output=100, verbosity=True
8898
):
8999
"""Performs streaming inference request to Azure."""
100+
self._ensure_client()
101+
client = self._client
90102
model_id = self.get_model_name(model)
91-
api_key = self.get_model_api_key(model)
92103
if model_id is None:
93104
print(f"Model {model} not available.")
94105
return None
95106

96107
inter_token_latencies = []
97-
endpoint = f"https://{model_id}.eastus.models.ai.azure.com/chat/completions"
98108
start_time = timer()
99109
try:
100-
response = requests.post(
101-
f"{endpoint}",
102-
headers={
103-
"Authorization": f"Bearer {api_key}",
104-
"Content-Type": "application/json",
105-
},
106-
json={
107-
"messages": [
108-
# {"role": "system", "content": self.system_prompt + "\nThe number appended at the end is not important."},
109-
# {"role": "user", "content": prompt + " " + str(timer())},
110-
{"role": "system", "content": self.system_prompt},
111-
{"role": "user", "content": prompt},
112-
],
113-
"max_tokens": max_output,
114-
"stream": True,
115-
},
116-
stream=True,
117-
timeout=500,
118-
)
119-
120110
first_token_time = None
121-
for line in response.iter_lines():
122-
if line:
123-
# print(line)
124-
if first_token_time is None:
125-
# print(line)
126-
first_token_time = timer()
127-
ttft = first_token_time - start_time
128-
prev_token_time = first_token_time
129-
if verbosity:
111+
with client.complete(
112+
stream=True,
113+
messages=[
114+
SystemMessage(content=self.system_prompt),
115+
UserMessage(content=prompt),
116+
],
117+
max_tokens=max_output,
118+
model=model_id
119+
) as response:
120+
for event in response:
121+
if not event.choices or not event.choices[0].delta:
122+
continue
123+
124+
delta = event.choices[0].delta
125+
if delta.content:
126+
if first_token_time is None:
127+
first_token_time = timer()
128+
ttft = first_token_time - start_time
129+
prev_token_time = first_token_time
130130
print(f"##### Time to First Token (TTFT): {ttft:.4f} seconds\n")
131131

132-
line_str = line.decode("utf-8").strip()
133-
134-
if line_str == "data: [DONE]":
135-
# print(line_str)
136-
# print("here")
137-
total_time = timer() - start_time
138-
break
139-
140-
# Capture token timing
141-
time_to_next_token = timer()
142-
inter_token_latency = time_to_next_token - prev_token_time
143-
prev_token_time = time_to_next_token
144-
inter_token_latencies.append(inter_token_latency)
145-
146-
# Display token if verbosity is enabled
147-
match = re.search(r'"content"\s*:\s*"(.*?)"', line_str)
148-
if match:
149-
print(match.group(1), end="")
150-
# if verbosity:
151-
# if len(inter_token_latencies) < 20:
152-
# print(line_str[19:].split('"')[5], end="")
153-
# elif len(inter_token_latencies) == 20:
154-
# print("...")
132+
time_to_next_token = timer()
133+
inter_token_latency = time_to_next_token - prev_token_time
134+
prev_token_time = time_to_next_token
135+
inter_token_latencies.append(inter_token_latency)
136+
137+
print(delta.content, end="", flush=True)
155138

139+
total_time = timer() - start_time
156140
# Calculate total metrics
157141

158142
if verbosity:

providers/base_provider.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,22 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
3737
timeout=(1, 2)
3838
)
3939
elapsed = timer() - start
40+
41+
usage = getattr(response, "usage", None)
42+
total_tokens = 0
43+
if usage:
44+
total_tokens = getattr(usage, "completion_tokens", None) or getattr(usage, "output_tokens", None) or 0
45+
46+
tbt = elapsed / max(total_tokens, 1)
47+
tps = (total_tokens / elapsed)
48+
49+
self.log_metrics(model, "totaltokens", total_tokens)
50+
self.log_metrics(model, "timebetweentokens", tbt)
51+
self.log_metrics(model, "tps", tps)
4052
self.log_metrics(model, "response_times", elapsed)
53+
4154
if verbosity:
55+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
4256
self.display_response(response, elapsed)
4357
return elapsed
4458

providers/cloudflare_provider.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,25 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
6161
)
6262

6363
elapsed = timer() - start_time
64-
# print("request sucess")
65-
# log response times metric
66-
self.log_metrics(model, "response_times", elapsed)
6764

6865
inference = response.json()
66+
67+
meta = inference.get("result", {})
68+
usage = meta.get("usage", {})
69+
total_tokens = usage.get("completion_tokens") or 0
70+
71+
tbt = elapsed / max(total_tokens, 1)
72+
tps = (total_tokens / elapsed)
73+
74+
self.log_metrics(model, "response_times", elapsed)
75+
self.log_metrics(model, "totaltokens", total_tokens)
76+
self.log_metrics(model, "timebetweentokens", tbt)
77+
self.log_metrics(model, "tps", tps)
78+
6979
print(inference)
7080
# logging.debug(inference["result"]["response"])
7181
if verbosity:
82+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
7283
print(inference["result"]["response"][:50])
7384

7485
print(f"#### _Generated in *{elapsed:.2f}* seconds_")

providers/google_provider.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,19 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
6060
)
6161
elapsed = timer() - start_time
6262

63+
usage = getattr(response, "usage_metadata", None)
64+
total_tokens = (getattr(usage, "candidates_token_count", 0) or 0) if usage else 0
65+
66+
tbt = elapsed / max(total_tokens, 1)
67+
tps = (total_tokens / elapsed)
68+
6369
self.log_metrics(model, "response_times", elapsed)
70+
self.log_metrics(model, "totaltokens", total_tokens)
71+
self.log_metrics(model, "timebetweentokens", tbt)
72+
self.log_metrics(model, "tps", tps)
73+
6474
if verbosity:
75+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
6576
print(response.text)
6677
print(f"\nGenerated in {elapsed:.2f} seconds")
6778
return elapsed

providers/vllm_provider.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,20 @@ def perform_inference(self, model, prompt, vllm_ip, max_output=100, verbosity=Tr
4949
)
5050
elapsed = timer() - start_time
5151

52-
# Log response times metric
52+
data = response.json()
53+
usage = data.get("usage") or {}
54+
total_tokens = usage.get("completion_tokens")
55+
56+
tbt = elapsed / max(total_tokens, 1)
57+
tps = (total_tokens / elapsed)
58+
59+
self.log_metrics(model, "totaltokens", total_tokens)
60+
self.log_metrics(model, "timebetweentokens", tbt)
61+
self.log_metrics(model, "tps", tps)
5362
self.log_metrics(model, "response_times", elapsed)
5463

5564
if verbosity:
65+
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
5666
print(f"#### _Generated in *{elapsed:.2f}* seconds_")
5767

5868
print(response)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ groq==0.13.0
1818
google-generativeai==0.8.3
1919
fastapi==0.115.6
2020
uvicorn==0.32.1
21-
pytest-asyncio==0.25.0
21+
pytest-asyncio==0.25.0
22+
azure-ai-inference==1.0.0b9

0 commit comments

Comments
 (0)