Skip to content

Commit 29ce11a

Browse files
committed
restructured azure API calls to match Azure AI Foundry docs
1 parent 3df277c commit 29ce11a

File tree

2 files changed

+72
-84
lines changed

2 files changed

+72
-84
lines changed

providers/azure_provider.py

Lines changed: 70 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2-
import requests
32
import numpy as np
43
from providers.base_provider import ProviderInterface
54
from time import perf_counter as timer
6-
import re
5+
from azure.ai.inference import ChatCompletionsClient
6+
from azure.ai.inference.models import SystemMessage, UserMessage
7+
from azure.core.credentials import AzureKeyCredential
78

89

910
class Azure(ProviderInterface):
@@ -17,51 +18,63 @@ def __init__(self):
1718
# Map model names to Azure model IDs
1819
self.model_map = {
1920
# "mistral-7b-instruct-v0.1": "mistral-7b-instruct-v0.1",
20-
"meta-llama-3.1-8b-instruct": "Meta-Llama-3-1-8B-Instruct-fyp",
21+
"meta-llama-3.1-8b-instruct": "Meta-Llama-3.1-8B-Instruct-fyp",
2122
"meta-llama-3.1-70b-instruct": "Meta-Llama-3-1-70B-Instruct-fyp",
2223
"mistral-large": "Mistral-Large-2411-yatcd",
2324
"common-model": "Mistral-Large-2411-yatcd",
2425
}
2526

27+
self._client = None
28+
29+
def _ensure_client(self):
30+
"""
31+
Create the Azure client only when first used.
32+
Raise a clear error if env vars are missing.
33+
"""
34+
if self._client is not None:
35+
return
36+
37+
if not self.api_key or not isinstance(self.api_key, str):
38+
raise RuntimeError(
39+
"Azure provider misconfigured: AZURE_AI_API_KEY is missing or not a string."
40+
)
41+
if not self.endpoint:
42+
raise RuntimeError(
43+
"Azure provider misconfigured: AZURE_AI_ENDPOINT is missing."
44+
)
45+
46+
credential = AzureKeyCredential(self.api_key)
47+
self._client = ChatCompletionsClient(
48+
endpoint=self.endpoint,
49+
credential=credential,
50+
api_version="2024-05-01-preview",
51+
)
52+
2653
def get_model_name(self, model):
2754
"""Retrieve the model name based on the input key."""
2855
return self.model_map.get(model, None)
2956

