Explorar el Código

Merge branch 'develop' into bolt/opencost-mods

Matt Bolt hace 1 año
padre
commit
e86e159d6a

+ 5 - 5
.github/workflows/integration-testing.yaml

@@ -54,7 +54,7 @@ jobs:
         runs-on: ubuntu-latest
         permissions: {}
         needs: check_actor_permissions
-        if: ${{ (always() && !cancelled()) && ( github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
+        if: ${{ (always() && !cancelled()) && ( github.event.event_name == 'schedule' || github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
         outputs:
             IMAGE_TAG: ${{ steps.set_image_tags.outputs.IMAGE_TAG }}
             NAMESPACE: ${{ steps.set_image_tags.outputs.NAMESPACE }}
@@ -117,7 +117,7 @@ jobs:
                 
     build-test-stack:
         needs: wait_for_image_ready
-        if: ${{ (always() && !cancelled()) && ( github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
+        if: ${{ (always() && !cancelled()) && ( github.event.event_name == 'schedule'  || github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
         uses: opencost/opencost-infra/.github/workflows/build-stack.yaml@main
         secrets: inherit
         with:
@@ -127,7 +127,7 @@ jobs:
     wait-for-dns:
         needs: [wait_for_image_ready, build-test-stack]
         runs-on: ubuntu-latest
-        if: ${{ (always() && !cancelled()) && ( github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
+        if: ${{ (always() && !cancelled()) && ( github.event.event_name == 'schedule'  || github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
         permissions: {}
         steps:
           - name: Wait for DNS to resolve
@@ -149,7 +149,7 @@ jobs:
 
     run-tests:
         needs: [wait_for_image_ready, build-test-stack, wait-for-dns]
-        if: ${{ (always() && !cancelled()) && ( github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
+        if: ${{ (always() && !cancelled()) && ( github.event.event_name == 'schedule'  || github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
         permissions: {}
         uses: opencost/opencost-infra/.github/workflows/test-stack.yaml@main
         secrets: inherit
@@ -159,7 +159,7 @@ jobs:
     
     teardown-test-stack:
         needs: [wait_for_image_ready, run-tests]
-        if: ${{ (always() && !cancelled()) && ( github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
+        if: ${{ (always() && !cancelled()) && ( github.event.event_name == 'schedule'  || github.event_name == 'push' || github.event_name == 'merge_group' || (github.event_name == 'pull_request_target'  && needs.check_actor_permissions.outputs.ismaintainer == 'true')) }}
         uses: opencost/opencost-infra/.github/workflows/destroy-stack.yaml@main
         secrets: inherit 
         permissions: {}

+ 1 - 0
core/go.mod

@@ -23,6 +23,7 @@ require (
 	github.com/prometheus/common v0.63.0
 	github.com/rs/zerolog v1.26.1
 	github.com/spf13/viper v1.8.1
+	github.com/stretchr/testify v1.9.0
 	golang.org/x/exp v0.0.0-20221031165847-c99f073a8326
 	golang.org/x/oauth2 v0.25.0
 	golang.org/x/sync v0.12.0

+ 26 - 5
core/pkg/protocol/http.go

@@ -12,6 +12,8 @@ import (
 // HTTPProtocol is a struct used as a selector for request/response protocol utility methods
 type HTTPProtocol struct{}
 
+const internalServerErrorJSON = `{"code":500,"message":"Internal Server Error"}`
+
 // HTTPError represents an http error response
 type HTTPError struct {
 	StatusCode int
@@ -123,14 +125,14 @@ func (hp HTTPProtocol) WriteJSONData(w http.ResponseWriter, data interface{}) {
 	w.WriteHeader(status)
 	if err := json.NewEncoder(w).Encode(data); err != nil {
 		log.Error("Failed to encode JSON response: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
 // WriteRawError uses json content-type and outputs raw error message for backwards compatibility to existing
 // frontend expectations.
 func (hp HTTPProtocol) WriteRawError(w http.ResponseWriter, httpStatusCode int, err string) {
-	// I know this isn't json, but its what we've done and don't want to break frontned while we fix CWE
-	w.Header().Set("Content-Type", "application/json")
 	http.Error(w, err, httpStatusCode)
 }
 
@@ -140,6 +142,8 @@ func (hp HTTPProtocol) WriteEncodedError(w http.ResponseWriter, httpStatusCode i
 	w.WriteHeader(httpStatusCode)
 	if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
 		log.Error("Failed to encode error response: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -149,8 +153,15 @@ func (hp HTTPProtocol) WriteData(w http.ResponseWriter, data interface{}) {
 	w.Header().Set("Content-Type", "application/json")
 	status := http.StatusOK
 	w.WriteHeader(status)
-	if err := json.NewEncoder(w).Encode(data); err != nil {
+
+	resp := &HTTPResponse{
+		Code: status,
+		Data: data,
+	}
+	if err := json.NewEncoder(w).Encode(resp); err != nil {
 		log.Error("Failed to encode response: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -166,6 +177,8 @@ func (hp HTTPProtocol) WriteDataWithWarning(w http.ResponseWriter, data interfac
 	w.WriteHeader(status)
 	if err := json.NewEncoder(w).Encode(resp); err != nil {
 		log.Error("Failed to encode response with warning: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -181,6 +194,8 @@ func (hp HTTPProtocol) WriteDataWithMessage(w http.ResponseWriter, data interfac
 	w.WriteHeader(status)
 	if err := json.NewEncoder(w).Encode(resp); err != nil {
 		log.Error("Failed to encode response with message: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -218,6 +233,8 @@ func (hp HTTPProtocol) WriteDataWithMessageAndWarning(w http.ResponseWriter, dat
 	w.WriteHeader(status)
 	if err := json.NewEncoder(w).Encode(resp); err != nil {
 		log.Error("Failed to encode response with message and warning: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -230,12 +247,14 @@ func (hp HTTPProtocol) WriteError(w http.ResponseWriter, err HTTPError) {
 	}
 	w.WriteHeader(status)
 
-	resp, _ := json.Marshal(&HTTPResponse{
+	resp := &HTTPResponse{
 		Code:    status,
 		Message: err.Body,
-	})
+	}
 	if err := json.NewEncoder(w).Encode(resp); err != nil {
 		log.Error("Failed to encode error response: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }
 
@@ -246,5 +265,7 @@ func (hp HTTPProtocol) WriteResponse(w http.ResponseWriter, r *HTTPResponse) {
 	w.WriteHeader(status)
 	if err := json.NewEncoder(w).Encode(r); err != nil {
 		log.Error("Failed to encode response: " + err.Error())
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte(internalServerErrorJSON))
 	}
 }

+ 186 - 0
core/pkg/protocol/http_test.go

@@ -0,0 +1,186 @@
+package protocol
+
+import (
+	"errors"
+	"log"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestHTTPError_Error(t *testing.T) {
+	err := HTTPError{StatusCode: 400, Body: "bad request"}
+	assert.Equal(t, "bad request", err.Error())
+}
+
+func TestHTTPProtocol_BadRequest(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.BadRequest("bad req")
+	assert.Equal(t, http.StatusBadRequest, err.StatusCode)
+	assert.Equal(t, "bad req", err.Body)
+}
+
+func TestHTTPProtocol_UnprocessableEntity(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.UnprocessableEntity("")
+	assert.Equal(t, http.StatusUnprocessableEntity, err.StatusCode)
+	assert.Equal(t, "Unprocessable Entity", err.Body)
+	err2 := hp.UnprocessableEntity("custom")
+	assert.Equal(t, "custom", err2.Body)
+}
+
+func TestHTTPProtocol_InternalServerError(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.InternalServerError("")
+	assert.Equal(t, http.StatusInternalServerError, err.StatusCode)
+	assert.Equal(t, "Internal Server Error", err.Body)
+	err2 := hp.InternalServerError("custom")
+	assert.Equal(t, "custom", err2.Body)
+}
+
+func TestHTTPProtocol_NotImplemented(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.NotImplemented("")
+	assert.Equal(t, http.StatusNotImplemented, err.StatusCode)
+	assert.Equal(t, "Not Implemented", err.Body)
+	err2 := hp.NotImplemented("custom")
+	assert.Equal(t, "custom", err2.Body)
+}
+
+func TestHTTPProtocol_Forbidden(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.Forbidden("")
+	assert.Equal(t, http.StatusForbidden, err.StatusCode)
+	assert.Equal(t, "Forbidden", err.Body)
+	err2 := hp.Forbidden("custom")
+	assert.Equal(t, "custom", err2.Body)
+}
+
+func TestHTTPProtocol_NotFound(t *testing.T) {
+	hp := HTTPProtocol{}
+	err := hp.NotFound()
+	assert.Equal(t, http.StatusNotFound, err.StatusCode)
+	assert.Equal(t, "Not Found", err.Body)
+}
+
+func TestHTTPProtocol_ToResponse(t *testing.T) {
+	hp := HTTPProtocol{}
+	resp := hp.ToResponse("data", nil)
+	assert.Equal(t, http.StatusOK, resp.Code)
+	assert.Equal(t, "data", resp.Data)
+	resp2 := hp.ToResponse("data", errors.New("fail"))
+	assert.Equal(t, http.StatusInternalServerError, resp2.Code)
+	assert.Equal(t, "fail", resp2.Message)
+}
+
+func TestHTTPProtocol_WriteRawOK(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteRawOK(rw)
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Equal(t, "application/json", rw.Header().Get("Content-Type"))
+	assert.Equal(t, "0", rw.Header().Get("Content-Length"))
+}
+
+func TestHTTPProtocol_WriteRawNoContent(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteRawNoContent(rw)
+	assert.Equal(t, http.StatusNoContent, rw.Code)
+	assert.Equal(t, "application/json", rw.Header().Get("Content-Type"))
+}
+
+func TestHTTPProtocol_WriteJSONData(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteJSONData(rw, map[string]string{"foo": "bar"})
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Contains(t, rw.Body.String(), "foo")
+}
+
+func TestHTTPProtocol_WriteRawError(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteRawError(rw, http.StatusBadRequest, "bad")
+	assert.Equal(t, http.StatusBadRequest, rw.Code)
+	assert.Equal(t, "text/plain; charset=utf-8", rw.Header().Get("Content-Type"))
+	assert.Contains(t, rw.Body.String(), "bad")
+}
+
+func TestHTTPProtocol_WriteEncodedError(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteEncodedError(rw, http.StatusBadRequest, map[string]string{"err": "bad"})
+	assert.Equal(t, http.StatusBadRequest, rw.Code)
+	assert.Contains(t, rw.Body.String(), "bad")
+}
+
+func TestHTTPProtocol_WriteData(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteData(rw, map[string]string{"foo": "bar"})
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Contains(t, rw.Body.String(), "foo")
+}
+
+func TestHTTPProtocol_WriteDataWithWarning(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteDataWithWarning(rw, map[string]string{"foo": "bar"}, "warn")
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Contains(t, rw.Body.String(), "warn")
+}
+
+func TestHTTPProtocol_WriteDataWithMessage(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteDataWithMessage(rw, map[string]string{"foo": "bar"}, "msg")
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Contains(t, rw.Body.String(), "msg")
+}
+
+func TestHTTPProtocol_WriteDataWithMessageAndWarning(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteDataWithMessageAndWarning(rw, map[string]string{"foo": "bar"}, "msg", "warn")
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Contains(t, rw.Body.String(), "msg")
+	assert.Contains(t, rw.Body.String(), "warn")
+}
+
+func TestHTTPProtocol_WriteError(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	hp.WriteError(rw, HTTPError{StatusCode: 400, Body: "fail"})
+	assert.Equal(t, 400, rw.Code)
+	body := rw.Body.String()
+	log.Println("body: " + body)
+	assert.Contains(t, body, "fail")
+}
+
+func TestHTTPProtocol_WriteResponse(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	resp := &HTTPResponse{Code: 200, Data: "foo"}
+	hp.WriteResponse(rw, resp)
+	assert.Equal(t, 200, rw.Code)
+	assert.Contains(t, rw.Body.String(), "foo")
+}
+
+func TestHTTPProtocol_WriteData_Structure(t *testing.T) {
+	hp := HTTPProtocol{}
+	rw := httptest.NewRecorder()
+	data := map[string]string{"foo": "bar"}
+	hp.WriteData(rw, data)
+	assert.Equal(t, http.StatusOK, rw.Code)
+	assert.Equal(t, "application/json", rw.Header().Get("Content-Type"))
+
+	// Check the structure of the JSON response
+	body := rw.Body.String()
+	assert.Contains(t, body, "\"code\":200")
+	assert.Contains(t, body, "\"data\":{\"foo\":\"bar\"}")
+	assert.NotContains(t, body, "message")
+	assert.NotContains(t, body, "warning")
+}