Skip to content

Commit 292a45d

Browse files
yroblataskbot
andauthored
Add backend routing capture to vMCP audit logs (#2983)
* Add backend routing capture to vMCP audit logs Enhance audit logging to capture which backend workload handled each MCP request (tools/call, resources/read, prompts/get). Backend routing information is now recorded in audit events as backend_name in the metadata.extra field. * changes from review --------- Co-authored-by: taskbot <[email protected]>
1 parent 6b82030 commit 292a45d

File tree

7 files changed

+998
-5
lines changed

7 files changed

+998
-5
lines changed

pkg/audit/auditor.go

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package audit
33

44
import (
55
"bytes"
6+
"context"
67
"encoding/json"
78
"io"
89
"log/slog"
@@ -21,6 +22,30 @@ import (
2122
// LevelAudit is a custom audit log level - between Info and Warn
2223
const LevelAudit = slog.Level(2)
2324

25+
// contextKey is an unexported type for context keys to avoid collisions
26+
type contextKey struct{}
27+
28+
// backendInfoKey is the context key for storing backend routing information
29+
var backendInfoKey = contextKey{}
30+
31+
// BackendInfo stores backend routing information that can be mutated by handlers.
32+
// This allows handlers deep in the call stack to provide backend info to the audit middleware.
33+
type BackendInfo struct {
34+
BackendName string
35+
}
36+
37+
// WithBackendInfo returns a new context with BackendInfo attached.
38+
func WithBackendInfo(ctx context.Context, info *BackendInfo) context.Context {
39+
return context.WithValue(ctx, backendInfoKey, info)
40+
}
41+
42+
// BackendInfoFromContext retrieves BackendInfo from the context.
43+
// Returns (nil, false) if BackendInfo is not found in the context.
44+
func BackendInfoFromContext(ctx context.Context) (*BackendInfo, bool) {
45+
info, ok := ctx.Value(backendInfoKey).(*BackendInfo)
46+
return info, ok
47+
}
48+
2449
// NewAuditLogger creates a new structured audit logger that writes to the specified writer.
2550
func NewAuditLogger(w io.Writer) *slog.Logger {
2651
if w == nil {
@@ -129,6 +154,14 @@ func (a *Auditor) Middleware(next http.Handler) http.Handler {
129154

130155
startTime := time.Now()
131156

157+
// Add BackendInfo to context if not already present
158+
// (backend enrichment middleware may have already added it)
159+
if _, ok := BackendInfoFromContext(r.Context()); !ok {
160+
backendInfo := &BackendInfo{}
161+
ctx := WithBackendInfo(r.Context(), backendInfo)
162+
r = r.WithContext(ctx)
163+
}
164+
132165
// Capture request data if configured
133166
var requestData []byte
134167
if a.config.IncludeRequestData && r.Body != nil {
@@ -194,7 +227,7 @@ func (a *Auditor) logAuditEvent(r *http.Request, rw *responseWriter, requestData
194227
}
195228

196229
// Add metadata
197-
a.addMetadata(event, duration, rw)
230+
a.addMetadata(event, r, duration, rw)
198231

199232
// Add request/response data if configured
200233
a.addEventData(event, r, rw, requestData)
@@ -398,7 +431,7 @@ func (*Auditor) extractTarget(r *http.Request, eventType string) map[string]stri
398431
}
399432

400433
// addMetadata adds metadata to the audit event.
401-
func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *responseWriter) {
434+
func (a *Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Duration, rw *responseWriter) {
402435
if event.Metadata.Extra == nil {
403436
event.Metadata.Extra = make(map[string]any)
404437
}
@@ -417,6 +450,12 @@ func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *res
417450
if rw.body != nil {
418451
event.Metadata.Extra[MetadataExtraKeyResponseSize] = rw.body.Len()
419452
}
453+
454+
// Add backend routing information from context if available
455+
// Backend info is populated by the backend enrichment middleware
456+
if backendInfo, ok := BackendInfoFromContext(r.Context()); ok && backendInfo != nil && backendInfo.BackendName != "" {
457+
event.Metadata.Extra["backend_name"] = backendInfo.BackendName
458+
}
420459
}
421460

422461
// addEventData adds request/response data to the audit event if configured.

pkg/audit/auditor_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,9 @@ func TestAddMetadata(t *testing.T) {
517517
ResponseWriter: httptest.NewRecorder(),
518518
body: bytes.NewBufferString("test response"),
519519
}
520+
req := httptest.NewRequest("GET", "/test", nil)
520521

521-
auditor.addMetadata(event, duration, rw)
522+
auditor.addMetadata(event, req, duration, rw)
522523

523524
require.NotNil(t, event.Metadata.Extra)
524525
assert.Equal(t, int64(150), event.Metadata.Extra[MetadataExtraKeyDuration])

pkg/audit/backend_info_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package audit
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestBackendInfoContext(t *testing.T) {
12+
t.Parallel()
13+
14+
t.Run("BackendInfo can be added and retrieved from context", func(t *testing.T) {
15+
t.Parallel()
16+
17+
// Create a BackendInfo
18+
info := &BackendInfo{
19+
BackendName: "test-backend",
20+
}
21+
22+
// Add it to context
23+
ctx := WithBackendInfo(context.Background(), info)
24+
25+
// Retrieve it
26+
retrieved, ok := BackendInfoFromContext(ctx)
27+
require.True(t, ok, "BackendInfo should be in context")
28+
require.NotNil(t, retrieved, "BackendInfo should not be nil")
29+
assert.Equal(t, "test-backend", retrieved.BackendName)
30+
31+
// Verify it's the same pointer
32+
assert.Same(t, info, retrieved, "Should be the same BackendInfo pointer")
33+
})
34+
35+
t.Run("BackendInfo can be mutated through context", func(t *testing.T) {
36+
t.Parallel()
37+
38+
// Create empty BackendInfo
39+
info := &BackendInfo{}
40+
41+
// Add to context
42+
ctx := WithBackendInfo(context.Background(), info)
43+
44+
// Retrieve and mutate
45+
retrieved, ok := BackendInfoFromContext(ctx)
46+
require.True(t, ok)
47+
retrieved.BackendName = "mutated-backend"
48+
49+
// Verify original was mutated
50+
assert.Equal(t, "mutated-backend", info.BackendName)
51+
})
52+
53+
t.Run("Missing BackendInfo returns false", func(t *testing.T) {
54+
t.Parallel()
55+
56+
ctx := context.Background()
57+
58+
retrieved, ok := BackendInfoFromContext(ctx)
59+
assert.False(t, ok, "Should return false when not in context")
60+
assert.Nil(t, retrieved, "Should return nil when not in context")
61+
})
62+
63+
t.Run("BackendInfo survives context derivation", func(t *testing.T) {
64+
t.Parallel()
65+
66+
// Create BackendInfo and add to context
67+
info := &BackendInfo{BackendName: "original"}
68+
ctx := WithBackendInfo(context.Background(), info)
69+
70+
// Derive a new context with additional value
71+
type key struct{}
72+
derivedCtx := context.WithValue(ctx, key{}, "some-value")
73+
74+
// BackendInfo should still be accessible
75+
retrieved, ok := BackendInfoFromContext(derivedCtx)
76+
require.True(t, ok, "BackendInfo should survive context derivation")
77+
assert.Equal(t, "original", retrieved.BackendName)
78+
})
79+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package server
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
9+
"github.com/stacklok/toolhive/pkg/audit"
10+
"github.com/stacklok/toolhive/pkg/logger"
11+
"github.com/stacklok/toolhive/pkg/vmcp"
12+
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
13+
)
14+
15+
// backendEnrichmentMiddleware wraps an HTTP handler to add backend routing information
16+
// to audit events by parsing MCP requests and looking up backends in the routing table.
17+
func (*Server) backendEnrichmentMiddleware(next http.Handler) http.Handler {
18+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
// Read and parse the request body to extract MCP method and parameters
20+
var requestBody []byte
21+
if r.Body != nil {
22+
var err error
23+
requestBody, err = io.ReadAll(r.Body)
24+
// Always restore body for next handler, even on error
25+
if err != nil {
26+
// Log the error and restore an empty body to ensure consistent behavior
27+
logger.Warnw("failed to read request body in backend enrichment middleware",
28+
"error", err)
29+
r.Body = io.NopCloser(bytes.NewReader([]byte{}))
30+
} else {
31+
// Restore body with the read content
32+
r.Body = io.NopCloser(bytes.NewReader(requestBody))
33+
}
34+
}
35+
36+
// Parse MCP request to extract tool/resource name
37+
var mcpRequest struct {
38+
Method string `json:"method"`
39+
Params map[string]any `json:"params"`
40+
}
41+
42+
if len(requestBody) > 0 && json.Unmarshal(requestBody, &mcpRequest) == nil {
43+
// Get routing table from discovered capabilities in context
44+
caps, ok := discovery.DiscoveredCapabilitiesFromContext(r.Context())
45+
if ok && caps != nil && caps.RoutingTable != nil {
46+
backendName := lookupBackendName(mcpRequest.Method, mcpRequest.Params, caps.RoutingTable)
47+
48+
// Mutate the existing BackendInfo from audit middleware
49+
if backendName != "" {
50+
if backendInfo, ok := audit.BackendInfoFromContext(r.Context()); ok && backendInfo != nil {
51+
backendInfo.BackendName = backendName
52+
}
53+
}
54+
}
55+
}
56+
57+
// Call next handler
58+
next.ServeHTTP(w, r)
59+
})
60+
}
61+
62+
// lookupBackendName looks up which backend handles a given MCP request.
63+
func lookupBackendName(method string, params map[string]any, routingTable *vmcp.RoutingTable) string {
64+
switch method {
65+
case "tools/call":
66+
if toolName, ok := params["name"].(string); ok {
67+
if target, exists := routingTable.Tools[toolName]; exists {
68+
return target.WorkloadName
69+
}
70+
}
71+
case "resources/read":
72+
if uri, ok := params["uri"].(string); ok {
73+
if target, exists := routingTable.Resources[uri]; exists {
74+
return target.WorkloadName
75+
}
76+
}
77+
case "prompts/get":
78+
if promptName, ok := params["name"].(string); ok {
79+
if target, exists := routingTable.Prompts[promptName]; exists {
80+
return target.WorkloadName
81+
}
82+
}
83+
}
84+
return ""
85+
}

0 commit comments

Comments
 (0)