Browse Source

Fix(plugins): Kill plugin processes on ingestor shutdown (#3543)

Signed-off-by: Tushar Verma <tusharmyself06@gmail.com>
Tushar-Verma 2 months ago
parent
commit
ca375d48d0

+ 38 - 1
pkg/cmd/costmodel/costmodel.go

@@ -29,6 +29,8 @@ import (
 	"github.com/opencost/opencost/pkg/metrics"
 )
 
+const shutdownTimeout = 30 * time.Second
+
 func Execute(conf *Config) error {
 	log.Infof("Starting cost-model version %s", version.FriendlyVersion())
 	if conf == nil {
@@ -104,7 +106,42 @@ func Execute(conf *Config) error {
 	telemetryHandler := metrics.ResponseMetricMiddleware(rootMux)
 	handler := cors.AllowAll().Handler(telemetryHandler)
 
-	return http.ListenAndServe(fmt.Sprint(":", conf.Port), errors.PanicHandlerMiddleware(handler))
+	server := &http.Server{
+		Addr:    fmt.Sprint(":", conf.Port),
+		Handler: errors.PanicHandlerMiddleware(handler),
+	}
+
+	serverErrors := make(chan error, 1)
+	go func() {
+		log.Infof("HTTP server starting on port %d", conf.Port)
+		serverErrors <- server.ListenAndServe()
+	}()
+
+	select {
+	case err := <-serverErrors:
+		if err != nil && err != http.ErrServerClosed {
+			return err
+		}
+		return nil
+	case <-ctx.Done():
+		log.Infof("Shutdown signal received, starting graceful shutdown...")
+
+		if customCostPipelineService != nil {
+			customCostPipelineService.Stop()
+		}
+
+		shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
+		defer shutdownCancel()
+
+		if err := server.Shutdown(shutdownCtx); err != nil {
+			log.Errorf("Error during server shutdown: %v", err)
+			server.Close()
+			return err
+		}
+
+		log.Infof("Graceful shutdown completed")
+		return nil
+	}
 }
 
 func StartExportWorker(ctx context.Context, model costmodel.AllocationModel) error {

+ 17 - 30
pkg/cmd/costmodel/costmodel_test.go

@@ -12,24 +12,18 @@ import (
 )
 
 func TestMCPServerGracefulShutdown(t *testing.T) {
-	// Test that MCP server responds to context cancellation and shuts down gracefully
-
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
 	accesses := &costmodel.Accesses{}
 	port := env.GetMCPHTTPPort()
 
-	// Channel to signal when server is ready
-	serverReady := make(chan error, 1)
-
 	// Start MCP server
 	go func() {
-		err := StartMCPServer(ctx, accesses, nil)
-		serverReady <- err
+		_ = StartMCPServer(ctx, accesses, nil)
 	}()
 
-	// Wait for server to be ready by attempting to connect
+	// Wait for server to be ready
 	serverUp := false
 	for i := 0; i < 10; i++ {
 		time.Sleep(100 * time.Millisecond)
@@ -43,38 +37,31 @@ func TestMCPServerGracefulShutdown(t *testing.T) {
 	}
 
 	if !serverUp {
-		t.Skip("MCP server did not start (may be expected in test environment)")
+		t.Skip("MCP server did not start")
 	}
 
-	// Trigger shutdown by cancelling context
-	shutdownStart := time.Now()
+	// Trigger shutdown
 	cancel()
-
-	// Wait for shutdown to complete (with reasonable timeout)
-	shutdownDone := make(chan bool, 1)
-	go func() {
-		time.Sleep(15 * time.Second)
-		shutdownDone <- false
-	}()
-
-	// Give shutdown goroutine time to execute
-	time.Sleep(1 * time.Second)
+	time.Sleep(500 * time.Millisecond)
 
 	// Verify server is no longer accepting connections
-	client := &http.Client{Timeout: 1 * time.Second}
+	client := &http.Client{Timeout: 500 * time.Millisecond}
 	_, err := client.Get(fmt.Sprintf("http://localhost:%d/", port))
 	if err == nil {
 		t.Error("Server still accepting connections after shutdown")
 	}
+}
 
-	shutdownDone <- true
-	<-shutdownDone
-
-	shutdownDuration := time.Since(shutdownStart)
-	t.Logf("Graceful shutdown completed in %v", shutdownDuration)
+// TestShutdownTimeoutConstant verifies the shutdown timeout constant is set correctly
+func TestShutdownTimeoutConstant(t *testing.T) {
+	if shutdownTimeout != 30*time.Second {
+		t.Errorf("Expected shutdown timeout of 30s, got %v", shutdownTimeout)
+	}
+}
 
-	// Verify shutdown completed in reasonable time (should be much less than 12s)
-	if shutdownDuration > 12*time.Second {
-		t.Errorf("Shutdown took too long: %v (expected < 12s)", shutdownDuration)
+// TestGracefulShutdownConfiguration verifies graceful shutdown works with the configured timeout
+func TestGracefulShutdownConfiguration(t *testing.T) {
+	if shutdownTimeout < 5*time.Second {
+		t.Error("Shutdown timeout is too short for graceful shutdown")
 	}
 }

+ 20 - 2
pkg/customcost/ingestor.go

@@ -66,6 +66,7 @@ type CustomCostIngestor struct {
 	exitBuildCh  chan string
 	exitRunCh    chan string
 	plugins      map[string]*plugin.Client
+	pluginsLock  sync.RWMutex
 	resolution   time.Duration
 	refreshRate  time.Duration
 }
@@ -128,6 +129,7 @@ func (ing *CustomCostIngestor) LoadWindow(start, end time.Time) {
 
 	for _, window := range targets {
 		allPluginsHave := true
+		ing.pluginsLock.RLock()
 		for domain := range ing.plugins {
 			has, err2 := ing.repo.Has(*window.Start(), domain)
 			if err2 != nil {
@@ -138,12 +140,15 @@ func (ing *CustomCostIngestor) LoadWindow(start, end time.Time) {
 				break
 			}
 		}
+		ing.pluginsLock.RUnlock()
 		if !allPluginsHave {
 			ing.BuildWindow(*window.Start(), *window.End())
 		} else {
+			ing.pluginsLock.RLock()
 			for domain := range ing.plugins {
 				ing.expandCoverage(window, domain)
 			}
+			ing.pluginsLock.RUnlock()
 			log.Debugf("CustomCost[%s]: ingestor: skipping build for window %s, coverage already exists", ing.key, window.String())
 		}
 	}
@@ -152,9 +157,11 @@ func (ing *CustomCostIngestor) LoadWindow(start, end time.Time) {
 
 func (ing *CustomCostIngestor) BuildWindow(start, end time.Time) {
 
+	ing.pluginsLock.RLock()
 	for domain := range ing.plugins {
 		ing.buildSingleDomain(start, end, domain)
 	}
+	ing.pluginsLock.RUnlock()
 }
 
 func (ing *CustomCostIngestor) buildSingleDomain(start, end time.Time, domain string) {
@@ -165,7 +172,9 @@ func (ing *CustomCostIngestor) buildSingleDomain(start, end time.Time, domain st
 	}
 	log.Infof("ingestor: building window %s for plugin %s", opencost.NewWindow(&start, &end), domain)
 	// make RPC call via plugin
+	ing.pluginsLock.RLock()
 	pluginClient, found := ing.plugins[domain]
+	ing.pluginsLock.RUnlock()
 	if !found {
 		log.Errorf("could not find plugin client for plugin %s. Did you initialize the plugin correctly?", domain)
 		return
@@ -262,8 +271,17 @@ func (ing *CustomCostIngestor) Stop() {
 
 	wg.Wait()
 
-	// Declare that the store is officially no longer running. This allows
-	// Start to be called again, restarting the store from scratch.
+	// Kill all plugin client processes before returning
+	ing.pluginsLock.Lock()
+	for name, client := range ing.plugins {
+		if client != nil {
+			log.Debugf("CustomCost[%s]: ingestor: killing plugin process: %s", ing.key, name)
+			client.Kill()
+		}
+	}
+	ing.pluginsLock.Unlock()
+
+	// Mark as no longer running so Start() can be called again if needed
 	ing.isRunning.Store(false)
 	ing.isStopping.Store(false)
 }

+ 189 - 0
pkg/customcost/ingestor_test.go

@@ -0,0 +1,189 @@
+package customcost
+
+import (
+	"os/exec"
+	"runtime"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/hashicorp/go-plugin"
+)
+
+func TestIngestor_Stop_KillsPluginProcesses(t *testing.T) {
+	cmd := exec.Command("sleep", "60")
+	client := plugin.NewClient(&plugin.ClientConfig{
+		HandshakeConfig: plugin.HandshakeConfig{
+			ProtocolVersion:  1,
+			MagicCookieKey:   "test",
+			MagicCookieValue: "test",
+		},
+		Cmd:          cmd,
+		StartTimeout: 2 * time.Second,
+	})
+	// Start the process (handshake will fail but process runs)
+	_, _ = client.Client()
+
+	ingestor := &CustomCostIngestor{
+		plugins: map[string]*plugin.Client{
+			"test-plugin": client,
+		},
+	}
+	ingestor.Stop()
+
+	if !client.Exited() {
+		t.Error("Expected plugin client process to be terminated after Stop()")
+	}
+}
+
+func TestIngestor_Stop_MultiplePlugins(t *testing.T) {
+	clients := make(map[string]*plugin.Client)
+	for _, name := range []string{"plugin-a", "plugin-b", "plugin-c"} {
+		cmd := exec.Command("sleep", "60")
+		client := plugin.NewClient(&plugin.ClientConfig{
+			HandshakeConfig: plugin.HandshakeConfig{
+				ProtocolVersion:  1,
+				MagicCookieKey:   "test",
+				MagicCookieValue: name,
+			},
+			Cmd:          cmd,
+			StartTimeout: 2 * time.Second,
+		})
+		_, _ = client.Client()
+		clients[name] = client
+	}
+
+	ingestor := &CustomCostIngestor{plugins: clients}
+	ingestor.Stop()
+
+	for name, client := range clients {
+		if !client.Exited() {
+			t.Errorf("Expected plugin %s to be terminated after Stop()", name)
+		}
+	}
+}
+
+func TestIngestor_Stop_EmptyPluginsMap(t *testing.T) {
+	ingestor := &CustomCostIngestor{
+		plugins: map[string]*plugin.Client{},
+	}
+	ingestor.Stop() // covers lock path with 0 iterations
+}
+
+func TestIngestor_Stop_NilPluginsMap(t *testing.T) {
+	ingestor := &CustomCostIngestor{}
+	ingestor.Stop() // should not panic
+}
+
+func TestIngestor_Stop_AlreadyStopping(t *testing.T) {
+	ingestor := &CustomCostIngestor{
+		plugins: map[string]*plugin.Client{},
+	}
+	ingestor.isStopping.Store(true) // atomic.Bool must use Store()!
+	ingestor.Stop()                 // should return immediately
+}
+
+func TestIngestor_Stop_ConcurrentCalls(t *testing.T) {
+	ingestor := &CustomCostIngestor{
+		plugins: map[string]*plugin.Client{},
+	}
+
+	var wg sync.WaitGroup
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			ingestor.Stop()
+		}()
+	}
+
+	done := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(done)
+	}()
+
+	select {
+	case <-done:
+		// success
+	case <-time.After(5 * time.Second):
+		t.Fatal("Concurrent Stop() calls deadlocked")
+	}
+}
+
+func TestIngestor_Stop_WithStartedIngestor(t *testing.T) {
+	repo := NewMemoryRepository()
+	config := &CustomCostIngestorConfig{
+		DailyDuration:     7 * 24 * time.Hour,
+		HourlyDuration:    16 * time.Hour,
+		DailyQueryWindow:  24 * time.Hour,
+		HourlyQueryWindow: time.Hour,
+	}
+
+	ingestor, err := NewCustomCostIngestor(config, repo, map[string]*plugin.Client{}, time.Hour)
+	if err != nil {
+		t.Fatalf("Failed to create ingestor: %v", err)
+	}
+
+	ingestor.Start(false)
+	time.Sleep(100 * time.Millisecond)
+
+	done := make(chan struct{})
+	go func() {
+		ingestor.Stop()
+		close(done)
+	}()
+
+	select {
+	case <-done:
+		// success
+	case <-time.After(5 * time.Second):
+		t.Fatal("Stop() on started ingestor timed out")
+	}
+
+	if ingestor.isRunning.Load() {
+		t.Error("Expected isRunning to be false after Stop()")
+	}
+	if ingestor.isStopping.Load() {
+		t.Error("Expected isStopping to be false after Stop()")
+	}
+}
+
+// TestIngestor_BuildWindow_WithPlugin covers pluginsLock paths inside buildSingleDomain.
+// Using a command that exits immediately causes client.Client() to fail fast, exercising
+// the RLock/RUnlock calls and the error-return path without hanging.
+func TestIngestor_BuildWindow_WithPlugin(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skip("requires Unix false command")
+	}
+
+	cmd := exec.Command("false") // exits immediately with failure
+	client := plugin.NewClient(&plugin.ClientConfig{
+		HandshakeConfig: plugin.HandshakeConfig{
+			ProtocolVersion:  1,
+			MagicCookieKey:   "test",
+			MagicCookieValue: "test",
+		},
+		Cmd:          cmd,
+		StartTimeout: 2 * time.Second,
+	})
+	t.Cleanup(func() { client.Kill() })
+
+	repo := NewMemoryRepository()
+	config := &CustomCostIngestorConfig{
+		DailyDuration:     24 * time.Hour,
+		HourlyDuration:    time.Hour,
+		DailyQueryWindow:  24 * time.Hour,
+		HourlyQueryWindow: time.Hour,
+	}
+
+	ingestor, err := NewCustomCostIngestor(config, repo, map[string]*plugin.Client{"test-plugin": client}, 24*time.Hour)
+	if err != nil {
+		t.Fatalf("Failed to create ingestor: %v", err)
+	}
+
+	now := time.Now().UTC()
+	// BuildWindow iterates the plugins map, exercising pluginsLock in both
+	// BuildWindow and buildSingleDomain; client.Client() fails fast (false exits)
+	ingestor.BuildWindow(now.Add(-time.Hour), now)
+}

+ 20 - 0
pkg/customcost/pipelineservice.go

@@ -147,6 +147,26 @@ func NewPipelineService(hourlyrepo, dailyrepo Repository, ingConf CustomCostInge
 	}, nil
 }
 
+// Stop gracefully shuts down both hourly and daily ingestors.
+// Both ingestors may reference the same plugin clients, so Kill() may be invoked
+// multiple times per plugin, which is safe per the go-plugin library.
+func (ps *PipelineService) Stop() {
+	if ps == nil {
+		return
+	}
+	log.Infof("Shutting down CustomCost Pipeline Service")
+
+	if ps.hourlyIngestor != nil {
+		ps.hourlyIngestor.Stop()
+	}
+
+	if ps.dailyIngestor != nil {
+		ps.dailyIngestor.Stop()
+	}
+
+	log.Infof("CustomCost Pipeline Service stopped successfully")
+}
+
 // Status gives a combined view of the state of configs and the ingestor status
 func (dp *PipelineService) Status() Status {
 

+ 91 - 0
pkg/customcost/pipelineservice_test.go

@@ -10,6 +10,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/hashicorp/go-plugin"
 	"github.com/opencost/opencost/core/pkg/log"
 	"github.com/opencost/opencost/core/pkg/util/timeutil"
 )
@@ -255,3 +256,93 @@ func writeDDConfig(pluginConfigDir string, t *testing.T) {
 		t.Fatalf("could not write file: %v", err)
 	}
 }
+
+// TestPipelineService_Stop_Nil ensures nil PipelineService is safe
+func TestPipelineService_Stop_Nil(t *testing.T) {
+	var ps *PipelineService
+	ps.Stop()
+	t.Log("Nil PipelineService handled safely")
+}
+
+// TestPipelineService_Stop_WithNilIngestors ensures nil ingestors are handled
+func TestPipelineService_Stop_WithNilIngestors(t *testing.T) {
+	ps := &PipelineService{
+		hourlyIngestor: nil,
+		dailyIngestor:  nil,
+		domains:        []string{},
+	}
+
+	ps.Stop()
+	t.Log("Nil ingestors handled safely")
+}
+
+// TestPipelineService_Stop_PartialNilIngestors ensures partial nil is handled
+func TestPipelineService_Stop_PartialNilIngestors(t *testing.T) {
+	hourly := &CustomCostIngestor{
+		key:     "hourly",
+		plugins: make(map[string]*plugin.Client),
+	}
+
+	ps := &PipelineService{
+		hourlyIngestor: hourly,
+		dailyIngestor:  nil,
+		domains:        []string{},
+	}
+
+	ps.Stop()
+	t.Log("Partial nil ingestors handled safely")
+}
+
+// TestPipelineService_Stop_ShutdownLogging verifies logging during shutdown
+func TestPipelineService_Stop_ShutdownLogging(t *testing.T) {
+	ps := &PipelineService{
+		hourlyIngestor: &CustomCostIngestor{
+			key:     "hourly",
+			plugins: make(map[string]*plugin.Client),
+		},
+		dailyIngestor: &CustomCostIngestor{
+			key:     "daily",
+			plugins: make(map[string]*plugin.Client),
+		},
+		domains: []string{},
+	}
+
+	ps.Stop()
+	time.Sleep(50 * time.Millisecond)
+
+	t.Log("Pipeline service logged shutdown progress")
+}
+
+func TestPipelineService_Stop_NilReceiver(t *testing.T) {
+	var ps *PipelineService
+	ps.Stop() // should not panic on nil receiver
+}
+
+func TestPipelineService_Stop_NilIngestors(t *testing.T) {
+	ps := &PipelineService{}
+	ps.Stop() // should not panic when ingestors are nil
+}
+
+func TestPipelineService_Stop_WithIngestors(t *testing.T) {
+	hourly := &CustomCostIngestor{plugins: map[string]*plugin.Client{}}
+	daily := &CustomCostIngestor{plugins: map[string]*plugin.Client{}}
+	ps := &PipelineService{
+		hourlyIngestor: hourly,
+		dailyIngestor:  daily,
+	}
+	ps.Stop()
+}
+
+func TestPipelineService_Stop_OnlyHourlyIngestor(t *testing.T) {
+	ps := &PipelineService{
+		hourlyIngestor: &CustomCostIngestor{plugins: map[string]*plugin.Client{}},
+	}
+	ps.Stop() // should not panic when dailyIngestor is nil
+}
+
+func TestPipelineService_Stop_OnlyDailyIngestor(t *testing.T) {
+	ps := &PipelineService{
+		dailyIngestor: &CustomCostIngestor{plugins: map[string]*plugin.Client{}},
+	}
+	ps.Stop() // should not panic when hourlyIngestor is nil
+}