1
0
Fork 0
mirror of https://github.com/FrankerFaceZ/FrankerFaceZ.git synced 2025-06-27 21:05:53 +00:00

Add ratelimits to publishing

This commit is contained in:
Kane York 2017-02-02 22:59:17 -08:00
parent b84bd1d4a2
commit 50e295c834
10 changed files with 260 additions and 103 deletions

View file

@ -116,9 +116,8 @@ func commandLineConsole() {
if i >= count {
break
}
select {
case cl.MessageChannel <- msg:
case <-cl.MsgChannelIsDone:
if cl.Send(msg) {
kickCount++
}
kickCount++
}

View file

@ -145,7 +145,6 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
uniqueUserChannel <- client.ClientID
SubscribeGlobal(client)
SubscribeDefaults(client)
jsTime := float64(time.Now().UnixNano()/1000) / 1000
return ClientMessage{
@ -197,7 +196,7 @@ func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
client.MsgChannelKeepalive.Add(1)
go func() {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}
client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand})
SendBacklogForNewClient(client)
client.MsgChannelKeepalive.Done()
}()
@ -553,10 +552,7 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien
bsl.Lock()
for _, member := range bsl.Members {
msg.MessageID = member.MessageID
select {
case member.Client.MessageChannel <- msg:
case <-member.Client.MsgChannelIsDone:
}
member.Client.Send(msg)
}
bsl.Unlock()
}(br)
@ -580,7 +576,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if err == ErrAuthorizationNeeded {
if client.TwitchUsername == "" {
// Not logged in
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError}
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError})
client.MsgChannelKeepalive.Done()
return
}
@ -588,19 +584,19 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if success {
doRemoteCommand(conn, msg, client)
} else {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString}
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString})
client.MsgChannelKeepalive.Done()
}
})
return // without keepalive.Done()
} else if bfe, ok := err.(ErrForwardedFromBackend); ok {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError}
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError})
} else if err != nil {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()})
} else {
msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}
msg.parseOrigArguments()
client.MessageChannel <- msg
client.Send(msg)
}
client.MsgChannelKeepalive.Done()
}

View file

@ -300,8 +300,8 @@ var CloseNonUTF8Data = websocket.CloseError{
Text: "Non UTF8 data recieved. Network corruption likely.",
}
const sendMessageBufferLength = 30
const sendMessageAbortLength = 20
const sendMessageBufferLength = 5
const sendMessageAbortLength = 5
// RunSocketConnection contains the main run loop of a websocket connection.
//
@ -350,11 +350,8 @@ func RunSocketConnection(conn *websocket.Conn) {
closeConnection(conn, closeReason)
// closeConnection(conn, closeReason, &report)
// Launch message draining goroutine - we aren't out of the pub/sub records
go func() {
for _ = range _serverMessageChan {
}
}()
// We can just drop serverMessageChan and let it be picked up by GC, because all sends are nonblocking.
_serverMessageChan = nil
// Closes client.MsgChannelIsDone and also stops the reader thread
close(stoppedChan)
@ -364,11 +361,8 @@ func RunSocketConnection(conn *websocket.Conn) {
// Wait for pending jobs to finish...
client.MsgChannelKeepalive.Wait()
client.MessageChannel = nil
// And done.
// Close the channel so the draining goroutine can finish, too.
close(_serverMessageChan)
if !StopAcceptingConnections {
// Don't perform high contention operations when server is closing

View file

@ -81,7 +81,7 @@ func (client *ClientInfo) StartAuthorization(callback AuthCallback) {
AddPendingAuthorization(client, challenge, callback)
client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge}
client.Send(ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge})
}
const AuthChannelName = "frankerfacezauthorizer"

View file

@ -7,8 +7,10 @@ import (
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
// LastSavedMessage contains a reply to a command along with an expiration time.
type LastSavedMessage struct {
Expires time.Time
Data string
@ -72,7 +74,7 @@ func SendBacklogForNewClient(client *ClientInfo) {
if ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments()
client.MessageChannel <- msg
client.Send(msg)
}
}
}
@ -88,7 +90,7 @@ func SendBacklogForChannel(client *ClientInfo, channel string) {
if msg, ok := chanMap[channel]; ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments()
client.MessageChannel <- msg
client.Send(msg)
}
}
CachedLSMLock.RUnlock()
@ -132,6 +134,21 @@ func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
}
}
func rateLimitFromRequest(r *http.Request) (RateLimit, error) {
if r.FormValue("rateCount") != "" {
c, err := strconv.ParseInt(r.FormValue("rateCount"), 10, 32)
if err != nil {
return nil, errors.Wrap(err, "rateCount")
}
d, err := time.ParseDuration(r.FormValue("rateTime"))
if err != nil {
return nil, errors.Wrap(err, "rateTime")
}
return NewRateLimit(int(c), d), nil
}
return Unlimited(), nil
}
// HTTPBackendCachedPublish handles the /cached_pub route.
// It publishes a message to clients, and then updates the in-server cache for the message.
//
@ -163,6 +180,12 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
}
expires = time.Unix(timeNum, 0)
}
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return
}
var count int
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
@ -174,8 +197,25 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
saveLastMessage(cmd, channel, expires, json, deleteMode)
}
CachedLSMLock.Unlock()
count = PublishToMultiple(channels, msg)
var wg sync.WaitGroup
wg.Add(1)
go rl.Run()
go func() {
count = PublishToMultiple(channels, msg, rl)
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case time.After(3*time.Second):
count = -1
case <-ch:
}
w.Write([]byte(strconv.Itoa(count)))
}
@ -199,26 +239,50 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
if cmd == "" {
w.WriteHeader(422)
fmt.Fprintf(w, "Error: cmd cannot be blank")
fmt.Fprint(w, "Error: cmd cannot be blank")
return
}
if channel == "" && scope != "global" {
w.WriteHeader(422)
fmt.Fprintf(w, "Error: channel must be specified")
fmt.Fprint(w, "Error: channel must be specified")
return
}
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return
}
cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json}
cm.parseOrigArguments()
var count int
switch scope {
default:
count = PublishToMultiple(strings.Split(channel, ","), cm)
case "global":
count = PublishToAll(cm)
var count int
var wg sync.WaitGroup
wg.Add(1)
go rl.Run()
go func() {
switch scope {
default:
count = PublishToMultiple(strings.Split(channel, ","), cm, rl)
case "global":
count = PublishToAll(cm, rl)
}
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case time.After(3*time.Second):
count = -1
case <-ch:
}
fmt.Fprint(w, count)
w.Write([]byte(strconv.Itoa(count)))
}
// HTTPGetSubscriberCount handles the /get_sub_count route.

