mirror of
https://github.com/miniflux/v2.git
synced 2025-08-01 17:38:37 +00:00
Move internal packages to an internal folder
For reference: https://go.dev/doc/go1.4#internalpackages
This commit is contained in:
parent
c234903255
commit
168a870c02
433 changed files with 1121 additions and 1123 deletions
55
internal/http/request/client_ip.go
Normal file
55
internal/http/request/client_ip.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FindClientIP returns the client real IP address based on trusted Reverse-Proxy HTTP headers.
|
||||
func FindClientIP(r *http.Request) string {
|
||||
headers := []string{"X-Forwarded-For", "X-Real-Ip"}
|
||||
for _, header := range headers {
|
||||
value := r.Header.Get(header)
|
||||
|
||||
if value != "" {
|
||||
addresses := strings.Split(value, ",")
|
||||
address := strings.TrimSpace(addresses[0])
|
||||
address = dropIPv6zone(address)
|
||||
|
||||
if net.ParseIP(address) != nil {
|
||||
return address
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to TCP/IP source IP address.
|
||||
return FindRemoteIP(r)
|
||||
}
|
||||
|
||||
// FindRemoteIP returns remote client IP address.
|
||||
func FindRemoteIP(r *http.Request) string {
|
||||
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
remoteIP = r.RemoteAddr
|
||||
}
|
||||
remoteIP = dropIPv6zone(remoteIP)
|
||||
|
||||
// When listening on a Unix socket, RemoteAddr is empty.
|
||||
if remoteIP == "" {
|
||||
remoteIP = "127.0.0.1"
|
||||
}
|
||||
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
func dropIPv6zone(address string) string {
|
||||
i := strings.IndexByte(address, '%')
|
||||
if i != -1 {
|
||||
address = address[:i]
|
||||
}
|
||||
return address
|
||||
}
|
125
internal/http/request/client_ip_test.go
Normal file
125
internal/http/request/client_ip_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFindClientIPWithoutHeaders(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
|
||||
if ip := FindClientIP(r); ip != "192.168.0.1" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
r = &http.Request{RemoteAddr: "192.168.0.1"}
|
||||
if ip := FindClientIP(r); ip != "192.168.0.1" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
r = &http.Request{RemoteAddr: "fe80::14c2:f039:edc7:edc7"}
|
||||
if ip := FindClientIP(r); ip != "fe80::14c2:f039:edc7:edc7" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
r = &http.Request{RemoteAddr: "fe80::14c2:f039:edc7:edc7%eth0"}
|
||||
if ip := FindClientIP(r); ip != "fe80::14c2:f039:edc7:edc7" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
r = &http.Request{RemoteAddr: "[fe80::14c2:f039:edc7:edc7%eth0]:4242"}
|
||||
if ip := FindClientIP(r); ip != "fe80::14c2:f039:edc7:edc7" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClientIPWithXFFHeader(t *testing.T) {
|
||||
// Test with multiple IPv4 addresses.
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "203.0.113.195" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
// Test with single IPv6 address.
|
||||
headers = http.Header{}
|
||||
headers.Set("X-Forwarded-For", "2001:db8:85a3:8d3:1319:8a2e:370:7348")
|
||||
r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "2001:db8:85a3:8d3:1319:8a2e:370:7348" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
// Test with single IPv6 address with zone
|
||||
headers = http.Header{}
|
||||
headers.Set("X-Forwarded-For", "fe80::14c2:f039:edc7:edc7%eth0")
|
||||
r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "fe80::14c2:f039:edc7:edc7" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
// Test with single IPv4 address.
|
||||
headers = http.Header{}
|
||||
headers.Set("X-Forwarded-For", "70.41.3.18")
|
||||
r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "70.41.3.18" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
|
||||
// Test with invalid IP address.
|
||||
headers = http.Header{}
|
||||
headers.Set("X-Forwarded-For", "fake IP")
|
||||
r = &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "192.168.0.1" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPWithXRealIPHeader(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Real-Ip", "192.168.122.1")
|
||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "192.168.122.1" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPWithBothHeaders(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
||||
headers.Set("X-Real-Ip", "192.168.122.1")
|
||||
|
||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "203.0.113.195" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPWithNoRemoteAddress(t *testing.T) {
|
||||
r := &http.Request{}
|
||||
|
||||
if ip := FindClientIP(r); ip != "127.0.0.1" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPWithoutRemoteAddrAndBothHeaders(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
||||
headers.Set("X-Real-Ip", "192.168.122.1")
|
||||
|
||||
r := &http.Request{RemoteAddr: "", Header: headers}
|
||||
|
||||
if ip := FindClientIP(r); ip != "203.0.113.195" {
|
||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||
}
|
||||
}
|
154
internal/http/request/context.go
Normal file
154
internal/http/request/context.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ContextKey represents a context key.
|
||||
type ContextKey int
|
||||
|
||||
// List of context keys.
|
||||
const (
|
||||
UserIDContextKey ContextKey = iota
|
||||
UserTimezoneContextKey
|
||||
IsAdminUserContextKey
|
||||
IsAuthenticatedContextKey
|
||||
UserSessionTokenContextKey
|
||||
UserLanguageContextKey
|
||||
UserThemeContextKey
|
||||
SessionIDContextKey
|
||||
CSRFContextKey
|
||||
OAuth2StateContextKey
|
||||
FlashMessageContextKey
|
||||
FlashErrorMessageContextKey
|
||||
PocketRequestTokenContextKey
|
||||
ClientIPContextKey
|
||||
GoogleReaderToken
|
||||
)
|
||||
|
||||
// GoolgeReaderToken returns the google reader token if it exists.
|
||||
func GoolgeReaderToken(r *http.Request) string {
|
||||
return getContextStringValue(r, GoogleReaderToken)
|
||||
}
|
||||
|
||||
// IsAdminUser checks if the logged user is administrator.
|
||||
func IsAdminUser(r *http.Request) bool {
|
||||
return getContextBoolValue(r, IsAdminUserContextKey)
|
||||
}
|
||||
|
||||
// IsAuthenticated returns a boolean if the user is authenticated.
|
||||
func IsAuthenticated(r *http.Request) bool {
|
||||
return getContextBoolValue(r, IsAuthenticatedContextKey)
|
||||
}
|
||||
|
||||
// UserID returns the UserID of the logged user.
|
||||
func UserID(r *http.Request) int64 {
|
||||
return getContextInt64Value(r, UserIDContextKey)
|
||||
}
|
||||
|
||||
// UserTimezone returns the timezone used by the logged user.
|
||||
func UserTimezone(r *http.Request) string {
|
||||
value := getContextStringValue(r, UserTimezoneContextKey)
|
||||
if value == "" {
|
||||
value = "UTC"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// UserLanguage get the locale used by the current logged user.
|
||||
func UserLanguage(r *http.Request) string {
|
||||
language := getContextStringValue(r, UserLanguageContextKey)
|
||||
if language == "" {
|
||||
language = "en_US"
|
||||
}
|
||||
return language
|
||||
}
|
||||
|
||||
// UserTheme get the theme used by the current logged user.
|
||||
func UserTheme(r *http.Request) string {
|
||||
theme := getContextStringValue(r, UserThemeContextKey)
|
||||
if theme == "" {
|
||||
theme = "system_serif"
|
||||
}
|
||||
return theme
|
||||
}
|
||||
|
||||
// CSRF returns the current CSRF token.
|
||||
func CSRF(r *http.Request) string {
|
||||
return getContextStringValue(r, CSRFContextKey)
|
||||
}
|
||||
|
||||
// SessionID returns the current session ID.
|
||||
func SessionID(r *http.Request) string {
|
||||
return getContextStringValue(r, SessionIDContextKey)
|
||||
}
|
||||
|
||||
// UserSessionToken returns the current user session token.
|
||||
func UserSessionToken(r *http.Request) string {
|
||||
return getContextStringValue(r, UserSessionTokenContextKey)
|
||||
}
|
||||
|
||||
// OAuth2State returns the current OAuth2 state.
|
||||
func OAuth2State(r *http.Request) string {
|
||||
return getContextStringValue(r, OAuth2StateContextKey)
|
||||
}
|
||||
|
||||
// FlashMessage returns the message message if any.
|
||||
func FlashMessage(r *http.Request) string {
|
||||
return getContextStringValue(r, FlashMessageContextKey)
|
||||
}
|
||||
|
||||
// FlashErrorMessage returns the message error message if any.
|
||||
func FlashErrorMessage(r *http.Request) string {
|
||||
return getContextStringValue(r, FlashErrorMessageContextKey)
|
||||
}
|
||||
|
||||
// PocketRequestToken returns the Pocket Request Token if any.
|
||||
func PocketRequestToken(r *http.Request) string {
|
||||
return getContextStringValue(r, PocketRequestTokenContextKey)
|
||||
}
|
||||
|
||||
// ClientIP returns the client IP address stored in the context.
|
||||
func ClientIP(r *http.Request) string {
|
||||
return getContextStringValue(r, ClientIPContextKey)
|
||||
}
|
||||
|
||||
func getContextStringValue(r *http.Request, key ContextKey) string {
|
||||
if v := r.Context().Value(key); v != nil {
|
||||
value, valid := v.(string)
|
||||
if !valid {
|
||||
return ""
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func getContextBoolValue(r *http.Request, key ContextKey) bool {
|
||||
if v := r.Context().Value(key); v != nil {
|
||||
value, valid := v.(bool)
|
||||
if !valid {
|
||||
return false
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
|
||||
if v := r.Context().Value(key); v != nil {
|
||||
value, valid := v.(int64)
|
||||
if !valid {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
435
internal/http/request/context_test.go
Normal file
435
internal/http/request/context_test.go
Normal file
|
@ -0,0 +1,435 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContextStringValue(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextStringValue(r, ClientIPContextKey)
|
||||
expected := "IP"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextStringValueWithInvalidType(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, ClientIPContextKey, 0)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextStringValue(r, ClientIPContextKey)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextStringValueWhenUnset(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := getContextStringValue(r, ClientIPContextKey)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBoolValue(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||
expected := true
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBoolValueWithInvalidType(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||
expected := false
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBoolValueWhenUnset(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||
expected := false
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextInt64Value(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextInt64Value(r, UserIDContextKey)
|
||||
expected := int64(1234)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextInt64ValueWithInvalidType(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result := getContextInt64Value(r, UserIDContextKey)
|
||||
expected := int64(0)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextInt64ValueWhenUnset(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := getContextInt64Value(r, UserIDContextKey)
|
||||
expected := int64(0)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAdmin(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := IsAdminUser(r)
|
||||
expected := false
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = IsAdminUser(r)
|
||||
expected = true
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAuthenticated(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := IsAuthenticated(r)
|
||||
expected := false
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = IsAuthenticated(r)
|
||||
expected = true
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserID(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := UserID(r)
|
||||
expected := int64(0)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = UserID(r)
|
||||
expected = int64(123)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserTimezone(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := UserTimezone(r)
|
||||
expected := "UTC"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = UserTimezone(r)
|
||||
expected = "Europe/Paris"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserLanguage(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := UserLanguage(r)
|
||||
expected := "en_US"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = UserLanguage(r)
|
||||
expected = "fr_FR"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserTheme(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := UserTheme(r)
|
||||
expected := "system_serif"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserThemeContextKey, "dark_serif")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = UserTheme(r)
|
||||
expected = "dark_serif"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := CSRF(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, CSRFContextKey, "secret")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = CSRF(r)
|
||||
expected = "secret"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionID(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := SessionID(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, SessionIDContextKey, "id")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = SessionID(r)
|
||||
expected = "id"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserSessionToken(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := UserSessionToken(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = UserSessionToken(r)
|
||||
expected = "token"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2State(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := OAuth2State(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = OAuth2State(r)
|
||||
expected = "state"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlashMessage(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := FlashMessage(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = FlashMessage(r)
|
||||
expected = "message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlashErrorMessage(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := FlashErrorMessage(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = FlashErrorMessage(r)
|
||||
expected = "error message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketRequestToken(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := PocketRequestToken(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = PocketRequestToken(r)
|
||||
expected = "request token"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIP(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := ClientIP(r)
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
result = ClientIP(r)
|
||||
expected = "127.0.0.1"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
16
internal/http/request/cookie.go
Normal file
16
internal/http/request/cookie.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import "net/http"
|
||||
|
||||
// CookieValue returns the cookie value.
|
||||
func CookieValue(r *http.Request, name string) string {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err == http.ErrNoCookie {
|
||||
return ""
|
||||
}
|
||||
|
||||
return cookie.Value
|
||||
}
|
32
internal/http/request/cookie_test.go
Normal file
32
internal/http/request/cookie_test.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetCookieValue(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"})
|
||||
|
||||
result := CookieValue(r, "my_cookie")
|
||||
expected := "cookie_value"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCookieValueWhenUnset(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||
|
||||
result := CookieValue(r, "my_cookie")
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
131
internal/http/request/params.go
Normal file
131
internal/http/request/params.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// FormInt64Value returns a form value as integer.
|
||||
func FormInt64Value(r *http.Request, param string) int64 {
|
||||
value := r.FormValue(param)
|
||||
integer, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return integer
|
||||
}
|
||||
|
||||
// RouteInt64Param returns an URL route parameter as int64.
|
||||
func RouteInt64Param(r *http.Request, param string) int64 {
|
||||
vars := mux.Vars(r)
|
||||
value, err := strconv.ParseInt(vars[param], 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// RouteStringParam returns a URL route parameter as string.
|
||||
func RouteStringParam(r *http.Request, param string) string {
|
||||
vars := mux.Vars(r)
|
||||
return vars[param]
|
||||
}
|
||||
|
||||
// QueryStringParam returns a query string parameter as string.
|
||||
func QueryStringParam(r *http.Request, param, defaultValue string) string {
|
||||
value := r.URL.Query().Get(param)
|
||||
if value == "" {
|
||||
value = defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// QueryStringParamList returns all values associated to the parameter.
|
||||
func QueryStringParamList(r *http.Request, param string) []string {
|
||||
var results []string
|
||||
values := r.URL.Query()
|
||||
|
||||
if _, found := values[param]; found {
|
||||
for _, value := range values[param] {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
results = append(results, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// QueryIntParam returns a query string parameter as integer.
|
||||
func QueryIntParam(r *http.Request, param string, defaultValue int) int {
|
||||
value := r.URL.Query().Get(param)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
val, err := strconv.ParseInt(value, 10, 0)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if val < 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return int(val)
|
||||
}
|
||||
|
||||
// QueryInt64Param returns a query string parameter as int64.
|
||||
func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
|
||||
value := r.URL.Query().Get(param)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
val, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if val < 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// QueryBoolParam returns a query string parameter as bool.
|
||||
func QueryBoolParam(r *http.Request, param string, defaultValue bool) bool {
|
||||
value := r.URL.Query().Get(param)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
val, err := strconv.ParseBool(value)
|
||||
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// HasQueryParam checks if the query string contains the given parameter.
|
||||
func HasQueryParam(r *http.Request, param string) bool {
|
||||
values := r.URL.Query()
|
||||
_, ok := values[param]
|
||||
return ok
|
||||
}
|
214
internal/http/request/params_test.go
Normal file
214
internal/http/request/params_test.go
Normal file
|
@ -0,0 +1,214 @@
|
|||
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package request // import "miniflux.app/v2/internal/http/request"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func TestFormInt64Value(t *testing.T) {
|
||||
f := url.Values{}
|
||||
f.Set("integer value", "42")
|
||||
f.Set("invalid value", "invalid integer")
|
||||
|
||||
r := &http.Request{Form: f}
|
||||
|
||||
result := FormInt64Value(r, "integer value")
|
||||
expected := int64(42)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = FormInt64Value(r, "invalid value")
|
||||
expected = int64(0)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = FormInt64Value(r, "missing value")
|
||||
expected = int64(0)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteStringParam(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
|
||||
result := RouteStringParam(r, "variable")
|
||||
expected := "value"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
result = RouteStringParam(r, "missing variable")
|
||||
expected = ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||
}
|
||||
})
|
||||
|
||||
r, err := http.NewRequest("GET", "/route/value/index", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func TestRouteInt64Param(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
|
||||
result := RouteInt64Param(r, "variable1")
|
||||
expected := int64(42)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = RouteInt64Param(r, "missing variable")
|
||||
expected = 0
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = RouteInt64Param(r, "variable2")
|
||||
expected = 0
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = RouteInt64Param(r, "variable3")
|
||||
expected = 0
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
})
|
||||
|
||||
r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func TestQueryStringParam(t *testing.T) {
|
||||
u, _ := url.Parse("http://example.org/?key=value")
|
||||
r := &http.Request{URL: u}
|
||||
|
||||
result := QueryStringParam(r, "key", "fallback")
|
||||
expected := "value"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryStringParam(r, "missing key", "fallback")
|
||||
expected = "fallback"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryIntParam(t *testing.T) {
|
||||
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
|
||||
r := &http.Request{URL: u}
|
||||
|
||||
result := QueryIntParam(r, "key", 84)
|
||||
expected := 42
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryIntParam(r, "missing key", 84)
|
||||
expected = 84
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryIntParam(r, "negative", 69)
|
||||
expected = 69
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryIntParam(r, "invalid", 99)
|
||||
expected = 99
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryInt64Param(t *testing.T) {
|
||||
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
|
||||
r := &http.Request{URL: u}
|
||||
|
||||
result := QueryInt64Param(r, "key", int64(84))
|
||||
expected := int64(42)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryInt64Param(r, "missing key", int64(84))
|
||||
expected = int64(84)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryInt64Param(r, "invalid", int64(69))
|
||||
expected = int64(69)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
|
||||
result = QueryInt64Param(r, "invalid", int64(99))
|
||||
expected = int64(99)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasQueryParam(t *testing.T) {
|
||||
u, _ := url.Parse("http://example.org/?key=42")
|
||||
r := &http.Request{URL: u}
|
||||
|
||||
result := HasQueryParam(r, "key")
|
||||
expected := true
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
|
||||
}
|
||||
|
||||
result = HasQueryParam(r, "missing key")
|
||||
expected = false
|
||||
|
||||
if result != expected {
|
||||
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue