diff --git a/socketserver/cmd/ffzsocketserver/console.go b/socketserver/cmd/ffzsocketserver/console.go index 5d3063d9..2b437d14 100644 --- a/socketserver/cmd/ffzsocketserver/console.go +++ b/socketserver/cmd/ffzsocketserver/console.go @@ -116,9 +116,8 @@ func commandLineConsole() { if i >= count { break } - select { - case cl.MessageChannel <- msg: - case <-cl.MsgChannelIsDone: + if cl.Send(msg) { + kickCount++ } kickCount++ } diff --git a/socketserver/server/commands.go b/socketserver/server/commands.go index 3d0156a1..5278fb59 100644 --- a/socketserver/server/commands.go +++ b/socketserver/server/commands.go @@ -145,7 +145,6 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg uniqueUserChannel <- client.ClientID SubscribeGlobal(client) - SubscribeDefaults(client) jsTime := float64(time.Now().UnixNano()/1000) / 1000 return ClientMessage{ @@ -197,7 +196,7 @@ func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg client.MsgChannelKeepalive.Add(1) go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}) SendBacklogForNewClient(client) client.MsgChannelKeepalive.Done() }() @@ -553,10 +552,7 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien bsl.Lock() for _, member := range bsl.Members { msg.MessageID = member.MessageID - select { - case member.Client.MessageChannel <- msg: - case <-member.Client.MsgChannelIsDone: - } + member.Client.Send(msg) } bsl.Unlock() }(br) @@ -580,7 +576,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo if err == ErrAuthorizationNeeded { if client.TwitchUsername == "" { // Not logged in - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError}) client.MsgChannelKeepalive.Done() return } @@ -588,19 +584,19 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo if success { doRemoteCommand(conn, msg, client) } else { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString}) client.MsgChannelKeepalive.Done() } }) return // without keepalive.Done() } else if bfe, ok := err.(ErrForwardedFromBackend); ok { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError}) } else if err != nil { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} + client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}) } else { msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } client.MsgChannelKeepalive.Done() } diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go index 723f4fb2..0b24a0a6 100644 --- a/socketserver/server/handlecore.go +++ b/socketserver/server/handlecore.go @@ -300,8 +300,8 @@ var CloseNonUTF8Data = websocket.CloseError{ Text: "Non UTF8 data recieved. Network corruption likely.", } -const sendMessageBufferLength = 30 -const sendMessageAbortLength = 20 +const sendMessageBufferLength = 5 +const sendMessageAbortLength = 5 // RunSocketConnection contains the main run loop of a websocket connection. // @@ -350,11 +350,8 @@ func RunSocketConnection(conn *websocket.Conn) { closeConnection(conn, closeReason) // closeConnection(conn, closeReason, &report) - // Launch message draining goroutine - we aren't out of the pub/sub records - go func() { - for _ = range _serverMessageChan { - } - }() + // We can just drop serverMessageChan and let it be picked up by GC, because all sends are nonblocking. + _serverMessageChan = nil // Closes client.MsgChannelIsDone and also stops the reader thread close(stoppedChan) @@ -364,11 +361,8 @@ func RunSocketConnection(conn *websocket.Conn) { // Wait for pending jobs to finish... client.MsgChannelKeepalive.Wait() - client.MessageChannel = nil // And done. - // Close the channel so the draining goroutine can finish, too. - close(_serverMessageChan) if !StopAcceptingConnections { // Don't perform high contention operations when server is closing diff --git a/socketserver/server/irc.go b/socketserver/server/irc.go index a5a837dc..0eb6eef2 100644 --- a/socketserver/server/irc.go +++ b/socketserver/server/irc.go @@ -81,7 +81,7 @@ func (client *ClientInfo) StartAuthorization(callback AuthCallback) { AddPendingAuthorization(client, challenge, callback) - client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge} + client.Send(ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge}) } const AuthChannelName = "frankerfacezauthorizer" diff --git a/socketserver/server/publisher.go b/socketserver/server/publisher.go index 85912dd6..98ed2d19 100644 --- a/socketserver/server/publisher.go +++ b/socketserver/server/publisher.go @@ -7,8 +7,10 @@ import ( "strings" "sync" "time" + "github.com/pkg/errors" ) +// LastSavedMessage contains a reply to a command along with an expiration time. type LastSavedMessage struct { Expires time.Time Data string @@ -72,7 +74,7 @@ func SendBacklogForNewClient(client *ClientInfo) { if ok { msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } } } @@ -88,7 +90,7 @@ func SendBacklogForChannel(client *ClientInfo, channel string) { if msg, ok := chanMap[channel]; ok { msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg.parseOrigArguments() - client.MessageChannel <- msg + client.Send(msg) } } CachedLSMLock.RUnlock() @@ -132,6 +134,21 @@ func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) { } } +func rateLimitFromRequest(r *http.Request) (RateLimit, error) { + if r.FormValue("rateCount") != "" { + c, err := strconv.ParseInt(r.FormValue("rateCount"), 10, 32) + if err != nil { + return nil, errors.Wrap(err, "rateCount") + } + d, err := time.ParseDuration(r.FormValue("rateTime")) + if err != nil { + return nil, errors.Wrap(err, "rateTime") + } + return NewRateLimit(int(c), d), nil + } + return Unlimited(), nil +} + // HTTPBackendCachedPublish handles the /cached_pub route. // It publishes a message to clients, and then updates the in-server cache for the message. // @@ -163,6 +180,12 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { } expires = time.Unix(timeNum, 0) } + rl, err := rateLimitFromRequest(r) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing ratelimit: %v", err) + return + } var count int msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} @@ -174,8 +197,25 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { saveLastMessage(cmd, channel, expires, json, deleteMode) } CachedLSMLock.Unlock() - count = PublishToMultiple(channels, msg) + var wg sync.WaitGroup + wg.Add(1) + go rl.Run() + go func() { + count = PublishToMultiple(channels, msg, rl) + wg.Done() + rl.Close() + }() + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + select { + case time.After(3*time.Second): + count = -1 + case <-ch: + } w.Write([]byte(strconv.Itoa(count))) } @@ -199,26 +239,50 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) { if cmd == "" { w.WriteHeader(422) - fmt.Fprintf(w, "Error: cmd cannot be blank") + fmt.Fprint(w, "Error: cmd cannot be blank") return } if channel == "" && scope != "global" { w.WriteHeader(422) - fmt.Fprintf(w, "Error: channel must be specified") + fmt.Fprint(w, "Error: channel must be specified") + return + } + rl, err := rateLimitFromRequest(r) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing ratelimit: %v", err) return } cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json} cm.parseOrigArguments() - var count int - switch scope { - default: - count = PublishToMultiple(strings.Split(channel, ","), cm) - case "global": - count = PublishToAll(cm) + var count int + var wg sync.WaitGroup + wg.Add(1) + go rl.Run() + go func() { + switch scope { + default: + count = PublishToMultiple(strings.Split(channel, ","), cm, rl) + case "global": + count = PublishToAll(cm, rl) + } + wg.Done() + rl.Close() + }() + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + select { + case time.After(3*time.Second): + count = -1 + case <-ch: } - fmt.Fprint(w, count) + w.Write([]byte(strconv.Itoa(count))) + } // HTTPGetSubscriberCount handles the /get_sub_count route. diff --git a/socketserver/server/ratelimit.go b/socketserver/server/ratelimit.go new file mode 100644 index 00000000..ad7b4e33 --- /dev/null +++ b/socketserver/server/ratelimit.go @@ -0,0 +1,76 @@ +package server + +import ( + "time" + "io" +) + +// A RateLimit supports a constant number of Performed() calls every +// time a given unit of time passes. +// +// Calls to Performed() when no "action tokens" are available will block +// until one is available. +type RateLimit interface { + // Run begins emitting tokens for the ratelimiter. + // A call to Run must be followed by a call to Close. + Run() + // Performed consumes one token from the rate limiter. + // If no tokens are available, the call will block until one is. + Performed() + // Close stops the rate limiter. Any future calls to Performed() will block forever. + // Close never returns an error. + io.Closer +} + +type timeRateLimit struct{ + count int + period time.Duration + ch chan struct{} + done chan struct{} +} + +// Construct a new RateLimit with the given count and duration. +func NewRateLimit(count int, period time.Duration) (RateLimit) { + return &timeRateLimit{ + count: count, + period: period, + ch: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (r *timeRateLimit) Run() { + for { + waiter := time.After(r.period) + for i := 0; i < r.count; i++ { + select { + case r.ch <- struct{}{}: + // ok + case <-r.done: + return + } + } + <-waiter + } +} + +func (r *timeRateLimit) Performed() { + <-r.ch +} + +func (r *timeRateLimit) Close() error { + close(r.done) + return nil +} + +type unlimited struct{} +var unlimitedInstance unlimited + +// Unlimited returns a RateLimit that never blocks. The Run() and Close() calls are no-ops. +func Unlimited() (RateLimit) { + return unlimitedInstance +} + +func (r unlimited) Run() { } +func (r unlimited) Performed() { } +func (r unlimited) Close() error { return nil } diff --git a/socketserver/server/ratelimit_test.go b/socketserver/server/ratelimit_test.go new file mode 100644 index 00000000..f8ba87f9 --- /dev/null +++ b/socketserver/server/ratelimit_test.go @@ -0,0 +1,40 @@ +package server + +import ( + "time" + "testing" +) + +var exampleData = []string{} + +func ExampleNewRateLimit() { + rl := NewRateLimit(100, 1*time.Minute) + go rl.Run() + defer rl.Close() + + for _, v := range exampleData { + rl.Performed() + // do something with v + _ = v + } +} + +func TestRateLimit(t *testing.T) { + rl := NewRateLimit(3, 100*time.Millisecond) + start := time.Now() + go rl.Run() + for i := 0; i < 4; i++ { + rl.Performed() + } + end := time.Now() + if end.Sub(start) < 100*time.Millisecond { + t.Error("ratelimiter did not wait for period to expire") + } + rl.Performed() + rl.Performed() + end2 := time.Now() + if end2.Sub(end) > 10*time.Millisecond { + t.Error("ratelimiter improperly waited when tokens were available") + } + rl.Close() +} \ No newline at end of file diff --git a/socketserver/server/subscriptions.go b/socketserver/server/subscriptions.go index 30bc4112..90eb7aa4 100644 --- a/socketserver/server/subscriptions.go +++ b/socketserver/server/subscriptions.go @@ -11,7 +11,7 @@ import ( type SubscriberList struct { sync.RWMutex - Members []chan<- ClientMessage + Members []*ClientInfo } var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) @@ -19,6 +19,18 @@ var ChatSubscriptionLock sync.RWMutex var GlobalSubscriptionInfo []*ClientInfo var GlobalSubscriptionLock sync.RWMutex +func (client *ClientInfo) Send(msg ClientMessage) bool { + select { + case client.MessageChannel <- msg: + return true + case <-client.MsgChannelIsDone: + return false + default: + // if we can't immediately send, ignore it + return false + } +} + func CountSubscriptions(channels []string) int { ChatSubscriptionLock.RLock() defer ChatSubscriptionLock.RUnlock() @@ -38,70 +50,77 @@ func CountSubscriptions(channels []string) int { func SubscribeChannel(client *ClientInfo, channelName string) { ChatSubscriptionLock.RLock() - _subscribeWhileRlocked(channelName, client.MessageChannel) + _subscribeWhileRlocked(channelName, client) ChatSubscriptionLock.RUnlock() } -func SubscribeDefaults(client *ClientInfo) { - -} - func SubscribeGlobal(client *ClientInfo) { GlobalSubscriptionLock.Lock() AddToSliceCl(&GlobalSubscriptionInfo, client) GlobalSubscriptionLock.Unlock() } -func PublishToChannel(channel string, msg ClientMessage) (count int) { +func PublishToChannel(channel string, msg ClientMessage, rl RateLimit) (count int) { + var found []*ClientInfo + ChatSubscriptionLock.RLock() list := ChatSubscriptionInfo[channel] if list != nil { list.RLock() - for _, msgChan := range list.Members { - msgChan <- msg - count++ - } + found = make([]*ClientInfo, len(list.Members)) + copy(found, list.Members) list.RUnlock() } ChatSubscriptionLock.RUnlock() + + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } + } return } -func PublishToMultiple(channels []string, msg ClientMessage) (count int) { - found := make(map[chan<- ClientMessage]struct{}) +func PublishToMultiple(channels []string, msg ClientMessage, rl RateLimit) (count int) { + var found []*ClientInfo ChatSubscriptionLock.RLock() - for _, channel := range channels { list := ChatSubscriptionInfo[channel] if list != nil { list.RLock() - for _, msgChan := range list.Members { - found[msgChan] = struct{}{} + for _, cl := range list.Members { + found = append(found, cl) } list.RUnlock() } } - ChatSubscriptionLock.RUnlock() - for msgChan, _ := range found { - msgChan <- msg - count++ + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } } - return + return count } -func PublishToAll(msg ClientMessage) (count int) { +func PublishToAll(msg ClientMessage, rl RateLimit) (count int) { + var found []*ClientInfo + GlobalSubscriptionLock.RLock() - for _, client := range GlobalSubscriptionInfo { - select { - case client.MessageChannel <- msg: - case <-client.MsgChannelIsDone: - } - count++ - } + found = make([]*ClientInfo, len(GlobalSubscriptionInfo)) + copy(found, GlobalSubscriptionInfo) GlobalSubscriptionLock.RUnlock() + + for _, cl := range found { + rl.Performed() + if cl.Send(msg) { + count++ + } + } return } @@ -110,7 +129,7 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) { list := ChatSubscriptionInfo[channelName] if list != nil { list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) + RemoveFromSliceCl(&list.Members, client) list.Unlock() } ChatSubscriptionLock.RUnlock() @@ -138,7 +157,7 @@ func UnsubscribeAll(client *ClientInfo) { list := ChatSubscriptionInfo[v] if list != nil { list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) + RemoveFromSliceCl(&list.Members, client) list.Unlock() } } @@ -191,14 +210,14 @@ func pubsubJanitor_do() { // - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object // - possible write lock to the 'which' top-level map via the wlocker object // - write lock to SubscriptionInfo (if not creating new) -func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { +func _subscribeWhileRlocked(channelName string, value *ClientInfo) { list := ChatSubscriptionInfo[channelName] if list == nil { // Not found, so create it ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.Lock() list = &SubscriberList{} - list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper + list.Members = []*ClientInfo{value} // Create it populated, to avoid reaper ChatSubscriptionInfo[channelName] = list ChatSubscriptionLock.Unlock() @@ -212,7 +231,7 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { ChatSubscriptionLock.RLock() } else { list.Lock() - AddToSliceC(&list.Members, value) + AddToSliceCl(&list.Members, value) list.Unlock() } } diff --git a/socketserver/server/types.go b/socketserver/server/types.go index 41c0d010..3fa99084 100644 --- a/socketserver/server/types.go +++ b/socketserver/server/types.go @@ -112,6 +112,7 @@ type ClientInfo struct { // This field will be nil before it is closed. MessageChannel chan<- ClientMessage + // Closed when the client is shutting down. MsgChannelIsDone <-chan struct{} // Take out an Add() on this during a command if you need to use the MessageChannel later. diff --git a/socketserver/server/utils.go b/socketserver/server/utils.go index 9552e7bb..6657dd02 100644 --- a/socketserver/server/utils.go +++ b/socketserver/server/utils.go @@ -130,38 +130,6 @@ func RemoveFromSliceS(ary *[]string, val string) bool { return true } -func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { - slice := *ary - for _, v := range slice { - if v == val { - return false - } - } - - slice = append(slice, val) - *ary = slice - return true -} - -func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { - slice := *ary - var idx int = -1 - for i, v := range slice { - if v == val { - idx = i - break - } - } - if idx == -1 { - return false - } - - slice[idx] = slice[len(slice)-1] - slice = slice[:len(slice)-1] - *ary = slice - return true -} - func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { slice := *ary for _, v := range slice {