1
0
Fork 0
mirror of https://github.com/FrankerFaceZ/FrankerFaceZ.git synced 2025-08-02 16:08:31 +00:00

Add ratelimits to publishing

This commit is contained in:
Kane York 2017-02-02 22:59:17 -08:00
parent b84bd1d4a2
commit 50e295c834
10 changed files with 260 additions and 103 deletions

View file

@ -116,9 +116,8 @@ func commandLineConsole() {
if i >= count { if i >= count {
break break
} }
select { if cl.Send(msg) {
case cl.MessageChannel <- msg: kickCount++
case <-cl.MsgChannelIsDone:
} }
kickCount++ kickCount++
} }

View file

@ -145,7 +145,6 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
uniqueUserChannel <- client.ClientID uniqueUserChannel <- client.ClientID
SubscribeGlobal(client) SubscribeGlobal(client)
SubscribeDefaults(client)
jsTime := float64(time.Now().UnixNano()/1000) / 1000 jsTime := float64(time.Now().UnixNano()/1000) / 1000
return ClientMessage{ return ClientMessage{
@ -197,7 +196,7 @@ func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
client.MsgChannelKeepalive.Add(1) client.MsgChannelKeepalive.Add(1)
go func() { go func() {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand})
SendBacklogForNewClient(client) SendBacklogForNewClient(client)
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
}() }()
@ -553,10 +552,7 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien
bsl.Lock() bsl.Lock()
for _, member := range bsl.Members { for _, member := range bsl.Members {
msg.MessageID = member.MessageID msg.MessageID = member.MessageID
select { member.Client.Send(msg)
case member.Client.MessageChannel <- msg:
case <-member.Client.MsgChannelIsDone:
}
} }
bsl.Unlock() bsl.Unlock()
}(br) }(br)
@ -580,7 +576,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if err == ErrAuthorizationNeeded { if err == ErrAuthorizationNeeded {
if client.TwitchUsername == "" { if client.TwitchUsername == "" {
// Not logged in // Not logged in
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError})
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
return return
} }
@ -588,19 +584,19 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if success { if success {
doRemoteCommand(conn, msg, client) doRemoteCommand(conn, msg, client)
} else { } else {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString})
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }
}) })
return // without keepalive.Done() return // without keepalive.Done()
} else if bfe, ok := err.(ErrForwardedFromBackend); ok { } else if bfe, ok := err.(ErrForwardedFromBackend); ok {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError})
} else if err != nil { } else if err != nil {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()})
} else { } else {
msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }

View file

