|
|
@@ -608,7 +608,7 @@ func TestQueryCloudCosts_QuerierCapture(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- _, err := s.QueryCloudCosts(req)
|
|
|
+ _, err := s.QueryCloudCosts(context.Background(), req)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
assert.Equal(t, []string{"provider", "service"}, dq.last.AggregateBy)
|
|
|
@@ -633,7 +633,7 @@ func TestProcessMCPRequest_CloudCostDispatch(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- resp, err := s.ProcessMCPRequest(req)
|
|
|
+ resp, err := s.ProcessMCPRequest(context.Background(), req)
|
|
|
require.NoError(t, err)
|
|
|
require.NotNil(t, resp)
|
|
|
require.NotNil(t, resp.Data)
|
|
|
@@ -648,7 +648,7 @@ func TestProcessMCPRequest_UnsupportedType(t *testing.T) {
|
|
|
Window: "1d",
|
|
|
},
|
|
|
}
|
|
|
- _, err := s.ProcessMCPRequest(req)
|
|
|
+ _, err := s.ProcessMCPRequest(context.Background(), req)
|
|
|
require.Error(t, err)
|
|
|
}
|
|
|
|
|
|
@@ -661,7 +661,7 @@ func TestProcessMCPRequest_ValidationError(t *testing.T) {
|
|
|
Window: "",
|
|
|
},
|
|
|
}
|
|
|
- _, err := s.ProcessMCPRequest(req)
|
|
|
+ _, err := s.ProcessMCPRequest(context.Background(), req)
|
|
|
require.Error(t, err)
|
|
|
}
|
|
|
|
|
|
@@ -829,7 +829,7 @@ func TestQueryCloudCosts_NilCloudQuerier(t *testing.T) {
|
|
|
Window: "24h",
|
|
|
}
|
|
|
|
|
|
- _, err := s.QueryCloudCosts(req)
|
|
|
+ _, err := s.QueryCloudCosts(context.Background(), req)
|
|
|
require.Error(t, err)
|
|
|
assert.Contains(t, err.Error(), "cloud cost querier not configured")
|
|
|
}
|
|
|
@@ -842,7 +842,7 @@ func TestQueryCloudCosts_InvalidWindow(t *testing.T) {
|
|
|
Window: "invalid-window",
|
|
|
}
|
|
|
|
|
|
- _, err := s.QueryCloudCosts(req)
|
|
|
+ _, err := s.QueryCloudCosts(context.Background(), req)
|
|
|
require.Error(t, err)
|
|
|
assert.Contains(t, err.Error(), "failed to parse window")
|
|
|
}
|
|
|
@@ -884,7 +884,7 @@ func TestProcessMCPRequest_ResponseMetadata(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- resp, err := s.ProcessMCPRequest(req)
|
|
|
+ resp, err := s.ProcessMCPRequest(context.Background(), req)
|
|
|
require.NoError(t, err)
|
|
|
require.NotNil(t, resp)
|
|
|
|
|
|
@@ -1462,3 +1462,64 @@ func TestTransformCloudCostSetRange_NilPointerHandling(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+// contextAwareQuerier is a mock querier that checks for context cancellation
|
|
|
+type contextAwareQuerier struct {
|
|
|
+ contextWasCancelled bool
|
|
|
+}
|
|
|
+
|
|
|
+func (caq *contextAwareQuerier) Query(ctx context.Context, req cloudcost.QueryRequest) (*opencost.CloudCostSetRange, error) {
|
|
|
+ // Check if context is already cancelled
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ caq.contextWasCancelled = true
|
|
|
+ return nil, ctx.Err()
|
|
|
+ default:
|
|
|
+ // Return empty set range
|
|
|
+ ccsr, _ := opencost.NewCloudCostSetRange(time.Now().Add(-24*time.Hour), time.Now(), opencost.AccumulateOptionDay, "")
|
|
|
+ return ccsr, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestQueryCloudCosts_ContextCancellation(t *testing.T) {
|
|
|
+ // Create a context that is already cancelled
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ cancel() // Cancel immediately
|
|
|
+
|
|
|
+ // Create a context-aware mock querier
|
|
|
+ caq := &contextAwareQuerier{}
|
|
|
+ s := &MCPServer{cloudQuerier: caq}
|
|
|
+
|
|
|
+ req := &OpenCostQueryRequest{
|
|
|
+ QueryType: CloudCostQueryType,
|
|
|
+ Window: "1d",
|
|
|
+ }
|
|
|
+
|
|
|
+ // Query should fail with context cancelled error
|
|
|
+ _, err := s.QueryCloudCosts(ctx, req)
|
|
|
+
|
|
|
+ // Verify context cancellation was detected
|
|
|
+ assert.Error(t, err)
|
|
|
+ assert.True(t, caq.contextWasCancelled, "Context cancellation should be detected by querier")
|
|
|
+ assert.ErrorIs(t, err, context.Canceled, "Error should be context.Canceled")
|
|
|
+}
|
|
|
+
|
|
|
+func TestProcessMCPRequest_ContextPropagation(t *testing.T) {
|
|
|
+ // Test that context is properly propagated through ProcessMCPRequest
|
|
|
+ ctx := context.Background()
|
|
|
+ dq := &dummyQuerier{}
|
|
|
+ s := &MCPServer{cloudQuerier: dq}
|
|
|
+
|
|
|
+ req := &MCPRequest{
|
|
|
+ Query: &OpenCostQueryRequest{
|
|
|
+ QueryType: CloudCostQueryType,
|
|
|
+ Window: "1d",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ resp, err := s.ProcessMCPRequest(ctx, req)
|
|
|
+ require.NoError(t, err)
|
|
|
+ require.NotNil(t, resp)
|
|
|
+ // Verify that the querier was called (context was propagated)
|
|
|
+ assert.NotNil(t, dq.last)
|
|
|
+}
|