Skip to content

Commit 1976761

Browse files
yroblataskbot
andauthored
Add audit middleware to vMCP server (#2981)
* Add audit middleware to vMCP server Integrate the existing ToolHive audit middleware into the Virtual MCP Server to enable audit logging of MCP operations for security, compliance, and debugging purposes. Changes: - Add AuditConfig field to vMCP server Config struct - Add audit middleware initialization in Start() method - Position audit middleware in chain: Auth → Audit → Discovery → Telemetry - Use "streamable-http" transport type for vMCP audit events - Add unit tests for audit configuration handling The audit middleware is optional (enabled when AuditConfig is provided) and uses the existing pkg/audit infrastructure with MCP-specific event types (mcp_tool_call, mcp_initialize, etc.). Partially-closes: #2980 * changes from review --------- Co-authored-by: taskbot <[email protected]>
1 parent fec40b2 commit 1976761

File tree

3 files changed

+388
-2
lines changed

3 files changed

+388
-2
lines changed

pkg/vmcp/server/integration_test.go

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ import (
88
"io"
99
"net"
1010
"net/http"
11+
"os"
12+
"strings"
1113
"testing"
1214
"time"
1315

1416
"github.com/stretchr/testify/assert"
1517
"github.com/stretchr/testify/require"
1618
"go.uber.org/mock/gomock"
1719

20+
"github.com/stacklok/toolhive/pkg/audit"
1821
"github.com/stacklok/toolhive/pkg/auth"
1922
"github.com/stacklok/toolhive/pkg/vmcp"
2023
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
@@ -540,3 +543,306 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) {
540543
assert.Equal(t, "backend1", result.Tools[0].BackendID)
541544
})
542545
}
546+
547+
// TestIntegration_AuditLogging tests that the vMCP server logs MCP operations
548+
// when audit middleware is enabled.
549+
// Note: This test does not use t.Parallel() because subtests share the same
550+
// server instance and audit log file, and must run sequentially.
551+
//
552+
//nolint:paralleltest // Subtests must run sequentially as they share server state
553+
func TestIntegration_AuditLogging(t *testing.T) {
554+
ctrl := gomock.NewController(t)
555+
t.Cleanup(ctrl.Finish)
556+
557+
ctx := context.Background()
558+
559+
// Create temp file for audit logs
560+
auditLogFile, err := os.CreateTemp("", "vmcp-audit-test-*.log")
561+
require.NoError(t, err)
562+
auditLogPath := auditLogFile.Name()
563+
auditLogFile.Close()
564+
t.Cleanup(func() {
565+
os.Remove(auditLogPath)
566+
})
567+
568+
// Create audit config that writes to temp file
569+
auditConfig := &audit.Config{
570+
Component: "vmcp-server-test",
571+
IncludeRequestData: true,
572+
IncludeResponseData: false,
573+
MaxDataSize: 2048,
574+
LogFile: auditLogPath,
575+
}
576+
577+
// Create mock backend client
578+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
579+
580+
// Define backend capabilities
581+
backendCapabilities := &vmcp.CapabilityList{
582+
Tools: []vmcp.Tool{
583+
{
584+
Name: "get_weather",
585+
Description: "Get weather information",
586+
InputSchema: map[string]any{
587+
"type": "object",
588+
"properties": map[string]any{
589+
"location": map[string]any{"type": "string"},
590+
},
591+
},
592+
BackendID: "weather-service",
593+
},
594+
},
595+
Resources: []vmcp.Resource{
596+
{
597+
URI: "weather://current",
598+
Name: "Current Weather",
599+
Description: "Current weather data",
600+
MimeType: "application/json",
601+
BackendID: "weather-service",
602+
},
603+
},
604+
Prompts: []vmcp.Prompt{
605+
{
606+
Name: "weather_summary",
607+
Description: "Generate weather summary",
608+
Arguments: []vmcp.PromptArgument{},
609+
BackendID: "weather-service",
610+
},
611+
},
612+
}
613+
614+
// Mock backend responses
615+
mockBackendClient.EXPECT().
616+
ListCapabilities(gomock.Any(), gomock.Any()).
617+
Return(backendCapabilities, nil).
618+
AnyTimes()
619+
620+
mockBackendClient.EXPECT().
621+
CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
622+
Return(map[string]any{
623+
"result": "Sunny, 72°F",
624+
}, nil).
625+
AnyTimes()
626+
627+
mockBackendClient.EXPECT().
628+
ReadResource(gomock.Any(), gomock.Any(), gomock.Any()).
629+
Return([]byte(`{"temp": 72, "condition": "sunny"}`), nil).
630+
AnyTimes()
631+
632+
// Create backends
633+
backends := []vmcp.Backend{
634+
{
635+
ID: "weather-service",
636+
Name: "Weather Service",
637+
},
638+
}
639+
640+
// Create router
641+
rt := router.NewDefaultRouter()
642+
643+
// Create discovery manager
644+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
645+
mockDiscoveryMgr.EXPECT().
646+
Discover(gomock.Any(), gomock.Any()).
647+
DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) {
648+
resolver := aggregator.NewPrefixConflictResolver("{workload}_")
649+
agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil)
650+
return agg.AggregateCapabilities(ctx, backends)
651+
}).
652+
AnyTimes()
653+
mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
654+
655+
// Helper function to read audit log file
656+
readAuditLog := func() string {
657+
data, err := os.ReadFile(auditLogPath)
658+
if err != nil {
659+
return ""
660+
}
661+
return string(data)
662+
}
663+
664+
// Create server with audit config
665+
srv, err := server.New(ctx, &server.Config{
666+
Host: "127.0.0.1",
667+
Port: 0, // Random port
668+
AuditConfig: auditConfig,
669+
}, rt, mockBackendClient, mockDiscoveryMgr, backends, nil)
670+
require.NoError(t, err)
671+
672+
// Start server
673+
serverCtx, cancelServer := context.WithCancel(ctx)
674+
t.Cleanup(cancelServer)
675+
676+
serverErrCh := make(chan error, 1)
677+
go func() {
678+
if err := srv.Start(serverCtx); err != nil && err != context.Canceled {
679+
serverErrCh <- err
680+
}
681+
}()
682+
683+
// Wait for server ready
684+
select {
685+
case <-srv.Ready():
686+
case err := <-serverErrCh:
687+
t.Fatalf("Server failed to start: %v", err)
688+
case <-time.After(5 * time.Second):
689+
t.Fatal("Server timeout waiting for ready")
690+
}
691+
692+
baseURL := "http://" + srv.Address()
693+
694+
// Test 1: Initialize request should be logged
695+
t.Run("initialize request is logged", func(t *testing.T) {
696+
initReq := map[string]any{
697+
"method": "initialize",
698+
"params": map[string]any{
699+
"protocolVersion": "2024-11-05",
700+
"capabilities": map[string]any{},
701+
"clientInfo": map[string]any{
702+
"name": "audit-test-client",
703+
"version": "1.0.0",
704+
},
705+
},
706+
}
707+
708+
reqBody, err := json.Marshal(initReq)
709+
require.NoError(t, err)
710+
711+
resp, err := http.Post(baseURL+"/mcp", "application/json", bytes.NewReader(reqBody))
712+
require.NoError(t, err)
713+
defer resp.Body.Close()
714+
715+
require.Equal(t, http.StatusOK, resp.StatusCode)
716+
717+
// Wait for audit event to be written
718+
time.Sleep(500 * time.Millisecond)
719+
720+
// Verify audit log contains initialize event
721+
auditLog := readAuditLog()
722+
assert.Contains(t, auditLog, "vmcp-server-test", "Should contain component name")
723+
assert.Contains(t, auditLog, "\"method\":\"initialize\"", "Should log initialize method in request data")
724+
assert.Contains(t, auditLog, "audit-test-client", "Should capture client name")
725+
})
726+
727+
// Test 2: Tool list request should be logged
728+
t.Run("tools/list request is logged", func(t *testing.T) {
729+
730+
toolsReq := map[string]any{
731+
"method": "tools/list",
732+
}
733+
734+
reqBody, err := json.Marshal(toolsReq)
735+
require.NoError(t, err)
736+
737+
req, err := http.NewRequest("POST", baseURL+"/mcp", bytes.NewReader(reqBody))
738+
require.NoError(t, err)
739+
req.Header.Set("Content-Type", "application/json")
740+
741+
resp, err := http.DefaultClient.Do(req)
742+
require.NoError(t, err)
743+
defer resp.Body.Close()
744+
745+
// Wait for audit event
746+
time.Sleep(500 * time.Millisecond)
747+
748+
auditLog := readAuditLog()
749+
assert.Contains(t, auditLog, "\"method\":\"tools/list\"", "Should log tools/list method in request data")
750+
assert.Contains(t, auditLog, "vmcp-server-test", "Should contain component name")
751+
})
752+
753+
// Test 3: Tool call should be logged
754+
t.Run("tool call is logged", func(t *testing.T) {
755+
756+
toolCallReq := map[string]any{
757+
"method": "tools/call",
758+
"params": map[string]any{
759+
"name": "get_weather",
760+
"arguments": map[string]any{
761+
"location": "San Francisco",
762+
},
763+
},
764+
}
765+
766+
reqBody, err := json.Marshal(toolCallReq)
767+
require.NoError(t, err)
768+
769+
req, err := http.NewRequest("POST", baseURL+"/mcp", bytes.NewReader(reqBody))
770+
require.NoError(t, err)
771+
req.Header.Set("Content-Type", "application/json")
772+
773+
resp, err := http.DefaultClient.Do(req)
774+
require.NoError(t, err)
775+
defer resp.Body.Close()
776+
777+
// Wait for audit event
778+
time.Sleep(500 * time.Millisecond)
779+
780+
auditLog := readAuditLog()
781+
assert.Contains(t, auditLog, "\"method\":\"tools/call\"", "Should log tools/call method in request data")
782+
assert.Contains(t, auditLog, "get_weather", "Should capture tool name in request data")
783+
assert.Contains(t, auditLog, "San Francisco", "Should capture tool arguments in request data")
784+
assert.Contains(t, auditLog, "vmcp-server-test", "Should contain component name")
785+
})
786+
787+
// Test 4: Resource read should be logged
788+
t.Run("resource read is logged", func(t *testing.T) {
789+
790+
resourceReq := map[string]any{
791+
"method": "resources/read",
792+
"params": map[string]any{
793+
"uri": "weather://current",
794+
},
795+
}
796+
797+
reqBody, err := json.Marshal(resourceReq)
798+
require.NoError(t, err)
799+
800+
req, err := http.NewRequest("POST", baseURL+"/mcp", bytes.NewReader(reqBody))
801+
require.NoError(t, err)
802+
req.Header.Set("Content-Type", "application/json")
803+
804+
resp, err := http.DefaultClient.Do(req)
805+
require.NoError(t, err)
806+
defer resp.Body.Close()
807+
808+
// Wait for audit event
809+
time.Sleep(500 * time.Millisecond)
810+
811+
auditLog := readAuditLog()
812+
assert.Contains(t, auditLog, "\"method\":\"resources/read\"", "Should log resources/read method in request data")
813+
assert.Contains(t, auditLog, "weather://current", "Should capture resource URI in request data")
814+
assert.Contains(t, auditLog, "vmcp-server-test", "Should contain component name")
815+
})
816+
817+
// Test 5: Verify audit events have required fields
818+
t.Run("audit events contain required fields", func(t *testing.T) {
819+
// Get all audit logs
820+
auditLog := readAuditLog()
821+
822+
// Split into individual log lines
823+
lines := strings.Split(strings.TrimSpace(auditLog), "\n")
824+
require.Greater(t, len(lines), 0, "Should have at least one audit event")
825+
826+
// Parse first audit event
827+
var auditEvent map[string]any
828+
err := json.Unmarshal([]byte(lines[0]), &auditEvent)
829+
require.NoError(t, err, "Audit log should be valid JSON")
830+
831+
// Verify required fields
832+
assert.Contains(t, auditEvent, "audit_id", "Should have audit_id")
833+
assert.Contains(t, auditEvent, "type", "Should have type")
834+
assert.Contains(t, auditEvent, "logged_at", "Should have logged_at")
835+
assert.Contains(t, auditEvent, "outcome", "Should have outcome")
836+
assert.Contains(t, auditEvent, "component", "Should have component")
837+
assert.Contains(t, auditEvent, "source", "Should have source")
838+
839+
// Verify component value
840+
assert.Equal(t, "vmcp-server-test", auditEvent["component"])
841+
842+
// Verify source has network information
843+
source, ok := auditEvent["source"].(map[string]any)
844+
require.True(t, ok, "Source should be an object")
845+
assert.Equal(t, "network", source["type"])
846+
assert.Contains(t, source, "value", "Source should have IP address")
847+
})
848+
}

