diff --git a/.gitignore b/.gitignore index 96f9db49..a524bdc7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,8 @@ npm-debug.log build Extension Building .idea +*.iml script.js -script.min.js \ No newline at end of file +script.min.js + +/socketserver/cmd/socketserver/socketserver diff --git a/socketserver/cmd/ffzsocketserver/socketserver.go b/socketserver/cmd/ffzsocketserver/socketserver.go new file mode 100644 index 00000000..5de7a059 --- /dev/null +++ b/socketserver/cmd/ffzsocketserver/socketserver.go @@ -0,0 +1,71 @@ +package main // import "bitbucket.org/stendec/frankerfacez/socketserver/cmd/socketserver" + +import ( + "../../internal/server" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" +) + +var configFilename *string = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.") +var generateKeys *bool = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.") + +func main() { + flag.Parse() + + if *generateKeys { + GenerateKeys(*configFilename) + return + } + + confFile, err := os.Open(*configFilename) + if os.IsNotExist(err) { + fmt.Println("Error: No config file. Run with -genkeys and edit config.json") + os.Exit(3) + } + if err != nil { + log.Fatal(err) + } + conf := &server.ConfigFile{} + confBytes, err := ioutil.ReadAll(confFile) + if err != nil { + log.Fatal(err) + } + err = json.Unmarshal(confBytes, &conf) + if err != nil { + log.Fatal(err) + } + + httpServer := &http.Server{ + Addr: conf.ListenAddr, + } + + server.SetupServerAndHandle(conf, httpServer.TLSConfig, nil) + + if conf.UseSSL { + err = httpServer.ListenAndServeTLS(conf.SSLCertificateFile, conf.SSLKeyFile) + } else { + err = httpServer.ListenAndServe() + } + + if err != nil { + log.Fatal("ListenAndServe: ", err) + } +} + +func GenerateKeys(outputFile string) { + if flag.NArg() < 1 { + fmt.Println("Specify a numeric server ID after -genkeys") + os.Exit(2) + } + if flag.NArg() >= 2 { + server.GenerateKeys(outputFile, flag.Arg(0), flag.Arg(1)) + } else { + server.GenerateKeys(outputFile, flag.Arg(0), "") + } + fmt.Println("Keys generated. Now edit config.json") +} diff --git a/socketserver/internal/server/backend.go b/socketserver/internal/server/backend.go new file mode 100644 index 00000000..72d74b22 --- /dev/null +++ b/socketserver/internal/server/backend.go @@ -0,0 +1,234 @@ +package server + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/pmylund/go-cache" + "golang.org/x/crypto/nacl/box" + "io/ioutil" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +var backendHttpClient http.Client +var backendUrl string +var responseCache *cache.Cache + +var getBacklogUrl string + +var backendSharedKey [32]byte +var serverId int + +var messageBufferPool sync.Pool + +func SetupBackend(config *ConfigFile) { + backendHttpClient.Timeout = 60 * time.Second + backendUrl = config.BackendUrl + if responseCache != nil { + responseCache.Flush() + } + responseCache = cache.New(60*time.Second, 120*time.Second) + + getBacklogUrl = fmt.Sprintf("%s/backlog", backendUrl) + + messageBufferPool.New = New4KByteBuffer + + var theirPublic, ourPrivate [32]byte + copy(theirPublic[:], config.BackendPublicKey) + copy(ourPrivate[:], config.OurPrivateKey) + serverId = config.ServerId + + box.Precompute(&backendSharedKey, &theirPublic, &ourPrivate) +} + +func getCacheKey(remoteCommand, data string) string { + return fmt.Sprintf("%s/%s", remoteCommand, data) +} + +// Publish a message to clients with no caching. +// The scope must be specified because no attempt is made to recognize the command. +func HBackendPublishRequest(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + cmd := formData.Get("cmd") + json := formData.Get("args") + channel := formData.Get("channel") + scope := formData.Get("scope") + + target := MessageTargetTypeByName(scope) + + if cmd == "" { + w.WriteHeader(422) + fmt.Fprintf(w, "Error: cmd cannot be blank") + return + } + if channel == "" && (target == MsgTargetTypeChat || target == MsgTargetTypeMultichat) { + w.WriteHeader(422) + fmt.Fprintf(w, "Error: channel must be specified") + return + } + + cm := ClientMessage{MessageID: -1, Command: Command(cmd), origArguments: json} + cm.parseOrigArguments() + var count int + + switch target { + case MsgTargetTypeSingle: + // TODO + case MsgTargetTypeChat: + count = PublishToChat(channel, cm) + case MsgTargetTypeMultichat: + // TODO + case MsgTargetTypeGlobal: + count = PublishToAll(cm) + case MsgTargetTypeInvalid: + default: + w.WriteHeader(422) + fmt.Fprint(w, "Invalid 'scope'. must be single, chat, multichat, channel, or global") + return + } + fmt.Fprint(w, count) +} + +func RequestRemoteDataCached(remoteCommand, data string, auth AuthInfo) (string, error) { + cached, ok := responseCache.Get(getCacheKey(remoteCommand, data)) + if ok { + return cached.(string), nil + } + return RequestRemoteData(remoteCommand, data, auth) +} + +func RequestRemoteData(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) { + destUrl := fmt.Sprintf("%s/cmd/%s", backendUrl, remoteCommand) + var authKey string + if auth.UsernameValidated { + authKey = "usernameClaimed" + } else { + authKey = "username" + } + + formData := url.Values{ + "clientData": []string{data}, + authKey: []string{auth.TwitchUsername}, + } + + sealedForm, err := SealRequest(formData) + if err != nil { + return "", err + } + + resp, err := backendHttpClient.PostForm(destUrl, sealedForm) + if err != nil { + return "", err + } + + respBytes, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return "", err + } + + responseStr = string(respBytes) + + if resp.Header.Get("FFZ-Cache") != "" { + durSecs, err := strconv.ParseInt(resp.Header.Get("FFZ-Cache"), 10, 64) + if err != nil { + return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err) + } + duration := time.Duration(durSecs) * time.Second + responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration) + } + + return +} + +func FetchBacklogData(chatSubs []string) ([]ClientMessage, error) { + formData := url.Values{ + "subs": chatSubs, + } + + sealedForm, err := SealRequest(formData) + if err != nil { + return nil, err + } + + resp, err := backendHttpClient.PostForm(getBacklogUrl, sealedForm) + if err != nil { + return nil, err + } + dec := json.NewDecoder(resp.Body) + var messages []ClientMessage + err = dec.Decode(messages) + if err != nil { + return nil, err + } + + return messages, nil +} + +func GenerateKeys(outputFile, serverId, theirPublicStr string) { + var err error + output := ConfigFile{ + ListenAddr: "0.0.0.0:8001", + SocketOrigin: "localhost:8001", + BackendUrl: "http://localhost:8002/ffz", + BannerHTML: ` + +CatBag + +
+
+
+
+
+
+ A FrankerFaceZ Service + — CatBag by Wolsk +
+
+`, + } + + output.ServerId, err = strconv.Atoi(serverId) + if err != nil { + log.Fatal(err) + } + + ourPublic, ourPrivate, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatal(err) + } + output.OurPublicKey, output.OurPrivateKey = ourPublic[:], ourPrivate[:] + + if theirPublicStr != "" { + reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(theirPublicStr)) + theirPublic, err := ioutil.ReadAll(reader) + if err != nil { + log.Fatal(err) + } + output.BackendPublicKey = theirPublic + } + + bytes, err := json.MarshalIndent(output, "", "\t") + if err != nil { + log.Fatal(err) + } + fmt.Println(string(bytes)) + err = ioutil.WriteFile(outputFile, bytes, 0600) + if err != nil { + log.Fatal(err) + } +} diff --git a/socketserver/internal/server/backend_test.go b/socketserver/internal/server/backend_test.go new file mode 100644 index 00000000..7043d9f3 --- /dev/null +++ b/socketserver/internal/server/backend_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "crypto/rand" + "golang.org/x/crypto/nacl/box" + "net/url" + "testing" +) + +func SetupRandomKeys(t testing.TB) { + _, senderPrivate, err := box.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + receiverPublic, _, err := box.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + box.Precompute(&backendSharedKey, receiverPublic, senderPrivate) + messageBufferPool.New = New4KByteBuffer +} + +func TestSealRequest(t *testing.T) { + SetupRandomKeys(t) + + values := url.Values{ + "QuickBrownFox": []string{"LazyDog"}, + } + + sealedValues, err := SealRequest(values) + if err != nil { + t.Fatal(err) + } + // sealedValues.Encode() + // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV + + unsealedValues, err := UnsealRequest(sealedValues) + if err != nil { + t.Fatal(err) + } + + if unsealedValues.Get("QuickBrownFox") != "LazyDog" { + t.Errorf("Failed to round-trip, got back %v", unsealedValues) + } +} diff --git a/socketserver/internal/server/backlog.go b/socketserver/internal/server/backlog.go new file mode 100644 index 00000000..66f09e93 --- /dev/null +++ b/socketserver/internal/server/backlog.go @@ -0,0 +1,364 @@ +package server + +import ( + "errors" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +type PushCommandCacheInfo struct { + Caching BacklogCacheType + Target MessageTargetType +} + +// this value is just docs right now +var ServerInitiatedCommands = map[Command]PushCommandCacheInfo{ + /// Global updates & notices + "update_news": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global + "message": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global + "reload_ff": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global + + /// Emote updates + "reload_badges": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global + "set_badge": {CacheTypeTimestamps, MsgTargetTypeMultichat}, // timecache:multichat + "reload_set": {}, // timecache:multichat + "load_set": {}, // TODO what are the semantics of this? + + /// User auth + "do_authorize": {CacheTypeNever, MsgTargetTypeSingle}, // nocache:single + + /// Channel data + // follow_sets: extra emote sets included in the chat + // follow_buttons: extra follow buttons below the stream + "follow_sets": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:chat + "follow_buttons": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:watching + "srl_race": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching + + /// Chatter/viewer counts + "chatters": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching + "viewers": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching +} + +type BacklogCacheType int + +const ( + // This is not a cache type. + CacheTypeInvalid BacklogCacheType = iota + // This message cannot be cached. + CacheTypeNever + // Save the last 24 hours of this message. + // If a client indicates that it has reconnected, replay the messages sent after the disconnect. + // Do not replay if the client indicates that this is a firstload. + CacheTypeTimestamps + // Save only the last copy of this message, and always send it when the backlog is requested. + CacheTypeLastOnly + // Save this backlog data to disk with its timestamp. + // Send it when the backlog is requested, or after a reconnect if it was updated. + CacheTypePersistent +) + +type MessageTargetType int + +const ( + // This is not a message target. + MsgTargetTypeInvalid MessageTargetType = iota + // This message is targeted to a single TODO(user or connection) + MsgTargetTypeSingle + // This message is targeted to all users in a chat + MsgTargetTypeChat + // This message is targeted to all users in multiple chats + MsgTargetTypeMultichat + // This message is sent to all FFZ users. + MsgTargetTypeGlobal +) + +// note: see types.go for methods on these + +// Returned by BacklogCacheType.UnmarshalJSON() +var ErrorUnrecognizedCacheType = errors.New("Invalid value for cachetype") + +// Returned by MessageTargetType.UnmarshalJSON() +var ErrorUnrecognizedTargetType = errors.New("Invalid value for message target") + +type TimestampedGlobalMessage struct { + Timestamp time.Time + Command Command + Data string +} + +type TimestampedMultichatMessage struct { + Timestamp time.Time + Channels []string + Command Command + Data string +} + +type LastSavedMessage struct { + Timestamp time.Time + Data string +} + +// map is command -> channel -> data + +// CacheTypeLastOnly. Cleaned up by reaper goroutine every ~hour. +var CachedLastMessages map[Command]map[string]LastSavedMessage +var CachedLSMLock sync.RWMutex + +// CacheTypePersistent. Never cleaned. +var PersistentLastMessages map[Command]map[string]LastSavedMessage +var PersistentLSMLock sync.RWMutex + +var CachedGlobalMessages []TimestampedGlobalMessage +var CachedChannelMessages []TimestampedMultichatMessage +var CacheListsLock sync.RWMutex + +func DumpCache() { + CachedLSMLock.Lock() + CachedLastMessages = make(map[Command]map[string]LastSavedMessage) + CachedLSMLock.Unlock() + + PersistentLSMLock.Lock() + PersistentLastMessages = make(map[Command]map[string]LastSavedMessage) + // TODO delete file? + PersistentLSMLock.Unlock() + + CacheListsLock.Lock() + CachedGlobalMessages = make(tgmarray, 0) + CachedChannelMessages = make(tmmarray, 0) + CacheListsLock.Unlock() +} + +func SendBacklogForNewClient(client *ClientInfo) { + client.Mutex.Lock() // reading CurrentChannels + PersistentLSMLock.RLock() + for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypePersistent, MsgTargetTypeChat}) { + chanMap := CachedLastMessages[cmd] + if chanMap == nil { + continue + } + for _, channel := range client.CurrentChannels { + msg, ok := chanMap[channel] + if ok { + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + } + } + PersistentLSMLock.RUnlock() + + CachedLSMLock.RLock() + for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat}) { + chanMap := CachedLastMessages[cmd] + if chanMap == nil { + continue + } + for _, channel := range client.CurrentChannels { + msg, ok := chanMap[channel] + if ok { + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + } + } + CachedLSMLock.RUnlock() + client.Mutex.Unlock() +} + +func SendTimedBacklogMessages(client *ClientInfo, disconnectTime time.Time) { + client.Mutex.Lock() // reading CurrentChannels + CacheListsLock.RLock() + + globIdx := FindFirstNewMessage(tgmarray(CachedGlobalMessages), disconnectTime) + + for i := globIdx; i < len(CachedGlobalMessages); i++ { + item := CachedGlobalMessages[i] + msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + + chanIdx := FindFirstNewMessage(tmmarray(CachedChannelMessages), disconnectTime) + + for i := chanIdx; i < len(CachedChannelMessages); i++ { + item := CachedChannelMessages[i] + var send bool + for _, channel := range item.Channels { + for _, matchChannel := range client.CurrentChannels { + if channel == matchChannel { + send = true + break + } + } + if send { + break + } + } + if send { + msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + } + + CacheListsLock.RUnlock() + client.Mutex.Unlock() +} + +func InsertionSort(ary sort.Interface) { + for i := 1; i < ary.Len(); i++ { + for j := i; j > 0 && ary.Less(j, j-1); j-- { + ary.Swap(j, j-1) + } + } +} + +type TimestampArray interface { + Len() int + GetTime(int) time.Time +} + +func FindFirstNewMessage(ary TimestampArray, disconnectTime time.Time) (idx int) { + // TODO needs tests + len := ary.Len() + i := len + + // Walk backwards until we find GetTime() before disconnectTime + step := 1 + for i > 0 { + i -= step + if i < 0 { + i = 0 + } + if !ary.GetTime(i).After(disconnectTime) { + break + } + step = int(float64(step)*1.5) + 1 + } + + // Walk forwards until we find GetTime() after disconnectTime + for i < len && !ary.GetTime(i).After(disconnectTime) { + i++ + } + + if i == len { + return -1 + } + return i +} + +func SaveLastMessage(which map[Command]map[string]LastSavedMessage, locker sync.Locker, cmd Command, channel string, timestamp time.Time, data string, deleting bool) { + locker.Lock() + defer locker.Unlock() + + chanMap, ok := CachedLastMessages[cmd] + if !ok { + if deleting { + return + } + chanMap = make(map[string]LastSavedMessage) + CachedLastMessages[cmd] = chanMap + } + + if deleting { + delete(chanMap, channel) + } else { + chanMap[channel] = LastSavedMessage{timestamp, data} + } +} + +func SaveGlobalMessage(cmd Command, timestamp time.Time, data string) { + CacheListsLock.Lock() + CachedGlobalMessages = append(CachedGlobalMessages, TimestampedGlobalMessage{timestamp, cmd, data}) + InsertionSort(tgmarray(CachedGlobalMessages)) + CacheListsLock.Unlock() +} + +func SaveMultichanMessage(cmd Command, channels string, timestamp time.Time, data string) { + CacheListsLock.Lock() + CachedChannelMessages = append(CachedChannelMessages, TimestampedMultichatMessage{timestamp, strings.Split(channels, ","), cmd, data}) + InsertionSort(tmmarray(CachedChannelMessages)) + CacheListsLock.Unlock() +} + +func GetCommandsOfType(match PushCommandCacheInfo) []Command { + var ret []Command + for cmd, info := range ServerInitiatedCommands { + if info == match { + ret = append(ret, cmd) + } + } + return ret +} + +func HBackendDumpBacklog(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + confirm := formData.Get("confirm") + if confirm == "1" { + DumpCache() + } +} + +// Publish a message to clients, and update the in-server cache for the message. +// notes: +// `scope` is implicit in the command +func HBackendUpdateAndPublish(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + cmd := Command(formData.Get("cmd")) + json := formData.Get("args") + channel := formData.Get("channel") + deleteMode := formData.Get("delete") != "" + timeStr := formData.Get("time") + timestamp, err := time.Parse(time.UnixDate, timeStr) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing time: %v", err) + } + + cacheinfo, ok := ServerInitiatedCommands[cmd] + if !ok { + w.WriteHeader(422) + fmt.Fprintf(w, "Caching semantics unknown for command '%s'. Post to /addcachedcommand first.") + return + } + + var count int + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} + msg.parseOrigArguments() + + if cacheinfo.Caching == CacheTypeLastOnly && cacheinfo.Target == MsgTargetTypeChat { + SaveLastMessage(CachedLastMessages, &CachedLSMLock, cmd, channel, timestamp, json, deleteMode) + count = PublishToChat(channel, msg) + } else if cacheinfo.Caching == CacheTypePersistent && cacheinfo.Target == MsgTargetTypeChat { + SaveLastMessage(PersistentLastMessages, &PersistentLSMLock, cmd, channel, timestamp, json, deleteMode) + count = PublishToChat(channel, msg) + } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeMultichat { + SaveMultichanMessage(cmd, channel, timestamp, json) + count = PublishToMultiple(strings.Split(channel, ","), msg) + } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeGlobal { + SaveGlobalMessage(cmd, timestamp, json) + count = PublishToAll(msg) + } + + w.Write([]byte(strconv.Itoa(count))) +} diff --git a/socketserver/internal/server/backlog_test.go b/socketserver/internal/server/backlog_test.go new file mode 100644 index 00000000..68757587 --- /dev/null +++ b/socketserver/internal/server/backlog_test.go @@ -0,0 +1,76 @@ +package server + +import ( + "testing" + "time" +) + +func TestFindFirstNewMessageEmpty(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{} + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != -1 { + t.Errorf("Expected -1, got %d", i) + } +} +func TestFindFirstNewMessageOneBefore(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{ + {Timestamp: time.Unix(8, 0)}, + } + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != -1 { + t.Errorf("Expected -1, got %d", i) + } +} +func TestFindFirstNewMessageSeveralBefore(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{ + {Timestamp: time.Unix(1, 0)}, + {Timestamp: time.Unix(2, 0)}, + {Timestamp: time.Unix(3, 0)}, + {Timestamp: time.Unix(4, 0)}, + {Timestamp: time.Unix(5, 0)}, + } + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != -1 { + t.Errorf("Expected -1, got %d", i) + } +} +func TestFindFirstNewMessageInMiddle(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{ + {Timestamp: time.Unix(1, 0)}, + {Timestamp: time.Unix(2, 0)}, + {Timestamp: time.Unix(3, 0)}, + {Timestamp: time.Unix(4, 0)}, + {Timestamp: time.Unix(5, 0)}, + {Timestamp: time.Unix(11, 0)}, + {Timestamp: time.Unix(12, 0)}, + {Timestamp: time.Unix(13, 0)}, + {Timestamp: time.Unix(14, 0)}, + {Timestamp: time.Unix(15, 0)}, + } + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != 5 { + t.Errorf("Expected 5, got %d", i) + } +} +func TestFindFirstNewMessageOneAfter(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{ + {Timestamp: time.Unix(15, 0)}, + } + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != 0 { + t.Errorf("Expected 0, got %d", i) + } +} +func TestFindFirstNewMessageSeveralAfter(t *testing.T) { + CachedGlobalMessages = []TimestampedGlobalMessage{ + {Timestamp: time.Unix(11, 0)}, + {Timestamp: time.Unix(12, 0)}, + {Timestamp: time.Unix(13, 0)}, + {Timestamp: time.Unix(14, 0)}, + {Timestamp: time.Unix(15, 0)}, + } + i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) + if i != 0 { + t.Errorf("Expected 0, got %d", i) + } +} diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go new file mode 100644 index 00000000..c947bf08 --- /dev/null +++ b/socketserver/internal/server/commands.go @@ -0,0 +1,285 @@ +package server + +import ( + "fmt" + "github.com/satori/go.uuid" + "golang.org/x/net/websocket" + "log" + "strconv" + "sync" + "time" +) + +var ResponseSuccess = ClientMessage{Command: SuccessCommand} +var ResponseFailure = ClientMessage{Command: "False"} + +const ChannelInfoDelay = 2 * time.Second + +func HandleCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) { + handler, ok := CommandHandlers[msg.Command] + if !ok { + log.Println("[!] Unknown command", msg.Command, "- sent by client", client.ClientID, "@", conn.RemoteAddr()) + FFZCodec.Send(conn, ClientMessage{ + MessageID: msg.MessageID, + Command: "error", + Arguments: fmt.Sprintf("Unknown command %s", msg.Command), + }) + return + } + + response, err := CallHandler(handler, conn, client, msg) + + if err == nil { + if response.Command == AsyncResponseCommand { + // Don't send anything + // The response will be delivered over client.MessageChannel / serverMessageChan + } else { + response.MessageID = msg.MessageID + FFZCodec.Send(conn, response) + } + } else { + FFZCodec.Send(conn, ClientMessage{ + MessageID: msg.MessageID, + Command: "error", + Arguments: err.Error(), + }) + } +} + +func HandleHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + version, clientId, err := msg.ArgumentsAsTwoStrings() + if err != nil { + return + } + + client.Version = version + client.ClientID = uuid.FromStringOrNil(clientId) + if client.ClientID == uuid.Nil { + client.ClientID = uuid.NewV4() + } + + SubscribeGlobal(client) + + return ClientMessage{ + Arguments: client.ClientID.String(), + }, nil +} + +func HandleReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + disconnectAt, err := msg.ArgumentsAsInt() + if err != nil { + return + } + + client.Mutex.Lock() + if client.MakePendingRequests != nil { + if !client.MakePendingRequests.Stop() { + // Timer already fired, GetSubscriptionBacklog() has started + rmsg.Command = SuccessCommand + return + } + } + client.PendingSubscriptionsBacklog = nil + client.MakePendingRequests = nil + client.Mutex.Unlock() + + if disconnectAt == 0 { + // backlog only + go func() { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + SendBacklogForNewClient(client) + }() + return ClientMessage{Command: AsyncResponseCommand}, nil + } else { + // backlog and timed + go func() { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + SendBacklogForNewClient(client) + SendTimedBacklogMessages(client, time.Unix(disconnectAt, 0)) + }() + return ClientMessage{Command: AsyncResponseCommand}, nil + } +} + +func HandleSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + username, err := msg.ArgumentsAsString() + if err != nil { + return + } + + client.Mutex.Lock() + client.TwitchUsername = username + client.UsernameValidated = false + client.Mutex.Unlock() + + return ResponseSuccess, nil +} + +func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, err := msg.ArgumentsAsString() + + if err != nil { + return + } + + client.Mutex.Lock() + + AddToSliceS(&client.CurrentChannels, channel) + client.PendingSubscriptionsBacklog = append(client.PendingSubscriptionsBacklog, channel) + + if client.MakePendingRequests == nil { + client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client)) + } else { + if !client.MakePendingRequests.Reset(ChannelInfoDelay) { + client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client)) + } + } + + client.Mutex.Unlock() + + SubscribeChat(client, channel) + + return ResponseSuccess, nil +} + +func HandleUnsub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, err := msg.ArgumentsAsString() + + if err != nil { + return + } + + client.Mutex.Lock() + RemoveFromSliceS(&client.CurrentChannels, channel) + client.Mutex.Unlock() + + UnsubscribeSingleChat(client, channel) + + return ResponseSuccess, nil +} + +func GetSubscriptionBacklogFor(conn *websocket.Conn, client *ClientInfo) func() { + return func() { + GetSubscriptionBacklog(conn, client) + } +} + +// On goroutine +func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) { + var subs []string + + // Lock, grab the data, and reset it + client.Mutex.Lock() + subs = client.PendingSubscriptionsBacklog + client.PendingSubscriptionsBacklog = nil + client.MakePendingRequests = nil + client.Mutex.Unlock() + + if len(subs) == 0 { + return + } + + if backendUrl == "" { + return // for testing runs + } + messages, err := FetchBacklogData(subs) + + if err != nil { + // Oh well. + log.Print("error in GetSubscriptionBacklog:", err) + return + } + + // Deliver to client + for _, msg := range messages { + client.MessageChannel <- msg + } +} + +type SurveySubmission struct { + User string + Json string +} + +var SurveySubmissions []SurveySubmission +var SurveySubmissionLock sync.Mutex + +func HandleSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + SurveySubmissionLock.Lock() + SurveySubmissions = append(SurveySubmissions, SurveySubmission{client.TwitchUsername, msg.origArguments}) + SurveySubmissionLock.Unlock() + + return ResponseSuccess, nil +} + +type FollowEvent struct { + User string + Channel string + NowFollowing bool + Timestamp time.Time +} + +var FollowEvents []FollowEvent +var FollowEventsLock sync.Mutex + +func HandleTrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, following, err := msg.ArgumentsAsStringAndBool() + if err != nil { + return + } + now := time.Now() + + FollowEventsLock.Lock() + FollowEvents = append(FollowEvents, FollowEvent{client.TwitchUsername, channel, following, now}) + FollowEventsLock.Unlock() + + return ResponseSuccess, nil +} + +var AggregateEmoteUsage map[int]map[string]int = make(map[int]map[string]int) +var AggregateEmoteUsageLock sync.Mutex + +func HandleEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + // arguments is [1]map[EmoteId]map[RoomName]float64 + + mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{}) + + AggregateEmoteUsageLock.Lock() + defer AggregateEmoteUsageLock.Unlock() + + for strEmote, val1 := range mapRoot { + var emoteId int + emoteId, err = strconv.Atoi(strEmote) + if err != nil { + return + } + + destMapInner, ok := AggregateEmoteUsage[emoteId] + if !ok { + destMapInner = make(map[string]int) + AggregateEmoteUsage[emoteId] = destMapInner + } + + mapInner := val1.(map[string]interface{}) + for roomName, val2 := range mapInner { + var count int = int(val2.(float64)) + destMapInner[roomName] += count + } + } + + return ResponseSuccess, nil +} + +func HandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + go func(conn *websocket.Conn, msg ClientMessage, authInfo AuthInfo) { + resp, err := RequestRemoteDataCached(string(msg.Command), msg.origArguments, authInfo) + + if err != nil { + FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}) + } else { + FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}) + } + }(conn, msg, client.AuthInfo) + + return ClientMessage{Command: AsyncResponseCommand}, nil +} diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go new file mode 100644 index 00000000..f227db8b --- /dev/null +++ b/socketserver/internal/server/handlecore.go @@ -0,0 +1,434 @@ +package server // import "bitbucket.org/stendec/frankerfacez/socketserver/internal/server" + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "golang.org/x/net/websocket" + "log" + "net/http" + "strconv" + "strings" + "sync" +) + +const MAX_PACKET_SIZE = 1024 + +// A command is how the client refers to a function on the server. It's just a string. +type Command string + +// A function that is called to respond to a Command. +type CommandHandler func(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) + +var CommandHandlers = map[Command]CommandHandler{ + HelloCommand: HandleHello, + "setuser": HandleSetUser, + "ready": HandleReady, + + "sub": HandleSub, + "unsub": HandleUnsub, + + "track_follow": HandleTrackFollow, + "emoticon_uses": HandleEmoticonUses, + "survey": HandleSurvey, + + "twitch_emote": HandleRemoteCommand, + "get_link": HandleRemoteCommand, + "get_display_name": HandleRemoteCommand, + "update_follow_buttons": HandleRemoteCommand, + "chat_history": HandleRemoteCommand, +} + +// Sent by the server in ClientMessage.Command to indicate success. +const SuccessCommand Command = "True" + +// Sent by the server in ClientMessage.Command to indicate failure. +const ErrorCommand Command = "error" + +// This must be the first command sent by the client once the connection is established. +const HelloCommand Command = "hello" + +// A handler returning a ClientMessage with this Command will prevent replying to the client. +// It signals that the work has been handed off to a background goroutine. +const AsyncResponseCommand Command = "_async" + +// A websocket.Codec that translates the protocol into ClientMessage objects. +var FFZCodec websocket.Codec = websocket.Codec{ + Marshal: MarshalClientMessage, + Unmarshal: UnmarshalClientMessage, +} + +// Errors that get returned to the client. +var ProtocolError error = errors.New("FFZ Socket protocol error.") +var ProtocolErrorNegativeID error = errors.New("FFZ Socket protocol error: negative or zero message ID.") +var ExpectedSingleString = errors.New("Error: Expected single string as arguments.") +var ExpectedSingleInt = errors.New("Error: Expected single integer as arguments.") +var ExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.") +var ExpectedStringAndInt = errors.New("Error: Expected array of string, int as arguments.") +var ExpectedStringAndBool = errors.New("Error: Expected array of string, bool as arguments.") +var ExpectedStringAndIntGotFloat = errors.New("Error: Second argument was a float, expected an integer.") + +var gconfig *ConfigFile + +// Create a websocket.Server with the options from the provided Config. +func setupServer(config *ConfigFile, tlsConfig *tls.Config) *websocket.Server { + gconfig = config + sockConf, err := websocket.NewConfig("/", config.SocketOrigin) + if err != nil { + log.Fatal(err) + } + + SetupBackend(config) + + if config.UseSSL { + cert, err := tls.LoadX509KeyPair(config.SSLCertificateFile, config.SSLKeyFile) + if err != nil { + log.Fatal(err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.ServerName = config.SocketOrigin + tlsConfig.BuildNameToCertificate() + sockConf.TlsConfig = tlsConfig + + } + + sockServer := &websocket.Server{} + sockServer.Config = *sockConf + sockServer.Handler = HandleSocketConnection + + go deadChannelReaper() + + return sockServer +} + +// Set up a websocket listener and register it on /. +// (Uses http.DefaultServeMux .) +func SetupServerAndHandle(config *ConfigFile, tlsConfig *tls.Config, serveMux *http.ServeMux) { + sockServer := setupServer(config, tlsConfig) + + if serveMux == nil { + serveMux = http.DefaultServeMux + } + serveMux.HandleFunc("/", ServeWebsocketOrCatbag(sockServer.ServeHTTP)) + serveMux.HandleFunc("/pub_msg", HBackendPublishRequest) + serveMux.HandleFunc("/dump_backlog", HBackendDumpBacklog) + serveMux.HandleFunc("/update_and_pub", HBackendUpdateAndPublish) +} + +func ServeWebsocketOrCatbag(sockfunc func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Connection") == "Upgrade" { + sockfunc(w, r) + return + } else { + w.Write([]byte(gconfig.BannerHTML)) + } + } +} + +// Handle a new websocket connection from a FFZ client. +// This runs in a goroutine started by net/http. +func HandleSocketConnection(conn *websocket.Conn) { + // websocket.Conn is a ReadWriteCloser + + var _closer sync.Once + closer := func() { + _closer.Do(func() { + conn.Close() + }) + } + + // Close the connection when we're done. + defer closer() + + _clientChan := make(chan ClientMessage) + _serverMessageChan := make(chan ClientMessage) + _errorChan := make(chan error) + + // Launch receiver goroutine + go func(errorChan chan<- error, clientChan chan<- ClientMessage) { + var msg ClientMessage + var err error + for ; err == nil; err = FFZCodec.Receive(conn, &msg) { + if msg.MessageID == 0 { + continue + } + clientChan <- msg + } + errorChan <- err + close(errorChan) + close(clientChan) + // exit + }(_errorChan, _clientChan) + + var errorChan <-chan error = _errorChan + var clientChan <-chan ClientMessage = _clientChan + var serverMessageChan <-chan ClientMessage = _serverMessageChan + + var client ClientInfo + client.MessageChannel = _serverMessageChan + + // All set up, now enter the work loop + +RunLoop: + for { + select { + case err := <-errorChan: + FFZCodec.Send(conn, ClientMessage{ + MessageID: -1, + Command: "error", + Arguments: err.Error(), + }) // note - socket might be closed, but don't care + break RunLoop + case msg := <-clientChan: + if client.Version == "" && msg.Command != HelloCommand { + FFZCodec.Send(conn, ClientMessage{ + MessageID: msg.MessageID, + Command: "error", + Arguments: "Error - the first message sent must be a 'hello'", + }) + break RunLoop + } + + HandleCommand(conn, &client, msg) + case smsg := <-serverMessageChan: + FFZCodec.Send(conn, smsg) + } + } + + // Exit + + // Launch message draining goroutine - we aren't out of the pub/sub records + go func() { + for _ = range _serverMessageChan { + } + }() + + // Stop getting messages... + UnsubscribeAll(&client) + + // And finished. + // Close the channel so the draining goroutine can finish, too. + close(_serverMessageChan) +} + +func CallHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + fmt.Print("[!] Error executing command", cmsg.Command, "--", r) + err, ok = r.(error) + if !ok { + err = fmt.Errorf("command handler: %v", r) + } + } + }() + return handler(conn, client, cmsg) +} + +// Unpack a message sent from the client into a ClientMessage. +func UnmarshalClientMessage(data []byte, payloadType byte, v interface{}) (err error) { + var spaceIdx int + + out := v.(*ClientMessage) + dataStr := string(data) + + // Message ID + spaceIdx = strings.IndexRune(dataStr, ' ') + if spaceIdx == -1 { + return ProtocolError + } + messageId, err := strconv.Atoi(dataStr[:spaceIdx]) + if messageId < -1 || messageId == 0 { + return ProtocolErrorNegativeID + } + + out.MessageID = messageId + dataStr = dataStr[spaceIdx+1:] + + spaceIdx = strings.IndexRune(dataStr, ' ') + if spaceIdx == -1 { + out.Command = Command(dataStr) + out.Arguments = nil + return nil + } else { + out.Command = Command(dataStr[:spaceIdx]) + } + dataStr = dataStr[spaceIdx+1:] + argumentsJson := dataStr + out.origArguments = argumentsJson + err = out.parseOrigArguments() + if err != nil { + return + } + return nil +} + +func (cm *ClientMessage) parseOrigArguments() error { + err := json.Unmarshal([]byte(cm.origArguments), &cm.Arguments) + if err != nil { + return err + } + return nil +} + +func MarshalClientMessage(clientMessage interface{}) (data []byte, payloadType byte, err error) { + var msg ClientMessage + var ok bool + msg, ok = clientMessage.(ClientMessage) + if !ok { + pMsg, ok := clientMessage.(*ClientMessage) + if !ok { + panic("MarshalClientMessage: argument needs to be a ClientMessage") + } + msg = *pMsg + } + var dataStr string + + if msg.Command == "" && msg.MessageID == 0 { + panic("MarshalClientMessage: attempt to send an empty ClientMessage") + } + + if msg.Command == "" { + msg.Command = SuccessCommand + } + if msg.MessageID == 0 { + msg.MessageID = -1 + } + + if msg.Arguments != nil { + argBytes, err := json.Marshal(msg.Arguments) + if err != nil { + return nil, 0, err + } + + dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes)) + } else { + dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command) + } + + return []byte(dataStr), websocket.TextFrame, nil +} + +// Command handlers should use this to construct responses. +func NewClientMessage(arguments interface{}) ClientMessage { + return ClientMessage{ + MessageID: 0, // filled by the select loop + Command: SuccessCommand, + Arguments: arguments, + } +} + +// Convenience method: Parse the arguments of the ClientMessage as a single string. +func (cm *ClientMessage) ArgumentsAsString() (string1 string, err error) { + var ok bool + string1, ok = cm.Arguments.(string) + if !ok { + err = ExpectedSingleString + return + } else { + return string1, nil + } +} + +// Convenience method: Parse the arguments of the ClientMessage as a single int. +func (cm *ClientMessage) ArgumentsAsInt() (int1 int64, err error) { + var ok bool + var num float64 + num, ok = cm.Arguments.(float64) + if !ok { + err = ExpectedSingleInt + return + } else { + int1 = int64(num) + return int1, nil + } +} + +// Convenience method: Parse the arguments of the ClientMessage as an array of two strings. +func (cm *ClientMessage) ArgumentsAsTwoStrings() (string1, string2 string, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ExpectedTwoStrings + return + } else { + if len(ary) != 2 { + err = ExpectedTwoStrings + return + } + string1, ok = ary[0].(string) + if !ok { + err = ExpectedTwoStrings + return + } + string2, ok = ary[1].(string) + if !ok { + err = ExpectedTwoStrings + return + } + return string1, string2, nil + } +} + +// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int. +func (cm *ClientMessage) ArgumentsAsStringAndInt() (string1 string, int int64, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ExpectedStringAndInt + return + } else { + if len(ary) != 2 { + err = ExpectedStringAndInt + return + } + string1, ok = ary[0].(string) + if !ok { + err = ExpectedStringAndInt + return + } + var num float64 + num, ok = ary[1].(float64) + if !ok { + err = ExpectedStringAndInt + return + } + int = int64(num) + if float64(int) != num { + err = ExpectedStringAndIntGotFloat + return + } + return string1, int, nil + } +} + +// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int. +func (cm *ClientMessage) ArgumentsAsStringAndBool() (str string, flag bool, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ExpectedStringAndBool + return + } else { + if len(ary) != 2 { + err = ExpectedStringAndBool + return + } + str, ok = ary[0].(string) + if !ok { + err = ExpectedStringAndBool + return + } + flag, ok = ary[1].(bool) + if !ok { + err = ExpectedStringAndBool + return + } + return str, flag, nil + } +} diff --git a/socketserver/internal/server/handlecore_test.go b/socketserver/internal/server/handlecore_test.go new file mode 100644 index 00000000..161b5921 --- /dev/null +++ b/socketserver/internal/server/handlecore_test.go @@ -0,0 +1,57 @@ +package server + +import ( + "fmt" + "golang.org/x/net/websocket" + "testing" +) + +func ExampleUnmarshalClientMessage() { + sourceData := []byte("100 hello [\"ffz_3.5.30\",\"898b5bfa-b577-47bb-afb4-252c703b67d6\"]") + var cm ClientMessage + err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm) + fmt.Println(err) + fmt.Println(cm.MessageID) + fmt.Println(cm.Command) + fmt.Println(cm.Arguments) + // Output: + // + // 100 + // hello + // [ffz_3.5.30 898b5bfa-b577-47bb-afb4-252c703b67d6] +} + +func ExampleMarshalClientMessage() { + var cm ClientMessage = ClientMessage{ + MessageID: -1, + Command: "do_authorize", + Arguments: "1234567890", + } + data, payloadType, err := MarshalClientMessage(&cm) + fmt.Println(err) + fmt.Println(payloadType == websocket.TextFrame) + fmt.Println(string(data)) + // Output: + // + // true + // -1 do_authorize "1234567890" +} + +func TestArgumentsAsStringAndBool(t *testing.T) { + sourceData := []byte("1 foo [\"string\", false]") + var cm ClientMessage + err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm) + if err != nil { + t.Fatal(err) + } + str, boolean, err := cm.ArgumentsAsStringAndBool() + if err != nil { + t.Fatal(err) + } + if str != "string" { + t.Error("Expected first array item to be 'string', got", str) + } + if boolean != false { + t.Error("Expected second array item to be false, got", boolean) + } +} diff --git a/socketserver/internal/server/publisher.go b/socketserver/internal/server/publisher.go new file mode 100644 index 00000000..d9658ac7 --- /dev/null +++ b/socketserver/internal/server/publisher.go @@ -0,0 +1,168 @@ +package server + +// This is the scariest code I've written yet for the server. +// If I screwed up the locking, I won't know until it's too late. + +import ( + "sync" + "time" +) + +type SubscriberList struct { + sync.RWMutex + Members []chan<- ClientMessage +} + +var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) +var ChatSubscriptionLock sync.RWMutex +var GlobalSubscriptionInfo SubscriberList + +func PublishToChat(channel string, msg ClientMessage) (count int) { + ChatSubscriptionLock.RLock() + list := ChatSubscriptionInfo[channel] + if list != nil { + list.RLock() + for _, msgChan := range list.Members { + msgChan <- msg + count++ + } + list.RUnlock() + } + ChatSubscriptionLock.RUnlock() + return +} + +func PublishToMultiple(channels []string, msg ClientMessage) (count int) { + found := make(map[chan<- ClientMessage]struct{}) + + ChatSubscriptionLock.RLock() + + for _, channel := range channels { + list := ChatSubscriptionInfo[channel] + if list != nil { + list.RLock() + for _, msgChan := range list.Members { + found[msgChan] = struct{}{} + } + list.RUnlock() + } + } + + ChatSubscriptionLock.RUnlock() + + for msgChan, _ := range found { + msgChan <- msg + count++ + } + return +} + +func PublishToAll(msg ClientMessage) (count int) { + GlobalSubscriptionInfo.RLock() + for _, msgChan := range GlobalSubscriptionInfo.Members { + msgChan <- msg + count++ + } + GlobalSubscriptionInfo.RUnlock() + return +} + +// Add a channel to the subscriptions while holding a read-lock to the map. +// Locks: +// - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object +// - possible write lock to the 'which' top-level map via the wlocker object +// - write lock to SubscriptionInfo (if not creating new) +func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { + list := ChatSubscriptionInfo[channelName] + if list == nil { + // Not found, so create it + ChatSubscriptionLock.RUnlock() + ChatSubscriptionLock.Lock() + list = &SubscriberList{} + list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper + ChatSubscriptionInfo[channelName] = list + ChatSubscriptionLock.Unlock() + ChatSubscriptionLock.RLock() + } else { + list.Lock() + AddToSliceC(&list.Members, value) + list.Unlock() + } +} + +func SubscribeGlobal(client *ClientInfo) { + GlobalSubscriptionInfo.Lock() + AddToSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel) + GlobalSubscriptionInfo.Unlock() +} + +func SubscribeChat(client *ClientInfo, channelName string) { + ChatSubscriptionLock.RLock() + _subscribeWhileRlocked(channelName, client.MessageChannel) + ChatSubscriptionLock.RUnlock() +} + +func unsubscribeAllClients() { + GlobalSubscriptionInfo.Lock() + GlobalSubscriptionInfo.Members = nil + GlobalSubscriptionInfo.Unlock() + ChatSubscriptionLock.Lock() + ChatSubscriptionInfo = make(map[string]*SubscriberList) + ChatSubscriptionLock.Unlock() +} + +// Unsubscribe the client from all channels, AND clear the CurrentChannels / WatchingChannels fields. +// Locks: +// - read lock to top-level maps +// - write lock to SubscriptionInfos +// - write lock to ClientInfo +func UnsubscribeAll(client *ClientInfo) { + client.Mutex.Lock() + client.PendingSubscriptionsBacklog = nil + client.PendingSubscriptionsBacklog = nil + client.Mutex.Unlock() + + GlobalSubscriptionInfo.Lock() + RemoveFromSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel) + GlobalSubscriptionInfo.Unlock() + + ChatSubscriptionLock.RLock() + client.Mutex.Lock() + for _, v := range client.CurrentChannels { + list := ChatSubscriptionInfo[v] + if list != nil { + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + } + } + client.CurrentChannels = nil + client.Mutex.Unlock() + ChatSubscriptionLock.RUnlock() +} + +func UnsubscribeSingleChat(client *ClientInfo, channelName string) { + ChatSubscriptionLock.RLock() + list := ChatSubscriptionInfo[channelName] + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + ChatSubscriptionLock.RUnlock() +} + +const ReapingDelay = 120 * time.Minute + +// Checks ChatSubscriptionInfo for entries with no subscribers every ReapingDelay. +// Started from SetupServer(). +func deadChannelReaper() { + for { + time.Sleep(ReapingDelay) + ChatSubscriptionLock.Lock() + for key, val := range ChatSubscriptionInfo { + if len(val.Members) == 0 { + ChatSubscriptionInfo[key] = nil + } + } + ChatSubscriptionLock.Unlock() + } +} diff --git a/socketserver/internal/server/publisher_test.go b/socketserver/internal/server/publisher_test.go new file mode 100644 index 00000000..2dc54ed6 --- /dev/null +++ b/socketserver/internal/server/publisher_test.go @@ -0,0 +1,441 @@ +package server + +import ( + "encoding/json" + "fmt" + "github.com/satori/go.uuid" + "golang.org/x/net/websocket" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "sync" + "syscall" + "testing" + "time" +) + +func TCountOpenFDs() uint64 { + ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())) + return uint64(len(ary)) +} + +const IgnoreReceivedArguments = 1 + 2i + +func TReceiveExpectedMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) (ClientMessage, bool) { + var msg ClientMessage + var fail bool + err := FFZCodec.Receive(conn, &msg) + if err != nil { + tb.Error(err) + return msg, false + } + if msg.MessageID != messageId { + tb.Error("Message ID was wrong. Expected", messageId, ", got", msg.MessageID, ":", msg) + fail = true + } + if msg.Command != command { + tb.Error("Command was wrong. Expected", command, ", got", msg.Command, ":", msg) + fail = true + } + if arguments != IgnoreReceivedArguments { + if arguments == nil { + if msg.origArguments != "" { + tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) + } + } else { + argBytes, _ := json.Marshal(arguments) + if msg.origArguments != string(argBytes) { + tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) + } + } + } + return msg, !fail +} + +func TSendMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) bool { + err := FFZCodec.Send(conn, ClientMessage{MessageID: messageId, Command: command, Arguments: arguments}) + if err != nil { + tb.Error(err) + } + return err == nil +} + +func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, deleteMode bool) (url.Values, error) { + form := url.Values{} + form.Set("cmd", string(cmd)) + argsBytes, err := json.Marshal(arguments) + if err != nil { + tb.Error(err) + return nil, err + } + form.Set("args", string(argsBytes)) + form.Set("channel", channel) + if deleteMode { + form.Set("delete", "1") + } + form.Set("time", time.Now().Format(time.UnixDate)) + + sealed, err := SealRequest(form) + if err != nil { + tb.Error(err) + return nil, err + } + return sealed, nil +} + +func TCheckResponse(tb testing.TB, resp *http.Response, expected string) bool { + var failed bool + respBytes, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + respStr := string(respBytes) + + if err != nil { + tb.Error(err) + failed = true + } + + if resp.StatusCode != 200 { + tb.Error("Publish failed: ", resp.StatusCode, respStr) + failed = true + } + + if respStr != expected { + tb.Errorf("Got wrong response from server. Expected: '%s' Got: '%s'", expected, respStr) + failed = true + } + return !failed +} + +type TURLs struct { + Websocket string + Origin string + PubMsg string + SavePubMsg string // update_and_pub +} + +func TGetUrls(testserver *httptest.Server) TURLs { + addr := testserver.Listener.Addr().String() + return TURLs{ + Websocket: fmt.Sprintf("ws://%s/", addr), + Origin: fmt.Sprintf("http://%s", addr), + PubMsg: fmt.Sprintf("http://%s/pub_msg", addr), + SavePubMsg: fmt.Sprintf("http://%s/update_and_pub", addr), + } +} + +func TSetup(testserver **httptest.Server, urls *TURLs) { + DumpCache() + + conf := &ConfigFile{ + ServerId: 20, + UseSSL: false, + SocketOrigin: "localhost:2002", + BannerHTML: ` + +CatBag + +
+
+
+
+
+
+ A FrankerFaceZ Service + — CatBag by Wolsk +
+
+`, + OurPublicKey: []byte{176, 149, 72, 209, 35, 42, 110, 220, 22, 236, 212, 129, 213, 199, 1, 227, 185, 167, 150, 159, 117, 202, 164, 100, 9, 107, 45, 141, 122, 221, 155, 73}, + OurPrivateKey: []byte{247, 133, 147, 194, 70, 240, 211, 216, 223, 16, 241, 253, 120, 14, 198, 74, 237, 180, 89, 33, 146, 146, 140, 58, 88, 160, 2, 246, 112, 35, 239, 87}, + BackendPublicKey: []byte{19, 163, 37, 157, 50, 139, 193, 85, 229, 47, 166, 21, 153, 231, 31, 133, 41, 158, 8, 53, 73, 0, 113, 91, 13, 181, 131, 248, 176, 18, 1, 107}, + } + gconfig = conf + SetupBackend(conf) + + if testserver != nil { + serveMux := http.NewServeMux() + SetupServerAndHandle(conf, nil, serveMux) + + tserv := httptest.NewUnstartedServer(serveMux) + *testserver = tserv + tserv.Start() + if urls != nil { + *urls = TGetUrls(tserv) + } + } +} + +func TestSubscriptionAndPublish(t *testing.T) { + var doneWg sync.WaitGroup + var readyWg sync.WaitGroup + + const TestChannelName1 = "room.testchannel" + const TestChannelName2 = "room.chan2" + const TestChannelName3 = "room.chan3" + const TestChannelNameUnused = "room.empty" + const TestCommandChan = "testdata_single" + const TestCommandMulti = "testdata_multi" + const TestCommandGlobal = "testdata_global" + const TestData1 = "123456789" + const TestData2 = 42 + const TestData3 = false + var TestData4 = []interface{}{"str1", "str2", "str3"} + + ServerInitiatedCommands[TestCommandChan] = PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat} + ServerInitiatedCommands[TestCommandMulti] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeMultichat} + ServerInitiatedCommands[TestCommandGlobal] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeGlobal} + + var server *httptest.Server + var urls TURLs + TSetup(&server, &urls) + defer server.CloseClientConnections() + defer unsubscribeAllClients() + + var conn *websocket.Conn + var err error + + // client 1: sub ch1, ch2 + // client 2: sub ch1, ch3 + // client 3: sub none + // client 4: delayed sub ch1 + // msg 1: ch1 + // msg 2: ch2, ch3 + // msg 3: chEmpty + // msg 4: global + + // Client 1 + conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + if err != nil { + t.Error(err) + return + } + + doneWg.Add(1) + readyWg.Add(1) + go func(conn *websocket.Conn) { + TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TSendMessage(t, conn, 2, "sub", TestChannelName1) + TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil) + TSendMessage(t, conn, 3, "sub", TestChannelName2) // 2 + TReceiveExpectedMessage(t, conn, 3, SuccessCommand, nil) + TSendMessage(t, conn, 4, "ready", 0) + TReceiveExpectedMessage(t, conn, 4, SuccessCommand, nil) + + readyWg.Done() + + TReceiveExpectedMessage(t, conn, -1, TestCommandChan, TestData1) + TReceiveExpectedMessage(t, conn, -1, TestCommandMulti, TestData2) + TReceiveExpectedMessage(t, conn, -1, TestCommandGlobal, TestData4) + + conn.Close() + doneWg.Done() + }(conn) + + // Client 2 + conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + if err != nil { + t.Error(err) + return + } + + doneWg.Add(1) + readyWg.Add(1) + go func(conn *websocket.Conn) { + TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TSendMessage(t, conn, 2, "sub", TestChannelName1) + TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil) + TSendMessage(t, conn, 3, "sub", TestChannelName3) // 3 + TReceiveExpectedMessage(t, conn, 3, SuccessCommand, nil) + TSendMessage(t, conn, 4, "ready", 0) + TReceiveExpectedMessage(t, conn, 4, SuccessCommand, nil) + + readyWg.Done() + + TReceiveExpectedMessage(t, conn, -1, TestCommandChan, TestData1) + TReceiveExpectedMessage(t, conn, -1, TestCommandMulti, TestData2) + TReceiveExpectedMessage(t, conn, -1, TestCommandGlobal, TestData4) + + conn.Close() + doneWg.Done() + }(conn) + + // Client 3 + conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + if err != nil { + t.Error(err) + return + } + + doneWg.Add(1) + readyWg.Add(1) + go func(conn *websocket.Conn) { + TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TSendMessage(t, conn, 2, "ready", 0) + TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil) + + readyWg.Done() + + TReceiveExpectedMessage(t, conn, -1, TestCommandGlobal, TestData4) + + conn.Close() + doneWg.Done() + }(conn) + + // Wait for clients 1-3 + readyWg.Wait() + + var form url.Values + var resp *http.Response + + // Publish message 1 - should go to clients 1, 2 + + form, err = TSealForSavePubMsg(t, TestCommandChan, TestChannelName1, TestData1, false) + if err != nil { + t.FailNow() + } + resp, err = http.PostForm(urls.SavePubMsg, form) + if !TCheckResponse(t, resp, strconv.Itoa(2)) { + t.FailNow() + } + + // Publish message 2 - should go to clients 1, 2 + + form, err = TSealForSavePubMsg(t, TestCommandMulti, TestChannelName2+","+TestChannelName3, TestData2, false) + if err != nil { + t.FailNow() + } + resp, err = http.PostForm(urls.SavePubMsg, form) + if !TCheckResponse(t, resp, strconv.Itoa(2)) { + t.FailNow() + } + + // Publish message 3 - should go to no clients + + form, err = TSealForSavePubMsg(t, TestCommandChan, TestChannelNameUnused, TestData3, false) + if err != nil { + t.FailNow() + } + resp, err = http.PostForm(urls.SavePubMsg, form) + if !TCheckResponse(t, resp, strconv.Itoa(0)) { + t.FailNow() + } + + // Publish message 4 - should go to clients 1, 2, 3 + + form, err = TSealForSavePubMsg(t, TestCommandGlobal, "", TestData4, false) + if err != nil { + t.FailNow() + } + resp, err = http.PostForm(urls.SavePubMsg, form) + if !TCheckResponse(t, resp, strconv.Itoa(3)) { + t.FailNow() + } + + // Start client 4 + conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + if err != nil { + t.Error(err) + return + } + + doneWg.Add(1) + readyWg.Add(1) + go func(conn *websocket.Conn) { + TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TSendMessage(t, conn, 2, "sub", TestChannelName1) + TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil) + TSendMessage(t, conn, 3, "ready", 0) + TReceiveExpectedMessage(t, conn, 3, SuccessCommand, nil) + + // backlog message + TReceiveExpectedMessage(t, conn, -1, TestCommandChan, TestData1) + + readyWg.Done() + + conn.Close() + doneWg.Done() + }(conn) + + readyWg.Wait() + + doneWg.Wait() + server.Close() +} + +func BenchmarkUserSubscriptionSinglePublish(b *testing.B) { + var doneWg sync.WaitGroup + var readyWg sync.WaitGroup + + const TestChannelName = "room.testchannel" + const TestCommand = "testdata" + const TestData = "123456789" + + message := ClientMessage{MessageID: -1, Command: "testdata", Arguments: TestData} + + fmt.Println() + fmt.Println(b.N) + + var limit syscall.Rlimit + syscall.Getrlimit(syscall.RLIMIT_NOFILE, &limit) + + limit.Cur = TCountOpenFDs() + uint64(b.N)*2 + 100 + + if limit.Cur > limit.Max { + b.Skip("Open file limit too low") + return + } + + syscall.Setrlimit(syscall.RLIMIT_NOFILE, &limit) + + var server *httptest.Server + var urls TURLs + TSetup(&server, &urls) + defer unsubscribeAllClients() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conn, err := websocket.Dial(urls.Websocket, "", urls.Origin) + if err != nil { + b.Error(err) + break + } + doneWg.Add(1) + readyWg.Add(1) + go func(i int, conn *websocket.Conn) { + TSendMessage(b, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TSendMessage(b, conn, 2, "sub", TestChannelName) + + TReceiveExpectedMessage(b, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TReceiveExpectedMessage(b, conn, 2, SuccessCommand, nil) + + readyWg.Done() + + TReceiveExpectedMessage(b, conn, -1, TestCommand, TestData) + + conn.Close() + doneWg.Done() + }(i, conn) + } + + readyWg.Wait() + + fmt.Println("publishing...") + if PublishToChat(TestChannelName, message) != b.N { + b.Error("not enough sent") + server.CloseClientConnections() + panic("halting test instead of waiting") + } + doneWg.Wait() + fmt.Println("...done.") + + b.StopTimer() + server.Close() + server.CloseClientConnections() +} diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go new file mode 100644 index 00000000..cc9ba947 --- /dev/null +++ b/socketserver/internal/server/types.go @@ -0,0 +1,232 @@ +package server + +import ( + "encoding/json" + "github.com/satori/go.uuid" + "sync" + "time" +) + +const CryptoBoxKeyLength = 32 + +type ConfigFile struct { + // Numeric server id known to the backend + ServerId int + ListenAddr string + // Hostname of the socket server + SocketOrigin string + // URL to the backend server + BackendUrl string + // Memes go here + BannerHTML string + + // SSL/TLS + UseSSL bool + SSLCertificateFile string + SSLKeyFile string + + // Nacl keys + OurPrivateKey []byte + OurPublicKey []byte + BackendPublicKey []byte +} + +type ClientMessage struct { + // Message ID. Increments by 1 for each message sent from the client. + // When replying to a command, the message ID must be echoed. + // When sending a server-initiated message, this is -1. + MessageID int + // The command that the client wants from the server. + // When sent from the server, the literal string 'True' indicates success. + // Before sending, a blank Command will be converted into SuccessCommand. + Command Command + // Result of json.Unmarshal on the third field send from the client + Arguments interface{} + + origArguments string +} + +type AuthInfo struct { + // The client's claimed username on Twitch. + TwitchUsername string + + // Whether or not the server has validated the client's claimed username. + UsernameValidated bool +} + +type ClientInfo struct { + // The client ID. + // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. + ClientID uuid.UUID + + // The client's version. + // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. + Version string + + // This mutex protects writable data in this struct. + // If it seems to be a performance problem, we can split this. + Mutex sync.Mutex + + // TODO(riking) - does this need to be protected cross-thread? + AuthInfo + + // Username validation nonce. + ValidationNonce string + + // The list of chats this client is currently in. + // Protected by Mutex. + CurrentChannels []string + + // List of channels that we have not yet checked current chat-related channel info for. + // This lets us batch the backlog requests. + // Protected by Mutex. + PendingSubscriptionsBacklog []string + + // A timer that, when fired, will make the pending backlog requests. + // Usually nil. Protected by Mutex. + MakePendingRequests *time.Timer + + // Server-initiated messages should be sent here + // Never nil. + MessageChannel chan<- ClientMessage +} + +type tgmarray []TimestampedGlobalMessage +type tmmarray []TimestampedMultichatMessage + +func (ta tgmarray) Len() int { + return len(ta) +} +func (ta tgmarray) Less(i, j int) bool { + return ta[i].Timestamp.Before(ta[j].Timestamp) +} +func (ta tgmarray) Swap(i, j int) { + ta[i], ta[j] = ta[j], ta[i] +} +func (ta tgmarray) GetTime(i int) time.Time { + return ta[i].Timestamp +} +func (ta tmmarray) Len() int { + return len(ta) +} +func (ta tmmarray) Less(i, j int) bool { + return ta[i].Timestamp.Before(ta[j].Timestamp) +} +func (ta tmmarray) Swap(i, j int) { + ta[i], ta[j] = ta[j], ta[i] +} +func (ta tmmarray) GetTime(i int) time.Time { + return ta[i].Timestamp +} + +func (bct BacklogCacheType) Name() string { + switch bct { + case CacheTypeInvalid: + return "" + case CacheTypeNever: + return "never" + case CacheTypeTimestamps: + return "timed" + case CacheTypeLastOnly: + return "last" + case CacheTypePersistent: + return "persist" + } + panic("Invalid BacklogCacheType value") +} + +var CacheTypesByName = map[string]BacklogCacheType{ + "never": CacheTypeNever, + "timed": CacheTypeTimestamps, + "last": CacheTypeLastOnly, + "persist": CacheTypePersistent, +} + +func BacklogCacheTypeByName(name string) (bct BacklogCacheType) { + // CacheTypeInvalid is the zero value so it doesn't matter + bct, _ = CacheTypesByName[name] + return +} + +// Implements Stringer +func (bct BacklogCacheType) String() string { return bct.Name() } + +// Implements json.Marshaler +func (bct BacklogCacheType) MarshalJSON() ([]byte, error) { + return json.Marshal(bct.Name()) +} + +// Implements json.Unmarshaler +func (pbct *BacklogCacheType) UnmarshalJSON(data []byte) error { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + if str == "" { + *pbct = CacheTypeInvalid + return nil + } + val := BacklogCacheTypeByName(str) + if val != CacheTypeInvalid { + *pbct = val + return nil + } + return ErrorUnrecognizedCacheType +} + +func (mtt MessageTargetType) Name() string { + switch mtt { + case MsgTargetTypeInvalid: + return "" + case MsgTargetTypeSingle: + return "single" + case MsgTargetTypeChat: + return "chat" + case MsgTargetTypeMultichat: + return "multichat" + case MsgTargetTypeGlobal: + return "global" + } + panic("Invalid MessageTargetType value") +} + +var TargetTypesByName = map[string]MessageTargetType{ + "single": MsgTargetTypeSingle, + "chat": MsgTargetTypeChat, + "multichat": MsgTargetTypeMultichat, + "global": MsgTargetTypeGlobal, +} + +func MessageTargetTypeByName(name string) (mtt MessageTargetType) { + // MsgTargetTypeInvalid is the zero value so it doesn't matter + mtt, _ = TargetTypesByName[name] + return +} + +// Implements Stringer +func (mtt MessageTargetType) String() string { return mtt.Name() } + +// Implements json.Marshaler +func (mtt MessageTargetType) MarshalJSON() ([]byte, error) { + return json.Marshal(mtt.Name()) +} + +// Implements json.Unmarshaler +func (pmtt *MessageTargetType) UnmarshalJSON(data []byte) error { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + if str == "" { + *pmtt = MsgTargetTypeInvalid + return nil + } + mtt := MessageTargetTypeByName(str) + if mtt != MsgTargetTypeInvalid { + *pmtt = mtt + return nil + } + return ErrorUnrecognizedTargetType +} diff --git a/socketserver/internal/server/utils.go b/socketserver/internal/server/utils.go new file mode 100644 index 00000000..8dbff0f4 --- /dev/null +++ b/socketserver/internal/server/utils.go @@ -0,0 +1,161 @@ +package server + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "errors" + "golang.org/x/crypto/nacl/box" + "log" + "net/url" + "strconv" + "strings" +) + +func FillCryptoRandom(buf []byte) error { + remaining := len(buf) + for remaining > 0 { + count, err := rand.Read(buf) + if err != nil { + return err + } + remaining -= count + } + return nil +} + +func New4KByteBuffer() interface{} { + return make([]byte, 0, 4096) +} + +func SealRequest(form url.Values) (url.Values, error) { + var nonce [24]byte + var err error + + err = FillCryptoRandom(nonce[:]) + if err != nil { + return nil, err + } + + cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backendSharedKey) + + bufMessage := new(bytes.Buffer) + enc := base64.NewEncoder(base64.URLEncoding, bufMessage) + enc.Write(cipherMsg) + enc.Close() + cipherString := bufMessage.String() + + bufNonce := new(bytes.Buffer) + enc = base64.NewEncoder(base64.URLEncoding, bufNonce) + enc.Write(nonce[:]) + enc.Close() + nonceString := bufNonce.String() + + retval := url.Values{ + "nonce": []string{nonceString}, + "msg": []string{cipherString}, + "id": []string{strconv.Itoa(serverId)}, + } + + return retval, nil +} + +var ErrorShortNonce = errors.New("Nonce too short.") +var ErrorInvalidSignature = errors.New("Invalid signature or contents") + +func UnsealRequest(form url.Values) (url.Values, error) { + var nonce [24]byte + + nonceString := form.Get("nonce") + dec := base64.NewDecoder(base64.URLEncoding, strings.NewReader(nonceString)) + count, err := dec.Read(nonce[:]) + if err != nil { + return nil, err + } + if count != 24 { + return nil, ErrorShortNonce + } + + cipherString := form.Get("msg") + dec = base64.NewDecoder(base64.URLEncoding, strings.NewReader(cipherString)) + cipherBuffer := new(bytes.Buffer) + cipherBuffer.ReadFrom(dec) + + message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backendSharedKey) + if !ok { + return nil, ErrorInvalidSignature + } + + retValues, err := url.ParseQuery(string(message)) + if err != nil { + // Assume that the signature was accidentally correct but the contents were garbage + log.Print(err) + return nil, ErrorInvalidSignature + } + + return retValues, nil +} + +func AddToSliceS(ary *[]string, val string) bool { + slice := *ary + for _, v := range slice { + if v == val { + return false + } + } + + slice = append(slice, val) + *ary = slice + return true +} + +func RemoveFromSliceS(ary *[]string, val string) bool { + slice := *ary + var idx int = -1 + for i, v := range slice { + if v == val { + idx = i + break + } + } + if idx == -1 { + return false + } + + slice[idx] = slice[len(slice)-1] + slice = slice[:len(slice)-1] + *ary = slice + return true +} + +func AddToSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { + slice := *ary + for _, v := range slice { + if v == val { + return false + } + } + + slice = append(slice, val) + *ary = slice + return true +} + +func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) bool { + slice := *ary + var idx int = -1 + for i, v := range slice { + if v == val { + idx = i + break + } + } + if idx == -1 { + return false + } + + slice[idx] = slice[len(slice)-1] + slice = slice[:len(slice)-1] + *ary = slice + return true +} diff --git a/src/constants.js b/src/constants.js index 836d1bdd..0a1da39b 100644 --- a/src/constants.js +++ b/src/constants.js @@ -1,10 +1,12 @@ var SVGPATH = '', DEBUG = localStorage.ffzDebugMode == "true" && document.body.classList.contains('ffz-dev'), + WS_SERVERS = DEBUG ? ["localhost:8001", "catbag.frankerfacez.com"] : ["catbag.frankerfacez.com"], SERVER = DEBUG ? "//localhost:8000/" : "//cdn.frankerfacez.com/"; module.exports = { DEBUG: DEBUG, SERVER: SERVER, + WS_SERVERS: WS_SERVERS, API_SERVER: "//api.frankerfacez.com/", API_SERVER_2: "//direct-api.frankerfacez.com/", diff --git a/src/ember/channel.js b/src/ember/channel.js index a113db0a..a7adb41b 100644 --- a/src/ember/channel.js +++ b/src/ember/channel.js @@ -153,10 +153,10 @@ FFZ.prototype.setup_channel = function() { if ( id !== f.__old_host_target ) { if ( f.__old_host_target ) - f.ws_send("unsub_channel", f.__old_host_target); + f.ws_send("unsub", "channel." + f.__old_host_target); if ( id ) { - f.ws_send("sub_channel", id); + f.ws_send("sub", "channel." + id); f.__old_host_target = id; } else delete f.__old_host_target; @@ -208,7 +208,7 @@ FFZ.prototype._modify_cindex = function(view) { el = this.get('element'); f._cindex = this; - f.ws_send("sub_channel", id); + f.ws_send("sub", "channel." + id); el.setAttribute('data-channel', id); el.classList.add('ffz-channel'); @@ -620,7 +620,7 @@ FFZ.prototype._modify_cindex = function(view) { ffzTeardown: function() { var id = this.get('controller.id'); if ( id ) - f.ws_send("unsub_channel", id); + f.ws_send("unsub", "channel." + id); this.get('element').setAttribute('data-channel', ''); f._cindex = undefined; diff --git a/src/ember/room.js b/src/ember/room.js index c0184363..0978d6c5 100644 --- a/src/ember/room.js +++ b/src/ember/room.js @@ -591,7 +591,7 @@ FFZ.prototype.add_room = function(id, room) { } // Let the server know where we are. - this.ws_send("sub", id); + this.ws_send("sub", "room." + id); // See if we need history? if ( ! this.has_bttv && this.settings.chat_history && room && (room.get('messages.length') || 0) < 10 ) { @@ -625,7 +625,7 @@ FFZ.prototype.remove_room = function(id) { utils.update_css(this._room_style, id, null); // Let the server know we're gone and delete our data for this room. - this.ws_send("unsub", id); + this.ws_send("unsub", "room." + id); delete this.rooms[id]; // Clean up sets we aren't using any longer. diff --git a/src/socket.js b/src/socket.js index 7739a086..99847bfc 100644 --- a/src/socket.js +++ b/src/socket.js @@ -1,8 +1,13 @@ -var FFZ = window.FrankerFaceZ; +var FFZ = window.FrankerFaceZ, + constants = require('./constants'); FFZ.prototype._ws_open = false; FFZ.prototype._ws_delay = 0; FFZ.prototype._ws_last_iframe = 0; +FFZ.prototype._ws_host_idx = Math.floor(Math.random() * constants.WS_SERVERS.length) + 1; +if (constants.DEBUG) { + FFZ.prototype._ws_host_idx = 0; +} FFZ.ws_commands = {}; FFZ.ws_on_close = []; @@ -12,6 +17,8 @@ FFZ.ws_on_close = []; // Socket Creation // ---------------- +// Attempt to authenticate to the socket server as a real browser by loading the root page. +// e.g. cloudflare ddos check FFZ.prototype.ws_iframe = function() { this._ws_last_iframe = Date.now(); var ifr = document.createElement('iframe'), @@ -39,7 +46,7 @@ FFZ.prototype.ws_create = function() { this._ws_pending = this._ws_pending || []; try { - ws = this._ws_sock = new WebSocket("ws://catbag.frankerfacez.com/"); + ws = this._ws_sock = new WebSocket("ws://" + constants.WS_SERVERS[this._ws_host_idx] + "/"); } catch(err) { this._ws_exists = false; return this.log("Error Creating WebSocket: " + err); @@ -53,17 +60,7 @@ FFZ.prototype.ws_create = function() { f._ws_last_iframe = Date.now(); f.log("Socket connected."); - // Check for incognito. We don't want to do a hello in incognito mode. - var fs = window.RequestFileSystem || window.webkitRequestFileSystem; - if (!fs) - // Assume not. - f.ws_send("hello", ["ffz_" + FFZ.version_info, localStorage.ffzClientId], f._ws_on_hello.bind(f)); - - else - fs(window.TEMPORARY, 100, - f.ws_send.bind(f, "hello", ["ffz_" + FFZ.version_info, localStorage.ffzClientId], f._ws_on_hello.bind(f)), - f.log.bind(f, "Operating in Incognito Mode.")); - + f.ws_send("hello", ["ffz_" + FFZ.version_info, localStorage.ffzClientId], f._ws_on_hello.bind(f)); var user = f.get_user(); if ( user ) @@ -73,8 +70,8 @@ FFZ.prototype.ws_create = function() { if ( f.is_dashboard ) { var match = location.pathname.match(/\/([^\/]+)/); if ( match ) { - f.ws_send("sub", match[1]); - f.ws_send("sub_channel", match[1]); + f.ws_send("sub", "room." + match[1]); + f.ws_send("sub", "channel." + match[1]); } } @@ -83,7 +80,7 @@ FFZ.prototype.ws_create = function() { if ( ! f.rooms.hasOwnProperty(room_id) || ! f.rooms[room_id] ) continue; - f.ws_send("sub", room_id); + f.ws_send("sub", "room." + room_id); if ( f.rooms[room_id].needs_history ) { f.rooms[room_id].needs_history = false; @@ -98,10 +95,10 @@ FFZ.prototype.ws_create = function() { hosted_id = f._cindex.get('controller.hostModeTarget.id'); if ( channel_id ) - f.ws_send("sub_channel", channel_id); + f.ws_send("sub", "channel." + channel_id); if ( hosted_id ) - f.ws_send("sub_channel", hosted_id); + f.ws_send("sub", "channel." + hosted_id); } // Send any pending commands. @@ -112,14 +109,35 @@ FFZ.prototype.ws_create = function() { var d = pending[i]; f.ws_send(d[0], d[1], d[2]); } + + // If reconnecting, get the backlog that we missed. + if ( f._ws_offline_time ) { + var timestamp = f._ws_offline_time; + delete f._ws_offline_time; + f.ws_send("ready", timestamp); + } else { + f.ws_send("ready", 0); + } + } + + ws.onerror = function() { + if ( ! f._ws_offline_time ) { + f._ws_offline_time = new Date().getTime(); + } + + // Cycle selected server + f._ws_host_idx = (f._ws_host_idx + 1) % constants.WS_SERVERS.length; } ws.onclose = function(e) { f.log("Socket closed. (Code: " + e.code + ", Reason: " + e.reason + ")"); f._ws_open = false; + if ( ! f._ws_offline_time ) { + f._ws_offline_time = new Date().getTime(); + } // When the connection closes, run our callbacks. - for(var i=0; i < FFZ.ws_on_close.length; i++) { + for (var i=0; i < FFZ.ws_on_close.length; i++) { try { FFZ.ws_on_close[i].bind(f)(); } catch(err) { @@ -127,6 +145,9 @@ FFZ.prototype.ws_create = function() { } } + // Cycle selected server + f._ws_host_idx = (f._ws_host_idx + 1) % constants.WS_SERVERS.length; + if ( f._ws_delay > 10000 ) { var ua = navigator.userAgent.toLowerCase(); if ( Date.now() - f._ws_last_iframe > 1800000 && !(ua.indexOf('chrome') === -1 && ua.indexOf('safari') !== -1) ) @@ -166,6 +187,11 @@ FFZ.prototype.ws_create = function() { else f.log("Invalid command: " + cmd, data, false, true); + } else if ( cmd === "error" ) { + f.log("Socket server reported error: " + data); + if (f._ws_callbacks[request] ) + delete f._ws_callbacks[request]; + } else { var success = cmd === 'True', has_callback = typeof f._ws_callbacks[request] === "function"; @@ -180,7 +206,7 @@ FFZ.prototype.ws_create = function() { f.error("Callback for " + request + ": " + err); } - f._ws_callbacks[request] = undefined; + delete f._ws_callbacks[request]; } } } diff --git a/src/tokenize.js b/src/tokenize.js index 12871234..ce514300 100644 --- a/src/tokenize.js +++ b/src/tokenize.js @@ -201,7 +201,7 @@ var FFZ = window.FrankerFaceZ, return; this._link_data[href] = data; - data.unsafe = false; + //data.unsafe = false; var tooltip = build_link_tooltip.bind(this)(href), links, no_trail = href.charAt(href.length-1) == "/" ? href.substr(0, href.length-1) : null;