@ -300,8 +300,8 @@ var CloseNonUTF8Data = websocket.CloseError{
Text: "Non UTF8 data recieved. Network corruption likely.", Text: "Non UTF8 data recieved. Network corruption likely.",
} }
const sendMessageBufferLength = 30 const sendMessageBufferLength = 5
const sendMessageAbortLength = 20 const sendMessageAbortLength = 5
// RunSocketConnection contains the main run loop of a websocket connection. // RunSocketConnection contains the main run loop of a websocket connection.
// //
@ -350,11 +350,8 @@ func RunSocketConnection(conn *websocket.Conn) {
closeConnection(conn, closeReason) closeConnection(conn, closeReason)
// closeConnection(conn, closeReason, &report) // closeConnection(conn, closeReason, &report)
// Launch message draining goroutine - we aren't out of the pub/sub records // We can just drop serverMessageChan and let it be picked up by GC, because all sends are nonblocking.
go func() { _serverMessageChan = nil
for _ = range _serverMessageChan {
}
}()
// Closes client.MsgChannelIsDone and also stops the reader thread // Closes client.MsgChannelIsDone and also stops the reader thread
close(stoppedChan) close(stoppedChan)
@ -364,11 +361,8 @@ func RunSocketConnection(conn *websocket.Conn) {
// Wait for pending jobs to finish... // Wait for pending jobs to finish...
client.MsgChannelKeepalive.Wait() client.MsgChannelKeepalive.Wait()
client.MessageChannel = nil
// And done. // And done.
// Close the channel so the draining goroutine can finish, too.
close(_serverMessageChan)
if !StopAcceptingConnections { if !StopAcceptingConnections {
// Don't perform high contention operations when server is closing // Don't perform high contention operations when server is closing

View file

@ -81,7 +81,7 @@ func (client *ClientInfo) StartAuthorization(callback AuthCallback) {
AddPendingAuthorization(client, challenge, callback) AddPendingAuthorization(client, challenge, callback)
client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge} client.Send(ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge})
} }
const AuthChannelName = "frankerfacezauthorizer" const AuthChannelName = "frankerfacezauthorizer"

View file

@ -7,8 +7,10 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pkg/errors"
) )
// LastSavedMessage contains a reply to a command along with an expiration time.
type LastSavedMessage struct { type LastSavedMessage struct {
Expires time.Time Expires time.Time
Data string Data string
@ -72,7 +74,7 @@ func SendBacklogForNewClient(client *ClientInfo) {
if ok { if ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
} }
} }
@ -88,7 +90,7 @@ func SendBacklogForChannel(client *ClientInfo, channel string) {
if msg, ok := chanMap[channel]; ok { if msg, ok := chanMap[channel]; ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
} }
CachedLSMLock.RUnlock() CachedLSMLock.RUnlock()
@ -132,6 +134,21 @@ func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
} }
} }
func rateLimitFromRequest(r *http.Request) (RateLimit, error) {
if r.FormValue("rateCount") != "" {
c, err := strconv.ParseInt(r.FormValue("rateCount"), 10, 32)
if err != nil {
return nil, errors.Wrap(err, "rateCount")
}
d, err := time.ParseDuration(r.FormValue("rateTime"))
if err != nil {
return nil, errors.Wrap(err, "rateTime")
}
return NewRateLimit(int(c), d), nil
}
return Unlimited(), nil
}
// HTTPBackendCachedPublish handles the /cached_pub route. // HTTPBackendCachedPublish handles the /cached_pub route.
// It publishes a message to clients, and then updates the in-server cache for the message. // It publishes a message to clients, and then updates the in-server cache for the message.
// //
@ -163,6 +180,12 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
} }
expires = time.Unix(timeNum, 0) expires = time.Unix(timeNum, 0)
} }
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return
}
var count int var count int
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
@ -174,8 +197,25 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
saveLastMessage(cmd, channel, expires, json, deleteMode) saveLastMessage(cmd, channel, expires, json, deleteMode)
} }
CachedLSMLock.Unlock() CachedLSMLock.Unlock()
count = PublishToMultiple(channels, msg)
var wg sync.WaitGroup
wg.Add(1)
go rl.Run()
go func() {
count = PublishToMultiple(channels, msg, rl)
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case time.After(3*time.Second):
count = -1
case <-ch:
}
w.Write([]byte(strconv.Itoa(count))) w.Write([]byte(strconv.Itoa(count)))
} }
@ -199,26 +239,50 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
if cmd == "" { if cmd == "" {
w.WriteHeader(422) w.WriteHeader(422)
fmt.Fprintf(w, "Error: cmd cannot be blank") fmt.Fprint(w, "Error: cmd cannot be blank")
return return
} }
if channel == "" && scope != "global" { if channel == "" && scope != "global" {
w.WriteHeader(422) w.WriteHeader(422)
fmt.Fprintf(w, "Error: channel must be specified") fmt.Fprint(w, "Error: channel must be specified")
return
}
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return return
} }
cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json} cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json}
cm.parseOrigArguments() cm.parseOrigArguments()
var count int
switch scope { var count int
default: var wg sync.WaitGroup
count = PublishToMultiple(strings.Split(channel, ","), cm) wg.Add(1)
case "global": go rl.Run()
count = PublishToAll(cm) go func() {
switch scope {
default:
count = PublishToMultiple(strings.Split(channel, ","), cm, rl)
case "global":
count = PublishToAll(cm, rl)
}
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case time.After(3*time.Second):
count = -1
case <-ch:
} }
fmt.Fprint(w, count) w.Write([]byte(strconv.Itoa(count)))
} }
// HTTPGetSubscriberCount handles the /get_sub_count route. // HTTPGetSubscriberCount handles the /get_sub_count route.

View file

