@@ -8,13 +8,16 @@ import (
88 "io"
99 "net"
1010 "net/http"
11+ "os"
12+ "strings"
1113 "testing"
1214 "time"
1315
1416 "github.com/stretchr/testify/assert"
1517 "github.com/stretchr/testify/require"
1618 "go.uber.org/mock/gomock"
1719
20+ "github.com/stacklok/toolhive/pkg/audit"
1821 "github.com/stacklok/toolhive/pkg/auth"
1922 "github.com/stacklok/toolhive/pkg/vmcp"
2023 "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
@@ -540,3 +543,306 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) {
540543 assert .Equal (t , "backend1" , result .Tools [0 ].BackendID )
541544 })
542545}
546+
547+ // TestIntegration_AuditLogging tests that the vMCP server logs MCP operations
548+ // when audit middleware is enabled.
549+ // Note: This test does not use t.Parallel() because subtests share the same
550+ // server instance and audit log file, and must run sequentially.
551+ //
552+ //nolint:paralleltest // Subtests must run sequentially as they share server state
553+ func TestIntegration_AuditLogging (t * testing.T ) {
554+ ctrl := gomock .NewController (t )
555+ t .Cleanup (ctrl .Finish )
556+
557+ ctx := context .Background ()
558+
559+ // Create temp file for audit logs
560+ auditLogFile , err := os .CreateTemp ("" , "vmcp-audit-test-*.log" )
561+ require .NoError (t , err )
562+ auditLogPath := auditLogFile .Name ()
563+ auditLogFile .Close ()
564+ t .Cleanup (func () {
565+ os .Remove (auditLogPath )
566+ })
567+
568+ // Create audit config that writes to temp file
569+ auditConfig := & audit.Config {
570+ Component : "vmcp-server-test" ,
571+ IncludeRequestData : true ,
572+ IncludeResponseData : false ,
573+ MaxDataSize : 2048 ,
574+ LogFile : auditLogPath ,
575+ }
576+
577+ // Create mock backend client
578+ mockBackendClient := mocks .NewMockBackendClient (ctrl )
579+
580+ // Define backend capabilities
581+ backendCapabilities := & vmcp.CapabilityList {
582+ Tools : []vmcp.Tool {
583+ {
584+ Name : "get_weather" ,
585+ Description : "Get weather information" ,
586+ InputSchema : map [string ]any {
587+ "type" : "object" ,
588+ "properties" : map [string ]any {
589+ "location" : map [string ]any {"type" : "string" },
590+ },
591+ },
592+ BackendID : "weather-service" ,
593+ },
594+ },
595+ Resources : []vmcp.Resource {
596+ {
597+ URI : "weather://current" ,
598+ Name : "Current Weather" ,
599+ Description : "Current weather data" ,
600+ MimeType : "application/json" ,
601+ BackendID : "weather-service" ,
602+ },
603+ },
604+ Prompts : []vmcp.Prompt {
605+ {
606+ Name : "weather_summary" ,
607+ Description : "Generate weather summary" ,
608+ Arguments : []vmcp.PromptArgument {},
609+ BackendID : "weather-service" ,
610+ },
611+ },
612+ }
613+
614+ // Mock backend responses
615+ mockBackendClient .EXPECT ().
616+ ListCapabilities (gomock .Any (), gomock .Any ()).
617+ Return (backendCapabilities , nil ).
618+ AnyTimes ()
619+
620+ mockBackendClient .EXPECT ().
621+ CallTool (gomock .Any (), gomock .Any (), gomock .Any (), gomock .Any ()).
622+ Return (map [string ]any {
623+ "result" : "Sunny, 72°F" ,
624+ }, nil ).
625+ AnyTimes ()
626+
627+ mockBackendClient .EXPECT ().
628+ ReadResource (gomock .Any (), gomock .Any (), gomock .Any ()).
629+ Return ([]byte (`{"temp": 72, "condition": "sunny"}` ), nil ).
630+ AnyTimes ()
631+
632+ // Create backends
633+ backends := []vmcp.Backend {
634+ {
635+ ID : "weather-service" ,
636+ Name : "Weather Service" ,
637+ },
638+ }
639+
640+ // Create router
641+ rt := router .NewDefaultRouter ()
642+
643+ // Create discovery manager
644+ mockDiscoveryMgr := discoveryMocks .NewMockManager (ctrl )
645+ mockDiscoveryMgr .EXPECT ().
646+ Discover (gomock .Any (), gomock .Any ()).
647+ DoAndReturn (func (_ context.Context , _ []vmcp.Backend ) (* aggregator.AggregatedCapabilities , error ) {
648+ resolver := aggregator .NewPrefixConflictResolver ("{workload}_" )
649+ agg := aggregator .NewDefaultAggregator (mockBackendClient , resolver , nil )
650+ return agg .AggregateCapabilities (ctx , backends )
651+ }).
652+ AnyTimes ()
653+ mockDiscoveryMgr .EXPECT ().Stop ().AnyTimes ()
654+
655+ // Helper function to read audit log file
656+ readAuditLog := func () string {
657+ data , err := os .ReadFile (auditLogPath )
658+ if err != nil {
659+ return ""
660+ }
661+ return string (data )
662+ }
663+
664+ // Create server with audit config
665+ srv , err := server .New (ctx , & server.Config {
666+ Host : "127.0.0.1" ,
667+ Port : 0 , // Random port
668+ AuditConfig : auditConfig ,
669+ }, rt , mockBackendClient , mockDiscoveryMgr , backends , nil )
670+ require .NoError (t , err )
671+
672+ // Start server
673+ serverCtx , cancelServer := context .WithCancel (ctx )
674+ t .Cleanup (cancelServer )
675+
676+ serverErrCh := make (chan error , 1 )
677+ go func () {
678+ if err := srv .Start (serverCtx ); err != nil && err != context .Canceled {
679+ serverErrCh <- err
680+ }
681+ }()
682+
683+ // Wait for server ready
684+ select {
685+ case <- srv .Ready ():
686+ case err := <- serverErrCh :
687+ t .Fatalf ("Server failed to start: %v" , err )
688+ case <- time .After (5 * time .Second ):
689+ t .Fatal ("Server timeout waiting for ready" )
690+ }
691+
692+ baseURL := "http://" + srv .Address ()
693+
694+ // Test 1: Initialize request should be logged
695+ t .Run ("initialize request is logged" , func (t * testing.T ) {
696+ initReq := map [string ]any {
697+ "method" : "initialize" ,
698+ "params" : map [string ]any {
699+ "protocolVersion" : "2024-11-05" ,
700+ "capabilities" : map [string ]any {},
701+ "clientInfo" : map [string ]any {
702+ "name" : "audit-test-client" ,
703+ "version" : "1.0.0" ,
704+ },
705+ },
706+ }
707+
708+ reqBody , err := json .Marshal (initReq )
709+ require .NoError (t , err )
710+
711+ resp , err := http .Post (baseURL + "/mcp" , "application/json" , bytes .NewReader (reqBody ))
712+ require .NoError (t , err )
713+ defer resp .Body .Close ()
714+
715+ require .Equal (t , http .StatusOK , resp .StatusCode )
716+
717+ // Wait for audit event to be written
718+ time .Sleep (500 * time .Millisecond )
719+
720+ // Verify audit log contains initialize event
721+ auditLog := readAuditLog ()
722+ assert .Contains (t , auditLog , "vmcp-server-test" , "Should contain component name" )
723+ assert .Contains (t , auditLog , "\" method\" :\" initialize\" " , "Should log initialize method in request data" )
724+ assert .Contains (t , auditLog , "audit-test-client" , "Should capture client name" )
725+ })
726+
727+ // Test 2: Tool list request should be logged
728+ t .Run ("tools/list request is logged" , func (t * testing.T ) {
729+
730+ toolsReq := map [string ]any {
731+ "method" : "tools/list" ,
732+ }
733+
734+ reqBody , err := json .Marshal (toolsReq )
735+ require .NoError (t , err )
736+
737+ req , err := http .NewRequest ("POST" , baseURL + "/mcp" , bytes .NewReader (reqBody ))
738+ require .NoError (t , err )
739+ req .Header .Set ("Content-Type" , "application/json" )
740+
741+ resp , err := http .DefaultClient .Do (req )
742+ require .NoError (t , err )
743+ defer resp .Body .Close ()
744+
745+ // Wait for audit event
746+ time .Sleep (500 * time .Millisecond )
747+
748+ auditLog := readAuditLog ()
749+ assert .Contains (t , auditLog , "\" method\" :\" tools/list\" " , "Should log tools/list method in request data" )
750+ assert .Contains (t , auditLog , "vmcp-server-test" , "Should contain component name" )
751+ })
752+
753+ // Test 3: Tool call should be logged
754+ t .Run ("tool call is logged" , func (t * testing.T ) {
755+
756+ toolCallReq := map [string ]any {
757+ "method" : "tools/call" ,
758+ "params" : map [string ]any {
759+ "name" : "get_weather" ,
760+ "arguments" : map [string ]any {
761+ "location" : "San Francisco" ,
762+ },
763+ },
764+ }
765+
766+ reqBody , err := json .Marshal (toolCallReq )
767+ require .NoError (t , err )
768+
769+ req , err := http .NewRequest ("POST" , baseURL + "/mcp" , bytes .NewReader (reqBody ))
770+ require .NoError (t , err )
771+ req .Header .Set ("Content-Type" , "application/json" )
772+
773+ resp , err := http .DefaultClient .Do (req )
774+ require .NoError (t , err )
775+ defer resp .Body .Close ()
776+
777+ // Wait for audit event
778+ time .Sleep (500 * time .Millisecond )
779+
780+ auditLog := readAuditLog ()
781+ assert .Contains (t , auditLog , "\" method\" :\" tools/call\" " , "Should log tools/call method in request data" )
782+ assert .Contains (t , auditLog , "get_weather" , "Should capture tool name in request data" )
783+ assert .Contains (t , auditLog , "San Francisco" , "Should capture tool arguments in request data" )
784+ assert .Contains (t , auditLog , "vmcp-server-test" , "Should contain component name" )
785+ })
786+
787+ // Test 4: Resource read should be logged
788+ t .Run ("resource read is logged" , func (t * testing.T ) {
789+
790+ resourceReq := map [string ]any {
791+ "method" : "resources/read" ,
792+ "params" : map [string ]any {
793+ "uri" : "weather://current" ,
794+ },
795+ }
796+
797+ reqBody , err := json .Marshal (resourceReq )
798+ require .NoError (t , err )
799+
800+ req , err := http .NewRequest ("POST" , baseURL + "/mcp" , bytes .NewReader (reqBody ))
801+ require .NoError (t , err )
802+ req .Header .Set ("Content-Type" , "application/json" )
803+
804+ resp , err := http .DefaultClient .Do (req )
805+ require .NoError (t , err )
806+ defer resp .Body .Close ()
807+
808+ // Wait for audit event
809+ time .Sleep (500 * time .Millisecond )
810+
811+ auditLog := readAuditLog ()
812+ assert .Contains (t , auditLog , "\" method\" :\" resources/read\" " , "Should log resources/read method in request data" )
813+ assert .Contains (t , auditLog , "weather://current" , "Should capture resource URI in request data" )
814+ assert .Contains (t , auditLog , "vmcp-server-test" , "Should contain component name" )
815+ })
816+
817+ // Test 5: Verify audit events have required fields
818+ t .Run ("audit events contain required fields" , func (t * testing.T ) {
819+ // Get all audit logs
820+ auditLog := readAuditLog ()
821+
822+ // Split into individual log lines
823+ lines := strings .Split (strings .TrimSpace (auditLog ), "\n " )
824+ require .Greater (t , len (lines ), 0 , "Should have at least one audit event" )
825+
826+ // Parse first audit event
827+ var auditEvent map [string ]any
828+ err := json .Unmarshal ([]byte (lines [0 ]), & auditEvent )
829+ require .NoError (t , err , "Audit log should be valid JSON" )
830+
831+ // Verify required fields
832+ assert .Contains (t , auditEvent , "audit_id" , "Should have audit_id" )
833+ assert .Contains (t , auditEvent , "type" , "Should have type" )
834+ assert .Contains (t , auditEvent , "logged_at" , "Should have logged_at" )
835+ assert .Contains (t , auditEvent , "outcome" , "Should have outcome" )
836+ assert .Contains (t , auditEvent , "component" , "Should have component" )
837+ assert .Contains (t , auditEvent , "source" , "Should have source" )
838+
839+ // Verify component value
840+ assert .Equal (t , "vmcp-server-test" , auditEvent ["component" ])
841+
842+ // Verify source has network information
843+ source , ok := auditEvent ["source" ].(map [string ]any )
844+ require .True (t , ok , "Source should be an object" )
845+ assert .Equal (t , "network" , source ["type" ])
846+ assert .Contains (t , source , "value" , "Source should have IP address" )
847+ })
848+ }
0 commit comments