11import os
2- import requests
32import numpy as np
43from providers .base_provider import ProviderInterface
54from time import perf_counter as timer
6- import re
5+ from azure .ai .inference import ChatCompletionsClient
6+ from azure .ai .inference .models import SystemMessage , UserMessage
7+ from azure .core .credentials import AzureKeyCredential
78
89
910class Azure (ProviderInterface ):
1011 def __init__ (self ):
1112 """Initialize AzureProvider with required API information."""
1213 super ().__init__ ()
1314
15+ self .endpoint = os .getenv ("AZURE_AI_ENDPOINT" )
16+ self .api_key = os .getenv ("AZURE_AI_API_KEY" )
17+
1418 # Map model names to Azure model IDs
1519 self .model_map = {
1620 # "mistral-7b-instruct-v0.1": "mistral-7b-instruct-v0.1",
17- "meta-llama-3.1-8b-instruct" : "Meta-Llama-3- 1-8B-Instruct-fyp" ,
21+ "meta-llama-3.1-8b-instruct" : "Meta-Llama-3. 1-8B-Instruct-fyp" ,
1822 "meta-llama-3.1-70b-instruct" : "Meta-Llama-3-1-70B-Instruct-fyp" ,
1923 "mistral-large" : "Mistral-Large-2411-yatcd" ,
2024 "common-model" : "Mistral-Large-2411-yatcd" ,
2125 }
2226
23- # Define API keys for each model
24- self .model_api_keys = {
25- # "mistral-7b-instruct-v0.1": os.environ.get("MISTRAL_API_KEY"),
26- "meta-llama-3.1-8b-instruct" : os .environ .get ("AZURE_LLAMA_8B_API" ),
27- "meta-llama-3.1-70b-instruct" : os .environ .get ("AZURE_LLAMA_3.1_70B_API" ),
28- "mistral-large" : os .environ .get ("MISTRAL_LARGE_API" ),
29- "common-model" : os .environ .get ("MISTRAL_LARGE_API" )
30- }
27+ self ._client = None
28+
29+ def _ensure_client (self ):
30+ """
31+ Create the Azure client only when first used.
32+ Raise a clear error if env vars are missing.
33+ """
34+ if self ._client is not None :
35+ return
36+
37+ if not self .api_key or not isinstance (self .api_key , str ):
38+ raise RuntimeError (
39+ "Azure provider misconfigured: AZURE_AI_API_KEY is missing or not a string."
40+ )
41+ if not self .endpoint :
42+ raise RuntimeError (
43+ "Azure provider misconfigured: AZURE_AI_ENDPOINT is missing."
44+ )
45+
46+ credential = AzureKeyCredential (self .api_key )
47+ self ._client = ChatCompletionsClient (
48+ endpoint = self .endpoint ,
49+ credential = credential ,
50+ api_version = "2024-05-01-preview" ,
51+ )
3152
3253 def get_model_name (self , model ):
3354 """Retrieve the model name based on the input key."""
3455 return self .model_map .get (model , None )
3556
36- def get_model_api_key (self , model ):
37- """Retrieve the API key for a specific model."""
38- api_key = self .model_api_keys .get (model )
39- if not api_key :
40- raise ValueError (
41- f"No API key found for model '{ model } '. Ensure it is set in environment variables."
42- )
43- return api_key
44-
4557 def perform_inference (self , model , prompt , max_output = 100 , verbosity = True ):
4658 """Performs non-streaming inference request to Azure."""
4759 try :
60+ self ._ensure_client ()
61+ client = self ._client
4862 model_id = self .get_model_name (model )
49- api_key = self .get_model_api_key (model )
5063 if model_id is None :
5164 print (f"Model { model } not available." )
5265 return None
5366 start_time = timer ()
54- endpoint = f"https://{ model_id } .eastus.models.ai.azure.com/chat/completions"
55- response = requests .post (
56- f"{ endpoint } " ,
57- headers = {
58- "Authorization" : f"Bearer { api_key } " ,
59- "Content-Type" : "application/json" ,
60- },
61- json = {
62- "messages" : [
63- {"role" : "system" , "content" : self .system_prompt },
64- {"role" : "user" , "content" : prompt },
65- ],
66- "max_tokens" : max_output ,
67- },
68- timeout = 500 ,
67+ response = client .complete (
68+ messages = [
69+ SystemMessage (content = self .system_prompt ),
70+ UserMessage (content = prompt ),
71+ ],
72+ max_tokens = max_output ,
73+ model = model_id
6974 )
7075 elapsed = timer () - start_time
71- if response .status_code != 200 :
72- print (f"Error: { response .status_code } - { response .text } " )
73- return None
7476
75- # Parse and display response
76- inference = response .json ()
77+ usage = response .get ("usage" )
78+ total_tokens = usage .get ("completion_tokens" ) or 0
79+ tbt = elapsed / max (total_tokens , 1 )
80+ tps = (total_tokens / elapsed )
81+
7782 self .log_metrics (model , "response_times" , elapsed )
83+ self .log_metrics (model , "totaltokens" , total_tokens )
84+ self .log_metrics (model , "timebetweentokens" , tbt )
85+ self .log_metrics (model , "tps" , tps )
86+
7887 if verbosity :
79- print (f"Response: { inference ['choices' ][0 ]['message' ]['content' ]} " )
80- return inference
88+ print (f"Tokens: { total_tokens } , Avg TBT: { tbt :.4f} s, TPS: { tps :.2f} " )
89+ print (f"Response: { response ['choices' ][0 ]['message' ]['content' ]} " )
90+ return response
8191
8292 except Exception as e :
8393 print (f"[ERROR] Inference failed for model '{ model } ': { e } " )
@@ -87,72 +97,46 @@ def perform_inference_streaming(
8797 self , model , prompt , max_output = 100 , verbosity = True
8898 ):
8999 """Performs streaming inference request to Azure."""
100+ self ._ensure_client ()
101+ client = self ._client
90102 model_id = self .get_model_name (model )
91- api_key = self .get_model_api_key (model )
92103 if model_id is None :
93104 print (f"Model { model } not available." )
94105 return None
95106
96107 inter_token_latencies = []
97- endpoint = f"https://{ model_id } .eastus.models.ai.azure.com/chat/completions"
98108 start_time = timer ()
99109 try :
100- response = requests .post (
101- f"{ endpoint } " ,
102- headers = {
103- "Authorization" : f"Bearer { api_key } " ,
104- "Content-Type" : "application/json" ,
105- },
106- json = {
107- "messages" : [
108- # {"role": "system", "content": self.system_prompt + "\nThe number appended at the end is not important."},
109- # {"role": "user", "content": prompt + " " + str(timer())},
110- {"role" : "system" , "content" : self .system_prompt },
111- {"role" : "user" , "content" : prompt },
112- ],
113- "max_tokens" : max_output ,
114- "stream" : True ,
115- },
116- stream = True ,
117- timeout = 500 ,
118- )
119-
120110 first_token_time = None
121- for line in response .iter_lines ():
122- if line :
123- # print(line)
124- if first_token_time is None :
125- # print(line)
126- first_token_time = timer ()
127- ttft = first_token_time - start_time
128- prev_token_time = first_token_time
129- if verbosity :
111+ with client .complete (
112+ stream = True ,
113+ messages = [
114+ SystemMessage (content = self .system_prompt ),
115+ UserMessage (content = prompt ),
116+ ],
117+ max_tokens = max_output ,
118+ model = model_id
119+ ) as response :
120+ for event in response :
121+ if not event .choices or not event .choices [0 ].delta :
122+ continue
123+
124+ delta = event .choices [0 ].delta
125+ if delta .content :
126+ if first_token_time is None :
127+ first_token_time = timer ()
128+ ttft = first_token_time - start_time
129+ prev_token_time = first_token_time
130130 print (f"##### Time to First Token (TTFT): { ttft :.4f} seconds\n " )
131131
132- line_str = line .decode ("utf-8" ).strip ()
133-
134- if line_str == "data: [DONE]" :
135- # print(line_str)
136- # print("here")
137- total_time = timer () - start_time
138- break
139-
140- # Capture token timing
141- time_to_next_token = timer ()
142- inter_token_latency = time_to_next_token - prev_token_time
143- prev_token_time = time_to_next_token
144- inter_token_latencies .append (inter_token_latency )
145-
146- # Display token if verbosity is enabled
147- match = re .search (r'"content"\s*:\s*"(.*?)"' , line_str )
148- if match :
149- print (match .group (1 ), end = "" )
150- # if verbosity:
151- # if len(inter_token_latencies) < 20:
152- # print(line_str[19:].split('"')[5], end="")
153- # elif len(inter_token_latencies) == 20:
154- # print("...")
132+ time_to_next_token = timer ()
133+ inter_token_latency = time_to_next_token - prev_token_time
134+ prev_token_time = time_to_next_token
135+ inter_token_latencies .append (inter_token_latency )
136+
137+ print (delta .content , end = "" , flush = True )
155138
139+ total_time = timer () - start_time
156140 # Calculate total metrics
157141
158142 if verbosity :
0 commit comments