3057
def perform_inference(self, model, prompt, max_output=100, verbosity=True):
3158
"""Performs non-streaming inference request to Azure."""
3259
try:
60+
self._ensure_client()
61+
client = self._client
3362
model_id = self.get_model_name(model)
3463
if model_id is None:
3564
print(f"Model {model} not available.")
3665
return None
37-
endpoint = (self.endpoint or "https://example.invalid").rstrip("/")
38-
api_key = self.api_key
3966
start_time = timer()
40-
response = requests.post(
41-
f"{endpoint}",
42-
headers={
43-
"Authorization": f"Bearer {api_key}",
44-
"Content-Type": "application/json",
45-
},
46-
json={
47-
"model": model_id,
48-
"messages": [
49-
{"role": "system", "content": self.system_prompt},
50-
{"role": "user", "content": prompt},
51-
],
52-
"max_tokens": max_output,
53-
},
54-
timeout=500,
67+
response = client.complete(
68+
messages=[
69+
SystemMessage(content=self.system_prompt),
70+
UserMessage(content=prompt),
71+
],
72+
max_tokens=max_output,
73+
model=model_id
5574
)
5675
elapsed = timer() - start_time
57-
if response.status_code != 200:
58-
print(f"Error: {response.status_code} - {response.text}")
59-
return None
6076

61-
# Parse and display response
62-
inference = response.json()
63-
64-
usage = inference.get("usage")
77+
usage = response.get("usage")
6578
total_tokens = usage.get("completion_tokens") or 0
6679
tbt = elapsed / max(total_tokens, 1)
6780
tps = (total_tokens / elapsed)
@@ -73,8 +86,8 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
7386

7487
if verbosity:
7588
print(f"Tokens: {total_tokens}, Avg TBT: {tbt:.4f}s, TPS: {tps:.2f}")
76-
print(f"Response: {inference['choices'][0]['message']['content']}")
77-
return inference
89+
print(f"Response: {response['choices'][0]['message']['content']}")
90+
return response
7891

7992
except Exception as e:
8093
print(f"[ERROR] Inference failed for model '{model}': {e}")
@@ -84,72 +97,46 @@ def perform_inference_streaming(
8497
self, model, prompt, max_output=100, verbosity=True
8598
):
8699
"""Performs streaming inference request to Azure."""
100+
self._ensure_client()
101+
client = self._client
87102
model_id = self.get_model_name(model)
88-
api_key = self.get_model_api_key(model)
89103
if model_id is None:
90104
print(f"Model {model} not available.")
91105
return None
92106

93107
inter_token_latencies = []
94-
endpoint = f"https://{model_id}.eastus.models.ai.azure.com/chat/completions"
95108
start_time = timer()
96109
try:
97-
response = requests.post(
98-
f"{endpoint}",
99-
headers={
100-
"Authorization": f"Bearer {api_key}",
101-
"Content-Type": "application/json",
102-
},
103-
json={
104-
"messages": [
105-
# {"role": "system", "content": self.system_prompt + "\nThe number appended at the end is not important."},
106-
# {"role": "user", "content": prompt + " " + str(timer())},
107-
{"role": "system", "content": self.system_prompt},
108-
{"role": "user", "content": prompt},
109-
],
110-
"max_tokens": max_output,
111-
"stream": True,
112-
},
113-
stream=True,
114-
timeout=500,
115-
)
116-
117110
first_token_time = None
118-
for line in response.iter_lines():
119-
if line:
120-
# print(line)
121-
if first_token_time is None:
122-
# print(line)
123-
first_token_time = timer()
124-
ttft = first_token_time - start_time
125-
prev_token_time = first_token_time
126-
if verbosity:
111+
with client.complete(
112+
stream=True,
113+
messages=[
114+
SystemMessage(content=self.system_prompt),
115+
UserMessage(content=prompt),
116+
],
117+
max_tokens=max_output,
118+
model=model_id
119+
) as response:
120+
for event in response:
121+
if not event.choices or not event.choices[0].delta:
122+
continue
123+
124+
delta = event.choices[0].delta
125+
if delta.content:
126+
if first_token_time is None:
127+
first_token_time = timer()
128+
ttft = first_token_time - start_time
129+
prev_token_time = first_token_time
127130
print(f"##### Time to First Token (TTFT): {ttft:.4f} seconds\n")
128131

129-
line_str = line.decode("utf-8").strip()
130-
131-
if line_str == "data: [DONE]":
132-
# print(line_str)
133-
# print("here")
134-
total_time = timer() - start_time
135-
break
136-
137-
# Capture token timing
138-
time_to_next_token = timer()
139-
inter_token_latency = time_to_next_token - prev_token_time
140-
prev_token_time = time_to_next_token
141-
inter_token_latencies.append(inter_token_latency)
142-
143-
# Display token if verbosity is enabled
144-
match = re.search(r'"content"\s*:\s*"(.*?)"', line_str)
145-
if match:
146-
print(match.group(1), end="")
147-
# if verbosity:
148-
# if len(inter_token_latencies) < 20:
149-
# print(line_str[19:].split('"')[5], end="")
150-
# elif len(inter_token_latencies) == 20:
151-
# print("...")
132+
time_to_next_token = timer()
133+
inter_token_latency = time_to_next_token - prev_token_time
134+
prev_token_time = time_to_next_token
135+
inter_token_latencies.append(inter_token_latency)
136+
137+
print(delta.content, end="", flush=True)
152138

139+
total_time = timer() - start_time
153140
# Calculate total metrics
154141

155142
if verbosity:

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ groq==0.13.0
1818
google-generativeai==0.8.3
1919
fastapi==0.115.6
2020
uvicorn==0.32.1
21-
pytest-asyncio==0.25.0
21+
pytest-asyncio==0.25.0
22+
azure-ai-inference==1.0.0b9

0 commit comments

Comments
 (0)