Преглед на файлове

fix(mcp): Propagate request context to cloud cost queries (#3518)

Signed-off-by: Tushar Verma <tusharmyself06@gmail.com>
Tushar-Verma преди 4 месеца
родител
ревизия
65d5e272f5
променени са 5 файла, в които са добавени 160 реда и са изтрити 19 реда
  1. 4 4
      pkg/cmd/costmodel/costmodel.go
  2. 13 1
      pkg/env/opencost.go
  3. 61 0
      pkg/env/opencost_test.go
  4. 14 7
      pkg/mcp/server.go
  5. 68 7
      pkg/mcp/server_test.go

+ 4 - 4
pkg/cmd/costmodel/costmodel.go

@@ -188,7 +188,7 @@ func StartMCPServer(ctx context.Context, accesses *costmodel.Accesses, cloudCost
 			Query: queryRequest,
 		}
 
-		mcpResp, err := mcpServer.ProcessMCPRequest(mcpReq)
+		mcpResp, err := mcpServer.ProcessMCPRequest(ctx, mcpReq)
 		if err != nil {
 			return nil, nil, fmt.Errorf("failed to process allocation request: %w", err)
 		}
@@ -207,7 +207,7 @@ func StartMCPServer(ctx context.Context, accesses *costmodel.Accesses, cloudCost
 			Query: queryRequest,
 		}
 
-		mcpResp, err := mcpServer.ProcessMCPRequest(mcpReq)
+		mcpResp, err := mcpServer.ProcessMCPRequest(ctx, mcpReq)
 		if err != nil {
 			return nil, nil, fmt.Errorf("failed to process asset request: %w", err)
 		}
@@ -235,7 +235,7 @@ func StartMCPServer(ctx context.Context, accesses *costmodel.Accesses, cloudCost
 			Query: queryRequest,
 		}
 
-		mcpResp, err := mcpServer.ProcessMCPRequest(mcpReq)
+		mcpResp, err := mcpServer.ProcessMCPRequest(ctx, mcpReq)
 		if err != nil {
 			return nil, nil, fmt.Errorf("failed to process cloud cost request: %w", err)
 		}
@@ -258,7 +258,7 @@ func StartMCPServer(ctx context.Context, accesses *costmodel.Accesses, cloudCost
 			Query: queryRequest,
 		}
 
-		mcpResp, err := mcpServer.ProcessMCPRequest(mcpReq)
+		mcpResp, err := mcpServer.ProcessMCPRequest(ctx, mcpReq)
 		if err != nil {
 			return nil, nil, fmt.Errorf("failed to process efficiency request: %w", err)
 		}

+ 13 - 1
pkg/env/opencost.go

@@ -15,7 +15,8 @@ const (
 )
 
 const (
-	UTCOffsetEnvVar = "UTC_OFFSET"
+	UTCOffsetEnvVar              = "UTC_OFFSET"
+	MCPQueryTimeoutSecondsEnvVar = "MCP_QUERY_TIMEOUT_SECONDS"
 )
 
 func GetOpencostAPIPort() int {
@@ -42,3 +43,14 @@ func GetParsedUTCOffset() time.Duration {
 	}
 	return offset
 }
+
+// GetMCPQueryTimeout returns the configured timeout for MCP query operations.
+// Default is 60 seconds, but can be configured via MCP_QUERY_TIMEOUT_SECONDS environment variable.
+// Minimum timeout is 1 second to prevent immediate timeouts.
+func GetMCPQueryTimeout() time.Duration {
+	seconds := env.GetInt(MCPQueryTimeoutSecondsEnvVar, 60)
+	if seconds <= 0 {
+		seconds = 1
+	}
+	return time.Duration(seconds) * time.Second
+}

+ 61 - 0
pkg/env/opencost_test.go

@@ -4,8 +4,11 @@ import (
 	"fmt"
 	"os"
 	"testing"
+	"time"
 
 	"github.com/opencost/opencost/core/pkg/env"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestGetAPIPort(t *testing.T) {
@@ -45,3 +48,61 @@ func TestGetAPIPort(t *testing.T) {
 	}
 
 }
+
+func TestGetMCPQueryTimeout_Default(t *testing.T) {
+	// Ensure env var is not set
+	os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 60*time.Second, timeout, "Default timeout should be 60 seconds")
+}
+
+func TestGetMCPQueryTimeout_CustomValue(t *testing.T) {
+	// Set custom timeout
+	err := os.Setenv(MCPQueryTimeoutSecondsEnvVar, "120")
+	require.NoError(t, err)
+	defer os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 120*time.Second, timeout, "Custom timeout should be 120 seconds")
+}
+
+func TestGetMCPQueryTimeout_InvalidValue(t *testing.T) {
+	// Set invalid value (should fall back to default)
+	err := os.Setenv(MCPQueryTimeoutSecondsEnvVar, "invalid")
+	require.NoError(t, err)
+	defer os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 60*time.Second, timeout, "Invalid value should fall back to default 60 seconds")
+}
+
+func TestGetMCPQueryTimeout_ZeroValue(t *testing.T) {
+	// Set zero value - should fall back to minimum of 1 second
+	err := os.Setenv(MCPQueryTimeoutSecondsEnvVar, "0")
+	require.NoError(t, err)
+	defer os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 1*time.Second, timeout, "Zero value should use minimum of 1 second")
+}
+
+func TestGetMCPQueryTimeout_NegativeValue(t *testing.T) {
+	// Set negative value - should fall back to minimum of 1 second
+	err := os.Setenv(MCPQueryTimeoutSecondsEnvVar, "-10")
+	require.NoError(t, err)
+	defer os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 1*time.Second, timeout, "Negative value should use minimum of 1 second")
+}
+
+func TestGetMCPQueryTimeout_LargeValue(t *testing.T) {
+	// Set large timeout value
+	err := os.Setenv(MCPQueryTimeoutSecondsEnvVar, "3600")
+	require.NoError(t, err)
+	defer os.Unsetenv(MCPQueryTimeoutSecondsEnvVar)
+
+	timeout := GetMCPQueryTimeout()
+	assert.Equal(t, 3600*time.Second, timeout, "Large timeout should be accepted (1 hour)")
+}