@ -0,0 +1,76 @@
package server
import (
"time"
"io"
)
// A RateLimit supports a constant number of Performed() calls every
// time a given unit of time passes.
//
// Calls to Performed() when no "action tokens" are available will block
// until one is available.
type RateLimit interface {
// Run begins emitting tokens for the ratelimiter.
// A call to Run must be followed by a call to Close.
Run()
// Performed consumes one token from the rate limiter.
// If no tokens are available, the call will block until one is.
Performed()
// Close stops the rate limiter. Any future calls to Performed() will block forever.
// Close never returns an error.
io.Closer
}
type timeRateLimit struct{
count int
period time.Duration
ch chan struct{}
done chan struct{}
}
// Construct a new RateLimit with the given count and duration.
func NewRateLimit(count int, period time.Duration) (RateLimit) {
return &timeRateLimit{
count: count,
period: period,
ch: make(chan struct{}),
done: make(chan struct{}),
}
}
func (r *timeRateLimit) Run() {
for {
waiter := time.After(r.period)
for i := 0; i < r.count; i++ {
select {
case r.ch <- struct{}{}:
// ok
case <-r.done:
return
}
}
<-waiter
}
}
func (r *timeRateLimit) Performed() {
<-r.ch
}
func (r *timeRateLimit) Close() error {
close(r.done)
return nil
}
type unlimited struct{}
var unlimitedInstance unlimited
// Unlimited returns a RateLimit that never blocks. The Run() and Close() calls are no-ops.
func Unlimited() (RateLimit) {
return unlimitedInstance
}
func (r unlimited) Run() { }
func (r unlimited) Performed() { }
func (r unlimited) Close() error { return nil }

View file

@ -0,0 +1,40 @@
package server
import (
"time"
"testing"
)
var exampleData = []string{}
func ExampleNewRateLimit() {
rl := NewRateLimit(100, 1*time.Minute)
go rl.Run()
defer rl.Close()
for _, v := range exampleData {
rl.Performed()
// do something with v
_ = v
}
}
func TestRateLimit(t *testing.T) {
rl := NewRateLimit(3, 100*time.Millisecond)
start := time.Now()
go rl.Run()
for i := 0; i < 4; i++ {
rl.Performed()
}
end := time.Now()
if end.Sub(start) < 100*time.Millisecond {
t.Error("ratelimiter did not wait for period to expire")
}
rl.Performed()
rl.Performed()
end2 := time.Now()
if end2.Sub(end) > 10*time.Millisecond {
t.Error("ratelimiter improperly waited when tokens were available")
}
rl.Close()
}

View file

