Pārlūkot izejas kodu

add middleware for auth (#3651)

Alex Meijer 2 mēneši atpakaļ
vecāks
revīzija
69430e7f62
3 mainītis faili ar 147 papildinājumiem un 8 dzēšanām
  1. 34 8
      pkg/costmodel/router.go
  2. 106 0
      pkg/costmodel/router_test.go
  3. 7 0
      pkg/env/costmodel.go

+ 34 - 8
pkg/costmodel/router.go

@@ -2,6 +2,7 @@ package costmodel
 
 import (
 	"context"
+	"crypto/subtle"
 	"encoding/base64"
 	"fmt"
 	"net/http"
@@ -127,6 +128,32 @@ func ParsePercentString(percentStr string) (float64, error) {
 	return discount, nil
 }
 
+// adminAuthMiddleware wraps a handler and requires a Bearer token matching ADMIN_TOKEN env var when set.
+// When ADMIN_TOKEN is not set, logs a deduped warning and allows the request through.
+// When ADMIN_TOKEN is set, returns 401 if the Bearer token is missing or 403 if it does not match.
+func adminAuthMiddleware(next httprouter.Handle) httprouter.Handle {
+	return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+		adminToken := env.GetAdminToken()
+		if adminToken == "" {
+			log.DedupedWarningf(5, "Admin token (ADMIN_TOKEN) not configured; write operations are unauthenticated")
+			next(w, r, ps)
+			return
+		}
+		authHeader := r.Header.Get("Authorization")
+		const prefix = "Bearer "
+		if !strings.HasPrefix(authHeader, prefix) {
+			http.Error(w, "Missing or invalid authorization", http.StatusUnauthorized)
+			return
+		}
+		bearerToken := strings.TrimPrefix(authHeader, prefix)
+		if subtle.ConstantTimeCompare([]byte(bearerToken), []byte(adminToken)) != 1 {
+			http.Error(w, "Missing or invalid authorization", http.StatusForbidden)
+			return
+		}
+		next(w, r, ps)
+	}
+}
+
 func WriteData(w http.ResponseWriter, data interface{}, err error) {
 	if err != nil {
 		proto.WriteError(w, proto.InternalServerError(err.Error()))
@@ -541,7 +568,7 @@ func Initialize(router *httprouter.Router, additionalConfigWatchers ...*watcher.
 	router.GET("/orphanedPods", a.GetOrphanedPods)
 	router.GET("/installNamespace", a.GetInstallNamespace)
 	router.GET("/installInfo", a.GetInstallInfo)
-	router.POST("/serviceKey", a.AddServiceKey)
+	router.POST("/serviceKey", adminAuthMiddleware(a.AddServiceKey))
 	router.GET("/helmValues", a.GetHelmValues)
 
 	return a
@@ -589,19 +616,18 @@ func InitializeCloudCost(router *httprouter.Router, providerConfig models.Provid
 	repoQuerier := cloudcost.NewRepositoryQuerier(repo)
 	cloudCostQueryService := cloudcost.NewQueryService(repoQuerier, repoQuerier)
 
-	router.GET("/cloud/config/export", cloudConfigController.GetExportConfigHandler())
-	router.GET("/cloud/config/enable", cloudConfigController.GetEnableConfigHandler())
-	router.GET("/cloud/config/disable", cloudConfigController.GetDisableConfigHandler())
-	router.GET("/cloud/config/delete", cloudConfigController.GetDeleteConfigHandler())
-
 	router.GET("/cloudCost", cloudCostQueryService.GetCloudCostHandler())
 	router.GET("/cloudCost/view/graph", cloudCostQueryService.GetCloudCostViewGraphHandler())
 	router.GET("/cloudCost/view/totals", cloudCostQueryService.GetCloudCostViewTotalsHandler())
 	router.GET("/cloudCost/view/table", cloudCostQueryService.GetCloudCostViewTableHandler(nil))
 
 	router.GET("/cloudCost/status", cloudCostPipelineService.GetCloudCostStatusHandler())
-	router.GET("/cloudCost/rebuild", cloudCostPipelineService.GetCloudCostRebuildHandler())
-	router.GET("/cloudCost/repair", cloudCostPipelineService.GetCloudCostRepairHandler())
+	router.GET("/cloudCost/rebuild", adminAuthMiddleware(cloudCostPipelineService.GetCloudCostRebuildHandler()))
+	router.GET("/cloudCost/repair", adminAuthMiddleware(cloudCostPipelineService.GetCloudCostRepairHandler()))
+	router.GET("/cloud/config/export", adminAuthMiddleware(cloudConfigController.GetExportConfigHandler()))
+	router.GET("/cloud/config/enable", adminAuthMiddleware(cloudConfigController.GetEnableConfigHandler()))
+	router.GET("/cloud/config/disable", adminAuthMiddleware(cloudConfigController.GetDisableConfigHandler()))
+	router.GET("/cloud/config/delete", adminAuthMiddleware(cloudConfigController.GetDeleteConfigHandler()))
 
 	return cloudCostPipelineService
 }

+ 106 - 0
pkg/costmodel/router_test.go

@@ -0,0 +1,106 @@
+package costmodel
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"testing"
+
+	"github.com/julienschmidt/httprouter"
+	"github.com/opencost/opencost/pkg/env"
+)
+
+func TestAdminAuthMiddleware(t *testing.T) {
+	const testToken = "test-admin-token-123"
+
+	nextCalled := false
+	next := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+		nextCalled = true
+		w.WriteHeader(http.StatusOK)
+	}
+
+	tests := []struct {
+		name           string
+		setToken       string
+		authHeader     string
+		wantStatus     int
+		wantNextCalled bool
+	}{
+		{
+			name:           "no admin token configured - request allowed with deduped warning",
+			setToken:       "",
+			authHeader:     "",
+			wantStatus:     http.StatusOK,
+			wantNextCalled: true,
+		},
+		{
+			name:           "missing authorization header",
+			setToken:       testToken,
+			authHeader:     "",
+			wantStatus:     http.StatusUnauthorized,
+			wantNextCalled: false,
+		},
+		{
+			name:           "wrong authorization scheme",
+			setToken:       testToken,
+			authHeader:     "Basic dXNlcjpwYXNz",
+			wantStatus:     http.StatusUnauthorized,
+			wantNextCalled: false,
+		},
+		{
+			name:           "bearer with wrong token",
+			setToken:       testToken,
+			authHeader:     "Bearer wrong-token",
+			wantStatus:     http.StatusForbidden,
+			wantNextCalled: false,
+		},
+		{
+			name:           "bearer with correct token",
+			setToken:       testToken,
+			authHeader:     "Bearer " + testToken,
+			wantStatus:     http.StatusOK,
+			wantNextCalled: true,
+		},
+		{
+			name:           "bearer token with extra spaces after prefix",
+			setToken:       testToken,
+			authHeader:     "Bearer  " + testToken,
+			wantStatus:     http.StatusForbidden,
+			wantNextCalled: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			prev := os.Getenv(env.AdminTokenEnvVar)
+			defer func() {
+				if prev == "" {
+					os.Unsetenv(env.AdminTokenEnvVar)
+				} else {
+					os.Setenv(env.AdminTokenEnvVar, prev)
+				}
+			}()
+			if tt.setToken != "" {
+				os.Setenv(env.AdminTokenEnvVar, tt.setToken)
+			} else {
+				os.Unsetenv(env.AdminTokenEnvVar)
+			}
+
+			nextCalled = false
+			req := httptest.NewRequest(http.MethodPost, "/serviceKey", nil)
+			if tt.authHeader != "" {
+				req.Header.Set("Authorization", tt.authHeader)
+			}
+			rec := httptest.NewRecorder()
+
+			handler := adminAuthMiddleware(next)
+			handler(rec, req, httprouter.Params{})
+
+			if rec.Code != tt.wantStatus {
+				t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus)
+			}
+			if nextCalled != tt.wantNextCalled {
+				t.Errorf("nextCalled = %v, want %v", nextCalled, tt.wantNextCalled)
+			}
+		})
+	}
+}

+ 7 - 0
pkg/env/costmodel.go

@@ -104,6 +104,9 @@ const (
 	MCPServerEnabledEnvVar = "MCP_SERVER_ENABLED"
 	MCPHTTPPortEnvVar      = "MCP_HTTP_PORT"
 
+	// Admin write operations (e.g. serviceKey, cloud config)
+	AdminTokenEnvVar = "ADMIN_TOKEN"
+
 	// Metrics Emitter
 	MetricsEmitterQueryWindowEnvVar = "METRICS_EMITTER_QUERY_WINDOW"
 )
@@ -112,6 +115,10 @@ func GetGCPAuthSecretFilePath() string {
 	return env.GetPathFromConfig(GCPAuthSecretFile)
 }
 
+func GetAdminToken() string {
+	return env.Get(AdminTokenEnvVar, "")
+}
+
 func GetExportCSVFile() string {
 	return env.Get(ExportCSVFile, "")
 }