1
0
Fork 0
mirror of https://github.com/miniflux/v2.git synced 2025-09-15 18:57:04 +00:00

Improve request package and add more unit tests

This commit is contained in:
Frédéric Guillot 2018-09-23 21:02:26 -07:00
parent 844680e573
commit 9d08139f43
49 changed files with 916 additions and 400 deletions

38
http/request/client_ip.go Normal file
View file

@ -0,0 +1,38 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"net"
"net/http"
"strings"
)
// FindClientIP returns client real IP address.
func FindClientIP(r *http.Request) string {
headers := []string{"X-Forwarded-For", "X-Real-Ip"}
for _, header := range headers {
value := r.Header.Get(header)
if value != "" {
addresses := strings.Split(value, ",")
address := strings.TrimSpace(addresses[0])
if net.ParseIP(address) != nil {
return address
}
}
}
// Fallback to TCP/IP source IP address.
var remoteIP string
if strings.ContainsRune(r.RemoteAddr, ':') {
remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
} else {
remoteIP = r.RemoteAddr
}
return remoteIP
}

View file

@ -9,7 +9,7 @@ import (
"testing"
)
func TestRealIPWithoutHeaders(t *testing.T) {
func TestFindClientIPWithoutHeaders(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
if ip := FindClientIP(r); ip != "192.168.0.1" {
t.Fatalf(`Unexpected result, got: %q`, ip)
@ -21,7 +21,7 @@ func TestRealIPWithoutHeaders(t *testing.T) {
}
}
func TestRealIPWithXFFHeader(t *testing.T) {
func TestFindClientIPWithXFFHeader(t *testing.T) {
// Test with multiple IPv4 addresses.
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
@ -59,7 +59,7 @@ func TestRealIPWithXFFHeader(t *testing.T) {
}
}
func TestRealIPWithXRealIPHeader(t *testing.T) {
func TestClientIPWithXRealIPHeader(t *testing.T) {
headers := http.Header{}
headers.Set("X-Real-Ip", "192.168.122.1")
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
@ -69,7 +69,7 @@ func TestRealIPWithXRealIPHeader(t *testing.T) {
}
}
func TestRealIPWithBothHeaders(t *testing.T) {
func TestClientIPWithBothHeaders(t *testing.T) {
headers := http.Header{}
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
headers.Set("X-Real-Ip", "192.168.122.1")

View file

@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string {
func getContextStringValue(r *http.Request, key ContextKey) string {
if v := r.Context().Value(key); v != nil {
return v.(string)
value, valid := v.(string)
if !valid {
return ""
}
return value
}
return ""
@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string {
func getContextBoolValue(r *http.Request, key ContextKey) bool {
if v := r.Context().Value(key); v != nil {
return v.(bool)
value, valid := v.(bool)
if !valid {
return false
}
return value
}
return false
@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool {
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
if v := r.Context().Value(key); v != nil {
return v.(int64)
value, valid := v.(int64)
if !valid {
return 0
}
return value
}
return 0

View file

@ -0,0 +1,436 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"context"
"net/http"
"testing"
)
func TestContextStringValue(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
r = r.WithContext(ctx)
result := getContextStringValue(r, ClientIPContextKey)
expected := "IP"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestContextStringValueWithInvalidType(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, ClientIPContextKey, 0)
r = r.WithContext(ctx)
result := getContextStringValue(r, ClientIPContextKey)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestContextStringValueWhenUnset(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := getContextStringValue(r, ClientIPContextKey)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestContextBoolValue(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
r = r.WithContext(ctx)
result := getContextBoolValue(r, IsAdminUserContextKey)
expected := true
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestContextBoolValueWithInvalidType(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
r = r.WithContext(ctx)
result := getContextBoolValue(r, IsAdminUserContextKey)
expected := false
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestContextBoolValueWhenUnset(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := getContextBoolValue(r, IsAdminUserContextKey)
expected := false
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestContextInt64Value(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
r = r.WithContext(ctx)
result := getContextInt64Value(r, UserIDContextKey)
expected := int64(1234)
if result != expected {
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
}
}
func TestContextInt64ValueWithInvalidType(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
ctx := r.Context()
ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
r = r.WithContext(ctx)
result := getContextInt64Value(r, UserIDContextKey)
expected := int64(0)
if result != expected {
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
}
}
func TestContextInt64ValueWhenUnset(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := getContextInt64Value(r, UserIDContextKey)
expected := int64(0)
if result != expected {
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
}
}
func TestIsAdmin(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := IsAdminUser(r)
expected := false
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
r = r.WithContext(ctx)
result = IsAdminUser(r)
expected = true
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestIsAuthenticated(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := IsAuthenticated(r)
expected := false
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
r = r.WithContext(ctx)
result = IsAuthenticated(r)
expected = true
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestUserID(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := UserID(r)
expected := int64(0)
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
r = r.WithContext(ctx)
result = UserID(r)
expected = int64(123)
if result != expected {
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
}
}
func TestUserTimezone(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := UserTimezone(r)
expected := "UTC"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
r = r.WithContext(ctx)
result = UserTimezone(r)
expected = "Europe/Paris"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestUserLanguage(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := UserLanguage(r)
expected := "en_US"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
r = r.WithContext(ctx)
result = UserLanguage(r)
expected = "fr_FR"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestUserTheme(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := UserTheme(r)
expected := "default"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserThemeContextKey, "black")
r = r.WithContext(ctx)
result = UserTheme(r)
expected = "black"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestCSRF(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := CSRF(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, CSRFContextKey, "secret")
r = r.WithContext(ctx)
result = CSRF(r)
expected = "secret"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestSessionID(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := SessionID(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, SessionIDContextKey, "id")
r = r.WithContext(ctx)
result = SessionID(r)
expected = "id"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestUserSessionToken(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := UserSessionToken(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
r = r.WithContext(ctx)
result = UserSessionToken(r)
expected = "token"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestOAuth2State(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := OAuth2State(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
r = r.WithContext(ctx)
result = OAuth2State(r)
expected = "state"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestFlashMessage(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := FlashMessage(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
r = r.WithContext(ctx)
result = FlashMessage(r)
expected = "message"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestFlashErrorMessage(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := FlashErrorMessage(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
r = r.WithContext(ctx)
result = FlashErrorMessage(r)
expected = "error message"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestPocketRequestToken(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := PocketRequestToken(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
r = r.WithContext(ctx)
result = PocketRequestToken(r)
expected = "request token"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}
func TestClientIP(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := ClientIP(r)
expected := ""
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
ctx := r.Context()
ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
r = r.WithContext(ctx)
result = ClientIP(r)
expected = "127.0.0.1"
if result != expected {
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
}
}

17
http/request/cookie.go Normal file
View file

@ -0,0 +1,17 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import "net/http"
// CookieValue returns the cookie value.
func CookieValue(r *http.Request, name string) string {
cookie, err := r.Cookie(name)
if err == http.ErrNoCookie {
return ""
}
return cookie.Value
}

View file

@ -0,0 +1,33 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"net/http"
"testing"
)
func TestGetCookieValue(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"})
result := CookieValue(r, "my_cookie")
expected := "cookie_value"
if result != expected {
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
}
}
func TestGetCookieValueWhenUnset(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.org", nil)
result := CookieValue(r, "my_cookie")
expected := ""
if result != expected {
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
}
}

84
http/request/params.go Normal file
View file

@ -0,0 +1,84 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"net/http"
"strconv"
"github.com/gorilla/mux"
)
// FormInt64Value returns a form value as integer.
func FormInt64Value(r *http.Request, param string) int64 {
value := r.FormValue(param)
integer, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return 0
}
return integer
}
// RouteInt64Param returns an URL route parameter as int64.
func RouteInt64Param(r *http.Request, param string) int64 {
vars := mux.Vars(r)
value, err := strconv.ParseInt(vars[param], 10, 64)
if err != nil {
return 0
}
if value < 0 {
return 0
}
return value
}
// RouteStringParam returns a URL route parameter as string.
func RouteStringParam(r *http.Request, param string) string {
vars := mux.Vars(r)
return vars[param]
}
// QueryStringParam returns a query string parameter as string.
func QueryStringParam(r *http.Request, param, defaultValue string) string {
value := r.URL.Query().Get(param)
if value == "" {
value = defaultValue
}
return value
}
// QueryIntParam returns a query string parameter as integer.
func QueryIntParam(r *http.Request, param string, defaultValue int) int {
return int(QueryInt64Param(r, param, int64(defaultValue)))
}
// QueryInt64Param returns a query string parameter as int64.
func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
value := r.URL.Query().Get(param)
if value == "" {
return defaultValue
}
val, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return defaultValue
}
if val < 0 {
return defaultValue
}
return val
}
// HasQueryParam checks if the query string contains the given parameter.
func HasQueryParam(r *http.Request, param string) bool {
values := r.URL.Query()
_, ok := values[param]
return ok
}

215
http/request/params_test.go Normal file
View file

@ -0,0 +1,215 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/mux"
)
func TestFormInt64Value(t *testing.T) {
f := url.Values{}
f.Set("integer value", "42")
f.Set("invalid value", "invalid integer")
r := &http.Request{Form: f}
result := FormInt64Value(r, "integer value")
expected := int64(42)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = FormInt64Value(r, "invalid value")
expected = int64(0)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = FormInt64Value(r, "missing value")
expected = int64(0)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
}
func TestRouteStringParam(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
result := RouteStringParam(r, "variable")
expected := "value"
if result != expected {
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
}
result = RouteStringParam(r, "missing variable")
expected = ""
if result != expected {
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
}
})
r, err := http.NewRequest("GET", "/route/value/index", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
}
func TestRouteInt64Param(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
result := RouteInt64Param(r, "variable1")
expected := int64(42)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = RouteInt64Param(r, "missing variable")
expected = 0
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = RouteInt64Param(r, "variable2")
expected = 0
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = RouteInt64Param(r, "variable3")
expected = 0
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
})
r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
}
func TestQueryStringParam(t *testing.T) {
u, _ := url.Parse("http://example.org/?key=value")
r := &http.Request{URL: u}
result := QueryStringParam(r, "key", "fallback")
expected := "value"
if result != expected {
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
}
result = QueryStringParam(r, "missing key", "fallback")
expected = "fallback"
if result != expected {
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
}
}
func TestQueryIntParam(t *testing.T) {
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
r := &http.Request{URL: u}
result := QueryIntParam(r, "key", 84)
expected := 42
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryIntParam(r, "missing key", 84)
expected = 84
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryIntParam(r, "negative", 69)
expected = 69
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryIntParam(r, "invalid", 99)
expected = 99
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
}
func TestQueryInt64Param(t *testing.T) {
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
r := &http.Request{URL: u}
result := QueryInt64Param(r, "key", int64(84))
expected := int64(42)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryInt64Param(r, "missing key", int64(84))
expected = int64(84)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryInt64Param(r, "invalid", int64(69))
expected = int64(69)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
result = QueryInt64Param(r, "invalid", int64(99))
expected = int64(99)
if result != expected {
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
}
}
func TestHasQueryParam(t *testing.T) {
u, _ := url.Parse("http://example.org/?key=42")
r := &http.Request{URL: u}
result := HasQueryParam(r, "key")
expected := true
if result != expected {
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
}
result = HasQueryParam(r, "missing key")
expected = false
if result != expected {
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
}
}

View file

@ -1,128 +0,0 @@
// Copyright 2018 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package request // import "miniflux.app/http/request"
import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"github.com/gorilla/mux"
)
// Cookie returns the cookie value.
func Cookie(r *http.Request, name string) string {
cookie, err := r.Cookie(name)
if err == http.ErrNoCookie {
return ""
}
return cookie.Value
}
// FormInt64Value returns a form value as integer.
func FormInt64Value(r *http.Request, param string) int64 {
value := r.FormValue(param)
integer, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return 0
}
return integer
}
// IntParam returns an URL route parameter as integer.
func IntParam(r *http.Request, param string) (int64, error) {
vars := mux.Vars(r)
value, err := strconv.Atoi(vars[param])
if err != nil {
return 0, fmt.Errorf("request: %s parameter is not an integer", param)
}
if value < 0 {
return 0, nil
}
return int64(value), nil
}
// Param returns an URL route parameter as string.
func Param(r *http.Request, param, defaultValue string) string {
vars := mux.Vars(r)
value := vars[param]
if value == "" {
value = defaultValue
}
return value
}
// QueryParam returns a querystring parameter as string.
func QueryParam(r *http.Request, param, defaultValue string) string {
value := r.URL.Query().Get(param)
if value == "" {
value = defaultValue
}
return value
}
// QueryIntParam returns a querystring parameter as integer.
func QueryIntParam(r *http.Request, param string, defaultValue int) int {
return int(QueryInt64Param(r, param, int64(defaultValue)))
}
// QueryInt64Param returns a querystring parameter as int64.
func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
value := r.URL.Query().Get(param)
if value == "" {
return defaultValue
}
val, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return defaultValue
}
if val < 0 {
return defaultValue
}
return val
}
// HasQueryParam checks if the query string contains the given parameter.
func HasQueryParam(r *http.Request, param string) bool {
values := r.URL.Query()
_, ok := values[param]
return ok
}
// FindClientIP returns client's real IP address.
func FindClientIP(r *http.Request) string {
headers := []string{"X-Forwarded-For", "X-Real-Ip"}
for _, header := range headers {
value := r.Header.Get(header)
if value != "" {
addresses := strings.Split(value, ",")
address := strings.TrimSpace(addresses[0])
if net.ParseIP(address) != nil {
return address
}
}
}
// Fallback to TCP/IP source IP address.
var remoteIP string
if strings.ContainsRune(r.RemoteAddr, ':') {
remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
} else {
remoteIP = r.RemoteAddr
}
return remoteIP
}