@@ -82,6 +82,18 @@ def format(self, record):
8282OPENAI_API_KEY = os .environ .get ("OPENAI_API_KEY" )
8383GEMINI_API_KEY = os .environ .get ("GEMINI_API_KEY" )
8484
85+ # Azure OpenAI configuration
86+ AZURE_OPENAI_ENDPOINT = os .environ .get ("AZURE_OPENAI_ENDPOINT" )
87+ AZURE_OPENAI_API_KEY = os .environ .get ("AZURE_OPENAI_API_KEY" )
88+ AZURE_API_VERSION = os .environ .get ("AZURE_API_VERSION" )
89+ AZURE_DEPLOYMENT_NAME = os .environ .get ("AZURE_DEPLOYMENT_NAME" )
90+
91+ # List of Azure models (deployment names)
92+ AZURE_MODELS = [
93+ # These are typically deployment names, not model names
94+ # Users configure their own deployment names in Azure OpenAI Studio
95+ ]
96+
8597# Get preferred provider (default to openai)
8698PREFERRED_PROVIDER = os .environ .get ("PREFERRED_PROVIDER" , "openai" ).lower ()
8799
@@ -112,6 +124,12 @@ def format(self, record):
112124 "gemini-2.0-flash"
113125]
114126
127+ # List of Azure models (deployment names)
128+ AZURE_MODELS = [
129+ # These are typically deployment names, not model names
130+ # Users configure their own deployment names in Azure OpenAI Studio
131+ ]
132+
115133# Helper function to clean schema for Gemini
116134def clean_gemini_schema (schema : Any ) -> Any :
117135 """Recursively removes unsupported fields from a JSON schema for Gemini."""
@@ -202,12 +220,17 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
202220 clean_v = clean_v [7 :]
203221 elif clean_v .startswith ('gemini/' ):
204222 clean_v = clean_v [7 :]
223+ elif clean_v .startswith ('azure/' ):
224+ clean_v = clean_v [6 :]
205225
206226 # --- Mapping Logic --- START ---
207227 mapped = False
208228 # Map Haiku to SMALL_MODEL based on provider preference
209229 if 'haiku' in clean_v .lower ():
210- if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS :
230+ if PREFERRED_PROVIDER == "azure" :
231+ new_model = f"azure/{ SMALL_MODEL } "
232+ mapped = True
233+ elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS :
211234 new_model = f"gemini/{ SMALL_MODEL } "
212235 mapped = True
213236 else :
@@ -216,7 +239,10 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
216239
217240 # Map Sonnet to BIG_MODEL based on provider preference
218241 elif 'sonnet' in clean_v .lower ():
219- if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS :
242+ if PREFERRED_PROVIDER == "azure" :
243+ new_model = f"azure/{ BIG_MODEL } "
244+ mapped = True
245+ elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS :
220246 new_model = f"gemini/{ BIG_MODEL } "
221247 mapped = True
222248 else :
@@ -237,7 +263,7 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
237263 logger .debug (f"📌 MODEL MAPPING: '{ original_model } ' ➡️ '{ new_model } '" )
238264 else :
239265 # If no mapping occurred and no prefix exists, log warning or decide default
240- if not v .startswith (('openai/' , 'gemini/' , 'anthropic/' )):
266+ if not v .startswith (('openai/' , 'gemini/' , 'anthropic/' , 'azure/' )):
241267 logger .warning (f"⚠️ No prefix or mapping rule for model: '{ original_model } '. Using as is." )
242268 new_model = v # Ensure we return the original if no rule applied
243269
@@ -275,12 +301,17 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
275301 clean_v = clean_v [7 :]
276302 elif clean_v .startswith ('gemini/' ):
277303 clean_v = clean_v [7 :]
304+ elif clean_v .startswith ('azure/' ):
305+ clean_v = clean_v [6 :]
278306
279307 # --- Mapping Logic --- START ---
280308 mapped = False
281309 # Map Haiku to SMALL_MODEL based on provider preference
282310 if 'haiku' in clean_v .lower ():
283- if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS :
311+ if PREFERRED_PROVIDER == "azure" :
312+ new_model = f"azure/{ SMALL_MODEL } "
313+ mapped = True
314+ elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS :
284315 new_model = f"gemini/{ SMALL_MODEL } "
285316 mapped = True
286317 else :
@@ -289,7 +320,10 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
289320
290321 # Map Sonnet to BIG_MODEL based on provider preference
291322 elif 'sonnet' in clean_v .lower ():
292- if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS :
323+ if PREFERRED_PROVIDER == "azure" :
324+ new_model = f"azure/{ BIG_MODEL } "
325+ mapped = True
326+ elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS :
293327 new_model = f"gemini/{ BIG_MODEL } "
294328 mapped = True
295329 else :
@@ -309,7 +343,7 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
309343 if mapped :
310344 logger .debug (f"📌 TOKEN COUNT MAPPING: '{ original_model } ' ➡️ '{ new_model } '" )
311345 else :
312- if not v .startswith (('openai/' , 'gemini/' , 'anthropic/' )):
346+ if not v .startswith (('openai/' , 'gemini/' , 'anthropic/' , 'azure/' )):
313347 logger .warning (f"⚠️ No prefix or mapping rule for token count model: '{ original_model } '. Using as is." )
314348 new_model = v # Ensure we return the original if no rule applied
315349
@@ -533,9 +567,9 @@ def convert_anthropic_to_litellm(anthropic_request: MessagesRequest) -> Dict[str
533567
534568 # Cap max_tokens for OpenAI models to their limit of 16384
535569 max_tokens = anthropic_request .max_tokens
536- if anthropic_request .model .startswith ("openai/" ) or anthropic_request .model .startswith ("gemini/" ):
570+ if anthropic_request .model .startswith ("openai/" ) or anthropic_request .model .startswith ("gemini/" ) or anthropic_request . model . startswith ( "azure/" ) :
537571 max_tokens = min (max_tokens , 16384 )
538- logger .debug (f"Capping max_tokens to 16384 for OpenAI/Gemini model (original value: { anthropic_request .max_tokens } )" )
572+ logger .debug (f"Capping max_tokens to 16384 for OpenAI/Gemini/Azure model (original value: { anthropic_request .max_tokens } )" )
539573
540574 # Create LiteLLM request dict
541575 litellm_request = {
@@ -1110,13 +1144,18 @@ async def create_message(
11101144 elif request .model .startswith ("gemini/" ):
11111145 litellm_request ["api_key" ] = GEMINI_API_KEY
11121146 logger .debug (f"Using Gemini API key for model: { request .model } " )
1147+ elif request .model .startswith ("azure/" ):
1148+ litellm_request ["api_key" ] = AZURE_OPENAI_API_KEY
1149+ litellm_request ["api_base" ] = AZURE_OPENAI_ENDPOINT
1150+ litellm_request ["api_version" ] = AZURE_API_VERSION
1151+ logger .debug (f"Using Azure OpenAI API key for model: { request .model } " )
11131152 else :
11141153 litellm_request ["api_key" ] = ANTHROPIC_API_KEY
11151154 logger .debug (f"Using Anthropic API key for model: { request .model } " )
11161155
11171156 # For OpenAI models - modify request format to work with limitations
1118- if "openai" in litellm_request ["model" ] and "messages" in litellm_request :
1119- logger .debug (f"Processing OpenAI model request: { litellm_request ['model' ]} " )
1157+ if ( "openai" in litellm_request ["model" ] or "azure" in litellm_request [ "model" ]) and "messages" in litellm_request :
1158+ logger .debug (f"Processing OpenAI/Azure model request: { litellm_request ['model' ]} " )
11201159
11211160 # For OpenAI models, we need to convert content blocks to simple strings
11221161 # and handle other requirements
0 commit comments