pkg/vmcp/server/server.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"github.com/mark3labs/mcp-go/server"
1919

20+
"github.com/stacklok/toolhive/pkg/audit"
2021
"github.com/stacklok/toolhive/pkg/auth"
2122
"github.com/stacklok/toolhive/pkg/logger"
2223
"github.com/stacklok/toolhive/pkg/telemetry"
@@ -87,6 +88,11 @@ type Config struct {
8788
// TelemetryProvider is the optional telemetry provider.
8889
// If nil, no telemetry is recorded.
8990
TelemetryProvider *telemetry.Provider
91+
92+
// AuditConfig is the optional audit configuration.
93+
// If nil, no audit logging is performed.
94+
// Component should be set to "vmcp-server" to distinguish vMCP audit logs.
95+
AuditConfig *audit.Config
9096
}
9197

9298
// Server is the Virtual MCP Server that aggregates multiple backends.
@@ -386,20 +392,33 @@ func (s *Server) Start(ctx context.Context) error {
386392
logger.Info("RFC 9728 OAuth discovery endpoints enabled at /.well-known/")
387393
}
388394

389-
// MCP endpoint - apply middleware chain: auth → discovery → telemetry
395+
// MCP endpoint - apply middleware chain: auth → audit → discovery → telemetry
390396
var mcpHandler http.Handler = streamableServer
391397

