Skip to content

Commit b404eef

Browse files
authored
fix(go/plugins/compat_oai): add support for custom providers (#3617)
1 parent 7fff791 commit b404eef

File tree

4 files changed

+78
-9
lines changed

4 files changed

+78
-9
lines changed

go/plugins/compat_oai/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# OpenAI-Compatible Plugin Package
22

3-
This directory contains a package for building plugins that are compatible with the OpenAI API specification, along with plugins built on top of this package.
3+
This directory contains a package for building plugins that are compatible with the OpenAI API specification, along with plugins built on top of this package.
44

55
## Package Overview
66

@@ -74,4 +74,4 @@ go test -v ./openai
7474
go test -v ./anthropic
7575
```
7676

77-
Note: Tests will be skipped if the required API keys are not set.
77+
Note: Tests will be skipped if the required API keys are not set.

go/plugins/compat_oai/compat_oai.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package compat_oai
1717
import (
1818
"context"
1919
"fmt"
20-
"strings"
2120
"sync"
2221

2322
"github.com/firebase/genkit/go/ai"
@@ -68,6 +67,14 @@ type OpenAICompatible struct {
6867
// This will be used as a prefix for model names (e.g., "myprovider/model-name").
6968
// Should be lowercase and match the plugin's Name() method.
7069
Provider string
70+
71+
// API key to use with the desired plugin.
72+
APIKey string
73+
74+
// Base URL to use for custom endpoints.
75+
// This should be used if you are running through a proxy or
76+
// using a non-official endpoint
77+
BaseURL string
7178
}
7279

7380
// Init implements genkit.Plugin.
@@ -78,6 +85,14 @@ func (o *OpenAICompatible) Init(ctx context.Context) []api.Action {
7885
panic("compat_oai.Init already called")
7986
}
8087

88+
if o.APIKey != "" {
89+
o.Opts = append([]option.RequestOption{option.WithAPIKey(o.APIKey)}, o.Opts...)
90+
}
91+
92+
if o.BaseURL != "" {
93+
o.Opts = append([]option.RequestOption{option.WithBaseURL(o.BaseURL)}, o.Opts...)
94+
}
95+
8196
// create client
8297
client := openai.NewClient(o.Opts...)
8398
o.client = &client
@@ -99,16 +114,13 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions
99114
panic("OpenAICompatible.Init not called")
100115
}
101116

102-
// Strip provider prefix if present to check against supportedModels
103-
modelName := strings.TrimPrefix(id, provider+"/")
104-
105117
return ai.NewModel(api.NewName(provider, id), &opts, func(
106118
ctx context.Context,
107119
input *ai.ModelRequest,
108120
cb func(context.Context, *ai.ModelResponseChunk) error,
109121
) (*ai.ModelResponse, error) {
110122
// Configure the response generator with input
111-
generator := NewModelGenerator(o.client, modelName).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools)
123+
generator := NewModelGenerator(o.client, id).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools)
112124

113125
// Generate response
114126
resp, err := generator.Generate(ctx, input, cb)
@@ -197,7 +209,7 @@ func (o *OpenAICompatible) ListActions(ctx context.Context) []api.ActionDesc {
197209
"systemRole": true,
198210
"tools": true,
199211
"toolChoice": true,
200-
"constrained": true,
212+
"constrained": "all",
201213
},
202214
},
203215
"versions": []string{},

go/plugins/compat_oai/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func (g *ModelGenerator) WithConfig(config any) *ModelGenerator {
164164
openaiConfig = *cfg
165165
case map[string]any:
166166
if err := mapToStruct(cfg, &openaiConfig); err != nil {
167-
g.err = fmt.Errorf("failed to convert config to OpenAIConfig: %w", err)
167+
g.err = fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err)
168168
return g
169169
}
170170
default:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"context"
19+
"log"
20+
"os"
21+
22+
"github.com/firebase/genkit/go/ai"
23+
"github.com/firebase/genkit/go/genkit"
24+
25+
oai "github.com/firebase/genkit/go/plugins/compat_oai"
26+
"github.com/openai/openai-go"
27+
)
28+
29+
func main() {
30+
ctx := context.Background()
31+
apiKey := os.Getenv("OPENROUTER_API_KEY")
32+
if apiKey == "" {
33+
log.Fatalf("OPENROUTER_API_KEY environment variable not set")
34+
}
35+
36+
g := genkit.Init(ctx, genkit.WithPlugins(&oai.OpenAICompatible{
37+
Provider: "openrouter",
38+
APIKey: apiKey,
39+
BaseURL: "https://openrouter.ai/api/v1",
40+
}),
41+
genkit.WithDefaultModel("openrouter/tngtech/deepseek-r1t2-chimera:free"))
42+
43+
prompt := "tell me a joke"
44+
config := &openai.ChatCompletionNewParams{
45+
Temperature: openai.Float(0.7),
46+
MaxTokens: openai.Int(1000),
47+
TopP: openai.Float(0.9),
48+
}
49+
50+
resp, err := genkit.Generate(context.Background(), g,
51+
ai.WithConfig(config),
52+
ai.WithPrompt(prompt))
53+
if err != nil {
54+
log.Fatalf("failed to generate contents: %v", err)
55+
}
56+
log.Println("Joke:", resp.Text())
57+
}

0 commit comments

Comments
 (0)