1
0
Fork 0
mirror of https://github.com/FrankerFaceZ/FrankerFaceZ.git synced 2025-07-03 01:28:30 +00:00

Use a regexp instead to match origin

This commit is contained in:
Kane York 2016-02-03 22:06:46 -08:00
parent cddd13ba16
commit bba5d8f344
2 changed files with 13 additions and 7 deletions

View file

@ -20,6 +20,7 @@ import (
"syscall" "syscall"
"time" "time"
"unicode/utf8" "unicode/utf8"
"regexp"
) )
// SuccessCommand is a Reply Command to indicate success in reply to a C2S Command. // SuccessCommand is a Reply Command to indicate success in reply to a C2S Command.
@ -49,9 +50,9 @@ const AsyncResponseCommand Command = "_async"
const defaultMinMemoryKB = 1024 * 24 const defaultMinMemoryKB = 1024 * 24
// TwitchDotTv is the http origin for twitch.tv. // DotTwitchDotTv is the .twitch.tv suffix.
const TwitchDotTv = "http://www.twitch.tv" const DotTwitchDotTv = ".twitch.tv"
const TwitchDotTvHTTPS = "https://www.twitch.tv" var OriginRegexp = regexp.MustCompile(DotTwitchDotTv + "$")
// ResponseSuccess is a Reply ClientMessage with the MessageID not yet filled out. // ResponseSuccess is a Reply ClientMessage with the MessageID not yet filled out.
var ResponseSuccess = ClientMessage{Command: SuccessCommand} var ResponseSuccess = ClientMessage{Command: SuccessCommand}
@ -176,7 +177,7 @@ var SocketUpgrader = websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return r.Header.Get("Origin") == TwitchDotTv || r.Header.Get("Origin") == TwitchDotTvHTTPS return OriginRegexp.MatchString(r.Header.Get("Origin"))
}, },
} }

View file

@ -14,6 +14,8 @@ import (
"time" "time"
) )
const TestOrigin = "http://www.twitch.tv"
func TestSubscriptionAndPublish(t *testing.T) { func TestSubscriptionAndPublish(t *testing.T) {
var doneWg sync.WaitGroup var doneWg sync.WaitGroup
var readyWg sync.WaitGroup var readyWg sync.WaitGroup
@ -54,7 +56,7 @@ func TestSubscriptionAndPublish(t *testing.T) {
var err error var err error
var headers http.Header = make(http.Header) var headers http.Header = make(http.Header)
headers.Set("Origin", TwitchDotTv) headers.Set("Origin", TestOrigin)
// client 1: sub ch1, ch2 // client 1: sub ch1, ch2
// client 2: sub ch1, ch3 // client 2: sub ch1, ch3
@ -72,6 +74,9 @@ func TestSubscriptionAndPublish(t *testing.T) {
return return
} }
// both origins need testing
headers.Set("Origin", "https://www.twitch.tv")
doneWg.Add(1) doneWg.Add(1)
readyWg.Add(1) readyWg.Add(1)
go func(conn *websocket.Conn) { go func(conn *websocket.Conn) {
@ -265,7 +270,7 @@ func TestRestrictedCommands(t *testing.T) {
var challengeChan = make(chan string) var challengeChan = make(chan string)
var headers http.Header = make(http.Header) var headers http.Header = make(http.Header)
headers.Set("Origin", TwitchDotTv) headers.Set("Origin", TestOrigin)
// Client 1 // Client 1
conn, _, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) conn, _, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
@ -366,7 +371,7 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) {
defer unsubscribeAllClients() defer unsubscribeAllClients()
var headers http.Header = make(http.Header) var headers http.Header = make(http.Header)
headers.Set("Origin", TwitchDotTv) headers.Set("Origin", TestOrigin)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {