gpusaturationquerier_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package prom
  2. import (
  3. "strings"
  4. "testing"
  5. "time"
  6. "github.com/rs/zerolog"
  7. zerologger "github.com/rs/zerolog/log"
  8. )
  9. func TestBuildGPUThrottleViolationQuery(t *testing.T) {
  10. query := buildGPUThrottleViolationQuery(`cluster_id="c1"`, "1h", "cluster_id", 3600)
  11. branches := strings.Split(query, " or ")
  12. if len(branches) != 4 {
  13. t.Fatalf("expected 4 violation branches, got %d: %s", len(branches), query)
  14. }
  15. // one hour is 3.6e9 microseconds: each branch must normalize by it
  16. wantBranch := `label_replace(avg(increase(DCGM_FI_DEV_POWER_VIOLATION{container!="",cluster_id="c1"}[1h])) by (container, pod, namespace, device, modelName, UUID, GPU_I_PROFILE, GPU_I_ID, pod_uid, cluster_id) / 3.6e+09, "reason", "power", "", "")`
  17. if branches[0] != wantBranch {
  18. t.Errorf("violation branch mismatch:\n got %s\nwant %s", branches[0], wantBranch)
  19. }
  20. for metric, reason := range map[string]string{
  21. "DCGM_FI_DEV_POWER_VIOLATION": "power",
  22. "DCGM_FI_DEV_THERMAL_VIOLATION": "thermal",
  23. "DCGM_FI_DEV_SYNC_BOOST_VIOLATION": "sync_boost",
  24. "DCGM_FI_DEV_BOARD_LIMIT_VIOLATION": "board_limit",
  25. } {
  26. if !strings.Contains(query, metric) {
  27. t.Errorf("expected query to reference %s", metric)
  28. }
  29. if !strings.Contains(query, `"reason", "`+reason+`"`) {
  30. t.Errorf("expected query to tag reason %q", reason)
  31. }
  32. }
  33. }
  34. func TestBuildGPUThrottleReasonQuery(t *testing.T) {
  35. query := buildGPUThrottleReasonQuery(`cluster_id="c1"`, "1h", "cluster_id", 5)
  36. // one label_replace-wrapped branch per saturation-relevant reason bit
  37. if got := strings.Count(query, "label_replace"); got != 6 {
  38. t.Fatalf("expected 6 reason branches, got %d: %s", got, query)
  39. }
  40. // the first branch tests the sw_power_cap bit (0x4 == 4) per sample at
  41. // the subquery resolution, then averages the 0/1 results over the window
  42. wantBranch := `label_replace(avg(avg_over_time(((floor((DCGM_FI_DEV_CLOCK_THROTTLE_REASONS{container!="",cluster_id="c1"} or DCGM_FI_DEV_CLOCKS_EVENT_REASONS{container!="",cluster_id="c1"}) / 4)) % 2)[1h:5m])) by (container, pod, namespace, device, modelName, UUID, GPU_I_PROFILE, GPU_I_ID, pod_uid, cluster_id), "reason", "sw_power_cap", "", "")`
  43. if !strings.HasPrefix(query, wantBranch+" or ") {
  44. t.Errorf("reason query does not start with expected sw_power_cap branch:\n got %s\nwant prefix %s", query, wantBranch)
  45. }
  46. for reason, bit := range map[string]string{
  47. "sw_power_cap": "/ 4",
  48. "hw_slowdown": "/ 8",
  49. "sync_boost": "/ 16",
  50. "sw_thermal": "/ 32",
  51. "hw_thermal": "/ 64",
  52. "hw_power_brake": "/ 128",
  53. } {
  54. if !strings.Contains(query, `"reason", "`+reason+`"`) {
  55. t.Errorf("expected query to tag reason %q", reason)
  56. }
  57. if !strings.Contains(query, bit) {
  58. t.Errorf("expected query to test bit via %q for reason %q", bit, reason)
  59. }
  60. }
  61. }
  62. // TestGPUSaturationQueries runs every saturation query against the no-op
  63. // client and asserts the logged query references the expected DCGM source
  64. // metric, carries the cluster filter, and groups by the saturation label
  65. // set.
  66. func TestGPUSaturationQueries(t *testing.T) {
  67. initLogging(t, "debug", false)
  68. logWriter := new(SingleLogWriter)
  69. zerologger.Logger = zerologger.Output(zerolog.ConsoleWriter{
  70. Out: logWriter,
  71. TimeFormat: "",
  72. NoColor: true,
  73. PartsExclude: []string{
  74. zerolog.TimestampFieldName,
  75. zerolog.LevelFieldName,
  76. zerolog.CallerFieldName,
  77. },
  78. })
  79. defer initLogging(t, "debug", false)
  80. t.Setenv("PROMETHEUS_SERVER_ENDPOINT", "nowhere")
  81. t.Setenv("CURRENT_CLUSTER_ID_FILTER_ENABLED", "true")
  82. t.Setenv("CLUSTER_ID", "test-cluster")
  83. t.Setenv("GPU_MEMORY_SATURATION_THRESHOLD", "0.8")
  84. config, err := NewOpenCostPrometheusConfigFromEnv()
  85. if err != nil {
  86. t.Fatalf("Failed to create OpenCost Prometheus config: %v", err)
  87. }
  88. mock := new(NoOpPromClient)
  89. contextFactory := NewContextFactory(mock, config)
  90. querier := newPrometheusMetricsQuerier(config, mock, contextFactory)
  91. queryEnd := time.Now().UTC().Truncate(time.Hour).Add(time.Hour)
  92. queryStart := queryEnd.Add(-24 * time.Hour)
  93. tests := map[string]struct {
  94. query func(time.Time, time.Time)
  95. wantMetric string
  96. wantExtra string
  97. }{
  98. "QueryGPUThrottleViolationRatio": {
  99. query: func(s, e time.Time) { querier.QueryGPUThrottleViolationRatio(s, e) },
  100. wantMetric: "DCGM_FI_DEV_POWER_VIOLATION",
  101. },
  102. "QueryGPUThrottleReasonRatio": {
  103. query: func(s, e time.Time) { querier.QueryGPUThrottleReasonRatio(s, e) },
  104. wantMetric: "DCGM_FI_DEV_CLOCK_THROTTLE_REASONS",
  105. wantExtra: "DCGM_FI_DEV_CLOCKS_EVENT_REASONS",
  106. },
  107. "QueryGPUMemoryUsedRatioAvg": {
  108. query: func(s, e time.Time) { querier.QueryGPUMemoryUsedRatioAvg(s, e) },
  109. wantMetric: "DCGM_FI_DEV_FB_USED",
  110. wantExtra: "DCGM_FI_DEV_FB_FREE",
  111. },
  112. "QueryGPUMemoryUsedRatioMax": {
  113. query: func(s, e time.Time) { querier.QueryGPUMemoryUsedRatioMax(s, e) },
  114. wantMetric: "DCGM_FI_DEV_FB_USED",
  115. wantExtra: "max_over_time",
  116. },
  117. "QueryGPUMemoryPressureRatio": {
  118. query: func(s, e time.Time) { querier.QueryGPUMemoryPressureRatio(s, e) },
  119. wantMetric: "DCGM_FI_DEV_FB_USED",
  120. wantExtra: ">= bool 0.8",
  121. },
  122. "QueryGPUXIDErrorCount": {
  123. query: func(s, e time.Time) { querier.QueryGPUXIDErrorCount(s, e) },
  124. wantMetric: "DCGM_FI_DEV_XID_ERRORS",
  125. wantExtra: "changes(",
  126. },
  127. "QueryGPUDRAMActiveAvg": {
  128. query: func(s, e time.Time) { querier.QueryGPUDRAMActiveAvg(s, e) },
  129. wantMetric: "DCGM_FI_PROF_DRAM_ACTIVE",
  130. wantExtra: "avg_over_time",
  131. },
  132. "QueryGPUDRAMActiveMax": {
  133. query: func(s, e time.Time) { querier.QueryGPUDRAMActiveMax(s, e) },
  134. wantMetric: "DCGM_FI_PROF_DRAM_ACTIVE",
  135. wantExtra: "max_over_time",
  136. },
  137. "QueryGPUSMActiveAvg": {
  138. query: func(s, e time.Time) { querier.QueryGPUSMActiveAvg(s, e) },
  139. wantMetric: "DCGM_FI_PROF_SM_ACTIVE",
  140. },
  141. "QueryGPUSMOccupancyAvg": {
  142. query: func(s, e time.Time) { querier.QueryGPUSMOccupancyAvg(s, e) },
  143. wantMetric: "DCGM_FI_PROF_SM_OCCUPANCY",
  144. },
  145. "QueryGPUPCIeTxBytesAvg": {
  146. query: func(s, e time.Time) { querier.QueryGPUPCIeTxBytesAvg(s, e) },
  147. wantMetric: "DCGM_FI_PROF_PCIE_TX_BYTES",
  148. wantExtra: "rate(",
  149. },
  150. "QueryGPUPCIeRxBytesAvg": {
  151. query: func(s, e time.Time) { querier.QueryGPUPCIeRxBytesAvg(s, e) },
  152. wantMetric: "DCGM_FI_PROF_PCIE_RX_BYTES",
  153. },
  154. "QueryGPUNVLinkTxBytesAvg": {
  155. query: func(s, e time.Time) { querier.QueryGPUNVLinkTxBytesAvg(s, e) },
  156. wantMetric: "DCGM_FI_PROF_NVLINK_TX_BYTES",
  157. },
  158. "QueryGPUNVLinkRxBytesAvg": {
  159. query: func(s, e time.Time) { querier.QueryGPUNVLinkRxBytesAvg(s, e) },
  160. wantMetric: "DCGM_FI_PROF_NVLINK_RX_BYTES",
  161. },
  162. }
  163. deviceTests := map[string]struct {
  164. query func(time.Time, time.Time)
  165. wantMetric string
  166. }{
  167. "QueryGPUDevicePowerAvg": {func(s, e time.Time) { querier.QueryGPUDevicePowerAvg(s, e) }, "DCGM_FI_DEV_POWER_USAGE"},
  168. "QueryGPUDeviceTempAvg": {func(s, e time.Time) { querier.QueryGPUDeviceTempAvg(s, e) }, "DCGM_FI_DEV_GPU_TEMP"},
  169. "QueryGPUDeviceUsageAvg": {func(s, e time.Time) { querier.QueryGPUDeviceUsageAvg(s, e) }, "DCGM_FI_PROF_GR_ENGINE_ACTIVE"},
  170. "QueryGPUDeviceUsageMax": {func(s, e time.Time) { querier.QueryGPUDeviceUsageMax(s, e) }, "DCGM_FI_PROF_GR_ENGINE_ACTIVE"},
  171. "QueryGPUDeviceMemoryUsedAvg": {func(s, e time.Time) { querier.QueryGPUDeviceMemoryUsedAvg(s, e) }, "DCGM_FI_DEV_FB_USED"},
  172. "QueryGPUDeviceMemoryUsedMax": {func(s, e time.Time) { querier.QueryGPUDeviceMemoryUsedMax(s, e) }, "DCGM_FI_DEV_FB_USED"},
  173. }
  174. const wantDeviceFilter = `cluster_id="test-cluster"`
  175. for testName, tc := range deviceTests {
  176. t.Run(testName, func(t *testing.T) {
  177. tc.query(queryStart, queryEnd)
  178. logged := logWriter.Log
  179. if !strings.Contains(logged, tc.wantMetric) {
  180. t.Errorf("expected query to reference %q, got: %s", tc.wantMetric, logged)
  181. }
  182. if !strings.Contains(logged, wantDeviceFilter) {
  183. t.Errorf("expected query to contain cluster filter %q, got: %s", wantDeviceFilter, logged)
  184. }
  185. // device-level grouping: no container attribution
  186. if !strings.Contains(logged, gpuDeviceByLabels) || strings.Contains(logged, "container,") {
  187. t.Errorf("expected device-level grouping %q without container, got: %s", gpuDeviceByLabels, logged)
  188. }
  189. })
  190. }
  191. const wantFilter = `cluster_id="test-cluster"`
  192. for testName, tc := range tests {
  193. t.Run(testName, func(t *testing.T) {
  194. tc.query(queryStart, queryEnd)
  195. logged := logWriter.Log
  196. if !strings.Contains(logged, testName) {
  197. t.Errorf("expected log to contain query name %q, got: %s", testName, logged)
  198. }
  199. if !strings.Contains(logged, tc.wantMetric) {
  200. t.Errorf("expected query to reference %q, got: %s", tc.wantMetric, logged)
  201. }
  202. if tc.wantExtra != "" && !strings.Contains(logged, tc.wantExtra) {
  203. t.Errorf("expected query to contain %q, got: %s", tc.wantExtra, logged)
  204. }
  205. if !strings.Contains(logged, wantFilter) {
  206. t.Errorf("expected query to contain cluster filter %q, got: %s", wantFilter, logged)
  207. }
  208. if !strings.Contains(logged, gpuSaturationByLabels) {
  209. t.Errorf("expected query to group by %q, got: %s", gpuSaturationByLabels, logged)
  210. }
  211. })
  212. }
  213. }