+ 14 - 7
pkg/mcp/server.go

@@ -19,6 +19,7 @@ import (
 	models "github.com/opencost/opencost/pkg/cloud/models"
 	"github.com/opencost/opencost/pkg/cloudcost"
 	"github.com/opencost/opencost/pkg/costmodel"
+	"github.com/opencost/opencost/pkg/env"
 )
 
 // QueryType defines the type of query to be executed.
@@ -377,8 +378,8 @@ func NewMCPServer(costModel *costmodel.CostModel, provider models.Provider, clou
 }
 
 // ProcessMCPRequest processes an MCP request and returns an MCP response.
-
-func (s *MCPServer) ProcessMCPRequest(request *MCPRequest) (*MCPResponse, error) {
+// It accepts a context for proper timeout handling and cancellation.
+func (s *MCPServer) ProcessMCPRequest(ctx context.Context, request *MCPRequest) (*MCPResponse, error) {
 	// 1. Validate Request
 	if err := validate.Struct(request); err != nil {
 		return nil, fmt.Errorf("validation failed: %w", err)
@@ -396,7 +397,7 @@ func (s *MCPServer) ProcessMCPRequest(request *MCPRequest) (*MCPResponse, error)
 	case AssetQueryType:
 		data, err = s.QueryAssets(request.Query)
 	case CloudCostQueryType:
-		data, err = s.QueryCloudCosts(request.Query)
+		data, err = s.QueryCloudCosts(ctx, request.Query)
 	case EfficiencyQueryType:
 		data, err = s.QueryEfficiency(request.Query)
 	default:
@@ -714,7 +715,8 @@ func transformAssetSet(assetSet *opencost.AssetSet) *AssetResponse {
 }
 
 // QueryCloudCosts translates an MCP query into a CloudCost repository query and transforms the result.
-func (s *MCPServer) QueryCloudCosts(query *OpenCostQueryRequest) (*CloudCostResponse, error) {
+// The ctx parameter is used for timeout and cancellation handling of the cloud cost query.
+func (s *MCPServer) QueryCloudCosts(ctx context.Context, query *OpenCostQueryRequest) (*CloudCostResponse, error) {
 	// 1. Check if cloud cost querier is available
 	if s.cloudQuerier == nil {
 		return nil, fmt.Errorf("cloud cost querier not configured - check cloud-integration.json file")
@@ -738,13 +740,18 @@ func (s *MCPServer) QueryCloudCosts(query *OpenCostQueryRequest) (*CloudCostResp
 		request = s.buildCloudCostQueryRequest(request, query.CloudCostParams)
 	}
 
-	// 5. Query the repository (this handles multiple cloud providers automatically)
-	ccsr, err := s.cloudQuerier.Query(context.TODO(), request)
+	// 5. Create a timeout context for the query with configured timeout
+	queryTimeout := env.GetMCPQueryTimeout()
+	queryCtx, cancel := context.WithTimeout(ctx, queryTimeout)
+	defer cancel()
+
+	// 6. Query the repository (this handles multiple cloud providers automatically)
+	ccsr, err := s.cloudQuerier.Query(queryCtx, request)
 	if err != nil {
 		return nil, fmt.Errorf("failed to query cloud costs: %w", err)
 	}
 
-	// 6. Transform Response
+	// 7. Transform Response
 	return transformCloudCostSetRange(ccsr), nil
 }
 

+ 68 - 7
pkg/mcp/server_test.go

@@ -608,7 +608,7 @@ func TestQueryCloudCosts_QuerierCapture(t *testing.T) {
 		},
 	}
 
-	_, err := s.QueryCloudCosts(req)
+	_, err := s.QueryCloudCosts(context.Background(), req)
 	require.NoError(t, err)
 
 	assert.Equal(t, []string{"provider", "service"}, dq.last.AggregateBy)
@@ -633,7 +633,7 @@ func TestProcessMCPRequest_CloudCostDispatch(t *testing.T) {
 		},
 	}
 
-	resp, err := s.ProcessMCPRequest(req)
+	resp, err := s.ProcessMCPRequest(context.Background(), req)
 	require.NoError(t, err)
 	require.NotNil(t, resp)
 	require.NotNil(t, resp.Data)
@@ -648,7 +648,7 @@ func TestProcessMCPRequest_UnsupportedType(t *testing.T) {
 			Window:    "1d",
 		},
 	}
-	_, err := s.ProcessMCPRequest(req)
+	_, err := s.ProcessMCPRequest(context.Background(), req)
 	require.Error(t, err)
 }
 
@@ -661,7 +661,7 @@ func TestProcessMCPRequest_ValidationError(t *testing.T) {
 			Window:    "",
 		},
 	}
