router_test.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package costmodel
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "os"
  6. "testing"
  7. "github.com/julienschmidt/httprouter"
  8. "github.com/opencost/opencost/pkg/env"
  9. )
  10. func TestAdminAuthMiddleware(t *testing.T) {
  11. const testToken = "test-admin-token-123"
  12. nextCalled := false
  13. next := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
  14. nextCalled = true
  15. w.WriteHeader(http.StatusOK)
  16. }
  17. tests := []struct {
  18. name string
  19. setToken string
  20. authHeader string
  21. wantStatus int
  22. wantNextCalled bool
  23. }{
  24. {
  25. name: "no admin token configured - request allowed with deduped warning",
  26. setToken: "",
  27. authHeader: "",
  28. wantStatus: http.StatusOK,
  29. wantNextCalled: true,
  30. },
  31. {
  32. name: "missing authorization header",
  33. setToken: testToken,
  34. authHeader: "",
  35. wantStatus: http.StatusUnauthorized,
  36. wantNextCalled: false,
  37. },
  38. {
  39. name: "wrong authorization scheme",
  40. setToken: testToken,
  41. authHeader: "Basic dXNlcjpwYXNz",
  42. wantStatus: http.StatusUnauthorized,
  43. wantNextCalled: false,
  44. },
  45. {
  46. name: "bearer with wrong token",
  47. setToken: testToken,
  48. authHeader: "Bearer wrong-token",
  49. wantStatus: http.StatusForbidden,
  50. wantNextCalled: false,
  51. },
  52. {
  53. name: "bearer with correct token",
  54. setToken: testToken,
  55. authHeader: "Bearer " + testToken,
  56. wantStatus: http.StatusOK,
  57. wantNextCalled: true,
  58. },
  59. {
  60. name: "bearer token with extra spaces after prefix",
  61. setToken: testToken,
  62. authHeader: "Bearer " + testToken,
  63. wantStatus: http.StatusForbidden,
  64. wantNextCalled: false,
  65. },
  66. }
  67. for _, tt := range tests {
  68. t.Run(tt.name, func(t *testing.T) {
  69. prev := os.Getenv(env.AdminTokenEnvVar)
  70. defer func() {
  71. if prev == "" {
  72. os.Unsetenv(env.AdminTokenEnvVar)
  73. } else {
  74. os.Setenv(env.AdminTokenEnvVar, prev)
  75. }
  76. }()
  77. if tt.setToken != "" {
  78. os.Setenv(env.AdminTokenEnvVar, tt.setToken)
  79. } else {
  80. os.Unsetenv(env.AdminTokenEnvVar)
  81. }
  82. nextCalled = false
  83. req := httptest.NewRequest(http.MethodPost, "/serviceKey", nil)
  84. if tt.authHeader != "" {
  85. req.Header.Set("Authorization", tt.authHeader)
  86. }
  87. rec := httptest.NewRecorder()
  88. handler := adminAuthMiddleware(next)
  89. handler(rec, req, httprouter.Params{})
  90. if rec.Code != tt.wantStatus {
  91. t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus)
  92. }
  93. if nextCalled != tt.wantNextCalled {
  94. t.Errorf("nextCalled = %v, want %v", nextCalled, tt.wantNextCalled)
  95. }
  96. })
  97. }
  98. }