Просмотр исходного кода

Fix GPUAllocation.Equal to compare pointer field values (#3849)

Signed-off-by: Cliff Colvin <clifford.colvin@gmail.com>
Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
Cliff Colvin 1 неделя назад
Родитель
Сommit
af0a27f5c5
2 измененных файлов с 92 добавлено и 3 удалено
  1. 16 3
      core/pkg/opencost/allocation.go
  2. 76 0
      core/pkg/opencost/allocation_test.go

+ 16 - 3
core/pkg/opencost/allocation.go

@@ -148,6 +148,19 @@ func (orig *GPUAllocation) Clone() *GPUAllocation {
 	}
 }
 
+// ptrValueEqual reports whether two pointers are both nil, or both non-nil
+// and pointing to equal values. Plain == on pointer fields compares
+// addresses, which made equal-valued GPUAllocations (e.g. binary
+// roundtrips) compare unequal (#3846). NaN values compare unequal per Go ==
+// semantics; SanitizeNaN normalizes NaN pointers to nil before comparisons
+// where that matters.
+func ptrValueEqual[T comparable](a, b *T) bool {
+	if a == nil || b == nil {
+		return a == b
+	}
+	return *a == *b
+}
+
 func (orig *GPUAllocation) Equal(that *GPUAllocation) bool {
 	if orig == nil && that == nil {
 		return true
@@ -159,9 +172,9 @@ func (orig *GPUAllocation) Equal(that *GPUAllocation) bool {
 	return orig.GPUDevice == that.GPUDevice &&
 		orig.GPUModel == that.GPUModel &&
 		orig.GPUUUID == that.GPUUUID &&
-		orig.IsGPUShared == that.IsGPUShared &&
-		orig.GPUUsageAverage == that.GPUUsageAverage &&
-		orig.GPURequestAverage == that.GPURequestAverage
+		ptrValueEqual(orig.IsGPUShared, that.IsGPUShared) &&
+		ptrValueEqual(orig.GPUUsageAverage, that.GPUUsageAverage) &&
+		ptrValueEqual(orig.GPURequestAverage, that.GPURequestAverage)
 
 }
 

+ 76 - 0
core/pkg/opencost/allocation_test.go

@@ -3956,3 +3956,79 @@ func checkAllFloat64sForNaN(t *testing.T, v reflect.Value, testCaseName string)
 		}
 	}
 }
+
+// TestGPUAllocation_Equal verifies value semantics for the pointer fields:
+// two independently constructed GPUAllocations with equal contents must be
+// equal, regardless of pointer identity. Regression test for #3846.
+func TestGPUAllocation_Equal(t *testing.T) {
+	makeGPUAllocation := func() *GPUAllocation {
+		shared := true
+		usage := 0.5
+		request := 1.0
+		return &GPUAllocation{
+			GPUDevice:         "nvidia0",
+			GPUModel:          "Tesla T4",
+			GPUUUID:           "GPU-1",
+			IsGPUShared:       &shared,
+			GPUUsageAverage:   &usage,
+			GPURequestAverage: &request,
+		}
+	}
+
+	cases := map[string]struct {
+		a, b *GPUAllocation
+		want bool
+	}{
+		"both nil": {nil, nil, true},
+		"one nil":  {makeGPUAllocation(), nil, false},
+		"identical values, distinct pointers": {
+			makeGPUAllocation(), makeGPUAllocation(), true,
+		},
+		"different usage value": {
+			makeGPUAllocation(),
+			func() *GPUAllocation { g := makeGPUAllocation(); v := 0.9; g.GPUUsageAverage = &v; return g }(),
+			false,
+		},
+		"different shared value": {
+			makeGPUAllocation(),
+			func() *GPUAllocation { g := makeGPUAllocation(); v := false; g.IsGPUShared = &v; return g }(),
+			false,
+		},
+		"nil vs set pointer field": {
+			makeGPUAllocation(),
+			func() *GPUAllocation { g := makeGPUAllocation(); g.GPURequestAverage = nil; return g }(),
+			false,
+		},
+		"different device identity": {
+			makeGPUAllocation(),
+			func() *GPUAllocation { g := makeGPUAllocation(); g.GPUUUID = "GPU-2"; return g }(),
+			false,
+		},
+	}
+
+	for name, tc := range cases {
+		t.Run(name, func(t *testing.T) {
+			if got := tc.a.Equal(tc.b); got != tc.want {
+				t.Errorf("Equal() = %v, want %v", got, tc.want)
+			}
+			if got := tc.b.Equal(tc.a); got != tc.want {
+				t.Errorf("Equal() reversed = %v, want %v", got, tc.want)
+			}
+		})
+	}
+
+	t.Run("binary roundtrip equals original", func(t *testing.T) {
+		orig := makeGPUAllocation()
+		bs, err := orig.MarshalBinary()
+		if err != nil {
+			t.Fatalf("MarshalBinary: %s", err)
+		}
+		decoded := new(GPUAllocation)
+		if err := decoded.UnmarshalBinary(bs); err != nil {
+			t.Fatalf("UnmarshalBinary: %s", err)
+		}
+		if !orig.Equal(decoded) {
+			t.Errorf("roundtrip-decoded GPUAllocation not Equal to original: %+v vs %+v", orig, decoded)
+		}
+	})
+}