From 0c9e7bb97d41f6e15136e36fec6f03b481fb368f Mon Sep 17 00:00:00 2001 From: Kane York Date: Sun, 25 Oct 2015 20:17:17 -0700 Subject: [PATCH] Add tests, fix some bugs uncovered by the tests --- socketserver/cmd/socketserver/socketserver.go | 2 +- socketserver/internal/server/backend.go | 15 +- socketserver/internal/server/backend_test.go | 13 +- socketserver/internal/server/commands.go | 27 +++- socketserver/internal/server/handlecore.go | 14 +- socketserver/internal/server/publisher.go | 79 ++++++++-- .../internal/server/publisher_test.go | 141 ++++++++++++++++++ 7 files changed, 261 insertions(+), 30 deletions(-) create mode 100644 socketserver/internal/server/publisher_test.go diff --git a/socketserver/cmd/socketserver/socketserver.go b/socketserver/cmd/socketserver/socketserver.go index 49db8872..463d8774 100644 --- a/socketserver/cmd/socketserver/socketserver.go +++ b/socketserver/cmd/socketserver/socketserver.go @@ -47,7 +47,7 @@ func main() { Addr: *bindAddress, } - server.SetupServerAndHandle(conf, httpServer.TLSConfig) + server.SetupServerAndHandle(conf, httpServer.TLSConfig, nil) var err error if conf.UseSSL { diff --git a/socketserver/internal/server/backend.go b/socketserver/internal/server/backend.go index af4668cf..6c95c261 100644 --- a/socketserver/internal/server/backend.go +++ b/socketserver/internal/server/backend.go @@ -86,7 +86,12 @@ func RequestRemoteData(remoteCommand, data string, auth AuthInfo) (responseStr s authKey: []string{auth.TwitchUsername}, } - resp, err := backendHttpClient.PostForm(destUrl, formData) + sealedForm, err := SealRequest(formData) + if err != nil { + return "", err + } + + resp, err := backendHttpClient.PostForm(destUrl, sealedForm) if err != nil { return "", err } @@ -117,7 +122,12 @@ func FetchBacklogData(chatSubs, channelSubs []string) ([]ClientMessage, error) { "channelSubs": channelSubs, } - resp, err := backendHttpClient.PostForm(getBacklogUrl, formData) + sealedForm, err := SealRequest(formData) + if err != nil { + return nil, err + } + + resp, err := backendHttpClient.PostForm(getBacklogUrl, sealedForm) if err != nil { return nil, err } @@ -152,7 +162,6 @@ func GenerateKeys(outputFile, serverId, theirPublicStr string) { if err != nil { log.Fatal(err) } - log.Print(theirPublic) output.TheirPublicKey = theirPublic } diff --git a/socketserver/internal/server/backend_test.go b/socketserver/internal/server/backend_test.go index 3a2d5a7b..e02d6c0e 100644 --- a/socketserver/internal/server/backend_test.go +++ b/socketserver/internal/server/backend_test.go @@ -6,29 +6,32 @@ import ( "crypto/rand" ) -func TestSealRequest(t *testing.T) { - senderPublic, senderPrivate, err := box.GenerateKey(rand.Reader) +func SetupRandomKeys(t testing.TB) { + _, senderPrivate, err := box.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } - receiverPublic, receiverPrivate, err := box.GenerateKey(rand.Reader) + receiverPublic, _, err := box.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } + box.Precompute(&backendSharedKey, receiverPublic, senderPrivate) messageBufferPool.New = New4KByteBuffer +} + +func TestSealRequest(t *testing.T) { + SetupRandomKeys(t) values := url.Values{ "QuickBrownFox": []string{"LazyDog"}, } - box.Precompute(&backendSharedKey, receiverPublic, senderPrivate) sealedValues, err := SealRequest(values) if err != nil { t.Fatal(err) } - box.Precompute(&backendSharedKey, senderPublic, receiverPrivate) unsealedValues, err := UnsealRequest(sealedValues) if err != nil { t.Fatal(err) diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go index 52692254..e30b543f 100644 --- a/socketserver/internal/server/commands.go +++ b/socketserver/internal/server/commands.go @@ -23,7 +23,7 @@ func HandleCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) return } - log.Println(conn.RemoteAddr(), msg.MessageID, msg.Command, msg.Arguments) +// log.Println(conn.RemoteAddr(), msg.MessageID, msg.Command, msg.Arguments) response, err := CallHandler(handler, conn, client, msg) @@ -76,6 +76,10 @@ func HandleSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + client.Mutex.Lock() AddToSliceS(&client.CurrentChannels, channel) @@ -91,7 +95,7 @@ func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rms client.Mutex.Unlock() - // note - pub/sub updating happens in GetSubscriptionBacklog + SubscribeChat(client, channel) return ResponseSuccess, nil } @@ -99,6 +103,10 @@ func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rms func HandleUnsub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + client.Mutex.Lock() RemoveFromSliceS(&client.CurrentChannels, channel) client.Mutex.Unlock() @@ -111,6 +119,10 @@ func HandleUnsub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (r func HandleSubChannel(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + client.Mutex.Lock() AddToSliceS(&client.WatchingChannels, channel) @@ -126,7 +138,7 @@ func HandleSubChannel(conn *websocket.Conn, client *ClientInfo, msg ClientMessag client.Mutex.Unlock() - // note - pub/sub updating happens in GetSubscriptionBacklog + SubscribeWatching(client, channel) return ResponseSuccess, nil } @@ -134,6 +146,10 @@ func HandleSubChannel(conn *websocket.Conn, client *ClientInfo, msg ClientMessag func HandleUnsubChannel(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + client.Mutex.Lock() RemoveFromSliceS(&client.WatchingChannels, channel) client.Mutex.Unlock() @@ -166,8 +182,9 @@ func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) { return } - SubscribeBatch(client, chatSubs, channelSubs) - + if backendUrl == "" { + return // for testing runs + } messages, err := FetchBacklogData(chatSubs, channelSubs) if err != nil { diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go index b6161db7..7ef4e533 100644 --- a/socketserver/internal/server/handlecore.go +++ b/socketserver/internal/server/handlecore.go @@ -74,6 +74,7 @@ var FFZCodec websocket.Codec = websocket.Codec{ // Errors that get returned to the client. var ProtocolError error = errors.New("FFZ Socket protocol error.") +var ProtocolErrorNegativeID error = errors.New("FFZ Socket protocol error: negative or zero message ID.") var ExpectedSingleString = errors.New("Error: Expected single string as arguments.") var ExpectedSingleInt = errors.New("Error: Expected single integer as arguments.") var ExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.") @@ -116,11 +117,14 @@ func setupServer(config *Config, tlsConfig *tls.Config) *websocket.Server { // Set up a websocket listener and register it on /. // (Uses http.DefaultServeMux .) -func SetupServerAndHandle(config *Config, tlsConfig *tls.Config) { +func SetupServerAndHandle(config *Config, tlsConfig *tls.Config, serveMux *http.ServeMux) { sockServer := setupServer(config, tlsConfig) - http.HandleFunc("/", sockServer.ServeHTTP) - http.HandleFunc("/pub", HandlePublishRequest) + if serveMux == nil { + serveMux = http.DefaultServeMux + } + serveMux.HandleFunc("/", sockServer.ServeHTTP) + serveMux.HandleFunc("/pub", HandlePublishRequest) } // Handle a new websocket connection from a FFZ client. @@ -235,8 +239,8 @@ func UnmarshalClientMessage(data []byte, payloadType byte, v interface{}) (err e return ProtocolError } messageId, err := strconv.Atoi(dataStr[:spaceIdx]) - if messageId <= 0 { - return ProtocolError + if messageId < -1 || messageId == 0 { + return ProtocolErrorNegativeID } out.MessageID = messageId diff --git a/socketserver/internal/server/publisher.go b/socketserver/internal/server/publisher.go index 61d521d8..eae5488e 100644 --- a/socketserver/internal/server/publisher.go +++ b/socketserver/internal/server/publisher.go @@ -7,6 +7,7 @@ import ( "sync" "time" "net/http" + "fmt" ) type SubscriberList struct { @@ -14,39 +15,65 @@ type SubscriberList struct { Members []chan <- ClientMessage } -var ChatSubscriptionInfo map[string]*SubscriberList +var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) var ChatSubscriptionLock sync.RWMutex -var WatchingSubscriptionInfo map[string]*SubscriberList +var WatchingSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) var WatchingSubscriptionLock sync.RWMutex -func PublishToChat(channel string, msg ClientMessage) { +func PublishToChat(channel string, msg ClientMessage) (count int) { ChatSubscriptionLock.RLock() list := ChatSubscriptionInfo[channel] if list != nil { list.RLock() for _, ch := range list.Members { ch <- msg + count++ } list.RUnlock() } ChatSubscriptionLock.RUnlock() + return } -func PublishToWatchers(channel string, msg ClientMessage) { +func PublishToWatchers(channel string, msg ClientMessage) (count int) { WatchingSubscriptionLock.RLock() list := WatchingSubscriptionInfo[channel] if list != nil { list.RLock() for _, ch := range list.Members { ch <- msg + count++ } list.RUnlock() } WatchingSubscriptionLock.RUnlock() + return } func HandlePublishRequest(w http.ResponseWriter, r *http.Request) { - // TODO - box.Open() + formData, err := UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + cmd := formData.Get("cmd") + json := formData.Get("args") + chat := formData.Get("chat") + watchChannel := formData.Get("channel") + cm := ClientMessage{MessageID: -1, Command: Command(cmd), origArguments: json} + var count int + if chat != "" { + count = PublishToChat(chat, cm) + } else if watchChannel != "" { + count = PublishToWatchers(watchChannel, cm) + } else { + w.WriteHeader(400) + fmt.Fprint(w, "Need to specify either chat or channel") + return + } + fmt.Fprint(w, count) } // Add a channel to the subscriptions while holding a read-lock to the map. @@ -72,6 +99,18 @@ func _subscribeWhileRlocked(which map[string]*SubscriberList, channelName string } } +func SubscribeChat(client *ClientInfo, channelName string) { + ChatSubscriptionLock.RLock() + _subscribeWhileRlocked(ChatSubscriptionInfo, channelName, client.MessageChannel, ChatSubscriptionLock.RLocker(), &ChatSubscriptionLock) + ChatSubscriptionLock.RUnlock() +} + +func SubscribeWatching(client *ClientInfo, channelName string) { + WatchingSubscriptionLock.RLock() + _subscribeWhileRlocked(WatchingSubscriptionInfo, channelName, client.MessageChannel, WatchingSubscriptionLock.RLocker(), &WatchingSubscriptionLock) + WatchingSubscriptionLock.RUnlock() +} + // Locks: // - read lock to top-level maps // - possible write lock to top-level maps @@ -102,13 +141,20 @@ func SubscribeBatch(client *ClientInfo, chatSubs, channelSubs []string) { // - write lock to SubscriptionInfos // - write lock to ClientInfo func UnsubscribeAll(client *ClientInfo) { + client.Mutex.Lock() + client.PendingChatBacklogs = nil + client.PendingStreamBacklogs = nil + client.Mutex.Unlock() + ChatSubscriptionLock.RLock() client.Mutex.Lock() for _, v := range client.CurrentChannels { list := ChatSubscriptionInfo[v] - list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) - list.Unlock() + if list != nil { + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + } } client.CurrentChannels = nil client.Mutex.Unlock() @@ -118,15 +164,26 @@ func UnsubscribeAll(client *ClientInfo) { client.Mutex.Lock() for _, v := range client.WatchingChannels { list := WatchingSubscriptionInfo[v] - list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) - list.Unlock() + if list != nil { + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + } } client.WatchingChannels = nil client.Mutex.Unlock() WatchingSubscriptionLock.RUnlock() } +func unsubscribeAllClients() { + ChatSubscriptionLock.Lock() + ChatSubscriptionInfo = make(map[string]*SubscriberList) + ChatSubscriptionLock.Unlock() + WatchingSubscriptionLock.Lock() + WatchingSubscriptionInfo = make(map[string]*SubscriberList) + WatchingSubscriptionLock.Unlock() +} + func UnsubscribeSingleChat(client *ClientInfo, channelName string) { ChatSubscriptionLock.RLock() list := ChatSubscriptionInfo[channelName] diff --git a/socketserver/internal/server/publisher_test.go b/socketserver/internal/server/publisher_test.go new file mode 100644 index 00000000..ec4855c3 --- /dev/null +++ b/socketserver/internal/server/publisher_test.go @@ -0,0 +1,141 @@ +package server +import ( + "testing" + "net/http/httptest" + "net/http" + "sync" + "golang.org/x/net/websocket" + "github.com/satori/go.uuid" + "fmt" + "syscall" + "os" + "io/ioutil" +) + +func CountOpenFDs() uint64 { + ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())) + return uint64(len(ary)) +} + +func BenchmarkThousandUserSubscription(b *testing.B) { + var doneWg sync.WaitGroup + var readyWg sync.WaitGroup + + const TestChannelName = "testchannel" + const TestCommand = "testdata" + + GenerateKeys("/tmp/test_naclkeys.json", "2", "+ZMqOmxhaVrCV5c0OMZ09QoSGcJHuqQtJrwzRD+JOjE=") + conf := &Config{ + UseSSL: false, + NaclKeysFile: "/tmp/test_naclkeys.json", + SocketOrigin: "localhost:2002", + } + serveMux := http.NewServeMux() + SetupServerAndHandle(conf, nil, serveMux) + + server := httptest.NewUnstartedServer(serveMux) + server.Start() + + wsUrl := fmt.Sprintf("ws://%s/", server.Listener.Addr().String()) + originUrl := fmt.Sprintf("http://%s", server.Listener.Addr().String()) + + message := ClientMessage{MessageID: -1, Command: "testdata", Arguments: "123456789"} + + fmt.Println() + fmt.Println(b.N) + + var limit syscall.Rlimit + syscall.Getrlimit(syscall.RLIMIT_NOFILE, &limit) + + limit.Cur = CountOpenFDs() + uint64(b.N) * 2 + 100 + + if limit.Cur > limit.Max { + b.Skip("Open file limit too low") + return + } + + syscall.Setrlimit(syscall.RLIMIT_NOFILE, &limit) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conn, err := websocket.Dial(wsUrl, "", originUrl) + if err != nil { + b.Error(err) + break + } + doneWg.Add(1) + readyWg.Add(1) + go func(i int, conn *websocket.Conn) { + var err error + var msg ClientMessage + err = FFZCodec.Send(conn, ClientMessage{MessageID: 1, Command: HelloCommand, Arguments: []interface{}{"ffz_test", uuid.NewV4().String()}}) + if err != nil { + b.Error(err) + } + err = FFZCodec.Send(conn, ClientMessage{MessageID: 2, Command: "sub", Arguments: TestChannelName}) + if err != nil { + b.Error(err) + } + err = FFZCodec.Receive(conn, &msg) + if err != nil { + b.Error(err) + } + if msg.MessageID != 1 { + b.Error("Got out-of-order message ID", msg) + } + if msg.Command != SuccessCommand { + b.Error("Command was not a success", msg) + } + err = FFZCodec.Receive(conn, &msg) + if err != nil { + b.Error(err) + } + if msg.MessageID != 2 { + b.Error("Got out-of-order message ID", msg) + } + if msg.Command != SuccessCommand { + b.Error("Command was not a success", msg) + } + + fmt.Println(i, " ready") + readyWg.Done() + + err = FFZCodec.Receive(conn, &msg) + if err != nil { + b.Error(err) + } + if msg.MessageID != -1 { + fmt.Println(msg) + b.Error("Client did not get expected messageID of -1") + } + if msg.Command != TestCommand { + fmt.Println(msg) + b.Error("Client did not get expected command") + } + str, err := msg.ArgumentsAsString() + if err != nil { + b.Error(err) + } + if str != "123456789" { + fmt.Println(msg) + b.Error("Client did not get expected data") + } + conn.Close() + doneWg.Done() + }(i, conn) + } + + readyWg.Wait() + + fmt.Println("publishing...") + if PublishToChat(TestChannelName, message) != b.N { + b.Error("not enough sent") + b.FailNow() + } + doneWg.Wait() + + b.StopTimer() + server.Close() + unsubscribeAllClients() + server.CloseClientConnections() +}