1
0
Fork 0
mirror of https://github.com/FrankerFaceZ/FrankerFaceZ.git synced 2025-08-15 18:40:54 +00:00

Merge pull request #98 from riking/publish-ratelimit

Add ratelimits to publishing
This commit is contained in:
Mike 2017-07-22 15:52:57 -04:00 committed by GitHub
commit d3cd94c262
13 changed files with 276 additions and 121 deletions

View file

@ -116,11 +116,9 @@ func commandLineConsole() {
if i >= count { if i >= count {
break break
} }
select { if cl.Send(msg) {
case cl.MessageChannel <- msg: kickCount++
case <-cl.MsgChannelIsDone:
} }
kickCount++
} }
return fmt.Sprintf("Kicked %d clients", kickCount), nil return fmt.Sprintf("Kicked %d clients", kickCount), nil
}) })

View file

@ -12,7 +12,6 @@ import (
"bitbucket.org/stendec/frankerfacez/socketserver/server" "bitbucket.org/stendec/frankerfacez/socketserver/server"
"github.com/clarkduvall/hyperloglog" "github.com/clarkduvall/hyperloglog"
"github.com/hashicorp/golang-lru"
) )
type serverFilter struct { type serverFilter struct {

View file

@ -12,11 +12,10 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"sync" cache "github.com/patrickmn/go-cache"
"github.com/pmylund/go-cache"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
) )

View file

@ -145,7 +145,6 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
uniqueUserChannel <- client.ClientID uniqueUserChannel <- client.ClientID
SubscribeGlobal(client) SubscribeGlobal(client)
SubscribeDefaults(client)
jsTime := float64(time.Now().UnixNano()/1000) / 1000 jsTime := float64(time.Now().UnixNano()/1000) / 1000
return ClientMessage{ return ClientMessage{
@ -197,7 +196,7 @@ func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
client.MsgChannelKeepalive.Add(1) client.MsgChannelKeepalive.Add(1)
go func() { go func() {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand})
SendBacklogForNewClient(client) SendBacklogForNewClient(client)
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
}() }()
@ -553,10 +552,7 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien
bsl.Lock() bsl.Lock()
for _, member := range bsl.Members { for _, member := range bsl.Members {
msg.MessageID = member.MessageID msg.MessageID = member.MessageID
select { member.Client.Send(msg)
case member.Client.MessageChannel <- msg:
case <-member.Client.MsgChannelIsDone:
}
} }
bsl.Unlock() bsl.Unlock()
}(br) }(br)
@ -580,7 +576,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if err == ErrAuthorizationNeeded { if err == ErrAuthorizationNeeded {
if client.TwitchUsername == "" { if client.TwitchUsername == "" {
// Not logged in // Not logged in
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError})
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
return return
} }
@ -588,19 +584,19 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if success { if success {
doRemoteCommand(conn, msg, client) doRemoteCommand(conn, msg, client)
} else { } else {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString})
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }
}) })
return // without keepalive.Done() return // without keepalive.Done()
} else if bfe, ok := err.(ErrForwardedFromBackend); ok { } else if bfe, ok := err.(ErrForwardedFromBackend); ok {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError})
} else if err != nil { } else if err != nil {
client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()})
} else { } else {
msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }

View file

