diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go index c3b52842..ed0f41f4 100644 --- a/socketserver/server/handlecore.go +++ b/socketserver/server/handlecore.go @@ -20,6 +20,7 @@ import ( "syscall" "time" "unicode/utf8" + "regexp" ) // SuccessCommand is a Reply Command to indicate success in reply to a C2S Command. @@ -49,9 +50,9 @@ const AsyncResponseCommand Command = "_async" const defaultMinMemoryKB = 1024 * 24 -// TwitchDotTv is the http origin for twitch.tv. -const TwitchDotTv = "http://www.twitch.tv" -const TwitchDotTvHTTPS = "https://www.twitch.tv" +// DotTwitchDotTv is the .twitch.tv suffix. +const DotTwitchDotTv = ".twitch.tv" +var OriginRegexp = regexp.MustCompile(DotTwitchDotTv + "$") // ResponseSuccess is a Reply ClientMessage with the MessageID not yet filled out. var ResponseSuccess = ClientMessage{Command: SuccessCommand} @@ -176,7 +177,7 @@ var SocketUpgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { - return r.Header.Get("Origin") == TwitchDotTv || r.Header.Get("Origin") == TwitchDotTvHTTPS + return OriginRegexp.MatchString(r.Header.Get("Origin")) }, } diff --git a/socketserver/server/subscriptions_test.go b/socketserver/server/subscriptions_test.go index c24f88e8..132dda74 100644 --- a/socketserver/server/subscriptions_test.go +++ b/socketserver/server/subscriptions_test.go @@ -14,6 +14,8 @@ import ( "time" ) +const TestOrigin = "http://www.twitch.tv" + func TestSubscriptionAndPublish(t *testing.T) { var doneWg sync.WaitGroup var readyWg sync.WaitGroup @@ -54,7 +56,7 @@ func TestSubscriptionAndPublish(t *testing.T) { var err error var headers http.Header = make(http.Header) - headers.Set("Origin", TwitchDotTv) + headers.Set("Origin", TestOrigin) // client 1: sub ch1, ch2 // client 2: sub ch1, ch3 @@ -72,6 +74,9 @@ func TestSubscriptionAndPublish(t *testing.T) { return } + // both origins need testing + headers.Set("Origin", "https://www.twitch.tv") + doneWg.Add(1) readyWg.Add(1) go func(conn *websocket.Conn) { @@ -265,7 +270,7 @@ func TestRestrictedCommands(t *testing.T) { var challengeChan = make(chan string) var headers http.Header = make(http.Header) - headers.Set("Origin", TwitchDotTv) + headers.Set("Origin", TestOrigin) // Client 1 conn, _, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) @@ -366,7 +371,7 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) { defer unsubscribeAllClients() var headers http.Header = make(http.Header) - headers.Set("Origin", TwitchDotTv) + headers.Set("Origin", TestOrigin) b.ResetTimer() for i := 0; i < b.N; i++ {