diff --git a/socketserver/cmd/ffzsocketserver/console.go b/socketserver/cmd/ffzsocketserver/console.go index 5d3063d9..93c95987 100644 --- a/socketserver/cmd/ffzsocketserver/console.go +++ b/socketserver/cmd/ffzsocketserver/console.go @@ -116,11 +116,9 @@ func commandLineConsole() { if i >= count { break } - select { - case cl.MessageChannel <- msg: - case <-cl.MsgChannelIsDone: + if cl.Send(msg) { + kickCount++ } - kickCount++ } return fmt.Sprintf("Kicked %d clients", kickCount), nil }) diff --git a/socketserver/cmd/statsweb/servers.go b/socketserver/cmd/statsweb/servers.go index d4818c61..d17e590f 100644 --- a/socketserver/cmd/statsweb/servers.go +++ b/socketserver/cmd/statsweb/servers.go @@ -12,7 +12,6 @@ import ( "bitbucket.org/stendec/frankerfacez/socketserver/server" "github.com/clarkduvall/hyperloglog" - "github.com/hashicorp/golang-lru" ) type serverFilter struct { diff --git a/socketserver/server/backend.go b/socketserver/server/backend.go index f41defde..3b7a8c3c 100644 --- a/socketserver/server/backend.go +++ b/socketserver/server/backend.go @@ -12,11 +12,10 @@ import ( "net/url" "strconv" "strings" + "sync" "time" - "sync" - - "github.com/pmylund/go-cache" + cache "github.com/patrickmn/go-cache" "golang.org/x/crypto/nacl/box" ) diff --git a/socketserver/server/commands.go b/socketserver/server/commands.go index 3d0156a1..5278fb59 100644 --- a/socketserver/server/commands.go +++ b/socketserver/server/commands.go @@ -145,7 +145,6 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg uniqueUserChannel <- client.ClientID SubscribeGlobal(client) - SubscribeDefaults(client) jsTime := float64(time.Now().UnixNano()/1000) / 1000 return ClientMessage{ @@ -197,7 +196,7 @@ func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg client.MsgChannelKeepalive.Add(1) go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}) SendBacklogForNewClient(client) client.MsgChannelKeepalive.Done() }() @@ -553,10 +552,7 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien bsl.Lock() for _, member := range bsl.Members { msg.MessageID = member.MessageID - select { - case member.Client.MessageChannel <- msg: - case <-member.Client.MsgChannelIsDone: - } + member.Client.Send(msg) } bsl.Unlock() }(br) @@ -580,7 +576,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo if err == ErrAuthorizationNeeded { if client.TwitchUsername == "" { // Not logged in - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError}) client.MsgChannelKeepalive.Done() return } @@ -588,19 +584,19 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo if success { doRemoteCommand(conn, msg, client) } else { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString}) client.MsgChannelKeepalive.Done() } }) return // without keepalive.Done() } else if bfe, ok := err.(ErrForwardedFromBackend); ok { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError}) } else if err != nil { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}) } else { msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } client.MsgChannelKeepalive.Done() } diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go index 723f4fb2..f87e61e8 100644 --- a/socketserver/server/handlecore.go +++ b/socketserver/server/handlecore.go @@ -1,4 +1,4 @@ -package server // import "bitbucket.org/stendec/frankerfacez/socketserver/server" +package server // import "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server" import ( "encoding/json" @@ -300,8 +300,8 @@ var CloseNonUTF8Data = websocket.CloseError{ Text: "Non UTF8 data recieved. Network corruption likely.", } -const sendMessageBufferLength = 30 -const sendMessageAbortLength = 20 +const sendMessageBufferLength = 5 +const sendMessageAbortLength = 5 // RunSocketConnection contains the main run loop of a websocket connection. // @@ -350,11 +350,8 @@ func RunSocketConnection(conn *websocket.Conn) { closeConnection(conn, closeReason) // closeConnection(conn, closeReason, &report) - // Launch message draining goroutine - we aren't out of the pub/sub records - go func() { - for _ = range _serverMessageChan { - } - }() + // We can just drop serverMessageChan and let it be picked up by GC, because all sends are nonblocking. + _serverMessageChan = nil // Closes client.MsgChannelIsDone and also stops the reader thread close(stoppedChan) @@ -364,11 +361,8 @@ func RunSocketConnection(conn *websocket.Conn) { // Wait for pending jobs to finish... client.MsgChannelKeepalive.Wait() - client.MessageChannel = nil // And done. - // Close the channel so the draining goroutine can finish, too. - close(_serverMessageChan) if !StopAcceptingConnections { // Don't perform high contention operations when server is closing diff --git a/socketserver/server/irc.go b/socketserver/server/irc.go index a5a837dc..0eb6eef2 100644 --- a/socketserver/server/irc.go +++ b/socketserver/server/irc.go @@ -81,7 +81,7 @@ func (client *ClientInfo) StartAuthorization(callback AuthCallback) { AddPendingAuthorization(client, challenge, callback) - client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge} + client.Send(ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge}) } const AuthChannelName = "frankerfacezauthorizer" diff --git a/socketserver/server/publisher.go b/socketserver/server/publisher.go index 85912dd6..a1815a63 100644 --- a/socketserver/server/publisher.go +++ b/socketserver/server/publisher.go @@ -7,11 +7,15 @@ import ( "strings" "sync" "time" + + "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate" + "github.com/pkg/errors" ) +// LastSavedMessage contains a reply to a command along with an expiration time. type LastSavedMessage struct { Expires time.Time - Data string + Data string } // map is command -> channel -> data @@ -23,7 +27,7 @@ var CachedLSMLock sync.RWMutex func cachedMessageJanitor() { for { - time.Sleep(1*time.Hour) + time.Sleep(1 * time.Hour) cachedMessageJanitor_do() } } @@ -72,7 +76,7 @@ func SendBacklogForNewClient(client *ClientInfo) { if ok { msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } } } @@ -88,7 +92,7 @@ func SendBacklogForChannel(client *ClientInfo, channel string) { if msg, ok := chanMap[channel]; ok { msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } } CachedLSMLock.RUnlock() @@ -132,6 +136,21 @@ func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) { } } +func rateLimitFromRequest(r *http.Request) (rate.Limiter, error) { + if r.FormValue("rateCount") != "" { + c, err := strconv.ParseInt(r.FormValue("rateCount"), 10, 32) + if err != nil { + return nil, errors.Wrap(err, "rateCount") + } + d, err := time.ParseDuration(r.FormValue("rateTime")) + if err != nil { + return nil, errors.Wrap(err, "rateTime") + } + return rate.NewRateLimit(int(c), d), nil + } + return rate.Unlimited(), nil +} + // HTTPBackendCachedPublish handles the /cached_pub route. // It publishes a message to clients, and then updates the in-server cache for the message. // @@ -163,6 +182,12 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { } expires = time.Unix(timeNum, 0) } + rl, err := rateLimitFromRequest(r) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing ratelimit: %v", err) + return + } var count int msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} @@ -174,8 +199,25 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { saveLastMessage(cmd, channel, expires, json, deleteMode) } CachedLSMLock.Unlock() - count = PublishToMultiple(channels, msg) + var wg sync.WaitGroup + wg.Add(1) + go rl.Run() + go func() { + count = PublishToMultiple(channels, msg, rl) + wg.Done() + rl.Close() + }() + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + select { + case <-time.After(3 * time.Second): + count = -1 + case <-ch: + } w.Write([]byte(strconv.Itoa(count))) } @@ -199,26 +241,50 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) { if cmd == "" { w.WriteHeader(422) - fmt.Fprintf(w, "Error: cmd cannot be blank") + fmt.Fprint(w, "Error: cmd cannot be blank") return } if channel == "" && scope != "global" { w.WriteHeader(422) - fmt.Fprintf(w, "Error: channel must be specified") + fmt.Fprint(w, "Error: channel must be specified") + return + } + rl, err := rateLimitFromRequest(r) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing ratelimit: %v", err) return } cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json} cm.parseOrigArguments() - var count int - switch scope { - default: - count = PublishToMultiple(strings.Split(channel, ","), cm) - case "global": - count = PublishToAll(cm) + var count int + var wg sync.WaitGroup + wg.Add(1) + go rl.Run() + go func() { + switch scope { + default: + count = PublishToMultiple(strings.Split(channel, ","), cm, rl) + case "global": + count = PublishToAll(cm, rl) + } + wg.Done() + rl.Close() + }() + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + select { + case <-time.After(3 * time.Second): + count = -1 + case <-ch: } - fmt.Fprint(w, count) + w.Write([]byte(strconv.Itoa(count))) + } // HTTPGetSubscriberCount handles the /get_sub_count route. @@ -236,4 +302,4 @@ func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) { channel := formData.Get("channel") fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ","))) -} \ No newline at end of file +} diff --git a/socketserver/server/publisher_test.go b/socketserver/server/publisher_test.go index ad667d1b..c21696e5 100644 --- a/socketserver/server/publisher_test.go +++ b/socketserver/server/publisher_test.go @@ -16,9 +16,9 @@ func TestExpiredCleanup(t *testing.T) { defer DumpBacklogData() var zeroTime time.Time - hourAgo := time.Now().Add(-1*time.Hour) + hourAgo := time.Now().Add(-1 * time.Hour) now := time.Now() - hourFromNow := time.Now().Add(1*time.Hour) + hourFromNow := time.Now().Add(1 * time.Hour) saveLastMessage(cmd, channel, hourAgo, "1", false) saveLastMessage(cmd, channel2, now, "2", false) @@ -26,11 +26,11 @@ func TestExpiredCleanup(t *testing.T) { if len(CachedLastMessages) != 1 { t.Error("messages not saved") } - if len(CachedLastMessages[cmd]) != 2{ + if len(CachedLastMessages[cmd]) != 2 { t.Error("messages not saved") } - time.Sleep(2*time.Millisecond) + time.Sleep(2 * time.Millisecond) cachedMessageJanitor_do() @@ -47,7 +47,7 @@ func TestExpiredCleanup(t *testing.T) { t.Error("messages not saved") } - time.Sleep(2*time.Millisecond) + time.Sleep(2 * time.Millisecond) cachedMessageJanitor_do() diff --git a/socketserver/server/rate/ratelimit.go b/socketserver/server/rate/ratelimit.go new file mode 100644 index 00000000..5b38799f --- /dev/null +++ b/socketserver/server/rate/ratelimit.go @@ -0,0 +1,77 @@ +package rate + +import ( + "io" + "time" +) + +// A Limiter supports a constant number of Performed() calls every +// time a certain amount of time passes. +// +// Calls to Performed() when no "action tokens" are available will block +// until one is available. +type Limiter interface { + // Run begins emitting tokens for the ratelimiter. + // A call to Run must be followed by a call to Close. + Run() + // Performed consumes one token from the rate limiter. + // If no tokens are available, the call will block until one is. + Performed() + // Close stops the rate limiter. Any future calls to Performed() will block forever. + // Close never returns an error. + io.Closer +} + +type timeRateLimit struct { + count int + period time.Duration + ch chan struct{} + done chan struct{} +} + +// Construct a new Limiter with the given count and duration. +func NewRateLimit(count int, period time.Duration) Limiter { + return &timeRateLimit{ + count: count, + period: period, + ch: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (r *timeRateLimit) Run() { + for { + waiter := time.After(r.period) + for i := 0; i < r.count; i++ { + select { + case r.ch <- struct{}{}: + // ok + case <-r.done: + return + } + } + <-waiter + } +} + +func (r *timeRateLimit) Performed() { + <-r.ch +} + +func (r *timeRateLimit) Close() error { + close(r.done) + return nil +} + +type unlimited struct{} + +var unlimitedInstance unlimited + +// Unlimited returns a Limiter that never blocks. The Run() and Close() calls are no-ops. +func Unlimited() Limiter { + return unlimitedInstance +} + +func (r unlimited) Run() {} +func (r unlimited) Performed() {} +func (r unlimited) Close() error { return nil } diff --git a/socketserver/server/rate/ratelimit_test.go b/socketserver/server/rate/ratelimit_test.go new file mode 100644 index 00000000..c4f59709 --- /dev/null +++ b/socketserver/server/rate/ratelimit_test.go @@ -0,0 +1,40 @@ +package rate + +import ( + "testing" + "time" +) + +var exampleData = []string{} + +func ExampleNewRateLimit() { + rl := NewRateLimit(100, 1*time.Minute) + go rl.Run() + defer rl.Close() + + for _, v := range exampleData { + rl.Performed() + // do something with v + _ = v + } +} + +func TestRateLimit(t *testing.T) { + rl := NewRateLimit(3, 100*time.Millisecond) + start := time.Now() + go rl.Run() + for i := 0; i < 4; i++ { + rl.Performed() + } + end := time.Now() + if end.Sub(start) < 100*time.Millisecond { + t.Error("ratelimiter did not wait for period to expire") + } + rl.Performed() + rl.Performed() + end2 := time.Now() + if end2.Sub(end) > 10*time.Millisecond { + t.Error("ratelimiter improperly waited when tokens were available") + } + rl.Close() +} diff --git a/socketserver/server/subscriptions.go b/socketserver/server/subscriptions.go index 30bc4112..392e08aa 100644 --- a/socketserver/server/subscriptions.go +++ b/socketserver/server/subscriptions.go @@ -1,17 +1,16 @@ package server -// This is the scariest code I've written yet for the server. -// If I screwed up the locking, I won't know until it's too late. - import ( "log" "sync" "time" + + "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate" ) type SubscriberList struct { sync.RWMutex - Members []chan<- ClientMessage + Members []*ClientInfo } var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) @@ -19,6 +18,18 @@ var ChatSubscriptionLock sync.RWMutex var GlobalSubscriptionInfo []*ClientInfo var GlobalSubscriptionLock sync.RWMutex +func (client *ClientInfo) Send(msg ClientMessage) bool { + select { + case client.MessageChannel <- msg: + return true + case <-client.MsgChannelIsDone: + return false + default: + // if we can't immediately send, ignore it + return false + } +} + func CountSubscriptions(channels []string) int { ChatSubscriptionLock.RLock() defer ChatSubscriptionLock.RUnlock() @@ -38,70 +49,77 @@ func CountSubscriptions(channels []string) int { func SubscribeChannel(client *ClientInfo, channelName string) { ChatSubscriptionLock.RLock() - _subscribeWhileRlocked(channelName, client.MessageChannel) + _subscribeWhileRlocked(channelName, client) ChatSubscriptionLock.RUnlock() } -func SubscribeDefaults(client *ClientInfo) { - -} - func SubscribeGlobal(client *ClientInfo) { GlobalSubscriptionLock.Lock() AddToSliceCl(&GlobalSubscriptionInfo, client) GlobalSubscriptionLock.Unlock() } -func PublishToChannel(channel string, msg ClientMessage) (count int) { +func PublishToChannel(channel string, msg ClientMessage, rl rate.Limiter) (count int) { + var found []*ClientInfo + ChatSubscriptionLock.RLock() list := ChatSubscriptionInfo[channel] if list != nil { list.RLock() - for _, msgChan := range list.Members { - msgChan <- msg - count++ - } + found = make([]*ClientInfo, len(list.Members)) + copy(found, list.Members) list.RUnlock() } ChatSubscriptionLock.RUnlock() + + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } + } return } -func PublishToMultiple(channels []string, msg ClientMessage) (count int) { - found := make(map[chan<- ClientMessage]struct{}) +func PublishToMultiple(channels []string, msg ClientMessage, rl rate.Limiter) (count int) { + var found []*ClientInfo ChatSubscriptionLock.RLock() - for _, channel := range channels { list := ChatSubscriptionInfo[channel] if list != nil { list.RLock() - for _, msgChan := range list.Members { - found[msgChan] = struct{}{} + for _, cl := range list.Members { + found = append(found, cl) } list.RUnlock() } } - ChatSubscriptionLock.RUnlock() - for msgChan, _ := range found { - msgChan <- msg - count++ + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } } return } -func PublishToAll(msg ClientMessage) (count int) { +func PublishToAll(msg ClientMessage, rl rate.Limiter) (count int) { + var found []*ClientInfo + GlobalSubscriptionLock.RLock() - for _, client := range GlobalSubscriptionInfo { - select { - case client.MessageChannel <- msg: - case <-client.MsgChannelIsDone: - } - count++ - } + found = make([]*ClientInfo, len(GlobalSubscriptionInfo)) + copy(found, GlobalSubscriptionInfo) GlobalSubscriptionLock.RUnlock() + + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } + } return } @@ -110,7 +128,7 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) { list := ChatSubscriptionInfo[channelName] if list != nil { list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) + RemoveFromSliceCl(&list.Members, client) list.Unlock() } ChatSubscriptionLock.RUnlock() @@ -138,7 +156,7 @@ func UnsubscribeAll(client *ClientInfo) { list := ChatSubscriptionInfo[v] if list != nil { list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) + RemoveFromSliceCl(&list.Members, client) list.Unlock() } } @@ -191,14 +209,14 @@ func pubsubJanitor_do() { // - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object // - possible write lock to the 'which' top-level map via the wlocker object // - write lock to SubscriptionInfo (if not creating new) -func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { +func _subscribeWhileRlocked(channelName string, value *ClientInfo) { list := ChatSubscriptionInfo[channelName] if list == nil { // Not found, so create it ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.Lock() list = &SubscriberList{} - list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper + list.Members = []*ClientInfo{value} // Create it populated, to avoid reaper ChatSubscriptionInfo[channelName] = list ChatSubscriptionLock.Unlock() @@ -212,7 +230,7 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { ChatSubscriptionLock.RLock() } else { list.Lock() - AddToSliceC(&list.Members, value) + AddToSliceCl(&list.Members, value) list.Unlock() } } diff --git a/socketserver/server/types.go b/socketserver/server/types.go index 41c0d010..365b91b6 100644 --- a/socketserver/server/types.go +++ b/socketserver/server/types.go @@ -108,10 +108,10 @@ type ClientInfo struct { // True if the client has already sent the 'ready' command ReadyComplete bool - // Server-initiated messages should be sent here - // This field will be nil before it is closed. + // Server-initiated messages should be sent via the Send() method. MessageChannel chan<- ClientMessage + // Closed when the client is shutting down. MsgChannelIsDone <-chan struct{} // Take out an Add() on this during a command if you need to use the MessageChannel later. diff --git a/socketserver/server/utils.go b/socketserver/server/utils.go index 9552e7bb..6657dd02 100644 --- a/socketserver/server/utils.go +++ b/socketserver/server/utils.go @@ -130,38 +130,6 @@ func RemoveFromSliceS(ary *[]string, val string) bool { return true } -func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { - slice := *ary - for _, v := range slice { - if v == val { - return false - } - } - - slice = append(slice, val) - *ary = slice - return true -} - -func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { - slice := *ary - var idx int = -1 - for i, v := range slice { - if v == val { - idx = i - break - } - } - if idx == -1 { - return false - } - - slice[idx] = slice[len(slice)-1] - slice = slice[:len(slice)-1] - *ary = slice - return true -} - func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { slice := *ary for _, v := range slice {