@ -1,4 +1,4 @@
package server // import "bitbucket.org/stendec/frankerfacez/socketserver/server" package server // import "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server"
import ( import (
"encoding/json" "encoding/json"
@ -300,8 +300,8 @@ var CloseNonUTF8Data = websocket.CloseError{
Text: "Non UTF8 data recieved. Network corruption likely.", Text: "Non UTF8 data recieved. Network corruption likely.",
} }
const sendMessageBufferLength = 30 const sendMessageBufferLength = 5
const sendMessageAbortLength = 20 const sendMessageAbortLength = 5
// RunSocketConnection contains the main run loop of a websocket connection. // RunSocketConnection contains the main run loop of a websocket connection.
// //
@ -350,11 +350,8 @@ func RunSocketConnection(conn *websocket.Conn) {
closeConnection(conn, closeReason) closeConnection(conn, closeReason)
// closeConnection(conn, closeReason, &report) // closeConnection(conn, closeReason, &report)
// Launch message draining goroutine - we aren't out of the pub/sub records // We can just drop serverMessageChan and let it be picked up by GC, because all sends are nonblocking.
go func() { _serverMessageChan = nil
for _ = range _serverMessageChan {
}
}()
// Closes client.MsgChannelIsDone and also stops the reader thread // Closes client.MsgChannelIsDone and also stops the reader thread
close(stoppedChan) close(stoppedChan)
@ -364,11 +361,8 @@ func RunSocketConnection(conn *websocket.Conn) {
// Wait for pending jobs to finish... // Wait for pending jobs to finish...
client.MsgChannelKeepalive.Wait() client.MsgChannelKeepalive.Wait()
client.MessageChannel = nil
// And done. // And done.
// Close the channel so the draining goroutine can finish, too.
close(_serverMessageChan)
if !StopAcceptingConnections { if !StopAcceptingConnections {
// Don't perform high contention operations when server is closing // Don't perform high contention operations when server is closing

View file

@ -81,7 +81,7 @@ func (client *ClientInfo) StartAuthorization(callback AuthCallback) {
AddPendingAuthorization(client, challenge, callback) AddPendingAuthorization(client, challenge, callback)
client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge} client.Send(ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge})
} }
const AuthChannelName = "frankerfacezauthorizer" const AuthChannelName = "frankerfacezauthorizer"

View file

@ -7,11 +7,15 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate"
"github.com/pkg/errors"
) )
// LastSavedMessage contains a reply to a command along with an expiration time.
type LastSavedMessage struct { type LastSavedMessage struct {
Expires time.Time Expires time.Time
Data string Data string
} }
// map is command -> channel -> data // map is command -> channel -> data
@ -23,7 +27,7 @@ var CachedLSMLock sync.RWMutex
func cachedMessageJanitor() { func cachedMessageJanitor() {
for { for {
time.Sleep(1*time.Hour) time.Sleep(1 * time.Hour)
cachedMessageJanitor_do() cachedMessageJanitor_do()
} }
} }
@ -72,7 +76,7 @@ func SendBacklogForNewClient(client *ClientInfo) {
if ok { if ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
} }
} }
@ -88,7 +92,7 @@ func SendBacklogForChannel(client *ClientInfo, channel string) {
if msg, ok := chanMap[channel]; ok { if msg, ok := chanMap[channel]; ok {
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
msg.parseOrigArguments() msg.parseOrigArguments()
client.MessageChannel <- msg client.Send(msg)
} }
} }
CachedLSMLock.RUnlock() CachedLSMLock.RUnlock()
@ -132,6 +136,21 @@ func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
} }
} }
func rateLimitFromRequest(r *http.Request) (rate.Limiter, error) {
if r.FormValue("rateCount") != "" {
c, err := strconv.ParseInt(r.FormValue("rateCount"), 10, 32)
if err != nil {
return nil, errors.Wrap(err, "rateCount")
}
d, err := time.ParseDuration(r.FormValue("rateTime"))
if err != nil {
return nil, errors.Wrap(err, "rateTime")
}
return rate.NewRateLimit(int(c), d), nil
}
return rate.Unlimited(), nil
}
// HTTPBackendCachedPublish handles the /cached_pub route. // HTTPBackendCachedPublish handles the /cached_pub route.
// It publishes a message to clients, and then updates the in-server cache for the message. // It publishes a message to clients, and then updates the in-server cache for the message.
// //
@ -163,6 +182,12 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
} }
expires = time.Unix(timeNum, 0) expires = time.Unix(timeNum, 0)
} }
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return
}
var count int var count int
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
@ -174,8 +199,25 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
saveLastMessage(cmd, channel, expires, json, deleteMode) saveLastMessage(cmd, channel, expires, json, deleteMode)
} }
CachedLSMLock.Unlock() CachedLSMLock.Unlock()
count = PublishToMultiple(channels, msg)
var wg sync.WaitGroup
wg.Add(1)
go rl.Run()
go func() {
count = PublishToMultiple(channels, msg, rl)
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case <-time.After(3 * time.Second):
count = -1
case <-ch:
}
w.Write([]byte(strconv.Itoa(count))) w.Write([]byte(strconv.Itoa(count)))
} }
@ -199,26 +241,50 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
if cmd == "" { if cmd == "" {
w.WriteHeader(422) w.WriteHeader(422)
fmt.Fprintf(w, "Error: cmd cannot be blank") fmt.Fprint(w, "Error: cmd cannot be blank")
return return
} }
if channel == "" && scope != "global" { if channel == "" && scope != "global" {
w.WriteHeader(422) w.WriteHeader(422)
fmt.Fprintf(w, "Error: channel must be specified") fmt.Fprint(w, "Error: channel must be specified")
return
}
rl, err := rateLimitFromRequest(r)
if err != nil {
w.WriteHeader(422)
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
return return
} }
cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json} cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json}
cm.parseOrigArguments() cm.parseOrigArguments()
var count int
switch scope { var count int
default: var wg sync.WaitGroup
count = PublishToMultiple(strings.Split(channel, ","), cm) wg.Add(1)
case "global": go rl.Run()
count = PublishToAll(cm) go func() {
switch scope {
default:
count = PublishToMultiple(strings.Split(channel, ","), cm, rl)
case "global":
count = PublishToAll(cm, rl)
}
wg.Done()
rl.Close()
}()
ch := make(chan struct{})
go func() {
wg.Wait()
close(ch)
}()
select {
case <-time.After(3 * time.Second):
count = -1
case <-ch:
} }
fmt.Fprint(w, count) w.Write([]byte(strconv.Itoa(count)))
} }
// HTTPGetSubscriberCount handles the /get_sub_count route. // HTTPGetSubscriberCount handles the /get_sub_count route.
@ -236,4 +302,4 @@ func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) {
channel := formData.Get("channel") channel := formData.Get("channel")
fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ","))) fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ",")))
} }