392398
if s.config.TelemetryProvider != nil {
393399
mcpHandler = s.config.TelemetryProvider.Middleware(s.config.Name, "streamable-http")(mcpHandler)
394400
logger.Info("Telemetry middleware enabled for MCP endpoints")
395401
}
396402

397-
// Apply discovery middleware (runs after auth middleware)
403+
// Apply discovery middleware (runs after audit/auth middleware)
398404
// Discovery middleware performs per-request capability aggregation with user context
399405
// Pass sessionManager to enable session-based capability retrieval for subsequent requests
400406
mcpHandler = discovery.Middleware(s.discoveryMgr, s.backends, s.sessionManager)(mcpHandler)
401407
logger.Info("Discovery middleware enabled for lazy per-user capability discovery")
402408

409+
// Apply audit middleware if configured (runs after auth, before discovery)
410+
if s.config.AuditConfig != nil {
411+
auditor, err := audit.NewAuditorWithTransport(
412+
s.config.AuditConfig,
413+
"streamable-http", // vMCP uses streamable HTTP transport
414+
)
415+
if err != nil {
416+
return fmt.Errorf("failed to create auditor: %w", err)
417+
}
418+
mcpHandler = auditor.Middleware(mcpHandler)
419+
logger.Info("Audit middleware enabled for MCP endpoints")
420+
}
421+
403422
// Apply authentication middleware if configured (runs first in chain)
404423
if s.config.AuthMiddleware != nil {
405424
mcpHandler = s.config.AuthMiddleware(mcpHandler)

0 commit comments

Comments
 (0)