View file

@ -0,0 +1,76 @@
package server
import (
"time"
"io"
)
// A RateLimit supports a constant number of Performed() calls every
// time a given unit of time passes.
//
// Calls to Performed() when no "action tokens" are available will block
// until one is available.
type RateLimit interface {
// Run begins emitting tokens for the ratelimiter.
// A call to Run must be followed by a call to Close.
Run()
// Performed consumes one token from the rate limiter.
// If no tokens are available, the call will block until one is.
Performed()
// Close stops the rate limiter. Any future calls to Performed() will block forever.
// Close never returns an error.
io.Closer
}
type timeRateLimit struct{
count int
period time.Duration
ch chan struct{}
done chan struct{}
}
// Construct a new RateLimit with the given count and duration.
func NewRateLimit(count int, period time.Duration) (RateLimit) {
return &timeRateLimit{
count: count,
period: period,
ch: make(chan struct{}),
done: make(chan struct{}),
}
}
func (r *timeRateLimit) Run() {
for {
waiter := time.After(r.period)
for i := 0; i < r.count; i++ {
select {
case r.ch <- struct{}{}:
// ok
case <-r.done:
return
}
}
<-waiter
}
}
func (r *timeRateLimit) Performed() {
<-r.ch
}
func (r *timeRateLimit) Close() error {
close(r.done)
return nil
}
type unlimited struct{}
var unlimitedInstance unlimited
// Unlimited returns a RateLimit that never blocks. The Run() and Close() calls are no-ops.
func Unlimited() (RateLimit) {
return unlimitedInstance
}
func (r unlimited) Run() { }
func (r unlimited) Performed() { }
func (r unlimited) Close() error { return nil }

View file

@ -0,0 +1,40 @@
package server
import (
"time"
"testing"
)
var exampleData = []string{}
func ExampleNewRateLimit() {
rl := NewRateLimit(100, 1*time.Minute)
go rl.Run()
defer rl.Close()
for _, v := range exampleData {
rl.Performed()
// do something with v
_ = v
}
}
func TestRateLimit(t *testing.T) {
rl := NewRateLimit(3, 100*time.Millisecond)
start := time.Now()
go rl.Run()
for i := 0; i < 4; i++ {
rl.Performed()
}
end := time.Now()
if end.Sub(start) < 100*time.Millisecond {
t.Error("ratelimiter did not wait for period to expire")
}
rl.Performed()
rl.Performed()
end2 := time.Now()
if end2.Sub(end) > 10*time.Millisecond {
t.Error("ratelimiter improperly waited when tokens were available")
}
rl.Close()
}

View file