View file

@ -16,9 +16,9 @@ func TestExpiredCleanup(t *testing.T) {
defer DumpBacklogData() defer DumpBacklogData()
var zeroTime time.Time var zeroTime time.Time
hourAgo := time.Now().Add(-1*time.Hour) hourAgo := time.Now().Add(-1 * time.Hour)
now := time.Now() now := time.Now()
hourFromNow := time.Now().Add(1*time.Hour) hourFromNow := time.Now().Add(1 * time.Hour)
saveLastMessage(cmd, channel, hourAgo, "1", false) saveLastMessage(cmd, channel, hourAgo, "1", false)
saveLastMessage(cmd, channel2, now, "2", false) saveLastMessage(cmd, channel2, now, "2", false)
@ -26,11 +26,11 @@ func TestExpiredCleanup(t *testing.T) {
if len(CachedLastMessages) != 1 { if len(CachedLastMessages) != 1 {
t.Error("messages not saved") t.Error("messages not saved")
} }
if len(CachedLastMessages[cmd]) != 2{ if len(CachedLastMessages[cmd]) != 2 {
t.Error("messages not saved") t.Error("messages not saved")
} }
time.Sleep(2*time.Millisecond) time.Sleep(2 * time.Millisecond)
cachedMessageJanitor_do() cachedMessageJanitor_do()
@ -47,7 +47,7 @@ func TestExpiredCleanup(t *testing.T) {
t.Error("messages not saved") t.Error("messages not saved")
} }
time.Sleep(2*time.Millisecond) time.Sleep(2 * time.Millisecond)
cachedMessageJanitor_do() cachedMessageJanitor_do()

View file

@ -0,0 +1,77 @@
package rate
import (
"io"
"time"
)
// A Limiter supports a constant number of Performed() calls every
// time a certain amount of time passes.
//
// Calls to Performed() when no "action tokens" are available will block
// until one is available.
type Limiter interface {
// Run begins emitting tokens for the ratelimiter.
// A call to Run must be followed by a call to Close.
Run()
// Performed consumes one token from the rate limiter.
// If no tokens are available, the call will block until one is.
Performed()
// Close stops the rate limiter. Any future calls to Performed() will block forever.
// Close never returns an error.
io.Closer
}
type timeRateLimit struct {
count int
period time.Duration
ch chan struct{}
done chan struct{}
}
// Construct a new Limiter with the given count and duration.
func NewRateLimit(count int, period time.Duration) Limiter {
return &timeRateLimit{
count: count,
period: period,
ch: make(chan struct{}),
done: make(chan struct{}),
}
}
func (r *timeRateLimit) Run() {
for {
waiter := time.After(r.period)
for i := 0; i < r.count; i++ {
select {
case r.ch <- struct{}{}:
// ok
case <-r.done:
return
}
}
<-waiter
}
}
func (r *timeRateLimit) Performed() {
<-r.ch
}
func (r *timeRateLimit) Close() error {
close(r.done)
return nil
}
type unlimited struct{}
var unlimitedInstance unlimited
// Unlimited returns a Limiter that never blocks. The Run() and Close() calls are no-ops.
func Unlimited() Limiter {
return unlimitedInstance
}
func (r unlimited) Run() {}
func (r unlimited) Performed() {}
func (r unlimited) Close() error { return nil }

View file

@ -0,0 +1,40 @@
package rate
import (
"testing"
"time"
)
var exampleData = []string{}
func ExampleNewRateLimit() {
rl := NewRateLimit(100, 1*time.Minute)
go rl.Run()
defer rl.Close()
for _, v := range exampleData {
rl.Performed()
// do something with v
_ = v
}
}
func TestRateLimit(t *testing.T) {
rl := NewRateLimit(3, 100*time.Millisecond)
start := time.Now()
go rl.Run()
for i := 0; i < 4; i++ {
rl.Performed()
}
end := time.Now()
if end.Sub(start) < 100*time.Millisecond {
t.Error("ratelimiter did not wait for period to expire")
}
rl.Performed()
rl.Performed()
end2 := time.Now()
if end2.Sub(end) > 10*time.Millisecond {
t.Error("ratelimiter improperly waited when tokens were available")
}
rl.Close()
}

View file

@ -1,17 +1,16 @@
package server package server
// This is the scariest code I've written yet for the server.
// If I screwed up the locking, I won't know until it's too late.
import ( import (
"log" "log"
"sync" "sync"
"time" "time"
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate"
) )
type SubscriberList struct { type SubscriberList struct {
sync.RWMutex sync.RWMutex
Members []chan<- ClientMessage Members []*ClientInfo
} }
var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList)
@ -19,6 +18,18 @@ var ChatSubscriptionLock sync.RWMutex
var GlobalSubscriptionInfo []*ClientInfo var GlobalSubscriptionInfo []*ClientInfo
var GlobalSubscriptionLock sync.RWMutex var GlobalSubscriptionLock sync.RWMutex
func (client *ClientInfo) Send(msg ClientMessage) bool {
select {
case client.MessageChannel <- msg:
return true
case <-client.MsgChannelIsDone:
return false
default:
// if we can't immediately send, ignore it
return false
}
}
func CountSubscriptions(channels []string) int { func CountSubscriptions(channels []string) int {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
defer ChatSubscriptionLock.RUnlock() defer ChatSubscriptionLock.RUnlock()
@ -38,70 +49,77 @@ func CountSubscriptions(channels []string) int {
func SubscribeChannel(client *ClientInfo, channelName string) { func SubscribeChannel(client *ClientInfo, channelName string) {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
_subscribeWhileRlocked(channelName, client.MessageChannel) _subscribeWhileRlocked(channelName, client)
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
} }
func SubscribeDefaults(client *ClientInfo) {
}
func SubscribeGlobal(client *ClientInfo) { func SubscribeGlobal(client *ClientInfo) {
GlobalSubscriptionLock.Lock() GlobalSubscriptionLock.Lock()
AddToSliceCl(&GlobalSubscriptionInfo, client) AddToSliceCl(&GlobalSubscriptionInfo, client)
GlobalSubscriptionLock.Unlock() GlobalSubscriptionLock.Unlock()
} }
func PublishToChannel(channel string, msg ClientMessage) (count int) { func PublishToChannel(channel string, msg ClientMessage, rl rate.Limiter) (count int) {
var found []*ClientInfo
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
list := ChatSubscriptionInfo[channel] list := ChatSubscriptionInfo[channel]
if list != nil { if list != nil {
list.RLock() list.RLock()
for _, msgChan := range list.Members { found = make([]*ClientInfo, len(list.Members))
msgChan <- msg copy(found, list.Members)
count++
}
list.RUnlock() list.RUnlock()
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return return
} }
func PublishToMultiple(channels []string, msg ClientMessage) (count int) { func PublishToMultiple(channels []string, msg ClientMessage, rl rate.Limiter) (count int) {
found := make(map[chan<- ClientMessage]struct{}) var found []*ClientInfo
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
for _, channel := range channels { for _, channel := range channels {
list := ChatSubscriptionInfo[channel] list := ChatSubscriptionInfo[channel]
if list != nil { if list != nil {
list.RLock() list.RLock()
for _, msgChan := range list.Members { for _, cl := range list.Members {
found[msgChan] = struct{}{} found = append(found, cl)
} }
list.RUnlock() list.RUnlock()
} }
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
for msgChan, _ := range found { for _, cl := range found {
msgChan <- msg rl.Performed()
count++ if cl.Send(msg) {
count++
}
} }
return return
} }
func PublishToAll(msg ClientMessage) (count int) { func PublishToAll(msg ClientMessage, rl rate.Limiter) (count int) {
var found []*ClientInfo
GlobalSubscriptionLock.RLock() GlobalSubscriptionLock.RLock()
for _, client := range GlobalSubscriptionInfo { found = make([]*ClientInfo, len(GlobalSubscriptionInfo))
select { copy(found, GlobalSubscriptionInfo)
case client.MessageChannel <- msg:
case <-client.MsgChannelIsDone:
}
count++
}
GlobalSubscriptionLock.RUnlock() GlobalSubscriptionLock.RUnlock()
for _, cl := range found {
rl.Performed()
if cl.Send(msg) {
count++
}
}
return return
} }
@ -110,7 +128,7 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
list := ChatSubscriptionInfo[channelName] list := ChatSubscriptionInfo[channelName]
if list != nil { if list != nil {
list.Lock() list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel) RemoveFromSliceCl(&list.Members, client)
list.Unlock() list.Unlock()
} }
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
@ -138,7 +156,7 @@ func UnsubscribeAll(client *ClientInfo) {
list := ChatSubscriptionInfo[v] list := ChatSubscriptionInfo[v]
if list != nil { if list != nil {
list.Lock() list.Lock()
RemoveFromSliceC(&list.Members, client.MessageChannel) RemoveFromSliceCl(&list.Members, client)
list.Unlock() list.Unlock()
} }
} }
@ -191,14 +209,14 @@ func pubsubJanitor_do() {
// - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object // - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object
// - possible write lock to the 'which' top-level map via the wlocker object // - possible write lock to the 'which' top-level map via the wlocker object
// - write lock to SubscriptionInfo (if not creating new) // - write lock to SubscriptionInfo (if not creating new)
func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { func _subscribeWhileRlocked(channelName string, value *ClientInfo) {
list := ChatSubscriptionInfo[channelName] list := ChatSubscriptionInfo[channelName]
if list == nil { if list == nil {
// Not found, so create it // Not found, so create it
ChatSubscriptionLock.RUnlock() ChatSubscriptionLock.RUnlock()
ChatSubscriptionLock.Lock() ChatSubscriptionLock.Lock()
list = &SubscriberList{} list = &SubscriberList{}
list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper list.Members = []*ClientInfo{value} // Create it populated, to avoid reaper
ChatSubscriptionInfo[channelName] = list ChatSubscriptionInfo[channelName] = list
ChatSubscriptionLock.Unlock() ChatSubscriptionLock.Unlock()
@ -212,7 +230,7 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
} else { } else {
list.Lock() list.Lock()
AddToSliceC(&list.Members, value) AddToSliceCl(&list.Members, value)
list.Unlock() list.Unlock()
} }
} }

View file

@ -108,10 +108,10 @@ type ClientInfo struct {
// True if the client has already sent the 'ready' command // True if the client has already sent the 'ready' command
ReadyComplete bool ReadyComplete bool
// Server-initiated messages should be sent here // Server-initiated messages should be sent via the Send() method.
// This field will be nil before it is closed.
MessageChannel chan<- ClientMessage MessageChannel chan<- ClientMessage
// Closed when the client is shutting down.
MsgChannelIsDone <-chan struct{} MsgChannelIsDone <-chan struct{}
// Take out an Add() on this during a command if you need to use the MessageChannel later. // Take out an Add() on this during a command if you need to use the MessageChannel later.

View file

@ -130,38 +130,6 @@ func RemoveFromSliceS(ary *[]string, val string) bool {
return true return true
} }
func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
for _, v := range slice {
if v == val {
return false
}
}
slice = append(slice, val)
*ary = slice
return true
}
func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool {
slice := *ary
var idx int = -1
for i, v := range slice {
if v == val {
idx = i
break
}
}
if idx == -1 {
return false
}
slice[idx] = slice[len(slice)-1]
slice = slice[:len(slice)-1]
*ary = slice
return true
}
func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
slice := *ary slice := *ary
for _, v := range slice { for _, v := range slice {