11import os
2- import requests
32import numpy as np
43from providers .base_provider import ProviderInterface
54from time import perf_counter as timer
6- import re
5+ from azure .ai .inference import ChatCompletionsClient
6+ from azure .ai .inference .models import SystemMessage , UserMessage
7+ from azure .core .credentials import AzureKeyCredential
78
89
910class Azure (ProviderInterface ):
@@ -17,51 +18,63 @@ def __init__(self):
1718 # Map model names to Azure model IDs
1819 self .model_map = {
1920 # "mistral-7b-instruct-v0.1": "mistral-7b-instruct-v0.1",
20- "meta-llama-3.1-8b-instruct" : "Meta-Llama-3- 1-8B-Instruct-fyp" ,
21+ "meta-llama-3.1-8b-instruct" : "Meta-Llama-3. 1-8B-Instruct-fyp" ,
2122 "meta-llama-3.1-70b-instruct" : "Meta-Llama-3-1-70B-Instruct-fyp" ,
2223 "mistral-large" : "Mistral-Large-2411-yatcd" ,
2324 "common-model" : "Mistral-Large-2411-yatcd" ,
2425 }
2526
27+ self ._client = None
28+
29+ def _ensure_client (self ):
30+ """
31+ Create the Azure client only when first used.
32+ Raise a clear error if env vars are missing.
33+ """
34+ if self ._client is not None :
35+ return
36+
37+ if not self .api_key or not isinstance (self .api_key , str ):
38+ raise RuntimeError (
39+ "Azure provider misconfigured: AZURE_AI_API_KEY is missing or not a string."
40+ )
41+ if not self .endpoint :
42+ raise RuntimeError (
43+ "Azure provider misconfigured: AZURE_AI_ENDPOINT is missing."
44+ )
45+
46+ credential = AzureKeyCredential (self .api_key )
47+ self ._client = ChatCompletionsClient (
48+ endpoint = self .endpoint ,
49+ credential = credential ,
50+ api_version = "2024-05-01-preview" ,
51+ )
52+
2653 def get_model_name (self , model ):
2754 """Retrieve the model name based on the input key."""
2855 return self .model_map .get (model , None )
2956
3057 def perform_inference (self , model , prompt , max_output = 100 , verbosity = True ):
3158 """Performs non-streaming inference request to Azure."""
3259 try :
60+ self ._ensure_client ()
61+ client = self ._client
3362 model_id = self .get_model_name (model )
3463 if model_id is None :
3564 print (f"Model { model } not available." )
3665 return None
37- endpoint = (self .endpoint or "https://example.invalid" ).rstrip ("/" )
38- api_key = self .api_key
3966 start_time = timer ()
40- response = requests .post (
41- f"{ endpoint } " ,
42- headers = {
43- "Authorization" : f"Bearer { api_key } " ,
44- "Content-Type" : "application/json" ,
45- },
46- json = {
47- "model" : model_id ,
48- "messages" : [
49- {"role" : "system" , "content" : self .system_prompt },
50- {"role" : "user" , "content" : prompt },
51- ],
52- "max_tokens" : max_output ,
53- },
54- timeout = 500 ,
67+ response = client .complete (
68+ messages = [
69+ SystemMessage (content = self .system_prompt ),
70+ UserMessage (content = prompt ),
71+ ],
72+ max_tokens = max_output ,
73+ model = model_id
5574 )
5675 elapsed = timer () - start_time
57- if response .status_code != 200 :
58- print (f"Error: { response .status_code } - { response .text } " )
59- return None
6076
61- # Parse and display response
62- inference = response .json ()
63-
64- usage = inference .get ("usage" )
77+ usage = response .get ("usage" )
6578 total_tokens = usage .get ("completion_tokens" ) or 0
6679 tbt = elapsed / max (total_tokens , 1 )
6780 tps = (total_tokens / elapsed )
@@ -73,8 +86,8 @@ def perform_inference(self, model, prompt, max_output=100, verbosity=True):
7386
7487 if verbosity :
7588 print (f"Tokens: { total_tokens } , Avg TBT: { tbt :.4f} s, TPS: { tps :.2f} " )
76- print (f"Response: { inference ['choices' ][0 ]['message' ]['content' ]} " )
77- return inference
89+ print (f"Response: { response ['choices' ][0 ]['message' ]['content' ]} " )
90+ return response
7891
7992 except Exception as e :
8093 print (f"[ERROR] Inference failed for model '{ model } ': { e } " )
@@ -84,72 +97,46 @@ def perform_inference_streaming(
8497 self , model , prompt , max_output = 100 , verbosity = True
8598 ):
8699 """Performs streaming inference request to Azure."""
100+ self ._ensure_client ()
101+ client = self ._client
87102 model_id = self .get_model_name (model )
88- api_key = self .get_model_api_key (model )
89103 if model_id is None :
90104 print (f"Model { model } not available." )
91105 return None
92106
93107 inter_token_latencies = []
94- endpoint = f"https://{ model_id } .eastus.models.ai.azure.com/chat/completions"
95108 start_time = timer ()
96109 try :
97- response = requests .post (
98- f"{ endpoint } " ,
99- headers = {
100- "Authorization" : f"Bearer { api_key } " ,
101- "Content-Type" : "application/json" ,
102- },
103- json = {
104- "messages" : [
105- # {"role": "system", "content": self.system_prompt + "\nThe number appended at the end is not important."},
106- # {"role": "user", "content": prompt + " " + str(timer())},
107- {"role" : "system" , "content" : self .system_prompt },
108- {"role" : "user" , "content" : prompt },
109- ],
110- "max_tokens" : max_output ,
111- "stream" : True ,
112- },
113- stream = True ,
114- timeout = 500 ,
115- )
116-
117110 first_token_time = None
118- for line in response .iter_lines ():
119- if line :
120- # print(line)
121- if first_token_time is None :
122- # print(line)
123- first_token_time = timer ()
124- ttft = first_token_time - start_time
125- prev_token_time = first_token_time
126- if verbosity :
111+ with client .complete (
112+ stream = True ,
113+ messages = [
114+ SystemMessage (content = self .system_prompt ),
115+ UserMessage (content = prompt ),
116+ ],
117+ max_tokens = max_output ,
118+ model = model_id
119+ ) as response :
120+ for event in response :
121+ if not event .choices or not event .choices [0 ].delta :
122+ continue
123+
124+ delta = event .choices [0 ].delta
125+ if delta .content :
126+ if first_token_time is None :
127+ first_token_time = timer ()
128+ ttft = first_token_time - start_time
129+ prev_token_time = first_token_time
127130 print (f"##### Time to First Token (TTFT): { ttft :.4f} seconds\n " )
128131
129- line_str = line .decode ("utf-8" ).strip ()
130-
131- if line_str == "data: [DONE]" :
132- # print(line_str)
133- # print("here")
134- total_time = timer () - start_time
135- break
136-
137- # Capture token timing
138- time_to_next_token = timer ()
139- inter_token_latency = time_to_next_token - prev_token_time
140- prev_token_time = time_to_next_token
141- inter_token_latencies .append (inter_token_latency )
142-
143- # Display token if verbosity is enabled
144- match = re .search (r'"content"\s*:\s*"(.*?)"' , line_str )
145- if match :
146- print (match .group (1 ), end = "" )
147- # if verbosity:
148- # if len(inter_token_latencies) < 20:
149- # print(line_str[19:].split('"')[5], end="")
150- # elif len(inter_token_latencies) == 20:
151- # print("...")
132+ time_to_next_token = timer ()
133+ inter_token_latency = time_to_next_token - prev_token_time
134+ prev_token_time = time_to_next_token
135+ inter_token_latencies .append (inter_token_latency )
136+
137+ print (delta .content , end = "" , flush = True )
152138
139+ total_time = timer () - start_time
153140 # Calculate total metrics
154141
155142 if verbosity :
0 commit comments