dcgm_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package kubemodel
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/stretchr/testify/require"
  6. )
  7. func TestValidateDCGMDevice(t *testing.T) {
  8. start := time.Now().UTC().Truncate(time.Hour)
  9. end := start.Add(time.Hour)
  10. window := Window{Start: start, End: end}
  11. tests := []struct {
  12. name string
  13. device *DCGMDevice
  14. wantErr string
  15. }{
  16. {
  17. name: "empty UUID",
  18. device: &DCGMDevice{Device: "GPU-0", Start: start, End: end},
  19. wantErr: "UUID is missing for DCGMDevice with device 'GPU-0'",
  20. },
  21. {
  22. name: "outside window",
  23. device: &DCGMDevice{UUID: "gpu-uuid", Device: "GPU-0", Start: start.Add(-time.Hour), End: end},
  24. wantErr: checkWindow(window, start.Add(-time.Hour), end).Error(),
  25. },
  26. {
  27. name: "valid",
  28. device: &DCGMDevice{UUID: "gpu-uuid", Device: "GPU-0", Start: start, End: end},
  29. },
  30. }
  31. for _, tt := range tests {
  32. t.Run(tt.name, func(t *testing.T) {
  33. err := tt.device.ValidateDCGMDevice(window)
  34. if tt.wantErr != "" {
  35. require.EqualError(t, err, tt.wantErr)
  36. } else {
  37. require.NoError(t, err)
  38. }
  39. })
  40. }
  41. }
  42. func TestRegisterDCGMDevice(t *testing.T) {
  43. start := time.Now().UTC().Truncate(time.Hour)
  44. end := start.Add(time.Hour)
  45. newDevice := func(uuid, device string) *DCGMDevice {
  46. return &DCGMDevice{UUID: uuid, Device: device, Start: start, End: end}
  47. }
  48. withCluster := func(kms *KubeModelSet) {
  49. kms.RegisterCluster(&Cluster{UID: "cluster-uid", Start: start, End: end})
  50. }
  51. tests := []struct {
  52. name string
  53. setup func(*KubeModelSet)
  54. device *DCGMDevice
  55. wantErr string
  56. want *KubeModelSet
  57. }{
  58. {
  59. name: "validation failure",
  60. device: &DCGMDevice{UUID: "", Device: "GPU-0", Start: start, End: end},
  61. wantErr: "RegisterDCGMDevice: invalid dcgm device: UUID is missing for DCGMDevice with device 'GPU-0'",
  62. want: func() *KubeModelSet {
  63. kms := NewKubeModelSet(start, end)
  64. kms.Metadata.Diagnostics = []Diagnostic{
  65. {Level: DiagnosticLevelError, Message: "RegisterDCGMDevice: invalid dcgm device: UUID is missing for DCGMDevice with device 'GPU-0'"},
  66. }
  67. return kms
  68. }(),
  69. },
  70. {
  71. name: "warns when cluster is nil",
  72. device: newDevice("gpu-uuid", "GPU-0"),
  73. want: func() *KubeModelSet {
  74. kms := NewKubeModelSet(start, end)
  75. kms.DCGMDevices["gpu-uuid"] = newDevice("gpu-uuid", "GPU-0")
  76. kms.Metadata.ObjectCount = 1
  77. kms.Metadata.Diagnostics = []Diagnostic{
  78. {Level: DiagnosticLevelWarning, Message: "RegisterDCGMDevice: Cluster is nil"},
  79. }
  80. return kms
  81. }(),
  82. },
  83. {
  84. name: "registers device with cluster",
  85. setup: withCluster,
  86. device: newDevice("gpu-uuid", "GPU-0"),
  87. want: func() *KubeModelSet {
  88. kms := NewKubeModelSet(start, end)
  89. withCluster(kms)
  90. kms.DCGMDevices["gpu-uuid"] = newDevice("gpu-uuid", "GPU-0")
  91. kms.Metadata.ObjectCount = 1
  92. return kms
  93. }(),
  94. },
  95. {
  96. name: "duplicate registration is a no-op",
  97. setup: func(kms *KubeModelSet) {
  98. withCluster(kms)
  99. kms.RegisterDCGMDevice(newDevice("gpu-uuid", "GPU-0"))
  100. },
  101. device: newDevice("gpu-uuid", "GPU-1"),
  102. want: func() *KubeModelSet {
  103. kms := NewKubeModelSet(start, end)
  104. withCluster(kms)
  105. kms.DCGMDevices["gpu-uuid"] = newDevice("gpu-uuid", "GPU-0")
  106. kms.Metadata.ObjectCount = 1
  107. return kms
  108. }(),
  109. },
  110. }
  111. for _, tt := range tests {
  112. t.Run(tt.name, func(t *testing.T) {
  113. kms := NewKubeModelSet(start, end)
  114. if tt.setup != nil {
  115. tt.setup(kms)
  116. }
  117. err := kms.RegisterDCGMDevice(tt.device)
  118. if tt.wantErr != "" {
  119. require.EqualError(t, err, tt.wantErr)
  120. } else {
  121. require.NoError(t, err)
  122. }
  123. KubeModelSetEquals(t, tt.want, kms)
  124. })
  125. }
  126. }