Sfoglia il codice sorgente

fix(mcp): Implement graceful shutdown for MCP HTTP server (#3494)

Signed-off-by: Tushar Verma <tusharmyself06@gmail.com>
Co-authored-by: Alex Meijer <ameijer@users.noreply.github.com>
Tushar-Verma 4 mesi fa
parent
commit
296dd00b37
2 ha cambiato i file con 101 aggiunte e 1 eliminazioni
  1. 21 1
      pkg/cmd/costmodel/costmodel.go
  2. 80 0
      pkg/cmd/costmodel/costmodel_test.go

+ 21 - 1
pkg/cmd/costmodel/costmodel.go

@@ -4,6 +4,9 @@ import (
 	"context"
 	"fmt"
 	"net/http"
+	"os"
+	"os/signal"
+	"syscall"
 	"time"
 
 	"github.com/julienschmidt/httprouter"
@@ -33,6 +36,10 @@ func Execute(conf *Config) error {
 	}
 	conf.log()
 
+	// Create cancellable context for graceful shutdown
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+	defer cancel()
+
 	router := httprouter.New()
 	var a *costmodel.Accesses
 	var cp models.Provider
@@ -81,7 +88,7 @@ func Execute(conf *Config) error {
 			cloudCostQuerier = cloudCostPipelineService.GetCloudCostQuerier()
 		}
 
-		err := StartMCPServer(context.Background(), a, cloudCostQuerier)
+		err := StartMCPServer(ctx, a, cloudCostQuerier)
 		if err != nil {
 			log.Errorf("Failed to start MCP server: %v", err)
 		}
@@ -309,6 +316,19 @@ func StartMCPServer(ctx context.Context, accesses *costmodel.Accesses, cloudCost
 		}
 	}()
 
+	// Graceful shutdown goroutine
+	go func() {
+		<-ctx.Done()
+		log.Info("Shutting down MCP server...")
+		shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+		defer cancel()
+		if err := server.Shutdown(shutdownCtx); err != nil {
+			log.Errorf("MCP server shutdown error: %v", err)
+		} else {
+			log.Info("MCP server shut down successfully")
+		}
+	}()
+
 	log.Info("MCP server started successfully")
 	return nil
 }

+ 80 - 0
pkg/cmd/costmodel/costmodel_test.go

@@ -0,0 +1,80 @@
+package costmodel
+
+import (
+	"context"
+	"fmt"
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/opencost/opencost/pkg/costmodel"
+	"github.com/opencost/opencost/pkg/env"
+)
+
+func TestMCPServerGracefulShutdown(t *testing.T) {
+	// Test that MCP server responds to context cancellation and shuts down gracefully
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	accesses := &costmodel.Accesses{}
+	port := env.GetMCPHTTPPort()
+
+	// Channel to signal when server is ready
+	serverReady := make(chan error, 1)
+
+	// Start MCP server
+	go func() {
+		err := StartMCPServer(ctx, accesses, nil)
+		serverReady <- err
+	}()
+
+	// Wait for server to be ready by attempting to connect
+	serverUp := false
+	for i := 0; i < 10; i++ {
+		time.Sleep(100 * time.Millisecond)
+		client := &http.Client{Timeout: 1 * time.Second}
+		resp, err := client.Get(fmt.Sprintf("http://localhost:%d/", port))
+		if err == nil {
+			resp.Body.Close()
+			serverUp = true
+			break
+		}
+	}
+
+	if !serverUp {
+		t.Skip("MCP server did not start (may be expected in test environment)")
+	}
+
+	// Trigger shutdown by cancelling context
+	shutdownStart := time.Now()
+	cancel()
+
+	// Wait for shutdown to complete (with reasonable timeout)
+	shutdownDone := make(chan bool, 1)
+	go func() {
+		time.Sleep(15 * time.Second)
+		shutdownDone <- false
+	}()
+
+	// Give shutdown goroutine time to execute
+	time.Sleep(1 * time.Second)
+
+	// Verify server is no longer accepting connections
+	client := &http.Client{Timeout: 1 * time.Second}
+	_, err := client.Get(fmt.Sprintf("http://localhost:%d/", port))
+	if err == nil {
+		t.Error("Server still accepting connections after shutdown")
+	}
+
+	shutdownDone <- true
+	<-shutdownDone
+
+	shutdownDuration := time.Since(shutdownStart)
+	t.Logf("Graceful shutdown completed in %v", shutdownDuration)
+
+	// Verify shutdown completed in reasonable time (should be much less than 12s)
+	if shutdownDuration > 12*time.Second {
+		t.Errorf("Shutdown took too long: %v (expected < 12s)", shutdownDuration)
+	}
+}