diff --git a/internal/http/response/builder.go b/internal/http/response/builder.go index db89654a..e2001a82 100644 --- a/internal/http/response/builder.go +++ b/internal/http/response/builder.go @@ -25,7 +25,7 @@ type Builder struct { statusCode int headers map[string]string enableCompression bool - body interface{} + body any } // WithStatus uses the given status code to build the response. @@ -41,7 +41,7 @@ func (b *Builder) WithHeader(key, value string) *Builder { } // WithBody uses the given body to build the response. -func (b *Builder) WithBody(body interface{}) *Builder { +func (b *Builder) WithBody(body any) *Builder { b.body = body return b } diff --git a/internal/http/response/json/json.go b/internal/http/response/json/json.go index 8e99681a..7f5d33d5 100644 --- a/internal/http/response/json/json.go +++ b/internal/http/response/json/json.go @@ -16,19 +16,31 @@ import ( const contentTypeHeader = `application/json` // OK creates a new JSON response with a 200 status code. -func OK(w http.ResponseWriter, r *http.Request, body interface{}) { +func OK(w http.ResponseWriter, r *http.Request, body any) { + responseBody, err := json.Marshal(body) + if err != nil { + ServerError(w, r, err) + return + } + builder := response.New(w, r) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSON(body)) + builder.WithBody(responseBody) builder.Write() } // Created sends a created response to the client. -func Created(w http.ResponseWriter, r *http.Request, body interface{}) { +func Created(w http.ResponseWriter, r *http.Request, body any) { + responseBody, err := json.Marshal(body) + if err != nil { + ServerError(w, r, err) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusCreated) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSON(body)) + builder.WithBody(responseBody) builder.Write() } @@ -62,10 +74,17 @@ func ServerError(w http.ResponseWriter, r *http.Request, err error) { ), ) + responseBody, jsonErr := generateJSONError(err) + if jsonErr != nil { + slog.Error("Unable to generate JSON error", slog.Any("error", jsonErr)) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusInternalServerError) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSONError(err)) + builder.WithBody(responseBody) builder.Write() } @@ -84,10 +103,17 @@ func BadRequest(w http.ResponseWriter, r *http.Request, err error) { ), ) + responseBody, jsonErr := generateJSONError(err) + if jsonErr != nil { + slog.Error("Unable to generate JSON error", slog.Any("error", jsonErr)) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusBadRequest) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSONError(err)) + builder.WithBody(responseBody) builder.Write() } @@ -105,10 +131,17 @@ func Unauthorized(w http.ResponseWriter, r *http.Request) { ), ) + responseBody, jsonErr := generateJSONError(errors.New("access unauthorized")) + if jsonErr != nil { + slog.Error("Unable to generate JSON error", slog.Any("error", jsonErr)) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusUnauthorized) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSONError(errors.New("access unauthorized"))) + builder.WithBody(responseBody) builder.Write() } @@ -126,10 +159,17 @@ func Forbidden(w http.ResponseWriter, r *http.Request) { ), ) + responseBody, jsonErr := generateJSONError(errors.New("access forbidden")) + if jsonErr != nil { + slog.Error("Unable to generate JSON error", slog.Any("error", jsonErr)) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusForbidden) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSONError(errors.New("access forbidden"))) + builder.WithBody(responseBody) builder.Write() } @@ -147,27 +187,29 @@ func NotFound(w http.ResponseWriter, r *http.Request) { ), ) + responseBody, jsonErr := generateJSONError(errors.New("resource not found")) + if jsonErr != nil { + slog.Error("Unable to generate JSON error", slog.Any("error", jsonErr)) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + builder := response.New(w, r) builder.WithStatus(http.StatusNotFound) builder.WithHeader("Content-Type", contentTypeHeader) - builder.WithBody(toJSONError(errors.New("resource not found"))) + builder.WithBody(responseBody) builder.Write() } -func toJSONError(err error) []byte { +func generateJSONError(err error) ([]byte, error) { type errorMsg struct { ErrorMessage string `json:"error_message"` } - return toJSON(errorMsg{ErrorMessage: err.Error()}) -} - -func toJSON(v interface{}) []byte { - b, err := json.Marshal(v) + encodedBody, err := json.Marshal(errorMsg{ErrorMessage: err.Error()}) if err != nil { - slog.Error("Unable to marshal JSON response", slog.Any("error", err)) - return []byte("") + return nil, err } - return b + return encodedBody, nil } diff --git a/internal/http/response/json/json_test.go b/internal/http/response/json/json_test.go index 9a1e1a03..1b6685c4 100644 --- a/internal/http/response/json/json_test.go +++ b/internal/http/response/json/json_test.go @@ -293,12 +293,12 @@ func TestBuildInvalidJSONResponse(t *testing.T) { handler.ServeHTTP(w, r) resp := w.Result() - expectedStatusCode := http.StatusOK + expectedStatusCode := http.StatusInternalServerError if resp.StatusCode != expectedStatusCode { t.Fatalf(`Unexpected status code, got %d instead of %d`, resp.StatusCode, expectedStatusCode) } - expectedBody := `` + expectedBody := `{"error_message":"json: unsupported type: chan int"}` actualBody := w.Body.String() if actualBody != expectedBody { t.Fatalf(`Unexpected body, got %s instead of %s`, actualBody, expectedBody)