@ -11,7 +11,7 @@ import (
type SubscriberList struct {
sync.RWMutex
Members []chan<- ClientMessage
Members []*ClientInfo
}
var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList)
@ -19,6 +19,18 @@ var ChatSubscriptionLock sync.RWMutex
var GlobalSubscriptionInfo []*ClientInfo
var GlobalSubscriptionLock sync.RWMutex
func (client *ClientInfo) Send(msg ClientMessage) bool {
select {
case client.MessageChannel <- msg:
return true
case <-client.MsgChannelIsDone:
return false
default:
// if we can't immediately send, ignore it
return false
}
}
func CountSubscriptions(channels []string) int {
ChatSubscriptionLock.RLock()
defer ChatSubscriptionLock.RUnlock()
@ -38,70 +50,77 @@ func CountSubscriptions(channels []string) int {
func SubscribeChannel(client *ClientInfo, channelName string) {
ChatSubscriptionLock.RLock()
_subscribeWhileRlocked(channelName, client.MessageChannel)
_subscribeWhileRlocked(channelName, client)
ChatSubscriptionLock.RUnlock()
}
func SubscribeDefaults(client *ClientInfo) {
}
func SubscribeGlobal(client *ClientInfo) {
GlobalSubscriptionLock.Lock()
AddToSliceCl(&GlobalSubscriptionInfo, client)
GlobalSubscriptionLock.Unlock()
}
func PublishToChannel(channel string, msg ClientMessage) (count int) {
func PublishToChannel(channel string, msg ClientMessage, rl RateLimit) (count int) {
var found []*ClientInfo
ChatSubscriptionLock.RLock()
list := ChatSubscriptionInfo[channel]
if list != nil {
list.RLock()
for _, msgChan := range list.Members {
msgChan <- msg
count++
}
found = make([]*ClientInfo, len(list.Members))
copy(found, list.Members)
list.RUnlock()
}
ChatSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return
}
func PublishToMultiple(channels []string, msg ClientMessage) (count int) {
found := make(map[chan<- ClientMessage]struct{})
func PublishToMultiple(channels []string, msg ClientMessage, rl RateLimit) (count int) {
var found []*ClientInfo
ChatSubscriptionLock.RLock()
for _, channel := range channels {
list := ChatSubscriptionInfo[channel]
if list != nil {
list.RLock()
for _, msgChan := range list.Members {
found[msgChan] = struct{}{}
for _, cl := range list.Members {
found = append(found, cl)
}
list.RUnlock()
}
}
ChatSubscriptionLock.RUnlock()
for msgChan, _ := range found {
msgChan <- msg
count++
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return
return count
}
func PublishToAll(msg ClientMessage) (count int) {
func PublishToAll(msg ClientMessage, rl RateLimit) (count int) {
var found []*ClientInfo
GlobalSubscriptionLock.RLock()
for _, client := range GlobalSubscriptionInfo {
select {
case client.MessageChannel <- msg:
case <-client.MsgChannelIsDone:
}
count++
}
found = make([]*ClientInfo, len(GlobalSubscriptionInfo))
copy(found, GlobalSubscriptionInfo)
GlobalSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return
}
@ -110,7 +129,7 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
list := ChatSubscriptionInfo[channelName]
if list != nil {
list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel)
RemoveFromSliceCl(&list.Members, client)
list.Unlock()
}
ChatSubscriptionLock.RUnlock()
@ -138,7 +157,7 @@ func UnsubscribeAll(client *ClientInfo) {
list := ChatSubscriptionInfo[v]
if list != nil {
list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel)
RemoveFromSliceCl(&list.Members, client)
list.Unlock()
}
}
@ -191,14 +210,14 @@ func pubsubJanitor_do() {
// - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object
// - possible write lock to the 'which' top-level map via the wlocker object
// - write lock to SubscriptionInfo (if not creating new)
func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
func _subscribeWhileRlocked(channelName string, value *ClientInfo) {
list := ChatSubscriptionInfo[channelName]
if list == nil {
// Not found, so create it
ChatSubscriptionLock.RUnlock()
ChatSubscriptionLock.Lock()
list = &SubscriberList{}
list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper
list.Members = []*ClientInfo{value} // Create it populated, to avoid reaper
ChatSubscriptionInfo[channelName] = list
ChatSubscriptionLock.Unlock()
@ -212,7 +231,7 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
ChatSubscriptionLock.RLock()
} else {
list.Lock()
AddToSliceC(&list.Members, value)
AddToSliceCl(&list.Members, value)
list.Unlock()
}
}

View file

@ -112,6 +112,7 @@ type ClientInfo struct {
// This field will be nil before it is closed.
MessageChannel chan<- ClientMessage
// Closed when the client is shutting down.
MsgChannelIsDone <-chan struct{}
// Take out an Add() on this during a command if you need to use the MessageChannel later.

View file

@ -130,38 +130,6 @@ func RemoveFromSliceS(ary *[]string, val string) bool {
return true
}
func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
for _, v := range slice {
if v == val {
return false
}
}
slice = append(slice, val)
*ary = slice
return true
}
func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
var idx int = -1
for i, v := range slice {
if v == val {
idx = i
break
}
}
if idx == -1 {
return false
}
slice[idx] = slice[len(slice)-1]
slice = slice[:len(slice)-1]
*ary = slice
return true
}
func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
slice := *ary
for _, v := range slice {