diff --git a/.env.example b/.env.example index 798ad8a..a9bd2e9 100644 --- a/.env.example +++ b/.env.example @@ -4,7 +4,7 @@ OPENAI_API_KEY="sk-..." GEMINI_API_KEY="your-google-ai-studio-key" # Optional: Provider Preference and Model Mapping -# Controls which provider (google or openai) is preferred for mapping haiku/sonnet. +# Controls which provider (google, openai, or azure) is preferred for mapping haiku/sonnet. # Defaults to openai if not set. PREFERRED_PROVIDER="openai" @@ -18,4 +18,15 @@ PREFERRED_PROVIDER="openai" # Example Google mapping: # PREFERRED_PROVIDER="google" # BIG_MODEL="gemini-2.5-pro-preview-03-25" -# SMALL_MODEL="gemini-2.0-flash" \ No newline at end of file +# SMALL_MODEL="gemini-2.0-flash" + +# If PREFERRED_PROVIDER=azure, these should match your Azure deployment names. +# +# Example Azure mapping: +# PREFERRED_PROVIDER="azure" +# BIG_MODEL="your-deployment-name" +# SMALL_MODEL="your-deployment-name" +# Azure-specific settings: +# AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com" +# AZURE_OPENAI_API_KEY="your-azure-openai-api-key" +# AZURE_API_VERSION="your-api-version" diff --git a/README.md b/README.md index 3dbbc89..8be2ce4 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,18 @@ BIG_MODEL="gpt-4o" # Example specific model SMALL_MODEL="gpt-4o-mini" # Example specific model ``` + +**Example 4: Use Azure OpenAI** +```dotenv +PREFERRED_PROVIDER="azure" +BIG_MODEL="your-gpt4-deployment" +SMALL_MODEL="your-gpt4-mini-deployment" + +# Azure OpenAI Configuration +AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com" +AZURE_OPENAI_API_KEY="your-azure-openai-api-key" +AZURE_API_VERSION="2024-02-15-preview" +``` ## How It Works 🧩 This proxy works by: diff --git a/server.py b/server.py index f4966b2..a7bae9b 100644 --- a/server.py +++ b/server.py @@ -82,6 +82,18 @@ def format(self, record): OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") +# Azure OpenAI configuration +AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT") +AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") +AZURE_API_VERSION = os.environ.get("AZURE_API_VERSION") +AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME") + +# List of Azure models (deployment names) +AZURE_MODELS = [ + # These are typically deployment names, not model names + # Users configure their own deployment names in Azure OpenAI Studio +] + # Get preferred provider (default to openai) PREFERRED_PROVIDER = os.environ.get("PREFERRED_PROVIDER", "openai").lower() @@ -112,6 +124,12 @@ def format(self, record): "gemini-2.0-flash" ] +# List of Azure models (deployment names) +AZURE_MODELS = [ + # These are typically deployment names, not model names + # Users configure their own deployment names in Azure OpenAI Studio +] + # Helper function to clean schema for Gemini def clean_gemini_schema(schema: Any) -> Any: """Recursively removes unsupported fields from a JSON schema for Gemini.""" @@ -202,12 +220,17 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict clean_v = clean_v[7:] elif clean_v.startswith('gemini/'): clean_v = clean_v[7:] + elif clean_v.startswith('azure/'): + clean_v = clean_v[6:] # --- Mapping Logic --- START --- mapped = False # Map Haiku to SMALL_MODEL based on provider preference if 'haiku' in clean_v.lower(): - if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS: + if PREFERRED_PROVIDER == "azure": + new_model = f"azure/{SMALL_MODEL}" + mapped = True + elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS: new_model = f"gemini/{SMALL_MODEL}" mapped = True else: @@ -216,7 +239,10 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict # Map Sonnet to BIG_MODEL based on provider preference elif 'sonnet' in clean_v.lower(): - if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS: + if PREFERRED_PROVIDER == "azure": + new_model = f"azure/{BIG_MODEL}" + mapped = True + elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS: new_model = f"gemini/{BIG_MODEL}" mapped = True else: @@ -237,7 +263,7 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict logger.debug(f"📌 MODEL MAPPING: '{original_model}' ➡️ '{new_model}'") else: # If no mapping occurred and no prefix exists, log warning or decide default - if not v.startswith(('openai/', 'gemini/', 'anthropic/')): + if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'azure/')): logger.warning(f"⚠️ No prefix or mapping rule for model: '{original_model}'. Using as is.") new_model = v # Ensure we return the original if no rule applied @@ -275,12 +301,17 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict clean_v = clean_v[7:] elif clean_v.startswith('gemini/'): clean_v = clean_v[7:] + elif clean_v.startswith('azure/'): + clean_v = clean_v[6:] # --- Mapping Logic --- START --- mapped = False # Map Haiku to SMALL_MODEL based on provider preference if 'haiku' in clean_v.lower(): - if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS: + if PREFERRED_PROVIDER == "azure": + new_model = f"azure/{SMALL_MODEL}" + mapped = True + elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS: new_model = f"gemini/{SMALL_MODEL}" mapped = True else: @@ -289,7 +320,10 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict # Map Sonnet to BIG_MODEL based on provider preference elif 'sonnet' in clean_v.lower(): - if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS: + if PREFERRED_PROVIDER == "azure": + new_model = f"azure/{BIG_MODEL}" + mapped = True + elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS: new_model = f"gemini/{BIG_MODEL}" mapped = True else: @@ -309,7 +343,7 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict if mapped: logger.debug(f"📌 TOKEN COUNT MAPPING: '{original_model}' ➡️ '{new_model}'") else: - if not v.startswith(('openai/', 'gemini/', 'anthropic/')): + if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'azure/')): logger.warning(f"⚠️ No prefix or mapping rule for token count model: '{original_model}'. Using as is.") new_model = v # Ensure we return the original if no rule applied @@ -533,9 +567,9 @@ def convert_anthropic_to_litellm(anthropic_request: MessagesRequest) -> Dict[str # Cap max_tokens for OpenAI models to their limit of 16384 max_tokens = anthropic_request.max_tokens - if anthropic_request.model.startswith("openai/") or anthropic_request.model.startswith("gemini/"): + if anthropic_request.model.startswith("openai/") or anthropic_request.model.startswith("gemini/") or anthropic_request.model.startswith("azure/"): max_tokens = min(max_tokens, 16384) - logger.debug(f"Capping max_tokens to 16384 for OpenAI/Gemini model (original value: {anthropic_request.max_tokens})") + logger.debug(f"Capping max_tokens to 16384 for OpenAI/Gemini/Azure model (original value: {anthropic_request.max_tokens})") # Create LiteLLM request dict litellm_request = { @@ -1110,13 +1144,18 @@ async def create_message( elif request.model.startswith("gemini/"): litellm_request["api_key"] = GEMINI_API_KEY logger.debug(f"Using Gemini API key for model: {request.model}") + elif request.model.startswith("azure/"): + litellm_request["api_key"] = AZURE_OPENAI_API_KEY + litellm_request["api_base"] = AZURE_OPENAI_ENDPOINT + litellm_request["api_version"] = AZURE_API_VERSION + logger.debug(f"Using Azure OpenAI API key for model: {request.model}") else: litellm_request["api_key"] = ANTHROPIC_API_KEY logger.debug(f"Using Anthropic API key for model: {request.model}") # For OpenAI models - modify request format to work with limitations - if "openai" in litellm_request["model"] and "messages" in litellm_request: - logger.debug(f"Processing OpenAI model request: {litellm_request['model']}") + if ("openai" in litellm_request["model"] or "azure" in litellm_request["model"]) and "messages" in litellm_request: + logger.debug(f"Processing OpenAI/Azure model request: {litellm_request['model']}") # For OpenAI models, we need to convert content blocks to simple strings # and handle other requirements diff --git a/tests.py b/tests.py index 84d1b18..e4d5673 100644 --- a/tests.py +++ b/tests.py @@ -361,6 +361,83 @@ def test_request(test_name, request_data, check_tools=False): traceback.print_exc() return False +def test_azure_model_mapping(): + """Test Azure model mapping specifically.""" + print(f"\n{'='*20} RUNNING AZURE MODEL MAPPING TEST {'='*20}") + + # Test different Azure configurations + azure_tests = [ + { + "name": "direct_azure", + "model": "azure/my-deployment", + "expected_prefix": "azure/", + "env_vars": {} + } + ] + + success_count = 0 + + for test_case in azure_tests: + print(f"\n--- Testing {test_case['name']} ---") + + # Create test request + test_data = { + "model": test_case["model"], + "max_tokens": 100, + "messages": [{"role": "user", "content": "Test message"}] + } + + try: + # We'll just test that the proxy accepts the request and processes it + # without necessarily getting a real response (since we may not have Azure configured) + print(f"Testing model: {test_case['model']}") + + # Use dummy headers for Azure-only testing if no API key is set + test_headers = proxy_headers + if not ANTHROPIC_API_KEY: + test_headers = { + "x-api-key": "dummy-key-for-azure-testing", + "anthropic-version": ANTHROPIC_VERSION, + "content-type": "application/json", + } + + proxy_response = httpx.post(PROXY_API_URL, headers=test_headers, json=test_data, timeout=10) + + print(f"Status code: {proxy_response.status_code}") + + # Accept both success and certain types of failures (like missing Azure config) + # The important thing is that the model mapping logic works + if proxy_response.status_code == 200: + print("✅ Request succeeded - Azure mapping worked") + success_count += 1 + elif proxy_response.status_code in [400, 401, 403, 422, 500]: + response_text = proxy_response.text + print(f"❌ Azure test failed (HTTP {proxy_response.status_code})") + print(f"Response: {response_text}") + + # Check if it's a config issue and provide helpful message + if "azure" in response_text.lower() or "deployment" in response_text.lower(): + print("💡 This looks like an Azure configuration issue.") + elif proxy_response.status_code == 500: + print("💡 HTTP 500 - This could be missing Azure config or incorrect deployment name.") + + print("💡 To fix Azure configuration, ensure you have these in your .env:") + print(" AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/") + print(" AZURE_OPENAI_API_KEY=your-azure-api-key") + print(" AZURE_API_VERSION=your-api-version") + print(" BIG_MODEL=your-deployment-name") + print(" SMALL_MODEL=your-deployment-name") + # Don't increment success_count - this is a failure + else: + print(f"❌ Unexpected status code: {proxy_response.status_code}") + print(f"Response: {proxy_response.text}") + + except Exception as e: + print(f"❌ Exception during test: {e}") + + print(f"\nAzure mapping tests: {success_count}/{len(azure_tests)} passed") + return success_count == len(azure_tests) + # ================= STREAMING TESTS ================= class StreamStats: @@ -636,8 +713,18 @@ async def run_tests(args): # Track test results results = {} - # First run non-streaming tests + # Run Azure-specific tests first (unless we're doing streaming-only) if not args.streaming_only: + print("\n\n=========== RUNNING AZURE INTEGRATION TESTS ===========\n") + azure_result = test_azure_model_mapping() + results["azure_model_mapping"] = azure_result + + # If Azure-only is specified, skip other tests + if args.azure_only: + print("\n\n=========== AZURE-ONLY MODE - SKIPPING OTHER TESTS ===========\n") + + # First run non-streaming tests + elif not args.streaming_only: print("\n\n=========== RUNNING NON-STREAMING TESTS ===========\n") for test_name, test_data in TEST_SCENARIOS.items(): # Skip streaming tests @@ -658,7 +745,7 @@ async def run_tests(args): results[test_name] = result # Now run streaming tests - if not args.no_streaming: + if not args.no_streaming and not args.azure_only: print("\n\n=========== RUNNING STREAMING TESTS ===========\n") for test_name, test_data in TEST_SCENARIOS.items(): # Only select streaming tests, or force streaming @@ -695,19 +782,20 @@ async def run_tests(args): return False async def main(): - # Check that API key is set - if not ANTHROPIC_API_KEY: - print("Error: ANTHROPIC_API_KEY not set in .env file") - return - - # Parse command-line arguments + # Parse command-line arguments first parser = argparse.ArgumentParser(description="Test the Claude-on-OpenAI proxy") parser.add_argument("--no-streaming", action="store_true", help="Skip streaming tests") parser.add_argument("--streaming-only", action="store_true", help="Only run streaming tests") parser.add_argument("--simple", action="store_true", help="Only run simple tests (no tools)") parser.add_argument("--tools-only", action="store_true", help="Only run tool tests") + parser.add_argument("--azure-only", action="store_true", help="Only run Azure integration tests") args = parser.parse_args() + # Check that API key is set (unless we're only testing Azure) + if not ANTHROPIC_API_KEY and not args.azure_only: + print("Error: ANTHROPIC_API_KEY not set in .env file") + return + # Run tests success = await run_tests(args) sys.exit(0 if success else 1)