Skip to content

Commit 6e11d5c

Browse files
Boris Devclaude
authored andcommitted
feat: Add comprehensive Azure OpenAI support
- Add Azure OpenAI provider support with full configuration - Update model mapping to support azure/ prefix - Add Azure configuration examples to .env.example - Improve Azure test error messages and validation - Add proper Azure error handling and helpful config guidance 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent e9c8cf8 commit 6e11d5c

File tree

4 files changed

+170
-20
lines changed

4 files changed

+170
-20
lines changed

.env.example

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ OPENAI_API_KEY="sk-..."
44
GEMINI_API_KEY="your-google-ai-studio-key"
55

66
# Optional: Provider Preference and Model Mapping
7-
# Controls which provider (google or openai) is preferred for mapping haiku/sonnet.
7+
# Controls which provider (google, openai, or azure) is preferred for mapping haiku/sonnet.
88
# Defaults to openai if not set.
99
PREFERRED_PROVIDER="openai"
1010

@@ -18,4 +18,15 @@ PREFERRED_PROVIDER="openai"
1818
# Example Google mapping:
1919
# PREFERRED_PROVIDER="google"
2020
# BIG_MODEL="gemini-2.5-pro-preview-03-25"
21-
# SMALL_MODEL="gemini-2.0-flash"
21+
# SMALL_MODEL="gemini-2.0-flash"
22+
23+
# If PREFERRED_PROVIDER=azure, these should match your Azure deployment names.
24+
#
25+
# Example Azure mapping:
26+
# PREFERRED_PROVIDER="azure"
27+
# BIG_MODEL="your-deployment-name"
28+
# SMALL_MODEL="your-deployment-name"
29+
# Azure-specific settings:
30+
# AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com"
31+
# AZURE_OPENAI_API_KEY="your-azure-openai-api-key"
32+
# AZURE_API_VERSION="your-api-version"

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,18 @@ BIG_MODEL="gpt-4o" # Example specific model
141141
SMALL_MODEL="gpt-4o-mini" # Example specific model
142142
```
143143

144+
145+
**Example 4: Use Azure OpenAI**
146+
```dotenv
147+
PREFERRED_PROVIDER="azure"
148+
BIG_MODEL="your-gpt4-deployment"
149+
SMALL_MODEL="your-gpt4-mini-deployment"
150+
151+
# Azure OpenAI Configuration
152+
AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com"
153+
AZURE_OPENAI_API_KEY="your-azure-openai-api-key"
154+
AZURE_API_VERSION="2024-02-15-preview"
155+
```
144156
## How It Works 🧩
145157

146158
This proxy works by:

server.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ def format(self, record):
8282
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
8383
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
8484

85+
# Azure OpenAI configuration
86+
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
87+
AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
88+
AZURE_API_VERSION = os.environ.get("AZURE_API_VERSION")
89+
AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME")
90+
91+
# List of Azure models (deployment names)
92+
AZURE_MODELS = [
93+
# These are typically deployment names, not model names
94+
# Users configure their own deployment names in Azure OpenAI Studio
95+
]
96+
8597
# Get preferred provider (default to openai)
8698
PREFERRED_PROVIDER = os.environ.get("PREFERRED_PROVIDER", "openai").lower()
8799

@@ -112,6 +124,12 @@ def format(self, record):
112124
"gemini-2.0-flash"
113125
]
114126

127+
# List of Azure models (deployment names)
128+
AZURE_MODELS = [
129+
# These are typically deployment names, not model names
130+
# Users configure their own deployment names in Azure OpenAI Studio
131+
]
132+
115133
# Helper function to clean schema for Gemini
116134
def clean_gemini_schema(schema: Any) -> Any:
117135
"""Recursively removes unsupported fields from a JSON schema for Gemini."""
@@ -202,12 +220,17 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
202220
clean_v = clean_v[7:]
203221
elif clean_v.startswith('gemini/'):
204222
clean_v = clean_v[7:]
223+
elif clean_v.startswith('azure/'):
224+
clean_v = clean_v[6:]
205225

206226
# --- Mapping Logic --- START ---
207227
mapped = False
208228
# Map Haiku to SMALL_MODEL based on provider preference
209229
if 'haiku' in clean_v.lower():
210-
if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
230+
if PREFERRED_PROVIDER == "azure":
231+
new_model = f"azure/{SMALL_MODEL}"
232+
mapped = True
233+
elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
211234
new_model = f"gemini/{SMALL_MODEL}"
212235
mapped = True
213236
else:
@@ -216,7 +239,10 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
216239

217240
# Map Sonnet to BIG_MODEL based on provider preference
218241
elif 'sonnet' in clean_v.lower():
219-
if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
242+
if PREFERRED_PROVIDER == "azure":
243+
new_model = f"azure/{BIG_MODEL}"
244+
mapped = True
245+
elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
220246
new_model = f"gemini/{BIG_MODEL}"
221247
mapped = True
222248
else:
@@ -237,7 +263,7 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
237263
logger.debug(f"📌 MODEL MAPPING: '{original_model}' ➡️ '{new_model}'")
238264
else:
239265
# If no mapping occurred and no prefix exists, log warning or decide default
240-
if not v.startswith(('openai/', 'gemini/', 'anthropic/')):
266+
if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'azure/')):
241267
logger.warning(f"⚠️ No prefix or mapping rule for model: '{original_model}'. Using as is.")
242268
new_model = v # Ensure we return the original if no rule applied
243269

@@ -275,12 +301,17 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
275301
clean_v = clean_v[7:]
276302
elif clean_v.startswith('gemini/'):
277303
clean_v = clean_v[7:]
304+
elif clean_v.startswith('azure/'):
305+
clean_v = clean_v[6:]
278306

279307
# --- Mapping Logic --- START ---
280308
mapped = False
281309
# Map Haiku to SMALL_MODEL based on provider preference
282310
if 'haiku' in clean_v.lower():
283-
if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
311+
if PREFERRED_PROVIDER == "azure":
312+
new_model = f"azure/{SMALL_MODEL}"
313+
mapped = True
314+
elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
284315
new_model = f"gemini/{SMALL_MODEL}"
285316
mapped = True
286317
else:
@@ -289,7 +320,10 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
289320

290321
# Map Sonnet to BIG_MODEL based on provider preference
291322
elif 'sonnet' in clean_v.lower():
292-
if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
323+
if PREFERRED_PROVIDER == "azure":
324+
new_model = f"azure/{BIG_MODEL}"
325+
mapped = True
326+
elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
293327
new_model = f"gemini/{BIG_MODEL}"
294328
mapped = True
295329
else:
@@ -309,7 +343,7 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
309343
if mapped:
310344
logger.debug(f"📌 TOKEN COUNT MAPPING: '{original_model}' ➡️ '{new_model}'")
311345
else:
312-
if not v.startswith(('openai/', 'gemini/', 'anthropic/')):
346+
if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'azure/')):
313347
logger.warning(f"⚠️ No prefix or mapping rule for token count model: '{original_model}'. Using as is.")
314348
new_model = v # Ensure we return the original if no rule applied
315349

@@ -533,9 +567,9 @@ def convert_anthropic_to_litellm(anthropic_request: MessagesRequest) -> Dict[str
533567

534568
# Cap max_tokens for OpenAI models to their limit of 16384
535569
max_tokens = anthropic_request.max_tokens
536-
if anthropic_request.model.startswith("openai/") or anthropic_request.model.startswith("gemini/"):
570+
if anthropic_request.model.startswith("openai/") or anthropic_request.model.startswith("gemini/") or anthropic_request.model.startswith("azure/"):
537571
max_tokens = min(max_tokens, 16384)
538-
logger.debug(f"Capping max_tokens to 16384 for OpenAI/Gemini model (original value: {anthropic_request.max_tokens})")
572+
logger.debug(f"Capping max_tokens to 16384 for OpenAI/Gemini/Azure model (original value: {anthropic_request.max_tokens})")
539573

540574
# Create LiteLLM request dict
541575
litellm_request = {
@@ -1110,13 +1144,18 @@ async def create_message(
11101144
elif request.model.startswith("gemini/"):
11111145
litellm_request["api_key"] = GEMINI_API_KEY
11121146
logger.debug(f"Using Gemini API key for model: {request.model}")
1147+
elif request.model.startswith("azure/"):
1148+
litellm_request["api_key"] = AZURE_OPENAI_API_KEY
1149+
litellm_request["api_base"] = AZURE_OPENAI_ENDPOINT
1150+
litellm_request["api_version"] = AZURE_API_VERSION
1151+
logger.debug(f"Using Azure OpenAI API key for model: {request.model}")
11131152
else:
11141153
litellm_request["api_key"] = ANTHROPIC_API_KEY
11151154
logger.debug(f"Using Anthropic API key for model: {request.model}")
11161155

11171156
# For OpenAI models - modify request format to work with limitations
1118-
if "openai" in litellm_request["model"] and "messages" in litellm_request:
1119-
logger.debug(f"Processing OpenAI model request: {litellm_request['model']}")
1157+
if ("openai" in litellm_request["model"] or "azure" in litellm_request["model"]) and "messages" in litellm_request:
1158+
logger.debug(f"Processing OpenAI/Azure model request: {litellm_request['model']}")
11201159

11211160
# For OpenAI models, we need to convert content blocks to simple strings
11221161
# and handle other requirements

tests.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,83 @@ def test_request(test_name, request_data, check_tools=False):
361361
traceback.print_exc()
362362
return False
363363

364+
def test_azure_model_mapping():
365+
"""Test Azure model mapping specifically."""
366+
print(f"\n{'='*20} RUNNING AZURE MODEL MAPPING TEST {'='*20}")
367+
368+
# Test different Azure configurations
369+
azure_tests = [
370+
{
371+
"name": "direct_azure",
372+
"model": "azure/my-deployment",
373+
"expected_prefix": "azure/",
374+
"env_vars": {}
375+
}
376+
]
377+
378+
success_count = 0
379+
380+
for test_case in azure_tests:
381+
print(f"\n--- Testing {test_case['name']} ---")
382+
383+
# Create test request
384+
test_data = {
385+
"model": test_case["model"],
386+
"max_tokens": 100,
387+
"messages": [{"role": "user", "content": "Test message"}]
388+
}
389+
390+
try:
391+
# We'll just test that the proxy accepts the request and processes it
392+
# without necessarily getting a real response (since we may not have Azure configured)
393+
print(f"Testing model: {test_case['model']}")
394+
395+
# Use dummy headers for Azure-only testing if no API key is set
396+
test_headers = proxy_headers
397+
if not ANTHROPIC_API_KEY:
398+
test_headers = {
399+
"x-api-key": "dummy-key-for-azure-testing",
400+
"anthropic-version": ANTHROPIC_VERSION,
401+
"content-type": "application/json",
402+
}
403+
404+
proxy_response = httpx.post(PROXY_API_URL, headers=test_headers, json=test_data, timeout=10)
405+
406+
print(f"Status code: {proxy_response.status_code}")
407+
408+
# Accept both success and certain types of failures (like missing Azure config)
409+
# The important thing is that the model mapping logic works
410+
if proxy_response.status_code == 200:
411+
print("✅ Request succeeded - Azure mapping worked")
412+
success_count += 1
413+
elif proxy_response.status_code in [400, 401, 403, 422, 500]:
414+
response_text = proxy_response.text
415+
print(f"❌ Azure test failed (HTTP {proxy_response.status_code})")
416+
print(f"Response: {response_text}")
417+
418+
# Check if it's a config issue and provide helpful message
419+
if "azure" in response_text.lower() or "deployment" in response_text.lower():
420+
print("💡 This looks like an Azure configuration issue.")
421+
elif proxy_response.status_code == 500:
422+
print("💡 HTTP 500 - This could be missing Azure config or incorrect deployment name.")
423+
424+
print("💡 To fix Azure configuration, ensure you have these in your .env:")
425+
print(" AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/")
426+
print(" AZURE_OPENAI_API_KEY=your-azure-api-key")
427+
print(" AZURE_API_VERSION=your-api-version")
428+
print(" BIG_MODEL=your-deployment-name")
429+
print(" SMALL_MODEL=your-deployment-name")
430+
# Don't increment success_count - this is a failure
431+
else:
432+
print(f"❌ Unexpected status code: {proxy_response.status_code}")
433+
print(f"Response: {proxy_response.text}")
434+
435+
except Exception as e:
436+
print(f"❌ Exception during test: {e}")
437+
438+
print(f"\nAzure mapping tests: {success_count}/{len(azure_tests)} passed")
439+
return success_count == len(azure_tests)
440+
364441
# ================= STREAMING TESTS =================
365442

366443
class StreamStats:
@@ -636,8 +713,18 @@ async def run_tests(args):
636713
# Track test results
637714
results = {}
638715

639-
# First run non-streaming tests
716+
# Run Azure-specific tests first (unless we're doing streaming-only)
640717
if not args.streaming_only:
718+
print("\n\n=========== RUNNING AZURE INTEGRATION TESTS ===========\n")
719+
azure_result = test_azure_model_mapping()
720+
results["azure_model_mapping"] = azure_result
721+
722+
# If Azure-only is specified, skip other tests
723+
if args.azure_only:
724+
print("\n\n=========== AZURE-ONLY MODE - SKIPPING OTHER TESTS ===========\n")
725+
726+
# First run non-streaming tests
727+
elif not args.streaming_only:
641728
print("\n\n=========== RUNNING NON-STREAMING TESTS ===========\n")
642729
for test_name, test_data in TEST_SCENARIOS.items():
643730
# Skip streaming tests
@@ -658,7 +745,7 @@ async def run_tests(args):
658745
results[test_name] = result
659746

660747
# Now run streaming tests
661-
if not args.no_streaming:
748+
if not args.no_streaming and not args.azure_only:
662749
print("\n\n=========== RUNNING STREAMING TESTS ===========\n")
663750
for test_name, test_data in TEST_SCENARIOS.items():
664751
# Only select streaming tests, or force streaming
@@ -695,19 +782,20 @@ async def run_tests(args):
695782
return False
696783

697784
async def main():
698-
# Check that API key is set
699-
if not ANTHROPIC_API_KEY:
700-
print("Error: ANTHROPIC_API_KEY not set in .env file")
701-
return
702-
703-
# Parse command-line arguments
785+
# Parse command-line arguments first
704786
parser = argparse.ArgumentParser(description="Test the Claude-on-OpenAI proxy")
705787
parser.add_argument("--no-streaming", action="store_true", help="Skip streaming tests")
706788
parser.add_argument("--streaming-only", action="store_true", help="Only run streaming tests")
707789
parser.add_argument("--simple", action="store_true", help="Only run simple tests (no tools)")
708790
parser.add_argument("--tools-only", action="store_true", help="Only run tool tests")
791+
parser.add_argument("--azure-only", action="store_true", help="Only run Azure integration tests")
709792
args = parser.parse_args()
710793

794+
# Check that API key is set (unless we're only testing Azure)
795+
if not ANTHROPIC_API_KEY and not args.azure_only:
796+
print("Error: ANTHROPIC_API_KEY not set in .env file")
797+
return
798+
711799
# Run tests
712800
success = await run_tests(args)
713801
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)