@ -11,7 +11,7 @@ import (
type SubscriberList struct { type SubscriberList struct {
sync.RWMutex sync.RWMutex
Members []chan<- ClientMessage Members []*ClientInfo
} }
var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList)
@ -19,6 +19,18 @@ var ChatSubscriptionLock sync.RWMutex
var GlobalSubscriptionInfo []*ClientInfo var GlobalSubscriptionInfo []*ClientInfo
var GlobalSubscriptionLock sync.RWMutex var GlobalSubscriptionLock sync.RWMutex
func (client *ClientInfo) Send(msg ClientMessage) bool {
select {
case client.MessageChannel <- msg:
return true
case <-client.MsgChannelIsDone:
return false
default:
// if we can't immediately send, ignore it
return false
}
}
func CountSubscriptions(channels []string) int { func CountSubscriptions(channels []string) int {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
defer ChatSubscriptionLock.RUnlock() defer ChatSubscriptionLock.RUnlock()
@ -38,70 +50,77 @@ func CountSubscriptions(channels []string) int {
func SubscribeChannel(client *ClientInfo, channelName string) { func SubscribeChannel(client *ClientInfo, channelName string) {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
_subscribeWhileRlocked(channelName, client.MessageChannel) _subscribeWhileRlocked(channelName, client)
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
} }
func SubscribeDefaults(client *ClientInfo) {
}
func SubscribeGlobal(client *ClientInfo) { func SubscribeGlobal(client *ClientInfo) {
GlobalSubscriptionLock.Lock() GlobalSubscriptionLock.Lock()
AddToSliceCl(&GlobalSubscriptionInfo, client) AddToSliceCl(&GlobalSubscriptionInfo, client)
GlobalSubscriptionLock.Unlock() GlobalSubscriptionLock.Unlock()
} }
func PublishToChannel(channel string, msg ClientMessage) (count int) { func PublishToChannel(channel string, msg ClientMessage, rl RateLimit) (count int) {
var found []*ClientInfo
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
list := ChatSubscriptionInfo[channel] list := ChatSubscriptionInfo[channel]
if list != nil { if list != nil {
list.RLock() list.RLock()
for _, msgChan := range list.Members { found = make([]*ClientInfo, len(list.Members))
msgChan <- msg copy(found, list.Members)
count++
}
list.RUnlock() list.RUnlock()
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return return
} }
func PublishToMultiple(channels []string, msg ClientMessage) (count int) { func PublishToMultiple(channels []string, msg ClientMessage, rl RateLimit) (count int) {
found := make(map[chan<- ClientMessage]struct{}) var found []*ClientInfo
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
for _, channel := range channels { for _, channel := range channels {
list := ChatSubscriptionInfo[channel] list := ChatSubscriptionInfo[channel]
if list != nil { if list != nil {
list.RLock() list.RLock()
for _, msgChan := range list.Members { for _, cl := range list.Members {
found[msgChan] = struct{}{} found = append(found, cl)
} }
list.RUnlock() list.RUnlock()
} }
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
for msgChan, _ := range found { for _, cl := range found {
msgChan <- msg rl.Performed()
count++ if cl.Send(msg) {
count++
}
} }
return return count
} }
func PublishToAll(msg ClientMessage) (count int) { func PublishToAll(msg ClientMessage, rl RateLimit) (count int) {
var found []*ClientInfo
GlobalSubscriptionLock.RLock() GlobalSubscriptionLock.RLock()
for _, client := range GlobalSubscriptionInfo { found = make([]*ClientInfo, len(GlobalSubscriptionInfo))
select { copy(found, GlobalSubscriptionInfo)
case client.MessageChannel <- msg:
case <-client.MsgChannelIsDone:
}
count++
}
GlobalSubscriptionLock.RUnlock() GlobalSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return return
} }
@ -110,7 +129,7 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
list := ChatSubscriptionInfo[channelName] list := ChatSubscriptionInfo[channelName]
if list != nil { if list != nil {
list.Lock() list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel) RemoveFromSliceCl(&list.Members, client)
list.Unlock() list.Unlock()
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
@ -138,7 +157,7 @@ func UnsubscribeAll(client *ClientInfo) {
list := ChatSubscriptionInfo[v] list := ChatSubscriptionInfo[v]
if list != nil { if list != nil {
list.Lock() list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel) RemoveFromSliceCl(&list.Members, client)
list.Unlock() list.Unlock()
} }
} }
@ -191,14 +210,14 @@ func pubsubJanitor_do() {
// - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object // - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object
// - possible write lock to the 'which' top-level map via the wlocker object // - possible write lock to the 'which' top-level map via the wlocker object
// - write lock to SubscriptionInfo (if not creating new) // - write lock to SubscriptionInfo (if not creating new)
func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { func _subscribeWhileRlocked(channelName string, value *ClientInfo) {
list := ChatSubscriptionInfo[channelName] list := ChatSubscriptionInfo[channelName]
if list == nil { if list == nil {
// Not found, so create it // Not found, so create it
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
ChatSubscriptionLock.Lock() ChatSubscriptionLock.Lock()
list = &SubscriberList{} list = &SubscriberList{}
list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper list.Members = []*ClientInfo{value} // Create it populated, to avoid reaper
ChatSubscriptionInfo[channelName] = list ChatSubscriptionInfo[channelName] = list
ChatSubscriptionLock.Unlock() ChatSubscriptionLock.Unlock()
@ -212,7 +231,7 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
} else { } else {
list.Lock() list.Lock()
AddToSliceC(&list.Members, value) AddToSliceCl(&list.Members, value)
list.Unlock() list.Unlock()
} }
} }

View file

@ -112,6 +112,7 @@ type ClientInfo struct {
// This field will be nil before it is closed. // This field will be nil before it is closed.
MessageChannel chan<- ClientMessage MessageChannel chan<- ClientMessage
// Closed when the client is shutting down.
MsgChannelIsDone <-chan struct{} MsgChannelIsDone <-chan struct{}
// Take out an Add() on this during a command if you need to use the MessageChannel later. // Take out an Add() on this during a command if you need to use the MessageChannel later.

View file

@ -130,38 +130,6 @@ func RemoveFromSliceS(ary *[]string, val string) bool {
return true return true
} }
func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
for _, v := range slice {
if v == val {
return false
}
}
slice = append(slice, val)
*ary = slice
return true
}
func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
var idx int = -1
for i, v := range slice {
if v == val {
idx = i
break
}
}
if idx == -1 {
return false
}
slice[idx] = slice[len(slice)-1]
slice = slice[:len(slice)-1]
*ary = slice
return true
}
func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
slice := *ary slice := *ary
for _, v := range slice { for _, v := range slice {