-	_, err := s.ProcessMCPRequest(req)
+	_, err := s.ProcessMCPRequest(context.Background(), req)
 	require.Error(t, err)
 }
 
@@ -829,7 +829,7 @@ func TestQueryCloudCosts_NilCloudQuerier(t *testing.T) {
 		Window:    "24h",
 	}
 
-	_, err := s.QueryCloudCosts(req)
+	_, err := s.QueryCloudCosts(context.Background(), req)
 	require.Error(t, err)
 	assert.Contains(t, err.Error(), "cloud cost querier not configured")
 }
@@ -842,7 +842,7 @@ func TestQueryCloudCosts_InvalidWindow(t *testing.T) {
 		Window:    "invalid-window",
 	}
 
-	_, err := s.QueryCloudCosts(req)
+	_, err := s.QueryCloudCosts(context.Background(), req)
 	require.Error(t, err)
 	assert.Contains(t, err.Error(), "failed to parse window")
 }
@@ -884,7 +884,7 @@ func TestProcessMCPRequest_ResponseMetadata(t *testing.T) {
 		},
 	}
 
-	resp, err := s.ProcessMCPRequest(req)
+	resp, err := s.ProcessMCPRequest(context.Background(), req)
 	require.NoError(t, err)
 	require.NotNil(t, resp)
 
@@ -1462,3 +1462,64 @@ func TestTransformCloudCostSetRange_NilPointerHandling(t *testing.T) {
 		})
 	}
 }
+
+// contextAwareQuerier is a mock querier that checks for context cancellation
+type contextAwareQuerier struct {
+	contextWasCancelled bool
+}
+
+func (caq *contextAwareQuerier) Query(ctx context.Context, req cloudcost.QueryRequest) (*opencost.CloudCostSetRange, error) {
+	// Check if context is already cancelled
+	select {
+	case <-ctx.Done():
+		caq.contextWasCancelled = true
+		return nil, ctx.Err()
+	default:
+		// Return empty set range
+		ccsr, _ := opencost.NewCloudCostSetRange(time.Now().Add(-24*time.Hour), time.Now(), opencost.AccumulateOptionDay, "")
+		return ccsr, nil
+	}
+}
+
+func TestQueryCloudCosts_ContextCancellation(t *testing.T) {
+	// Create a context that is already cancelled
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel() // Cancel immediately
+
+	// Create a context-aware mock querier
+	caq := &contextAwareQuerier{}
+	s := &MCPServer{cloudQuerier: caq}
+
+	req := &OpenCostQueryRequest{
+		QueryType: CloudCostQueryType,
+		Window:    "1d",
+	}
+
+	// Query should fail with context cancelled error
+	_, err := s.QueryCloudCosts(ctx, req)
+	
+	// Verify context cancellation was detected
+	assert.Error(t, err)
+	assert.True(t, caq.contextWasCancelled, "Context cancellation should be detected by querier")
+	assert.ErrorIs(t, err, context.Canceled, "Error should be context.Canceled")
+}
+
+func TestProcessMCPRequest_ContextPropagation(t *testing.T) {
+	// Test that context is properly propagated through ProcessMCPRequest
+	ctx := context.Background()
+	dq := &dummyQuerier{}
+	s := &MCPServer{cloudQuerier: dq}
+
+	req := &MCPRequest{
+		Query: &OpenCostQueryRequest{
+			QueryType: CloudCostQueryType,
+			Window:    "1d",
+		},
+	}
+
+	resp, err := s.ProcessMCPRequest(ctx, req)
+	require.NoError(t, err)
+	require.NotNil(t, resp)
+	// Verify that the querier was called (context was propagated)
+	assert.NotNil(t, dq.last)
+}