diff --git a/socketserver/SocketServerDesign.svg b/socketserver/SocketServerDesign.svg new file mode 100644 index 00000000..6a141697 --- /dev/null +++ b/socketserver/SocketServerDesign.svg @@ -0,0 +1,1127 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + Client + + Client + + Client + + Client + + Client + + Client + + Client + + Client + + Client + + Client + + + Socket Server + andknuckles.frankerfacez.com + + Socket Server + catbag.frankerfacez.com + + Socket Server + tuturu.frankerfacez.com + TLS + + TLS + TLS / websocket + + + TLS + + HTTP / websocket + + + + + + TLS + TLS + TLS + TLS + TLS + Socket Backend + catbag.frankerfacez.com + + NaCl / HTTP + + + NaCl / HTTP + NaCl / HTTP + + + (out of scope) + www.frankerfacez.com + Web Server + + + + Various otherservices + www.twitter.com, bit.ly,www.speedrunslive.com + + (out of scope) + + diff --git a/socketserver/cmd/ffzsocketserver/.gitignore b/socketserver/cmd/ffzsocketserver/.gitignore new file mode 100644 index 00000000..43852e2e --- /dev/null +++ b/socketserver/cmd/ffzsocketserver/.gitignore @@ -0,0 +1,3 @@ +config.json +ffzsocketserver +uniques/ diff --git a/socketserver/cmd/ffzsocketserver/console.go b/socketserver/cmd/ffzsocketserver/console.go new file mode 100644 index 00000000..5d3063d9 --- /dev/null +++ b/socketserver/cmd/ffzsocketserver/console.go @@ -0,0 +1,136 @@ +package main + +import ( + "fmt" + "runtime" + "strconv" + "strings" + + "../../server" + "github.com/abiosoft/ishell" + "github.com/gorilla/websocket" +) + +func commandLineConsole() { + + shell := ishell.NewShell() + + shell.Register("help", func(args ...string) (string, error) { + shell.PrintCommands() + return "", nil + }) + + shell.Register("clientcount", func(args ...string) (string, error) { + server.GlobalSubscriptionLock.RLock() + count := len(server.GlobalSubscriptionInfo) + server.GlobalSubscriptionLock.RUnlock() + return fmt.Sprintln(count, "clients connected"), nil + }) + + shell.Register("globalnotice", func(args ...string) (string, error) { + msg := server.ClientMessage{ + MessageID: -1, + Command: "message", + Arguments: args[0], + } + server.PublishToAll(msg) + return "Message sent.", nil + }) + + shell.Register("publish", func(args ...string) (string, error) { + if len(args) < 4 { + return "Usage: publish [room.sirstendec | _ALL] -1 reload_ff 23", nil + } + + target := args[0] + line := strings.Join(args[1:], " ") + msg := server.ClientMessage{} + err := server.UnmarshalClientMessage([]byte(line), websocket.TextMessage, &msg) + if err != nil { + return "", err + } + + var count int + if target == "_ALL" { + count = server.PublishToAll(msg) + } else { + count = server.PublishToChannel(target, msg) + } + return fmt.Sprintf("Published to %d clients", count), nil + }) + + shell.Register("memstatsbysize", func(args ...string) (string, error) { + runtime.GC() + + m := runtime.MemStats{} + runtime.ReadMemStats(&m) + for _, val := range m.BySize { + if val.Mallocs == 0 { + continue + } + shell.Print(fmt.Sprintf("%5d: %6d outstanding (%d total)\n", val.Size, val.Mallocs-val.Frees, val.Mallocs)) + } + shell.Println(m.NumGC, "collections occurred") + return "", nil + }) + + shell.Register("authorizeeveryone", func(args ...string) (string, error) { + if len(args) == 0 { + if server.Configuration.SendAuthToNewClients { + return "All clients are recieving auth challenges upon claiming a name.", nil + } + return "All clients are not recieving auth challenges upon claiming a name.", nil + } else if args[0] == "on" { + server.Configuration.SendAuthToNewClients = true + return "All new clients will recieve auth challenges upon claiming a name.", nil + } else if args[0] == "off" { + server.Configuration.SendAuthToNewClients = false + return "All new clients will not recieve auth challenges upon claiming a name.", nil + } + return "Usage: authorizeeveryone [ on | off ]", nil + }) + + shell.Register("kickclients", func(args ...string) (string, error) { + if len(args) == 0 { + return "Please enter either a count or a fraction of clients to kick.", nil + } + input, err := strconv.ParseFloat(args[0], 64) + if err != nil { + return "Argument must be a number", err + } + var count int + if input >= 1 { + count = int(input) + } else { + server.GlobalSubscriptionLock.RLock() + count = int(float64(len(server.GlobalSubscriptionInfo)) * input) + server.GlobalSubscriptionLock.RUnlock() + } + + msg := server.ClientMessage{Arguments: &server.CloseRebalance} + server.GlobalSubscriptionLock.RLock() + defer server.GlobalSubscriptionLock.RUnlock() + + kickCount := 0 + for i, cl := range server.GlobalSubscriptionInfo { + if i >= count { + break + } + select { + case cl.MessageChannel <- msg: + case <-cl.MsgChannelIsDone: + } + kickCount++ + } + return fmt.Sprintf("Kicked %d clients", kickCount), nil + }) + + shell.Register("panic", func(args ...string) (string, error) { + go func() { + panic("requested panic") + }() + return "", nil + }) + + shell.Start() +} diff --git a/socketserver/cmd/ffzsocketserver/index.html b/socketserver/cmd/ffzsocketserver/index.html new file mode 100644 index 00000000..e28ce2ee --- /dev/null +++ b/socketserver/cmd/ffzsocketserver/index.html @@ -0,0 +1,13 @@ + +CatBag + +
+
+
+
+
+
+ A FrankerFaceZ Service + — CatBag by Wolsk +
+
diff --git a/socketserver/cmd/ffzsocketserver/socketserver.go b/socketserver/cmd/ffzsocketserver/socketserver.go index 5de7a059..1d67bddc 100644 --- a/socketserver/cmd/ffzsocketserver/socketserver.go +++ b/socketserver/cmd/ffzsocketserver/socketserver.go @@ -1,7 +1,6 @@ -package main // import "bitbucket.org/stendec/frankerfacez/socketserver/cmd/socketserver" +package main // import "bitbucket.org/stendec/frankerfacez/socketserver/cmd/ffzsocketserver" import ( - "../../internal/server" "encoding/json" "flag" "fmt" @@ -9,16 +8,23 @@ import ( "log" "net/http" "os" + + "../../server" ) -var configFilename *string = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.") -var generateKeys *bool = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.") +import _ "net/http/pprof" + +var configFilename = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.") +var flagGenerateKeys = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.") + +var BuildTime string = "build not stamped" +var BuildHash string = "build not stamped" func main() { flag.Parse() - if *generateKeys { - GenerateKeys(*configFilename) + if *flagGenerateKeys { + generateKeys(*configFilename) return } @@ -40,24 +46,30 @@ func main() { log.Fatal(err) } - httpServer := &http.Server{ - Addr: conf.ListenAddr, - } + // logFile, err := os.OpenFile("output.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + // if err != nil { + // log.Fatal("Could not create logfile: ", err) + // } - server.SetupServerAndHandle(conf, httpServer.TLSConfig, nil) + server.SetupServerAndHandle(conf, http.DefaultServeMux) + server.SetBuildStamp(BuildTime, BuildHash) + + go commandLineConsole() if conf.UseSSL { - err = httpServer.ListenAndServeTLS(conf.SSLCertificateFile, conf.SSLKeyFile) - } else { - err = httpServer.ListenAndServe() + go func() { + if err := http.ListenAndServeTLS(conf.SSLListenAddr, conf.SSLCertificateFile, conf.SSLKeyFile, http.DefaultServeMux); err != nil { + log.Fatal("ListenAndServeTLS: ", err) + } + }() } - if err != nil { + if err = http.ListenAndServe(conf.ListenAddr, http.DefaultServeMux); err != nil { log.Fatal("ListenAndServe: ", err) } } -func GenerateKeys(outputFile string) { +func generateKeys(outputFile string) { if flag.NArg() < 1 { fmt.Println("Specify a numeric server ID after -genkeys") os.Exit(2) diff --git a/socketserver/cmd/mergecounts/.gitignore b/socketserver/cmd/mergecounts/.gitignore new file mode 100644 index 00000000..5b97e97f --- /dev/null +++ b/socketserver/cmd/mergecounts/.gitignore @@ -0,0 +1 @@ +mergecounts diff --git a/socketserver/cmd/mergecounts/mergecounts.go b/socketserver/cmd/mergecounts/mergecounts.go new file mode 100644 index 00000000..8ce8b307 --- /dev/null +++ b/socketserver/cmd/mergecounts/mergecounts.go @@ -0,0 +1,102 @@ +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "net/http" + "os" + + "../../server" + "github.com/clarkduvall/hyperloglog" +) + +var SERVERS = []string{ + "https://catbag.frankerfacez.com", + "https://andknuckles.frankerfacez.com", + "https://tuturu.frankerfacez.com", +} + +const folderPrefix = "/hll/" + +const HELP = ` +Usage: mergecounts [filename] + +Downloads the file /hll/filename from the 3 FFZ socket servers, merges the contents, and prints the total cardinality. + +Filename should be in one of the following formats: + + daily-25-12-2015.gob + weekly-51-2015.gob + monthly-12-2015.gob +` + +var forceWrite = flag.Bool("f", false, "force servers to write out their current") + +func main() { + flag.Parse() + if flag.NArg() < 1 { + fmt.Print(HELP) + os.Exit(2) + return + } + + filename := flag.Arg(0) + hll, err := DownloadAll(filename) + if err != nil { + fmt.Println(err) + os.Exit(1) + return + } + + fmt.Println(hll.Count()) +} + +func ForceWrite() { + for _, server := range SERVERS { + resp, err := http.Get(fmt.Sprintf("%s/hll_force_write", server)) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + resp.Body.Close() + } +} + +func DownloadAll(filename string) (*hyperloglog.HyperLogLogPlus, error) { + result, _ := hyperloglog.NewPlus(server.CounterPrecision) + + for _, server := range SERVERS { + if *forceWrite { + resp, err := http.Get(fmt.Sprintf("%s/hll_force_write", server)) + if err == nil { + resp.Body.Close() + } + } + singleHLL, err := DownloadHLL(fmt.Sprintf("%s%s%s", server, folderPrefix, filename)) + if err != nil { + return nil, err + } + result.Merge(singleHLL) + } + + return result, nil +} + +func DownloadHLL(url string) (*hyperloglog.HyperLogLogPlus, error) { + result, _ := hyperloglog.NewPlus(server.CounterPrecision) + + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + dec := gob.NewDecoder(resp.Body) + err = dec.Decode(result) + if err != nil { + return nil, err + } + fmt.Println(url, result.Count()) + + return result, nil +} diff --git a/socketserver/cmd/statsweb/.gitignore b/socketserver/cmd/statsweb/.gitignore new file mode 100644 index 00000000..3dddabd8 --- /dev/null +++ b/socketserver/cmd/statsweb/.gitignore @@ -0,0 +1,3 @@ +database.sqlite +gobcache/ +statsweb diff --git a/socketserver/cmd/statsweb/config.go b/socketserver/cmd/statsweb/config.go new file mode 100644 index 00000000..04591f44 --- /dev/null +++ b/socketserver/cmd/statsweb/config.go @@ -0,0 +1,74 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" +) + +type ConfigFile struct { + ListenAddr string + DatabaseLocation string + GobFilesLocation string +} + +func makeConfig() { + config.ListenAddr = "localhost:3000" + home, ok := os.LookupEnv("HOME") + if ok { + config.DatabaseLocation = fmt.Sprintf("%s/.ffzstatsweb/database.sqlite", home) + config.GobFilesLocation = fmt.Sprintf("%s/.ffzstatsweb/gobcache", home) + os.MkdirAll(config.GobFilesLocation, 0755) + } else { + config.DatabaseLocation = "./database.sqlite" + config.GobFilesLocation = "./gobcache" + os.MkdirAll(config.GobFilesLocation, 0755) + } + file, err := os.Create(*configLocation) + if err != nil { + fmt.Printf("Error: could not create config file: %v\n", err) + os.Exit(ExitCodeBadConfig) + return + } + enc := json.NewEncoder(file) + err = enc.Encode(config) + if err != nil { + fmt.Printf("Error: could not write config file: %v\n", err) + os.Exit(ExitCodeBadConfig) + return + } + err = file.Close() + if err != nil { + fmt.Printf("Error: could not write config file: %v\n", err) + os.Exit(ExitCodeBadConfig) + return + } + return +} + +func loadConfig() { + file, err := os.Open(*configLocation) + if err != nil { + if os.IsNotExist(err) { + fmt.Println("You must create a config file with -genconf") + } else { + fmt.Printf("Error: could not load config file: %v", err) + } + os.Exit(ExitCodeBadConfig) + return + } + dec := json.NewDecoder(file) + err = dec.Decode(&config) + if err != nil { + fmt.Printf("Error: could not load config file: %v\n", err) + os.Exit(ExitCodeBadConfig) + return + } + err = file.Close() + if err != nil { + fmt.Printf("Error: could not load config file: %v\n", err) + os.Exit(ExitCodeBadConfig) + return + } + return +} diff --git a/socketserver/cmd/statsweb/config.json b/socketserver/cmd/statsweb/config.json new file mode 100644 index 00000000..07cfa8e3 --- /dev/null +++ b/socketserver/cmd/statsweb/config.json @@ -0,0 +1 @@ +{"ListenAddr":"localhost:3000","DatabaseLocation":"./database.sqlite","GobFilesLocation":"./gobcache"} diff --git a/socketserver/cmd/statsweb/html.go b/socketserver/cmd/statsweb/html.go new file mode 100644 index 00000000..4cbdf89c --- /dev/null +++ b/socketserver/cmd/statsweb/html.go @@ -0,0 +1,62 @@ +package main + +import ( + "html/template" + "net/http" + "time" + + "bitbucket.org/stendec/frankerfacez/socketserver/server" +) + +type CalendarData struct { + Weeks []CalWeekData +} +type CalWeekData struct { + Days []CalDayData +} +type CalDayData struct { + NoData bool + Date int + UniqUsers int +} + +type CalendarMonthInfo struct { + Year int + Month time.Month + // Ranges from -5 to +1. + // A value of +1 means the 1st of the month is a Sunday. + // A value of 0 means the 1st of the month is a Monday. + // A value of -5 means the 1st of the month is a Saturday. + FirstSundayOffset int + // True if the calendar for this month needs six sundays. + NeedSixSundays bool +} + +func GetMonthInfo(at time.Time) CalendarMonthInfo { + year, month, _ := at.Date() + monthStartWeekday := time.Date(year, month, 1, 0, 0, 0, 0, server.CounterLocation).Weekday() + // 1 (start of month) - weekday of start of month = day offset of start of week at start of mont + monthWeekStartDay := 1 - int(monthStartWeekday) + // first day on calendar + 6 weeks < end of month? + sixthSundayDay := monthWeekStartDay + 5*7 + sixthSundayDate := time.Date(year, month, sixthSundayDay, 0, 0, 0, 0, server.CounterLocation) + var needSixSundays bool = false + if sixthSundayDate.Month() == month { + needSixSundays = true + } + + return CalendarMonthInfo{ + Year: year, + Month: month, + FirstSundayOffset: monthWeekStartDay, + NeedSixSundays: needSixSundays, + } +} + +func renderCalendar(w http.ResponseWriter, at time.Time) { + layout, err := template.ParseFiles("./webroot/layout.template.html", "./webroot/cal_entry.hbs", "./webroot/calendar.hbs") + data := CalendarData{} + data.Weeks = make([]CalWeekData, 6) + _ = layout + _ = err +} diff --git a/socketserver/cmd/statsweb/servers.go b/socketserver/cmd/statsweb/servers.go new file mode 100644 index 00000000..d4818c61 --- /dev/null +++ b/socketserver/cmd/statsweb/servers.go @@ -0,0 +1,342 @@ +package main + +import ( + "encoding/gob" + "errors" + "fmt" + "io" + "net/http" + "os" + "sync" + "time" + + "bitbucket.org/stendec/frankerfacez/socketserver/server" + "github.com/clarkduvall/hyperloglog" + "github.com/hashicorp/golang-lru" +) + +type serverFilter struct { + // Mode is false for blacklist, true for whitelist + Mode bool + Special []string +} + +const serverFilterModeBlacklist = false +const serverFilterModeWhitelist = true + +func (sf *serverFilter) IsServerAllowed(server *serverInfo) bool { + name := server.subdomain + for _, v := range sf.Special { + if name == v { + return sf.Mode + } + } + return !sf.Mode +} + +func (sf *serverFilter) Remove(server string) { + if sf.Mode == serverFilterModeWhitelist { + var idx int = -1 + for i, v := range sf.Special { + if server == v { + idx = i + break + } + } + if idx != -1 { + var lenMinusOne = len(sf.Special) - 1 + sf.Special[idx] = sf.Special[lenMinusOne] + sf.Special = sf.Special[:lenMinusOne] + } + } else { + for _, v := range sf.Special { + if server == v { + return + } + } + sf.Special = append(sf.Special, server) + } +} + +func (sf *serverFilter) Add(server string) { + if sf.Mode == serverFilterModeBlacklist { + var idx int = -1 + for i, v := range sf.Special { + if server == v { + idx = i + break + } + } + if idx != -1 { + var lenMinusOne = len(sf.Special) - 1 + sf.Special[idx] = sf.Special[lenMinusOne] + sf.Special = sf.Special[:lenMinusOne] + } + } else { + for _, v := range sf.Special { + if server == v { + return + } + } + sf.Special = append(sf.Special, server) + } +} + +var serverFilterAll serverFilter = serverFilter{Mode: serverFilterModeBlacklist} +var serverFilterNone serverFilter = serverFilter{Mode: serverFilterModeWhitelist} + +func cannotCacheHLL(at time.Time) bool { + now := time.Now() + now.Add(-72 * time.Hour) + return now.Before(at) +} + +var ServerNames = []string{ + "catbag", + "andknuckles", + "tuturu", +} + +var httpClient http.Client + +const serverNameSuffix = ".frankerfacez.com" + +const failedStateThreshold = 4 + +var ErrServerInFailedState = errors.New("server has been down recently and not recovered") +var ErrServerHasNoData = errors.New("no data for specified date") + +type errServerNot200 struct { + StatusCode int + StatusText string +} + +func (e *errServerNot200) Error() string { + return fmt.Sprintf("The server responded with %d %s", e.StatusCode, e.StatusText) +} +func Not200Error(resp *http.Response) *errServerNot200 { + return &errServerNot200{ + StatusCode: resp.StatusCode, + StatusText: resp.Status, + } +} + +func getHLLCacheKey(at time.Time) string { + year, month, day := at.Date() + return fmt.Sprintf("%d-%d-%d", year, month, day) +} + +type serverInfo struct { + subdomain string + + memcache *lru.TwoQueueCache + + FailedState bool + FailureErr error + failureCount int + + lock sync.Mutex +} + +func (si *serverInfo) Setup(subdomain string) { + si.subdomain = subdomain + tq, err := lru.New2Q(60) + if err != nil { + panic(err) + } + si.memcache = tq +} + +// GetHLL gets the HLL from +func (si *serverInfo) GetHLL(at time.Time) (*hyperloglog.HyperLogLogPlus, error) { + if cannotCacheHLL(at) { + fmt.Println(at) + err := si.ForceWrite() + if err != nil { + return nil, err + } + reader, err := si.DownloadHLL(at) + if err != nil { + return nil, err + } + fmt.Printf("downloaded uncached hll %s:%s\n", si.subdomain, getHLLCacheKey(at)) + defer si.DeleteHLL(at) + return loadHLLFromStream(reader) + } + + hll, ok := si.PeekHLL(at) + if ok { + fmt.Printf("got cached hll %s:%s\n", si.subdomain, getHLLCacheKey(at)) + return hll, nil + } + + reader, err := si.OpenHLL(at) + if err != nil { + // continue to download + } else { + //fmt.Printf("opened hll %s:%s\n", si.subdomain, getHLLCacheKey(at)) + return loadHLLFromStream(reader) + } + + reader, err = si.DownloadHLL(at) + if err != nil { + if err == ErrServerHasNoData { + return hyperloglog.NewPlus(server.CounterPrecision) + } + return nil, err + } + fmt.Printf("downloaded hll %s:%s\n", si.subdomain, getHLLCacheKey(at)) + return loadHLLFromStream(reader) +} + +func loadHLLFromStream(reader io.ReadCloser) (*hyperloglog.HyperLogLogPlus, error) { + defer reader.Close() + hll, _ := hyperloglog.NewPlus(server.CounterPrecision) + dec := gob.NewDecoder(reader) + err := dec.Decode(hll) + if err != nil { + return nil, err + } + return hll, nil +} + +// PeekHLL tries to grab a HLL from the memcache without downloading it or hitting the disk. +func (si *serverInfo) PeekHLL(at time.Time) (*hyperloglog.HyperLogLogPlus, bool) { + if cannotCacheHLL(at) { + return nil, false + } + + key := getHLLCacheKey(at) + hll, ok := si.memcache.Get(key) + if ok { + return hll.(*hyperloglog.HyperLogLogPlus), true + } + + return nil, false +} + +func (si *serverInfo) DeleteHLL(at time.Time) { + year, month, day := at.Date() + filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day) + err := os.Remove(filename) + if err != nil { + fmt.Println(err) + } +} + +func (si *serverInfo) OpenHLL(at time.Time) (io.ReadCloser, error) { + year, month, day := at.Date() + filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day) + + file, err := os.Open(filename) + if err == nil { + return file, nil + } + // file is nil + if !os.IsNotExist(err) { + return nil, err + } + + return nil, os.ErrNotExist +} + +func (si *serverInfo) DownloadHLL(at time.Time) (io.ReadCloser, error) { + if si.FailedState { + return nil, ErrServerInFailedState + } + si.lock.Lock() + defer si.lock.Unlock() + + year, month, day := at.Date() + url := fmt.Sprintf("https://%s/hll/daily-%d-%d-%d.gob", si.Domain(), day, month, year) + resp, err := httpClient.Get(url) + if err != nil { + si.ServerFailed(err) + return nil, err + } + if resp.StatusCode == 404 { + return nil, ErrServerHasNoData + } + if resp.StatusCode != 200 { + err = Not200Error(resp) + si.ServerFailed(err) + return nil, err + } + + filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day) + file, err := os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644) + if os.IsNotExist(err) { + os.MkdirAll(fmt.Sprintf("%s/%s", config.GobFilesLocation, si.subdomain), 0755) + file, err = os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644) + } + if err != nil { + resp.Body.Close() + return nil, fmt.Errorf("downloadhll: error opening file for writing: %v", err) + } + + return &teeReadCloser{r: resp.Body, w: file}, nil +} + +func (si *serverInfo) ForceWrite() error { + if si.FailedState { + return ErrServerInFailedState + } + + url := fmt.Sprintf("https://%s/hll_force_write", si.Domain()) + resp, err := httpClient.Get(url) + if err != nil { + si.ServerFailed(err) + return err + } + if resp.StatusCode != 200 { + err = Not200Error(resp) + si.ServerFailed(err) + return err + } + resp.Body.Close() + return nil +} + +func (si *serverInfo) Domain() string { + return fmt.Sprintf("%s%s", si.subdomain, serverNameSuffix) +} + +func (si *serverInfo) ServerFailed(err error) { + si.lock.Lock() + defer si.lock.Unlock() + si.failureCount++ + if si.failureCount > failedStateThreshold { + fmt.Printf("Server %s entering failed state\n", si.subdomain) + si.FailedState = true + si.FailureErr = err + go recoveryCheck(si) + } +} + +func recoveryCheck(si *serverInfo) { + // TODO check for server recovery +} + +type teeReadCloser struct { + r io.ReadCloser + w io.WriteCloser +} + +func (t *teeReadCloser) Read(p []byte) (n int, err error) { + n, err = t.r.Read(p) + if n > 0 { + if n, err := t.w.Write(p[:n]); err != nil { + return n, err + } + } + return +} + +func (t *teeReadCloser) Close() error { + err1 := t.r.Close() + err2 := t.w.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/socketserver/cmd/statsweb/statsweb.go b/socketserver/cmd/statsweb/statsweb.go new file mode 100644 index 00000000..b36cc768 --- /dev/null +++ b/socketserver/cmd/statsweb/statsweb.go @@ -0,0 +1,322 @@ +package main + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "bitbucket.org/stendec/frankerfacez/socketserver/server" + "github.com/clarkduvall/hyperloglog" +) + +var _ = os.Exit + +var configLocation = flag.String("config", "./config.json", "Location of the configuration file. Defaults to ./config.json") +var genConfig = flag.Bool("genconf", false, "Generate a new configuration file.") + +var config ConfigFile + +const ExitCodeBadConfig = 2 + +var allServers []*serverInfo + +func main() { + flag.Parse() + + if *genConfig { + makeConfig() + return + } + + loadConfig() + + allServers = make([]*serverInfo, len(ServerNames)) + for i, v := range ServerNames { + allServers[i] = &serverInfo{} + allServers[i].Setup(v) + } + + //printEveryDay() + //os.Exit(0) + http.HandleFunc("/api/get", ServeAPIGet) + http.ListenAndServe(config.ListenAddr, http.DefaultServeMux) +} + +func printEveryDay() { + year := 2015 + month := 12 + day := 23 + filter := serverFilterAll + var filter1, filter2, filter3 serverFilter + filter1.Mode = serverFilterModeWhitelist + filter2.Mode = serverFilterModeWhitelist + filter3.Mode = serverFilterModeWhitelist + filter1.Add(allServers[0].subdomain) + filter2.Add(allServers[1].subdomain) + filter3.Add(allServers[2].subdomain) + stopTime := time.Now() + var at time.Time + const timeFmt = "2006-01-02" + for ; stopTime.After(at); day++ { + at = time.Date(year, time.Month(month), day, 0, 0, 0, 0, server.CounterLocation) + hll, _ := hyperloglog.NewPlus(server.CounterPrecision) + hll1, _ := hyperloglog.NewPlus(server.CounterPrecision) + hll2, _ := hyperloglog.NewPlus(server.CounterPrecision) + hll3, _ := hyperloglog.NewPlus(server.CounterPrecision) + addSingleDate(at, filter, hll) + addSingleDate(at, filter1, hll1) + addSingleDate(at, filter2, hll2) + addSingleDate(at, filter3, hll3) + fmt.Printf("%s\t%d\t%d\t%d\t%d\n", at.Format(timeFmt), hll.Count(), hll1.Count(), hll2.Count(), hll3.Count()) + } +} + +const RequestURIName = "q" +const separatorRange = "~" +const separatorAdd = " " +const separatorServer = "@" +const jsonErrMalformedRequest = `{"status":"error","error":"malformed request uri"}` +const jsonErrBlankRequest = `{"status":"error","error":"no queries given"}` +const statusError = "error" +const statusPartial = "partial" +const statusOk = "ok" + +type apiResponse struct { + Status string `json:"status"` + Responses []requestResponse `json:"resp"` +} + +type requestResponse struct { + Status string `json:"status"` + Request string `json:"req"` + Error string `json:"error,omitempty"` + Count uint64 `json:"count,omitempty"` +} + +func ServeAPIGet(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + u, err := url.ParseRequestURI(r.RequestURI) + if err != nil { + w.WriteHeader(400) + fmt.Fprint(w, jsonErrMalformedRequest) + return + } + + query := u.Query() + reqCount := len(query[RequestURIName]) + if reqCount == 0 { + w.WriteHeader(400) + fmt.Fprint(w, jsonErrBlankRequest) + return + } + + resp := apiResponse{Status: statusOk} + resp.Responses = make([]requestResponse, reqCount) + for i, v := range query[RequestURIName] { + if len(v) == 0 { + continue + } + resp.Responses[i] = ProcessSingleGetRequest(v) + } + for _, v := range resp.Responses { + if v.Status != statusOk { + resp.Status = statusPartial + break + } + } + + w.WriteHeader(200) + enc := json.NewEncoder(w) + enc.Encode(resp) +} + +var errRangeFormatIncorrect = errors.New("incorrect range format, must be yyyy-mm-dd~yyyy-mm-dd") + +// ProcessSingleGetRequest takes a request string and pulls the unique user data for the given dates and filters. +// +// The request string is in the following format: +// +// Request = AddDateRanges [ "@" ServerFilter ] . +// ServerFilter = [ "!" ] ServerName { " " ServerName } . +// ServerName = { "a" … "z" } . +// AddDateRanges = DateMaybeRange { " " DateMaybeRange } . +// DateMaybeRange = DateRange | Date . +// DateRange = Date "~" Date . +// Date = Year "-" Month "-" Day . +// Year = number number number number . +// Month = number number . +// Day = number number . +// number = "0" … "9" . +// +// Example of a well-formed request: +// +// 2016-01-04~2016-01-08 2016-01-11~2016-01-15@andknuckles tuturu +// +// Remember that spaces are urlencoded as "+", so the HTTP request to send to retrieve that data would be this: +// +// /api/get?q=2016-01-04~2016-01-08+2016-01-11~2016-01-15%40andknuckles+tuturu +// +// If a ServerFilter is specified, only users connecting to the specified servers will be included in the count. +// +// It does not matter if a date is specified multiple times, due to the data format used. +func ProcessSingleGetRequest(req string) (result requestResponse) { + fmt.Println("processing request:", req) + hll, _ := hyperloglog.NewPlus(server.CounterPrecision) + + result.Request = req + result.Status = statusOk + filter := serverFilterAll + + collectError := func(err error) bool { + if err == ErrServerInFailedState { + result.Status = statusPartial + return false + } else if err != nil { + result.Status = statusError + result.Error = err.Error() + return true + } + return false + } + + serverSplit := strings.Split(req, separatorServer) + if len(serverSplit) == 2 { + filter = serverFilterNone + serversOnly := strings.Split(serverSplit[1], separatorAdd) + for _, v := range serversOnly { + filter.Add(v) + } + } + + addSplit := strings.Split(serverSplit[0], separatorAdd) + +outerLoop: + for _, split1 := range addSplit { + if len(split1) == 0 { + continue + } + + rangeSplit := strings.Split(split1, separatorRange) + if len(rangeSplit) == 1 { + at, err := parseDateFromRequest(rangeSplit[0]) + if collectError(err) { + break outerLoop + } + + err = addSingleDate(at, filter, hll) + if collectError(err) { + break outerLoop + } + } else if len(rangeSplit) == 2 { + from, err := parseDateFromRequest(rangeSplit[0]) + if collectError(err) { + break outerLoop + } + to, err := parseDateFromRequest(rangeSplit[1]) + if collectError(err) { + break outerLoop + } + + err = addRange(from, to, filter, hll) + if collectError(err) { + break outerLoop + } + } else { + collectError(errRangeFormatIncorrect) + break outerLoop + } + } + + if result.Status == statusOk { + result.Count = hll.Count() + } + return result +} + +var errBadDate = errors.New("bad date format, must be yyyy-mm-dd") +var zeroTime = time.Unix(0, 0) + +func parseDateFromRequest(dateStr string) (time.Time, error) { + var year, month, day int + n, err := fmt.Sscanf(dateStr, "%d-%d-%d", &year, &month, &day) + if err != nil || n != 3 { + return zeroTime, errBadDate + } + return time.Date(year, time.Month(month), day, 0, 0, 0, 0, server.CounterLocation), nil +} + +type hllAndError struct { + hll *hyperloglog.HyperLogLogPlus + err error +} + +func addSingleDate(at time.Time, filter serverFilter, dest *hyperloglog.HyperLogLogPlus) error { + var partialErr error + for _, si := range allServers { + if filter.IsServerAllowed(si) { + hll, err2 := si.GetHLL(at) + if err2 == ErrServerInFailedState { + partialErr = err2 + } else if err2 != nil { + return err2 + } else { + dest.Merge(hll) + } + } + } + return partialErr +} + +func addRange(start time.Time, end time.Time, filter serverFilter, dest *hyperloglog.HyperLogLogPlus) error { + end = server.TruncateToMidnight(end) + year, month, day := start.Date() + var partialErr error + var myAllServers = make([]*serverInfo, 0, len(allServers)) + for _, si := range allServers { + if filter.IsServerAllowed(si) { + myAllServers = append(myAllServers, si) + } + } + + var ch = make(chan hllAndError) + var wg sync.WaitGroup + for current := start; current.Before(end); day = day + 1 { + current = time.Date(year, month, day, 0, 0, 0, 0, server.CounterLocation) + for _, si := range myAllServers { + wg.Add(1) + go getHLL(ch, si, current) + } + } + + go func() { + wg.Wait() + close(ch) + }() + + for pair := range ch { + wg.Done() + hll, err := pair.hll, pair.err + if err != nil { + if partialErr == nil || partialErr == ErrServerInFailedState { + partialErr = err + } + } else { + dest.Merge(hll) + } + } + + return partialErr +} + +func getHLL(ch chan hllAndError, si *serverInfo, at time.Time) { + hll, err := si.GetHLL(at) + ch <- hllAndError{hll: hll, err: err} +} diff --git a/socketserver/cmd/statsweb/webroot/cal_entry.hbs b/socketserver/cmd/statsweb/webroot/cal_entry.hbs new file mode 100644 index 00000000..cf178cf3 --- /dev/null +++ b/socketserver/cmd/statsweb/webroot/cal_entry.hbs @@ -0,0 +1,6 @@ + + {{.Date}} + {{if not .NoData}} + {{.UniqUsers}} + {{end}} + diff --git a/socketserver/cmd/statsweb/webroot/calendar.hbs b/socketserver/cmd/statsweb/webroot/calendar.hbs new file mode 100644 index 00000000..1a8070ce --- /dev/null +++ b/socketserver/cmd/statsweb/webroot/calendar.hbs @@ -0,0 +1,18 @@ + + + + + + + + + + + + {{range .Weeks}} + {{range .Days}} + {{template "cal_entry"}} + {{end}} + {{end}} + +
SundayMondayTuesdayWednesdayThursdayFridaySaturday
\ No newline at end of file diff --git a/socketserver/cmd/statsweb/webroot/layout.template.html b/socketserver/cmd/statsweb/webroot/layout.template.html new file mode 100644 index 00000000..09c24acf --- /dev/null +++ b/socketserver/cmd/statsweb/webroot/layout.template.html @@ -0,0 +1,15 @@ + + + + + Socket Server Stats Dashboard + + + +
+ {{template "content"}} +
+ + diff --git a/socketserver/internal/server/backend.go b/socketserver/internal/server/backend.go deleted file mode 100644 index 72d74b22..00000000 --- a/socketserver/internal/server/backend.go +++ /dev/null @@ -1,234 +0,0 @@ -package server - -import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "github.com/pmylund/go-cache" - "golang.org/x/crypto/nacl/box" - "io/ioutil" - "log" - "net/http" - "net/url" - "strconv" - "strings" - "sync" - "time" -) - -var backendHttpClient http.Client -var backendUrl string -var responseCache *cache.Cache - -var getBacklogUrl string - -var backendSharedKey [32]byte -var serverId int - -var messageBufferPool sync.Pool - -func SetupBackend(config *ConfigFile) { - backendHttpClient.Timeout = 60 * time.Second - backendUrl = config.BackendUrl - if responseCache != nil { - responseCache.Flush() - } - responseCache = cache.New(60*time.Second, 120*time.Second) - - getBacklogUrl = fmt.Sprintf("%s/backlog", backendUrl) - - messageBufferPool.New = New4KByteBuffer - - var theirPublic, ourPrivate [32]byte - copy(theirPublic[:], config.BackendPublicKey) - copy(ourPrivate[:], config.OurPrivateKey) - serverId = config.ServerId - - box.Precompute(&backendSharedKey, &theirPublic, &ourPrivate) -} - -func getCacheKey(remoteCommand, data string) string { - return fmt.Sprintf("%s/%s", remoteCommand, data) -} - -// Publish a message to clients with no caching. -// The scope must be specified because no attempt is made to recognize the command. -func HBackendPublishRequest(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - formData, err := UnsealRequest(r.Form) - if err != nil { - w.WriteHeader(403) - fmt.Fprintf(w, "Error: %v", err) - return - } - - cmd := formData.Get("cmd") - json := formData.Get("args") - channel := formData.Get("channel") - scope := formData.Get("scope") - - target := MessageTargetTypeByName(scope) - - if cmd == "" { - w.WriteHeader(422) - fmt.Fprintf(w, "Error: cmd cannot be blank") - return - } - if channel == "" && (target == MsgTargetTypeChat || target == MsgTargetTypeMultichat) { - w.WriteHeader(422) - fmt.Fprintf(w, "Error: channel must be specified") - return - } - - cm := ClientMessage{MessageID: -1, Command: Command(cmd), origArguments: json} - cm.parseOrigArguments() - var count int - - switch target { - case MsgTargetTypeSingle: - // TODO - case MsgTargetTypeChat: - count = PublishToChat(channel, cm) - case MsgTargetTypeMultichat: - // TODO - case MsgTargetTypeGlobal: - count = PublishToAll(cm) - case MsgTargetTypeInvalid: - default: - w.WriteHeader(422) - fmt.Fprint(w, "Invalid 'scope'. must be single, chat, multichat, channel, or global") - return - } - fmt.Fprint(w, count) -} - -func RequestRemoteDataCached(remoteCommand, data string, auth AuthInfo) (string, error) { - cached, ok := responseCache.Get(getCacheKey(remoteCommand, data)) - if ok { - return cached.(string), nil - } - return RequestRemoteData(remoteCommand, data, auth) -} - -func RequestRemoteData(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) { - destUrl := fmt.Sprintf("%s/cmd/%s", backendUrl, remoteCommand) - var authKey string - if auth.UsernameValidated { - authKey = "usernameClaimed" - } else { - authKey = "username" - } - - formData := url.Values{ - "clientData": []string{data}, - authKey: []string{auth.TwitchUsername}, - } - - sealedForm, err := SealRequest(formData) - if err != nil { - return "", err - } - - resp, err := backendHttpClient.PostForm(destUrl, sealedForm) - if err != nil { - return "", err - } - - respBytes, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return "", err - } - - responseStr = string(respBytes) - - if resp.Header.Get("FFZ-Cache") != "" { - durSecs, err := strconv.ParseInt(resp.Header.Get("FFZ-Cache"), 10, 64) - if err != nil { - return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err) - } - duration := time.Duration(durSecs) * time.Second - responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration) - } - - return -} - -func FetchBacklogData(chatSubs []string) ([]ClientMessage, error) { - formData := url.Values{ - "subs": chatSubs, - } - - sealedForm, err := SealRequest(formData) - if err != nil { - return nil, err - } - - resp, err := backendHttpClient.PostForm(getBacklogUrl, sealedForm) - if err != nil { - return nil, err - } - dec := json.NewDecoder(resp.Body) - var messages []ClientMessage - err = dec.Decode(messages) - if err != nil { - return nil, err - } - - return messages, nil -} - -func GenerateKeys(outputFile, serverId, theirPublicStr string) { - var err error - output := ConfigFile{ - ListenAddr: "0.0.0.0:8001", - SocketOrigin: "localhost:8001", - BackendUrl: "http://localhost:8002/ffz", - BannerHTML: ` - -CatBag - -
-
-
-
-
-
- A FrankerFaceZ Service - — CatBag by Wolsk -
-
-`, - } - - output.ServerId, err = strconv.Atoi(serverId) - if err != nil { - log.Fatal(err) - } - - ourPublic, ourPrivate, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatal(err) - } - output.OurPublicKey, output.OurPrivateKey = ourPublic[:], ourPrivate[:] - - if theirPublicStr != "" { - reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(theirPublicStr)) - theirPublic, err := ioutil.ReadAll(reader) - if err != nil { - log.Fatal(err) - } - output.BackendPublicKey = theirPublic - } - - bytes, err := json.MarshalIndent(output, "", "\t") - if err != nil { - log.Fatal(err) - } - fmt.Println(string(bytes)) - err = ioutil.WriteFile(outputFile, bytes, 0600) - if err != nil { - log.Fatal(err) - } -} diff --git a/socketserver/internal/server/backend_test.go b/socketserver/internal/server/backend_test.go deleted file mode 100644 index 7043d9f3..00000000 --- a/socketserver/internal/server/backend_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package server - -import ( - "crypto/rand" - "golang.org/x/crypto/nacl/box" - "net/url" - "testing" -) - -func SetupRandomKeys(t testing.TB) { - _, senderPrivate, err := box.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - receiverPublic, _, err := box.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - - box.Precompute(&backendSharedKey, receiverPublic, senderPrivate) - messageBufferPool.New = New4KByteBuffer -} - -func TestSealRequest(t *testing.T) { - SetupRandomKeys(t) - - values := url.Values{ - "QuickBrownFox": []string{"LazyDog"}, - } - - sealedValues, err := SealRequest(values) - if err != nil { - t.Fatal(err) - } - // sealedValues.Encode() - // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV - - unsealedValues, err := UnsealRequest(sealedValues) - if err != nil { - t.Fatal(err) - } - - if unsealedValues.Get("QuickBrownFox") != "LazyDog" { - t.Errorf("Failed to round-trip, got back %v", unsealedValues) - } -} diff --git a/socketserver/internal/server/backlog.go b/socketserver/internal/server/backlog.go deleted file mode 100644 index 66f09e93..00000000 --- a/socketserver/internal/server/backlog.go +++ /dev/null @@ -1,364 +0,0 @@ -package server - -import ( - "errors" - "fmt" - "net/http" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -type PushCommandCacheInfo struct { - Caching BacklogCacheType - Target MessageTargetType -} - -// this value is just docs right now -var ServerInitiatedCommands = map[Command]PushCommandCacheInfo{ - /// Global updates & notices - "update_news": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global - "message": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global - "reload_ff": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global - - /// Emote updates - "reload_badges": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global - "set_badge": {CacheTypeTimestamps, MsgTargetTypeMultichat}, // timecache:multichat - "reload_set": {}, // timecache:multichat - "load_set": {}, // TODO what are the semantics of this? - - /// User auth - "do_authorize": {CacheTypeNever, MsgTargetTypeSingle}, // nocache:single - - /// Channel data - // follow_sets: extra emote sets included in the chat - // follow_buttons: extra follow buttons below the stream - "follow_sets": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:chat - "follow_buttons": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:watching - "srl_race": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching - - /// Chatter/viewer counts - "chatters": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching - "viewers": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching -} - -type BacklogCacheType int - -const ( - // This is not a cache type. - CacheTypeInvalid BacklogCacheType = iota - // This message cannot be cached. - CacheTypeNever - // Save the last 24 hours of this message. - // If a client indicates that it has reconnected, replay the messages sent after the disconnect. - // Do not replay if the client indicates that this is a firstload. - CacheTypeTimestamps - // Save only the last copy of this message, and always send it when the backlog is requested. - CacheTypeLastOnly - // Save this backlog data to disk with its timestamp. - // Send it when the backlog is requested, or after a reconnect if it was updated. - CacheTypePersistent -) - -type MessageTargetType int - -const ( - // This is not a message target. - MsgTargetTypeInvalid MessageTargetType = iota - // This message is targeted to a single TODO(user or connection) - MsgTargetTypeSingle - // This message is targeted to all users in a chat - MsgTargetTypeChat - // This message is targeted to all users in multiple chats - MsgTargetTypeMultichat - // This message is sent to all FFZ users. - MsgTargetTypeGlobal -) - -// note: see types.go for methods on these - -// Returned by BacklogCacheType.UnmarshalJSON() -var ErrorUnrecognizedCacheType = errors.New("Invalid value for cachetype") - -// Returned by MessageTargetType.UnmarshalJSON() -var ErrorUnrecognizedTargetType = errors.New("Invalid value for message target") - -type TimestampedGlobalMessage struct { - Timestamp time.Time - Command Command - Data string -} - -type TimestampedMultichatMessage struct { - Timestamp time.Time - Channels []string - Command Command - Data string -} - -type LastSavedMessage struct { - Timestamp time.Time - Data string -} - -// map is command -> channel -> data - -// CacheTypeLastOnly. Cleaned up by reaper goroutine every ~hour. -var CachedLastMessages map[Command]map[string]LastSavedMessage -var CachedLSMLock sync.RWMutex - -// CacheTypePersistent. Never cleaned. -var PersistentLastMessages map[Command]map[string]LastSavedMessage -var PersistentLSMLock sync.RWMutex - -var CachedGlobalMessages []TimestampedGlobalMessage -var CachedChannelMessages []TimestampedMultichatMessage -var CacheListsLock sync.RWMutex - -func DumpCache() { - CachedLSMLock.Lock() - CachedLastMessages = make(map[Command]map[string]LastSavedMessage) - CachedLSMLock.Unlock() - - PersistentLSMLock.Lock() - PersistentLastMessages = make(map[Command]map[string]LastSavedMessage) - // TODO delete file? - PersistentLSMLock.Unlock() - - CacheListsLock.Lock() - CachedGlobalMessages = make(tgmarray, 0) - CachedChannelMessages = make(tmmarray, 0) - CacheListsLock.Unlock() -} - -func SendBacklogForNewClient(client *ClientInfo) { - client.Mutex.Lock() // reading CurrentChannels - PersistentLSMLock.RLock() - for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypePersistent, MsgTargetTypeChat}) { - chanMap := CachedLastMessages[cmd] - if chanMap == nil { - continue - } - for _, channel := range client.CurrentChannels { - msg, ok := chanMap[channel] - if ok { - msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} - msg.parseOrigArguments() - client.MessageChannel <- msg - } - } - } - PersistentLSMLock.RUnlock() - - CachedLSMLock.RLock() - for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat}) { - chanMap := CachedLastMessages[cmd] - if chanMap == nil { - continue - } - for _, channel := range client.CurrentChannels { - msg, ok := chanMap[channel] - if ok { - msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} - msg.parseOrigArguments() - client.MessageChannel <- msg - } - } - } - CachedLSMLock.RUnlock() - client.Mutex.Unlock() -} - -func SendTimedBacklogMessages(client *ClientInfo, disconnectTime time.Time) { - client.Mutex.Lock() // reading CurrentChannels - CacheListsLock.RLock() - - globIdx := FindFirstNewMessage(tgmarray(CachedGlobalMessages), disconnectTime) - - for i := globIdx; i < len(CachedGlobalMessages); i++ { - item := CachedGlobalMessages[i] - msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data} - msg.parseOrigArguments() - client.MessageChannel <- msg - } - - chanIdx := FindFirstNewMessage(tmmarray(CachedChannelMessages), disconnectTime) - - for i := chanIdx; i < len(CachedChannelMessages); i++ { - item := CachedChannelMessages[i] - var send bool - for _, channel := range item.Channels { - for _, matchChannel := range client.CurrentChannels { - if channel == matchChannel { - send = true - break - } - } - if send { - break - } - } - if send { - msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data} - msg.parseOrigArguments() - client.MessageChannel <- msg - } - } - - CacheListsLock.RUnlock() - client.Mutex.Unlock() -} - -func InsertionSort(ary sort.Interface) { - for i := 1; i < ary.Len(); i++ { - for j := i; j > 0 && ary.Less(j, j-1); j-- { - ary.Swap(j, j-1) - } - } -} - -type TimestampArray interface { - Len() int - GetTime(int) time.Time -} - -func FindFirstNewMessage(ary TimestampArray, disconnectTime time.Time) (idx int) { - // TODO needs tests - len := ary.Len() - i := len - - // Walk backwards until we find GetTime() before disconnectTime - step := 1 - for i > 0 { - i -= step - if i < 0 { - i = 0 - } - if !ary.GetTime(i).After(disconnectTime) { - break - } - step = int(float64(step)*1.5) + 1 - } - - // Walk forwards until we find GetTime() after disconnectTime - for i < len && !ary.GetTime(i).After(disconnectTime) { - i++ - } - - if i == len { - return -1 - } - return i -} - -func SaveLastMessage(which map[Command]map[string]LastSavedMessage, locker sync.Locker, cmd Command, channel string, timestamp time.Time, data string, deleting bool) { - locker.Lock() - defer locker.Unlock() - - chanMap, ok := CachedLastMessages[cmd] - if !ok { - if deleting { - return - } - chanMap = make(map[string]LastSavedMessage) - CachedLastMessages[cmd] = chanMap - } - - if deleting { - delete(chanMap, channel) - } else { - chanMap[channel] = LastSavedMessage{timestamp, data} - } -} - -func SaveGlobalMessage(cmd Command, timestamp time.Time, data string) { - CacheListsLock.Lock() - CachedGlobalMessages = append(CachedGlobalMessages, TimestampedGlobalMessage{timestamp, cmd, data}) - InsertionSort(tgmarray(CachedGlobalMessages)) - CacheListsLock.Unlock() -} - -func SaveMultichanMessage(cmd Command, channels string, timestamp time.Time, data string) { - CacheListsLock.Lock() - CachedChannelMessages = append(CachedChannelMessages, TimestampedMultichatMessage{timestamp, strings.Split(channels, ","), cmd, data}) - InsertionSort(tmmarray(CachedChannelMessages)) - CacheListsLock.Unlock() -} - -func GetCommandsOfType(match PushCommandCacheInfo) []Command { - var ret []Command - for cmd, info := range ServerInitiatedCommands { - if info == match { - ret = append(ret, cmd) - } - } - return ret -} - -func HBackendDumpBacklog(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - formData, err := UnsealRequest(r.Form) - if err != nil { - w.WriteHeader(403) - fmt.Fprintf(w, "Error: %v", err) - return - } - - confirm := formData.Get("confirm") - if confirm == "1" { - DumpCache() - } -} - -// Publish a message to clients, and update the in-server cache for the message. -// notes: -// `scope` is implicit in the command -func HBackendUpdateAndPublish(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - formData, err := UnsealRequest(r.Form) - if err != nil { - w.WriteHeader(403) - fmt.Fprintf(w, "Error: %v", err) - return - } - - cmd := Command(formData.Get("cmd")) - json := formData.Get("args") - channel := formData.Get("channel") - deleteMode := formData.Get("delete") != "" - timeStr := formData.Get("time") - timestamp, err := time.Parse(time.UnixDate, timeStr) - if err != nil { - w.WriteHeader(422) - fmt.Fprintf(w, "error parsing time: %v", err) - } - - cacheinfo, ok := ServerInitiatedCommands[cmd] - if !ok { - w.WriteHeader(422) - fmt.Fprintf(w, "Caching semantics unknown for command '%s'. Post to /addcachedcommand first.") - return - } - - var count int - msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} - msg.parseOrigArguments() - - if cacheinfo.Caching == CacheTypeLastOnly && cacheinfo.Target == MsgTargetTypeChat { - SaveLastMessage(CachedLastMessages, &CachedLSMLock, cmd, channel, timestamp, json, deleteMode) - count = PublishToChat(channel, msg) - } else if cacheinfo.Caching == CacheTypePersistent && cacheinfo.Target == MsgTargetTypeChat { - SaveLastMessage(PersistentLastMessages, &PersistentLSMLock, cmd, channel, timestamp, json, deleteMode) - count = PublishToChat(channel, msg) - } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeMultichat { - SaveMultichanMessage(cmd, channel, timestamp, json) - count = PublishToMultiple(strings.Split(channel, ","), msg) - } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeGlobal { - SaveGlobalMessage(cmd, timestamp, json) - count = PublishToAll(msg) - } - - w.Write([]byte(strconv.Itoa(count))) -} diff --git a/socketserver/internal/server/backlog_test.go b/socketserver/internal/server/backlog_test.go deleted file mode 100644 index 68757587..00000000 --- a/socketserver/internal/server/backlog_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package server - -import ( - "testing" - "time" -) - -func TestFindFirstNewMessageEmpty(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{} - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != -1 { - t.Errorf("Expected -1, got %d", i) - } -} -func TestFindFirstNewMessageOneBefore(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{ - {Timestamp: time.Unix(8, 0)}, - } - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != -1 { - t.Errorf("Expected -1, got %d", i) - } -} -func TestFindFirstNewMessageSeveralBefore(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{ - {Timestamp: time.Unix(1, 0)}, - {Timestamp: time.Unix(2, 0)}, - {Timestamp: time.Unix(3, 0)}, - {Timestamp: time.Unix(4, 0)}, - {Timestamp: time.Unix(5, 0)}, - } - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != -1 { - t.Errorf("Expected -1, got %d", i) - } -} -func TestFindFirstNewMessageInMiddle(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{ - {Timestamp: time.Unix(1, 0)}, - {Timestamp: time.Unix(2, 0)}, - {Timestamp: time.Unix(3, 0)}, - {Timestamp: time.Unix(4, 0)}, - {Timestamp: time.Unix(5, 0)}, - {Timestamp: time.Unix(11, 0)}, - {Timestamp: time.Unix(12, 0)}, - {Timestamp: time.Unix(13, 0)}, - {Timestamp: time.Unix(14, 0)}, - {Timestamp: time.Unix(15, 0)}, - } - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != 5 { - t.Errorf("Expected 5, got %d", i) - } -} -func TestFindFirstNewMessageOneAfter(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{ - {Timestamp: time.Unix(15, 0)}, - } - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != 0 { - t.Errorf("Expected 0, got %d", i) - } -} -func TestFindFirstNewMessageSeveralAfter(t *testing.T) { - CachedGlobalMessages = []TimestampedGlobalMessage{ - {Timestamp: time.Unix(11, 0)}, - {Timestamp: time.Unix(12, 0)}, - {Timestamp: time.Unix(13, 0)}, - {Timestamp: time.Unix(14, 0)}, - {Timestamp: time.Unix(15, 0)}, - } - i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0)) - if i != 0 { - t.Errorf("Expected 0, got %d", i) - } -} diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go deleted file mode 100644 index c947bf08..00000000 --- a/socketserver/internal/server/commands.go +++ /dev/null @@ -1,285 +0,0 @@ -package server - -import ( - "fmt" - "github.com/satori/go.uuid" - "golang.org/x/net/websocket" - "log" - "strconv" - "sync" - "time" -) - -var ResponseSuccess = ClientMessage{Command: SuccessCommand} -var ResponseFailure = ClientMessage{Command: "False"} - -const ChannelInfoDelay = 2 * time.Second - -func HandleCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) { - handler, ok := CommandHandlers[msg.Command] - if !ok { - log.Println("[!] Unknown command", msg.Command, "- sent by client", client.ClientID, "@", conn.RemoteAddr()) - FFZCodec.Send(conn, ClientMessage{ - MessageID: msg.MessageID, - Command: "error", - Arguments: fmt.Sprintf("Unknown command %s", msg.Command), - }) - return - } - - response, err := CallHandler(handler, conn, client, msg) - - if err == nil { - if response.Command == AsyncResponseCommand { - // Don't send anything - // The response will be delivered over client.MessageChannel / serverMessageChan - } else { - response.MessageID = msg.MessageID - FFZCodec.Send(conn, response) - } - } else { - FFZCodec.Send(conn, ClientMessage{ - MessageID: msg.MessageID, - Command: "error", - Arguments: err.Error(), - }) - } -} - -func HandleHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - version, clientId, err := msg.ArgumentsAsTwoStrings() - if err != nil { - return - } - - client.Version = version - client.ClientID = uuid.FromStringOrNil(clientId) - if client.ClientID == uuid.Nil { - client.ClientID = uuid.NewV4() - } - - SubscribeGlobal(client) - - return ClientMessage{ - Arguments: client.ClientID.String(), - }, nil -} - -func HandleReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - disconnectAt, err := msg.ArgumentsAsInt() - if err != nil { - return - } - - client.Mutex.Lock() - if client.MakePendingRequests != nil { - if !client.MakePendingRequests.Stop() { - // Timer already fired, GetSubscriptionBacklog() has started - rmsg.Command = SuccessCommand - return - } - } - client.PendingSubscriptionsBacklog = nil - client.MakePendingRequests = nil - client.Mutex.Unlock() - - if disconnectAt == 0 { - // backlog only - go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} - SendBacklogForNewClient(client) - }() - return ClientMessage{Command: AsyncResponseCommand}, nil - } else { - // backlog and timed - go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} - SendBacklogForNewClient(client) - SendTimedBacklogMessages(client, time.Unix(disconnectAt, 0)) - }() - return ClientMessage{Command: AsyncResponseCommand}, nil - } -} - -func HandleSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - username, err := msg.ArgumentsAsString() - if err != nil { - return - } - - client.Mutex.Lock() - client.TwitchUsername = username - client.UsernameValidated = false - client.Mutex.Unlock() - - return ResponseSuccess, nil -} - -func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - channel, err := msg.ArgumentsAsString() - - if err != nil { - return - } - - client.Mutex.Lock() - - AddToSliceS(&client.CurrentChannels, channel) - client.PendingSubscriptionsBacklog = append(client.PendingSubscriptionsBacklog, channel) - - if client.MakePendingRequests == nil { - client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client)) - } else { - if !client.MakePendingRequests.Reset(ChannelInfoDelay) { - client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client)) - } - } - - client.Mutex.Unlock() - - SubscribeChat(client, channel) - - return ResponseSuccess, nil -} - -func HandleUnsub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - channel, err := msg.ArgumentsAsString() - - if err != nil { - return - } - - client.Mutex.Lock() - RemoveFromSliceS(&client.CurrentChannels, channel) - client.Mutex.Unlock() - - UnsubscribeSingleChat(client, channel) - - return ResponseSuccess, nil -} - -func GetSubscriptionBacklogFor(conn *websocket.Conn, client *ClientInfo) func() { - return func() { - GetSubscriptionBacklog(conn, client) - } -} - -// On goroutine -func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) { - var subs []string - - // Lock, grab the data, and reset it - client.Mutex.Lock() - subs = client.PendingSubscriptionsBacklog - client.PendingSubscriptionsBacklog = nil - client.MakePendingRequests = nil - client.Mutex.Unlock() - - if len(subs) == 0 { - return - } - - if backendUrl == "" { - return // for testing runs - } - messages, err := FetchBacklogData(subs) - - if err != nil { - // Oh well. - log.Print("error in GetSubscriptionBacklog:", err) - return - } - - // Deliver to client - for _, msg := range messages { - client.MessageChannel <- msg - } -} - -type SurveySubmission struct { - User string - Json string -} - -var SurveySubmissions []SurveySubmission -var SurveySubmissionLock sync.Mutex - -func HandleSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - SurveySubmissionLock.Lock() - SurveySubmissions = append(SurveySubmissions, SurveySubmission{client.TwitchUsername, msg.origArguments}) - SurveySubmissionLock.Unlock() - - return ResponseSuccess, nil -} - -type FollowEvent struct { - User string - Channel string - NowFollowing bool - Timestamp time.Time -} - -var FollowEvents []FollowEvent -var FollowEventsLock sync.Mutex - -func HandleTrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - channel, following, err := msg.ArgumentsAsStringAndBool() - if err != nil { - return - } - now := time.Now() - - FollowEventsLock.Lock() - FollowEvents = append(FollowEvents, FollowEvent{client.TwitchUsername, channel, following, now}) - FollowEventsLock.Unlock() - - return ResponseSuccess, nil -} - -var AggregateEmoteUsage map[int]map[string]int = make(map[int]map[string]int) -var AggregateEmoteUsageLock sync.Mutex - -func HandleEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - // arguments is [1]map[EmoteId]map[RoomName]float64 - - mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{}) - - AggregateEmoteUsageLock.Lock() - defer AggregateEmoteUsageLock.Unlock() - - for strEmote, val1 := range mapRoot { - var emoteId int - emoteId, err = strconv.Atoi(strEmote) - if err != nil { - return - } - - destMapInner, ok := AggregateEmoteUsage[emoteId] - if !ok { - destMapInner = make(map[string]int) - AggregateEmoteUsage[emoteId] = destMapInner - } - - mapInner := val1.(map[string]interface{}) - for roomName, val2 := range mapInner { - var count int = int(val2.(float64)) - destMapInner[roomName] += count - } - } - - return ResponseSuccess, nil -} - -func HandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { - go func(conn *websocket.Conn, msg ClientMessage, authInfo AuthInfo) { - resp, err := RequestRemoteDataCached(string(msg.Command), msg.origArguments, authInfo) - - if err != nil { - FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}) - } else { - FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}) - } - }(conn, msg, client.AuthInfo) - - return ClientMessage{Command: AsyncResponseCommand}, nil -} diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go deleted file mode 100644 index f227db8b..00000000 --- a/socketserver/internal/server/handlecore.go +++ /dev/null @@ -1,434 +0,0 @@ -package server // import "bitbucket.org/stendec/frankerfacez/socketserver/internal/server" - -import ( - "crypto/tls" - "encoding/json" - "errors" - "fmt" - "golang.org/x/net/websocket" - "log" - "net/http" - "strconv" - "strings" - "sync" -) - -const MAX_PACKET_SIZE = 1024 - -// A command is how the client refers to a function on the server. It's just a string. -type Command string - -// A function that is called to respond to a Command. -type CommandHandler func(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) - -var CommandHandlers = map[Command]CommandHandler{ - HelloCommand: HandleHello, - "setuser": HandleSetUser, - "ready": HandleReady, - - "sub": HandleSub, - "unsub": HandleUnsub, - - "track_follow": HandleTrackFollow, - "emoticon_uses": HandleEmoticonUses, - "survey": HandleSurvey, - - "twitch_emote": HandleRemoteCommand, - "get_link": HandleRemoteCommand, - "get_display_name": HandleRemoteCommand, - "update_follow_buttons": HandleRemoteCommand, - "chat_history": HandleRemoteCommand, -} - -// Sent by the server in ClientMessage.Command to indicate success. -const SuccessCommand Command = "True" - -// Sent by the server in ClientMessage.Command to indicate failure. -const ErrorCommand Command = "error" - -// This must be the first command sent by the client once the connection is established. -const HelloCommand Command = "hello" - -// A handler returning a ClientMessage with this Command will prevent replying to the client. -// It signals that the work has been handed off to a background goroutine. -const AsyncResponseCommand Command = "_async" - -// A websocket.Codec that translates the protocol into ClientMessage objects. -var FFZCodec websocket.Codec = websocket.Codec{ - Marshal: MarshalClientMessage, - Unmarshal: UnmarshalClientMessage, -} - -// Errors that get returned to the client. -var ProtocolError error = errors.New("FFZ Socket protocol error.") -var ProtocolErrorNegativeID error = errors.New("FFZ Socket protocol error: negative or zero message ID.") -var ExpectedSingleString = errors.New("Error: Expected single string as arguments.") -var ExpectedSingleInt = errors.New("Error: Expected single integer as arguments.") -var ExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.") -var ExpectedStringAndInt = errors.New("Error: Expected array of string, int as arguments.") -var ExpectedStringAndBool = errors.New("Error: Expected array of string, bool as arguments.") -var ExpectedStringAndIntGotFloat = errors.New("Error: Second argument was a float, expected an integer.") - -var gconfig *ConfigFile - -// Create a websocket.Server with the options from the provided Config. -func setupServer(config *ConfigFile, tlsConfig *tls.Config) *websocket.Server { - gconfig = config - sockConf, err := websocket.NewConfig("/", config.SocketOrigin) - if err != nil { - log.Fatal(err) - } - - SetupBackend(config) - - if config.UseSSL { - cert, err := tls.LoadX509KeyPair(config.SSLCertificateFile, config.SSLKeyFile) - if err != nil { - log.Fatal(err) - } - tlsConfig.Certificates = []tls.Certificate{cert} - tlsConfig.ServerName = config.SocketOrigin - tlsConfig.BuildNameToCertificate() - sockConf.TlsConfig = tlsConfig - - } - - sockServer := &websocket.Server{} - sockServer.Config = *sockConf - sockServer.Handler = HandleSocketConnection - - go deadChannelReaper() - - return sockServer -} - -// Set up a websocket listener and register it on /. -// (Uses http.DefaultServeMux .) -func SetupServerAndHandle(config *ConfigFile, tlsConfig *tls.Config, serveMux *http.ServeMux) { - sockServer := setupServer(config, tlsConfig) - - if serveMux == nil { - serveMux = http.DefaultServeMux - } - serveMux.HandleFunc("/", ServeWebsocketOrCatbag(sockServer.ServeHTTP)) - serveMux.HandleFunc("/pub_msg", HBackendPublishRequest) - serveMux.HandleFunc("/dump_backlog", HBackendDumpBacklog) - serveMux.HandleFunc("/update_and_pub", HBackendUpdateAndPublish) -} - -func ServeWebsocketOrCatbag(sockfunc func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Connection") == "Upgrade" { - sockfunc(w, r) - return - } else { - w.Write([]byte(gconfig.BannerHTML)) - } - } -} - -// Handle a new websocket connection from a FFZ client. -// This runs in a goroutine started by net/http. -func HandleSocketConnection(conn *websocket.Conn) { - // websocket.Conn is a ReadWriteCloser - - var _closer sync.Once - closer := func() { - _closer.Do(func() { - conn.Close() - }) - } - - // Close the connection when we're done. - defer closer() - - _clientChan := make(chan ClientMessage) - _serverMessageChan := make(chan ClientMessage) - _errorChan := make(chan error) - - // Launch receiver goroutine - go func(errorChan chan<- error, clientChan chan<- ClientMessage) { - var msg ClientMessage - var err error - for ; err == nil; err = FFZCodec.Receive(conn, &msg) { - if msg.MessageID == 0 { - continue - } - clientChan <- msg - } - errorChan <- err - close(errorChan) - close(clientChan) - // exit - }(_errorChan, _clientChan) - - var errorChan <-chan error = _errorChan - var clientChan <-chan ClientMessage = _clientChan - var serverMessageChan <-chan ClientMessage = _serverMessageChan - - var client ClientInfo - client.MessageChannel = _serverMessageChan - - // All set up, now enter the work loop - -RunLoop: - for { - select { - case err := <-errorChan: - FFZCodec.Send(conn, ClientMessage{ - MessageID: -1, - Command: "error", - Arguments: err.Error(), - }) // note - socket might be closed, but don't care - break RunLoop - case msg := <-clientChan: - if client.Version == "" && msg.Command != HelloCommand { - FFZCodec.Send(conn, ClientMessage{ - MessageID: msg.MessageID, - Command: "error", - Arguments: "Error - the first message sent must be a 'hello'", - }) - break RunLoop - } - - HandleCommand(conn, &client, msg) - case smsg := <-serverMessageChan: - FFZCodec.Send(conn, smsg) - } - } - - // Exit - - // Launch message draining goroutine - we aren't out of the pub/sub records - go func() { - for _ = range _serverMessageChan { - } - }() - - // Stop getting messages... - UnsubscribeAll(&client) - - // And finished. - // Close the channel so the draining goroutine can finish, too. - close(_serverMessageChan) -} - -func CallHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) { - defer func() { - if r := recover(); r != nil { - var ok bool - fmt.Print("[!] Error executing command", cmsg.Command, "--", r) - err, ok = r.(error) - if !ok { - err = fmt.Errorf("command handler: %v", r) - } - } - }() - return handler(conn, client, cmsg) -} - -// Unpack a message sent from the client into a ClientMessage. -func UnmarshalClientMessage(data []byte, payloadType byte, v interface{}) (err error) { - var spaceIdx int - - out := v.(*ClientMessage) - dataStr := string(data) - - // Message ID - spaceIdx = strings.IndexRune(dataStr, ' ') - if spaceIdx == -1 { - return ProtocolError - } - messageId, err := strconv.Atoi(dataStr[:spaceIdx]) - if messageId < -1 || messageId == 0 { - return ProtocolErrorNegativeID - } - - out.MessageID = messageId - dataStr = dataStr[spaceIdx+1:] - - spaceIdx = strings.IndexRune(dataStr, ' ') - if spaceIdx == -1 { - out.Command = Command(dataStr) - out.Arguments = nil - return nil - } else { - out.Command = Command(dataStr[:spaceIdx]) - } - dataStr = dataStr[spaceIdx+1:] - argumentsJson := dataStr - out.origArguments = argumentsJson - err = out.parseOrigArguments() - if err != nil { - return - } - return nil -} - -func (cm *ClientMessage) parseOrigArguments() error { - err := json.Unmarshal([]byte(cm.origArguments), &cm.Arguments) - if err != nil { - return err - } - return nil -} - -func MarshalClientMessage(clientMessage interface{}) (data []byte, payloadType byte, err error) { - var msg ClientMessage - var ok bool - msg, ok = clientMessage.(ClientMessage) - if !ok { - pMsg, ok := clientMessage.(*ClientMessage) - if !ok { - panic("MarshalClientMessage: argument needs to be a ClientMessage") - } - msg = *pMsg - } - var dataStr string - - if msg.Command == "" && msg.MessageID == 0 { - panic("MarshalClientMessage: attempt to send an empty ClientMessage") - } - - if msg.Command == "" { - msg.Command = SuccessCommand - } - if msg.MessageID == 0 { - msg.MessageID = -1 - } - - if msg.Arguments != nil { - argBytes, err := json.Marshal(msg.Arguments) - if err != nil { - return nil, 0, err - } - - dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes)) - } else { - dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command) - } - - return []byte(dataStr), websocket.TextFrame, nil -} - -// Command handlers should use this to construct responses. -func NewClientMessage(arguments interface{}) ClientMessage { - return ClientMessage{ - MessageID: 0, // filled by the select loop - Command: SuccessCommand, - Arguments: arguments, - } -} - -// Convenience method: Parse the arguments of the ClientMessage as a single string. -func (cm *ClientMessage) ArgumentsAsString() (string1 string, err error) { - var ok bool - string1, ok = cm.Arguments.(string) - if !ok { - err = ExpectedSingleString - return - } else { - return string1, nil - } -} - -// Convenience method: Parse the arguments of the ClientMessage as a single int. -func (cm *ClientMessage) ArgumentsAsInt() (int1 int64, err error) { - var ok bool - var num float64 - num, ok = cm.Arguments.(float64) - if !ok { - err = ExpectedSingleInt - return - } else { - int1 = int64(num) - return int1, nil - } -} - -// Convenience method: Parse the arguments of the ClientMessage as an array of two strings. -func (cm *ClientMessage) ArgumentsAsTwoStrings() (string1, string2 string, err error) { - var ok bool - var ary []interface{} - ary, ok = cm.Arguments.([]interface{}) - if !ok { - err = ExpectedTwoStrings - return - } else { - if len(ary) != 2 { - err = ExpectedTwoStrings - return - } - string1, ok = ary[0].(string) - if !ok { - err = ExpectedTwoStrings - return - } - string2, ok = ary[1].(string) - if !ok { - err = ExpectedTwoStrings - return - } - return string1, string2, nil - } -} - -// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int. -func (cm *ClientMessage) ArgumentsAsStringAndInt() (string1 string, int int64, err error) { - var ok bool - var ary []interface{} - ary, ok = cm.Arguments.([]interface{}) - if !ok { - err = ExpectedStringAndInt - return - } else { - if len(ary) != 2 { - err = ExpectedStringAndInt - return - } - string1, ok = ary[0].(string) - if !ok { - err = ExpectedStringAndInt - return - } - var num float64 - num, ok = ary[1].(float64) - if !ok { - err = ExpectedStringAndInt - return - } - int = int64(num) - if float64(int) != num { - err = ExpectedStringAndIntGotFloat - return - } - return string1, int, nil - } -} - -// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int. -func (cm *ClientMessage) ArgumentsAsStringAndBool() (str string, flag bool, err error) { - var ok bool - var ary []interface{} - ary, ok = cm.Arguments.([]interface{}) - if !ok { - err = ExpectedStringAndBool - return - } else { - if len(ary) != 2 { - err = ExpectedStringAndBool - return - } - str, ok = ary[0].(string) - if !ok { - err = ExpectedStringAndBool - return - } - flag, ok = ary[1].(bool) - if !ok { - err = ExpectedStringAndBool - return - } - return str, flag, nil - } -} diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go deleted file mode 100644 index cc9ba947..00000000 --- a/socketserver/internal/server/types.go +++ /dev/null @@ -1,232 +0,0 @@ -package server - -import ( - "encoding/json" - "github.com/satori/go.uuid" - "sync" - "time" -) - -const CryptoBoxKeyLength = 32 - -type ConfigFile struct { - // Numeric server id known to the backend - ServerId int - ListenAddr string - // Hostname of the socket server - SocketOrigin string - // URL to the backend server - BackendUrl string - // Memes go here - BannerHTML string - - // SSL/TLS - UseSSL bool - SSLCertificateFile string - SSLKeyFile string - - // Nacl keys - OurPrivateKey []byte - OurPublicKey []byte - BackendPublicKey []byte -} - -type ClientMessage struct { - // Message ID. Increments by 1 for each message sent from the client. - // When replying to a command, the message ID must be echoed. - // When sending a server-initiated message, this is -1. - MessageID int - // The command that the client wants from the server. - // When sent from the server, the literal string 'True' indicates success. - // Before sending, a blank Command will be converted into SuccessCommand. - Command Command - // Result of json.Unmarshal on the third field send from the client - Arguments interface{} - - origArguments string -} - -type AuthInfo struct { - // The client's claimed username on Twitch. - TwitchUsername string - - // Whether or not the server has validated the client's claimed username. - UsernameValidated bool -} - -type ClientInfo struct { - // The client ID. - // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. - ClientID uuid.UUID - - // The client's version. - // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. - Version string - - // This mutex protects writable data in this struct. - // If it seems to be a performance problem, we can split this. - Mutex sync.Mutex - - // TODO(riking) - does this need to be protected cross-thread? - AuthInfo - - // Username validation nonce. - ValidationNonce string - - // The list of chats this client is currently in. - // Protected by Mutex. - CurrentChannels []string - - // List of channels that we have not yet checked current chat-related channel info for. - // This lets us batch the backlog requests. - // Protected by Mutex. - PendingSubscriptionsBacklog []string - - // A timer that, when fired, will make the pending backlog requests. - // Usually nil. Protected by Mutex. - MakePendingRequests *time.Timer - - // Server-initiated messages should be sent here - // Never nil. - MessageChannel chan<- ClientMessage -} - -type tgmarray []TimestampedGlobalMessage -type tmmarray []TimestampedMultichatMessage - -func (ta tgmarray) Len() int { - return len(ta) -} -func (ta tgmarray) Less(i, j int) bool { - return ta[i].Timestamp.Before(ta[j].Timestamp) -} -func (ta tgmarray) Swap(i, j int) { - ta[i], ta[j] = ta[j], ta[i] -} -func (ta tgmarray) GetTime(i int) time.Time { - return ta[i].Timestamp -} -func (ta tmmarray) Len() int { - return len(ta) -} -func (ta tmmarray) Less(i, j int) bool { - return ta[i].Timestamp.Before(ta[j].Timestamp) -} -func (ta tmmarray) Swap(i, j int) { - ta[i], ta[j] = ta[j], ta[i] -} -func (ta tmmarray) GetTime(i int) time.Time { - return ta[i].Timestamp -} - -func (bct BacklogCacheType) Name() string { - switch bct { - case CacheTypeInvalid: - return "" - case CacheTypeNever: - return "never" - case CacheTypeTimestamps: - return "timed" - case CacheTypeLastOnly: - return "last" - case CacheTypePersistent: - return "persist" - } - panic("Invalid BacklogCacheType value") -} - -var CacheTypesByName = map[string]BacklogCacheType{ - "never": CacheTypeNever, - "timed": CacheTypeTimestamps, - "last": CacheTypeLastOnly, - "persist": CacheTypePersistent, -} - -func BacklogCacheTypeByName(name string) (bct BacklogCacheType) { - // CacheTypeInvalid is the zero value so it doesn't matter - bct, _ = CacheTypesByName[name] - return -} - -// Implements Stringer -func (bct BacklogCacheType) String() string { return bct.Name() } - -// Implements json.Marshaler -func (bct BacklogCacheType) MarshalJSON() ([]byte, error) { - return json.Marshal(bct.Name()) -} - -// Implements json.Unmarshaler -func (pbct *BacklogCacheType) UnmarshalJSON(data []byte) error { - var str string - err := json.Unmarshal(data, &str) - if err != nil { - return err - } - if str == "" { - *pbct = CacheTypeInvalid - return nil - } - val := BacklogCacheTypeByName(str) - if val != CacheTypeInvalid { - *pbct = val - return nil - } - return ErrorUnrecognizedCacheType -} - -func (mtt MessageTargetType) Name() string { - switch mtt { - case MsgTargetTypeInvalid: - return "" - case MsgTargetTypeSingle: - return "single" - case MsgTargetTypeChat: - return "chat" - case MsgTargetTypeMultichat: - return "multichat" - case MsgTargetTypeGlobal: - return "global" - } - panic("Invalid MessageTargetType value") -} - -var TargetTypesByName = map[string]MessageTargetType{ - "single": MsgTargetTypeSingle, - "chat": MsgTargetTypeChat, - "multichat": MsgTargetTypeMultichat, - "global": MsgTargetTypeGlobal, -} - -func MessageTargetTypeByName(name string) (mtt MessageTargetType) { - // MsgTargetTypeInvalid is the zero value so it doesn't matter - mtt, _ = TargetTypesByName[name] - return -} - -// Implements Stringer -func (mtt MessageTargetType) String() string { return mtt.Name() } - -// Implements json.Marshaler -func (mtt MessageTargetType) MarshalJSON() ([]byte, error) { - return json.Marshal(mtt.Name()) -} - -// Implements json.Unmarshaler -func (pmtt *MessageTargetType) UnmarshalJSON(data []byte) error { - var str string - err := json.Unmarshal(data, &str) - if err != nil { - return err - } - if str == "" { - *pmtt = MsgTargetTypeInvalid - return nil - } - mtt := MessageTargetTypeByName(str) - if mtt != MsgTargetTypeInvalid { - *pmtt = mtt - return nil - } - return ErrorUnrecognizedTargetType -} diff --git a/socketserver/server/backend.go b/socketserver/server/backend.go new file mode 100644 index 00000000..f41defde --- /dev/null +++ b/socketserver/server/backend.go @@ -0,0 +1,301 @@ +package server + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "sync" + + "github.com/pmylund/go-cache" + "golang.org/x/crypto/nacl/box" +) + +const bPathAnnounceStartup = "/startup" +const bPathAddTopic = "/topics" +const bPathAggStats = "/stats" +const bPathOtherCommand = "/cmd/" + +type backendInfo struct { + HTTPClient http.Client + baseURL string + responseCache *cache.Cache + + postStatisticsURL string + addTopicURL string + announceStartupURL string + + sharedKey [32]byte + serverID int + + lastSuccess map[string]time.Time + lastSuccessLock sync.Mutex +} + +var Backend *backendInfo + +func setupBackend(config *ConfigFile) *backendInfo { + b := new(backendInfo) + Backend = b + b.serverID = config.ServerID + + b.HTTPClient.Timeout = 60 * time.Second + b.baseURL = config.BackendURL + b.responseCache = cache.New(60*time.Second, 120*time.Second) + + b.announceStartupURL = fmt.Sprintf("%s%s", b.baseURL, bPathAnnounceStartup) + b.addTopicURL = fmt.Sprintf("%s%s", b.baseURL, bPathAddTopic) + b.postStatisticsURL = fmt.Sprintf("%s%s", b.baseURL, bPathAggStats) + + epochTime := time.Unix(0, 0).UTC() + lastBackendSuccess := map[string]time.Time{ + bPathAnnounceStartup: epochTime, + bPathAddTopic: epochTime, + bPathAggStats: epochTime, + bPathOtherCommand: epochTime, + } + b.lastSuccess = lastBackendSuccess + + var theirPublic, ourPrivate [32]byte + copy(theirPublic[:], config.BackendPublicKey) + copy(ourPrivate[:], config.OurPrivateKey) + + box.Precompute(&b.sharedKey, &theirPublic, &ourPrivate) + + return b +} + +func getCacheKey(remoteCommand, data string) string { + return fmt.Sprintf("%s/%s", remoteCommand, data) +} + +// ErrForwardedFromBackend is an error returned by the backend server. +type ErrForwardedFromBackend struct { + JSONError interface{} +} + +func (bfe ErrForwardedFromBackend) Error() string { + bytes, _ := json.Marshal(bfe.JSONError) + return string(bytes) +} + +// ErrAuthorizationNeeded is emitted when the backend replies with HTTP 401. +// Indicates that an attempt to validate `ClientInfo.TwitchUsername` should be attempted. +var ErrAuthorizationNeeded = errors.New("Must authenticate Twitch username to use this command") + +// SendRemoteCommandCached performs a RPC call on the backend, but caches responses. +func (backend *backendInfo) SendRemoteCommandCached(remoteCommand, data string, auth AuthInfo) (string, error) { + cached, ok := backend.responseCache.Get(getCacheKey(remoteCommand, data)) + if ok { + return cached.(string), nil + } + return backend.SendRemoteCommand(remoteCommand, data, auth) +} + +// SendRemoteCommand performs a RPC call on the backend by POSTing to `/cmd/$remoteCommand`. +// The form data is as follows: `clientData` is the JSON in the `data` parameter +// (should be retrieved from ClientMessage.Arguments), and either `username` or +// `usernameClaimed` depending on whether AuthInfo.UsernameValidates is true is AuthInfo.TwitchUsername. +func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) { + destURL := fmt.Sprintf("%s/cmd/%s", backend.baseURL, remoteCommand) + healthBucket := fmt.Sprintf("/cmd/%s", remoteCommand) + + formData := url.Values{ + "clientData": []string{data}, + "username": []string{auth.TwitchUsername}, + } + + if auth.UsernameValidated { + formData.Set("authenticated", "1") + } else { + formData.Set("authenticated", "0") + } + + sealedForm, err := backend.SealRequest(formData) + if err != nil { + return "", err + } + + resp, err := backend.HTTPClient.PostForm(destURL, sealedForm) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + responseStr = string(respBytes) + + if resp.StatusCode == 401 { + return "", ErrAuthorizationNeeded + } else if resp.StatusCode != 200 { + if resp.Header.Get("Content-Type") == "application/json" { + var err2 ErrForwardedFromBackend + err := json.Unmarshal(respBytes, &err2.JSONError) + if err != nil { + return "", fmt.Errorf("error decoding json error from backend: %v | %s", err, responseStr) + } + return "", err2 + } + return "", httpError(resp.StatusCode) + } + + if resp.Header.Get("FFZ-Cache") != "" { + durSecs, err := strconv.ParseInt(resp.Header.Get("FFZ-Cache"), 10, 64) + if err != nil { + return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err) + } + duration := time.Duration(durSecs) * time.Second + backend.responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration) + } + + now := time.Now().UTC() + backend.lastSuccessLock.Lock() + defer backend.lastSuccessLock.Unlock() + backend.lastSuccess[bPathOtherCommand] = now + backend.lastSuccess[healthBucket] = now + + return +} + +// SendAggregatedData sends aggregated emote usage and following data to the backend server. +func (backend *backendInfo) SendAggregatedData(form url.Values) error { + sealedForm, err := backend.SealRequest(form) + if err != nil { + return err + } + + resp, err := backend.HTTPClient.PostForm(backend.postStatisticsURL, sealedForm) + if err != nil { + return err + } + if resp.StatusCode < 200 || resp.StatusCode > 299 { + resp.Body.Close() + return httpError(resp.StatusCode) + } + + backend.lastSuccessLock.Lock() + defer backend.lastSuccessLock.Unlock() + backend.lastSuccess[bPathAggStats] = time.Now().UTC() + + return resp.Body.Close() +} + +// ErrBackendNotOK indicates that the backend replied with something other than the string "ok". +type ErrBackendNotOK struct { + Response string + Code int +} + +// Error Implements the error interface. +func (noe ErrBackendNotOK) Error() string { + return fmt.Sprintf("backend returned %d: %s", noe.Code, noe.Response) +} + +// SendNewTopicNotice notifies the backend that a client has performed the first subscription to a pub/sub topic. +// POST data: +// channels=room.trihex +// added=t +func (backend *backendInfo) SendNewTopicNotice(topic string) error { + return backend.sendTopicNotice(topic, true) +} + +// SendCleanupTopicsNotice notifies the backend that pub/sub topics have no subscribers anymore. +// POST data: +// channels=room.sirstendec,room.bobross,feature.foo +// added=f +func (backend *backendInfo) SendCleanupTopicsNotice(topics []string) error { + return backend.sendTopicNotice(strings.Join(topics, ","), false) +} + +func (backend *backendInfo) sendTopicNotice(topic string, added bool) error { + formData := url.Values{} + formData.Set("channels", topic) + if added { + formData.Set("added", "t") + } else { + formData.Set("added", "f") + } + + sealedForm, err := backend.SealRequest(formData) + if err != nil { + return err + } + + resp, err := backend.HTTPClient.PostForm(backend.addTopicURL, sealedForm) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return ErrBackendNotOK{Code: resp.StatusCode, Response: fmt.Sprintf("(error reading non-2xx response): %s", err.Error())} + } + return ErrBackendNotOK{Code: resp.StatusCode, Response: string(respBytes)} + } + + backend.lastSuccessLock.Lock() + defer backend.lastSuccessLock.Unlock() + backend.lastSuccess[bPathAddTopic] = time.Now().UTC() + + return nil +} + +func httpError(statusCode int) error { + return fmt.Errorf("backend http error: %d", statusCode) +} + +// GenerateKeys generates a new NaCl keypair for the server and writes out the default configuration file. +func GenerateKeys(outputFile, serverID, theirPublicStr string) { + var err error + output := ConfigFile{ + ListenAddr: "0.0.0.0:8001", + SSLListenAddr: "0.0.0.0:443", + BackendURL: "http://localhost:8002/ffz", + MinMemoryKBytes: defaultMinMemoryKB, + } + + output.ServerID, err = strconv.Atoi(serverID) + if err != nil { + log.Fatal(err) + } + + ourPublic, ourPrivate, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatal(err) + } + output.OurPublicKey, output.OurPrivateKey = ourPublic[:], ourPrivate[:] + + if theirPublicStr != "" { + reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(theirPublicStr)) + theirPublic, err := ioutil.ReadAll(reader) + if err != nil { + log.Fatal(err) + } + output.BackendPublicKey = theirPublic + } + + bytes, err := json.MarshalIndent(output, "", "\t") + if err != nil { + log.Fatal(err) + } + fmt.Println(string(bytes)) + err = ioutil.WriteFile(outputFile, bytes, 0600) + if err != nil { + log.Fatal(err) + } +} diff --git a/socketserver/server/backend_test.go b/socketserver/server/backend_test.go new file mode 100644 index 00000000..d1c85adb --- /dev/null +++ b/socketserver/server/backend_test.go @@ -0,0 +1,131 @@ +package server + +import ( + "net/http" + "net/url" + "testing" + + . "gopkg.in/check.v1" +) + +func Test(t *testing.T) { TestingT(t) } + +func TestSealRequest(t *testing.T) { + TSetup(SetupNoServers, nil) + b := Backend + + values := url.Values{ + "QuickBrownFox": []string{"LazyDog"}, + } + + sealedValues, err := b.SealRequest(values) + if err != nil { + t.Fatal(err) + } + // sealedValues.Encode() + // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV + + unsealedValues, err := b.UnsealRequest(sealedValues) + if err != nil { + t.Fatal(err) + } + + if unsealedValues.Get("QuickBrownFox") != "LazyDog" { + t.Errorf("Failed to round-trip, got back %v", unsealedValues) + } +} + +type BackendSuite struct{} + +var _ = Suite(&BackendSuite{}) + +func (s *BackendSuite) TestSendRemoteCommand(c *C) { + const TestCommand1 = "somecommand" + const TestCommand2 = "other" + const PathTestCommand1 = "/cmd/" + TestCommand1 + const PathTestCommand2 = "/cmd/" + TestCommand2 + const TestData1 = "623478.32" + const TestData2 = "\"Hello, there\"" + const TestData3 = "3" + const TestUsername = "sirstendec" + const TestResponse1 = "asfdg" + const TestResponse2 = "yuiop" + const TestErrorText = "{\"err\":\"some kind of special error\"}" + + var AnonAuthInfo = AuthInfo{} + var NonValidatedAuthInfo = AuthInfo{TwitchUsername: TestUsername} + var ValidatedAuthInfo = AuthInfo{TwitchUsername: TestUsername, UsernameValidated: true} + + headersCacheTwoSeconds := http.Header{"FFZ-Cache": []string{"2"}} + headersCacheInvalid := http.Header{"FFZ-Cache": []string{"NotANumber"}} + headersApplicationJson := http.Header{"Content-Type": []string{"application/json"}} + + mockBackend := NewTBackendRequestChecker(c, + TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{""}}, TestResponse1, nil}, + TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{""}}, TestResponse2, nil}, + TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, nil}, + TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"1"}, "username": []string{TestUsername}}, TestResponse1, nil}, + TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData2}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, headersCacheTwoSeconds}, + // cached + // cached + TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse2, headersCacheTwoSeconds}, + TExpectedBackendRequest{401, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, "", nil}, + TExpectedBackendRequest{503, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, "", nil}, + TExpectedBackendRequest{418, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestErrorText, headersApplicationJson}, + TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData3}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, headersCacheInvalid}, + ) + _, _, _ = TSetup(SetupWantBackendServer, mockBackend) + defer mockBackend.Close() + + var resp string + var err error + b := Backend + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, AnonAuthInfo) + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, AnonAuthInfo) + c.Check(resp, Equals, TestResponse2) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo) + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, ValidatedAuthInfo) + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + // cache save + resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, NonValidatedAuthInfo) + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, NonValidatedAuthInfo) // cache hit + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, AnonAuthInfo) // cache hit + c.Check(resp, Equals, TestResponse1) + c.Check(err, IsNil) + // cache miss - data is different + resp, err = b.SendRemoteCommandCached(TestCommand2, TestData1, NonValidatedAuthInfo) + c.Check(resp, Equals, TestResponse2) + c.Check(err, IsNil) + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo) + c.Check(resp, Equals, "") + c.Check(err, Equals, ErrAuthorizationNeeded) + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo) + c.Check(resp, Equals, "") + c.Check(err, ErrorMatches, "backend http error: 503") + + resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo) + c.Check(resp, Equals, "") + c.Check(err, ErrorMatches, TestErrorText) + + resp, err = b.SendRemoteCommand(TestCommand2, TestData3, NonValidatedAuthInfo) + c.Check(resp, Equals, "") + c.Check(err, ErrorMatches, "The RPC server returned a non-integer cache duration: .*") +} diff --git a/socketserver/server/commands.go b/socketserver/server/commands.go new file mode 100644 index 00000000..3d0156a1 --- /dev/null +++ b/socketserver/server/commands.go @@ -0,0 +1,606 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "log" + "net/url" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/satori/go.uuid" +) + +// Command is a string indicating which RPC is requested. +// The Commands sent from Client -> Server and Server -> Client are disjoint sets. +type Command string + +// CommandHandler is a RPC handler associated with a Command. +type CommandHandler func(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) + +var commandHandlers = map[Command]CommandHandler{ + HelloCommand: C2SHello, + "ping": C2SPing, + SetUserCommand: C2SSetUser, + ReadyCommand: C2SReady, + + "sub": C2SSubscribe, + "unsub": C2SUnsubscribe, + + "track_follow": C2STrackFollow, + "emoticon_uses": C2SEmoticonUses, + "survey": C2SSurvey, + + "twitch_emote": C2SHandleBunchedCommand, + "get_link": C2SHandleBunchedCommand, + "get_display_name": C2SHandleBunchedCommand, + "update_follow_buttons": C2SHandleRemoteCommand, + "chat_history": C2SHandleRemoteCommand, + "user_history": C2SHandleRemoteCommand, +} + +func setupInterning() { + PubSubChannelPool = NewStringPool() + TwitchChannelPool = NewStringPool() + + CommandPool = NewStringPool() + CommandPool._Intern_Setup(string(HelloCommand)) + CommandPool._Intern_Setup("ping") + CommandPool._Intern_Setup(string(SetUserCommand)) + CommandPool._Intern_Setup(string(ReadyCommand)) + CommandPool._Intern_Setup("sub") + CommandPool._Intern_Setup("unsub") + CommandPool._Intern_Setup("track_follow") + CommandPool._Intern_Setup("emoticon_uses") + CommandPool._Intern_Setup("twitch_emote") + CommandPool._Intern_Setup("get_link") + CommandPool._Intern_Setup("get_display_name") + CommandPool._Intern_Setup("update_follow_buttons") + CommandPool._Intern_Setup("chat_history") + CommandPool._Intern_Setup("user_history") + CommandPool._Intern_Setup("adjacent_history") +} + +// DispatchC2SCommand handles a C2S Command in the provided ClientMessage. +// It calls the correct CommandHandler function, catching panics. +// It sends either the returned Reply ClientMessage, setting the correct messageID, or sends an ErrorCommand +func DispatchC2SCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) { + handler, ok := commandHandlers[msg.Command] + if !ok { + handler = C2SHandleRemoteCommand + } + + CommandCounter <- msg.Command + + response, err := callHandler(handler, conn, client, msg) + + if err == nil { + if response.Command == AsyncResponseCommand { + // Don't send anything + // The response will be delivered over client.MessageChannel / serverMessageChan + } else { + response.MessageID = msg.MessageID + SendMessage(conn, response) + } + } else { + SendMessage(conn, ClientMessage{ + MessageID: msg.MessageID, + Command: ErrorCommand, + Arguments: err.Error(), + }) + } +} + +func callHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + fmt.Print("[!] Error executing command", cmsg.Command, "--", r) + err, ok = r.(error) + if !ok { + err = fmt.Errorf("command handler: %v", r) + } + } + }() + return handler(conn, client, cmsg) +} + +// C2SHello implements the `hello` C2S Command. +// It calls SubscribeGlobal() and SubscribeDefaults() with the client, and fills out ClientInfo.Version and ClientInfo.ClientID. +func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + ary, ok := msg.Arguments.([]interface{}) + if !ok { + err = ErrExpectedTwoStrings + return + } + if len(ary) != 2 { + err = ErrExpectedTwoStrings + return + } + version, ok := ary[0].(string) + if !ok { + err = ErrExpectedTwoStrings + return + } + + client.VersionString = copyString(version) + client.Version = VersionFromString(version) + + if clientIDStr, ok := ary[1].(string); ok { + client.ClientID = uuid.FromStringOrNil(clientIDStr) + if client.ClientID == uuid.Nil { + client.ClientID = uuid.NewV4() + } + } else if _, ok := ary[1].(bool); ok { + // opt out + client.ClientID = AnonymousClientID + } else { + err = ErrExpectedTwoStrings + return + } + + uniqueUserChannel <- client.ClientID + + SubscribeGlobal(client) + SubscribeDefaults(client) + + jsTime := float64(time.Now().UnixNano()/1000) / 1000 + return ClientMessage{ + Arguments: []interface{}{ + client.ClientID.String(), + jsTime, + }, + }, nil +} + +func C2SPing(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + return ClientMessage{ + Arguments: float64(time.Now().UnixNano()/1000) / 1000, + }, nil +} + +func C2SSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + username, err := msg.ArgumentsAsString() + if err != nil { + return + } + + username = copyString(username) + + client.Mutex.Lock() + client.UsernameValidated = false + client.TwitchUsername = username + client.Mutex.Unlock() + + if Configuration.SendAuthToNewClients { + client.MsgChannelKeepalive.Add(1) + go client.StartAuthorization(func(_ *ClientInfo, _ bool) { + client.MsgChannelKeepalive.Done() + }) + } + + return ResponseSuccess, nil +} + +func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + // disconnectAt, err := msg.ArgumentsAsInt() + // if err != nil { + // return + // } + + client.Mutex.Lock() + client.ReadyComplete = true + client.Mutex.Unlock() + + client.MsgChannelKeepalive.Add(1) + go func() { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + SendBacklogForNewClient(client) + client.MsgChannelKeepalive.Done() + }() + return ClientMessage{Command: AsyncResponseCommand}, nil +} + +func C2SSubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + + channel = PubSubChannelPool.Intern(channel) + + client.Mutex.Lock() + AddToSliceS(&client.CurrentChannels, channel) + client.Mutex.Unlock() + + SubscribeChannel(client, channel) + + if client.ReadyComplete { + client.MsgChannelKeepalive.Add(1) + go func() { + SendBacklogForChannel(client, channel) + client.MsgChannelKeepalive.Done() + }() + } + + return ResponseSuccess, nil +} + +// C2SUnsubscribe implements the `unsub` C2S Command. +// It removes the channel from ClientInfo.CurrentChannels and calls UnsubscribeSingleChat. +func C2SUnsubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, err := msg.ArgumentsAsString() + if err != nil { + return + } + + channel = PubSubChannelPool.Intern(channel) + + client.Mutex.Lock() + RemoveFromSliceS(&client.CurrentChannels, channel) + client.Mutex.Unlock() + + UnsubscribeSingleChat(client, channel) + + return ResponseSuccess, nil +} + +// C2SSurvey implements the survey C2S Command. +// Surveys are discarded.s +func C2SSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + // Discard + return ResponseSuccess, nil +} + +type followEvent struct { + User string `json:"u"` + Channel string `json:"c"` + NowFollowing bool `json:"f"` + Timestamp time.Time `json:"t"` +} + +var followEvents []followEvent + +// followEventsLock is the lock for followEvents. +var followEventsLock sync.Mutex + +// C2STrackFollow implements the `track_follow` C2S Command. +// It adds the record to `followEvents`, which is submitted to the backend on a timer. +func C2STrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + channel, following, err := msg.ArgumentsAsStringAndBool() + if err != nil { + return + } + now := time.Now() + + channel = TwitchChannelPool.Intern(channel) + + followEventsLock.Lock() + followEvents = append(followEvents, followEvent{User: client.TwitchUsername, Channel: channel, NowFollowing: following, Timestamp: now}) + followEventsLock.Unlock() + + return ResponseSuccess, nil +} + +// AggregateEmoteUsage is a map from emoteID to a map from chatroom name to usage count. +var aggregateEmoteUsage = make(map[int]map[string]int) + +// AggregateEmoteUsageLock is the lock for AggregateEmoteUsage. +var aggregateEmoteUsageLock sync.Mutex + +// ErrNegativeEmoteUsage is emitted when the submitted emote usage is negative. +var ErrNegativeEmoteUsage = errors.New("Emote usage count cannot be negative") + +// C2SEmoticonUses implements the `emoticon_uses` C2S Command. +// msg.Arguments are in the JSON format of [1]map[emoteID]map[ChatroomName]float64. +func C2SEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + // if this panics, will be caught by callHandler + mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{}) + + // Validate: male suire + for strEmote, val1 := range mapRoot { + _, err = strconv.Atoi(strEmote) + if err != nil { + return + } + mapInner := val1.(map[string]interface{}) + for _, val2 := range mapInner { + var count = int(val2.(float64)) + if count <= 0 { + err = ErrNegativeEmoteUsage + return + } + } + } + + aggregateEmoteUsageLock.Lock() + defer aggregateEmoteUsageLock.Unlock() + + var total int + + for strEmote, val1 := range mapRoot { + var emoteID int + emoteID, err = strconv.Atoi(strEmote) + if err != nil { + return + } + + destMapInner, ok := aggregateEmoteUsage[emoteID] + if !ok { + destMapInner = make(map[string]int) + aggregateEmoteUsage[emoteID] = destMapInner + } + + mapInner := val1.(map[string]interface{}) + for roomName, val2 := range mapInner { + var count = int(val2.(float64)) + if count > 200 { + count = 200 + } + roomName = TwitchChannelPool.Intern(roomName) + destMapInner[roomName] += count + total += count + } + } + + Statistics.EmotesReportedTotal += uint64(total) + + return ResponseSuccess, nil +} + +// is_init_func +func aggregateDataSender() { + for { + time.Sleep(5 * time.Minute) + aggregateDataSender_do() + } +} + +func aggregateDataSender_do() { + followEventsLock.Lock() + follows := followEvents + followEvents = nil + followEventsLock.Unlock() + aggregateEmoteUsageLock.Lock() + emoteUsage := aggregateEmoteUsage + aggregateEmoteUsage = make(map[int]map[string]int) + aggregateEmoteUsageLock.Unlock() + + reportForm := url.Values{} + + followJSON, err := json.Marshal(follows) + if err != nil { + log.Println("error reporting aggregate data:", err) + } else { + reportForm.Set("follows", string(followJSON)) + } + + strEmoteUsage := make(map[string]map[string]int) + for emoteID, usageByChannel := range emoteUsage { + strEmoteID := strconv.Itoa(emoteID) + strEmoteUsage[strEmoteID] = usageByChannel + } + emoteJSON, err := json.Marshal(strEmoteUsage) + if err != nil { + log.Println("error reporting aggregate data:", err) + } else { + reportForm.Set("emotes", string(emoteJSON)) + } + + err = Backend.SendAggregatedData(reportForm) + if err != nil { + log.Println("error reporting aggregate data:", err) + return + } + + // done +} + +type bunchedRequest struct { + Command Command + Param string +} + +type cachedBunchedResponse struct { + Response string + Timestamp time.Time +} +type bunchSubscriber struct { + Client *ClientInfo + MessageID int +} + +type bunchSubscriberList struct { + sync.Mutex + Members []bunchSubscriber +} + +type cacheStatus byte + +const ( + CacheStatusNotFound = iota + CacheStatusFound + CacheStatusExpired +) + +var pendingBunchedRequests = make(map[bunchedRequest]*bunchSubscriberList) +var pendingBunchLock sync.Mutex +var bunchCache = make(map[bunchedRequest]cachedBunchedResponse) +var bunchCacheLock sync.RWMutex +var bunchCacheCleanupSignal = sync.NewCond(&bunchCacheLock) +var bunchCacheLastCleanup time.Time + +func bunchedRequestFromCM(msg *ClientMessage) bunchedRequest { + return bunchedRequest{Command: msg.Command, Param: copyString(msg.origArguments)} +} + +// is_init_func +func bunchCacheJanitor() { + go func() { + for { + time.Sleep(30 * time.Minute) + bunchCacheCleanupSignal.Signal() + } + }() + + bunchCacheLock.Lock() + for { + // Unlocks CachedBunchLock, waits for signal, re-locks + bunchCacheCleanupSignal.Wait() + + if bunchCacheLastCleanup.After(time.Now().Add(-1 * time.Second)) { + // skip if it's been less than 1 second + continue + } + + // CachedBunchLock is held here + keepIfAfter := time.Now().Add(-5 * time.Minute) + for req, resp := range bunchCache { + if !resp.Timestamp.After(keepIfAfter) { + delete(bunchCache, req) + } + } + bunchCacheLastCleanup = time.Now() + // Loop and Wait(), which re-locks + } +} + +var emptyCachedBunchedResponse cachedBunchedResponse + +func bunchGetCacheStatus(br bunchedRequest, client *ClientInfo) (cacheStatus, cachedBunchedResponse) { + bunchCacheLock.RLock() + defer bunchCacheLock.RUnlock() + cachedResponse, ok := bunchCache[br] + if ok && cachedResponse.Timestamp.After(time.Now().Add(-5*time.Minute)) { + return CacheStatusFound, cachedResponse + } else if ok { + return CacheStatusExpired, emptyCachedBunchedResponse + } + return CacheStatusNotFound, emptyCachedBunchedResponse +} + +func normalizeBunchedRequest(br bunchedRequest) bunchedRequest { + if br.Command == "get_link" { + // TODO + } + return br +} + +// C2SHandleBunchedCommand handles C2S Commands such as `get_link`. +// It makes a request to the backend server for the data, but any other requests coming in while the first is pending also get the responses from the first one. +// Additionally, results are cached. +func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + // FIXME(riking): Function is too complex + + br := bunchedRequestFromCM(&msg) + br = normalizeBunchedRequest(br) + + cacheStatus, cachedResponse := bunchGetCacheStatus(br, client) + + if cacheStatus == CacheStatusFound { + var response ClientMessage + response.Command = SuccessCommand + response.MessageID = msg.MessageID + response.origArguments = cachedResponse.Response + response.parseOrigArguments() + + return response, nil + } else if cacheStatus == CacheStatusExpired { + // Wake up the lazy janitor + bunchCacheCleanupSignal.Signal() + } + + pendingBunchLock.Lock() + defer pendingBunchLock.Unlock() + list, ok := pendingBunchedRequests[br] + if ok { + list.Lock() + AddToSliceB(&list.Members, client, msg.MessageID) + list.Unlock() + + return ClientMessage{Command: AsyncResponseCommand}, nil + } + + pendingBunchedRequests[br] = &bunchSubscriberList{Members: []bunchSubscriber{{Client: client, MessageID: msg.MessageID}}} + + go func(request bunchedRequest) { + respStr, err := Backend.SendRemoteCommandCached(string(request.Command), request.Param, AuthInfo{}) + + var msg ClientMessage + if err == nil { + msg.Command = SuccessCommand + msg.origArguments = respStr + msg.parseOrigArguments() + } else { + msg.Command = ErrorCommand + msg.Arguments = err.Error() + } + + if err == nil { + bunchCacheLock.Lock() + bunchCache[request] = cachedBunchedResponse{Response: respStr, Timestamp: time.Now()} + bunchCacheLock.Unlock() + } + + pendingBunchLock.Lock() + bsl := pendingBunchedRequests[request] + delete(pendingBunchedRequests, request) + pendingBunchLock.Unlock() + + bsl.Lock() + for _, member := range bsl.Members { + msg.MessageID = member.MessageID + select { + case member.Client.MessageChannel <- msg: + case <-member.Client.MsgChannelIsDone: + } + } + bsl.Unlock() + }(br) + + return ClientMessage{Command: AsyncResponseCommand}, nil +} + +func C2SHandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + client.MsgChannelKeepalive.Add(1) + go doRemoteCommand(conn, msg, client) + + return ClientMessage{Command: AsyncResponseCommand}, nil +} + +const AuthorizationFailedErrorString = "Failed to verify your Twitch username." +const AuthorizationNeededError = "You must be signed in to use that command." + +func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo) { + resp, err := Backend.SendRemoteCommandCached(string(msg.Command), copyString(msg.origArguments), client.AuthInfo) + + if err == ErrAuthorizationNeeded { + if client.TwitchUsername == "" { + // Not logged in + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError} + client.MsgChannelKeepalive.Done() + return + } + client.StartAuthorization(func(_ *ClientInfo, success bool) { + if success { + doRemoteCommand(conn, msg, client) + } else { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString} + client.MsgChannelKeepalive.Done() + } + }) + return // without keepalive.Done() + } else if bfe, ok := err.(ErrForwardedFromBackend); ok { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError} + } else if err != nil { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} + } else { + msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + client.MsgChannelKeepalive.Done() +} diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go new file mode 100644 index 00000000..723f4fb2 --- /dev/null +++ b/socketserver/server/handlecore.go @@ -0,0 +1,711 @@ +package server // import "bitbucket.org/stendec/frankerfacez/socketserver/server" + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "os/signal" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + "unicode/utf8" + + "github.com/gorilla/websocket" +) + +// SuccessCommand is a Reply Command to indicate success in reply to a C2S Command. +const SuccessCommand Command = "ok" + +// ErrorCommand is a Reply Command to indicate that a C2S Command failed. +const ErrorCommand Command = "error" + +// HelloCommand is a C2S Command. +// HelloCommand must be the Command of the first ClientMessage sent during a connection. +// Sending any other command will result in a CloseFirstMessageNotHello. +const HelloCommand Command = "hello" + +// ReadyCommand is a C2S Command. +// It indicates that the client is finished sending the initial 'sub' commands and the server should send the backlog. +const ReadyCommand Command = "ready" + +const SetUserCommand Command = "setuser" + +// AuthorizeCommand is a S2C Command sent as part of Twitch username validation. +const AuthorizeCommand Command = "do_authorize" + +// AsyncResponseCommand is a pseudo-Reply Command. +// It indicates that the Reply Command to the client's C2S Command will be delivered +// on a goroutine over the ClientInfo.MessageChannel and should not be delivered immediately. +const AsyncResponseCommand Command = "_async" + +const defaultMinMemoryKB = 1024 * 24 + +// DotTwitchDotTv is the .twitch.tv suffix. +const DotTwitchDotTv = ".twitch.tv" + +const dotCbenniDotCom = ".cbenni.com" + +var OriginRegexp = regexp.MustCompile("(" + DotTwitchDotTv + "|" + dotCbenniDotCom + ")" + "$") + +// ResponseSuccess is a Reply ClientMessage with the MessageID not yet filled out. +var ResponseSuccess = ClientMessage{Command: SuccessCommand} + +// Configuration is the active ConfigFile. +var Configuration *ConfigFile + +var janitorsOnce sync.Once + +var CommandPool *StringPool +var PubSubChannelPool *StringPool +var TwitchChannelPool *StringPool + +// SetupServerAndHandle starts all background goroutines and registers HTTP listeners on the given ServeMux. +// Essentially, this function completely preps the server for a http.ListenAndServe call. +// (Uses http.DefaultServeMux if `serveMux` is nil.) +func SetupServerAndHandle(config *ConfigFile, serveMux *http.ServeMux) { + Configuration = config + + if config.MinMemoryKBytes == 0 { + config.MinMemoryKBytes = defaultMinMemoryKB + } + + Backend = setupBackend(config) + + if serveMux == nil { + serveMux = http.DefaultServeMux + } + + bannerBytes, err := ioutil.ReadFile("index.html") + if err != nil { + log.Fatalln("Could not open index.html:", err) + } + BannerHTML = bannerBytes + + serveMux.HandleFunc("/", HTTPHandleRootURL) + serveMux.Handle("/.well-known/", http.FileServer(http.Dir("/tmp/letsencrypt/"))) + serveMux.HandleFunc("/healthcheck", HTTPSayOK) + serveMux.HandleFunc("/stats", HTTPShowStatistics) + serveMux.HandleFunc("/hll/", HTTPShowHLL) + serveMux.HandleFunc("/hll_force_write", HTTPWriteHLL) + + serveMux.HandleFunc("/drop_backlog", HTTPBackendDropBacklog) + serveMux.HandleFunc("/uncached_pub", HTTPBackendUncachedPublish) + serveMux.HandleFunc("/cached_pub", HTTPBackendCachedPublish) + serveMux.HandleFunc("/get_sub_count", HTTPGetSubscriberCount) + + announceForm, err := Backend.SealRequest(url.Values{ + "startup": []string{"1"}, + }) + if err != nil { + log.Fatalln("Unable to seal requests:", err) + } + resp, err := Backend.HTTPClient.PostForm(Backend.announceStartupURL, announceForm) + if err != nil { + log.Println("could not announce startup to backend:", err) + } else { + resp.Body.Close() + Backend.lastSuccessLock.Lock() + Backend.lastSuccess[bPathAnnounceStartup] = time.Now().UTC() + Backend.lastSuccessLock.Unlock() + } + + janitorsOnce.Do(startJanitors) +} + +func init() { + setupInterning() +} + +// startJanitors starts the 'is_init_func' goroutines +func startJanitors() { + loadUniqueUsers() + + go authorizationJanitor() + go aggregateDataSender() + go bunchCacheJanitor() + go cachedMessageJanitor() + go commandCounter() + go pubsubJanitor() + + go ircConnection() + go shutdownHandler() +} + +// is_init_func +func shutdownHandler() { + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGUSR1) + signal.Notify(ch, syscall.SIGTERM) + <-ch + log.Println("Shutting down...") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + writeHLL() + wg.Done() + }() + + StopAcceptingConnections = true + close(StopAcceptingConnectionsCh) + + time.Sleep(1 * time.Second) + wg.Wait() + os.Exit(0) +} + +// is_init_func +test +func dumpStackOnCtrlZ() { + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGTSTP) + for range ch { + fmt.Println("Got ^Z") + + buf := make([]byte, 10000) + byteCnt := runtime.Stack(buf, true) + fmt.Println(string(buf[:byteCnt])) + } +} + +// HTTPSayOK replies with 200 and a body of "ok\n". +func HTTPSayOK(w http.ResponseWriter, _ *http.Request) { + w.(interface { + WriteString(string) error + }).WriteString("ok\n") +} + +// SocketUpgrader is the websocket.Upgrader currently in use. +var SocketUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return r.Header.Get("Origin") == "" || OriginRegexp.MatchString(r.Header.Get("Origin")) + }, +} + +// BannerHTML is the content served to web browsers viewing the socket server website. +// Memes go here. +var BannerHTML []byte + +// StopAcceptingConnectionsCh is closed while the server is shutting down. +var StopAcceptingConnectionsCh = make(chan struct{}) +var StopAcceptingConnections = false + +// HTTPHandleRootURL is the http.HandleFunc for requests on `/`. +// It either uses the SocketUpgrader or writes out the BannerHTML. +func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + fmt.Println(404) + return + } + + // racy, but should be ok? + if StopAcceptingConnections { + w.WriteHeader(503) + fmt.Fprint(w, "server is shutting down") + return + } + + if strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") { + updateSysMem() + + if Statistics.SysMemFreeKB > 0 && Statistics.SysMemFreeKB < Configuration.MinMemoryKBytes { + atomic.AddUint64(&Statistics.LowMemDroppedConnections, 1) + w.WriteHeader(503) + fmt.Fprint(w, "error: low memory") + return + } + + if Configuration.MaxClientCount != 0 { + curClients := atomic.LoadUint64(&Statistics.CurrentClientCount) + if curClients >= Configuration.MaxClientCount { + w.WriteHeader(503) + fmt.Fprint(w, "error: client limit reached") + return + } + } + + conn, err := SocketUpgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Fprintf(w, "error: %v", err) + return + } + RunSocketConnection(conn) + + return + } else { + w.Write(BannerHTML) + } +} + +// ErrProtocolGeneric is sent in a ErrorCommand Reply. +var ErrProtocolGeneric error = errors.New("FFZ Socket protocol error.") + +// ErrProtocolNegativeMsgID is sent in a ErrorCommand Reply when a negative MessageID is received. +var ErrProtocolNegativeMsgID error = errors.New("FFZ Socket protocol error: negative or zero message ID.") + +// ErrExpectedSingleString is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedSingleString = errors.New("Error: Expected single string as arguments.") + +// ErrExpectedSingleInt is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedSingleInt = errors.New("Error: Expected single integer as arguments.") + +// ErrExpectedTwoStrings is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.") + +// ErrExpectedStringAndBool is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedStringAndBool = errors.New("Error: Expected array of string, bool as arguments.") + +// ErrExpectedStringAndInt is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedStringAndInt = errors.New("Error: Expected array of string, int as arguments.") + +// ErrExpectedStringAndIntGotFloat is sent in a ErrorCommand Reply when the Arguments are of the wrong type. +var ErrExpectedStringAndIntGotFloat = errors.New("Error: Second argument was a float, expected an integer.") + +// CloseGoingAway is sent when the server is restarting. +var CloseGoingAway = websocket.CloseError{Code: websocket.CloseGoingAway, Text: "server restarting"} + +// CloseRebalance is sent when the server has too many clients and needs to shunt some to another server. +var CloseRebalance = websocket.CloseError{Code: websocket.CloseGoingAway, Text: "kicked for rebalancing, please select a new server"} + +// CloseGotBinaryMessage is the termination reason when the client sends a binary websocket frame. +var CloseGotBinaryMessage = websocket.CloseError{Code: websocket.CloseUnsupportedData, Text: "got binary packet"} + +// CloseTimedOut is the termination reason when the client fails to send or respond to ping frames. +var CloseTimedOut = websocket.CloseError{Code: 3003, Text: "no ping replies for 5 minutes"} + +// CloseTooManyBufferedMessages is the termination reason when the sending thread buffers too many messages. +var CloseTooManyBufferedMessages = websocket.CloseError{Code: websocket.CloseMessageTooBig, Text: "too many pending messages"} + +// CloseFirstMessageNotHello is the termination reason +var CloseFirstMessageNotHello = websocket.CloseError{ + Text: "Error - the first message sent must be a 'hello'", + Code: websocket.ClosePolicyViolation, +} + +var CloseNonUTF8Data = websocket.CloseError{ + Code: websocket.CloseUnsupportedData, + Text: "Non UTF8 data recieved. Network corruption likely.", +} + +const sendMessageBufferLength = 30 +const sendMessageAbortLength = 20 + +// RunSocketConnection contains the main run loop of a websocket connection. +// +// First, it sets up the channels, the ClientInfo object, and the pong frame handler. +// It starts the reader goroutine pointing at the newly created channels. +// The function then enters the run loop (a `for{select{}}`). +// The run loop is broken when an object is received on errorChan, or if `hello` is not the first C2S Command. +// +// After the run loop stops, the function launches a goroutine to drain +// client.MessageChannel, signals the reader goroutine to stop, unsubscribes +// from all pub/sub channels, waits on MsgChannelKeepalive (remember, the +// messages are being drained), and finally closes client.MessageChannel +// (which ends the drainer goroutine). +func RunSocketConnection(conn *websocket.Conn) { + // websocket.Conn is a ReadWriteCloser + + atomic.AddUint64(&Statistics.ClientConnectsTotal, 1) + atomic.AddUint64(&Statistics.CurrentClientCount, 1) + + _clientChan := make(chan ClientMessage) + _serverMessageChan := make(chan ClientMessage, sendMessageBufferLength) + _errorChan := make(chan error) + stoppedChan := make(chan struct{}) + + var client ClientInfo + client.MessageChannel = _serverMessageChan + client.RemoteAddr = conn.RemoteAddr() + client.MsgChannelIsDone = stoppedChan + + // var report logstasher.ConnectionReport + // report.ConnectTime = time.Now() + // report.RemoteAddr = client.RemoteAddr + + conn.SetPongHandler(func(pongBody string) error { + client.Mutex.Lock() + client.pingCount = 0 + client.Mutex.Unlock() + return nil + }) + + // All set up, now enter the work loop + go runSocketReader(conn, _errorChan, _clientChan, stoppedChan) + closeReason := runSocketWriter(conn, &client, _errorChan, _clientChan, _serverMessageChan) + + // Exit + closeConnection(conn, closeReason) + // closeConnection(conn, closeReason, &report) + + // Launch message draining goroutine - we aren't out of the pub/sub records + go func() { + for _ = range _serverMessageChan { + } + }() + + // Closes client.MsgChannelIsDone and also stops the reader thread + close(stoppedChan) + + // Stop getting messages... + UnsubscribeAll(&client) + + // Wait for pending jobs to finish... + client.MsgChannelKeepalive.Wait() + client.MessageChannel = nil + + // And done. + // Close the channel so the draining goroutine can finish, too. + close(_serverMessageChan) + + if !StopAcceptingConnections { + // Don't perform high contention operations when server is closing + atomic.AddUint64(&Statistics.CurrentClientCount, NegativeOne) + atomic.AddUint64(&Statistics.ClientDisconnectsTotal, 1) + + // report.UsernameWasValidated = client.UsernameValidated + // report.TwitchUsername = client.TwitchUsername + // logstasher.Submit(&report) + } +} + +func runSocketReader(conn *websocket.Conn, errorChan chan<- error, clientChan chan<- ClientMessage, stoppedChan <-chan struct{}) { + var msg ClientMessage + var messageType int + var packet []byte + var err error + + defer close(errorChan) + defer close(clientChan) + + for ; err == nil; messageType, packet, err = conn.ReadMessage() { + if messageType == websocket.BinaryMessage { + err = &CloseGotBinaryMessage + break + } + if messageType == websocket.CloseMessage { + err = io.EOF + break + } + + UnmarshalClientMessage(packet, messageType, &msg) + if msg.MessageID == 0 { + continue + } + select { + case clientChan <- msg: + case <-stoppedChan: + return + } + } + + select { + case errorChan <- err: + case <-stoppedChan: + } + // exit goroutine +} + +func runSocketWriter(conn *websocket.Conn, client *ClientInfo, errorChan <-chan error, clientChan <-chan ClientMessage, serverMessageChan <-chan ClientMessage) websocket.CloseError { + for { + select { + case err := <-errorChan: + if err == io.EOF { + return websocket.CloseError{ + Code: websocket.CloseGoingAway, + Text: err.Error(), + } + } else if closeMsg, isClose := err.(*websocket.CloseError); isClose { + return *closeMsg + } else { + return websocket.CloseError{ + Code: websocket.CloseInternalServerErr, + Text: err.Error(), + } + } + + case msg := <-clientChan: + if client.VersionString == "" && msg.Command != HelloCommand { + return CloseFirstMessageNotHello + } + + for _, char := range msg.Command { + if char == utf8.RuneError { + return CloseNonUTF8Data + } + } + + DispatchC2SCommand(conn, client, msg) + + case msg := <-serverMessageChan: + if len(serverMessageChan) > sendMessageAbortLength { + return CloseTooManyBufferedMessages + } + if cls, ok := msg.Arguments.(*websocket.CloseError); ok { + return *cls + } + SendMessage(conn, msg) + + case <-time.After(1 * time.Minute): + client.Mutex.Lock() + client.pingCount++ + tooManyPings := client.pingCount == 5 + client.Mutex.Unlock() + if tooManyPings { + return CloseTimedOut + } else { + conn.WriteControl(websocket.PingMessage, []byte(strconv.FormatInt(time.Now().Unix(), 10)), getDeadline()) + } + + case <-StopAcceptingConnectionsCh: + return CloseGoingAway + } + } +} + +func getDeadline() time.Time { + return time.Now().Add(1 * time.Minute) +} + +func closeConnection(conn *websocket.Conn, closeMsg websocket.CloseError) { + closeTxt := closeMsg.Text + if strings.Contains(closeTxt, "read: connection reset by peer") { + closeTxt = "read: connection reset by peer" + } else if strings.Contains(closeTxt, "use of closed network connection") { + closeTxt = "read: use of closed network connection" + } else if closeMsg.Code == 1001 { + closeTxt = "clean shutdown" + } + + // report.DisconnectCode = closeMsg.Code + // report.DisconnectReason = closeTxt + // report.DisconnectTime = time.Now() + + conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(closeMsg.Code, closeMsg.Text), getDeadline()) + conn.Close() +} + +// SendMessage sends a ClientMessage over the websocket connection with a timeout. +// If marshalling the ClientMessage fails, this function will panic. +func SendMessage(conn *websocket.Conn, msg ClientMessage) { + messageType, packet, err := MarshalClientMessage(msg) + if err != nil { + panic(fmt.Sprintf("failed to marshal: %v %v", err, msg)) + } + conn.SetWriteDeadline(getDeadline()) + conn.WriteMessage(messageType, packet) + atomic.AddUint64(&Statistics.MessagesSent, 1) +} + +// UnmarshalClientMessage unpacks websocket TextMessage into a ClientMessage provided in the `v` parameter. +func UnmarshalClientMessage(data []byte, payloadType int, v interface{}) (err error) { + var spaceIdx int + + out := v.(*ClientMessage) + dataStr := string(data) + + // Message ID + spaceIdx = strings.IndexRune(dataStr, ' ') + if spaceIdx == -1 { + return ErrProtocolGeneric + } + messageID, err := strconv.Atoi(dataStr[:spaceIdx]) + if messageID < -1 || messageID == 0 { + return ErrProtocolNegativeMsgID + } + + out.MessageID = messageID + dataStr = dataStr[spaceIdx+1:] + + spaceIdx = strings.IndexRune(dataStr, ' ') + if spaceIdx == -1 { + out.Command = CommandPool.InternCommand(dataStr) + out.Arguments = nil + return nil + } else { + out.Command = CommandPool.InternCommand(dataStr[:spaceIdx]) + } + dataStr = dataStr[spaceIdx+1:] + argumentsJSON := string([]byte(dataStr)) + out.origArguments = argumentsJSON + err = out.parseOrigArguments() + if err != nil { + return + } + return nil +} + +func (cm *ClientMessage) parseOrigArguments() error { + err := json.Unmarshal([]byte(cm.origArguments), &cm.Arguments) + if err != nil { + return err + } + return nil +} + +func MarshalClientMessage(clientMessage interface{}) (payloadType int, data []byte, err error) { + var msg ClientMessage + var ok bool + msg, ok = clientMessage.(ClientMessage) + if !ok { + pMsg, ok := clientMessage.(*ClientMessage) + if !ok { + panic("MarshalClientMessage: argument needs to be a ClientMessage") + } + msg = *pMsg + } + var dataStr string + + if msg.Command == "" && msg.MessageID == 0 { + panic("MarshalClientMessage: attempt to send an empty ClientMessage") + } + + if msg.Command == "" { + msg.Command = SuccessCommand + } + if msg.MessageID == 0 { + msg.MessageID = -1 + } + + if msg.Arguments != nil { + argBytes, err := json.Marshal(msg.Arguments) + if err != nil { + return 0, nil, err + } + + dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes)) + } else { + dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command) + } + + return websocket.TextMessage, []byte(dataStr), nil +} + +// ArgumentsAsString parses the arguments of the ClientMessage as a single string. +func (cm *ClientMessage) ArgumentsAsString() (string1 string, err error) { + var ok bool + string1, ok = cm.Arguments.(string) + if !ok { + err = ErrExpectedSingleString + return + } else { + return string1, nil + } +} + +// ArgumentsAsInt parses the arguments of the ClientMessage as a single int. +func (cm *ClientMessage) ArgumentsAsInt() (int1 int64, err error) { + var ok bool + var num float64 + num, ok = cm.Arguments.(float64) + if !ok { + err = ErrExpectedSingleInt + return + } else { + int1 = int64(num) + return int1, nil + } +} + +// ArgumentsAsTwoStrings parses the arguments of the ClientMessage as an array of two strings. +func (cm *ClientMessage) ArgumentsAsTwoStrings() (string1, string2 string, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ErrExpectedTwoStrings + return + } else { + if len(ary) != 2 { + err = ErrExpectedTwoStrings + return + } + string1, ok = ary[0].(string) + if !ok { + err = ErrExpectedTwoStrings + return + } + // clientID can be null + if ary[1] == nil { + return string1, "", nil + } + string2, ok = ary[1].(string) + if !ok { + err = ErrExpectedTwoStrings + return + } + return string1, string2, nil + } +} + +// ArgumentsAsStringAndInt parses the arguments of the ClientMessage as an array of a string and an int. +func (cm *ClientMessage) ArgumentsAsStringAndInt() (string1 string, int int64, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ErrExpectedStringAndInt + return + } else { + if len(ary) != 2 { + err = ErrExpectedStringAndInt + return + } + string1, ok = ary[0].(string) + if !ok { + err = ErrExpectedStringAndInt + return + } + var num float64 + num, ok = ary[1].(float64) + if !ok { + err = ErrExpectedStringAndInt + return + } + int = int64(num) + if float64(int) != num { + err = ErrExpectedStringAndIntGotFloat + return + } + return string1, int, nil + } +} + +// ArgumentsAsStringAndBool parses the arguments of the ClientMessage as an array of a string and an int. +func (cm *ClientMessage) ArgumentsAsStringAndBool() (str string, flag bool, err error) { + var ok bool + var ary []interface{} + ary, ok = cm.Arguments.([]interface{}) + if !ok { + err = ErrExpectedStringAndBool + return + } else { + if len(ary) != 2 { + err = ErrExpectedStringAndBool + return + } + str, ok = ary[0].(string) + if !ok { + err = ErrExpectedStringAndBool + return + } + flag, ok = ary[1].(bool) + if !ok { + err = ErrExpectedStringAndBool + return + } + return str, flag, nil + } +} diff --git a/socketserver/internal/server/handlecore_test.go b/socketserver/server/handlecore_test.go similarity index 79% rename from socketserver/internal/server/handlecore_test.go rename to socketserver/server/handlecore_test.go index 161b5921..8ecf848b 100644 --- a/socketserver/internal/server/handlecore_test.go +++ b/socketserver/server/handlecore_test.go @@ -2,14 +2,15 @@ package server import ( "fmt" - "golang.org/x/net/websocket" "testing" + + "github.com/gorilla/websocket" ) func ExampleUnmarshalClientMessage() { sourceData := []byte("100 hello [\"ffz_3.5.30\",\"898b5bfa-b577-47bb-afb4-252c703b67d6\"]") var cm ClientMessage - err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm) + err := UnmarshalClientMessage(sourceData, websocket.TextMessage, &cm) fmt.Println(err) fmt.Println(cm.MessageID) fmt.Println(cm.Command) @@ -27,9 +28,9 @@ func ExampleMarshalClientMessage() { Command: "do_authorize", Arguments: "1234567890", } - data, payloadType, err := MarshalClientMessage(&cm) + payloadType, data, err := MarshalClientMessage(&cm) fmt.Println(err) - fmt.Println(payloadType == websocket.TextFrame) + fmt.Println(payloadType == websocket.TextMessage) fmt.Println(string(data)) // Output: // @@ -40,7 +41,7 @@ func ExampleMarshalClientMessage() { func TestArgumentsAsStringAndBool(t *testing.T) { sourceData := []byte("1 foo [\"string\", false]") var cm ClientMessage - err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm) + err := UnmarshalClientMessage(sourceData, websocket.TextMessage, &cm) if err != nil { t.Fatal(err) } diff --git a/socketserver/server/intern.go b/socketserver/server/intern.go new file mode 100644 index 00000000..2f9cf416 --- /dev/null +++ b/socketserver/server/intern.go @@ -0,0 +1,42 @@ +package server + +import ( + "sync" +) + +type StringPool struct { + sync.RWMutex + lookup map[string]string +} + +func NewStringPool() *StringPool { + return &StringPool{lookup: make(map[string]string)} +} + +// doesn't lock, doesn't check for dupes. +func (p *StringPool) _Intern_Setup(s string) { + p.lookup[s] = s +} + +func (p *StringPool) InternCommand(s string) Command { + return Command(p.Intern(s)) +} + +func (p *StringPool) Intern(s string) string { + p.RLock() + ss, exists := p.lookup[s] + p.RUnlock() + if exists { + return ss + } + + p.Lock() + defer p.Unlock() + ss, exists = p.lookup[s] + if exists { + return ss + } + ss = copyString(s) + p.lookup[ss] = ss + return ss +} diff --git a/socketserver/server/irc.go b/socketserver/server/irc.go new file mode 100644 index 00000000..a5a837dc --- /dev/null +++ b/socketserver/server/irc.go @@ -0,0 +1,181 @@ +package server + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "log" + "strings" + "sync" + "time" + + irc "github.com/fluffle/goirc/client" +) + +type AuthCallback func(client *ClientInfo, successful bool) + +type PendingAuthorization struct { + Client *ClientInfo + Challenge string + Callback AuthCallback + EnteredAt time.Time +} + +var PendingAuths []PendingAuthorization +var PendingAuthLock sync.Mutex + +func AddPendingAuthorization(client *ClientInfo, challenge string, callback AuthCallback) { + PendingAuthLock.Lock() + defer PendingAuthLock.Unlock() + + PendingAuths = append(PendingAuths, PendingAuthorization{ + Client: client, + Challenge: challenge, + Callback: callback, + EnteredAt: time.Now(), + }) +} + +// is_init_func +func authorizationJanitor() { + for { + time.Sleep(5 * time.Minute) + authorizationJanitor_do() + } +} + +func authorizationJanitor_do() { + cullTime := time.Now().Add(-30 * time.Minute) + + PendingAuthLock.Lock() + defer PendingAuthLock.Unlock() + + newPendingAuths := make([]PendingAuthorization, 0, len(PendingAuths)) + + for _, v := range PendingAuths { + if !cullTime.After(v.EnteredAt) { + newPendingAuths = append(newPendingAuths, v) + } else { + go v.Callback(v.Client, false) + } + } + + PendingAuths = newPendingAuths +} + +func (client *ClientInfo) StartAuthorization(callback AuthCallback) { + if callback == nil { + return // callback must not be nil + } + var nonce [32]byte + _, err := rand.Read(nonce[:]) + if err != nil { + go callback(client, false) + return + } + buf := bytes.NewBuffer(nil) + enc := base64.NewEncoder(base64.RawURLEncoding, buf) + enc.Write(nonce[:]) + enc.Close() + challenge := buf.String() + + AddPendingAuthorization(client, challenge, callback) + + client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge} +} + +const AuthChannelName = "frankerfacezauthorizer" +const AuthChannel = "#" + AuthChannelName +const AuthCommand = "AUTH" + +var authIrcConnection *irc.Conn + +// is_init_func +func ircConnection() { + c := irc.SimpleClient("justinfan123") + c.Config().Server = "irc.chat.twitch.tv" + authIrcConnection = c + + var reconnect func(conn *irc.Conn) + connect := func(conn *irc.Conn) { + err := c.Connect() + if err != nil { + log.Println("irc: failed to connect to IRC:", err) + go reconnect(conn) + } + } + + reconnect = func(conn *irc.Conn) { + time.Sleep(5 * time.Second) + log.Println("irc: Reconnecting…") + connect(conn) + } + + c.HandleFunc(irc.CONNECTED, func(conn *irc.Conn, line *irc.Line) { + conn.Join(AuthChannel) + }) + + c.HandleFunc(irc.DISCONNECTED, func(conn *irc.Conn, line *irc.Line) { + log.Println("irc: Disconnected. Reconnecting in 5 seconds.") + go reconnect(conn) + }) + + c.HandleFunc(irc.PRIVMSG, func(conn *irc.Conn, line *irc.Line) { + channel := line.Args[0] + msg := line.Args[1] + if channel != AuthChannel || !strings.HasPrefix(msg, AuthCommand) || !line.Public() { + return + } + + msgArray := strings.Split(msg, " ") + if len(msgArray) != 2 { + return + } + + submittedUser := line.Nick + submittedChallenge := msgArray[1] + + submitAuth(submittedUser, submittedChallenge) + }) + + connect(c) +} + +func submitAuth(user, challenge string) { + var auth PendingAuthorization + var idx int = -1 + + PendingAuthLock.Lock() + for i, v := range PendingAuths { + if v.Client.TwitchUsername == user && v.Challenge == challenge { + auth = v + idx = i + break + } + } + if idx != -1 { + PendingAuths = append(PendingAuths[:idx], PendingAuths[idx+1:]...) + } + PendingAuthLock.Unlock() + + if idx == -1 { + return // perhaps it was for another socket server + } + + // auth is valid, and removed from pending list + + var usernameChanged bool + auth.Client.Mutex.Lock() + if auth.Client.TwitchUsername == user { // recheck condition + auth.Client.UsernameValidated = true + } else { + usernameChanged = true + } + auth.Client.Mutex.Unlock() + + if !usernameChanged { + auth.Callback(auth.Client, true) + } else { + auth.Callback(auth.Client, false) + } +} diff --git a/socketserver/server/logstasher/elasticsearch.go b/socketserver/server/logstasher/elasticsearch.go new file mode 100644 index 00000000..c505e00e --- /dev/null +++ b/socketserver/server/logstasher/elasticsearch.go @@ -0,0 +1,209 @@ +package logstasher + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +// ID is a 128-bit ID for an elasticsearch document. +// Textually, it is base64-encoded. +// The Next() method increments the ID. +type ID struct { + High uint64 + Low uint64 +} + +// Text converts the ID into a base64 string. +func (id ID) String() string { + var buf bytes.Buffer + buf.Grow(21) + enc := base64.NewEncoder(base64.StdEncoding, &buf) + var bytes [16]byte + binary.LittleEndian.PutUint64(bytes[0:8], id.High) + binary.LittleEndian.PutUint64(bytes[8:16], id.Low) + enc.Write(bytes[:]) + enc.Close() + return buf.String() +} + +// Next increments the ID and returns the prior state. +// Overflow is not checked because it's a uint64, do you really expect me to overflow that +func (id *ID) Next() ID { + ret := ID{ + High: id.High, + Low: id.Low, + } + id.Low++ + return ret +} + +var idPool = sync.Pool{New: func() interface{} { + var bytes [16]byte + n, err := rand.Reader.Read(bytes[:]) + if n != 16 || err != nil { + panic(fmt.Errorf("Short read from crypto/rand: %v", err)) + } + + return &ID{ + High: binary.LittleEndian.Uint64(bytes[0:8]), + Low: binary.LittleEndian.Uint64(bytes[8:16]), + } +}} + +func ExampleID_Next() { + id := idPool.Get().(*ID).Next() + fmt.Println(id) + idPool.Put(id) +} + +// Report is the interface presented to the Submit() function. +// FillReport() is satisfied by ReportBasic, but ReportType must always be specified. +type Report interface { + FillReport() error + ReportType() string + + GetID() string + GetTimestamp() time.Time +} + +// ReportBasic is the essential fields of any report. +type ReportBasic struct { + ID string + Timestamp time.Time + Host string +} + +// FillReport sets the Host and Timestamp fields. +func (report *ReportBasic) FillReport() error { + report.Host = hostMarker + report.Timestamp = time.Now() + id := idPool.Get().(*ID).Next() + report.ID = id.String() + idPool.Put(id) + return nil +} + +func (report *ReportBasic) GetID() string { + return report.ID +} + +func (report *ReportBasic) GetTimestamp() time.Time { + return report.Timestamp +} + +type ConnectionReport struct { + ReportBasic + + ConnectTime time.Time + DisconnectTime time.Time + // calculated + ConnectionDuration time.Duration + + DisconnectCode int + DisconnectReason string + + UsernameWasValidated bool + + RemoteAddr net.Addr `json:"-"` // not transmitted until I can figure out data minimization + TwitchUsername string `json:"-"` // also not transmitted +} + +// FillReport sets all the calculated fields, and calls esReportBasic.FillReport(). +func (report *ConnectionReport) FillReport() error { + report.ReportBasic.FillReport() + report.ConnectionDuration = report.DisconnectTime.Sub(report.ConnectTime) + return nil +} + +func (report *ConnectionReport) ReportType() string { + return "conn" +} + +var serverPresent bool +var esClient http.Client +var submitChan chan Report +var serverBase, indexPrefix, hostMarker string + +func checkServerPresent() { + if serverBase == "" { + serverBase = "http://localhost:9200" + } + if indexPrefix == "" { + indexPrefix = "sockreport" + } + + urlHealth := fmt.Sprintf("%s/_cluster/health", serverBase) + resp, err := esClient.Get(urlHealth) + if err == nil { + resp.Body.Close() + serverPresent = true + submitChan = make(chan Report, 8) + fmt.Println("elasticsearch reports enabled") + go submissionWorker() + } else { + serverPresent = false + } +} + +// Setup sets up the global variables for the package. +func Setup(ESServer, ESIndexPrefix, ESHostname string) { + serverBase = ESServer + indexPrefix = ESIndexPrefix + hostMarker = ESHostname + checkServerPresent() +} + +// Submit inserts a report into elasticsearch (this is basically a manual logstash). +func Submit(report Report) { + if !serverPresent { + return + } + + report.FillReport() + submitChan <- report +} + +func submissionWorker() { + for report := range submitChan { + time := report.GetTimestamp() + rType := report.ReportType() + + // prefix-type-date + indexName := fmt.Sprintf("%s-%s-%d-%d-%d", indexPrefix, rType, time.Year(), time.Month(), time.Day()) + // base/index/type/id + putUrl, err := url.Parse(fmt.Sprintf("%s/%s/%s/%s", serverBase, indexName, rType, report.GetID())) + if err != nil { + panic(fmt.Errorf("logstash: cannot parse url: %v", err)) + } + body, err := json.Marshal(report) + if err != nil { + panic(fmt.Errorf("logstash: cannot marshal json: %v", err)) + } + + req := &http.Request{ + Method: "PUT", + URL: putUrl, + Body: ioutil.NopCloser(bytes.NewReader(body)), + } + + resp, err := esClient.Do(req) + + if err != nil { + // ignore, the show must go on + } else { + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + } + } +} diff --git a/socketserver/server/publisher.go b/socketserver/server/publisher.go new file mode 100644 index 00000000..85912dd6 --- /dev/null +++ b/socketserver/server/publisher.go @@ -0,0 +1,239 @@ +package server + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +type LastSavedMessage struct { + Expires time.Time + Data string +} + +// map is command -> channel -> data + +// CachedLastMessages is of CacheTypeLastOnly. +// Not actually cleaned up by reaper goroutine every ~hour. +var CachedLastMessages = make(map[Command]map[string]LastSavedMessage) +var CachedLSMLock sync.RWMutex + +func cachedMessageJanitor() { + for { + time.Sleep(1*time.Hour) + cachedMessageJanitor_do() + } +} + +func cachedMessageJanitor_do() { + CachedLSMLock.Lock() + defer CachedLSMLock.Unlock() + + now := time.Now() + + for cmd, chanMap := range CachedLastMessages { + for channel, msg := range chanMap { + if !msg.Expires.IsZero() && msg.Expires.Before(now) { + delete(chanMap, channel) + } + } + if len(chanMap) == 0 { + delete(CachedLastMessages, cmd) + } + } +} + +// DumpBacklogData drops all /cached_pub data. +func DumpBacklogData() { + CachedLSMLock.Lock() + CachedLastMessages = make(map[Command]map[string]LastSavedMessage) + CachedLSMLock.Unlock() +} + +// SendBacklogForNewClient sends any backlog data relevant to a new client. +// This should be done when the client sends a `ready` message. +// This will only send data for CacheTypePersistent and CacheTypeLastOnly because those do not involve timestamps. +func SendBacklogForNewClient(client *ClientInfo) { + client.Mutex.Lock() // reading CurrentChannels + curChannels := make([]string, len(client.CurrentChannels)) + copy(curChannels, client.CurrentChannels) + client.Mutex.Unlock() + + CachedLSMLock.RLock() + for cmd, chanMap := range CachedLastMessages { + if chanMap == nil { + continue + } + for _, channel := range curChannels { + msg, ok := chanMap[channel] + if ok { + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + } + } + CachedLSMLock.RUnlock() +} + +func SendBacklogForChannel(client *ClientInfo, channel string) { + CachedLSMLock.RLock() + for cmd, chanMap := range CachedLastMessages { + if chanMap == nil { + continue + } + if msg, ok := chanMap[channel]; ok { + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data} + msg.parseOrigArguments() + client.MessageChannel <- msg + } + } + CachedLSMLock.RUnlock() +} + +type timestampArray interface { + Len() int + GetTime(int) time.Time +} + +// the CachedLSMLock must be held when calling this +func saveLastMessage(cmd Command, channel string, expires time.Time, data string, deleting bool) { + chanMap, ok := CachedLastMessages[cmd] + if !ok { + if deleting { + return + } + chanMap = make(map[string]LastSavedMessage) + CachedLastMessages[cmd] = chanMap + } + + if deleting { + delete(chanMap, channel) + } else { + chanMap[channel] = LastSavedMessage{Expires: expires, Data: data} + } +} + +func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := Backend.UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + confirm := formData.Get("confirm") + if confirm == "1" { + DumpBacklogData() + } +} + +// HTTPBackendCachedPublish handles the /cached_pub route. +// It publishes a message to clients, and then updates the in-server cache for the message. +// +// The 'channel' parameter is a comma-separated list of topics to publish the message to. +// The 'args' parameter is the JSON-encoded command data. +// If the 'delete' parameter is present, an entry is removed from the cache instead of publishing a message. +// If the 'expires' parameter is not specified, the message will not expire (though it is only kept in-memory). +func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := Backend.UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + cmd := CommandPool.InternCommand(formData.Get("cmd")) + json := formData.Get("args") + channel := formData.Get("channel") + deleteMode := formData.Get("delete") != "" + timeStr := formData.Get("expires") + var expires time.Time + if timeStr != "" { + timeNum, err := strconv.ParseInt(timeStr, 10, 64) + if err != nil { + w.WriteHeader(422) + fmt.Fprintf(w, "error parsing time: %v", err) + return + } + expires = time.Unix(timeNum, 0) + } + + var count int + msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json} + msg.parseOrigArguments() + + channels := strings.Split(channel, ",") + CachedLSMLock.Lock() + for _, channel := range channels { + saveLastMessage(cmd, channel, expires, json, deleteMode) + } + CachedLSMLock.Unlock() + count = PublishToMultiple(channels, msg) + + w.Write([]byte(strconv.Itoa(count))) +} + +// HTTPBackendUncachedPublish handles the /uncached_pub route. +// The backend can POST here to publish a message to clients with no caching. +// The POST arguments are `cmd`, `args`, `channel`, and `scope`. +// If "scope" is "global", then "channel" is not used. +func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := Backend.UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + cmd := formData.Get("cmd") + json := formData.Get("args") + channel := formData.Get("channel") + scope := formData.Get("scope") + + if cmd == "" { + w.WriteHeader(422) + fmt.Fprintf(w, "Error: cmd cannot be blank") + return + } + if channel == "" && scope != "global" { + w.WriteHeader(422) + fmt.Fprintf(w, "Error: channel must be specified") + return + } + + cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json} + cm.parseOrigArguments() + var count int + + switch scope { + default: + count = PublishToMultiple(strings.Split(channel, ","), cm) + case "global": + count = PublishToAll(cm) + } + fmt.Fprint(w, count) +} + +// HTTPGetSubscriberCount handles the /get_sub_count route. +// It replies with the number of clients subscribed to a pub/sub topic. +// A "global" option is not available, use fetch(/stats).CurrentClientCount instead. +func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + formData, err := Backend.UnsealRequest(r.Form) + if err != nil { + w.WriteHeader(403) + fmt.Fprintf(w, "Error: %v", err) + return + } + + channel := formData.Get("channel") + + fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ","))) +} \ No newline at end of file diff --git a/socketserver/server/publisher_test.go b/socketserver/server/publisher_test.go new file mode 100644 index 00000000..ad667d1b --- /dev/null +++ b/socketserver/server/publisher_test.go @@ -0,0 +1,66 @@ +package server + +import ( + "testing" + "time" +) + +func TestExpiredCleanup(t *testing.T) { + const cmd = "test_command" + const channel = "trihex" + const channel2 = "twitch" + const channel3 = "360chrism" + const channel4 = "qa_partner" + + DumpBacklogData() + defer DumpBacklogData() + + var zeroTime time.Time + hourAgo := time.Now().Add(-1*time.Hour) + now := time.Now() + hourFromNow := time.Now().Add(1*time.Hour) + + saveLastMessage(cmd, channel, hourAgo, "1", false) + saveLastMessage(cmd, channel2, now, "2", false) + + if len(CachedLastMessages) != 1 { + t.Error("messages not saved") + } + if len(CachedLastMessages[cmd]) != 2{ + t.Error("messages not saved") + } + + time.Sleep(2*time.Millisecond) + + cachedMessageJanitor_do() + + if len(CachedLastMessages) != 0 { + t.Error("messages still present") + } + + saveLastMessage(cmd, channel, hourAgo, "1", false) + saveLastMessage(cmd, channel2, now, "2", false) + saveLastMessage(cmd, channel3, hourFromNow, "3", false) + saveLastMessage(cmd, channel4, zeroTime, "4", false) + + if len(CachedLastMessages[cmd]) != 4 { + t.Error("messages not saved") + } + + time.Sleep(2*time.Millisecond) + + cachedMessageJanitor_do() + + if len(CachedLastMessages) != 1 { + t.Error("messages not saved") + } + if len(CachedLastMessages[cmd]) != 2 { + t.Error("messages not saved") + } + if CachedLastMessages[cmd][channel3].Data != "3" { + t.Error("saved wrong message") + } + if CachedLastMessages[cmd][channel4].Data != "4" { + t.Error("saved wrong message") + } +} diff --git a/socketserver/server/stats.go b/socketserver/server/stats.go new file mode 100644 index 00000000..a9401248 --- /dev/null +++ b/socketserver/server/stats.go @@ -0,0 +1,216 @@ +package server + +import ( + "bytes" + "encoding/json" + "net/http" + "runtime" + "sync" + "time" + + linuxproc "github.com/c9s/goprocinfo/linux" +) + +type StatsData struct { + StatsDataVersion int + + StartTime time.Time + Uptime string + BuildTime string + BuildHash string + + CachedStatsLastUpdate time.Time + + Health struct { + IRC bool + Backend map[string]time.Time + } + + CurrentClientCount uint64 + + PubSubChannelCount int + + SysMemTotalKB uint64 + SysMemFreeKB uint64 + MemoryInUseKB uint64 + MemoryRSSKB uint64 + + LowMemDroppedConnections uint64 + + MemPerClientBytes uint64 + + CpuUsagePct float64 + + ClientConnectsTotal uint64 + ClientDisconnectsTotal uint64 + + ClientVersions map[string]uint64 + + DisconnectCodes map[string]uint64 + + CommandsIssuedTotal uint64 + CommandsIssuedMap map[Command]uint64 + + MessagesSent uint64 + + EmotesReportedTotal uint64 + + BackendVerifyFails uint64 + + // DisconnectReasons is at the bottom because it has indeterminate size + DisconnectReasons map[string]uint64 +} + +// Statistics is several variables that get incremented during normal operation of the server. +// Its structure should be versioned as it is exposed via JSON. +// +// Note as to threaded access - this is soft/fun data and not critical to data integrity. +// Fix anything that -race turns up, but otherwise it's not too much of a problem. +var Statistics = newStatsData() + +// CommandCounter is a channel for race-free counting of command usage. +var CommandCounter = make(chan Command, 10) + +// commandCounter receives from the CommandCounter channel and uses the value to increment the values in Statistics. +// is_init_func +func commandCounter() { + for cmd := range CommandCounter { + Statistics.CommandsIssuedTotal++ + Statistics.CommandsIssuedMap[cmd]++ + } +} + +// StatsDataVersion is the version of the StatsData struct. +const StatsDataVersion = 6 +const pageSize = 4096 + +var cpuUsage struct { + UserTime uint64 + SysTime uint64 +} + +func newStatsData() *StatsData { + return &StatsData{ + StartTime: time.Now(), + CommandsIssuedMap: make(map[Command]uint64), + DisconnectCodes: make(map[string]uint64), + DisconnectReasons: make(map[string]uint64), + ClientVersions: make(map[string]uint64), + StatsDataVersion: StatsDataVersion, + Health: struct { + IRC bool + Backend map[string]time.Time + }{ + Backend: make(map[string]time.Time), + }, + } +} + +// SetBuildStamp should be called from the main package to identify the git build hash and build time. +func SetBuildStamp(buildTime, buildHash string) { + Statistics.BuildTime = buildTime + Statistics.BuildHash = buildHash +} + +func updateStatsIfNeeded() { + if time.Now().Add(-2 * time.Second).After(Statistics.CachedStatsLastUpdate) { + updatePeriodicStats() + } +} + +func updatePeriodicStats() { + nowUpdate := time.Now() + timeDiff := nowUpdate.Sub(Statistics.CachedStatsLastUpdate) + Statistics.CachedStatsLastUpdate = nowUpdate + + { + m := runtime.MemStats{} + runtime.ReadMemStats(&m) + + Statistics.MemoryInUseKB = m.Alloc / 1024 + } + + { + pstat, err := linuxproc.ReadProcessStat("/proc/self/stat") + if err == nil { + userTicks := pstat.Utime - cpuUsage.UserTime + sysTicks := pstat.Stime - cpuUsage.SysTime + cpuUsage.UserTime = pstat.Utime + cpuUsage.SysTime = pstat.Stime + + Statistics.CpuUsagePct = 100 * float64(userTicks+sysTicks) / (timeDiff.Seconds() * float64(ticksPerSecond)) + Statistics.MemoryRSSKB = uint64(pstat.Rss * pageSize / 1024) + Statistics.MemPerClientBytes = (Statistics.MemoryRSSKB * 1024) / (Statistics.CurrentClientCount + 1) + } + updateSysMem() + } + + { + ChatSubscriptionLock.RLock() + Statistics.PubSubChannelCount = len(ChatSubscriptionInfo) + ChatSubscriptionLock.RUnlock() + + GlobalSubscriptionLock.RLock() + + Statistics.CurrentClientCount = uint64(len(GlobalSubscriptionInfo)) + versions := make(map[string]uint64) + for _, v := range GlobalSubscriptionInfo { + versions[v.VersionString]++ + } + Statistics.ClientVersions = versions + + GlobalSubscriptionLock.RUnlock() + } + + { + Statistics.Uptime = nowUpdate.Sub(Statistics.StartTime).String() + } + + { + Statistics.Health.IRC = authIrcConnection.Connected() + Backend.lastSuccessLock.Lock() + for k, v := range Backend.lastSuccess { + Statistics.Health.Backend[k] = v + } + Backend.lastSuccessLock.Unlock() + } +} + +var sysMemLastUpdate time.Time +var sysMemUpdateLock sync.Mutex + +// updateSysMem reads the system's available RAM. +func updateSysMem() { + if time.Now().Add(-2 * time.Second).After(sysMemLastUpdate) { + sysMemUpdateLock.Lock() + defer sysMemUpdateLock.Unlock() + if !time.Now().Add(-2 * time.Second).After(sysMemLastUpdate) { + return + } + } else { + return + } + sysMemLastUpdate = time.Now() + memInfo, err := linuxproc.ReadMemInfo("/proc/meminfo") + if err == nil { + Statistics.SysMemTotalKB = memInfo.MemTotal + Statistics.SysMemFreeKB = memInfo.MemAvailable + } + + { + writeHLL() + } +} + +// HTTPShowStatistics handles the /stats endpoint. It writes out the Statistics object as indented JSON. +func HTTPShowStatistics(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + updateStatsIfNeeded() + + jsonBytes, _ := json.Marshal(Statistics) + outBuf := bytes.NewBuffer(nil) + json.Indent(outBuf, jsonBytes, "", "\t") + + outBuf.WriteTo(w) +} diff --git a/socketserver/internal/server/publisher.go b/socketserver/server/subscriptions.go similarity index 59% rename from socketserver/internal/server/publisher.go rename to socketserver/server/subscriptions.go index d9658ac7..30bc4112 100644 --- a/socketserver/internal/server/publisher.go +++ b/socketserver/server/subscriptions.go @@ -4,6 +4,7 @@ package server // If I screwed up the locking, I won't know until it's too late. import ( + "log" "sync" "time" ) @@ -15,9 +16,43 @@ type SubscriberList struct { var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList) var ChatSubscriptionLock sync.RWMutex -var GlobalSubscriptionInfo SubscriberList +var GlobalSubscriptionInfo []*ClientInfo +var GlobalSubscriptionLock sync.RWMutex -func PublishToChat(channel string, msg ClientMessage) (count int) { +func CountSubscriptions(channels []string) int { + ChatSubscriptionLock.RLock() + defer ChatSubscriptionLock.RUnlock() + + count := 0 + for _, channelName := range channels { + list := ChatSubscriptionInfo[channelName] + if list != nil { + list.RLock() + count += len(list.Members) + list.RUnlock() + } + } + + return count +} + +func SubscribeChannel(client *ClientInfo, channelName string) { + ChatSubscriptionLock.RLock() + _subscribeWhileRlocked(channelName, client.MessageChannel) + ChatSubscriptionLock.RUnlock() +} + +func SubscribeDefaults(client *ClientInfo) { + +} + +func SubscribeGlobal(client *ClientInfo) { + GlobalSubscriptionLock.Lock() + AddToSliceCl(&GlobalSubscriptionInfo, client) + GlobalSubscriptionLock.Unlock() +} + +func PublishToChannel(channel string, msg ClientMessage) (count int) { ChatSubscriptionLock.RLock() list := ChatSubscriptionInfo[channel] if list != nil { @@ -58,15 +93,99 @@ func PublishToMultiple(channels []string, msg ClientMessage) (count int) { } func PublishToAll(msg ClientMessage) (count int) { - GlobalSubscriptionInfo.RLock() - for _, msgChan := range GlobalSubscriptionInfo.Members { - msgChan <- msg + GlobalSubscriptionLock.RLock() + for _, client := range GlobalSubscriptionInfo { + select { + case client.MessageChannel <- msg: + case <-client.MsgChannelIsDone: + } count++ } - GlobalSubscriptionInfo.RUnlock() + GlobalSubscriptionLock.RUnlock() return } +func UnsubscribeSingleChat(client *ClientInfo, channelName string) { + ChatSubscriptionLock.RLock() + list := ChatSubscriptionInfo[channelName] + if list != nil { + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + } + ChatSubscriptionLock.RUnlock() +} + +// UnsubscribeAll will unsubscribe the client from all channels, +// AND clear the CurrentChannels / WatchingChannels fields. +// +// Locks: +// - read lock to top-level maps +// - write lock to SubscriptionInfos +// - write lock to ClientInfo +func UnsubscribeAll(client *ClientInfo) { + if StopAcceptingConnections { + return // no need to remove from a high-contention list when the server is closing + } + + GlobalSubscriptionLock.Lock() + RemoveFromSliceCl(&GlobalSubscriptionInfo, client) + GlobalSubscriptionLock.Unlock() + + ChatSubscriptionLock.RLock() + client.Mutex.Lock() + for _, v := range client.CurrentChannels { + list := ChatSubscriptionInfo[v] + if list != nil { + list.Lock() + RemoveFromSliceC(&list.Members, client.MessageChannel) + list.Unlock() + } + } + client.CurrentChannels = nil + client.Mutex.Unlock() + ChatSubscriptionLock.RUnlock() +} + +func unsubscribeAllClients() { + GlobalSubscriptionLock.Lock() + GlobalSubscriptionInfo = nil + GlobalSubscriptionLock.Unlock() + ChatSubscriptionLock.Lock() + ChatSubscriptionInfo = make(map[string]*SubscriberList) + ChatSubscriptionLock.Unlock() +} + +const ReapingDelay = 1 * time.Minute + +// Checks ChatSubscriptionInfo for entries with no subscribers every ReapingDelay. +// is_init_func +func pubsubJanitor() { + for { + time.Sleep(ReapingDelay) + pubsubJanitor_do() + } +} + +func pubsubJanitor_do() { + var cleanedUp = make([]string, 0, 6) + ChatSubscriptionLock.Lock() + for key, val := range ChatSubscriptionInfo { + if val == nil || len(val.Members) == 0 { + delete(ChatSubscriptionInfo, key) + cleanedUp = append(cleanedUp, key) + } + } + ChatSubscriptionLock.Unlock() + + if len(cleanedUp) != 0 { + err := Backend.SendCleanupTopicsNotice(cleanedUp) + if err != nil { + log.Println("error reporting cleaned subs:", err) + } + } +} + // Add a channel to the subscriptions while holding a read-lock to the map. // Locks: // - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object @@ -82,6 +201,14 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper ChatSubscriptionInfo[channelName] = list ChatSubscriptionLock.Unlock() + + go func(topic string) { + err := Backend.SendNewTopicNotice(topic) + if err != nil { + log.Println("error reporting new sub:", err) + } + }(channelName) + ChatSubscriptionLock.RLock() } else { list.Lock() @@ -89,80 +216,3 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) { list.Unlock() } } - -func SubscribeGlobal(client *ClientInfo) { - GlobalSubscriptionInfo.Lock() - AddToSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel) - GlobalSubscriptionInfo.Unlock() -} - -func SubscribeChat(client *ClientInfo, channelName string) { - ChatSubscriptionLock.RLock() - _subscribeWhileRlocked(channelName, client.MessageChannel) - ChatSubscriptionLock.RUnlock() -} - -func unsubscribeAllClients() { - GlobalSubscriptionInfo.Lock() - GlobalSubscriptionInfo.Members = nil - GlobalSubscriptionInfo.Unlock() - ChatSubscriptionLock.Lock() - ChatSubscriptionInfo = make(map[string]*SubscriberList) - ChatSubscriptionLock.Unlock() -} - -// Unsubscribe the client from all channels, AND clear the CurrentChannels / WatchingChannels fields. -// Locks: -// - read lock to top-level maps -// - write lock to SubscriptionInfos -// - write lock to ClientInfo -func UnsubscribeAll(client *ClientInfo) { - client.Mutex.Lock() - client.PendingSubscriptionsBacklog = nil - client.PendingSubscriptionsBacklog = nil - client.Mutex.Unlock() - - GlobalSubscriptionInfo.Lock() - RemoveFromSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel) - GlobalSubscriptionInfo.Unlock() - - ChatSubscriptionLock.RLock() - client.Mutex.Lock() - for _, v := range client.CurrentChannels { - list := ChatSubscriptionInfo[v] - if list != nil { - list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) - list.Unlock() - } - } - client.CurrentChannels = nil - client.Mutex.Unlock() - ChatSubscriptionLock.RUnlock() -} - -func UnsubscribeSingleChat(client *ClientInfo, channelName string) { - ChatSubscriptionLock.RLock() - list := ChatSubscriptionInfo[channelName] - list.Lock() - RemoveFromSliceC(&list.Members, client.MessageChannel) - list.Unlock() - ChatSubscriptionLock.RUnlock() -} - -const ReapingDelay = 120 * time.Minute - -// Checks ChatSubscriptionInfo for entries with no subscribers every ReapingDelay. -// Started from SetupServer(). -func deadChannelReaper() { - for { - time.Sleep(ReapingDelay) - ChatSubscriptionLock.Lock() - for key, val := range ChatSubscriptionInfo { - if len(val.Members) == 0 { - ChatSubscriptionInfo[key] = nil - } - } - ChatSubscriptionLock.Unlock() - } -} diff --git a/socketserver/internal/server/publisher_test.go b/socketserver/server/subscriptions_test.go similarity index 53% rename from socketserver/internal/server/publisher_test.go rename to socketserver/server/subscriptions_test.go index 2dc54ed6..fea4f533 100644 --- a/socketserver/internal/server/publisher_test.go +++ b/socketserver/server/subscriptions_test.go @@ -1,172 +1,21 @@ package server import ( - "encoding/json" "fmt" - "github.com/satori/go.uuid" - "golang.org/x/net/websocket" - "io/ioutil" "net/http" "net/http/httptest" "net/url" - "os" "strconv" "sync" "syscall" "testing" "time" + + "github.com/gorilla/websocket" + "github.com/satori/go.uuid" ) -func TCountOpenFDs() uint64 { - ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())) - return uint64(len(ary)) -} - -const IgnoreReceivedArguments = 1 + 2i - -func TReceiveExpectedMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) (ClientMessage, bool) { - var msg ClientMessage - var fail bool - err := FFZCodec.Receive(conn, &msg) - if err != nil { - tb.Error(err) - return msg, false - } - if msg.MessageID != messageId { - tb.Error("Message ID was wrong. Expected", messageId, ", got", msg.MessageID, ":", msg) - fail = true - } - if msg.Command != command { - tb.Error("Command was wrong. Expected", command, ", got", msg.Command, ":", msg) - fail = true - } - if arguments != IgnoreReceivedArguments { - if arguments == nil { - if msg.origArguments != "" { - tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) - } - } else { - argBytes, _ := json.Marshal(arguments) - if msg.origArguments != string(argBytes) { - tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) - } - } - } - return msg, !fail -} - -func TSendMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) bool { - err := FFZCodec.Send(conn, ClientMessage{MessageID: messageId, Command: command, Arguments: arguments}) - if err != nil { - tb.Error(err) - } - return err == nil -} - -func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, deleteMode bool) (url.Values, error) { - form := url.Values{} - form.Set("cmd", string(cmd)) - argsBytes, err := json.Marshal(arguments) - if err != nil { - tb.Error(err) - return nil, err - } - form.Set("args", string(argsBytes)) - form.Set("channel", channel) - if deleteMode { - form.Set("delete", "1") - } - form.Set("time", time.Now().Format(time.UnixDate)) - - sealed, err := SealRequest(form) - if err != nil { - tb.Error(err) - return nil, err - } - return sealed, nil -} - -func TCheckResponse(tb testing.TB, resp *http.Response, expected string) bool { - var failed bool - respBytes, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - respStr := string(respBytes) - - if err != nil { - tb.Error(err) - failed = true - } - - if resp.StatusCode != 200 { - tb.Error("Publish failed: ", resp.StatusCode, respStr) - failed = true - } - - if respStr != expected { - tb.Errorf("Got wrong response from server. Expected: '%s' Got: '%s'", expected, respStr) - failed = true - } - return !failed -} - -type TURLs struct { - Websocket string - Origin string - PubMsg string - SavePubMsg string // update_and_pub -} - -func TGetUrls(testserver *httptest.Server) TURLs { - addr := testserver.Listener.Addr().String() - return TURLs{ - Websocket: fmt.Sprintf("ws://%s/", addr), - Origin: fmt.Sprintf("http://%s", addr), - PubMsg: fmt.Sprintf("http://%s/pub_msg", addr), - SavePubMsg: fmt.Sprintf("http://%s/update_and_pub", addr), - } -} - -func TSetup(testserver **httptest.Server, urls *TURLs) { - DumpCache() - - conf := &ConfigFile{ - ServerId: 20, - UseSSL: false, - SocketOrigin: "localhost:2002", - BannerHTML: ` - -CatBag - -
-
-
-
-
-
- A FrankerFaceZ Service - — CatBag by Wolsk -
-
-`, - OurPublicKey: []byte{176, 149, 72, 209, 35, 42, 110, 220, 22, 236, 212, 129, 213, 199, 1, 227, 185, 167, 150, 159, 117, 202, 164, 100, 9, 107, 45, 141, 122, 221, 155, 73}, - OurPrivateKey: []byte{247, 133, 147, 194, 70, 240, 211, 216, 223, 16, 241, 253, 120, 14, 198, 74, 237, 180, 89, 33, 146, 146, 140, 58, 88, 160, 2, 246, 112, 35, 239, 87}, - BackendPublicKey: []byte{19, 163, 37, 157, 50, 139, 193, 85, 229, 47, 166, 21, 153, 231, 31, 133, 41, 158, 8, 53, 73, 0, 113, 91, 13, 181, 131, 248, 176, 18, 1, 107}, - } - gconfig = conf - SetupBackend(conf) - - if testserver != nil { - serveMux := http.NewServeMux() - SetupServerAndHandle(conf, nil, serveMux) - - tserv := httptest.NewUnstartedServer(serveMux) - *testserver = tserv - tserv.Start() - if urls != nil { - *urls = TGetUrls(tserv) - } - } -} +const TestOrigin = "http://www.twitch.tv" func TestSubscriptionAndPublish(t *testing.T) { var doneWg sync.WaitGroup @@ -184,19 +33,30 @@ func TestSubscriptionAndPublish(t *testing.T) { const TestData3 = false var TestData4 = []interface{}{"str1", "str2", "str3"} - ServerInitiatedCommands[TestCommandChan] = PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat} - ServerInitiatedCommands[TestCommandMulti] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeMultichat} - ServerInitiatedCommands[TestCommandGlobal] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeGlobal} + t.Log("TestSubscriptionAndPublish") var server *httptest.Server var urls TURLs - TSetup(&server, &urls) + + var backendExpected = NewTBackendRequestChecker(t, + TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil}, + TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName1}, "added": []string{"t"}}, "ok", nil}, + TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName2}, "added": []string{"t"}}, "ok", nil}, + TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName3}, "added": []string{"t"}}, "ok", nil}, + ) + server, _, urls = TSetup(SetupWantSocketServer|SetupWantBackendServer|SetupWantURLs, backendExpected) + defer server.CloseClientConnections() defer unsubscribeAllClients() + defer backendExpected.Close() var conn *websocket.Conn + var resp *http.Response var err error + var headers http.Header = make(http.Header) + headers.Set("Origin", TestOrigin) + // client 1: sub ch1, ch2 // client 2: sub ch1, ch3 // client 3: sub none @@ -204,15 +64,18 @@ func TestSubscriptionAndPublish(t *testing.T) { // msg 1: ch1 // msg 2: ch2, ch3 // msg 3: chEmpty - // msg 4: global + // msg 4: global uncached // Client 1 - conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) if err != nil { t.Error(err) return } + // both origins need testing + headers.Set("Origin", "https://www.twitch.tv") + doneWg.Add(1) readyWg.Add(1) go func(conn *websocket.Conn) { @@ -236,13 +99,14 @@ func TestSubscriptionAndPublish(t *testing.T) { }(conn) // Client 2 - conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) if err != nil { t.Error(err) return } doneWg.Add(1) + readyWg.Wait() // enforce ordering readyWg.Add(1) go func(conn *websocket.Conn) { TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) @@ -265,13 +129,15 @@ func TestSubscriptionAndPublish(t *testing.T) { }(conn) // Client 3 - conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) if err != nil { t.Error(err) return } doneWg.Add(1) + readyWg.Wait() // enforce ordering + time.Sleep(2 * time.Millisecond) readyWg.Add(1) go func(conn *websocket.Conn) { TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) @@ -291,7 +157,6 @@ func TestSubscriptionAndPublish(t *testing.T) { readyWg.Wait() var form url.Values - var resp *http.Response // Publish message 1 - should go to clients 1, 2 @@ -300,7 +165,7 @@ func TestSubscriptionAndPublish(t *testing.T) { t.FailNow() } resp, err = http.PostForm(urls.SavePubMsg, form) - if !TCheckResponse(t, resp, strconv.Itoa(2)) { + if !TCheckResponse(t, resp, strconv.Itoa(2), "pub msg 1") { t.FailNow() } @@ -311,7 +176,7 @@ func TestSubscriptionAndPublish(t *testing.T) { t.FailNow() } resp, err = http.PostForm(urls.SavePubMsg, form) - if !TCheckResponse(t, resp, strconv.Itoa(2)) { + if !TCheckResponse(t, resp, strconv.Itoa(2), "pub msg 2") { t.FailNow() } @@ -322,23 +187,23 @@ func TestSubscriptionAndPublish(t *testing.T) { t.FailNow() } resp, err = http.PostForm(urls.SavePubMsg, form) - if !TCheckResponse(t, resp, strconv.Itoa(0)) { + if !TCheckResponse(t, resp, strconv.Itoa(0), "pub msg 3") { t.FailNow() } // Publish message 4 - should go to clients 1, 2, 3 - form, err = TSealForSavePubMsg(t, TestCommandGlobal, "", TestData4, false) + form, err = TSealForUncachedPubMsg(t, TestCommandGlobal, "", TestData4, "global", false) if err != nil { t.FailNow() } - resp, err = http.PostForm(urls.SavePubMsg, form) - if !TCheckResponse(t, resp, strconv.Itoa(3)) { + resp, err = http.PostForm(urls.UncachedPubMsg, form) + if !TCheckResponse(t, resp, strconv.Itoa(3), "pub msg 4") { t.FailNow() } // Start client 4 - conn, err = websocket.Dial(urls.Websocket, "", urls.Origin) + conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) if err != nil { t.Error(err) return @@ -367,6 +232,113 @@ func TestSubscriptionAndPublish(t *testing.T) { doneWg.Wait() server.Close() + + clientCount := readCurrentHLL() + if clientCount < 3 || clientCount > 5 { + t.Error("clientCount outside acceptable range: expected 4, got ", clientCount) + } +} + +func TestRestrictedCommands(t *testing.T) { + var doneWg sync.WaitGroup + var readyWg sync.WaitGroup + + const TestCommandNeedsAuth = "needsauth" + const TestRequestData = "123456" + const TestRequestDataJSON = "\"" + TestRequestData + "\"" + const TestReplyData = "success" + const TestUsername = "sirstendec" + + var server *httptest.Server + var urls TURLs + + t.Log("TestRestrictedCommands") + + var backendExpected = NewTBackendRequestChecker(t, + TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil}, + TExpectedBackendRequest{401, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"0"}, "username": []string{""}, "clientData": []string{TestRequestDataJSON}}, "", nil}, + TExpectedBackendRequest{401, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"0"}, "username": []string{TestUsername}, "clientData": []string{TestRequestDataJSON}}, "", nil}, + TExpectedBackendRequest{200, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"1"}, "username": []string{TestUsername}, "clientData": []string{TestRequestDataJSON}}, fmt.Sprintf("\"%s\"", TestReplyData), nil}, + ) + server, _, urls = TSetup(SetupWantSocketServer|SetupWantBackendServer|SetupWantURLs, backendExpected) + + defer server.CloseClientConnections() + defer unsubscribeAllClients() + defer backendExpected.Close() + + var conn *websocket.Conn + var err error + var challengeChan = make(chan string) + + var headers http.Header = make(http.Header) + headers.Set("Origin", TestOrigin) + + // Client 1 + conn, _, err = websocket.DefaultDialer.Dial(urls.Websocket, headers) + if err != nil { + t.Error(err) + return + } + + doneWg.Add(1) + readyWg.Add(1) + go func(conn *websocket.Conn) { + defer doneWg.Done() + defer conn.Close() + TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()}) + TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments) + TSendMessage(t, conn, 2, ReadyCommand, 0) + TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil) + + // Should get immediate refusal because no username set + TSendMessage(t, conn, 3, TestCommandNeedsAuth, TestRequestData) + TReceiveExpectedMessage(t, conn, 3, ErrorCommand, AuthorizationNeededError) + + // Set a username + TSendMessage(t, conn, 4, SetUserCommand, TestUsername) + TReceiveExpectedMessage(t, conn, 4, SuccessCommand, nil) + + // Should get authorization prompt + TSendMessage(t, conn, 5, TestCommandNeedsAuth, TestRequestData) + readyWg.Done() + msg, success := TReceiveExpectedMessage(t, conn, -1, AuthorizeCommand, IgnoreReceivedArguments) + if !success { + t.Error("recieve authorize command failed, cannot continue") + return + } + challenge, err := msg.ArgumentsAsString() + if err != nil { + t.Error(err) + return + } + challengeChan <- challenge // mocked: sending challenge to IRC server, IRC server sends challenge to socket server + + TReceiveExpectedMessage(t, conn, 5, SuccessCommand, TestReplyData) + }(conn) + + readyWg.Wait() + + challenge := <-challengeChan + PendingAuthLock.Lock() + found := false + for _, v := range PendingAuths { + if conn.LocalAddr().String() == v.Client.RemoteAddr.String() { + found = true + if v.Challenge != challenge { + t.Error("Challenge in array was not what client got") + } + break + } + } + PendingAuthLock.Unlock() + if !found { + t.Fatal("Did not find authorization challenge in the pending auths array") + } + + submitAuth(TestUsername, challenge) + + doneWg.Wait() + server.Close() } func BenchmarkUserSubscriptionSinglePublish(b *testing.B) { @@ -396,12 +368,15 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) { var server *httptest.Server var urls TURLs - TSetup(&server, &urls) + server, _, urls = TSetup(SetupWantSocketServer|SetupWantURLs, nil) defer unsubscribeAllClients() + var headers http.Header = make(http.Header) + headers.Set("Origin", TestOrigin) + b.ResetTimer() for i := 0; i < b.N; i++ { - conn, err := websocket.Dial(urls.Websocket, "", urls.Origin) + conn, _, err := websocket.DefaultDialer.Dial(urls.Websocket, headers) if err != nil { b.Error(err) break @@ -427,7 +402,7 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) { readyWg.Wait() fmt.Println("publishing...") - if PublishToChat(TestChannelName, message) != b.N { + if PublishToChannel(TestChannelName, message) != b.N { b.Error("not enough sent") server.CloseClientConnections() panic("halting test instead of waiting") diff --git a/socketserver/server/testinfra_test.go b/socketserver/server/testinfra_test.go new file mode 100644 index 00000000..5d8a4570 --- /dev/null +++ b/socketserver/server/testinfra_test.go @@ -0,0 +1,357 @@ +package server + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +const ( + SetupWantSocketServer = 1 << iota + SetupWantBackendServer + SetupWantURLs +) +const SetupNoServers = 0 + +var signalCatch sync.Once + +func TSetup(flags int, backendChecker *TBackendRequestChecker) (socketserver *httptest.Server, backend *httptest.Server, urls TURLs) { + signalCatch.Do(func() { + go dumpStackOnCtrlZ() + }) + + DumpBacklogData() + + ioutil.WriteFile("index.html", []byte(` + +CatBag + +
+
+
+
+
+
+ A FrankerFaceZ Service + — CatBag by Wolsk +
+
`), 0644) + + conf := &ConfigFile{ + ServerID: 20, + UseSSL: false, + OurPublicKey: []byte{176, 149, 72, 209, 35, 42, 110, 220, 22, 236, 212, 129, 213, 199, 1, 227, 185, 167, 150, 159, 117, 202, 164, 100, 9, 107, 45, 141, 122, 221, 155, 73}, + OurPrivateKey: []byte{247, 133, 147, 194, 70, 240, 211, 216, 223, 16, 241, 253, 120, 14, 198, 74, 237, 180, 89, 33, 146, 146, 140, 58, 88, 160, 2, 246, 112, 35, 239, 87}, + BackendPublicKey: []byte{19, 163, 37, 157, 50, 139, 193, 85, 229, 47, 166, 21, 153, 231, 31, 133, 41, 158, 8, 53, 73, 0, 113, 91, 13, 181, 131, 248, 176, 18, 1, 107}, + } + + if flags&SetupWantBackendServer != 0 { + backend = httptest.NewServer(backendChecker) + conf.BackendURL = fmt.Sprintf("http://%s", backend.Listener.Addr().String()) + } + + Configuration = conf + setupBackend(conf) + + if flags&SetupWantSocketServer != 0 { + serveMux := http.NewServeMux() + SetupServerAndHandle(conf, serveMux) + dumpUniqueUsers() + + socketserver = httptest.NewServer(serveMux) + } + + if flags&SetupWantURLs != 0 { + urls = TGetUrls(socketserver, backend) + } + return +} + +type TBC interface { + Error(args ...interface{}) + Errorf(format string, args ...interface{}) +} + +const MethodIsPost = "POST" + +type TExpectedBackendRequest struct { + ResponseCode int + Path string + // Method string // always POST + PostForm *url.Values + Response string + ResponseHeaders http.Header +} + +func (er *TExpectedBackendRequest) String() string { + if MethodIsPost == "" { + return er.Path + } + return fmt.Sprint("%s %s: %s", MethodIsPost, er.Path, er.PostForm.Encode()) +} + +type TBackendRequestChecker struct { + ExpectedRequests []TExpectedBackendRequest + + currentRequest int + tb TBC + mutex sync.Mutex +} + +func NewTBackendRequestChecker(tb TBC, urls ...TExpectedBackendRequest) *TBackendRequestChecker { + return &TBackendRequestChecker{ExpectedRequests: urls, tb: tb, currentRequest: 0} +} + +func (backend *TBackendRequestChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { + backend.mutex.Lock() + defer backend.mutex.Unlock() + + if r.Method != MethodIsPost { + backend.tb.Errorf("Bad backend request: was not a POST. %v", r) + return + } + + r.ParseForm() + + unsealedForm, err := Backend.UnsealRequest(r.PostForm) + if err != nil { + backend.tb.Errorf("Failed to unseal backend request: %v", err) + } + + if backend.currentRequest >= len(backend.ExpectedRequests) { + backend.tb.Errorf("Unexpected backend request: %s %s: %s", r.Method, r.URL, unsealedForm) + return + } + + cur := backend.ExpectedRequests[backend.currentRequest] + backend.currentRequest++ + + headers := w.Header() + for k, v := range cur.ResponseHeaders { + if len(v) == 1 { + headers.Set(k, v[0]) + } else if len(v) == 0 { + headers.Del(k) + } else { + for _, hv := range v { + headers.Add(k, hv) + } + } + } + + defer func() { + w.WriteHeader(cur.ResponseCode) + if cur.Response != "" { + w.Write([]byte(cur.Response)) + } + }() + + if cur.Path != "" { + if r.URL.Path != cur.Path { + backend.tb.Errorf("Bad backend request. Expected %v, got %s %s", cur, r.Method, r.URL) + return + } + } + + if cur.PostForm != nil { + anyErr := TcompareForms(backend.tb, "Different form contents", *cur.PostForm, unsealedForm) + if anyErr { + backend.tb.Errorf("...in %s %s: %s", r.Method, r.URL, unsealedForm.Encode()) + } + } +} + +func (backend *TBackendRequestChecker) Close() error { + if backend.currentRequest < len(backend.ExpectedRequests) { + backend.tb.Errorf("Not all requests sent, got %d out of %d", backend.currentRequest, len(backend.ExpectedRequests)) + } + return nil +} + +func TcompareForms(tb TBC, ctx string, expectedForm, gotForm url.Values) (anyErrors bool) { + for k, expVal := range expectedForm { + gotVal, ok := gotForm[k] + if !ok { + tb.Errorf("%s: Form[%s]: Expected %v, (got nothing)", ctx, k, expVal) + anyErrors = true + continue + } + if len(expVal) != len(gotVal) { + tb.Errorf("%s: Form[%s]: Expected %d%v, Got %d%v", ctx, k, len(expVal), expVal, len(gotVal), gotVal) + anyErrors = true + continue + } + for i, el := range expVal { + if gotVal[i] != el { + tb.Errorf("%s: Form[%s][%d]: Expected %s, Got %s", ctx, k, i, el, gotVal[i]) + anyErrors = true + } + } + } + for k, gotVal := range gotForm { + _, ok := expectedForm[k] + if !ok { + tb.Errorf("%s: Form[%s]: (expected nothing), Got %v", ctx, k, gotVal) + anyErrors = true + } + } + return anyErrors +} + +func TCountOpenFDs() uint64 { + ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())) + return uint64(len(ary)) +} + +const IgnoreReceivedArguments = 1 + 2i + +func TReceiveExpectedMessage(tb testing.TB, conn *websocket.Conn, messageID int, command Command, arguments interface{}) (ClientMessage, bool) { + var msg ClientMessage + var fail bool + messageType, packet, err := conn.ReadMessage() + if err != nil { + tb.Error(err) + return msg, false + } + if messageType != websocket.TextMessage { + tb.Error("got non-text message", packet) + return msg, false + } + + err = UnmarshalClientMessage(packet, messageType, &msg) + if err != nil { + tb.Error(err) + return msg, false + } + if msg.MessageID != messageID { + tb.Error("Message ID was wrong. Expected", messageID, ", got", msg.MessageID, ":", msg) + fail = true + } + if msg.Command != command { + tb.Error("Command was wrong. Expected", command, ", got", msg.Command, ":", msg) + fail = true + } + if arguments != IgnoreReceivedArguments { + if arguments == nil { + if msg.origArguments != "" { + tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) + } + } else { + argBytes, _ := json.Marshal(arguments) + if msg.origArguments != string(argBytes) { + tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg) + } + } + } + return msg, !fail +} + +func TSendMessage(tb testing.TB, conn *websocket.Conn, messageID int, command Command, arguments interface{}) bool { + SendMessage(conn, ClientMessage{MessageID: messageID, Command: command, Arguments: arguments}) + return true +} + +func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, deleteMode bool) (url.Values, error) { + form := url.Values{} + form.Set("cmd", string(cmd)) + argsBytes, err := json.Marshal(arguments) + if err != nil { + tb.Error(err) + return nil, err + } + form.Set("args", string(argsBytes)) + form.Set("channel", channel) + if deleteMode { + form.Set("delete", "1") + } + form.Set("time", strconv.FormatInt(time.Now().Unix(), 10)) + + sealed, err := Backend.SealRequest(form) + if err != nil { + tb.Error(err) + return nil, err + } + return sealed, nil +} + +func TSealForUncachedPubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, scope string, deleteMode bool) (url.Values, error) { + form := url.Values{} + form.Set("cmd", string(cmd)) + argsBytes, err := json.Marshal(arguments) + if err != nil { + tb.Error(err) + return nil, err + } + form.Set("args", string(argsBytes)) + form.Set("channel", channel) + if deleteMode { + form.Set("delete", "1") + } + form.Set("time", time.Now().Format(time.UnixDate)) + form.Set("scope", scope) + + sealed, err := Backend.SealRequest(form) + if err != nil { + tb.Error(err) + return nil, err + } + return sealed, nil +} + +func TCheckResponse(tb testing.TB, resp *http.Response, expected string, desc string) bool { + var failed bool + respBytes, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + respStr := string(respBytes) + + if err != nil { + tb.Error(err) + failed = true + } + + if resp.StatusCode != 200 { + tb.Error("Publish failed: ", resp.StatusCode, respStr) + failed = true + } + + if respStr != expected { + tb.Errorf("Got wrong response from server. %s Expected: '%s' Got: '%s'", desc, expected, respStr) + failed = true + } + return !failed +} + +type TURLs struct { + Websocket string + Origin string + UncachedPubMsg string // uncached_pub + SavePubMsg string // cached_pub +} + +func TGetUrls(socketserver *httptest.Server, backend *httptest.Server) TURLs { + addr := socketserver.Listener.Addr().String() + return TURLs{ + Websocket: fmt.Sprintf("ws://%s/", addr), + Origin: fmt.Sprintf("http://%s", addr), + UncachedPubMsg: fmt.Sprintf("http://%s/uncached_pub", addr), + SavePubMsg: fmt.Sprintf("http://%s/cached_pub", addr), + } +} + +func TCheckHLLValue(tb testing.TB, expected uint64, actual uint64) { + high := uint64(float64(expected) * 1.05) + low := uint64(float64(expected) * 0.95) + if actual < low || actual > high { + tb.Errorf("Count outside expected range. Expected %d, Got %d", expected, actual) + } +} diff --git a/socketserver/server/tickspersecond.go b/socketserver/server/tickspersecond.go new file mode 100644 index 00000000..160eec7c --- /dev/null +++ b/socketserver/server/tickspersecond.go @@ -0,0 +1,12 @@ +package server + +// #include +// long get_ticks_per_second() { +// return sysconf(_SC_CLK_TCK); +// } +//import "C" + +// note: this seems to add 0.1s to compile time on my machine +//var ticksPerSecond = int(C.get_ticks_per_second()) + +var ticksPerSecond = 100 diff --git a/socketserver/server/types.go b/socketserver/server/types.go new file mode 100644 index 00000000..41c0d010 --- /dev/null +++ b/socketserver/server/types.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + "sync" + + "github.com/satori/go.uuid" +) + +const NegativeOne = ^uint64(0) + +var AnonymousClientID = uuid.FromStringOrNil("683b45e4-f853-4c45-bf96-7d799cc93e34") + +type ConfigFile struct { + // Numeric server id known to the backend + ServerID int + // Address to bind the HTTP server to on startup. + ListenAddr string + // Address to bind the TLS server to on startup. + SSLListenAddr string + // URL to the backend server + BackendURL string + + // Minimum memory to accept a new connection + MinMemoryKBytes uint64 + // Maximum # of clients that can be connected. 0 to disable. + MaxClientCount uint64 + + // SSL/TLS + // Enable the use of SSL. + UseSSL bool + // Path to certificate file. + SSLCertificateFile string + // Path to key file. + SSLKeyFile string + + UseESLogStashing bool + ESServer string + ESIndexPrefix string + ESHostName string + + // Nacl keys + OurPrivateKey []byte + OurPublicKey []byte + BackendPublicKey []byte + + // Request username validation from all new clients. + SendAuthToNewClients bool +} + +type ClientMessage struct { + // Message ID. Increments by 1 for each message sent from the client. + // When replying to a command, the message ID must be echoed. + // When sending a server-initiated message, this is -1. + MessageID int `json:"m"` + // The command that the client wants from the server. + // When sent from the server, the literal string 'True' indicates success. + // Before sending, a blank Command will be converted into SuccessCommand. + Command Command `json:"c"` + // Result of json.Unmarshal on the third field send from the client + Arguments interface{} `json:"a"` + + origArguments string +} + +type AuthInfo struct { + // The client's claimed username on Twitch. + TwitchUsername string + + // Whether or not the server has validated the client's claimed username. + UsernameValidated bool +} + +type ClientVersion struct { + Major int + Minor int + Revision int +} + +type ClientInfo struct { + // The client ID. + // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. + ClientID uuid.UUID + + // The client's literal version string. + // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. + VersionString string + + Version ClientVersion + + // This mutex protects writable data in this struct. + // If it seems to be a performance problem, we can split this. + Mutex sync.Mutex + + // Info about the client's username and whether or not we have verified it. + AuthInfo + + RemoteAddr net.Addr + + // Username validation nonce. + ValidationNonce string + + // The list of chats this client is currently in. + // Protected by Mutex. + CurrentChannels []string + + // True if the client has already sent the 'ready' command + ReadyComplete bool + + // Server-initiated messages should be sent here + // This field will be nil before it is closed. + MessageChannel chan<- ClientMessage + + MsgChannelIsDone <-chan struct{} + + // Take out an Add() on this during a command if you need to use the MessageChannel later. + MsgChannelKeepalive sync.WaitGroup + + // The number of pings sent without a response. + // Protected by Mutex + pingCount int +} + +func VersionFromString(v string) ClientVersion { + var cv ClientVersion + fmt.Sscanf(v, "ffz_%d.%d.%d", &cv.Major, &cv.Minor, &cv.Revision) + return cv +} + +func (cv *ClientVersion) After(cv2 *ClientVersion) bool { + if cv.Major > cv2.Major { + return true + } else if cv.Major < cv2.Major { + return false + } + if cv.Minor > cv2.Minor { + return true + } else if cv.Minor < cv2.Minor { + return false + } + if cv.Revision > cv2.Revision { + return true + } else if cv.Revision < cv2.Revision { + return false + } + + return false // equal +} + +func (cv *ClientVersion) Equal(cv2 *ClientVersion) bool { + return cv.Major == cv2.Major && cv.Minor == cv2.Minor && cv.Revision == cv2.Revision +} diff --git a/socketserver/server/usercount.go b/socketserver/server/usercount.go new file mode 100644 index 00000000..1b9cbcbb --- /dev/null +++ b/socketserver/server/usercount.go @@ -0,0 +1,244 @@ +package server + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "encoding/gob" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "time" + + "io" + + "github.com/clarkduvall/hyperloglog" + "github.com/satori/go.uuid" +) + +// UuidHash implements a hash for uuid.UUID by XORing the random bits. +type UuidHash uuid.UUID + +func (u UuidHash) Sum64() uint64 { + var valLow, valHigh uint64 + valLow = binary.LittleEndian.Uint64(u[0:8]) + valHigh = binary.LittleEndian.Uint64(u[8:16]) + return valLow ^ valHigh +} + +type PeriodUniqueUsers struct { + Start time.Time + End time.Time + Counter *hyperloglog.HyperLogLogPlus +} + +type usageToken struct{} + +const uniqCountDir = "./uniques" +const UsersDailyFmt = "daily-%d-%d-%d.gob" // d-m-y +const CounterPrecision uint8 = 12 + +var uniqueCounter PeriodUniqueUsers +var uniqueUserChannel chan uuid.UUID +var uniqueCtrWritingToken chan usageToken + +var CounterLocation *time.Location = time.FixedZone("UTC-5", int((time.Hour*-5)/time.Second)) + +func TruncateToMidnight(at time.Time) time.Time { + year, month, day := at.Date() + return time.Date(year, month, day, 0, 0, 0, 0, CounterLocation) +} + +// GetCounterPeriod calculates the start and end timestamps for the HLL measurement period that includes the 'at' timestamp. +func GetCounterPeriod(at time.Time) (start time.Time, end time.Time) { + year, month, day := at.Date() + start = time.Date(year, month, day, 0, 0, 0, 0, CounterLocation) + end = time.Date(year, month, day+1, 0, 0, 0, 0, CounterLocation) + return start, end +} + +// GetHLLFilename returns the filename for the saved HLL whose measurement period covers the given time. +func GetHLLFilename(at time.Time) string { + var filename string + year, month, day := at.Date() + filename = fmt.Sprintf(UsersDailyFmt, day, month, year) + return fmt.Sprintf("%s/%s", uniqCountDir, filename) +} + +// loadHLL loads a HLL from disk and stores the result in dest.Counter. +// If dest.Counter is nil, it will be initialized. (This is a useful side-effect.) +// If dest is one of the uniqueCounters, the usageToken must be held. +func loadHLL(at time.Time, dest *PeriodUniqueUsers) error { + fileBytes, err := ioutil.ReadFile(GetHLLFilename(at)) + if err != nil { + return err + } + + if dest.Counter == nil { + dest.Counter, _ = hyperloglog.NewPlus(CounterPrecision) + } + + dec := gob.NewDecoder(bytes.NewReader(fileBytes)) + err = dec.Decode(dest.Counter) + if err != nil { + return err + } + return nil +} + +// writeHLL writes the indicated HLL to disk. +// The function takes the usageToken. +func writeHLL() error { + token := <-uniqueCtrWritingToken + result := writeHLL_do(&uniqueCounter) + uniqueCtrWritingToken <- token + return result +} + +// writeHLL_do writes out the HLL indicated by `which` to disk. +// The usageToken must be held when calling this function. +func writeHLL_do(hll *PeriodUniqueUsers) (err error) { + filename := GetHLLFilename(hll.Start) + file, err := os.Create(filename) + if err != nil { + return err + } + + defer func(file io.Closer) { + fileErr := file.Close() + if err == nil { + err = fileErr + } + }(file) + + enc := gob.NewEncoder(file) + return enc.Encode(hll.Counter) +} + +// readCurrentHLL reads the current value of the active HLL counter. +// The function takes the usageToken. +func readCurrentHLL() uint64 { + token := <-uniqueCtrWritingToken + result := uniqueCounter.Counter.Count() + uniqueCtrWritingToken <- token + return result +} + +var hllFileServer = http.StripPrefix("/hll", http.FileServer(http.Dir(uniqCountDir))) + +func HTTPShowHLL(w http.ResponseWriter, r *http.Request) { + hllFileServer.ServeHTTP(w, r) +} + +func HTTPWriteHLL(w http.ResponseWriter, r *http.Request) { + writeHLL() + w.WriteHeader(200) + w.Write([]byte("ok")) +} + +// loadUniqueUsers loads the previous HLLs into memory. +// is_init_func +func loadUniqueUsers() { + gob.RegisterName("hyperloglog", hyperloglog.HyperLogLogPlus{}) + err := os.MkdirAll(uniqCountDir, 0755) + if err != nil { + log.Panicln("could not make unique users data dir:", err) + } + + now := time.Now().In(CounterLocation) + uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now) + err = loadHLL(now, &uniqueCounter) + isIgnorableError := err != nil && (false || + (os.IsNotExist(err)) || + (err == io.EOF)) + + if isIgnorableError { + // file didn't finish writing + // errors in NewPlus are bad precisions + uniqueCounter.Counter, _ = hyperloglog.NewPlus(CounterPrecision) + log.Println("failed to load unique users data:", err) + } else if err != nil { + log.Panicln("failed to load unique users data:", err) + } + + uniqueUserChannel = make(chan uuid.UUID) + uniqueCtrWritingToken = make(chan usageToken) + go processNewUsers() + go rolloverCounters() + uniqueCtrWritingToken <- usageToken{} +} + +// dumpUniqueUsers dumps all the data in uniqueCounters. +func dumpUniqueUsers() { + token := <-uniqueCtrWritingToken + + uniqueCounter.Counter.Clear() + + uniqueCtrWritingToken <- token +} + +// processNewUsers reads uniqueUserChannel, and also dispatches the writing token. +// This function is the primary writer of uniqueCounters, so it makes sense for it to hold the token. +// is_init_func +func processNewUsers() { + token := <-uniqueCtrWritingToken + + for { + select { + case u := <-uniqueUserChannel: + hashed := UuidHash(u) + uniqueCounter.Counter.Add(hashed) + case uniqueCtrWritingToken <- token: + // relinquish token. important that there is only one of this going on + // otherwise we thrash + token = <-uniqueCtrWritingToken + } + } +} + +func getNextMidnight() time.Time { + now := time.Now().In(CounterLocation) + year, month, day := now.Date() + return time.Date(year, month, day+1, 0, 0, 1, 0, CounterLocation) +} + +// is_init_func +func rolloverCounters() { + for { + duration := getNextMidnight().Sub(time.Now()) + // fmt.Println(duration) + time.Sleep(duration) + rolloverCounters_do() + } +} + +func rolloverCounters_do() { + var token usageToken + var now time.Time + + token = <-uniqueCtrWritingToken + now = time.Now().In(CounterLocation) + // Cycle for period + err := writeHLL_do(&uniqueCounter) + if err != nil { + log.Println("could not cycle unique user counter:", err) + + // Attempt to rescue the data into the log + var buf bytes.Buffer + bytes, err := uniqueCounter.Counter.GobEncode() + if err == nil { + enc := base64.NewEncoder(base64.StdEncoding, &buf) + enc.Write(bytes) + enc.Close() + log.Print("data for ", GetHLLFilename(uniqueCounter.Start), ":", buf.String()) + } + } + + uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now) + // errors are bad precisions, so we can ignore + uniqueCounter.Counter, _ = hyperloglog.NewPlus(CounterPrecision) + + uniqueCtrWritingToken <- token +} diff --git a/socketserver/server/usercount_test.go b/socketserver/server/usercount_test.go new file mode 100644 index 00000000..2cf608a2 --- /dev/null +++ b/socketserver/server/usercount_test.go @@ -0,0 +1,68 @@ +package server + +import ( + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/satori/go.uuid" +) + +func TestUniqueConnections(t *testing.T) { + const TestExpectedCount = 1000 + + testStart := time.Now().In(CounterLocation) + + var server *httptest.Server + var backendExpected = NewTBackendRequestChecker(t, + TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil}, + ) + server, _, _ = TSetup(SetupWantSocketServer|SetupWantBackendServer, backendExpected) + + defer server.CloseClientConnections() + defer unsubscribeAllClients() + defer backendExpected.Close() + + dumpUniqueUsers() + + for i := 0; i < TestExpectedCount; i++ { + uuid := uuid.NewV4() + uniqueUserChannel <- uuid + uniqueUserChannel <- uuid + } + + TCheckHLLValue(t, TestExpectedCount, readCurrentHLL()) + + token := <-uniqueCtrWritingToken + uniqueCounter.End = time.Now().In(CounterLocation).Add(-1 * time.Second) + uniqueCtrWritingToken <- token + + rolloverCounters_do() + + for i := 0; i < TestExpectedCount; i++ { + uuid := uuid.NewV4() + uniqueUserChannel <- uuid + uniqueUserChannel <- uuid + } + + TCheckHLLValue(t, TestExpectedCount, readCurrentHLL()) + + // Check: Merging the two days results in 2000 + // note: rolloverCounters_do() wrote out a file, and loadHLL() is reading it back + // TODO need to rewrite some of the test to make this work + var loadDest PeriodUniqueUsers + loadHLL(testStart, &loadDest) + + token = <-uniqueCtrWritingToken + loadDest.Counter.Merge(uniqueCounter.Counter) + uniqueCtrWritingToken <- token + + TCheckHLLValue(t, TestExpectedCount*2, loadDest.Counter.Count()) +} + +func TestUniqueUsersCleanup(t *testing.T) { + // Not a test. Removes old files. + os.RemoveAll(uniqCountDir) +} diff --git a/socketserver/internal/server/utils.go b/socketserver/server/utils.go similarity index 68% rename from socketserver/internal/server/utils.go rename to socketserver/server/utils.go index 8dbff0f4..9552e7bb 100644 --- a/socketserver/internal/server/utils.go +++ b/socketserver/server/utils.go @@ -5,11 +5,11 @@ import ( "crypto/rand" "encoding/base64" "errors" - "golang.org/x/crypto/nacl/box" - "log" "net/url" "strconv" "strings" + + "golang.org/x/crypto/nacl/box" ) func FillCryptoRandom(buf []byte) error { @@ -24,11 +24,11 @@ func FillCryptoRandom(buf []byte) error { return nil } -func New4KByteBuffer() interface{} { - return make([]byte, 0, 4096) +func copyString(s string) string { + return string([]byte(s)) } -func SealRequest(form url.Values) (url.Values, error) { +func (backend *backendInfo) SealRequest(form url.Values) (url.Values, error) { var nonce [24]byte var err error @@ -37,7 +37,7 @@ func SealRequest(form url.Values) (url.Values, error) { return nil, err } - cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backendSharedKey) + cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backend.sharedKey) bufMessage := new(bytes.Buffer) enc := base64.NewEncoder(base64.URLEncoding, bufMessage) @@ -54,7 +54,7 @@ func SealRequest(form url.Values) (url.Values, error) { retval := url.Values{ "nonce": []string{nonceString}, "msg": []string{cipherString}, - "id": []string{strconv.Itoa(serverId)}, + "id": []string{strconv.Itoa(Backend.serverID)}, } return retval, nil @@ -63,16 +63,18 @@ func SealRequest(form url.Values) (url.Values, error) { var ErrorShortNonce = errors.New("Nonce too short.") var ErrorInvalidSignature = errors.New("Invalid signature or contents") -func UnsealRequest(form url.Values) (url.Values, error) { +func (backend *backendInfo) UnsealRequest(form url.Values) (url.Values, error) { var nonce [24]byte nonceString := form.Get("nonce") dec := base64.NewDecoder(base64.URLEncoding, strings.NewReader(nonceString)) count, err := dec.Read(nonce[:]) if err != nil { + Statistics.BackendVerifyFails++ return nil, err } if count != 24 { + Statistics.BackendVerifyFails++ return nil, ErrorShortNonce } @@ -81,15 +83,15 @@ func UnsealRequest(form url.Values) (url.Values, error) { cipherBuffer := new(bytes.Buffer) cipherBuffer.ReadFrom(dec) - message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backendSharedKey) + message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backend.sharedKey) if !ok { + Statistics.BackendVerifyFails++ return nil, ErrorInvalidSignature } retValues, err := url.ParseQuery(string(message)) if err != nil { - // Assume that the signature was accidentally correct but the contents were garbage - log.Print(err) + Statistics.BackendVerifyFails++ return nil, ErrorInvalidSignature } @@ -159,3 +161,49 @@ func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) boo *ary = slice return true } + +func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { + slice := *ary + for _, v := range slice { + if v == val { + return false + } + } + + slice = append(slice, val) + *ary = slice + return true +} + +func RemoveFromSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool { + slice := *ary + var idx int = -1 + for i, v := range slice { + if v == val { + idx = i + break + } + } + if idx == -1 { + return false + } + + slice[idx] = slice[len(slice)-1] + slice = slice[:len(slice)-1] + *ary = slice + return true +} + +func AddToSliceB(ary *[]bunchSubscriber, client *ClientInfo, mid int) bool { + newSub := bunchSubscriber{Client: client, MessageID: mid} + slice := *ary + for _, v := range slice { + if v == newSub { + return false + } + } + + slice = append(slice, newSub) + *ary = slice + return true +}