diff --git a/socketserver/SocketServerDesign.svg b/socketserver/SocketServerDesign.svg
new file mode 100644
index 00000000..6a141697
--- /dev/null
+++ b/socketserver/SocketServerDesign.svg
@@ -0,0 +1,1127 @@
+
+
+
+
diff --git a/socketserver/cmd/ffzsocketserver/.gitignore b/socketserver/cmd/ffzsocketserver/.gitignore
new file mode 100644
index 00000000..43852e2e
--- /dev/null
+++ b/socketserver/cmd/ffzsocketserver/.gitignore
@@ -0,0 +1,3 @@
+config.json
+ffzsocketserver
+uniques/
diff --git a/socketserver/cmd/ffzsocketserver/console.go b/socketserver/cmd/ffzsocketserver/console.go
new file mode 100644
index 00000000..5d3063d9
--- /dev/null
+++ b/socketserver/cmd/ffzsocketserver/console.go
@@ -0,0 +1,136 @@
+package main
+
+import (
+ "fmt"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "../../server"
+ "github.com/abiosoft/ishell"
+ "github.com/gorilla/websocket"
+)
+
+func commandLineConsole() {
+
+ shell := ishell.NewShell()
+
+ shell.Register("help", func(args ...string) (string, error) {
+ shell.PrintCommands()
+ return "", nil
+ })
+
+ shell.Register("clientcount", func(args ...string) (string, error) {
+ server.GlobalSubscriptionLock.RLock()
+ count := len(server.GlobalSubscriptionInfo)
+ server.GlobalSubscriptionLock.RUnlock()
+ return fmt.Sprintln(count, "clients connected"), nil
+ })
+
+ shell.Register("globalnotice", func(args ...string) (string, error) {
+ msg := server.ClientMessage{
+ MessageID: -1,
+ Command: "message",
+ Arguments: args[0],
+ }
+ server.PublishToAll(msg)
+ return "Message sent.", nil
+ })
+
+ shell.Register("publish", func(args ...string) (string, error) {
+ if len(args) < 4 {
+ return "Usage: publish [room.sirstendec | _ALL] -1 reload_ff 23", nil
+ }
+
+ target := args[0]
+ line := strings.Join(args[1:], " ")
+ msg := server.ClientMessage{}
+ err := server.UnmarshalClientMessage([]byte(line), websocket.TextMessage, &msg)
+ if err != nil {
+ return "", err
+ }
+
+ var count int
+ if target == "_ALL" {
+ count = server.PublishToAll(msg)
+ } else {
+ count = server.PublishToChannel(target, msg)
+ }
+ return fmt.Sprintf("Published to %d clients", count), nil
+ })
+
+ shell.Register("memstatsbysize", func(args ...string) (string, error) {
+ runtime.GC()
+
+ m := runtime.MemStats{}
+ runtime.ReadMemStats(&m)
+ for _, val := range m.BySize {
+ if val.Mallocs == 0 {
+ continue
+ }
+ shell.Print(fmt.Sprintf("%5d: %6d outstanding (%d total)\n", val.Size, val.Mallocs-val.Frees, val.Mallocs))
+ }
+ shell.Println(m.NumGC, "collections occurred")
+ return "", nil
+ })
+
+ shell.Register("authorizeeveryone", func(args ...string) (string, error) {
+ if len(args) == 0 {
+ if server.Configuration.SendAuthToNewClients {
+ return "All clients are recieving auth challenges upon claiming a name.", nil
+ }
+ return "All clients are not recieving auth challenges upon claiming a name.", nil
+ } else if args[0] == "on" {
+ server.Configuration.SendAuthToNewClients = true
+ return "All new clients will recieve auth challenges upon claiming a name.", nil
+ } else if args[0] == "off" {
+ server.Configuration.SendAuthToNewClients = false
+ return "All new clients will not recieve auth challenges upon claiming a name.", nil
+ }
+ return "Usage: authorizeeveryone [ on | off ]", nil
+ })
+
+ shell.Register("kickclients", func(args ...string) (string, error) {
+ if len(args) == 0 {
+ return "Please enter either a count or a fraction of clients to kick.", nil
+ }
+ input, err := strconv.ParseFloat(args[0], 64)
+ if err != nil {
+ return "Argument must be a number", err
+ }
+ var count int
+ if input >= 1 {
+ count = int(input)
+ } else {
+ server.GlobalSubscriptionLock.RLock()
+ count = int(float64(len(server.GlobalSubscriptionInfo)) * input)
+ server.GlobalSubscriptionLock.RUnlock()
+ }
+
+ msg := server.ClientMessage{Arguments: &server.CloseRebalance}
+ server.GlobalSubscriptionLock.RLock()
+ defer server.GlobalSubscriptionLock.RUnlock()
+
+ kickCount := 0
+ for i, cl := range server.GlobalSubscriptionInfo {
+ if i >= count {
+ break
+ }
+ select {
+ case cl.MessageChannel <- msg:
+ case <-cl.MsgChannelIsDone:
+ }
+ kickCount++
+ }
+ return fmt.Sprintf("Kicked %d clients", kickCount), nil
+ })
+
+ shell.Register("panic", func(args ...string) (string, error) {
+ go func() {
+ panic("requested panic")
+ }()
+ return "", nil
+ })
+
+ shell.Start()
+}
diff --git a/socketserver/cmd/ffzsocketserver/index.html b/socketserver/cmd/ffzsocketserver/index.html
new file mode 100644
index 00000000..e28ce2ee
--- /dev/null
+++ b/socketserver/cmd/ffzsocketserver/index.html
@@ -0,0 +1,13 @@
+
+
CatBag
+
+
diff --git a/socketserver/cmd/ffzsocketserver/socketserver.go b/socketserver/cmd/ffzsocketserver/socketserver.go
index 5de7a059..1d67bddc 100644
--- a/socketserver/cmd/ffzsocketserver/socketserver.go
+++ b/socketserver/cmd/ffzsocketserver/socketserver.go
@@ -1,7 +1,6 @@
-package main // import "bitbucket.org/stendec/frankerfacez/socketserver/cmd/socketserver"
+package main // import "bitbucket.org/stendec/frankerfacez/socketserver/cmd/ffzsocketserver"
import (
- "../../internal/server"
"encoding/json"
"flag"
"fmt"
@@ -9,16 +8,23 @@ import (
"log"
"net/http"
"os"
+
+ "../../server"
)
-var configFilename *string = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.")
-var generateKeys *bool = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.")
+import _ "net/http/pprof"
+
+var configFilename = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.")
+var flagGenerateKeys = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.")
+
+var BuildTime string = "build not stamped"
+var BuildHash string = "build not stamped"
func main() {
flag.Parse()
- if *generateKeys {
- GenerateKeys(*configFilename)
+ if *flagGenerateKeys {
+ generateKeys(*configFilename)
return
}
@@ -40,24 +46,30 @@ func main() {
log.Fatal(err)
}
- httpServer := &http.Server{
- Addr: conf.ListenAddr,
- }
+ // logFile, err := os.OpenFile("output.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
+ // if err != nil {
+ // log.Fatal("Could not create logfile: ", err)
+ // }
- server.SetupServerAndHandle(conf, httpServer.TLSConfig, nil)
+ server.SetupServerAndHandle(conf, http.DefaultServeMux)
+ server.SetBuildStamp(BuildTime, BuildHash)
+
+ go commandLineConsole()
if conf.UseSSL {
- err = httpServer.ListenAndServeTLS(conf.SSLCertificateFile, conf.SSLKeyFile)
- } else {
- err = httpServer.ListenAndServe()
+ go func() {
+ if err := http.ListenAndServeTLS(conf.SSLListenAddr, conf.SSLCertificateFile, conf.SSLKeyFile, http.DefaultServeMux); err != nil {
+ log.Fatal("ListenAndServeTLS: ", err)
+ }
+ }()
}
- if err != nil {
+ if err = http.ListenAndServe(conf.ListenAddr, http.DefaultServeMux); err != nil {
log.Fatal("ListenAndServe: ", err)
}
}
-func GenerateKeys(outputFile string) {
+func generateKeys(outputFile string) {
if flag.NArg() < 1 {
fmt.Println("Specify a numeric server ID after -genkeys")
os.Exit(2)
diff --git a/socketserver/cmd/mergecounts/.gitignore b/socketserver/cmd/mergecounts/.gitignore
new file mode 100644
index 00000000..5b97e97f
--- /dev/null
+++ b/socketserver/cmd/mergecounts/.gitignore
@@ -0,0 +1 @@
+mergecounts
diff --git a/socketserver/cmd/mergecounts/mergecounts.go b/socketserver/cmd/mergecounts/mergecounts.go
new file mode 100644
index 00000000..8ce8b307
--- /dev/null
+++ b/socketserver/cmd/mergecounts/mergecounts.go
@@ -0,0 +1,102 @@
+package main
+
+import (
+ "encoding/gob"
+ "flag"
+ "fmt"
+ "net/http"
+ "os"
+
+ "../../server"
+ "github.com/clarkduvall/hyperloglog"
+)
+
+var SERVERS = []string{
+ "https://catbag.frankerfacez.com",
+ "https://andknuckles.frankerfacez.com",
+ "https://tuturu.frankerfacez.com",
+}
+
+const folderPrefix = "/hll/"
+
+const HELP = `
+Usage: mergecounts [filename]
+
+Downloads the file /hll/filename from the 3 FFZ socket servers, merges the contents, and prints the total cardinality.
+
+Filename should be in one of the following formats:
+
+ daily-25-12-2015.gob
+ weekly-51-2015.gob
+ monthly-12-2015.gob
+`
+
+var forceWrite = flag.Bool("f", false, "force servers to write out their current")
+
+func main() {
+ flag.Parse()
+ if flag.NArg() < 1 {
+ fmt.Print(HELP)
+ os.Exit(2)
+ return
+ }
+
+ filename := flag.Arg(0)
+ hll, err := DownloadAll(filename)
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ return
+ }
+
+ fmt.Println(hll.Count())
+}
+
+func ForceWrite() {
+ for _, server := range SERVERS {
+ resp, err := http.Get(fmt.Sprintf("%s/hll_force_write", server))
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+ resp.Body.Close()
+ }
+}
+
+func DownloadAll(filename string) (*hyperloglog.HyperLogLogPlus, error) {
+ result, _ := hyperloglog.NewPlus(server.CounterPrecision)
+
+ for _, server := range SERVERS {
+ if *forceWrite {
+ resp, err := http.Get(fmt.Sprintf("%s/hll_force_write", server))
+ if err == nil {
+ resp.Body.Close()
+ }
+ }
+ singleHLL, err := DownloadHLL(fmt.Sprintf("%s%s%s", server, folderPrefix, filename))
+ if err != nil {
+ return nil, err
+ }
+ result.Merge(singleHLL)
+ }
+
+ return result, nil
+}
+
+func DownloadHLL(url string) (*hyperloglog.HyperLogLogPlus, error) {
+ result, _ := hyperloglog.NewPlus(server.CounterPrecision)
+
+ resp, err := http.Get(url)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ dec := gob.NewDecoder(resp.Body)
+ err = dec.Decode(result)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Println(url, result.Count())
+
+ return result, nil
+}
diff --git a/socketserver/cmd/statsweb/.gitignore b/socketserver/cmd/statsweb/.gitignore
new file mode 100644
index 00000000..3dddabd8
--- /dev/null
+++ b/socketserver/cmd/statsweb/.gitignore
@@ -0,0 +1,3 @@
+database.sqlite
+gobcache/
+statsweb
diff --git a/socketserver/cmd/statsweb/config.go b/socketserver/cmd/statsweb/config.go
new file mode 100644
index 00000000..04591f44
--- /dev/null
+++ b/socketserver/cmd/statsweb/config.go
@@ -0,0 +1,74 @@
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+)
+
+type ConfigFile struct {
+ ListenAddr string
+ DatabaseLocation string
+ GobFilesLocation string
+}
+
+func makeConfig() {
+ config.ListenAddr = "localhost:3000"
+ home, ok := os.LookupEnv("HOME")
+ if ok {
+ config.DatabaseLocation = fmt.Sprintf("%s/.ffzstatsweb/database.sqlite", home)
+ config.GobFilesLocation = fmt.Sprintf("%s/.ffzstatsweb/gobcache", home)
+ os.MkdirAll(config.GobFilesLocation, 0755)
+ } else {
+ config.DatabaseLocation = "./database.sqlite"
+ config.GobFilesLocation = "./gobcache"
+ os.MkdirAll(config.GobFilesLocation, 0755)
+ }
+ file, err := os.Create(*configLocation)
+ if err != nil {
+ fmt.Printf("Error: could not create config file: %v\n", err)
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ enc := json.NewEncoder(file)
+ err = enc.Encode(config)
+ if err != nil {
+ fmt.Printf("Error: could not write config file: %v\n", err)
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ err = file.Close()
+ if err != nil {
+ fmt.Printf("Error: could not write config file: %v\n", err)
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ return
+}
+
+func loadConfig() {
+ file, err := os.Open(*configLocation)
+ if err != nil {
+ if os.IsNotExist(err) {
+ fmt.Println("You must create a config file with -genconf")
+ } else {
+ fmt.Printf("Error: could not load config file: %v", err)
+ }
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ dec := json.NewDecoder(file)
+ err = dec.Decode(&config)
+ if err != nil {
+ fmt.Printf("Error: could not load config file: %v\n", err)
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ err = file.Close()
+ if err != nil {
+ fmt.Printf("Error: could not load config file: %v\n", err)
+ os.Exit(ExitCodeBadConfig)
+ return
+ }
+ return
+}
diff --git a/socketserver/cmd/statsweb/config.json b/socketserver/cmd/statsweb/config.json
new file mode 100644
index 00000000..07cfa8e3
--- /dev/null
+++ b/socketserver/cmd/statsweb/config.json
@@ -0,0 +1 @@
+{"ListenAddr":"localhost:3000","DatabaseLocation":"./database.sqlite","GobFilesLocation":"./gobcache"}
diff --git a/socketserver/cmd/statsweb/html.go b/socketserver/cmd/statsweb/html.go
new file mode 100644
index 00000000..4cbdf89c
--- /dev/null
+++ b/socketserver/cmd/statsweb/html.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+ "html/template"
+ "net/http"
+ "time"
+
+ "bitbucket.org/stendec/frankerfacez/socketserver/server"
+)
+
+type CalendarData struct {
+ Weeks []CalWeekData
+}
+type CalWeekData struct {
+ Days []CalDayData
+}
+type CalDayData struct {
+ NoData bool
+ Date int
+ UniqUsers int
+}
+
+type CalendarMonthInfo struct {
+ Year int
+ Month time.Month
+ // Ranges from -5 to +1.
+ // A value of +1 means the 1st of the month is a Sunday.
+ // A value of 0 means the 1st of the month is a Monday.
+ // A value of -5 means the 1st of the month is a Saturday.
+ FirstSundayOffset int
+ // True if the calendar for this month needs six sundays.
+ NeedSixSundays bool
+}
+
+func GetMonthInfo(at time.Time) CalendarMonthInfo {
+ year, month, _ := at.Date()
+ monthStartWeekday := time.Date(year, month, 1, 0, 0, 0, 0, server.CounterLocation).Weekday()
+ // 1 (start of month) - weekday of start of month = day offset of start of week at start of mont
+ monthWeekStartDay := 1 - int(monthStartWeekday)
+ // first day on calendar + 6 weeks < end of month?
+ sixthSundayDay := monthWeekStartDay + 5*7
+ sixthSundayDate := time.Date(year, month, sixthSundayDay, 0, 0, 0, 0, server.CounterLocation)
+ var needSixSundays bool = false
+ if sixthSundayDate.Month() == month {
+ needSixSundays = true
+ }
+
+ return CalendarMonthInfo{
+ Year: year,
+ Month: month,
+ FirstSundayOffset: monthWeekStartDay,
+ NeedSixSundays: needSixSundays,
+ }
+}
+
+func renderCalendar(w http.ResponseWriter, at time.Time) {
+ layout, err := template.ParseFiles("./webroot/layout.template.html", "./webroot/cal_entry.hbs", "./webroot/calendar.hbs")
+ data := CalendarData{}
+ data.Weeks = make([]CalWeekData, 6)
+ _ = layout
+ _ = err
+}
diff --git a/socketserver/cmd/statsweb/servers.go b/socketserver/cmd/statsweb/servers.go
new file mode 100644
index 00000000..d4818c61
--- /dev/null
+++ b/socketserver/cmd/statsweb/servers.go
@@ -0,0 +1,342 @@
+package main
+
+import (
+ "encoding/gob"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "sync"
+ "time"
+
+ "bitbucket.org/stendec/frankerfacez/socketserver/server"
+ "github.com/clarkduvall/hyperloglog"
+ "github.com/hashicorp/golang-lru"
+)
+
+type serverFilter struct {
+ // Mode is false for blacklist, true for whitelist
+ Mode bool
+ Special []string
+}
+
+const serverFilterModeBlacklist = false
+const serverFilterModeWhitelist = true
+
+func (sf *serverFilter) IsServerAllowed(server *serverInfo) bool {
+ name := server.subdomain
+ for _, v := range sf.Special {
+ if name == v {
+ return sf.Mode
+ }
+ }
+ return !sf.Mode
+}
+
+func (sf *serverFilter) Remove(server string) {
+ if sf.Mode == serverFilterModeWhitelist {
+ var idx int = -1
+ for i, v := range sf.Special {
+ if server == v {
+ idx = i
+ break
+ }
+ }
+ if idx != -1 {
+ var lenMinusOne = len(sf.Special) - 1
+ sf.Special[idx] = sf.Special[lenMinusOne]
+ sf.Special = sf.Special[:lenMinusOne]
+ }
+ } else {
+ for _, v := range sf.Special {
+ if server == v {
+ return
+ }
+ }
+ sf.Special = append(sf.Special, server)
+ }
+}
+
+func (sf *serverFilter) Add(server string) {
+ if sf.Mode == serverFilterModeBlacklist {
+ var idx int = -1
+ for i, v := range sf.Special {
+ if server == v {
+ idx = i
+ break
+ }
+ }
+ if idx != -1 {
+ var lenMinusOne = len(sf.Special) - 1
+ sf.Special[idx] = sf.Special[lenMinusOne]
+ sf.Special = sf.Special[:lenMinusOne]
+ }
+ } else {
+ for _, v := range sf.Special {
+ if server == v {
+ return
+ }
+ }
+ sf.Special = append(sf.Special, server)
+ }
+}
+
+var serverFilterAll serverFilter = serverFilter{Mode: serverFilterModeBlacklist}
+var serverFilterNone serverFilter = serverFilter{Mode: serverFilterModeWhitelist}
+
+func cannotCacheHLL(at time.Time) bool {
+ now := time.Now()
+ now.Add(-72 * time.Hour)
+ return now.Before(at)
+}
+
+var ServerNames = []string{
+ "catbag",
+ "andknuckles",
+ "tuturu",
+}
+
+var httpClient http.Client
+
+const serverNameSuffix = ".frankerfacez.com"
+
+const failedStateThreshold = 4
+
+var ErrServerInFailedState = errors.New("server has been down recently and not recovered")
+var ErrServerHasNoData = errors.New("no data for specified date")
+
+type errServerNot200 struct {
+ StatusCode int
+ StatusText string
+}
+
+func (e *errServerNot200) Error() string {
+ return fmt.Sprintf("The server responded with %d %s", e.StatusCode, e.StatusText)
+}
+func Not200Error(resp *http.Response) *errServerNot200 {
+ return &errServerNot200{
+ StatusCode: resp.StatusCode,
+ StatusText: resp.Status,
+ }
+}
+
+func getHLLCacheKey(at time.Time) string {
+ year, month, day := at.Date()
+ return fmt.Sprintf("%d-%d-%d", year, month, day)
+}
+
+type serverInfo struct {
+ subdomain string
+
+ memcache *lru.TwoQueueCache
+
+ FailedState bool
+ FailureErr error
+ failureCount int
+
+ lock sync.Mutex
+}
+
+func (si *serverInfo) Setup(subdomain string) {
+ si.subdomain = subdomain
+ tq, err := lru.New2Q(60)
+ if err != nil {
+ panic(err)
+ }
+ si.memcache = tq
+}
+
+// GetHLL gets the HLL from
+func (si *serverInfo) GetHLL(at time.Time) (*hyperloglog.HyperLogLogPlus, error) {
+ if cannotCacheHLL(at) {
+ fmt.Println(at)
+ err := si.ForceWrite()
+ if err != nil {
+ return nil, err
+ }
+ reader, err := si.DownloadHLL(at)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("downloaded uncached hll %s:%s\n", si.subdomain, getHLLCacheKey(at))
+ defer si.DeleteHLL(at)
+ return loadHLLFromStream(reader)
+ }
+
+ hll, ok := si.PeekHLL(at)
+ if ok {
+ fmt.Printf("got cached hll %s:%s\n", si.subdomain, getHLLCacheKey(at))
+ return hll, nil
+ }
+
+ reader, err := si.OpenHLL(at)
+ if err != nil {
+ // continue to download
+ } else {
+ //fmt.Printf("opened hll %s:%s\n", si.subdomain, getHLLCacheKey(at))
+ return loadHLLFromStream(reader)
+ }
+
+ reader, err = si.DownloadHLL(at)
+ if err != nil {
+ if err == ErrServerHasNoData {
+ return hyperloglog.NewPlus(server.CounterPrecision)
+ }
+ return nil, err
+ }
+ fmt.Printf("downloaded hll %s:%s\n", si.subdomain, getHLLCacheKey(at))
+ return loadHLLFromStream(reader)
+}
+
+func loadHLLFromStream(reader io.ReadCloser) (*hyperloglog.HyperLogLogPlus, error) {
+ defer reader.Close()
+ hll, _ := hyperloglog.NewPlus(server.CounterPrecision)
+ dec := gob.NewDecoder(reader)
+ err := dec.Decode(hll)
+ if err != nil {
+ return nil, err
+ }
+ return hll, nil
+}
+
+// PeekHLL tries to grab a HLL from the memcache without downloading it or hitting the disk.
+func (si *serverInfo) PeekHLL(at time.Time) (*hyperloglog.HyperLogLogPlus, bool) {
+ if cannotCacheHLL(at) {
+ return nil, false
+ }
+
+ key := getHLLCacheKey(at)
+ hll, ok := si.memcache.Get(key)
+ if ok {
+ return hll.(*hyperloglog.HyperLogLogPlus), true
+ }
+
+ return nil, false
+}
+
+func (si *serverInfo) DeleteHLL(at time.Time) {
+ year, month, day := at.Date()
+ filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day)
+ err := os.Remove(filename)
+ if err != nil {
+ fmt.Println(err)
+ }
+}
+
+func (si *serverInfo) OpenHLL(at time.Time) (io.ReadCloser, error) {
+ year, month, day := at.Date()
+ filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day)
+
+ file, err := os.Open(filename)
+ if err == nil {
+ return file, nil
+ }
+ // file is nil
+ if !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ return nil, os.ErrNotExist
+}
+
+func (si *serverInfo) DownloadHLL(at time.Time) (io.ReadCloser, error) {
+ if si.FailedState {
+ return nil, ErrServerInFailedState
+ }
+ si.lock.Lock()
+ defer si.lock.Unlock()
+
+ year, month, day := at.Date()
+ url := fmt.Sprintf("https://%s/hll/daily-%d-%d-%d.gob", si.Domain(), day, month, year)
+ resp, err := httpClient.Get(url)
+ if err != nil {
+ si.ServerFailed(err)
+ return nil, err
+ }
+ if resp.StatusCode == 404 {
+ return nil, ErrServerHasNoData
+ }
+ if resp.StatusCode != 200 {
+ err = Not200Error(resp)
+ si.ServerFailed(err)
+ return nil, err
+ }
+
+ filename := fmt.Sprintf("%s/%s/%d-%d-%d.gob", config.GobFilesLocation, si.subdomain, year, month, day)
+ file, err := os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if os.IsNotExist(err) {
+ os.MkdirAll(fmt.Sprintf("%s/%s", config.GobFilesLocation, si.subdomain), 0755)
+ file, err = os.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ }
+ if err != nil {
+ resp.Body.Close()
+ return nil, fmt.Errorf("downloadhll: error opening file for writing: %v", err)
+ }
+
+ return &teeReadCloser{r: resp.Body, w: file}, nil
+}
+
+func (si *serverInfo) ForceWrite() error {
+ if si.FailedState {
+ return ErrServerInFailedState
+ }
+
+ url := fmt.Sprintf("https://%s/hll_force_write", si.Domain())
+ resp, err := httpClient.Get(url)
+ if err != nil {
+ si.ServerFailed(err)
+ return err
+ }
+ if resp.StatusCode != 200 {
+ err = Not200Error(resp)
+ si.ServerFailed(err)
+ return err
+ }
+ resp.Body.Close()
+ return nil
+}
+
+func (si *serverInfo) Domain() string {
+ return fmt.Sprintf("%s%s", si.subdomain, serverNameSuffix)
+}
+
+func (si *serverInfo) ServerFailed(err error) {
+ si.lock.Lock()
+ defer si.lock.Unlock()
+ si.failureCount++
+ if si.failureCount > failedStateThreshold {
+ fmt.Printf("Server %s entering failed state\n", si.subdomain)
+ si.FailedState = true
+ si.FailureErr = err
+ go recoveryCheck(si)
+ }
+}
+
+func recoveryCheck(si *serverInfo) {
+ // TODO check for server recovery
+}
+
+type teeReadCloser struct {
+ r io.ReadCloser
+ w io.WriteCloser
+}
+
+func (t *teeReadCloser) Read(p []byte) (n int, err error) {
+ n, err = t.r.Read(p)
+ if n > 0 {
+ if n, err := t.w.Write(p[:n]); err != nil {
+ return n, err
+ }
+ }
+ return
+}
+
+func (t *teeReadCloser) Close() error {
+ err1 := t.r.Close()
+ err2 := t.w.Close()
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
diff --git a/socketserver/cmd/statsweb/statsweb.go b/socketserver/cmd/statsweb/statsweb.go
new file mode 100644
index 00000000..b36cc768
--- /dev/null
+++ b/socketserver/cmd/statsweb/statsweb.go
@@ -0,0 +1,322 @@
+package main
+
+import (
+ "encoding/json"
+ "errors"
+ "flag"
+ "fmt"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "bitbucket.org/stendec/frankerfacez/socketserver/server"
+ "github.com/clarkduvall/hyperloglog"
+)
+
+var _ = os.Exit
+
+var configLocation = flag.String("config", "./config.json", "Location of the configuration file. Defaults to ./config.json")
+var genConfig = flag.Bool("genconf", false, "Generate a new configuration file.")
+
+var config ConfigFile
+
+const ExitCodeBadConfig = 2
+
+var allServers []*serverInfo
+
+func main() {
+ flag.Parse()
+
+ if *genConfig {
+ makeConfig()
+ return
+ }
+
+ loadConfig()
+
+ allServers = make([]*serverInfo, len(ServerNames))
+ for i, v := range ServerNames {
+ allServers[i] = &serverInfo{}
+ allServers[i].Setup(v)
+ }
+
+ //printEveryDay()
+ //os.Exit(0)
+ http.HandleFunc("/api/get", ServeAPIGet)
+ http.ListenAndServe(config.ListenAddr, http.DefaultServeMux)
+}
+
+func printEveryDay() {
+ year := 2015
+ month := 12
+ day := 23
+ filter := serverFilterAll
+ var filter1, filter2, filter3 serverFilter
+ filter1.Mode = serverFilterModeWhitelist
+ filter2.Mode = serverFilterModeWhitelist
+ filter3.Mode = serverFilterModeWhitelist
+ filter1.Add(allServers[0].subdomain)
+ filter2.Add(allServers[1].subdomain)
+ filter3.Add(allServers[2].subdomain)
+ stopTime := time.Now()
+ var at time.Time
+ const timeFmt = "2006-01-02"
+ for ; stopTime.After(at); day++ {
+ at = time.Date(year, time.Month(month), day, 0, 0, 0, 0, server.CounterLocation)
+ hll, _ := hyperloglog.NewPlus(server.CounterPrecision)
+ hll1, _ := hyperloglog.NewPlus(server.CounterPrecision)
+ hll2, _ := hyperloglog.NewPlus(server.CounterPrecision)
+ hll3, _ := hyperloglog.NewPlus(server.CounterPrecision)
+ addSingleDate(at, filter, hll)
+ addSingleDate(at, filter1, hll1)
+ addSingleDate(at, filter2, hll2)
+ addSingleDate(at, filter3, hll3)
+ fmt.Printf("%s\t%d\t%d\t%d\t%d\n", at.Format(timeFmt), hll.Count(), hll1.Count(), hll2.Count(), hll3.Count())
+ }
+}
+
+const RequestURIName = "q"
+const separatorRange = "~"
+const separatorAdd = " "
+const separatorServer = "@"
+const jsonErrMalformedRequest = `{"status":"error","error":"malformed request uri"}`
+const jsonErrBlankRequest = `{"status":"error","error":"no queries given"}`
+const statusError = "error"
+const statusPartial = "partial"
+const statusOk = "ok"
+
+type apiResponse struct {
+ Status string `json:"status"`
+ Responses []requestResponse `json:"resp"`
+}
+
+type requestResponse struct {
+ Status string `json:"status"`
+ Request string `json:"req"`
+ Error string `json:"error,omitempty"`
+ Count uint64 `json:"count,omitempty"`
+}
+
+func ServeAPIGet(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ u, err := url.ParseRequestURI(r.RequestURI)
+ if err != nil {
+ w.WriteHeader(400)
+ fmt.Fprint(w, jsonErrMalformedRequest)
+ return
+ }
+
+ query := u.Query()
+ reqCount := len(query[RequestURIName])
+ if reqCount == 0 {
+ w.WriteHeader(400)
+ fmt.Fprint(w, jsonErrBlankRequest)
+ return
+ }
+
+ resp := apiResponse{Status: statusOk}
+ resp.Responses = make([]requestResponse, reqCount)
+ for i, v := range query[RequestURIName] {
+ if len(v) == 0 {
+ continue
+ }
+ resp.Responses[i] = ProcessSingleGetRequest(v)
+ }
+ for _, v := range resp.Responses {
+ if v.Status != statusOk {
+ resp.Status = statusPartial
+ break
+ }
+ }
+
+ w.WriteHeader(200)
+ enc := json.NewEncoder(w)
+ enc.Encode(resp)
+}
+
+var errRangeFormatIncorrect = errors.New("incorrect range format, must be yyyy-mm-dd~yyyy-mm-dd")
+
+// ProcessSingleGetRequest takes a request string and pulls the unique user data for the given dates and filters.
+//
+// The request string is in the following format:
+//
+// Request = AddDateRanges [ "@" ServerFilter ] .
+// ServerFilter = [ "!" ] ServerName { " " ServerName } .
+// ServerName = { "a" … "z" } .
+// AddDateRanges = DateMaybeRange { " " DateMaybeRange } .
+// DateMaybeRange = DateRange | Date .
+// DateRange = Date "~" Date .
+// Date = Year "-" Month "-" Day .
+// Year = number number number number .
+// Month = number number .
+// Day = number number .
+// number = "0" … "9" .
+//
+// Example of a well-formed request:
+//
+// 2016-01-04~2016-01-08 2016-01-11~2016-01-15@andknuckles tuturu
+//
+// Remember that spaces are urlencoded as "+", so the HTTP request to send to retrieve that data would be this:
+//
+// /api/get?q=2016-01-04~2016-01-08+2016-01-11~2016-01-15%40andknuckles+tuturu
+//
+// If a ServerFilter is specified, only users connecting to the specified servers will be included in the count.
+//
+// It does not matter if a date is specified multiple times, due to the data format used.
+func ProcessSingleGetRequest(req string) (result requestResponse) {
+ fmt.Println("processing request:", req)
+ hll, _ := hyperloglog.NewPlus(server.CounterPrecision)
+
+ result.Request = req
+ result.Status = statusOk
+ filter := serverFilterAll
+
+ collectError := func(err error) bool {
+ if err == ErrServerInFailedState {
+ result.Status = statusPartial
+ return false
+ } else if err != nil {
+ result.Status = statusError
+ result.Error = err.Error()
+ return true
+ }
+ return false
+ }
+
+ serverSplit := strings.Split(req, separatorServer)
+ if len(serverSplit) == 2 {
+ filter = serverFilterNone
+ serversOnly := strings.Split(serverSplit[1], separatorAdd)
+ for _, v := range serversOnly {
+ filter.Add(v)
+ }
+ }
+
+ addSplit := strings.Split(serverSplit[0], separatorAdd)
+
+outerLoop:
+ for _, split1 := range addSplit {
+ if len(split1) == 0 {
+ continue
+ }
+
+ rangeSplit := strings.Split(split1, separatorRange)
+ if len(rangeSplit) == 1 {
+ at, err := parseDateFromRequest(rangeSplit[0])
+ if collectError(err) {
+ break outerLoop
+ }
+
+ err = addSingleDate(at, filter, hll)
+ if collectError(err) {
+ break outerLoop
+ }
+ } else if len(rangeSplit) == 2 {
+ from, err := parseDateFromRequest(rangeSplit[0])
+ if collectError(err) {
+ break outerLoop
+ }
+ to, err := parseDateFromRequest(rangeSplit[1])
+ if collectError(err) {
+ break outerLoop
+ }
+
+ err = addRange(from, to, filter, hll)
+ if collectError(err) {
+ break outerLoop
+ }
+ } else {
+ collectError(errRangeFormatIncorrect)
+ break outerLoop
+ }
+ }
+
+ if result.Status == statusOk {
+ result.Count = hll.Count()
+ }
+ return result
+}
+
+var errBadDate = errors.New("bad date format, must be yyyy-mm-dd")
+var zeroTime = time.Unix(0, 0)
+
+func parseDateFromRequest(dateStr string) (time.Time, error) {
+ var year, month, day int
+ n, err := fmt.Sscanf(dateStr, "%d-%d-%d", &year, &month, &day)
+ if err != nil || n != 3 {
+ return zeroTime, errBadDate
+ }
+ return time.Date(year, time.Month(month), day, 0, 0, 0, 0, server.CounterLocation), nil
+}
+
+type hllAndError struct {
+ hll *hyperloglog.HyperLogLogPlus
+ err error
+}
+
+func addSingleDate(at time.Time, filter serverFilter, dest *hyperloglog.HyperLogLogPlus) error {
+ var partialErr error
+ for _, si := range allServers {
+ if filter.IsServerAllowed(si) {
+ hll, err2 := si.GetHLL(at)
+ if err2 == ErrServerInFailedState {
+ partialErr = err2
+ } else if err2 != nil {
+ return err2
+ } else {
+ dest.Merge(hll)
+ }
+ }
+ }
+ return partialErr
+}
+
+func addRange(start time.Time, end time.Time, filter serverFilter, dest *hyperloglog.HyperLogLogPlus) error {
+ end = server.TruncateToMidnight(end)
+ year, month, day := start.Date()
+ var partialErr error
+ var myAllServers = make([]*serverInfo, 0, len(allServers))
+ for _, si := range allServers {
+ if filter.IsServerAllowed(si) {
+ myAllServers = append(myAllServers, si)
+ }
+ }
+
+ var ch = make(chan hllAndError)
+ var wg sync.WaitGroup
+ for current := start; current.Before(end); day = day + 1 {
+ current = time.Date(year, month, day, 0, 0, 0, 0, server.CounterLocation)
+ for _, si := range myAllServers {
+ wg.Add(1)
+ go getHLL(ch, si, current)
+ }
+ }
+
+ go func() {
+ wg.Wait()
+ close(ch)
+ }()
+
+ for pair := range ch {
+ wg.Done()
+ hll, err := pair.hll, pair.err
+ if err != nil {
+ if partialErr == nil || partialErr == ErrServerInFailedState {
+ partialErr = err
+ }
+ } else {
+ dest.Merge(hll)
+ }
+ }
+
+ return partialErr
+}
+
+func getHLL(ch chan hllAndError, si *serverInfo, at time.Time) {
+ hll, err := si.GetHLL(at)
+ ch <- hllAndError{hll: hll, err: err}
+}
diff --git a/socketserver/cmd/statsweb/webroot/cal_entry.hbs b/socketserver/cmd/statsweb/webroot/cal_entry.hbs
new file mode 100644
index 00000000..cf178cf3
--- /dev/null
+++ b/socketserver/cmd/statsweb/webroot/cal_entry.hbs
@@ -0,0 +1,6 @@
+
+ {{.Date}}
+ {{if not .NoData}}
+ {{.UniqUsers}}
+ {{end}}
+ |
diff --git a/socketserver/cmd/statsweb/webroot/calendar.hbs b/socketserver/cmd/statsweb/webroot/calendar.hbs
new file mode 100644
index 00000000..1a8070ce
--- /dev/null
+++ b/socketserver/cmd/statsweb/webroot/calendar.hbs
@@ -0,0 +1,18 @@
+
+
+ Sunday |
+ Monday |
+ Tuesday |
+ Wednesday |
+ Thursday |
+ Friday |
+ Saturday |
+
+
+ {{range .Weeks}}
+ {{range .Days}}
+ {{template "cal_entry"}}
+ {{end}}
+ {{end}}
+
+
\ No newline at end of file
diff --git a/socketserver/cmd/statsweb/webroot/layout.template.html b/socketserver/cmd/statsweb/webroot/layout.template.html
new file mode 100644
index 00000000..09c24acf
--- /dev/null
+++ b/socketserver/cmd/statsweb/webroot/layout.template.html
@@ -0,0 +1,15 @@
+
+
+
+
+ Socket Server Stats Dashboard
+
+
+
+
+ {{template "content"}}
+
+
+
diff --git a/socketserver/internal/server/backend.go b/socketserver/internal/server/backend.go
deleted file mode 100644
index 72d74b22..00000000
--- a/socketserver/internal/server/backend.go
+++ /dev/null
@@ -1,234 +0,0 @@
-package server
-
-import (
- "crypto/rand"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "github.com/pmylund/go-cache"
- "golang.org/x/crypto/nacl/box"
- "io/ioutil"
- "log"
- "net/http"
- "net/url"
- "strconv"
- "strings"
- "sync"
- "time"
-)
-
-var backendHttpClient http.Client
-var backendUrl string
-var responseCache *cache.Cache
-
-var getBacklogUrl string
-
-var backendSharedKey [32]byte
-var serverId int
-
-var messageBufferPool sync.Pool
-
-func SetupBackend(config *ConfigFile) {
- backendHttpClient.Timeout = 60 * time.Second
- backendUrl = config.BackendUrl
- if responseCache != nil {
- responseCache.Flush()
- }
- responseCache = cache.New(60*time.Second, 120*time.Second)
-
- getBacklogUrl = fmt.Sprintf("%s/backlog", backendUrl)
-
- messageBufferPool.New = New4KByteBuffer
-
- var theirPublic, ourPrivate [32]byte
- copy(theirPublic[:], config.BackendPublicKey)
- copy(ourPrivate[:], config.OurPrivateKey)
- serverId = config.ServerId
-
- box.Precompute(&backendSharedKey, &theirPublic, &ourPrivate)
-}
-
-func getCacheKey(remoteCommand, data string) string {
- return fmt.Sprintf("%s/%s", remoteCommand, data)
-}
-
-// Publish a message to clients with no caching.
-// The scope must be specified because no attempt is made to recognize the command.
-func HBackendPublishRequest(w http.ResponseWriter, r *http.Request) {
- r.ParseForm()
- formData, err := UnsealRequest(r.Form)
- if err != nil {
- w.WriteHeader(403)
- fmt.Fprintf(w, "Error: %v", err)
- return
- }
-
- cmd := formData.Get("cmd")
- json := formData.Get("args")
- channel := formData.Get("channel")
- scope := formData.Get("scope")
-
- target := MessageTargetTypeByName(scope)
-
- if cmd == "" {
- w.WriteHeader(422)
- fmt.Fprintf(w, "Error: cmd cannot be blank")
- return
- }
- if channel == "" && (target == MsgTargetTypeChat || target == MsgTargetTypeMultichat) {
- w.WriteHeader(422)
- fmt.Fprintf(w, "Error: channel must be specified")
- return
- }
-
- cm := ClientMessage{MessageID: -1, Command: Command(cmd), origArguments: json}
- cm.parseOrigArguments()
- var count int
-
- switch target {
- case MsgTargetTypeSingle:
- // TODO
- case MsgTargetTypeChat:
- count = PublishToChat(channel, cm)
- case MsgTargetTypeMultichat:
- // TODO
- case MsgTargetTypeGlobal:
- count = PublishToAll(cm)
- case MsgTargetTypeInvalid:
- default:
- w.WriteHeader(422)
- fmt.Fprint(w, "Invalid 'scope'. must be single, chat, multichat, channel, or global")
- return
- }
- fmt.Fprint(w, count)
-}
-
-func RequestRemoteDataCached(remoteCommand, data string, auth AuthInfo) (string, error) {
- cached, ok := responseCache.Get(getCacheKey(remoteCommand, data))
- if ok {
- return cached.(string), nil
- }
- return RequestRemoteData(remoteCommand, data, auth)
-}
-
-func RequestRemoteData(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) {
- destUrl := fmt.Sprintf("%s/cmd/%s", backendUrl, remoteCommand)
- var authKey string
- if auth.UsernameValidated {
- authKey = "usernameClaimed"
- } else {
- authKey = "username"
- }
-
- formData := url.Values{
- "clientData": []string{data},
- authKey: []string{auth.TwitchUsername},
- }
-
- sealedForm, err := SealRequest(formData)
- if err != nil {
- return "", err
- }
-
- resp, err := backendHttpClient.PostForm(destUrl, sealedForm)
- if err != nil {
- return "", err
- }
-
- respBytes, err := ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- if err != nil {
- return "", err
- }
-
- responseStr = string(respBytes)
-
- if resp.Header.Get("FFZ-Cache") != "" {
- durSecs, err := strconv.ParseInt(resp.Header.Get("FFZ-Cache"), 10, 64)
- if err != nil {
- return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err)
- }
- duration := time.Duration(durSecs) * time.Second
- responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration)
- }
-
- return
-}
-
-func FetchBacklogData(chatSubs []string) ([]ClientMessage, error) {
- formData := url.Values{
- "subs": chatSubs,
- }
-
- sealedForm, err := SealRequest(formData)
- if err != nil {
- return nil, err
- }
-
- resp, err := backendHttpClient.PostForm(getBacklogUrl, sealedForm)
- if err != nil {
- return nil, err
- }
- dec := json.NewDecoder(resp.Body)
- var messages []ClientMessage
- err = dec.Decode(messages)
- if err != nil {
- return nil, err
- }
-
- return messages, nil
-}
-
-func GenerateKeys(outputFile, serverId, theirPublicStr string) {
- var err error
- output := ConfigFile{
- ListenAddr: "0.0.0.0:8001",
- SocketOrigin: "localhost:8001",
- BackendUrl: "http://localhost:8002/ffz",
- BannerHTML: `
-
-CatBag
-
-
-`,
- }
-
- output.ServerId, err = strconv.Atoi(serverId)
- if err != nil {
- log.Fatal(err)
- }
-
- ourPublic, ourPrivate, err := box.GenerateKey(rand.Reader)
- if err != nil {
- log.Fatal(err)
- }
- output.OurPublicKey, output.OurPrivateKey = ourPublic[:], ourPrivate[:]
-
- if theirPublicStr != "" {
- reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(theirPublicStr))
- theirPublic, err := ioutil.ReadAll(reader)
- if err != nil {
- log.Fatal(err)
- }
- output.BackendPublicKey = theirPublic
- }
-
- bytes, err := json.MarshalIndent(output, "", "\t")
- if err != nil {
- log.Fatal(err)
- }
- fmt.Println(string(bytes))
- err = ioutil.WriteFile(outputFile, bytes, 0600)
- if err != nil {
- log.Fatal(err)
- }
-}
diff --git a/socketserver/internal/server/backend_test.go b/socketserver/internal/server/backend_test.go
deleted file mode 100644
index 7043d9f3..00000000
--- a/socketserver/internal/server/backend_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package server
-
-import (
- "crypto/rand"
- "golang.org/x/crypto/nacl/box"
- "net/url"
- "testing"
-)
-
-func SetupRandomKeys(t testing.TB) {
- _, senderPrivate, err := box.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
- receiverPublic, _, err := box.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
-
- box.Precompute(&backendSharedKey, receiverPublic, senderPrivate)
- messageBufferPool.New = New4KByteBuffer
-}
-
-func TestSealRequest(t *testing.T) {
- SetupRandomKeys(t)
-
- values := url.Values{
- "QuickBrownFox": []string{"LazyDog"},
- }
-
- sealedValues, err := SealRequest(values)
- if err != nil {
- t.Fatal(err)
- }
- // sealedValues.Encode()
- // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV
-
- unsealedValues, err := UnsealRequest(sealedValues)
- if err != nil {
- t.Fatal(err)
- }
-
- if unsealedValues.Get("QuickBrownFox") != "LazyDog" {
- t.Errorf("Failed to round-trip, got back %v", unsealedValues)
- }
-}
diff --git a/socketserver/internal/server/backlog.go b/socketserver/internal/server/backlog.go
deleted file mode 100644
index 66f09e93..00000000
--- a/socketserver/internal/server/backlog.go
+++ /dev/null
@@ -1,364 +0,0 @@
-package server
-
-import (
- "errors"
- "fmt"
- "net/http"
- "sort"
- "strconv"
- "strings"
- "sync"
- "time"
-)
-
-type PushCommandCacheInfo struct {
- Caching BacklogCacheType
- Target MessageTargetType
-}
-
-// this value is just docs right now
-var ServerInitiatedCommands = map[Command]PushCommandCacheInfo{
- /// Global updates & notices
- "update_news": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global
- "message": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global
- "reload_ff": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global
-
- /// Emote updates
- "reload_badges": {CacheTypeTimestamps, MsgTargetTypeGlobal}, // timecache:global
- "set_badge": {CacheTypeTimestamps, MsgTargetTypeMultichat}, // timecache:multichat
- "reload_set": {}, // timecache:multichat
- "load_set": {}, // TODO what are the semantics of this?
-
- /// User auth
- "do_authorize": {CacheTypeNever, MsgTargetTypeSingle}, // nocache:single
-
- /// Channel data
- // follow_sets: extra emote sets included in the chat
- // follow_buttons: extra follow buttons below the stream
- "follow_sets": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:chat
- "follow_buttons": {CacheTypePersistent, MsgTargetTypeChat}, // mustcache:watching
- "srl_race": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching
-
- /// Chatter/viewer counts
- "chatters": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching
- "viewers": {CacheTypeLastOnly, MsgTargetTypeChat}, // cachelast:watching
-}
-
-type BacklogCacheType int
-
-const (
- // This is not a cache type.
- CacheTypeInvalid BacklogCacheType = iota
- // This message cannot be cached.
- CacheTypeNever
- // Save the last 24 hours of this message.
- // If a client indicates that it has reconnected, replay the messages sent after the disconnect.
- // Do not replay if the client indicates that this is a firstload.
- CacheTypeTimestamps
- // Save only the last copy of this message, and always send it when the backlog is requested.
- CacheTypeLastOnly
- // Save this backlog data to disk with its timestamp.
- // Send it when the backlog is requested, or after a reconnect if it was updated.
- CacheTypePersistent
-)
-
-type MessageTargetType int
-
-const (
- // This is not a message target.
- MsgTargetTypeInvalid MessageTargetType = iota
- // This message is targeted to a single TODO(user or connection)
- MsgTargetTypeSingle
- // This message is targeted to all users in a chat
- MsgTargetTypeChat
- // This message is targeted to all users in multiple chats
- MsgTargetTypeMultichat
- // This message is sent to all FFZ users.
- MsgTargetTypeGlobal
-)
-
-// note: see types.go for methods on these
-
-// Returned by BacklogCacheType.UnmarshalJSON()
-var ErrorUnrecognizedCacheType = errors.New("Invalid value for cachetype")
-
-// Returned by MessageTargetType.UnmarshalJSON()
-var ErrorUnrecognizedTargetType = errors.New("Invalid value for message target")
-
-type TimestampedGlobalMessage struct {
- Timestamp time.Time
- Command Command
- Data string
-}
-
-type TimestampedMultichatMessage struct {
- Timestamp time.Time
- Channels []string
- Command Command
- Data string
-}
-
-type LastSavedMessage struct {
- Timestamp time.Time
- Data string
-}
-
-// map is command -> channel -> data
-
-// CacheTypeLastOnly. Cleaned up by reaper goroutine every ~hour.
-var CachedLastMessages map[Command]map[string]LastSavedMessage
-var CachedLSMLock sync.RWMutex
-
-// CacheTypePersistent. Never cleaned.
-var PersistentLastMessages map[Command]map[string]LastSavedMessage
-var PersistentLSMLock sync.RWMutex
-
-var CachedGlobalMessages []TimestampedGlobalMessage
-var CachedChannelMessages []TimestampedMultichatMessage
-var CacheListsLock sync.RWMutex
-
-func DumpCache() {
- CachedLSMLock.Lock()
- CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
- CachedLSMLock.Unlock()
-
- PersistentLSMLock.Lock()
- PersistentLastMessages = make(map[Command]map[string]LastSavedMessage)
- // TODO delete file?
- PersistentLSMLock.Unlock()
-
- CacheListsLock.Lock()
- CachedGlobalMessages = make(tgmarray, 0)
- CachedChannelMessages = make(tmmarray, 0)
- CacheListsLock.Unlock()
-}
-
-func SendBacklogForNewClient(client *ClientInfo) {
- client.Mutex.Lock() // reading CurrentChannels
- PersistentLSMLock.RLock()
- for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypePersistent, MsgTargetTypeChat}) {
- chanMap := CachedLastMessages[cmd]
- if chanMap == nil {
- continue
- }
- for _, channel := range client.CurrentChannels {
- msg, ok := chanMap[channel]
- if ok {
- msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
- msg.parseOrigArguments()
- client.MessageChannel <- msg
- }
- }
- }
- PersistentLSMLock.RUnlock()
-
- CachedLSMLock.RLock()
- for _, cmd := range GetCommandsOfType(PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat}) {
- chanMap := CachedLastMessages[cmd]
- if chanMap == nil {
- continue
- }
- for _, channel := range client.CurrentChannels {
- msg, ok := chanMap[channel]
- if ok {
- msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
- msg.parseOrigArguments()
- client.MessageChannel <- msg
- }
- }
- }
- CachedLSMLock.RUnlock()
- client.Mutex.Unlock()
-}
-
-func SendTimedBacklogMessages(client *ClientInfo, disconnectTime time.Time) {
- client.Mutex.Lock() // reading CurrentChannels
- CacheListsLock.RLock()
-
- globIdx := FindFirstNewMessage(tgmarray(CachedGlobalMessages), disconnectTime)
-
- for i := globIdx; i < len(CachedGlobalMessages); i++ {
- item := CachedGlobalMessages[i]
- msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data}
- msg.parseOrigArguments()
- client.MessageChannel <- msg
- }
-
- chanIdx := FindFirstNewMessage(tmmarray(CachedChannelMessages), disconnectTime)
-
- for i := chanIdx; i < len(CachedChannelMessages); i++ {
- item := CachedChannelMessages[i]
- var send bool
- for _, channel := range item.Channels {
- for _, matchChannel := range client.CurrentChannels {
- if channel == matchChannel {
- send = true
- break
- }
- }
- if send {
- break
- }
- }
- if send {
- msg := ClientMessage{MessageID: -1, Command: item.Command, origArguments: item.Data}
- msg.parseOrigArguments()
- client.MessageChannel <- msg
- }
- }
-
- CacheListsLock.RUnlock()
- client.Mutex.Unlock()
-}
-
-func InsertionSort(ary sort.Interface) {
- for i := 1; i < ary.Len(); i++ {
- for j := i; j > 0 && ary.Less(j, j-1); j-- {
- ary.Swap(j, j-1)
- }
- }
-}
-
-type TimestampArray interface {
- Len() int
- GetTime(int) time.Time
-}
-
-func FindFirstNewMessage(ary TimestampArray, disconnectTime time.Time) (idx int) {
- // TODO needs tests
- len := ary.Len()
- i := len
-
- // Walk backwards until we find GetTime() before disconnectTime
- step := 1
- for i > 0 {
- i -= step
- if i < 0 {
- i = 0
- }
- if !ary.GetTime(i).After(disconnectTime) {
- break
- }
- step = int(float64(step)*1.5) + 1
- }
-
- // Walk forwards until we find GetTime() after disconnectTime
- for i < len && !ary.GetTime(i).After(disconnectTime) {
- i++
- }
-
- if i == len {
- return -1
- }
- return i
-}
-
-func SaveLastMessage(which map[Command]map[string]LastSavedMessage, locker sync.Locker, cmd Command, channel string, timestamp time.Time, data string, deleting bool) {
- locker.Lock()
- defer locker.Unlock()
-
- chanMap, ok := CachedLastMessages[cmd]
- if !ok {
- if deleting {
- return
- }
- chanMap = make(map[string]LastSavedMessage)
- CachedLastMessages[cmd] = chanMap
- }
-
- if deleting {
- delete(chanMap, channel)
- } else {
- chanMap[channel] = LastSavedMessage{timestamp, data}
- }
-}
-
-func SaveGlobalMessage(cmd Command, timestamp time.Time, data string) {
- CacheListsLock.Lock()
- CachedGlobalMessages = append(CachedGlobalMessages, TimestampedGlobalMessage{timestamp, cmd, data})
- InsertionSort(tgmarray(CachedGlobalMessages))
- CacheListsLock.Unlock()
-}
-
-func SaveMultichanMessage(cmd Command, channels string, timestamp time.Time, data string) {
- CacheListsLock.Lock()
- CachedChannelMessages = append(CachedChannelMessages, TimestampedMultichatMessage{timestamp, strings.Split(channels, ","), cmd, data})
- InsertionSort(tmmarray(CachedChannelMessages))
- CacheListsLock.Unlock()
-}
-
-func GetCommandsOfType(match PushCommandCacheInfo) []Command {
- var ret []Command
- for cmd, info := range ServerInitiatedCommands {
- if info == match {
- ret = append(ret, cmd)
- }
- }
- return ret
-}
-
-func HBackendDumpBacklog(w http.ResponseWriter, r *http.Request) {
- r.ParseForm()
- formData, err := UnsealRequest(r.Form)
- if err != nil {
- w.WriteHeader(403)
- fmt.Fprintf(w, "Error: %v", err)
- return
- }
-
- confirm := formData.Get("confirm")
- if confirm == "1" {
- DumpCache()
- }
-}
-
-// Publish a message to clients, and update the in-server cache for the message.
-// notes:
-// `scope` is implicit in the command
-func HBackendUpdateAndPublish(w http.ResponseWriter, r *http.Request) {
- r.ParseForm()
- formData, err := UnsealRequest(r.Form)
- if err != nil {
- w.WriteHeader(403)
- fmt.Fprintf(w, "Error: %v", err)
- return
- }
-
- cmd := Command(formData.Get("cmd"))
- json := formData.Get("args")
- channel := formData.Get("channel")
- deleteMode := formData.Get("delete") != ""
- timeStr := formData.Get("time")
- timestamp, err := time.Parse(time.UnixDate, timeStr)
- if err != nil {
- w.WriteHeader(422)
- fmt.Fprintf(w, "error parsing time: %v", err)
- }
-
- cacheinfo, ok := ServerInitiatedCommands[cmd]
- if !ok {
- w.WriteHeader(422)
- fmt.Fprintf(w, "Caching semantics unknown for command '%s'. Post to /addcachedcommand first.")
- return
- }
-
- var count int
- msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
- msg.parseOrigArguments()
-
- if cacheinfo.Caching == CacheTypeLastOnly && cacheinfo.Target == MsgTargetTypeChat {
- SaveLastMessage(CachedLastMessages, &CachedLSMLock, cmd, channel, timestamp, json, deleteMode)
- count = PublishToChat(channel, msg)
- } else if cacheinfo.Caching == CacheTypePersistent && cacheinfo.Target == MsgTargetTypeChat {
- SaveLastMessage(PersistentLastMessages, &PersistentLSMLock, cmd, channel, timestamp, json, deleteMode)
- count = PublishToChat(channel, msg)
- } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeMultichat {
- SaveMultichanMessage(cmd, channel, timestamp, json)
- count = PublishToMultiple(strings.Split(channel, ","), msg)
- } else if cacheinfo.Caching == CacheTypeTimestamps && cacheinfo.Target == MsgTargetTypeGlobal {
- SaveGlobalMessage(cmd, timestamp, json)
- count = PublishToAll(msg)
- }
-
- w.Write([]byte(strconv.Itoa(count)))
-}
diff --git a/socketserver/internal/server/backlog_test.go b/socketserver/internal/server/backlog_test.go
deleted file mode 100644
index 68757587..00000000
--- a/socketserver/internal/server/backlog_test.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package server
-
-import (
- "testing"
- "time"
-)
-
-func TestFindFirstNewMessageEmpty(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{}
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != -1 {
- t.Errorf("Expected -1, got %d", i)
- }
-}
-func TestFindFirstNewMessageOneBefore(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{
- {Timestamp: time.Unix(8, 0)},
- }
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != -1 {
- t.Errorf("Expected -1, got %d", i)
- }
-}
-func TestFindFirstNewMessageSeveralBefore(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{
- {Timestamp: time.Unix(1, 0)},
- {Timestamp: time.Unix(2, 0)},
- {Timestamp: time.Unix(3, 0)},
- {Timestamp: time.Unix(4, 0)},
- {Timestamp: time.Unix(5, 0)},
- }
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != -1 {
- t.Errorf("Expected -1, got %d", i)
- }
-}
-func TestFindFirstNewMessageInMiddle(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{
- {Timestamp: time.Unix(1, 0)},
- {Timestamp: time.Unix(2, 0)},
- {Timestamp: time.Unix(3, 0)},
- {Timestamp: time.Unix(4, 0)},
- {Timestamp: time.Unix(5, 0)},
- {Timestamp: time.Unix(11, 0)},
- {Timestamp: time.Unix(12, 0)},
- {Timestamp: time.Unix(13, 0)},
- {Timestamp: time.Unix(14, 0)},
- {Timestamp: time.Unix(15, 0)},
- }
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != 5 {
- t.Errorf("Expected 5, got %d", i)
- }
-}
-func TestFindFirstNewMessageOneAfter(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{
- {Timestamp: time.Unix(15, 0)},
- }
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != 0 {
- t.Errorf("Expected 0, got %d", i)
- }
-}
-func TestFindFirstNewMessageSeveralAfter(t *testing.T) {
- CachedGlobalMessages = []TimestampedGlobalMessage{
- {Timestamp: time.Unix(11, 0)},
- {Timestamp: time.Unix(12, 0)},
- {Timestamp: time.Unix(13, 0)},
- {Timestamp: time.Unix(14, 0)},
- {Timestamp: time.Unix(15, 0)},
- }
- i := FindFirstNewMessage(tgmarray(CachedGlobalMessages), time.Unix(10, 0))
- if i != 0 {
- t.Errorf("Expected 0, got %d", i)
- }
-}
diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go
deleted file mode 100644
index c947bf08..00000000
--- a/socketserver/internal/server/commands.go
+++ /dev/null
@@ -1,285 +0,0 @@
-package server
-
-import (
- "fmt"
- "github.com/satori/go.uuid"
- "golang.org/x/net/websocket"
- "log"
- "strconv"
- "sync"
- "time"
-)
-
-var ResponseSuccess = ClientMessage{Command: SuccessCommand}
-var ResponseFailure = ClientMessage{Command: "False"}
-
-const ChannelInfoDelay = 2 * time.Second
-
-func HandleCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) {
- handler, ok := CommandHandlers[msg.Command]
- if !ok {
- log.Println("[!] Unknown command", msg.Command, "- sent by client", client.ClientID, "@", conn.RemoteAddr())
- FFZCodec.Send(conn, ClientMessage{
- MessageID: msg.MessageID,
- Command: "error",
- Arguments: fmt.Sprintf("Unknown command %s", msg.Command),
- })
- return
- }
-
- response, err := CallHandler(handler, conn, client, msg)
-
- if err == nil {
- if response.Command == AsyncResponseCommand {
- // Don't send anything
- // The response will be delivered over client.MessageChannel / serverMessageChan
- } else {
- response.MessageID = msg.MessageID
- FFZCodec.Send(conn, response)
- }
- } else {
- FFZCodec.Send(conn, ClientMessage{
- MessageID: msg.MessageID,
- Command: "error",
- Arguments: err.Error(),
- })
- }
-}
-
-func HandleHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- version, clientId, err := msg.ArgumentsAsTwoStrings()
- if err != nil {
- return
- }
-
- client.Version = version
- client.ClientID = uuid.FromStringOrNil(clientId)
- if client.ClientID == uuid.Nil {
- client.ClientID = uuid.NewV4()
- }
-
- SubscribeGlobal(client)
-
- return ClientMessage{
- Arguments: client.ClientID.String(),
- }, nil
-}
-
-func HandleReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- disconnectAt, err := msg.ArgumentsAsInt()
- if err != nil {
- return
- }
-
- client.Mutex.Lock()
- if client.MakePendingRequests != nil {
- if !client.MakePendingRequests.Stop() {
- // Timer already fired, GetSubscriptionBacklog() has started
- rmsg.Command = SuccessCommand
- return
- }
- }
- client.PendingSubscriptionsBacklog = nil
- client.MakePendingRequests = nil
- client.Mutex.Unlock()
-
- if disconnectAt == 0 {
- // backlog only
- go func() {
- client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}
- SendBacklogForNewClient(client)
- }()
- return ClientMessage{Command: AsyncResponseCommand}, nil
- } else {
- // backlog and timed
- go func() {
- client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}
- SendBacklogForNewClient(client)
- SendTimedBacklogMessages(client, time.Unix(disconnectAt, 0))
- }()
- return ClientMessage{Command: AsyncResponseCommand}, nil
- }
-}
-
-func HandleSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- username, err := msg.ArgumentsAsString()
- if err != nil {
- return
- }
-
- client.Mutex.Lock()
- client.TwitchUsername = username
- client.UsernameValidated = false
- client.Mutex.Unlock()
-
- return ResponseSuccess, nil
-}
-
-func HandleSub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- channel, err := msg.ArgumentsAsString()
-
- if err != nil {
- return
- }
-
- client.Mutex.Lock()
-
- AddToSliceS(&client.CurrentChannels, channel)
- client.PendingSubscriptionsBacklog = append(client.PendingSubscriptionsBacklog, channel)
-
- if client.MakePendingRequests == nil {
- client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client))
- } else {
- if !client.MakePendingRequests.Reset(ChannelInfoDelay) {
- client.MakePendingRequests = time.AfterFunc(ChannelInfoDelay, GetSubscriptionBacklogFor(conn, client))
- }
- }
-
- client.Mutex.Unlock()
-
- SubscribeChat(client, channel)
-
- return ResponseSuccess, nil
-}
-
-func HandleUnsub(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- channel, err := msg.ArgumentsAsString()
-
- if err != nil {
- return
- }
-
- client.Mutex.Lock()
- RemoveFromSliceS(&client.CurrentChannels, channel)
- client.Mutex.Unlock()
-
- UnsubscribeSingleChat(client, channel)
-
- return ResponseSuccess, nil
-}
-
-func GetSubscriptionBacklogFor(conn *websocket.Conn, client *ClientInfo) func() {
- return func() {
- GetSubscriptionBacklog(conn, client)
- }
-}
-
-// On goroutine
-func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) {
- var subs []string
-
- // Lock, grab the data, and reset it
- client.Mutex.Lock()
- subs = client.PendingSubscriptionsBacklog
- client.PendingSubscriptionsBacklog = nil
- client.MakePendingRequests = nil
- client.Mutex.Unlock()
-
- if len(subs) == 0 {
- return
- }
-
- if backendUrl == "" {
- return // for testing runs
- }
- messages, err := FetchBacklogData(subs)
-
- if err != nil {
- // Oh well.
- log.Print("error in GetSubscriptionBacklog:", err)
- return
- }
-
- // Deliver to client
- for _, msg := range messages {
- client.MessageChannel <- msg
- }
-}
-
-type SurveySubmission struct {
- User string
- Json string
-}
-
-var SurveySubmissions []SurveySubmission
-var SurveySubmissionLock sync.Mutex
-
-func HandleSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- SurveySubmissionLock.Lock()
- SurveySubmissions = append(SurveySubmissions, SurveySubmission{client.TwitchUsername, msg.origArguments})
- SurveySubmissionLock.Unlock()
-
- return ResponseSuccess, nil
-}
-
-type FollowEvent struct {
- User string
- Channel string
- NowFollowing bool
- Timestamp time.Time
-}
-
-var FollowEvents []FollowEvent
-var FollowEventsLock sync.Mutex
-
-func HandleTrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- channel, following, err := msg.ArgumentsAsStringAndBool()
- if err != nil {
- return
- }
- now := time.Now()
-
- FollowEventsLock.Lock()
- FollowEvents = append(FollowEvents, FollowEvent{client.TwitchUsername, channel, following, now})
- FollowEventsLock.Unlock()
-
- return ResponseSuccess, nil
-}
-
-var AggregateEmoteUsage map[int]map[string]int = make(map[int]map[string]int)
-var AggregateEmoteUsageLock sync.Mutex
-
-func HandleEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- // arguments is [1]map[EmoteId]map[RoomName]float64
-
- mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{})
-
- AggregateEmoteUsageLock.Lock()
- defer AggregateEmoteUsageLock.Unlock()
-
- for strEmote, val1 := range mapRoot {
- var emoteId int
- emoteId, err = strconv.Atoi(strEmote)
- if err != nil {
- return
- }
-
- destMapInner, ok := AggregateEmoteUsage[emoteId]
- if !ok {
- destMapInner = make(map[string]int)
- AggregateEmoteUsage[emoteId] = destMapInner
- }
-
- mapInner := val1.(map[string]interface{})
- for roomName, val2 := range mapInner {
- var count int = int(val2.(float64))
- destMapInner[roomName] += count
- }
- }
-
- return ResponseSuccess, nil
-}
-
-func HandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
- go func(conn *websocket.Conn, msg ClientMessage, authInfo AuthInfo) {
- resp, err := RequestRemoteDataCached(string(msg.Command), msg.origArguments, authInfo)
-
- if err != nil {
- FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()})
- } else {
- FFZCodec.Send(conn, ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp})
- }
- }(conn, msg, client.AuthInfo)
-
- return ClientMessage{Command: AsyncResponseCommand}, nil
-}
diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go
deleted file mode 100644
index f227db8b..00000000
--- a/socketserver/internal/server/handlecore.go
+++ /dev/null
@@ -1,434 +0,0 @@
-package server // import "bitbucket.org/stendec/frankerfacez/socketserver/internal/server"
-
-import (
- "crypto/tls"
- "encoding/json"
- "errors"
- "fmt"
- "golang.org/x/net/websocket"
- "log"
- "net/http"
- "strconv"
- "strings"
- "sync"
-)
-
-const MAX_PACKET_SIZE = 1024
-
-// A command is how the client refers to a function on the server. It's just a string.
-type Command string
-
-// A function that is called to respond to a Command.
-type CommandHandler func(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error)
-
-var CommandHandlers = map[Command]CommandHandler{
- HelloCommand: HandleHello,
- "setuser": HandleSetUser,
- "ready": HandleReady,
-
- "sub": HandleSub,
- "unsub": HandleUnsub,
-
- "track_follow": HandleTrackFollow,
- "emoticon_uses": HandleEmoticonUses,
- "survey": HandleSurvey,
-
- "twitch_emote": HandleRemoteCommand,
- "get_link": HandleRemoteCommand,
- "get_display_name": HandleRemoteCommand,
- "update_follow_buttons": HandleRemoteCommand,
- "chat_history": HandleRemoteCommand,
-}
-
-// Sent by the server in ClientMessage.Command to indicate success.
-const SuccessCommand Command = "True"
-
-// Sent by the server in ClientMessage.Command to indicate failure.
-const ErrorCommand Command = "error"
-
-// This must be the first command sent by the client once the connection is established.
-const HelloCommand Command = "hello"
-
-// A handler returning a ClientMessage with this Command will prevent replying to the client.
-// It signals that the work has been handed off to a background goroutine.
-const AsyncResponseCommand Command = "_async"
-
-// A websocket.Codec that translates the protocol into ClientMessage objects.
-var FFZCodec websocket.Codec = websocket.Codec{
- Marshal: MarshalClientMessage,
- Unmarshal: UnmarshalClientMessage,
-}
-
-// Errors that get returned to the client.
-var ProtocolError error = errors.New("FFZ Socket protocol error.")
-var ProtocolErrorNegativeID error = errors.New("FFZ Socket protocol error: negative or zero message ID.")
-var ExpectedSingleString = errors.New("Error: Expected single string as arguments.")
-var ExpectedSingleInt = errors.New("Error: Expected single integer as arguments.")
-var ExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.")
-var ExpectedStringAndInt = errors.New("Error: Expected array of string, int as arguments.")
-var ExpectedStringAndBool = errors.New("Error: Expected array of string, bool as arguments.")
-var ExpectedStringAndIntGotFloat = errors.New("Error: Second argument was a float, expected an integer.")
-
-var gconfig *ConfigFile
-
-// Create a websocket.Server with the options from the provided Config.
-func setupServer(config *ConfigFile, tlsConfig *tls.Config) *websocket.Server {
- gconfig = config
- sockConf, err := websocket.NewConfig("/", config.SocketOrigin)
- if err != nil {
- log.Fatal(err)
- }
-
- SetupBackend(config)
-
- if config.UseSSL {
- cert, err := tls.LoadX509KeyPair(config.SSLCertificateFile, config.SSLKeyFile)
- if err != nil {
- log.Fatal(err)
- }
- tlsConfig.Certificates = []tls.Certificate{cert}
- tlsConfig.ServerName = config.SocketOrigin
- tlsConfig.BuildNameToCertificate()
- sockConf.TlsConfig = tlsConfig
-
- }
-
- sockServer := &websocket.Server{}
- sockServer.Config = *sockConf
- sockServer.Handler = HandleSocketConnection
-
- go deadChannelReaper()
-
- return sockServer
-}
-
-// Set up a websocket listener and register it on /.
-// (Uses http.DefaultServeMux .)
-func SetupServerAndHandle(config *ConfigFile, tlsConfig *tls.Config, serveMux *http.ServeMux) {
- sockServer := setupServer(config, tlsConfig)
-
- if serveMux == nil {
- serveMux = http.DefaultServeMux
- }
- serveMux.HandleFunc("/", ServeWebsocketOrCatbag(sockServer.ServeHTTP))
- serveMux.HandleFunc("/pub_msg", HBackendPublishRequest)
- serveMux.HandleFunc("/dump_backlog", HBackendDumpBacklog)
- serveMux.HandleFunc("/update_and_pub", HBackendUpdateAndPublish)
-}
-
-func ServeWebsocketOrCatbag(sockfunc func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- if r.Header.Get("Connection") == "Upgrade" {
- sockfunc(w, r)
- return
- } else {
- w.Write([]byte(gconfig.BannerHTML))
- }
- }
-}
-
-// Handle a new websocket connection from a FFZ client.
-// This runs in a goroutine started by net/http.
-func HandleSocketConnection(conn *websocket.Conn) {
- // websocket.Conn is a ReadWriteCloser
-
- var _closer sync.Once
- closer := func() {
- _closer.Do(func() {
- conn.Close()
- })
- }
-
- // Close the connection when we're done.
- defer closer()
-
- _clientChan := make(chan ClientMessage)
- _serverMessageChan := make(chan ClientMessage)
- _errorChan := make(chan error)
-
- // Launch receiver goroutine
- go func(errorChan chan<- error, clientChan chan<- ClientMessage) {
- var msg ClientMessage
- var err error
- for ; err == nil; err = FFZCodec.Receive(conn, &msg) {
- if msg.MessageID == 0 {
- continue
- }
- clientChan <- msg
- }
- errorChan <- err
- close(errorChan)
- close(clientChan)
- // exit
- }(_errorChan, _clientChan)
-
- var errorChan <-chan error = _errorChan
- var clientChan <-chan ClientMessage = _clientChan
- var serverMessageChan <-chan ClientMessage = _serverMessageChan
-
- var client ClientInfo
- client.MessageChannel = _serverMessageChan
-
- // All set up, now enter the work loop
-
-RunLoop:
- for {
- select {
- case err := <-errorChan:
- FFZCodec.Send(conn, ClientMessage{
- MessageID: -1,
- Command: "error",
- Arguments: err.Error(),
- }) // note - socket might be closed, but don't care
- break RunLoop
- case msg := <-clientChan:
- if client.Version == "" && msg.Command != HelloCommand {
- FFZCodec.Send(conn, ClientMessage{
- MessageID: msg.MessageID,
- Command: "error",
- Arguments: "Error - the first message sent must be a 'hello'",
- })
- break RunLoop
- }
-
- HandleCommand(conn, &client, msg)
- case smsg := <-serverMessageChan:
- FFZCodec.Send(conn, smsg)
- }
- }
-
- // Exit
-
- // Launch message draining goroutine - we aren't out of the pub/sub records
- go func() {
- for _ = range _serverMessageChan {
- }
- }()
-
- // Stop getting messages...
- UnsubscribeAll(&client)
-
- // And finished.
- // Close the channel so the draining goroutine can finish, too.
- close(_serverMessageChan)
-}
-
-func CallHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) {
- defer func() {
- if r := recover(); r != nil {
- var ok bool
- fmt.Print("[!] Error executing command", cmsg.Command, "--", r)
- err, ok = r.(error)
- if !ok {
- err = fmt.Errorf("command handler: %v", r)
- }
- }
- }()
- return handler(conn, client, cmsg)
-}
-
-// Unpack a message sent from the client into a ClientMessage.
-func UnmarshalClientMessage(data []byte, payloadType byte, v interface{}) (err error) {
- var spaceIdx int
-
- out := v.(*ClientMessage)
- dataStr := string(data)
-
- // Message ID
- spaceIdx = strings.IndexRune(dataStr, ' ')
- if spaceIdx == -1 {
- return ProtocolError
- }
- messageId, err := strconv.Atoi(dataStr[:spaceIdx])
- if messageId < -1 || messageId == 0 {
- return ProtocolErrorNegativeID
- }
-
- out.MessageID = messageId
- dataStr = dataStr[spaceIdx+1:]
-
- spaceIdx = strings.IndexRune(dataStr, ' ')
- if spaceIdx == -1 {
- out.Command = Command(dataStr)
- out.Arguments = nil
- return nil
- } else {
- out.Command = Command(dataStr[:spaceIdx])
- }
- dataStr = dataStr[spaceIdx+1:]
- argumentsJson := dataStr
- out.origArguments = argumentsJson
- err = out.parseOrigArguments()
- if err != nil {
- return
- }
- return nil
-}
-
-func (cm *ClientMessage) parseOrigArguments() error {
- err := json.Unmarshal([]byte(cm.origArguments), &cm.Arguments)
- if err != nil {
- return err
- }
- return nil
-}
-
-func MarshalClientMessage(clientMessage interface{}) (data []byte, payloadType byte, err error) {
- var msg ClientMessage
- var ok bool
- msg, ok = clientMessage.(ClientMessage)
- if !ok {
- pMsg, ok := clientMessage.(*ClientMessage)
- if !ok {
- panic("MarshalClientMessage: argument needs to be a ClientMessage")
- }
- msg = *pMsg
- }
- var dataStr string
-
- if msg.Command == "" && msg.MessageID == 0 {
- panic("MarshalClientMessage: attempt to send an empty ClientMessage")
- }
-
- if msg.Command == "" {
- msg.Command = SuccessCommand
- }
- if msg.MessageID == 0 {
- msg.MessageID = -1
- }
-
- if msg.Arguments != nil {
- argBytes, err := json.Marshal(msg.Arguments)
- if err != nil {
- return nil, 0, err
- }
-
- dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes))
- } else {
- dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command)
- }
-
- return []byte(dataStr), websocket.TextFrame, nil
-}
-
-// Command handlers should use this to construct responses.
-func NewClientMessage(arguments interface{}) ClientMessage {
- return ClientMessage{
- MessageID: 0, // filled by the select loop
- Command: SuccessCommand,
- Arguments: arguments,
- }
-}
-
-// Convenience method: Parse the arguments of the ClientMessage as a single string.
-func (cm *ClientMessage) ArgumentsAsString() (string1 string, err error) {
- var ok bool
- string1, ok = cm.Arguments.(string)
- if !ok {
- err = ExpectedSingleString
- return
- } else {
- return string1, nil
- }
-}
-
-// Convenience method: Parse the arguments of the ClientMessage as a single int.
-func (cm *ClientMessage) ArgumentsAsInt() (int1 int64, err error) {
- var ok bool
- var num float64
- num, ok = cm.Arguments.(float64)
- if !ok {
- err = ExpectedSingleInt
- return
- } else {
- int1 = int64(num)
- return int1, nil
- }
-}
-
-// Convenience method: Parse the arguments of the ClientMessage as an array of two strings.
-func (cm *ClientMessage) ArgumentsAsTwoStrings() (string1, string2 string, err error) {
- var ok bool
- var ary []interface{}
- ary, ok = cm.Arguments.([]interface{})
- if !ok {
- err = ExpectedTwoStrings
- return
- } else {
- if len(ary) != 2 {
- err = ExpectedTwoStrings
- return
- }
- string1, ok = ary[0].(string)
- if !ok {
- err = ExpectedTwoStrings
- return
- }
- string2, ok = ary[1].(string)
- if !ok {
- err = ExpectedTwoStrings
- return
- }
- return string1, string2, nil
- }
-}
-
-// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int.
-func (cm *ClientMessage) ArgumentsAsStringAndInt() (string1 string, int int64, err error) {
- var ok bool
- var ary []interface{}
- ary, ok = cm.Arguments.([]interface{})
- if !ok {
- err = ExpectedStringAndInt
- return
- } else {
- if len(ary) != 2 {
- err = ExpectedStringAndInt
- return
- }
- string1, ok = ary[0].(string)
- if !ok {
- err = ExpectedStringAndInt
- return
- }
- var num float64
- num, ok = ary[1].(float64)
- if !ok {
- err = ExpectedStringAndInt
- return
- }
- int = int64(num)
- if float64(int) != num {
- err = ExpectedStringAndIntGotFloat
- return
- }
- return string1, int, nil
- }
-}
-
-// Convenience method: Parse the arguments of the ClientMessage as an array of a string and an int.
-func (cm *ClientMessage) ArgumentsAsStringAndBool() (str string, flag bool, err error) {
- var ok bool
- var ary []interface{}
- ary, ok = cm.Arguments.([]interface{})
- if !ok {
- err = ExpectedStringAndBool
- return
- } else {
- if len(ary) != 2 {
- err = ExpectedStringAndBool
- return
- }
- str, ok = ary[0].(string)
- if !ok {
- err = ExpectedStringAndBool
- return
- }
- flag, ok = ary[1].(bool)
- if !ok {
- err = ExpectedStringAndBool
- return
- }
- return str, flag, nil
- }
-}
diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go
deleted file mode 100644
index cc9ba947..00000000
--- a/socketserver/internal/server/types.go
+++ /dev/null
@@ -1,232 +0,0 @@
-package server
-
-import (
- "encoding/json"
- "github.com/satori/go.uuid"
- "sync"
- "time"
-)
-
-const CryptoBoxKeyLength = 32
-
-type ConfigFile struct {
- // Numeric server id known to the backend
- ServerId int
- ListenAddr string
- // Hostname of the socket server
- SocketOrigin string
- // URL to the backend server
- BackendUrl string
- // Memes go here
- BannerHTML string
-
- // SSL/TLS
- UseSSL bool
- SSLCertificateFile string
- SSLKeyFile string
-
- // Nacl keys
- OurPrivateKey []byte
- OurPublicKey []byte
- BackendPublicKey []byte
-}
-
-type ClientMessage struct {
- // Message ID. Increments by 1 for each message sent from the client.
- // When replying to a command, the message ID must be echoed.
- // When sending a server-initiated message, this is -1.
- MessageID int
- // The command that the client wants from the server.
- // When sent from the server, the literal string 'True' indicates success.
- // Before sending, a blank Command will be converted into SuccessCommand.
- Command Command
- // Result of json.Unmarshal on the third field send from the client
- Arguments interface{}
-
- origArguments string
-}
-
-type AuthInfo struct {
- // The client's claimed username on Twitch.
- TwitchUsername string
-
- // Whether or not the server has validated the client's claimed username.
- UsernameValidated bool
-}
-
-type ClientInfo struct {
- // The client ID.
- // This must be written once by the owning goroutine before the struct is passed off to any other goroutines.
- ClientID uuid.UUID
-
- // The client's version.
- // This must be written once by the owning goroutine before the struct is passed off to any other goroutines.
- Version string
-
- // This mutex protects writable data in this struct.
- // If it seems to be a performance problem, we can split this.
- Mutex sync.Mutex
-
- // TODO(riking) - does this need to be protected cross-thread?
- AuthInfo
-
- // Username validation nonce.
- ValidationNonce string
-
- // The list of chats this client is currently in.
- // Protected by Mutex.
- CurrentChannels []string
-
- // List of channels that we have not yet checked current chat-related channel info for.
- // This lets us batch the backlog requests.
- // Protected by Mutex.
- PendingSubscriptionsBacklog []string
-
- // A timer that, when fired, will make the pending backlog requests.
- // Usually nil. Protected by Mutex.
- MakePendingRequests *time.Timer
-
- // Server-initiated messages should be sent here
- // Never nil.
- MessageChannel chan<- ClientMessage
-}
-
-type tgmarray []TimestampedGlobalMessage
-type tmmarray []TimestampedMultichatMessage
-
-func (ta tgmarray) Len() int {
- return len(ta)
-}
-func (ta tgmarray) Less(i, j int) bool {
- return ta[i].Timestamp.Before(ta[j].Timestamp)
-}
-func (ta tgmarray) Swap(i, j int) {
- ta[i], ta[j] = ta[j], ta[i]
-}
-func (ta tgmarray) GetTime(i int) time.Time {
- return ta[i].Timestamp
-}
-func (ta tmmarray) Len() int {
- return len(ta)
-}
-func (ta tmmarray) Less(i, j int) bool {
- return ta[i].Timestamp.Before(ta[j].Timestamp)
-}
-func (ta tmmarray) Swap(i, j int) {
- ta[i], ta[j] = ta[j], ta[i]
-}
-func (ta tmmarray) GetTime(i int) time.Time {
- return ta[i].Timestamp
-}
-
-func (bct BacklogCacheType) Name() string {
- switch bct {
- case CacheTypeInvalid:
- return ""
- case CacheTypeNever:
- return "never"
- case CacheTypeTimestamps:
- return "timed"
- case CacheTypeLastOnly:
- return "last"
- case CacheTypePersistent:
- return "persist"
- }
- panic("Invalid BacklogCacheType value")
-}
-
-var CacheTypesByName = map[string]BacklogCacheType{
- "never": CacheTypeNever,
- "timed": CacheTypeTimestamps,
- "last": CacheTypeLastOnly,
- "persist": CacheTypePersistent,
-}
-
-func BacklogCacheTypeByName(name string) (bct BacklogCacheType) {
- // CacheTypeInvalid is the zero value so it doesn't matter
- bct, _ = CacheTypesByName[name]
- return
-}
-
-// Implements Stringer
-func (bct BacklogCacheType) String() string { return bct.Name() }
-
-// Implements json.Marshaler
-func (bct BacklogCacheType) MarshalJSON() ([]byte, error) {
- return json.Marshal(bct.Name())
-}
-
-// Implements json.Unmarshaler
-func (pbct *BacklogCacheType) UnmarshalJSON(data []byte) error {
- var str string
- err := json.Unmarshal(data, &str)
- if err != nil {
- return err
- }
- if str == "" {
- *pbct = CacheTypeInvalid
- return nil
- }
- val := BacklogCacheTypeByName(str)
- if val != CacheTypeInvalid {
- *pbct = val
- return nil
- }
- return ErrorUnrecognizedCacheType
-}
-
-func (mtt MessageTargetType) Name() string {
- switch mtt {
- case MsgTargetTypeInvalid:
- return ""
- case MsgTargetTypeSingle:
- return "single"
- case MsgTargetTypeChat:
- return "chat"
- case MsgTargetTypeMultichat:
- return "multichat"
- case MsgTargetTypeGlobal:
- return "global"
- }
- panic("Invalid MessageTargetType value")
-}
-
-var TargetTypesByName = map[string]MessageTargetType{
- "single": MsgTargetTypeSingle,
- "chat": MsgTargetTypeChat,
- "multichat": MsgTargetTypeMultichat,
- "global": MsgTargetTypeGlobal,
-}
-
-func MessageTargetTypeByName(name string) (mtt MessageTargetType) {
- // MsgTargetTypeInvalid is the zero value so it doesn't matter
- mtt, _ = TargetTypesByName[name]
- return
-}
-
-// Implements Stringer
-func (mtt MessageTargetType) String() string { return mtt.Name() }
-
-// Implements json.Marshaler
-func (mtt MessageTargetType) MarshalJSON() ([]byte, error) {
- return json.Marshal(mtt.Name())
-}
-
-// Implements json.Unmarshaler
-func (pmtt *MessageTargetType) UnmarshalJSON(data []byte) error {
- var str string
- err := json.Unmarshal(data, &str)
- if err != nil {
- return err
- }
- if str == "" {
- *pmtt = MsgTargetTypeInvalid
- return nil
- }
- mtt := MessageTargetTypeByName(str)
- if mtt != MsgTargetTypeInvalid {
- *pmtt = mtt
- return nil
- }
- return ErrorUnrecognizedTargetType
-}
diff --git a/socketserver/server/backend.go b/socketserver/server/backend.go
new file mode 100644
index 00000000..f41defde
--- /dev/null
+++ b/socketserver/server/backend.go
@@ -0,0 +1,301 @@
+package server
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "sync"
+
+ "github.com/pmylund/go-cache"
+ "golang.org/x/crypto/nacl/box"
+)
+
+const bPathAnnounceStartup = "/startup"
+const bPathAddTopic = "/topics"
+const bPathAggStats = "/stats"
+const bPathOtherCommand = "/cmd/"
+
+type backendInfo struct {
+ HTTPClient http.Client
+ baseURL string
+ responseCache *cache.Cache
+
+ postStatisticsURL string
+ addTopicURL string
+ announceStartupURL string
+
+ sharedKey [32]byte
+ serverID int
+
+ lastSuccess map[string]time.Time
+ lastSuccessLock sync.Mutex
+}
+
+var Backend *backendInfo
+
+func setupBackend(config *ConfigFile) *backendInfo {
+ b := new(backendInfo)
+ Backend = b
+ b.serverID = config.ServerID
+
+ b.HTTPClient.Timeout = 60 * time.Second
+ b.baseURL = config.BackendURL
+ b.responseCache = cache.New(60*time.Second, 120*time.Second)
+
+ b.announceStartupURL = fmt.Sprintf("%s%s", b.baseURL, bPathAnnounceStartup)
+ b.addTopicURL = fmt.Sprintf("%s%s", b.baseURL, bPathAddTopic)
+ b.postStatisticsURL = fmt.Sprintf("%s%s", b.baseURL, bPathAggStats)
+
+ epochTime := time.Unix(0, 0).UTC()
+ lastBackendSuccess := map[string]time.Time{
+ bPathAnnounceStartup: epochTime,
+ bPathAddTopic: epochTime,
+ bPathAggStats: epochTime,
+ bPathOtherCommand: epochTime,
+ }
+ b.lastSuccess = lastBackendSuccess
+
+ var theirPublic, ourPrivate [32]byte
+ copy(theirPublic[:], config.BackendPublicKey)
+ copy(ourPrivate[:], config.OurPrivateKey)
+
+ box.Precompute(&b.sharedKey, &theirPublic, &ourPrivate)
+
+ return b
+}
+
+func getCacheKey(remoteCommand, data string) string {
+ return fmt.Sprintf("%s/%s", remoteCommand, data)
+}
+
+// ErrForwardedFromBackend is an error returned by the backend server.
+type ErrForwardedFromBackend struct {
+ JSONError interface{}
+}
+
+func (bfe ErrForwardedFromBackend) Error() string {
+ bytes, _ := json.Marshal(bfe.JSONError)
+ return string(bytes)
+}
+
+// ErrAuthorizationNeeded is emitted when the backend replies with HTTP 401.
+// Indicates that an attempt to validate `ClientInfo.TwitchUsername` should be attempted.
+var ErrAuthorizationNeeded = errors.New("Must authenticate Twitch username to use this command")
+
+// SendRemoteCommandCached performs a RPC call on the backend, but caches responses.
+func (backend *backendInfo) SendRemoteCommandCached(remoteCommand, data string, auth AuthInfo) (string, error) {
+ cached, ok := backend.responseCache.Get(getCacheKey(remoteCommand, data))
+ if ok {
+ return cached.(string), nil
+ }
+ return backend.SendRemoteCommand(remoteCommand, data, auth)
+}
+
+// SendRemoteCommand performs a RPC call on the backend by POSTing to `/cmd/$remoteCommand`.
+// The form data is as follows: `clientData` is the JSON in the `data` parameter
+// (should be retrieved from ClientMessage.Arguments), and either `username` or
+// `usernameClaimed` depending on whether AuthInfo.UsernameValidates is true is AuthInfo.TwitchUsername.
+func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) {
+ destURL := fmt.Sprintf("%s/cmd/%s", backend.baseURL, remoteCommand)
+ healthBucket := fmt.Sprintf("/cmd/%s", remoteCommand)
+
+ formData := url.Values{
+ "clientData": []string{data},
+ "username": []string{auth.TwitchUsername},
+ }
+
+ if auth.UsernameValidated {
+ formData.Set("authenticated", "1")
+ } else {
+ formData.Set("authenticated", "0")
+ }
+
+ sealedForm, err := backend.SealRequest(formData)
+ if err != nil {
+ return "", err
+ }
+
+ resp, err := backend.HTTPClient.PostForm(destURL, sealedForm)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ respBytes, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return "", err
+ }
+
+ responseStr = string(respBytes)
+
+ if resp.StatusCode == 401 {
+ return "", ErrAuthorizationNeeded
+ } else if resp.StatusCode != 200 {
+ if resp.Header.Get("Content-Type") == "application/json" {
+ var err2 ErrForwardedFromBackend
+ err := json.Unmarshal(respBytes, &err2.JSONError)
+ if err != nil {
+ return "", fmt.Errorf("error decoding json error from backend: %v | %s", err, responseStr)
+ }
+ return "", err2
+ }
+ return "", httpError(resp.StatusCode)
+ }
+
+ if resp.Header.Get("FFZ-Cache") != "" {
+ durSecs, err := strconv.ParseInt(resp.Header.Get("FFZ-Cache"), 10, 64)
+ if err != nil {
+ return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err)
+ }
+ duration := time.Duration(durSecs) * time.Second
+ backend.responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration)
+ }
+
+ now := time.Now().UTC()
+ backend.lastSuccessLock.Lock()
+ defer backend.lastSuccessLock.Unlock()
+ backend.lastSuccess[bPathOtherCommand] = now
+ backend.lastSuccess[healthBucket] = now
+
+ return
+}
+
+// SendAggregatedData sends aggregated emote usage and following data to the backend server.
+func (backend *backendInfo) SendAggregatedData(form url.Values) error {
+ sealedForm, err := backend.SealRequest(form)
+ if err != nil {
+ return err
+ }
+
+ resp, err := backend.HTTPClient.PostForm(backend.postStatisticsURL, sealedForm)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode > 299 {
+ resp.Body.Close()
+ return httpError(resp.StatusCode)
+ }
+
+ backend.lastSuccessLock.Lock()
+ defer backend.lastSuccessLock.Unlock()
+ backend.lastSuccess[bPathAggStats] = time.Now().UTC()
+
+ return resp.Body.Close()
+}
+
+// ErrBackendNotOK indicates that the backend replied with something other than the string "ok".
+type ErrBackendNotOK struct {
+ Response string
+ Code int
+}
+
+// Error Implements the error interface.
+func (noe ErrBackendNotOK) Error() string {
+ return fmt.Sprintf("backend returned %d: %s", noe.Code, noe.Response)
+}
+
+// SendNewTopicNotice notifies the backend that a client has performed the first subscription to a pub/sub topic.
+// POST data:
+// channels=room.trihex
+// added=t
+func (backend *backendInfo) SendNewTopicNotice(topic string) error {
+ return backend.sendTopicNotice(topic, true)
+}
+
+// SendCleanupTopicsNotice notifies the backend that pub/sub topics have no subscribers anymore.
+// POST data:
+// channels=room.sirstendec,room.bobross,feature.foo
+// added=f
+func (backend *backendInfo) SendCleanupTopicsNotice(topics []string) error {
+ return backend.sendTopicNotice(strings.Join(topics, ","), false)
+}
+
+func (backend *backendInfo) sendTopicNotice(topic string, added bool) error {
+ formData := url.Values{}
+ formData.Set("channels", topic)
+ if added {
+ formData.Set("added", "t")
+ } else {
+ formData.Set("added", "f")
+ }
+
+ sealedForm, err := backend.SealRequest(formData)
+ if err != nil {
+ return err
+ }
+
+ resp, err := backend.HTTPClient.PostForm(backend.addTopicURL, sealedForm)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode < 200 || resp.StatusCode > 299 {
+ respBytes, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return ErrBackendNotOK{Code: resp.StatusCode, Response: fmt.Sprintf("(error reading non-2xx response): %s", err.Error())}
+ }
+ return ErrBackendNotOK{Code: resp.StatusCode, Response: string(respBytes)}
+ }
+
+ backend.lastSuccessLock.Lock()
+ defer backend.lastSuccessLock.Unlock()
+ backend.lastSuccess[bPathAddTopic] = time.Now().UTC()
+
+ return nil
+}
+
+func httpError(statusCode int) error {
+ return fmt.Errorf("backend http error: %d", statusCode)
+}
+
+// GenerateKeys generates a new NaCl keypair for the server and writes out the default configuration file.
+func GenerateKeys(outputFile, serverID, theirPublicStr string) {
+ var err error
+ output := ConfigFile{
+ ListenAddr: "0.0.0.0:8001",
+ SSLListenAddr: "0.0.0.0:443",
+ BackendURL: "http://localhost:8002/ffz",
+ MinMemoryKBytes: defaultMinMemoryKB,
+ }
+
+ output.ServerID, err = strconv.Atoi(serverID)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ ourPublic, ourPrivate, err := box.GenerateKey(rand.Reader)
+ if err != nil {
+ log.Fatal(err)
+ }
+ output.OurPublicKey, output.OurPrivateKey = ourPublic[:], ourPrivate[:]
+
+ if theirPublicStr != "" {
+ reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(theirPublicStr))
+ theirPublic, err := ioutil.ReadAll(reader)
+ if err != nil {
+ log.Fatal(err)
+ }
+ output.BackendPublicKey = theirPublic
+ }
+
+ bytes, err := json.MarshalIndent(output, "", "\t")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(string(bytes))
+ err = ioutil.WriteFile(outputFile, bytes, 0600)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/socketserver/server/backend_test.go b/socketserver/server/backend_test.go
new file mode 100644
index 00000000..d1c85adb
--- /dev/null
+++ b/socketserver/server/backend_test.go
@@ -0,0 +1,131 @@
+package server
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ . "gopkg.in/check.v1"
+)
+
+func Test(t *testing.T) { TestingT(t) }
+
+func TestSealRequest(t *testing.T) {
+ TSetup(SetupNoServers, nil)
+ b := Backend
+
+ values := url.Values{
+ "QuickBrownFox": []string{"LazyDog"},
+ }
+
+ sealedValues, err := b.SealRequest(values)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // sealedValues.Encode()
+ // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV
+
+ unsealedValues, err := b.UnsealRequest(sealedValues)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if unsealedValues.Get("QuickBrownFox") != "LazyDog" {
+ t.Errorf("Failed to round-trip, got back %v", unsealedValues)
+ }
+}
+
+type BackendSuite struct{}
+
+var _ = Suite(&BackendSuite{})
+
+func (s *BackendSuite) TestSendRemoteCommand(c *C) {
+ const TestCommand1 = "somecommand"
+ const TestCommand2 = "other"
+ const PathTestCommand1 = "/cmd/" + TestCommand1
+ const PathTestCommand2 = "/cmd/" + TestCommand2
+ const TestData1 = "623478.32"
+ const TestData2 = "\"Hello, there\""
+ const TestData3 = "3"
+ const TestUsername = "sirstendec"
+ const TestResponse1 = "asfdg"
+ const TestResponse2 = "yuiop"
+ const TestErrorText = "{\"err\":\"some kind of special error\"}"
+
+ var AnonAuthInfo = AuthInfo{}
+ var NonValidatedAuthInfo = AuthInfo{TwitchUsername: TestUsername}
+ var ValidatedAuthInfo = AuthInfo{TwitchUsername: TestUsername, UsernameValidated: true}
+
+ headersCacheTwoSeconds := http.Header{"FFZ-Cache": []string{"2"}}
+ headersCacheInvalid := http.Header{"FFZ-Cache": []string{"NotANumber"}}
+ headersApplicationJson := http.Header{"Content-Type": []string{"application/json"}}
+
+ mockBackend := NewTBackendRequestChecker(c,
+ TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{""}}, TestResponse1, nil},
+ TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{""}}, TestResponse2, nil},
+ TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, nil},
+ TExpectedBackendRequest{200, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"1"}, "username": []string{TestUsername}}, TestResponse1, nil},
+ TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData2}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, headersCacheTwoSeconds},
+ // cached
+ // cached
+ TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse2, headersCacheTwoSeconds},
+ TExpectedBackendRequest{401, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, "", nil},
+ TExpectedBackendRequest{503, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, "", nil},
+ TExpectedBackendRequest{418, PathTestCommand1, &url.Values{"clientData": []string{TestData1}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestErrorText, headersApplicationJson},
+ TExpectedBackendRequest{200, PathTestCommand2, &url.Values{"clientData": []string{TestData3}, "authenticated": []string{"0"}, "username": []string{TestUsername}}, TestResponse1, headersCacheInvalid},
+ )
+ _, _, _ = TSetup(SetupWantBackendServer, mockBackend)
+ defer mockBackend.Close()
+
+ var resp string
+ var err error
+ b := Backend
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, AnonAuthInfo)
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, AnonAuthInfo)
+ c.Check(resp, Equals, TestResponse2)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo)
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, ValidatedAuthInfo)
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+ // cache save
+ resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, NonValidatedAuthInfo)
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, NonValidatedAuthInfo) // cache hit
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommandCached(TestCommand2, TestData2, AnonAuthInfo) // cache hit
+ c.Check(resp, Equals, TestResponse1)
+ c.Check(err, IsNil)
+ // cache miss - data is different
+ resp, err = b.SendRemoteCommandCached(TestCommand2, TestData1, NonValidatedAuthInfo)
+ c.Check(resp, Equals, TestResponse2)
+ c.Check(err, IsNil)
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo)
+ c.Check(resp, Equals, "")
+ c.Check(err, Equals, ErrAuthorizationNeeded)
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo)
+ c.Check(resp, Equals, "")
+ c.Check(err, ErrorMatches, "backend http error: 503")
+
+ resp, err = b.SendRemoteCommand(TestCommand1, TestData1, NonValidatedAuthInfo)
+ c.Check(resp, Equals, "")
+ c.Check(err, ErrorMatches, TestErrorText)
+
+ resp, err = b.SendRemoteCommand(TestCommand2, TestData3, NonValidatedAuthInfo)
+ c.Check(resp, Equals, "")
+ c.Check(err, ErrorMatches, "The RPC server returned a non-integer cache duration: .*")
+}
diff --git a/socketserver/server/commands.go b/socketserver/server/commands.go
new file mode 100644
index 00000000..3d0156a1
--- /dev/null
+++ b/socketserver/server/commands.go
@@ -0,0 +1,606 @@
+package server
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "net/url"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/satori/go.uuid"
+)
+
+// Command is a string indicating which RPC is requested.
+// The Commands sent from Client -> Server and Server -> Client are disjoint sets.
+type Command string
+
+// CommandHandler is a RPC handler associated with a Command.
+type CommandHandler func(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error)
+
+var commandHandlers = map[Command]CommandHandler{
+ HelloCommand: C2SHello,
+ "ping": C2SPing,
+ SetUserCommand: C2SSetUser,
+ ReadyCommand: C2SReady,
+
+ "sub": C2SSubscribe,
+ "unsub": C2SUnsubscribe,
+
+ "track_follow": C2STrackFollow,
+ "emoticon_uses": C2SEmoticonUses,
+ "survey": C2SSurvey,
+
+ "twitch_emote": C2SHandleBunchedCommand,
+ "get_link": C2SHandleBunchedCommand,
+ "get_display_name": C2SHandleBunchedCommand,
+ "update_follow_buttons": C2SHandleRemoteCommand,
+ "chat_history": C2SHandleRemoteCommand,
+ "user_history": C2SHandleRemoteCommand,
+}
+
+func setupInterning() {
+ PubSubChannelPool = NewStringPool()
+ TwitchChannelPool = NewStringPool()
+
+ CommandPool = NewStringPool()
+ CommandPool._Intern_Setup(string(HelloCommand))
+ CommandPool._Intern_Setup("ping")
+ CommandPool._Intern_Setup(string(SetUserCommand))
+ CommandPool._Intern_Setup(string(ReadyCommand))
+ CommandPool._Intern_Setup("sub")
+ CommandPool._Intern_Setup("unsub")
+ CommandPool._Intern_Setup("track_follow")
+ CommandPool._Intern_Setup("emoticon_uses")
+ CommandPool._Intern_Setup("twitch_emote")
+ CommandPool._Intern_Setup("get_link")
+ CommandPool._Intern_Setup("get_display_name")
+ CommandPool._Intern_Setup("update_follow_buttons")
+ CommandPool._Intern_Setup("chat_history")
+ CommandPool._Intern_Setup("user_history")
+ CommandPool._Intern_Setup("adjacent_history")
+}
+
+// DispatchC2SCommand handles a C2S Command in the provided ClientMessage.
+// It calls the correct CommandHandler function, catching panics.
+// It sends either the returned Reply ClientMessage, setting the correct messageID, or sends an ErrorCommand
+func DispatchC2SCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) {
+ handler, ok := commandHandlers[msg.Command]
+ if !ok {
+ handler = C2SHandleRemoteCommand
+ }
+
+ CommandCounter <- msg.Command
+
+ response, err := callHandler(handler, conn, client, msg)
+
+ if err == nil {
+ if response.Command == AsyncResponseCommand {
+ // Don't send anything
+ // The response will be delivered over client.MessageChannel / serverMessageChan
+ } else {
+ response.MessageID = msg.MessageID
+ SendMessage(conn, response)
+ }
+ } else {
+ SendMessage(conn, ClientMessage{
+ MessageID: msg.MessageID,
+ Command: ErrorCommand,
+ Arguments: err.Error(),
+ })
+ }
+}
+
+func callHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ var ok bool
+ fmt.Print("[!] Error executing command", cmsg.Command, "--", r)
+ err, ok = r.(error)
+ if !ok {
+ err = fmt.Errorf("command handler: %v", r)
+ }
+ }
+ }()
+ return handler(conn, client, cmsg)
+}
+
+// C2SHello implements the `hello` C2S Command.
+// It calls SubscribeGlobal() and SubscribeDefaults() with the client, and fills out ClientInfo.Version and ClientInfo.ClientID.
+func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ ary, ok := msg.Arguments.([]interface{})
+ if !ok {
+ err = ErrExpectedTwoStrings
+ return
+ }
+ if len(ary) != 2 {
+ err = ErrExpectedTwoStrings
+ return
+ }
+ version, ok := ary[0].(string)
+ if !ok {
+ err = ErrExpectedTwoStrings
+ return
+ }
+
+ client.VersionString = copyString(version)
+ client.Version = VersionFromString(version)
+
+ if clientIDStr, ok := ary[1].(string); ok {
+ client.ClientID = uuid.FromStringOrNil(clientIDStr)
+ if client.ClientID == uuid.Nil {
+ client.ClientID = uuid.NewV4()
+ }
+ } else if _, ok := ary[1].(bool); ok {
+ // opt out
+ client.ClientID = AnonymousClientID
+ } else {
+ err = ErrExpectedTwoStrings
+ return
+ }
+
+ uniqueUserChannel <- client.ClientID
+
+ SubscribeGlobal(client)
+ SubscribeDefaults(client)
+
+ jsTime := float64(time.Now().UnixNano()/1000) / 1000
+ return ClientMessage{
+ Arguments: []interface{}{
+ client.ClientID.String(),
+ jsTime,
+ },
+ }, nil
+}
+
+func C2SPing(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ return ClientMessage{
+ Arguments: float64(time.Now().UnixNano()/1000) / 1000,
+ }, nil
+}
+
+func C2SSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ username, err := msg.ArgumentsAsString()
+ if err != nil {
+ return
+ }
+
+ username = copyString(username)
+
+ client.Mutex.Lock()
+ client.UsernameValidated = false
+ client.TwitchUsername = username
+ client.Mutex.Unlock()
+
+ if Configuration.SendAuthToNewClients {
+ client.MsgChannelKeepalive.Add(1)
+ go client.StartAuthorization(func(_ *ClientInfo, _ bool) {
+ client.MsgChannelKeepalive.Done()
+ })
+ }
+
+ return ResponseSuccess, nil
+}
+
+func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ // disconnectAt, err := msg.ArgumentsAsInt()
+ // if err != nil {
+ // return
+ // }
+
+ client.Mutex.Lock()
+ client.ReadyComplete = true
+ client.Mutex.Unlock()
+
+ client.MsgChannelKeepalive.Add(1)
+ go func() {
+ client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}
+ SendBacklogForNewClient(client)
+ client.MsgChannelKeepalive.Done()
+ }()
+ return ClientMessage{Command: AsyncResponseCommand}, nil
+}
+
+func C2SSubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ channel, err := msg.ArgumentsAsString()
+ if err != nil {
+ return
+ }
+
+ channel = PubSubChannelPool.Intern(channel)
+
+ client.Mutex.Lock()
+ AddToSliceS(&client.CurrentChannels, channel)
+ client.Mutex.Unlock()
+
+ SubscribeChannel(client, channel)
+
+ if client.ReadyComplete {
+ client.MsgChannelKeepalive.Add(1)
+ go func() {
+ SendBacklogForChannel(client, channel)
+ client.MsgChannelKeepalive.Done()
+ }()
+ }
+
+ return ResponseSuccess, nil
+}
+
+// C2SUnsubscribe implements the `unsub` C2S Command.
+// It removes the channel from ClientInfo.CurrentChannels and calls UnsubscribeSingleChat.
+func C2SUnsubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ channel, err := msg.ArgumentsAsString()
+ if err != nil {
+ return
+ }
+
+ channel = PubSubChannelPool.Intern(channel)
+
+ client.Mutex.Lock()
+ RemoveFromSliceS(&client.CurrentChannels, channel)
+ client.Mutex.Unlock()
+
+ UnsubscribeSingleChat(client, channel)
+
+ return ResponseSuccess, nil
+}
+
+// C2SSurvey implements the survey C2S Command.
+// Surveys are discarded.s
+func C2SSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ // Discard
+ return ResponseSuccess, nil
+}
+
+type followEvent struct {
+ User string `json:"u"`
+ Channel string `json:"c"`
+ NowFollowing bool `json:"f"`
+ Timestamp time.Time `json:"t"`
+}
+
+var followEvents []followEvent
+
+// followEventsLock is the lock for followEvents.
+var followEventsLock sync.Mutex
+
+// C2STrackFollow implements the `track_follow` C2S Command.
+// It adds the record to `followEvents`, which is submitted to the backend on a timer.
+func C2STrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ channel, following, err := msg.ArgumentsAsStringAndBool()
+ if err != nil {
+ return
+ }
+ now := time.Now()
+
+ channel = TwitchChannelPool.Intern(channel)
+
+ followEventsLock.Lock()
+ followEvents = append(followEvents, followEvent{User: client.TwitchUsername, Channel: channel, NowFollowing: following, Timestamp: now})
+ followEventsLock.Unlock()
+
+ return ResponseSuccess, nil
+}
+
+// AggregateEmoteUsage is a map from emoteID to a map from chatroom name to usage count.
+var aggregateEmoteUsage = make(map[int]map[string]int)
+
+// AggregateEmoteUsageLock is the lock for AggregateEmoteUsage.
+var aggregateEmoteUsageLock sync.Mutex
+
+// ErrNegativeEmoteUsage is emitted when the submitted emote usage is negative.
+var ErrNegativeEmoteUsage = errors.New("Emote usage count cannot be negative")
+
+// C2SEmoticonUses implements the `emoticon_uses` C2S Command.
+// msg.Arguments are in the JSON format of [1]map[emoteID]map[ChatroomName]float64.
+func C2SEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ // if this panics, will be caught by callHandler
+ mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{})
+
+ // Validate: male suire
+ for strEmote, val1 := range mapRoot {
+ _, err = strconv.Atoi(strEmote)
+ if err != nil {
+ return
+ }
+ mapInner := val1.(map[string]interface{})
+ for _, val2 := range mapInner {
+ var count = int(val2.(float64))
+ if count <= 0 {
+ err = ErrNegativeEmoteUsage
+ return
+ }
+ }
+ }
+
+ aggregateEmoteUsageLock.Lock()
+ defer aggregateEmoteUsageLock.Unlock()
+
+ var total int
+
+ for strEmote, val1 := range mapRoot {
+ var emoteID int
+ emoteID, err = strconv.Atoi(strEmote)
+ if err != nil {
+ return
+ }
+
+ destMapInner, ok := aggregateEmoteUsage[emoteID]
+ if !ok {
+ destMapInner = make(map[string]int)
+ aggregateEmoteUsage[emoteID] = destMapInner
+ }
+
+ mapInner := val1.(map[string]interface{})
+ for roomName, val2 := range mapInner {
+ var count = int(val2.(float64))
+ if count > 200 {
+ count = 200
+ }
+ roomName = TwitchChannelPool.Intern(roomName)
+ destMapInner[roomName] += count
+ total += count
+ }
+ }
+
+ Statistics.EmotesReportedTotal += uint64(total)
+
+ return ResponseSuccess, nil
+}
+
+// is_init_func
+func aggregateDataSender() {
+ for {
+ time.Sleep(5 * time.Minute)
+ aggregateDataSender_do()
+ }
+}
+
+func aggregateDataSender_do() {
+ followEventsLock.Lock()
+ follows := followEvents
+ followEvents = nil
+ followEventsLock.Unlock()
+ aggregateEmoteUsageLock.Lock()
+ emoteUsage := aggregateEmoteUsage
+ aggregateEmoteUsage = make(map[int]map[string]int)
+ aggregateEmoteUsageLock.Unlock()
+
+ reportForm := url.Values{}
+
+ followJSON, err := json.Marshal(follows)
+ if err != nil {
+ log.Println("error reporting aggregate data:", err)
+ } else {
+ reportForm.Set("follows", string(followJSON))
+ }
+
+ strEmoteUsage := make(map[string]map[string]int)
+ for emoteID, usageByChannel := range emoteUsage {
+ strEmoteID := strconv.Itoa(emoteID)
+ strEmoteUsage[strEmoteID] = usageByChannel
+ }
+ emoteJSON, err := json.Marshal(strEmoteUsage)
+ if err != nil {
+ log.Println("error reporting aggregate data:", err)
+ } else {
+ reportForm.Set("emotes", string(emoteJSON))
+ }
+
+ err = Backend.SendAggregatedData(reportForm)
+ if err != nil {
+ log.Println("error reporting aggregate data:", err)
+ return
+ }
+
+ // done
+}
+
+type bunchedRequest struct {
+ Command Command
+ Param string
+}
+
+type cachedBunchedResponse struct {
+ Response string
+ Timestamp time.Time
+}
+type bunchSubscriber struct {
+ Client *ClientInfo
+ MessageID int
+}
+
+type bunchSubscriberList struct {
+ sync.Mutex
+ Members []bunchSubscriber
+}
+
+type cacheStatus byte
+
+const (
+ CacheStatusNotFound = iota
+ CacheStatusFound
+ CacheStatusExpired
+)
+
+var pendingBunchedRequests = make(map[bunchedRequest]*bunchSubscriberList)
+var pendingBunchLock sync.Mutex
+var bunchCache = make(map[bunchedRequest]cachedBunchedResponse)
+var bunchCacheLock sync.RWMutex
+var bunchCacheCleanupSignal = sync.NewCond(&bunchCacheLock)
+var bunchCacheLastCleanup time.Time
+
+func bunchedRequestFromCM(msg *ClientMessage) bunchedRequest {
+ return bunchedRequest{Command: msg.Command, Param: copyString(msg.origArguments)}
+}
+
+// is_init_func
+func bunchCacheJanitor() {
+ go func() {
+ for {
+ time.Sleep(30 * time.Minute)
+ bunchCacheCleanupSignal.Signal()
+ }
+ }()
+
+ bunchCacheLock.Lock()
+ for {
+ // Unlocks CachedBunchLock, waits for signal, re-locks
+ bunchCacheCleanupSignal.Wait()
+
+ if bunchCacheLastCleanup.After(time.Now().Add(-1 * time.Second)) {
+ // skip if it's been less than 1 second
+ continue
+ }
+
+ // CachedBunchLock is held here
+ keepIfAfter := time.Now().Add(-5 * time.Minute)
+ for req, resp := range bunchCache {
+ if !resp.Timestamp.After(keepIfAfter) {
+ delete(bunchCache, req)
+ }
+ }
+ bunchCacheLastCleanup = time.Now()
+ // Loop and Wait(), which re-locks
+ }
+}
+
+var emptyCachedBunchedResponse cachedBunchedResponse
+
+func bunchGetCacheStatus(br bunchedRequest, client *ClientInfo) (cacheStatus, cachedBunchedResponse) {
+ bunchCacheLock.RLock()
+ defer bunchCacheLock.RUnlock()
+ cachedResponse, ok := bunchCache[br]
+ if ok && cachedResponse.Timestamp.After(time.Now().Add(-5*time.Minute)) {
+ return CacheStatusFound, cachedResponse
+ } else if ok {
+ return CacheStatusExpired, emptyCachedBunchedResponse
+ }
+ return CacheStatusNotFound, emptyCachedBunchedResponse
+}
+
+func normalizeBunchedRequest(br bunchedRequest) bunchedRequest {
+ if br.Command == "get_link" {
+ // TODO
+ }
+ return br
+}
+
+// C2SHandleBunchedCommand handles C2S Commands such as `get_link`.
+// It makes a request to the backend server for the data, but any other requests coming in while the first is pending also get the responses from the first one.
+// Additionally, results are cached.
+func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ // FIXME(riking): Function is too complex
+
+ br := bunchedRequestFromCM(&msg)
+ br = normalizeBunchedRequest(br)
+
+ cacheStatus, cachedResponse := bunchGetCacheStatus(br, client)
+
+ if cacheStatus == CacheStatusFound {
+ var response ClientMessage
+ response.Command = SuccessCommand
+ response.MessageID = msg.MessageID
+ response.origArguments = cachedResponse.Response
+ response.parseOrigArguments()
+
+ return response, nil
+ } else if cacheStatus == CacheStatusExpired {
+ // Wake up the lazy janitor
+ bunchCacheCleanupSignal.Signal()
+ }
+
+ pendingBunchLock.Lock()
+ defer pendingBunchLock.Unlock()
+ list, ok := pendingBunchedRequests[br]
+ if ok {
+ list.Lock()
+ AddToSliceB(&list.Members, client, msg.MessageID)
+ list.Unlock()
+
+ return ClientMessage{Command: AsyncResponseCommand}, nil
+ }
+
+ pendingBunchedRequests[br] = &bunchSubscriberList{Members: []bunchSubscriber{{Client: client, MessageID: msg.MessageID}}}
+
+ go func(request bunchedRequest) {
+ respStr, err := Backend.SendRemoteCommandCached(string(request.Command), request.Param, AuthInfo{})
+
+ var msg ClientMessage
+ if err == nil {
+ msg.Command = SuccessCommand
+ msg.origArguments = respStr
+ msg.parseOrigArguments()
+ } else {
+ msg.Command = ErrorCommand
+ msg.Arguments = err.Error()
+ }
+
+ if err == nil {
+ bunchCacheLock.Lock()
+ bunchCache[request] = cachedBunchedResponse{Response: respStr, Timestamp: time.Now()}
+ bunchCacheLock.Unlock()
+ }
+
+ pendingBunchLock.Lock()
+ bsl := pendingBunchedRequests[request]
+ delete(pendingBunchedRequests, request)
+ pendingBunchLock.Unlock()
+
+ bsl.Lock()
+ for _, member := range bsl.Members {
+ msg.MessageID = member.MessageID
+ select {
+ case member.Client.MessageChannel <- msg:
+ case <-member.Client.MsgChannelIsDone:
+ }
+ }
+ bsl.Unlock()
+ }(br)
+
+ return ClientMessage{Command: AsyncResponseCommand}, nil
+}
+
+func C2SHandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) {
+ client.MsgChannelKeepalive.Add(1)
+ go doRemoteCommand(conn, msg, client)
+
+ return ClientMessage{Command: AsyncResponseCommand}, nil
+}
+
+const AuthorizationFailedErrorString = "Failed to verify your Twitch username."
+const AuthorizationNeededError = "You must be signed in to use that command."
+
+func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo) {
+ resp, err := Backend.SendRemoteCommandCached(string(msg.Command), copyString(msg.origArguments), client.AuthInfo)
+
+ if err == ErrAuthorizationNeeded {
+ if client.TwitchUsername == "" {
+ // Not logged in
+ client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError}
+ client.MsgChannelKeepalive.Done()
+ return
+ }
+ client.StartAuthorization(func(_ *ClientInfo, success bool) {
+ if success {
+ doRemoteCommand(conn, msg, client)
+ } else {
+ client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString}
+ client.MsgChannelKeepalive.Done()
+ }
+ })
+ return // without keepalive.Done()
+ } else if bfe, ok := err.(ErrForwardedFromBackend); ok {
+ client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError}
+ } else if err != nil {
+ client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}
+ } else {
+ msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp}
+ msg.parseOrigArguments()
+ client.MessageChannel <- msg
+ }
+ client.MsgChannelKeepalive.Done()
+}
diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go
new file mode 100644
index 00000000..723f4fb2
--- /dev/null
+++ b/socketserver/server/handlecore.go
@@ -0,0 +1,711 @@
+package server // import "bitbucket.org/stendec/frankerfacez/socketserver/server"
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/url"
+ "os"
+ "os/signal"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+ "unicode/utf8"
+
+ "github.com/gorilla/websocket"
+)
+
+// SuccessCommand is a Reply Command to indicate success in reply to a C2S Command.
+const SuccessCommand Command = "ok"
+
+// ErrorCommand is a Reply Command to indicate that a C2S Command failed.
+const ErrorCommand Command = "error"
+
+// HelloCommand is a C2S Command.
+// HelloCommand must be the Command of the first ClientMessage sent during a connection.
+// Sending any other command will result in a CloseFirstMessageNotHello.
+const HelloCommand Command = "hello"
+
+// ReadyCommand is a C2S Command.
+// It indicates that the client is finished sending the initial 'sub' commands and the server should send the backlog.
+const ReadyCommand Command = "ready"
+
+const SetUserCommand Command = "setuser"
+
+// AuthorizeCommand is a S2C Command sent as part of Twitch username validation.
+const AuthorizeCommand Command = "do_authorize"
+
+// AsyncResponseCommand is a pseudo-Reply Command.
+// It indicates that the Reply Command to the client's C2S Command will be delivered
+// on a goroutine over the ClientInfo.MessageChannel and should not be delivered immediately.
+const AsyncResponseCommand Command = "_async"
+
+const defaultMinMemoryKB = 1024 * 24
+
+// DotTwitchDotTv is the .twitch.tv suffix.
+const DotTwitchDotTv = ".twitch.tv"
+
+const dotCbenniDotCom = ".cbenni.com"
+
+var OriginRegexp = regexp.MustCompile("(" + DotTwitchDotTv + "|" + dotCbenniDotCom + ")" + "$")
+
+// ResponseSuccess is a Reply ClientMessage with the MessageID not yet filled out.
+var ResponseSuccess = ClientMessage{Command: SuccessCommand}
+
+// Configuration is the active ConfigFile.
+var Configuration *ConfigFile
+
+var janitorsOnce sync.Once
+
+var CommandPool *StringPool
+var PubSubChannelPool *StringPool
+var TwitchChannelPool *StringPool
+
+// SetupServerAndHandle starts all background goroutines and registers HTTP listeners on the given ServeMux.
+// Essentially, this function completely preps the server for a http.ListenAndServe call.
+// (Uses http.DefaultServeMux if `serveMux` is nil.)
+func SetupServerAndHandle(config *ConfigFile, serveMux *http.ServeMux) {
+ Configuration = config
+
+ if config.MinMemoryKBytes == 0 {
+ config.MinMemoryKBytes = defaultMinMemoryKB
+ }
+
+ Backend = setupBackend(config)
+
+ if serveMux == nil {
+ serveMux = http.DefaultServeMux
+ }
+
+ bannerBytes, err := ioutil.ReadFile("index.html")
+ if err != nil {
+ log.Fatalln("Could not open index.html:", err)
+ }
+ BannerHTML = bannerBytes
+
+ serveMux.HandleFunc("/", HTTPHandleRootURL)
+ serveMux.Handle("/.well-known/", http.FileServer(http.Dir("/tmp/letsencrypt/")))
+ serveMux.HandleFunc("/healthcheck", HTTPSayOK)
+ serveMux.HandleFunc("/stats", HTTPShowStatistics)
+ serveMux.HandleFunc("/hll/", HTTPShowHLL)
+ serveMux.HandleFunc("/hll_force_write", HTTPWriteHLL)
+
+ serveMux.HandleFunc("/drop_backlog", HTTPBackendDropBacklog)
+ serveMux.HandleFunc("/uncached_pub", HTTPBackendUncachedPublish)
+ serveMux.HandleFunc("/cached_pub", HTTPBackendCachedPublish)
+ serveMux.HandleFunc("/get_sub_count", HTTPGetSubscriberCount)
+
+ announceForm, err := Backend.SealRequest(url.Values{
+ "startup": []string{"1"},
+ })
+ if err != nil {
+ log.Fatalln("Unable to seal requests:", err)
+ }
+ resp, err := Backend.HTTPClient.PostForm(Backend.announceStartupURL, announceForm)
+ if err != nil {
+ log.Println("could not announce startup to backend:", err)
+ } else {
+ resp.Body.Close()
+ Backend.lastSuccessLock.Lock()
+ Backend.lastSuccess[bPathAnnounceStartup] = time.Now().UTC()
+ Backend.lastSuccessLock.Unlock()
+ }
+
+ janitorsOnce.Do(startJanitors)
+}
+
+func init() {
+ setupInterning()
+}
+
+// startJanitors starts the 'is_init_func' goroutines
+func startJanitors() {
+ loadUniqueUsers()
+
+ go authorizationJanitor()
+ go aggregateDataSender()
+ go bunchCacheJanitor()
+ go cachedMessageJanitor()
+ go commandCounter()
+ go pubsubJanitor()
+
+ go ircConnection()
+ go shutdownHandler()
+}
+
+// is_init_func
+func shutdownHandler() {
+ ch := make(chan os.Signal)
+ signal.Notify(ch, syscall.SIGUSR1)
+ signal.Notify(ch, syscall.SIGTERM)
+ <-ch
+ log.Println("Shutting down...")
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ writeHLL()
+ wg.Done()
+ }()
+
+ StopAcceptingConnections = true
+ close(StopAcceptingConnectionsCh)
+
+ time.Sleep(1 * time.Second)
+ wg.Wait()
+ os.Exit(0)
+}
+
+// is_init_func +test
+func dumpStackOnCtrlZ() {
+ ch := make(chan os.Signal)
+ signal.Notify(ch, syscall.SIGTSTP)
+ for range ch {
+ fmt.Println("Got ^Z")
+
+ buf := make([]byte, 10000)
+ byteCnt := runtime.Stack(buf, true)
+ fmt.Println(string(buf[:byteCnt]))
+ }
+}
+
+// HTTPSayOK replies with 200 and a body of "ok\n".
+func HTTPSayOK(w http.ResponseWriter, _ *http.Request) {
+ w.(interface {
+ WriteString(string) error
+ }).WriteString("ok\n")
+}
+
+// SocketUpgrader is the websocket.Upgrader currently in use.
+var SocketUpgrader = websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ CheckOrigin: func(r *http.Request) bool {
+ return r.Header.Get("Origin") == "" || OriginRegexp.MatchString(r.Header.Get("Origin"))
+ },
+}
+
+// BannerHTML is the content served to web browsers viewing the socket server website.
+// Memes go here.
+var BannerHTML []byte
+
+// StopAcceptingConnectionsCh is closed while the server is shutting down.
+var StopAcceptingConnectionsCh = make(chan struct{})
+var StopAcceptingConnections = false
+
+// HTTPHandleRootURL is the http.HandleFunc for requests on `/`.
+// It either uses the SocketUpgrader or writes out the BannerHTML.
+func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/" {
+ http.NotFound(w, r)
+ fmt.Println(404)
+ return
+ }
+
+ // racy, but should be ok?
+ if StopAcceptingConnections {
+ w.WriteHeader(503)
+ fmt.Fprint(w, "server is shutting down")
+ return
+ }
+
+ if strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {
+ updateSysMem()
+
+ if Statistics.SysMemFreeKB > 0 && Statistics.SysMemFreeKB < Configuration.MinMemoryKBytes {
+ atomic.AddUint64(&Statistics.LowMemDroppedConnections, 1)
+ w.WriteHeader(503)
+ fmt.Fprint(w, "error: low memory")
+ return
+ }
+
+ if Configuration.MaxClientCount != 0 {
+ curClients := atomic.LoadUint64(&Statistics.CurrentClientCount)
+ if curClients >= Configuration.MaxClientCount {
+ w.WriteHeader(503)
+ fmt.Fprint(w, "error: client limit reached")
+ return
+ }
+ }
+
+ conn, err := SocketUpgrader.Upgrade(w, r, nil)
+ if err != nil {
+ fmt.Fprintf(w, "error: %v", err)
+ return
+ }
+ RunSocketConnection(conn)
+
+ return
+ } else {
+ w.Write(BannerHTML)
+ }
+}
+
+// ErrProtocolGeneric is sent in a ErrorCommand Reply.
+var ErrProtocolGeneric error = errors.New("FFZ Socket protocol error.")
+
+// ErrProtocolNegativeMsgID is sent in a ErrorCommand Reply when a negative MessageID is received.
+var ErrProtocolNegativeMsgID error = errors.New("FFZ Socket protocol error: negative or zero message ID.")
+
+// ErrExpectedSingleString is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedSingleString = errors.New("Error: Expected single string as arguments.")
+
+// ErrExpectedSingleInt is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedSingleInt = errors.New("Error: Expected single integer as arguments.")
+
+// ErrExpectedTwoStrings is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedTwoStrings = errors.New("Error: Expected array of string, string as arguments.")
+
+// ErrExpectedStringAndBool is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedStringAndBool = errors.New("Error: Expected array of string, bool as arguments.")
+
+// ErrExpectedStringAndInt is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedStringAndInt = errors.New("Error: Expected array of string, int as arguments.")
+
+// ErrExpectedStringAndIntGotFloat is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
+var ErrExpectedStringAndIntGotFloat = errors.New("Error: Second argument was a float, expected an integer.")
+
+// CloseGoingAway is sent when the server is restarting.
+var CloseGoingAway = websocket.CloseError{Code: websocket.CloseGoingAway, Text: "server restarting"}
+
+// CloseRebalance is sent when the server has too many clients and needs to shunt some to another server.
+var CloseRebalance = websocket.CloseError{Code: websocket.CloseGoingAway, Text: "kicked for rebalancing, please select a new server"}
+
+// CloseGotBinaryMessage is the termination reason when the client sends a binary websocket frame.
+var CloseGotBinaryMessage = websocket.CloseError{Code: websocket.CloseUnsupportedData, Text: "got binary packet"}
+
+// CloseTimedOut is the termination reason when the client fails to send or respond to ping frames.
+var CloseTimedOut = websocket.CloseError{Code: 3003, Text: "no ping replies for 5 minutes"}
+
+// CloseTooManyBufferedMessages is the termination reason when the sending thread buffers too many messages.
+var CloseTooManyBufferedMessages = websocket.CloseError{Code: websocket.CloseMessageTooBig, Text: "too many pending messages"}
+
+// CloseFirstMessageNotHello is the termination reason
+var CloseFirstMessageNotHello = websocket.CloseError{
+ Text: "Error - the first message sent must be a 'hello'",
+ Code: websocket.ClosePolicyViolation,
+}
+
+var CloseNonUTF8Data = websocket.CloseError{
+ Code: websocket.CloseUnsupportedData,
+ Text: "Non UTF8 data recieved. Network corruption likely.",
+}
+
+const sendMessageBufferLength = 30
+const sendMessageAbortLength = 20
+
+// RunSocketConnection contains the main run loop of a websocket connection.
+//
+// First, it sets up the channels, the ClientInfo object, and the pong frame handler.
+// It starts the reader goroutine pointing at the newly created channels.
+// The function then enters the run loop (a `for{select{}}`).
+// The run loop is broken when an object is received on errorChan, or if `hello` is not the first C2S Command.
+//
+// After the run loop stops, the function launches a goroutine to drain
+// client.MessageChannel, signals the reader goroutine to stop, unsubscribes
+// from all pub/sub channels, waits on MsgChannelKeepalive (remember, the
+// messages are being drained), and finally closes client.MessageChannel
+// (which ends the drainer goroutine).
+func RunSocketConnection(conn *websocket.Conn) {
+ // websocket.Conn is a ReadWriteCloser
+
+ atomic.AddUint64(&Statistics.ClientConnectsTotal, 1)
+ atomic.AddUint64(&Statistics.CurrentClientCount, 1)
+
+ _clientChan := make(chan ClientMessage)
+ _serverMessageChan := make(chan ClientMessage, sendMessageBufferLength)
+ _errorChan := make(chan error)
+ stoppedChan := make(chan struct{})
+
+ var client ClientInfo
+ client.MessageChannel = _serverMessageChan
+ client.RemoteAddr = conn.RemoteAddr()
+ client.MsgChannelIsDone = stoppedChan
+
+ // var report logstasher.ConnectionReport
+ // report.ConnectTime = time.Now()
+ // report.RemoteAddr = client.RemoteAddr
+
+ conn.SetPongHandler(func(pongBody string) error {
+ client.Mutex.Lock()
+ client.pingCount = 0
+ client.Mutex.Unlock()
+ return nil
+ })
+
+ // All set up, now enter the work loop
+ go runSocketReader(conn, _errorChan, _clientChan, stoppedChan)
+ closeReason := runSocketWriter(conn, &client, _errorChan, _clientChan, _serverMessageChan)
+
+ // Exit
+ closeConnection(conn, closeReason)
+ // closeConnection(conn, closeReason, &report)
+
+ // Launch message draining goroutine - we aren't out of the pub/sub records
+ go func() {
+ for _ = range _serverMessageChan {
+ }
+ }()
+
+ // Closes client.MsgChannelIsDone and also stops the reader thread
+ close(stoppedChan)
+
+ // Stop getting messages...
+ UnsubscribeAll(&client)
+
+ // Wait for pending jobs to finish...
+ client.MsgChannelKeepalive.Wait()
+ client.MessageChannel = nil
+
+ // And done.
+ // Close the channel so the draining goroutine can finish, too.
+ close(_serverMessageChan)
+
+ if !StopAcceptingConnections {
+ // Don't perform high contention operations when server is closing
+ atomic.AddUint64(&Statistics.CurrentClientCount, NegativeOne)
+ atomic.AddUint64(&Statistics.ClientDisconnectsTotal, 1)
+
+ // report.UsernameWasValidated = client.UsernameValidated
+ // report.TwitchUsername = client.TwitchUsername
+ // logstasher.Submit(&report)
+ }
+}
+
+func runSocketReader(conn *websocket.Conn, errorChan chan<- error, clientChan chan<- ClientMessage, stoppedChan <-chan struct{}) {
+ var msg ClientMessage
+ var messageType int
+ var packet []byte
+ var err error
+
+ defer close(errorChan)
+ defer close(clientChan)
+
+ for ; err == nil; messageType, packet, err = conn.ReadMessage() {
+ if messageType == websocket.BinaryMessage {
+ err = &CloseGotBinaryMessage
+ break
+ }
+ if messageType == websocket.CloseMessage {
+ err = io.EOF
+ break
+ }
+
+ UnmarshalClientMessage(packet, messageType, &msg)
+ if msg.MessageID == 0 {
+ continue
+ }
+ select {
+ case clientChan <- msg:
+ case <-stoppedChan:
+ return
+ }
+ }
+
+ select {
+ case errorChan <- err:
+ case <-stoppedChan:
+ }
+ // exit goroutine
+}
+
+func runSocketWriter(conn *websocket.Conn, client *ClientInfo, errorChan <-chan error, clientChan <-chan ClientMessage, serverMessageChan <-chan ClientMessage) websocket.CloseError {
+ for {
+ select {
+ case err := <-errorChan:
+ if err == io.EOF {
+ return websocket.CloseError{
+ Code: websocket.CloseGoingAway,
+ Text: err.Error(),
+ }
+ } else if closeMsg, isClose := err.(*websocket.CloseError); isClose {
+ return *closeMsg
+ } else {
+ return websocket.CloseError{
+ Code: websocket.CloseInternalServerErr,
+ Text: err.Error(),
+ }
+ }
+
+ case msg := <-clientChan:
+ if client.VersionString == "" && msg.Command != HelloCommand {
+ return CloseFirstMessageNotHello
+ }
+
+ for _, char := range msg.Command {
+ if char == utf8.RuneError {
+ return CloseNonUTF8Data
+ }
+ }
+
+ DispatchC2SCommand(conn, client, msg)
+
+ case msg := <-serverMessageChan:
+ if len(serverMessageChan) > sendMessageAbortLength {
+ return CloseTooManyBufferedMessages
+ }
+ if cls, ok := msg.Arguments.(*websocket.CloseError); ok {
+ return *cls
+ }
+ SendMessage(conn, msg)
+
+ case <-time.After(1 * time.Minute):
+ client.Mutex.Lock()
+ client.pingCount++
+ tooManyPings := client.pingCount == 5
+ client.Mutex.Unlock()
+ if tooManyPings {
+ return CloseTimedOut
+ } else {
+ conn.WriteControl(websocket.PingMessage, []byte(strconv.FormatInt(time.Now().Unix(), 10)), getDeadline())
+ }
+
+ case <-StopAcceptingConnectionsCh:
+ return CloseGoingAway
+ }
+ }
+}
+
+func getDeadline() time.Time {
+ return time.Now().Add(1 * time.Minute)
+}
+
+func closeConnection(conn *websocket.Conn, closeMsg websocket.CloseError) {
+ closeTxt := closeMsg.Text
+ if strings.Contains(closeTxt, "read: connection reset by peer") {
+ closeTxt = "read: connection reset by peer"
+ } else if strings.Contains(closeTxt, "use of closed network connection") {
+ closeTxt = "read: use of closed network connection"
+ } else if closeMsg.Code == 1001 {
+ closeTxt = "clean shutdown"
+ }
+
+ // report.DisconnectCode = closeMsg.Code
+ // report.DisconnectReason = closeTxt
+ // report.DisconnectTime = time.Now()
+
+ conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(closeMsg.Code, closeMsg.Text), getDeadline())
+ conn.Close()
+}
+
+// SendMessage sends a ClientMessage over the websocket connection with a timeout.
+// If marshalling the ClientMessage fails, this function will panic.
+func SendMessage(conn *websocket.Conn, msg ClientMessage) {
+ messageType, packet, err := MarshalClientMessage(msg)
+ if err != nil {
+ panic(fmt.Sprintf("failed to marshal: %v %v", err, msg))
+ }
+ conn.SetWriteDeadline(getDeadline())
+ conn.WriteMessage(messageType, packet)
+ atomic.AddUint64(&Statistics.MessagesSent, 1)
+}
+
+// UnmarshalClientMessage unpacks websocket TextMessage into a ClientMessage provided in the `v` parameter.
+func UnmarshalClientMessage(data []byte, payloadType int, v interface{}) (err error) {
+ var spaceIdx int
+
+ out := v.(*ClientMessage)
+ dataStr := string(data)
+
+ // Message ID
+ spaceIdx = strings.IndexRune(dataStr, ' ')
+ if spaceIdx == -1 {
+ return ErrProtocolGeneric
+ }
+ messageID, err := strconv.Atoi(dataStr[:spaceIdx])
+ if messageID < -1 || messageID == 0 {
+ return ErrProtocolNegativeMsgID
+ }
+
+ out.MessageID = messageID
+ dataStr = dataStr[spaceIdx+1:]
+
+ spaceIdx = strings.IndexRune(dataStr, ' ')
+ if spaceIdx == -1 {
+ out.Command = CommandPool.InternCommand(dataStr)
+ out.Arguments = nil
+ return nil
+ } else {
+ out.Command = CommandPool.InternCommand(dataStr[:spaceIdx])
+ }
+ dataStr = dataStr[spaceIdx+1:]
+ argumentsJSON := string([]byte(dataStr))
+ out.origArguments = argumentsJSON
+ err = out.parseOrigArguments()
+ if err != nil {
+ return
+ }
+ return nil
+}
+
+func (cm *ClientMessage) parseOrigArguments() error {
+ err := json.Unmarshal([]byte(cm.origArguments), &cm.Arguments)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func MarshalClientMessage(clientMessage interface{}) (payloadType int, data []byte, err error) {
+ var msg ClientMessage
+ var ok bool
+ msg, ok = clientMessage.(ClientMessage)
+ if !ok {
+ pMsg, ok := clientMessage.(*ClientMessage)
+ if !ok {
+ panic("MarshalClientMessage: argument needs to be a ClientMessage")
+ }
+ msg = *pMsg
+ }
+ var dataStr string
+
+ if msg.Command == "" && msg.MessageID == 0 {
+ panic("MarshalClientMessage: attempt to send an empty ClientMessage")
+ }
+
+ if msg.Command == "" {
+ msg.Command = SuccessCommand
+ }
+ if msg.MessageID == 0 {
+ msg.MessageID = -1
+ }
+
+ if msg.Arguments != nil {
+ argBytes, err := json.Marshal(msg.Arguments)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes))
+ } else {
+ dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command)
+ }
+
+ return websocket.TextMessage, []byte(dataStr), nil
+}
+
+// ArgumentsAsString parses the arguments of the ClientMessage as a single string.
+func (cm *ClientMessage) ArgumentsAsString() (string1 string, err error) {
+ var ok bool
+ string1, ok = cm.Arguments.(string)
+ if !ok {
+ err = ErrExpectedSingleString
+ return
+ } else {
+ return string1, nil
+ }
+}
+
+// ArgumentsAsInt parses the arguments of the ClientMessage as a single int.
+func (cm *ClientMessage) ArgumentsAsInt() (int1 int64, err error) {
+ var ok bool
+ var num float64
+ num, ok = cm.Arguments.(float64)
+ if !ok {
+ err = ErrExpectedSingleInt
+ return
+ } else {
+ int1 = int64(num)
+ return int1, nil
+ }
+}
+
+// ArgumentsAsTwoStrings parses the arguments of the ClientMessage as an array of two strings.
+func (cm *ClientMessage) ArgumentsAsTwoStrings() (string1, string2 string, err error) {
+ var ok bool
+ var ary []interface{}
+ ary, ok = cm.Arguments.([]interface{})
+ if !ok {
+ err = ErrExpectedTwoStrings
+ return
+ } else {
+ if len(ary) != 2 {
+ err = ErrExpectedTwoStrings
+ return
+ }
+ string1, ok = ary[0].(string)
+ if !ok {
+ err = ErrExpectedTwoStrings
+ return
+ }
+ // clientID can be null
+ if ary[1] == nil {
+ return string1, "", nil
+ }
+ string2, ok = ary[1].(string)
+ if !ok {
+ err = ErrExpectedTwoStrings
+ return
+ }
+ return string1, string2, nil
+ }
+}
+
+// ArgumentsAsStringAndInt parses the arguments of the ClientMessage as an array of a string and an int.
+func (cm *ClientMessage) ArgumentsAsStringAndInt() (string1 string, int int64, err error) {
+ var ok bool
+ var ary []interface{}
+ ary, ok = cm.Arguments.([]interface{})
+ if !ok {
+ err = ErrExpectedStringAndInt
+ return
+ } else {
+ if len(ary) != 2 {
+ err = ErrExpectedStringAndInt
+ return
+ }
+ string1, ok = ary[0].(string)
+ if !ok {
+ err = ErrExpectedStringAndInt
+ return
+ }
+ var num float64
+ num, ok = ary[1].(float64)
+ if !ok {
+ err = ErrExpectedStringAndInt
+ return
+ }
+ int = int64(num)
+ if float64(int) != num {
+ err = ErrExpectedStringAndIntGotFloat
+ return
+ }
+ return string1, int, nil
+ }
+}
+
+// ArgumentsAsStringAndBool parses the arguments of the ClientMessage as an array of a string and an int.
+func (cm *ClientMessage) ArgumentsAsStringAndBool() (str string, flag bool, err error) {
+ var ok bool
+ var ary []interface{}
+ ary, ok = cm.Arguments.([]interface{})
+ if !ok {
+ err = ErrExpectedStringAndBool
+ return
+ } else {
+ if len(ary) != 2 {
+ err = ErrExpectedStringAndBool
+ return
+ }
+ str, ok = ary[0].(string)
+ if !ok {
+ err = ErrExpectedStringAndBool
+ return
+ }
+ flag, ok = ary[1].(bool)
+ if !ok {
+ err = ErrExpectedStringAndBool
+ return
+ }
+ return str, flag, nil
+ }
+}
diff --git a/socketserver/internal/server/handlecore_test.go b/socketserver/server/handlecore_test.go
similarity index 79%
rename from socketserver/internal/server/handlecore_test.go
rename to socketserver/server/handlecore_test.go
index 161b5921..8ecf848b 100644
--- a/socketserver/internal/server/handlecore_test.go
+++ b/socketserver/server/handlecore_test.go
@@ -2,14 +2,15 @@ package server
import (
"fmt"
- "golang.org/x/net/websocket"
"testing"
+
+ "github.com/gorilla/websocket"
)
func ExampleUnmarshalClientMessage() {
sourceData := []byte("100 hello [\"ffz_3.5.30\",\"898b5bfa-b577-47bb-afb4-252c703b67d6\"]")
var cm ClientMessage
- err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm)
+ err := UnmarshalClientMessage(sourceData, websocket.TextMessage, &cm)
fmt.Println(err)
fmt.Println(cm.MessageID)
fmt.Println(cm.Command)
@@ -27,9 +28,9 @@ func ExampleMarshalClientMessage() {
Command: "do_authorize",
Arguments: "1234567890",
}
- data, payloadType, err := MarshalClientMessage(&cm)
+ payloadType, data, err := MarshalClientMessage(&cm)
fmt.Println(err)
- fmt.Println(payloadType == websocket.TextFrame)
+ fmt.Println(payloadType == websocket.TextMessage)
fmt.Println(string(data))
// Output:
//
@@ -40,7 +41,7 @@ func ExampleMarshalClientMessage() {
func TestArgumentsAsStringAndBool(t *testing.T) {
sourceData := []byte("1 foo [\"string\", false]")
var cm ClientMessage
- err := UnmarshalClientMessage(sourceData, websocket.TextFrame, &cm)
+ err := UnmarshalClientMessage(sourceData, websocket.TextMessage, &cm)
if err != nil {
t.Fatal(err)
}
diff --git a/socketserver/server/intern.go b/socketserver/server/intern.go
new file mode 100644
index 00000000..2f9cf416
--- /dev/null
+++ b/socketserver/server/intern.go
@@ -0,0 +1,42 @@
+package server
+
+import (
+ "sync"
+)
+
+type StringPool struct {
+ sync.RWMutex
+ lookup map[string]string
+}
+
+func NewStringPool() *StringPool {
+ return &StringPool{lookup: make(map[string]string)}
+}
+
+// doesn't lock, doesn't check for dupes.
+func (p *StringPool) _Intern_Setup(s string) {
+ p.lookup[s] = s
+}
+
+func (p *StringPool) InternCommand(s string) Command {
+ return Command(p.Intern(s))
+}
+
+func (p *StringPool) Intern(s string) string {
+ p.RLock()
+ ss, exists := p.lookup[s]
+ p.RUnlock()
+ if exists {
+ return ss
+ }
+
+ p.Lock()
+ defer p.Unlock()
+ ss, exists = p.lookup[s]
+ if exists {
+ return ss
+ }
+ ss = copyString(s)
+ p.lookup[ss] = ss
+ return ss
+}
diff --git a/socketserver/server/irc.go b/socketserver/server/irc.go
new file mode 100644
index 00000000..a5a837dc
--- /dev/null
+++ b/socketserver/server/irc.go
@@ -0,0 +1,181 @@
+package server
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/base64"
+ "log"
+ "strings"
+ "sync"
+ "time"
+
+ irc "github.com/fluffle/goirc/client"
+)
+
+type AuthCallback func(client *ClientInfo, successful bool)
+
+type PendingAuthorization struct {
+ Client *ClientInfo
+ Challenge string
+ Callback AuthCallback
+ EnteredAt time.Time
+}
+
+var PendingAuths []PendingAuthorization
+var PendingAuthLock sync.Mutex
+
+func AddPendingAuthorization(client *ClientInfo, challenge string, callback AuthCallback) {
+ PendingAuthLock.Lock()
+ defer PendingAuthLock.Unlock()
+
+ PendingAuths = append(PendingAuths, PendingAuthorization{
+ Client: client,
+ Challenge: challenge,
+ Callback: callback,
+ EnteredAt: time.Now(),
+ })
+}
+
+// is_init_func
+func authorizationJanitor() {
+ for {
+ time.Sleep(5 * time.Minute)
+ authorizationJanitor_do()
+ }
+}
+
+func authorizationJanitor_do() {
+ cullTime := time.Now().Add(-30 * time.Minute)
+
+ PendingAuthLock.Lock()
+ defer PendingAuthLock.Unlock()
+
+ newPendingAuths := make([]PendingAuthorization, 0, len(PendingAuths))
+
+ for _, v := range PendingAuths {
+ if !cullTime.After(v.EnteredAt) {
+ newPendingAuths = append(newPendingAuths, v)
+ } else {
+ go v.Callback(v.Client, false)
+ }
+ }
+
+ PendingAuths = newPendingAuths
+}
+
+func (client *ClientInfo) StartAuthorization(callback AuthCallback) {
+ if callback == nil {
+ return // callback must not be nil
+ }
+ var nonce [32]byte
+ _, err := rand.Read(nonce[:])
+ if err != nil {
+ go callback(client, false)
+ return
+ }
+ buf := bytes.NewBuffer(nil)
+ enc := base64.NewEncoder(base64.RawURLEncoding, buf)
+ enc.Write(nonce[:])
+ enc.Close()
+ challenge := buf.String()
+
+ AddPendingAuthorization(client, challenge, callback)
+
+ client.MessageChannel <- ClientMessage{MessageID: -1, Command: AuthorizeCommand, Arguments: challenge}
+}
+
+const AuthChannelName = "frankerfacezauthorizer"
+const AuthChannel = "#" + AuthChannelName
+const AuthCommand = "AUTH"
+
+var authIrcConnection *irc.Conn
+
+// is_init_func
+func ircConnection() {
+ c := irc.SimpleClient("justinfan123")
+ c.Config().Server = "irc.chat.twitch.tv"
+ authIrcConnection = c
+
+ var reconnect func(conn *irc.Conn)
+ connect := func(conn *irc.Conn) {
+ err := c.Connect()
+ if err != nil {
+ log.Println("irc: failed to connect to IRC:", err)
+ go reconnect(conn)
+ }
+ }
+
+ reconnect = func(conn *irc.Conn) {
+ time.Sleep(5 * time.Second)
+ log.Println("irc: Reconnecting…")
+ connect(conn)
+ }
+
+ c.HandleFunc(irc.CONNECTED, func(conn *irc.Conn, line *irc.Line) {
+ conn.Join(AuthChannel)
+ })
+
+ c.HandleFunc(irc.DISCONNECTED, func(conn *irc.Conn, line *irc.Line) {
+ log.Println("irc: Disconnected. Reconnecting in 5 seconds.")
+ go reconnect(conn)
+ })
+
+ c.HandleFunc(irc.PRIVMSG, func(conn *irc.Conn, line *irc.Line) {
+ channel := line.Args[0]
+ msg := line.Args[1]
+ if channel != AuthChannel || !strings.HasPrefix(msg, AuthCommand) || !line.Public() {
+ return
+ }
+
+ msgArray := strings.Split(msg, " ")
+ if len(msgArray) != 2 {
+ return
+ }
+
+ submittedUser := line.Nick
+ submittedChallenge := msgArray[1]
+
+ submitAuth(submittedUser, submittedChallenge)
+ })
+
+ connect(c)
+}
+
+func submitAuth(user, challenge string) {
+ var auth PendingAuthorization
+ var idx int = -1
+
+ PendingAuthLock.Lock()
+ for i, v := range PendingAuths {
+ if v.Client.TwitchUsername == user && v.Challenge == challenge {
+ auth = v
+ idx = i
+ break
+ }
+ }
+ if idx != -1 {
+ PendingAuths = append(PendingAuths[:idx], PendingAuths[idx+1:]...)
+ }
+ PendingAuthLock.Unlock()
+
+ if idx == -1 {
+ return // perhaps it was for another socket server
+ }
+
+ // auth is valid, and removed from pending list
+
+ var usernameChanged bool
+ auth.Client.Mutex.Lock()
+ if auth.Client.TwitchUsername == user { // recheck condition
+ auth.Client.UsernameValidated = true
+ } else {
+ usernameChanged = true
+ }
+ auth.Client.Mutex.Unlock()
+
+ if !usernameChanged {
+ auth.Callback(auth.Client, true)
+ } else {
+ auth.Callback(auth.Client, false)
+ }
+}
diff --git a/socketserver/server/logstasher/elasticsearch.go b/socketserver/server/logstasher/elasticsearch.go
new file mode 100644
index 00000000..c505e00e
--- /dev/null
+++ b/socketserver/server/logstasher/elasticsearch.go
@@ -0,0 +1,209 @@
+package logstasher
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "sync"
+ "time"
+)
+
+// ID is a 128-bit ID for an elasticsearch document.
+// Textually, it is base64-encoded.
+// The Next() method increments the ID.
+type ID struct {
+ High uint64
+ Low uint64
+}
+
+// Text converts the ID into a base64 string.
+func (id ID) String() string {
+ var buf bytes.Buffer
+ buf.Grow(21)
+ enc := base64.NewEncoder(base64.StdEncoding, &buf)
+ var bytes [16]byte
+ binary.LittleEndian.PutUint64(bytes[0:8], id.High)
+ binary.LittleEndian.PutUint64(bytes[8:16], id.Low)
+ enc.Write(bytes[:])
+ enc.Close()
+ return buf.String()
+}
+
+// Next increments the ID and returns the prior state.
+// Overflow is not checked because it's a uint64, do you really expect me to overflow that
+func (id *ID) Next() ID {
+ ret := ID{
+ High: id.High,
+ Low: id.Low,
+ }
+ id.Low++
+ return ret
+}
+
+var idPool = sync.Pool{New: func() interface{} {
+ var bytes [16]byte
+ n, err := rand.Reader.Read(bytes[:])
+ if n != 16 || err != nil {
+ panic(fmt.Errorf("Short read from crypto/rand: %v", err))
+ }
+
+ return &ID{
+ High: binary.LittleEndian.Uint64(bytes[0:8]),
+ Low: binary.LittleEndian.Uint64(bytes[8:16]),
+ }
+}}
+
+func ExampleID_Next() {
+ id := idPool.Get().(*ID).Next()
+ fmt.Println(id)
+ idPool.Put(id)
+}
+
+// Report is the interface presented to the Submit() function.
+// FillReport() is satisfied by ReportBasic, but ReportType must always be specified.
+type Report interface {
+ FillReport() error
+ ReportType() string
+
+ GetID() string
+ GetTimestamp() time.Time
+}
+
+// ReportBasic is the essential fields of any report.
+type ReportBasic struct {
+ ID string
+ Timestamp time.Time
+ Host string
+}
+
+// FillReport sets the Host and Timestamp fields.
+func (report *ReportBasic) FillReport() error {
+ report.Host = hostMarker
+ report.Timestamp = time.Now()
+ id := idPool.Get().(*ID).Next()
+ report.ID = id.String()
+ idPool.Put(id)
+ return nil
+}
+
+func (report *ReportBasic) GetID() string {
+ return report.ID
+}
+
+func (report *ReportBasic) GetTimestamp() time.Time {
+ return report.Timestamp
+}
+
+type ConnectionReport struct {
+ ReportBasic
+
+ ConnectTime time.Time
+ DisconnectTime time.Time
+ // calculated
+ ConnectionDuration time.Duration
+
+ DisconnectCode int
+ DisconnectReason string
+
+ UsernameWasValidated bool
+
+ RemoteAddr net.Addr `json:"-"` // not transmitted until I can figure out data minimization
+ TwitchUsername string `json:"-"` // also not transmitted
+}
+
+// FillReport sets all the calculated fields, and calls esReportBasic.FillReport().
+func (report *ConnectionReport) FillReport() error {
+ report.ReportBasic.FillReport()
+ report.ConnectionDuration = report.DisconnectTime.Sub(report.ConnectTime)
+ return nil
+}
+
+func (report *ConnectionReport) ReportType() string {
+ return "conn"
+}
+
+var serverPresent bool
+var esClient http.Client
+var submitChan chan Report
+var serverBase, indexPrefix, hostMarker string
+
+func checkServerPresent() {
+ if serverBase == "" {
+ serverBase = "http://localhost:9200"
+ }
+ if indexPrefix == "" {
+ indexPrefix = "sockreport"
+ }
+
+ urlHealth := fmt.Sprintf("%s/_cluster/health", serverBase)
+ resp, err := esClient.Get(urlHealth)
+ if err == nil {
+ resp.Body.Close()
+ serverPresent = true
+ submitChan = make(chan Report, 8)
+ fmt.Println("elasticsearch reports enabled")
+ go submissionWorker()
+ } else {
+ serverPresent = false
+ }
+}
+
+// Setup sets up the global variables for the package.
+func Setup(ESServer, ESIndexPrefix, ESHostname string) {
+ serverBase = ESServer
+ indexPrefix = ESIndexPrefix
+ hostMarker = ESHostname
+ checkServerPresent()
+}
+
+// Submit inserts a report into elasticsearch (this is basically a manual logstash).
+func Submit(report Report) {
+ if !serverPresent {
+ return
+ }
+
+ report.FillReport()
+ submitChan <- report
+}
+
+func submissionWorker() {
+ for report := range submitChan {
+ time := report.GetTimestamp()
+ rType := report.ReportType()
+
+ // prefix-type-date
+ indexName := fmt.Sprintf("%s-%s-%d-%d-%d", indexPrefix, rType, time.Year(), time.Month(), time.Day())
+ // base/index/type/id
+ putUrl, err := url.Parse(fmt.Sprintf("%s/%s/%s/%s", serverBase, indexName, rType, report.GetID()))
+ if err != nil {
+ panic(fmt.Errorf("logstash: cannot parse url: %v", err))
+ }
+ body, err := json.Marshal(report)
+ if err != nil {
+ panic(fmt.Errorf("logstash: cannot marshal json: %v", err))
+ }
+
+ req := &http.Request{
+ Method: "PUT",
+ URL: putUrl,
+ Body: ioutil.NopCloser(bytes.NewReader(body)),
+ }
+
+ resp, err := esClient.Do(req)
+
+ if err != nil {
+ // ignore, the show must go on
+ } else {
+ io.Copy(ioutil.Discard, resp.Body)
+ resp.Body.Close()
+ }
+ }
+}
diff --git a/socketserver/server/publisher.go b/socketserver/server/publisher.go
new file mode 100644
index 00000000..85912dd6
--- /dev/null
+++ b/socketserver/server/publisher.go
@@ -0,0 +1,239 @@
+package server
+
+import (
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+type LastSavedMessage struct {
+ Expires time.Time
+ Data string
+}
+
+// map is command -> channel -> data
+
+// CachedLastMessages is of CacheTypeLastOnly.
+// Not actually cleaned up by reaper goroutine every ~hour.
+var CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
+var CachedLSMLock sync.RWMutex
+
+func cachedMessageJanitor() {
+ for {
+ time.Sleep(1*time.Hour)
+ cachedMessageJanitor_do()
+ }
+}
+
+func cachedMessageJanitor_do() {
+ CachedLSMLock.Lock()
+ defer CachedLSMLock.Unlock()
+
+ now := time.Now()
+
+ for cmd, chanMap := range CachedLastMessages {
+ for channel, msg := range chanMap {
+ if !msg.Expires.IsZero() && msg.Expires.Before(now) {
+ delete(chanMap, channel)
+ }
+ }
+ if len(chanMap) == 0 {
+ delete(CachedLastMessages, cmd)
+ }
+ }
+}
+
+// DumpBacklogData drops all /cached_pub data.
+func DumpBacklogData() {
+ CachedLSMLock.Lock()
+ CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
+ CachedLSMLock.Unlock()
+}
+
+// SendBacklogForNewClient sends any backlog data relevant to a new client.
+// This should be done when the client sends a `ready` message.
+// This will only send data for CacheTypePersistent and CacheTypeLastOnly because those do not involve timestamps.
+func SendBacklogForNewClient(client *ClientInfo) {
+ client.Mutex.Lock() // reading CurrentChannels
+ curChannels := make([]string, len(client.CurrentChannels))
+ copy(curChannels, client.CurrentChannels)
+ client.Mutex.Unlock()
+
+ CachedLSMLock.RLock()
+ for cmd, chanMap := range CachedLastMessages {
+ if chanMap == nil {
+ continue
+ }
+ for _, channel := range curChannels {
+ msg, ok := chanMap[channel]
+ if ok {
+ msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
+ msg.parseOrigArguments()
+ client.MessageChannel <- msg
+ }
+ }
+ }
+ CachedLSMLock.RUnlock()
+}
+
+func SendBacklogForChannel(client *ClientInfo, channel string) {
+ CachedLSMLock.RLock()
+ for cmd, chanMap := range CachedLastMessages {
+ if chanMap == nil {
+ continue
+ }
+ if msg, ok := chanMap[channel]; ok {
+ msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
+ msg.parseOrigArguments()
+ client.MessageChannel <- msg
+ }
+ }
+ CachedLSMLock.RUnlock()
+}
+
+type timestampArray interface {
+ Len() int
+ GetTime(int) time.Time
+}
+
+// the CachedLSMLock must be held when calling this
+func saveLastMessage(cmd Command, channel string, expires time.Time, data string, deleting bool) {
+ chanMap, ok := CachedLastMessages[cmd]
+ if !ok {
+ if deleting {
+ return
+ }
+ chanMap = make(map[string]LastSavedMessage)
+ CachedLastMessages[cmd] = chanMap
+ }
+
+ if deleting {
+ delete(chanMap, channel)
+ } else {
+ chanMap[channel] = LastSavedMessage{Expires: expires, Data: data}
+ }
+}
+
+func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ formData, err := Backend.UnsealRequest(r.Form)
+ if err != nil {
+ w.WriteHeader(403)
+ fmt.Fprintf(w, "Error: %v", err)
+ return
+ }
+
+ confirm := formData.Get("confirm")
+ if confirm == "1" {
+ DumpBacklogData()
+ }
+}
+
+// HTTPBackendCachedPublish handles the /cached_pub route.
+// It publishes a message to clients, and then updates the in-server cache for the message.
+//
+// The 'channel' parameter is a comma-separated list of topics to publish the message to.
+// The 'args' parameter is the JSON-encoded command data.
+// If the 'delete' parameter is present, an entry is removed from the cache instead of publishing a message.
+// If the 'expires' parameter is not specified, the message will not expire (though it is only kept in-memory).
+func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ formData, err := Backend.UnsealRequest(r.Form)
+ if err != nil {
+ w.WriteHeader(403)
+ fmt.Fprintf(w, "Error: %v", err)
+ return
+ }
+
+ cmd := CommandPool.InternCommand(formData.Get("cmd"))
+ json := formData.Get("args")
+ channel := formData.Get("channel")
+ deleteMode := formData.Get("delete") != ""
+ timeStr := formData.Get("expires")
+ var expires time.Time
+ if timeStr != "" {
+ timeNum, err := strconv.ParseInt(timeStr, 10, 64)
+ if err != nil {
+ w.WriteHeader(422)
+ fmt.Fprintf(w, "error parsing time: %v", err)
+ return
+ }
+ expires = time.Unix(timeNum, 0)
+ }
+
+ var count int
+ msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
+ msg.parseOrigArguments()
+
+ channels := strings.Split(channel, ",")
+ CachedLSMLock.Lock()
+ for _, channel := range channels {
+ saveLastMessage(cmd, channel, expires, json, deleteMode)
+ }
+ CachedLSMLock.Unlock()
+ count = PublishToMultiple(channels, msg)
+
+ w.Write([]byte(strconv.Itoa(count)))
+}
+
+// HTTPBackendUncachedPublish handles the /uncached_pub route.
+// The backend can POST here to publish a message to clients with no caching.
+// The POST arguments are `cmd`, `args`, `channel`, and `scope`.
+// If "scope" is "global", then "channel" is not used.
+func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ formData, err := Backend.UnsealRequest(r.Form)
+ if err != nil {
+ w.WriteHeader(403)
+ fmt.Fprintf(w, "Error: %v", err)
+ return
+ }
+
+ cmd := formData.Get("cmd")
+ json := formData.Get("args")
+ channel := formData.Get("channel")
+ scope := formData.Get("scope")
+
+ if cmd == "" {
+ w.WriteHeader(422)
+ fmt.Fprintf(w, "Error: cmd cannot be blank")
+ return
+ }
+ if channel == "" && scope != "global" {
+ w.WriteHeader(422)
+ fmt.Fprintf(w, "Error: channel must be specified")
+ return
+ }
+
+ cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json}
+ cm.parseOrigArguments()
+ var count int
+
+ switch scope {
+ default:
+ count = PublishToMultiple(strings.Split(channel, ","), cm)
+ case "global":
+ count = PublishToAll(cm)
+ }
+ fmt.Fprint(w, count)
+}
+
+// HTTPGetSubscriberCount handles the /get_sub_count route.
+// It replies with the number of clients subscribed to a pub/sub topic.
+// A "global" option is not available, use fetch(/stats).CurrentClientCount instead.
+func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ formData, err := Backend.UnsealRequest(r.Form)
+ if err != nil {
+ w.WriteHeader(403)
+ fmt.Fprintf(w, "Error: %v", err)
+ return
+ }
+
+ channel := formData.Get("channel")
+
+ fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ",")))
+}
\ No newline at end of file
diff --git a/socketserver/server/publisher_test.go b/socketserver/server/publisher_test.go
new file mode 100644
index 00000000..ad667d1b
--- /dev/null
+++ b/socketserver/server/publisher_test.go
@@ -0,0 +1,66 @@
+package server
+
+import (
+ "testing"
+ "time"
+)
+
+func TestExpiredCleanup(t *testing.T) {
+ const cmd = "test_command"
+ const channel = "trihex"
+ const channel2 = "twitch"
+ const channel3 = "360chrism"
+ const channel4 = "qa_partner"
+
+ DumpBacklogData()
+ defer DumpBacklogData()
+
+ var zeroTime time.Time
+ hourAgo := time.Now().Add(-1*time.Hour)
+ now := time.Now()
+ hourFromNow := time.Now().Add(1*time.Hour)
+
+ saveLastMessage(cmd, channel, hourAgo, "1", false)
+ saveLastMessage(cmd, channel2, now, "2", false)
+
+ if len(CachedLastMessages) != 1 {
+ t.Error("messages not saved")
+ }
+ if len(CachedLastMessages[cmd]) != 2{
+ t.Error("messages not saved")
+ }
+
+ time.Sleep(2*time.Millisecond)
+
+ cachedMessageJanitor_do()
+
+ if len(CachedLastMessages) != 0 {
+ t.Error("messages still present")
+ }
+
+ saveLastMessage(cmd, channel, hourAgo, "1", false)
+ saveLastMessage(cmd, channel2, now, "2", false)
+ saveLastMessage(cmd, channel3, hourFromNow, "3", false)
+ saveLastMessage(cmd, channel4, zeroTime, "4", false)
+
+ if len(CachedLastMessages[cmd]) != 4 {
+ t.Error("messages not saved")
+ }
+
+ time.Sleep(2*time.Millisecond)
+
+ cachedMessageJanitor_do()
+
+ if len(CachedLastMessages) != 1 {
+ t.Error("messages not saved")
+ }
+ if len(CachedLastMessages[cmd]) != 2 {
+ t.Error("messages not saved")
+ }
+ if CachedLastMessages[cmd][channel3].Data != "3" {
+ t.Error("saved wrong message")
+ }
+ if CachedLastMessages[cmd][channel4].Data != "4" {
+ t.Error("saved wrong message")
+ }
+}
diff --git a/socketserver/server/stats.go b/socketserver/server/stats.go
new file mode 100644
index 00000000..a9401248
--- /dev/null
+++ b/socketserver/server/stats.go
@@ -0,0 +1,216 @@
+package server
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "runtime"
+ "sync"
+ "time"
+
+ linuxproc "github.com/c9s/goprocinfo/linux"
+)
+
+type StatsData struct {
+ StatsDataVersion int
+
+ StartTime time.Time
+ Uptime string
+ BuildTime string
+ BuildHash string
+
+ CachedStatsLastUpdate time.Time
+
+ Health struct {
+ IRC bool
+ Backend map[string]time.Time
+ }
+
+ CurrentClientCount uint64
+
+ PubSubChannelCount int
+
+ SysMemTotalKB uint64
+ SysMemFreeKB uint64
+ MemoryInUseKB uint64
+ MemoryRSSKB uint64
+
+ LowMemDroppedConnections uint64
+
+ MemPerClientBytes uint64
+
+ CpuUsagePct float64
+
+ ClientConnectsTotal uint64
+ ClientDisconnectsTotal uint64
+
+ ClientVersions map[string]uint64
+
+ DisconnectCodes map[string]uint64
+
+ CommandsIssuedTotal uint64
+ CommandsIssuedMap map[Command]uint64
+
+ MessagesSent uint64
+
+ EmotesReportedTotal uint64
+
+ BackendVerifyFails uint64
+
+ // DisconnectReasons is at the bottom because it has indeterminate size
+ DisconnectReasons map[string]uint64
+}
+
+// Statistics is several variables that get incremented during normal operation of the server.
+// Its structure should be versioned as it is exposed via JSON.
+//
+// Note as to threaded access - this is soft/fun data and not critical to data integrity.
+// Fix anything that -race turns up, but otherwise it's not too much of a problem.
+var Statistics = newStatsData()
+
+// CommandCounter is a channel for race-free counting of command usage.
+var CommandCounter = make(chan Command, 10)
+
+// commandCounter receives from the CommandCounter channel and uses the value to increment the values in Statistics.
+// is_init_func
+func commandCounter() {
+ for cmd := range CommandCounter {
+ Statistics.CommandsIssuedTotal++
+ Statistics.CommandsIssuedMap[cmd]++
+ }
+}
+
+// StatsDataVersion is the version of the StatsData struct.
+const StatsDataVersion = 6
+const pageSize = 4096
+
+var cpuUsage struct {
+ UserTime uint64
+ SysTime uint64
+}
+
+func newStatsData() *StatsData {
+ return &StatsData{
+ StartTime: time.Now(),
+ CommandsIssuedMap: make(map[Command]uint64),
+ DisconnectCodes: make(map[string]uint64),
+ DisconnectReasons: make(map[string]uint64),
+ ClientVersions: make(map[string]uint64),
+ StatsDataVersion: StatsDataVersion,
+ Health: struct {
+ IRC bool
+ Backend map[string]time.Time
+ }{
+ Backend: make(map[string]time.Time),
+ },
+ }
+}
+
+// SetBuildStamp should be called from the main package to identify the git build hash and build time.
+func SetBuildStamp(buildTime, buildHash string) {
+ Statistics.BuildTime = buildTime
+ Statistics.BuildHash = buildHash
+}
+
+func updateStatsIfNeeded() {
+ if time.Now().Add(-2 * time.Second).After(Statistics.CachedStatsLastUpdate) {
+ updatePeriodicStats()
+ }
+}
+
+func updatePeriodicStats() {
+ nowUpdate := time.Now()
+ timeDiff := nowUpdate.Sub(Statistics.CachedStatsLastUpdate)
+ Statistics.CachedStatsLastUpdate = nowUpdate
+
+ {
+ m := runtime.MemStats{}
+ runtime.ReadMemStats(&m)
+
+ Statistics.MemoryInUseKB = m.Alloc / 1024
+ }
+
+ {
+ pstat, err := linuxproc.ReadProcessStat("/proc/self/stat")
+ if err == nil {
+ userTicks := pstat.Utime - cpuUsage.UserTime
+ sysTicks := pstat.Stime - cpuUsage.SysTime
+ cpuUsage.UserTime = pstat.Utime
+ cpuUsage.SysTime = pstat.Stime
+
+ Statistics.CpuUsagePct = 100 * float64(userTicks+sysTicks) / (timeDiff.Seconds() * float64(ticksPerSecond))
+ Statistics.MemoryRSSKB = uint64(pstat.Rss * pageSize / 1024)
+ Statistics.MemPerClientBytes = (Statistics.MemoryRSSKB * 1024) / (Statistics.CurrentClientCount + 1)
+ }
+ updateSysMem()
+ }
+
+ {
+ ChatSubscriptionLock.RLock()
+ Statistics.PubSubChannelCount = len(ChatSubscriptionInfo)
+ ChatSubscriptionLock.RUnlock()
+
+ GlobalSubscriptionLock.RLock()
+
+ Statistics.CurrentClientCount = uint64(len(GlobalSubscriptionInfo))
+ versions := make(map[string]uint64)
+ for _, v := range GlobalSubscriptionInfo {
+ versions[v.VersionString]++
+ }
+ Statistics.ClientVersions = versions
+
+ GlobalSubscriptionLock.RUnlock()
+ }
+
+ {
+ Statistics.Uptime = nowUpdate.Sub(Statistics.StartTime).String()
+ }
+
+ {
+ Statistics.Health.IRC = authIrcConnection.Connected()
+ Backend.lastSuccessLock.Lock()
+ for k, v := range Backend.lastSuccess {
+ Statistics.Health.Backend[k] = v
+ }
+ Backend.lastSuccessLock.Unlock()
+ }
+}
+
+var sysMemLastUpdate time.Time
+var sysMemUpdateLock sync.Mutex
+
+// updateSysMem reads the system's available RAM.
+func updateSysMem() {
+ if time.Now().Add(-2 * time.Second).After(sysMemLastUpdate) {
+ sysMemUpdateLock.Lock()
+ defer sysMemUpdateLock.Unlock()
+ if !time.Now().Add(-2 * time.Second).After(sysMemLastUpdate) {
+ return
+ }
+ } else {
+ return
+ }
+ sysMemLastUpdate = time.Now()
+ memInfo, err := linuxproc.ReadMemInfo("/proc/meminfo")
+ if err == nil {
+ Statistics.SysMemTotalKB = memInfo.MemTotal
+ Statistics.SysMemFreeKB = memInfo.MemAvailable
+ }
+
+ {
+ writeHLL()
+ }
+}
+
+// HTTPShowStatistics handles the /stats endpoint. It writes out the Statistics object as indented JSON.
+func HTTPShowStatistics(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ updateStatsIfNeeded()
+
+ jsonBytes, _ := json.Marshal(Statistics)
+ outBuf := bytes.NewBuffer(nil)
+ json.Indent(outBuf, jsonBytes, "", "\t")
+
+ outBuf.WriteTo(w)
+}
diff --git a/socketserver/internal/server/publisher.go b/socketserver/server/subscriptions.go
similarity index 59%
rename from socketserver/internal/server/publisher.go
rename to socketserver/server/subscriptions.go
index d9658ac7..30bc4112 100644
--- a/socketserver/internal/server/publisher.go
+++ b/socketserver/server/subscriptions.go
@@ -4,6 +4,7 @@ package server
// If I screwed up the locking, I won't know until it's too late.
import (
+ "log"
"sync"
"time"
)
@@ -15,9 +16,43 @@ type SubscriberList struct {
var ChatSubscriptionInfo map[string]*SubscriberList = make(map[string]*SubscriberList)
var ChatSubscriptionLock sync.RWMutex
-var GlobalSubscriptionInfo SubscriberList
+var GlobalSubscriptionInfo []*ClientInfo
+var GlobalSubscriptionLock sync.RWMutex
-func PublishToChat(channel string, msg ClientMessage) (count int) {
+func CountSubscriptions(channels []string) int {
+ ChatSubscriptionLock.RLock()
+ defer ChatSubscriptionLock.RUnlock()
+
+ count := 0
+ for _, channelName := range channels {
+ list := ChatSubscriptionInfo[channelName]
+ if list != nil {
+ list.RLock()
+ count += len(list.Members)
+ list.RUnlock()
+ }
+ }
+
+ return count
+}
+
+func SubscribeChannel(client *ClientInfo, channelName string) {
+ ChatSubscriptionLock.RLock()
+ _subscribeWhileRlocked(channelName, client.MessageChannel)
+ ChatSubscriptionLock.RUnlock()
+}
+
+func SubscribeDefaults(client *ClientInfo) {
+
+}
+
+func SubscribeGlobal(client *ClientInfo) {
+ GlobalSubscriptionLock.Lock()
+ AddToSliceCl(&GlobalSubscriptionInfo, client)
+ GlobalSubscriptionLock.Unlock()
+}
+
+func PublishToChannel(channel string, msg ClientMessage) (count int) {
ChatSubscriptionLock.RLock()
list := ChatSubscriptionInfo[channel]
if list != nil {
@@ -58,15 +93,99 @@ func PublishToMultiple(channels []string, msg ClientMessage) (count int) {
}
func PublishToAll(msg ClientMessage) (count int) {
- GlobalSubscriptionInfo.RLock()
- for _, msgChan := range GlobalSubscriptionInfo.Members {
- msgChan <- msg
+ GlobalSubscriptionLock.RLock()
+ for _, client := range GlobalSubscriptionInfo {
+ select {
+ case client.MessageChannel <- msg:
+ case <-client.MsgChannelIsDone:
+ }
count++
}
- GlobalSubscriptionInfo.RUnlock()
+ GlobalSubscriptionLock.RUnlock()
return
}
+func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
+ ChatSubscriptionLock.RLock()
+ list := ChatSubscriptionInfo[channelName]
+ if list != nil {
+ list.Lock()
+ RemoveFromSliceC(&list.Members, client.MessageChannel)
+ list.Unlock()
+ }
+ ChatSubscriptionLock.RUnlock()
+}
+
+// UnsubscribeAll will unsubscribe the client from all channels,
+// AND clear the CurrentChannels / WatchingChannels fields.
+//
+// Locks:
+// - read lock to top-level maps
+// - write lock to SubscriptionInfos
+// - write lock to ClientInfo
+func UnsubscribeAll(client *ClientInfo) {
+ if StopAcceptingConnections {
+ return // no need to remove from a high-contention list when the server is closing
+ }
+
+ GlobalSubscriptionLock.Lock()
+ RemoveFromSliceCl(&GlobalSubscriptionInfo, client)
+ GlobalSubscriptionLock.Unlock()
+
+ ChatSubscriptionLock.RLock()
+ client.Mutex.Lock()
+ for _, v := range client.CurrentChannels {
+ list := ChatSubscriptionInfo[v]
+ if list != nil {
+ list.Lock()
+ RemoveFromSliceC(&list.Members, client.MessageChannel)
+ list.Unlock()
+ }
+ }
+ client.CurrentChannels = nil
+ client.Mutex.Unlock()
+ ChatSubscriptionLock.RUnlock()
+}
+
+func unsubscribeAllClients() {
+ GlobalSubscriptionLock.Lock()
+ GlobalSubscriptionInfo = nil
+ GlobalSubscriptionLock.Unlock()
+ ChatSubscriptionLock.Lock()
+ ChatSubscriptionInfo = make(map[string]*SubscriberList)
+ ChatSubscriptionLock.Unlock()
+}
+
+const ReapingDelay = 1 * time.Minute
+
+// Checks ChatSubscriptionInfo for entries with no subscribers every ReapingDelay.
+// is_init_func
+func pubsubJanitor() {
+ for {
+ time.Sleep(ReapingDelay)
+ pubsubJanitor_do()
+ }
+}
+
+func pubsubJanitor_do() {
+ var cleanedUp = make([]string, 0, 6)
+ ChatSubscriptionLock.Lock()
+ for key, val := range ChatSubscriptionInfo {
+ if val == nil || len(val.Members) == 0 {
+ delete(ChatSubscriptionInfo, key)
+ cleanedUp = append(cleanedUp, key)
+ }
+ }
+ ChatSubscriptionLock.Unlock()
+
+ if len(cleanedUp) != 0 {
+ err := Backend.SendCleanupTopicsNotice(cleanedUp)
+ if err != nil {
+ log.Println("error reporting cleaned subs:", err)
+ }
+ }
+}
+
// Add a channel to the subscriptions while holding a read-lock to the map.
// Locks:
// - ALREADY HOLDING a read-lock to the 'which' top-level map via the rlocker object
@@ -82,6 +201,14 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
list.Members = []chan<- ClientMessage{value} // Create it populated, to avoid reaper
ChatSubscriptionInfo[channelName] = list
ChatSubscriptionLock.Unlock()
+
+ go func(topic string) {
+ err := Backend.SendNewTopicNotice(topic)
+ if err != nil {
+ log.Println("error reporting new sub:", err)
+ }
+ }(channelName)
+
ChatSubscriptionLock.RLock()
} else {
list.Lock()
@@ -89,80 +216,3 @@ func _subscribeWhileRlocked(channelName string, value chan<- ClientMessage) {
list.Unlock()
}
}
-
-func SubscribeGlobal(client *ClientInfo) {
- GlobalSubscriptionInfo.Lock()
- AddToSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel)
- GlobalSubscriptionInfo.Unlock()
-}
-
-func SubscribeChat(client *ClientInfo, channelName string) {
- ChatSubscriptionLock.RLock()
- _subscribeWhileRlocked(channelName, client.MessageChannel)
- ChatSubscriptionLock.RUnlock()
-}
-
-func unsubscribeAllClients() {
- GlobalSubscriptionInfo.Lock()
- GlobalSubscriptionInfo.Members = nil
- GlobalSubscriptionInfo.Unlock()
- ChatSubscriptionLock.Lock()
- ChatSubscriptionInfo = make(map[string]*SubscriberList)
- ChatSubscriptionLock.Unlock()
-}
-
-// Unsubscribe the client from all channels, AND clear the CurrentChannels / WatchingChannels fields.
-// Locks:
-// - read lock to top-level maps
-// - write lock to SubscriptionInfos
-// - write lock to ClientInfo
-func UnsubscribeAll(client *ClientInfo) {
- client.Mutex.Lock()
- client.PendingSubscriptionsBacklog = nil
- client.PendingSubscriptionsBacklog = nil
- client.Mutex.Unlock()
-
- GlobalSubscriptionInfo.Lock()
- RemoveFromSliceC(&GlobalSubscriptionInfo.Members, client.MessageChannel)
- GlobalSubscriptionInfo.Unlock()
-
- ChatSubscriptionLock.RLock()
- client.Mutex.Lock()
- for _, v := range client.CurrentChannels {
- list := ChatSubscriptionInfo[v]
- if list != nil {
- list.Lock()
- RemoveFromSliceC(&list.Members, client.MessageChannel)
- list.Unlock()
- }
- }
- client.CurrentChannels = nil
- client.Mutex.Unlock()
- ChatSubscriptionLock.RUnlock()
-}
-
-func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
- ChatSubscriptionLock.RLock()
- list := ChatSubscriptionInfo[channelName]
- list.Lock()
- RemoveFromSliceC(&list.Members, client.MessageChannel)
- list.Unlock()
- ChatSubscriptionLock.RUnlock()
-}
-
-const ReapingDelay = 120 * time.Minute
-
-// Checks ChatSubscriptionInfo for entries with no subscribers every ReapingDelay.
-// Started from SetupServer().
-func deadChannelReaper() {
- for {
- time.Sleep(ReapingDelay)
- ChatSubscriptionLock.Lock()
- for key, val := range ChatSubscriptionInfo {
- if len(val.Members) == 0 {
- ChatSubscriptionInfo[key] = nil
- }
- }
- ChatSubscriptionLock.Unlock()
- }
-}
diff --git a/socketserver/internal/server/publisher_test.go b/socketserver/server/subscriptions_test.go
similarity index 53%
rename from socketserver/internal/server/publisher_test.go
rename to socketserver/server/subscriptions_test.go
index 2dc54ed6..fea4f533 100644
--- a/socketserver/internal/server/publisher_test.go
+++ b/socketserver/server/subscriptions_test.go
@@ -1,172 +1,21 @@
package server
import (
- "encoding/json"
"fmt"
- "github.com/satori/go.uuid"
- "golang.org/x/net/websocket"
- "io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
- "os"
"strconv"
"sync"
"syscall"
"testing"
"time"
+
+ "github.com/gorilla/websocket"
+ "github.com/satori/go.uuid"
)
-func TCountOpenFDs() uint64 {
- ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid()))
- return uint64(len(ary))
-}
-
-const IgnoreReceivedArguments = 1 + 2i
-
-func TReceiveExpectedMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) (ClientMessage, bool) {
- var msg ClientMessage
- var fail bool
- err := FFZCodec.Receive(conn, &msg)
- if err != nil {
- tb.Error(err)
- return msg, false
- }
- if msg.MessageID != messageId {
- tb.Error("Message ID was wrong. Expected", messageId, ", got", msg.MessageID, ":", msg)
- fail = true
- }
- if msg.Command != command {
- tb.Error("Command was wrong. Expected", command, ", got", msg.Command, ":", msg)
- fail = true
- }
- if arguments != IgnoreReceivedArguments {
- if arguments == nil {
- if msg.origArguments != "" {
- tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg)
- }
- } else {
- argBytes, _ := json.Marshal(arguments)
- if msg.origArguments != string(argBytes) {
- tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg)
- }
- }
- }
- return msg, !fail
-}
-
-func TSendMessage(tb testing.TB, conn *websocket.Conn, messageId int, command Command, arguments interface{}) bool {
- err := FFZCodec.Send(conn, ClientMessage{MessageID: messageId, Command: command, Arguments: arguments})
- if err != nil {
- tb.Error(err)
- }
- return err == nil
-}
-
-func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, deleteMode bool) (url.Values, error) {
- form := url.Values{}
- form.Set("cmd", string(cmd))
- argsBytes, err := json.Marshal(arguments)
- if err != nil {
- tb.Error(err)
- return nil, err
- }
- form.Set("args", string(argsBytes))
- form.Set("channel", channel)
- if deleteMode {
- form.Set("delete", "1")
- }
- form.Set("time", time.Now().Format(time.UnixDate))
-
- sealed, err := SealRequest(form)
- if err != nil {
- tb.Error(err)
- return nil, err
- }
- return sealed, nil
-}
-
-func TCheckResponse(tb testing.TB, resp *http.Response, expected string) bool {
- var failed bool
- respBytes, err := ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- respStr := string(respBytes)
-
- if err != nil {
- tb.Error(err)
- failed = true
- }
-
- if resp.StatusCode != 200 {
- tb.Error("Publish failed: ", resp.StatusCode, respStr)
- failed = true
- }
-
- if respStr != expected {
- tb.Errorf("Got wrong response from server. Expected: '%s' Got: '%s'", expected, respStr)
- failed = true
- }
- return !failed
-}
-
-type TURLs struct {
- Websocket string
- Origin string
- PubMsg string
- SavePubMsg string // update_and_pub
-}
-
-func TGetUrls(testserver *httptest.Server) TURLs {
- addr := testserver.Listener.Addr().String()
- return TURLs{
- Websocket: fmt.Sprintf("ws://%s/", addr),
- Origin: fmt.Sprintf("http://%s", addr),
- PubMsg: fmt.Sprintf("http://%s/pub_msg", addr),
- SavePubMsg: fmt.Sprintf("http://%s/update_and_pub", addr),
- }
-}
-
-func TSetup(testserver **httptest.Server, urls *TURLs) {
- DumpCache()
-
- conf := &ConfigFile{
- ServerId: 20,
- UseSSL: false,
- SocketOrigin: "localhost:2002",
- BannerHTML: `
-
-CatBag
-
-
-`,
- OurPublicKey: []byte{176, 149, 72, 209, 35, 42, 110, 220, 22, 236, 212, 129, 213, 199, 1, 227, 185, 167, 150, 159, 117, 202, 164, 100, 9, 107, 45, 141, 122, 221, 155, 73},
- OurPrivateKey: []byte{247, 133, 147, 194, 70, 240, 211, 216, 223, 16, 241, 253, 120, 14, 198, 74, 237, 180, 89, 33, 146, 146, 140, 58, 88, 160, 2, 246, 112, 35, 239, 87},
- BackendPublicKey: []byte{19, 163, 37, 157, 50, 139, 193, 85, 229, 47, 166, 21, 153, 231, 31, 133, 41, 158, 8, 53, 73, 0, 113, 91, 13, 181, 131, 248, 176, 18, 1, 107},
- }
- gconfig = conf
- SetupBackend(conf)
-
- if testserver != nil {
- serveMux := http.NewServeMux()
- SetupServerAndHandle(conf, nil, serveMux)
-
- tserv := httptest.NewUnstartedServer(serveMux)
- *testserver = tserv
- tserv.Start()
- if urls != nil {
- *urls = TGetUrls(tserv)
- }
- }
-}
+const TestOrigin = "http://www.twitch.tv"
func TestSubscriptionAndPublish(t *testing.T) {
var doneWg sync.WaitGroup
@@ -184,19 +33,30 @@ func TestSubscriptionAndPublish(t *testing.T) {
const TestData3 = false
var TestData4 = []interface{}{"str1", "str2", "str3"}
- ServerInitiatedCommands[TestCommandChan] = PushCommandCacheInfo{CacheTypeLastOnly, MsgTargetTypeChat}
- ServerInitiatedCommands[TestCommandMulti] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeMultichat}
- ServerInitiatedCommands[TestCommandGlobal] = PushCommandCacheInfo{CacheTypeTimestamps, MsgTargetTypeGlobal}
+ t.Log("TestSubscriptionAndPublish")
var server *httptest.Server
var urls TURLs
- TSetup(&server, &urls)
+
+ var backendExpected = NewTBackendRequestChecker(t,
+ TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil},
+ TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName1}, "added": []string{"t"}}, "ok", nil},
+ TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName2}, "added": []string{"t"}}, "ok", nil},
+ TExpectedBackendRequest{200, bPathAddTopic, &url.Values{"channels": []string{TestChannelName3}, "added": []string{"t"}}, "ok", nil},
+ )
+ server, _, urls = TSetup(SetupWantSocketServer|SetupWantBackendServer|SetupWantURLs, backendExpected)
+
defer server.CloseClientConnections()
defer unsubscribeAllClients()
+ defer backendExpected.Close()
var conn *websocket.Conn
+ var resp *http.Response
var err error
+ var headers http.Header = make(http.Header)
+ headers.Set("Origin", TestOrigin)
+
// client 1: sub ch1, ch2
// client 2: sub ch1, ch3
// client 3: sub none
@@ -204,15 +64,18 @@ func TestSubscriptionAndPublish(t *testing.T) {
// msg 1: ch1
// msg 2: ch2, ch3
// msg 3: chEmpty
- // msg 4: global
+ // msg 4: global uncached
// Client 1
- conn, err = websocket.Dial(urls.Websocket, "", urls.Origin)
+ conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
if err != nil {
t.Error(err)
return
}
+ // both origins need testing
+ headers.Set("Origin", "https://www.twitch.tv")
+
doneWg.Add(1)
readyWg.Add(1)
go func(conn *websocket.Conn) {
@@ -236,13 +99,14 @@ func TestSubscriptionAndPublish(t *testing.T) {
}(conn)
// Client 2
- conn, err = websocket.Dial(urls.Websocket, "", urls.Origin)
+ conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
if err != nil {
t.Error(err)
return
}
doneWg.Add(1)
+ readyWg.Wait() // enforce ordering
readyWg.Add(1)
go func(conn *websocket.Conn) {
TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()})
@@ -265,13 +129,15 @@ func TestSubscriptionAndPublish(t *testing.T) {
}(conn)
// Client 3
- conn, err = websocket.Dial(urls.Websocket, "", urls.Origin)
+ conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
if err != nil {
t.Error(err)
return
}
doneWg.Add(1)
+ readyWg.Wait() // enforce ordering
+ time.Sleep(2 * time.Millisecond)
readyWg.Add(1)
go func(conn *websocket.Conn) {
TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()})
@@ -291,7 +157,6 @@ func TestSubscriptionAndPublish(t *testing.T) {
readyWg.Wait()
var form url.Values
- var resp *http.Response
// Publish message 1 - should go to clients 1, 2
@@ -300,7 +165,7 @@ func TestSubscriptionAndPublish(t *testing.T) {
t.FailNow()
}
resp, err = http.PostForm(urls.SavePubMsg, form)
- if !TCheckResponse(t, resp, strconv.Itoa(2)) {
+ if !TCheckResponse(t, resp, strconv.Itoa(2), "pub msg 1") {
t.FailNow()
}
@@ -311,7 +176,7 @@ func TestSubscriptionAndPublish(t *testing.T) {
t.FailNow()
}
resp, err = http.PostForm(urls.SavePubMsg, form)
- if !TCheckResponse(t, resp, strconv.Itoa(2)) {
+ if !TCheckResponse(t, resp, strconv.Itoa(2), "pub msg 2") {
t.FailNow()
}
@@ -322,23 +187,23 @@ func TestSubscriptionAndPublish(t *testing.T) {
t.FailNow()
}
resp, err = http.PostForm(urls.SavePubMsg, form)
- if !TCheckResponse(t, resp, strconv.Itoa(0)) {
+ if !TCheckResponse(t, resp, strconv.Itoa(0), "pub msg 3") {
t.FailNow()
}
// Publish message 4 - should go to clients 1, 2, 3
- form, err = TSealForSavePubMsg(t, TestCommandGlobal, "", TestData4, false)
+ form, err = TSealForUncachedPubMsg(t, TestCommandGlobal, "", TestData4, "global", false)
if err != nil {
t.FailNow()
}
- resp, err = http.PostForm(urls.SavePubMsg, form)
- if !TCheckResponse(t, resp, strconv.Itoa(3)) {
+ resp, err = http.PostForm(urls.UncachedPubMsg, form)
+ if !TCheckResponse(t, resp, strconv.Itoa(3), "pub msg 4") {
t.FailNow()
}
// Start client 4
- conn, err = websocket.Dial(urls.Websocket, "", urls.Origin)
+ conn, resp, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
if err != nil {
t.Error(err)
return
@@ -367,6 +232,113 @@ func TestSubscriptionAndPublish(t *testing.T) {
doneWg.Wait()
server.Close()
+
+ clientCount := readCurrentHLL()
+ if clientCount < 3 || clientCount > 5 {
+ t.Error("clientCount outside acceptable range: expected 4, got ", clientCount)
+ }
+}
+
+func TestRestrictedCommands(t *testing.T) {
+ var doneWg sync.WaitGroup
+ var readyWg sync.WaitGroup
+
+ const TestCommandNeedsAuth = "needsauth"
+ const TestRequestData = "123456"
+ const TestRequestDataJSON = "\"" + TestRequestData + "\""
+ const TestReplyData = "success"
+ const TestUsername = "sirstendec"
+
+ var server *httptest.Server
+ var urls TURLs
+
+ t.Log("TestRestrictedCommands")
+
+ var backendExpected = NewTBackendRequestChecker(t,
+ TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil},
+ TExpectedBackendRequest{401, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"0"}, "username": []string{""}, "clientData": []string{TestRequestDataJSON}}, "", nil},
+ TExpectedBackendRequest{401, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"0"}, "username": []string{TestUsername}, "clientData": []string{TestRequestDataJSON}}, "", nil},
+ TExpectedBackendRequest{200, fmt.Sprintf("%s%s", bPathOtherCommand, TestCommandNeedsAuth), &url.Values{"authenticated": []string{"1"}, "username": []string{TestUsername}, "clientData": []string{TestRequestDataJSON}}, fmt.Sprintf("\"%s\"", TestReplyData), nil},
+ )
+ server, _, urls = TSetup(SetupWantSocketServer|SetupWantBackendServer|SetupWantURLs, backendExpected)
+
+ defer server.CloseClientConnections()
+ defer unsubscribeAllClients()
+ defer backendExpected.Close()
+
+ var conn *websocket.Conn
+ var err error
+ var challengeChan = make(chan string)
+
+ var headers http.Header = make(http.Header)
+ headers.Set("Origin", TestOrigin)
+
+ // Client 1
+ conn, _, err = websocket.DefaultDialer.Dial(urls.Websocket, headers)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ doneWg.Add(1)
+ readyWg.Add(1)
+ go func(conn *websocket.Conn) {
+ defer doneWg.Done()
+ defer conn.Close()
+ TSendMessage(t, conn, 1, HelloCommand, []interface{}{"ffz_0.0-test", uuid.NewV4().String()})
+ TReceiveExpectedMessage(t, conn, 1, SuccessCommand, IgnoreReceivedArguments)
+ TSendMessage(t, conn, 2, ReadyCommand, 0)
+ TReceiveExpectedMessage(t, conn, 2, SuccessCommand, nil)
+
+ // Should get immediate refusal because no username set
+ TSendMessage(t, conn, 3, TestCommandNeedsAuth, TestRequestData)
+ TReceiveExpectedMessage(t, conn, 3, ErrorCommand, AuthorizationNeededError)
+
+ // Set a username
+ TSendMessage(t, conn, 4, SetUserCommand, TestUsername)
+ TReceiveExpectedMessage(t, conn, 4, SuccessCommand, nil)
+
+ // Should get authorization prompt
+ TSendMessage(t, conn, 5, TestCommandNeedsAuth, TestRequestData)
+ readyWg.Done()
+ msg, success := TReceiveExpectedMessage(t, conn, -1, AuthorizeCommand, IgnoreReceivedArguments)
+ if !success {
+ t.Error("recieve authorize command failed, cannot continue")
+ return
+ }
+ challenge, err := msg.ArgumentsAsString()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ challengeChan <- challenge // mocked: sending challenge to IRC server, IRC server sends challenge to socket server
+
+ TReceiveExpectedMessage(t, conn, 5, SuccessCommand, TestReplyData)
+ }(conn)
+
+ readyWg.Wait()
+
+ challenge := <-challengeChan
+ PendingAuthLock.Lock()
+ found := false
+ for _, v := range PendingAuths {
+ if conn.LocalAddr().String() == v.Client.RemoteAddr.String() {
+ found = true
+ if v.Challenge != challenge {
+ t.Error("Challenge in array was not what client got")
+ }
+ break
+ }
+ }
+ PendingAuthLock.Unlock()
+ if !found {
+ t.Fatal("Did not find authorization challenge in the pending auths array")
+ }
+
+ submitAuth(TestUsername, challenge)
+
+ doneWg.Wait()
+ server.Close()
}
func BenchmarkUserSubscriptionSinglePublish(b *testing.B) {
@@ -396,12 +368,15 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) {
var server *httptest.Server
var urls TURLs
- TSetup(&server, &urls)
+ server, _, urls = TSetup(SetupWantSocketServer|SetupWantURLs, nil)
defer unsubscribeAllClients()
+ var headers http.Header = make(http.Header)
+ headers.Set("Origin", TestOrigin)
+
b.ResetTimer()
for i := 0; i < b.N; i++ {
- conn, err := websocket.Dial(urls.Websocket, "", urls.Origin)
+ conn, _, err := websocket.DefaultDialer.Dial(urls.Websocket, headers)
if err != nil {
b.Error(err)
break
@@ -427,7 +402,7 @@ func BenchmarkUserSubscriptionSinglePublish(b *testing.B) {
readyWg.Wait()
fmt.Println("publishing...")
- if PublishToChat(TestChannelName, message) != b.N {
+ if PublishToChannel(TestChannelName, message) != b.N {
b.Error("not enough sent")
server.CloseClientConnections()
panic("halting test instead of waiting")
diff --git a/socketserver/server/testinfra_test.go b/socketserver/server/testinfra_test.go
new file mode 100644
index 00000000..5d8a4570
--- /dev/null
+++ b/socketserver/server/testinfra_test.go
@@ -0,0 +1,357 @@
+package server
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+const (
+ SetupWantSocketServer = 1 << iota
+ SetupWantBackendServer
+ SetupWantURLs
+)
+const SetupNoServers = 0
+
+var signalCatch sync.Once
+
+func TSetup(flags int, backendChecker *TBackendRequestChecker) (socketserver *httptest.Server, backend *httptest.Server, urls TURLs) {
+ signalCatch.Do(func() {
+ go dumpStackOnCtrlZ()
+ })
+
+ DumpBacklogData()
+
+ ioutil.WriteFile("index.html", []byte(`
+
+CatBag
+
+`), 0644)
+
+ conf := &ConfigFile{
+ ServerID: 20,
+ UseSSL: false,
+ OurPublicKey: []byte{176, 149, 72, 209, 35, 42, 110, 220, 22, 236, 212, 129, 213, 199, 1, 227, 185, 167, 150, 159, 117, 202, 164, 100, 9, 107, 45, 141, 122, 221, 155, 73},
+ OurPrivateKey: []byte{247, 133, 147, 194, 70, 240, 211, 216, 223, 16, 241, 253, 120, 14, 198, 74, 237, 180, 89, 33, 146, 146, 140, 58, 88, 160, 2, 246, 112, 35, 239, 87},
+ BackendPublicKey: []byte{19, 163, 37, 157, 50, 139, 193, 85, 229, 47, 166, 21, 153, 231, 31, 133, 41, 158, 8, 53, 73, 0, 113, 91, 13, 181, 131, 248, 176, 18, 1, 107},
+ }
+
+ if flags&SetupWantBackendServer != 0 {
+ backend = httptest.NewServer(backendChecker)
+ conf.BackendURL = fmt.Sprintf("http://%s", backend.Listener.Addr().String())
+ }
+
+ Configuration = conf
+ setupBackend(conf)
+
+ if flags&SetupWantSocketServer != 0 {
+ serveMux := http.NewServeMux()
+ SetupServerAndHandle(conf, serveMux)
+ dumpUniqueUsers()
+
+ socketserver = httptest.NewServer(serveMux)
+ }
+
+ if flags&SetupWantURLs != 0 {
+ urls = TGetUrls(socketserver, backend)
+ }
+ return
+}
+
+type TBC interface {
+ Error(args ...interface{})
+ Errorf(format string, args ...interface{})
+}
+
+const MethodIsPost = "POST"
+
+type TExpectedBackendRequest struct {
+ ResponseCode int
+ Path string
+ // Method string // always POST
+ PostForm *url.Values
+ Response string
+ ResponseHeaders http.Header
+}
+
+func (er *TExpectedBackendRequest) String() string {
+ if MethodIsPost == "" {
+ return er.Path
+ }
+ return fmt.Sprint("%s %s: %s", MethodIsPost, er.Path, er.PostForm.Encode())
+}
+
+type TBackendRequestChecker struct {
+ ExpectedRequests []TExpectedBackendRequest
+
+ currentRequest int
+ tb TBC
+ mutex sync.Mutex
+}
+
+func NewTBackendRequestChecker(tb TBC, urls ...TExpectedBackendRequest) *TBackendRequestChecker {
+ return &TBackendRequestChecker{ExpectedRequests: urls, tb: tb, currentRequest: 0}
+}
+
+func (backend *TBackendRequestChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ backend.mutex.Lock()
+ defer backend.mutex.Unlock()
+
+ if r.Method != MethodIsPost {
+ backend.tb.Errorf("Bad backend request: was not a POST. %v", r)
+ return
+ }
+
+ r.ParseForm()
+
+ unsealedForm, err := Backend.UnsealRequest(r.PostForm)
+ if err != nil {
+ backend.tb.Errorf("Failed to unseal backend request: %v", err)
+ }
+
+ if backend.currentRequest >= len(backend.ExpectedRequests) {
+ backend.tb.Errorf("Unexpected backend request: %s %s: %s", r.Method, r.URL, unsealedForm)
+ return
+ }
+
+ cur := backend.ExpectedRequests[backend.currentRequest]
+ backend.currentRequest++
+
+ headers := w.Header()
+ for k, v := range cur.ResponseHeaders {
+ if len(v) == 1 {
+ headers.Set(k, v[0])
+ } else if len(v) == 0 {
+ headers.Del(k)
+ } else {
+ for _, hv := range v {
+ headers.Add(k, hv)
+ }
+ }
+ }
+
+ defer func() {
+ w.WriteHeader(cur.ResponseCode)
+ if cur.Response != "" {
+ w.Write([]byte(cur.Response))
+ }
+ }()
+
+ if cur.Path != "" {
+ if r.URL.Path != cur.Path {
+ backend.tb.Errorf("Bad backend request. Expected %v, got %s %s", cur, r.Method, r.URL)
+ return
+ }
+ }
+
+ if cur.PostForm != nil {
+ anyErr := TcompareForms(backend.tb, "Different form contents", *cur.PostForm, unsealedForm)
+ if anyErr {
+ backend.tb.Errorf("...in %s %s: %s", r.Method, r.URL, unsealedForm.Encode())
+ }
+ }
+}
+
+func (backend *TBackendRequestChecker) Close() error {
+ if backend.currentRequest < len(backend.ExpectedRequests) {
+ backend.tb.Errorf("Not all requests sent, got %d out of %d", backend.currentRequest, len(backend.ExpectedRequests))
+ }
+ return nil
+}
+
+func TcompareForms(tb TBC, ctx string, expectedForm, gotForm url.Values) (anyErrors bool) {
+ for k, expVal := range expectedForm {
+ gotVal, ok := gotForm[k]
+ if !ok {
+ tb.Errorf("%s: Form[%s]: Expected %v, (got nothing)", ctx, k, expVal)
+ anyErrors = true
+ continue
+ }
+ if len(expVal) != len(gotVal) {
+ tb.Errorf("%s: Form[%s]: Expected %d%v, Got %d%v", ctx, k, len(expVal), expVal, len(gotVal), gotVal)
+ anyErrors = true
+ continue
+ }
+ for i, el := range expVal {
+ if gotVal[i] != el {
+ tb.Errorf("%s: Form[%s][%d]: Expected %s, Got %s", ctx, k, i, el, gotVal[i])
+ anyErrors = true
+ }
+ }
+ }
+ for k, gotVal := range gotForm {
+ _, ok := expectedForm[k]
+ if !ok {
+ tb.Errorf("%s: Form[%s]: (expected nothing), Got %v", ctx, k, gotVal)
+ anyErrors = true
+ }
+ }
+ return anyErrors
+}
+
+func TCountOpenFDs() uint64 {
+ ary, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid()))
+ return uint64(len(ary))
+}
+
+const IgnoreReceivedArguments = 1 + 2i
+
+func TReceiveExpectedMessage(tb testing.TB, conn *websocket.Conn, messageID int, command Command, arguments interface{}) (ClientMessage, bool) {
+ var msg ClientMessage
+ var fail bool
+ messageType, packet, err := conn.ReadMessage()
+ if err != nil {
+ tb.Error(err)
+ return msg, false
+ }
+ if messageType != websocket.TextMessage {
+ tb.Error("got non-text message", packet)
+ return msg, false
+ }
+
+ err = UnmarshalClientMessage(packet, messageType, &msg)
+ if err != nil {
+ tb.Error(err)
+ return msg, false
+ }
+ if msg.MessageID != messageID {
+ tb.Error("Message ID was wrong. Expected", messageID, ", got", msg.MessageID, ":", msg)
+ fail = true
+ }
+ if msg.Command != command {
+ tb.Error("Command was wrong. Expected", command, ", got", msg.Command, ":", msg)
+ fail = true
+ }
+ if arguments != IgnoreReceivedArguments {
+ if arguments == nil {
+ if msg.origArguments != "" {
+ tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg)
+ }
+ } else {
+ argBytes, _ := json.Marshal(arguments)
+ if msg.origArguments != string(argBytes) {
+ tb.Error("Arguments are wrong. Expected", arguments, ", got", msg.Arguments, ":", msg)
+ }
+ }
+ }
+ return msg, !fail
+}
+
+func TSendMessage(tb testing.TB, conn *websocket.Conn, messageID int, command Command, arguments interface{}) bool {
+ SendMessage(conn, ClientMessage{MessageID: messageID, Command: command, Arguments: arguments})
+ return true
+}
+
+func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, deleteMode bool) (url.Values, error) {
+ form := url.Values{}
+ form.Set("cmd", string(cmd))
+ argsBytes, err := json.Marshal(arguments)
+ if err != nil {
+ tb.Error(err)
+ return nil, err
+ }
+ form.Set("args", string(argsBytes))
+ form.Set("channel", channel)
+ if deleteMode {
+ form.Set("delete", "1")
+ }
+ form.Set("time", strconv.FormatInt(time.Now().Unix(), 10))
+
+ sealed, err := Backend.SealRequest(form)
+ if err != nil {
+ tb.Error(err)
+ return nil, err
+ }
+ return sealed, nil
+}
+
+func TSealForUncachedPubMsg(tb testing.TB, cmd Command, channel string, arguments interface{}, scope string, deleteMode bool) (url.Values, error) {
+ form := url.Values{}
+ form.Set("cmd", string(cmd))
+ argsBytes, err := json.Marshal(arguments)
+ if err != nil {
+ tb.Error(err)
+ return nil, err
+ }
+ form.Set("args", string(argsBytes))
+ form.Set("channel", channel)
+ if deleteMode {
+ form.Set("delete", "1")
+ }
+ form.Set("time", time.Now().Format(time.UnixDate))
+ form.Set("scope", scope)
+
+ sealed, err := Backend.SealRequest(form)
+ if err != nil {
+ tb.Error(err)
+ return nil, err
+ }
+ return sealed, nil
+}
+
+func TCheckResponse(tb testing.TB, resp *http.Response, expected string, desc string) bool {
+ var failed bool
+ respBytes, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ respStr := string(respBytes)
+
+ if err != nil {
+ tb.Error(err)
+ failed = true
+ }
+
+ if resp.StatusCode != 200 {
+ tb.Error("Publish failed: ", resp.StatusCode, respStr)
+ failed = true
+ }
+
+ if respStr != expected {
+ tb.Errorf("Got wrong response from server. %s Expected: '%s' Got: '%s'", desc, expected, respStr)
+ failed = true
+ }
+ return !failed
+}
+
+type TURLs struct {
+ Websocket string
+ Origin string
+ UncachedPubMsg string // uncached_pub
+ SavePubMsg string // cached_pub
+}
+
+func TGetUrls(socketserver *httptest.Server, backend *httptest.Server) TURLs {
+ addr := socketserver.Listener.Addr().String()
+ return TURLs{
+ Websocket: fmt.Sprintf("ws://%s/", addr),
+ Origin: fmt.Sprintf("http://%s", addr),
+ UncachedPubMsg: fmt.Sprintf("http://%s/uncached_pub", addr),
+ SavePubMsg: fmt.Sprintf("http://%s/cached_pub", addr),
+ }
+}
+
+func TCheckHLLValue(tb testing.TB, expected uint64, actual uint64) {
+ high := uint64(float64(expected) * 1.05)
+ low := uint64(float64(expected) * 0.95)
+ if actual < low || actual > high {
+ tb.Errorf("Count outside expected range. Expected %d, Got %d", expected, actual)
+ }
+}
diff --git a/socketserver/server/tickspersecond.go b/socketserver/server/tickspersecond.go
new file mode 100644
index 00000000..160eec7c
--- /dev/null
+++ b/socketserver/server/tickspersecond.go
@@ -0,0 +1,12 @@
+package server
+
+// #include
+// long get_ticks_per_second() {
+// return sysconf(_SC_CLK_TCK);
+// }
+//import "C"
+
+// note: this seems to add 0.1s to compile time on my machine
+//var ticksPerSecond = int(C.get_ticks_per_second())
+
+var ticksPerSecond = 100
diff --git a/socketserver/server/types.go b/socketserver/server/types.go
new file mode 100644
index 00000000..41c0d010
--- /dev/null
+++ b/socketserver/server/types.go
@@ -0,0 +1,153 @@
+package server
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/satori/go.uuid"
+)
+
+const NegativeOne = ^uint64(0)
+
+var AnonymousClientID = uuid.FromStringOrNil("683b45e4-f853-4c45-bf96-7d799cc93e34")
+
+type ConfigFile struct {
+ // Numeric server id known to the backend
+ ServerID int
+ // Address to bind the HTTP server to on startup.
+ ListenAddr string
+ // Address to bind the TLS server to on startup.
+ SSLListenAddr string
+ // URL to the backend server
+ BackendURL string
+
+ // Minimum memory to accept a new connection
+ MinMemoryKBytes uint64
+ // Maximum # of clients that can be connected. 0 to disable.
+ MaxClientCount uint64
+
+ // SSL/TLS
+ // Enable the use of SSL.
+ UseSSL bool
+ // Path to certificate file.
+ SSLCertificateFile string
+ // Path to key file.
+ SSLKeyFile string
+
+ UseESLogStashing bool
+ ESServer string
+ ESIndexPrefix string
+ ESHostName string
+
+ // Nacl keys
+ OurPrivateKey []byte
+ OurPublicKey []byte
+ BackendPublicKey []byte
+
+ // Request username validation from all new clients.
+ SendAuthToNewClients bool
+}
+
+type ClientMessage struct {
+ // Message ID. Increments by 1 for each message sent from the client.
+ // When replying to a command, the message ID must be echoed.
+ // When sending a server-initiated message, this is -1.
+ MessageID int `json:"m"`
+ // The command that the client wants from the server.
+ // When sent from the server, the literal string 'True' indicates success.
+ // Before sending, a blank Command will be converted into SuccessCommand.
+ Command Command `json:"c"`
+ // Result of json.Unmarshal on the third field send from the client
+ Arguments interface{} `json:"a"`
+
+ origArguments string
+}
+
+type AuthInfo struct {
+ // The client's claimed username on Twitch.
+ TwitchUsername string
+
+ // Whether or not the server has validated the client's claimed username.
+ UsernameValidated bool
+}
+
+type ClientVersion struct {
+ Major int
+ Minor int
+ Revision int
+}
+
+type ClientInfo struct {
+ // The client ID.
+ // This must be written once by the owning goroutine before the struct is passed off to any other goroutines.
+ ClientID uuid.UUID
+
+ // The client's literal version string.
+ // This must be written once by the owning goroutine before the struct is passed off to any other goroutines.
+ VersionString string
+
+ Version ClientVersion
+
+ // This mutex protects writable data in this struct.
+ // If it seems to be a performance problem, we can split this.
+ Mutex sync.Mutex
+
+ // Info about the client's username and whether or not we have verified it.
+ AuthInfo
+
+ RemoteAddr net.Addr
+
+ // Username validation nonce.
+ ValidationNonce string
+
+ // The list of chats this client is currently in.
+ // Protected by Mutex.
+ CurrentChannels []string
+
+ // True if the client has already sent the 'ready' command
+ ReadyComplete bool
+
+ // Server-initiated messages should be sent here
+ // This field will be nil before it is closed.
+ MessageChannel chan<- ClientMessage
+
+ MsgChannelIsDone <-chan struct{}
+
+ // Take out an Add() on this during a command if you need to use the MessageChannel later.
+ MsgChannelKeepalive sync.WaitGroup
+
+ // The number of pings sent without a response.
+ // Protected by Mutex
+ pingCount int
+}
+
+func VersionFromString(v string) ClientVersion {
+ var cv ClientVersion
+ fmt.Sscanf(v, "ffz_%d.%d.%d", &cv.Major, &cv.Minor, &cv.Revision)
+ return cv
+}
+
+func (cv *ClientVersion) After(cv2 *ClientVersion) bool {
+ if cv.Major > cv2.Major {
+ return true
+ } else if cv.Major < cv2.Major {
+ return false
+ }
+ if cv.Minor > cv2.Minor {
+ return true
+ } else if cv.Minor < cv2.Minor {
+ return false
+ }
+ if cv.Revision > cv2.Revision {
+ return true
+ } else if cv.Revision < cv2.Revision {
+ return false
+ }
+
+ return false // equal
+}
+
+func (cv *ClientVersion) Equal(cv2 *ClientVersion) bool {
+ return cv.Major == cv2.Major && cv.Minor == cv2.Minor && cv.Revision == cv2.Revision
+}
diff --git a/socketserver/server/usercount.go b/socketserver/server/usercount.go
new file mode 100644
index 00000000..1b9cbcbb
--- /dev/null
+++ b/socketserver/server/usercount.go
@@ -0,0 +1,244 @@
+package server
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/gob"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "time"
+
+ "io"
+
+ "github.com/clarkduvall/hyperloglog"
+ "github.com/satori/go.uuid"
+)
+
+// UuidHash implements a hash for uuid.UUID by XORing the random bits.
+type UuidHash uuid.UUID
+
+func (u UuidHash) Sum64() uint64 {
+ var valLow, valHigh uint64
+ valLow = binary.LittleEndian.Uint64(u[0:8])
+ valHigh = binary.LittleEndian.Uint64(u[8:16])
+ return valLow ^ valHigh
+}
+
+type PeriodUniqueUsers struct {
+ Start time.Time
+ End time.Time
+ Counter *hyperloglog.HyperLogLogPlus
+}
+
+type usageToken struct{}
+
+const uniqCountDir = "./uniques"
+const UsersDailyFmt = "daily-%d-%d-%d.gob" // d-m-y
+const CounterPrecision uint8 = 12
+
+var uniqueCounter PeriodUniqueUsers
+var uniqueUserChannel chan uuid.UUID
+var uniqueCtrWritingToken chan usageToken
+
+var CounterLocation *time.Location = time.FixedZone("UTC-5", int((time.Hour*-5)/time.Second))
+
+func TruncateToMidnight(at time.Time) time.Time {
+ year, month, day := at.Date()
+ return time.Date(year, month, day, 0, 0, 0, 0, CounterLocation)
+}
+
+// GetCounterPeriod calculates the start and end timestamps for the HLL measurement period that includes the 'at' timestamp.
+func GetCounterPeriod(at time.Time) (start time.Time, end time.Time) {
+ year, month, day := at.Date()
+ start = time.Date(year, month, day, 0, 0, 0, 0, CounterLocation)
+ end = time.Date(year, month, day+1, 0, 0, 0, 0, CounterLocation)
+ return start, end
+}
+
+// GetHLLFilename returns the filename for the saved HLL whose measurement period covers the given time.
+func GetHLLFilename(at time.Time) string {
+ var filename string
+ year, month, day := at.Date()
+ filename = fmt.Sprintf(UsersDailyFmt, day, month, year)
+ return fmt.Sprintf("%s/%s", uniqCountDir, filename)
+}
+
+// loadHLL loads a HLL from disk and stores the result in dest.Counter.
+// If dest.Counter is nil, it will be initialized. (This is a useful side-effect.)
+// If dest is one of the uniqueCounters, the usageToken must be held.
+func loadHLL(at time.Time, dest *PeriodUniqueUsers) error {
+ fileBytes, err := ioutil.ReadFile(GetHLLFilename(at))
+ if err != nil {
+ return err
+ }
+
+ if dest.Counter == nil {
+ dest.Counter, _ = hyperloglog.NewPlus(CounterPrecision)
+ }
+
+ dec := gob.NewDecoder(bytes.NewReader(fileBytes))
+ err = dec.Decode(dest.Counter)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// writeHLL writes the indicated HLL to disk.
+// The function takes the usageToken.
+func writeHLL() error {
+ token := <-uniqueCtrWritingToken
+ result := writeHLL_do(&uniqueCounter)
+ uniqueCtrWritingToken <- token
+ return result
+}
+
+// writeHLL_do writes out the HLL indicated by `which` to disk.
+// The usageToken must be held when calling this function.
+func writeHLL_do(hll *PeriodUniqueUsers) (err error) {
+ filename := GetHLLFilename(hll.Start)
+ file, err := os.Create(filename)
+ if err != nil {
+ return err
+ }
+
+ defer func(file io.Closer) {
+ fileErr := file.Close()
+ if err == nil {
+ err = fileErr
+ }
+ }(file)
+
+ enc := gob.NewEncoder(file)
+ return enc.Encode(hll.Counter)
+}
+
+// readCurrentHLL reads the current value of the active HLL counter.
+// The function takes the usageToken.
+func readCurrentHLL() uint64 {
+ token := <-uniqueCtrWritingToken
+ result := uniqueCounter.Counter.Count()
+ uniqueCtrWritingToken <- token
+ return result
+}
+
+var hllFileServer = http.StripPrefix("/hll", http.FileServer(http.Dir(uniqCountDir)))
+
+func HTTPShowHLL(w http.ResponseWriter, r *http.Request) {
+ hllFileServer.ServeHTTP(w, r)
+}
+
+func HTTPWriteHLL(w http.ResponseWriter, r *http.Request) {
+ writeHLL()
+ w.WriteHeader(200)
+ w.Write([]byte("ok"))
+}
+
+// loadUniqueUsers loads the previous HLLs into memory.
+// is_init_func
+func loadUniqueUsers() {
+ gob.RegisterName("hyperloglog", hyperloglog.HyperLogLogPlus{})
+ err := os.MkdirAll(uniqCountDir, 0755)
+ if err != nil {
+ log.Panicln("could not make unique users data dir:", err)
+ }
+
+ now := time.Now().In(CounterLocation)
+ uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now)
+ err = loadHLL(now, &uniqueCounter)
+ isIgnorableError := err != nil && (false ||
+ (os.IsNotExist(err)) ||
+ (err == io.EOF))
+
+ if isIgnorableError {
+ // file didn't finish writing
+ // errors in NewPlus are bad precisions
+ uniqueCounter.Counter, _ = hyperloglog.NewPlus(CounterPrecision)
+ log.Println("failed to load unique users data:", err)
+ } else if err != nil {
+ log.Panicln("failed to load unique users data:", err)
+ }
+
+ uniqueUserChannel = make(chan uuid.UUID)
+ uniqueCtrWritingToken = make(chan usageToken)
+ go processNewUsers()
+ go rolloverCounters()
+ uniqueCtrWritingToken <- usageToken{}
+}
+
+// dumpUniqueUsers dumps all the data in uniqueCounters.
+func dumpUniqueUsers() {
+ token := <-uniqueCtrWritingToken
+
+ uniqueCounter.Counter.Clear()
+
+ uniqueCtrWritingToken <- token
+}
+
+// processNewUsers reads uniqueUserChannel, and also dispatches the writing token.
+// This function is the primary writer of uniqueCounters, so it makes sense for it to hold the token.
+// is_init_func
+func processNewUsers() {
+ token := <-uniqueCtrWritingToken
+
+ for {
+ select {
+ case u := <-uniqueUserChannel:
+ hashed := UuidHash(u)
+ uniqueCounter.Counter.Add(hashed)
+ case uniqueCtrWritingToken <- token:
+ // relinquish token. important that there is only one of this going on
+ // otherwise we thrash
+ token = <-uniqueCtrWritingToken
+ }
+ }
+}
+
+func getNextMidnight() time.Time {
+ now := time.Now().In(CounterLocation)
+ year, month, day := now.Date()
+ return time.Date(year, month, day+1, 0, 0, 1, 0, CounterLocation)
+}
+
+// is_init_func
+func rolloverCounters() {
+ for {
+ duration := getNextMidnight().Sub(time.Now())
+ // fmt.Println(duration)
+ time.Sleep(duration)
+ rolloverCounters_do()
+ }
+}
+
+func rolloverCounters_do() {
+ var token usageToken
+ var now time.Time
+
+ token = <-uniqueCtrWritingToken
+ now = time.Now().In(CounterLocation)
+ // Cycle for period
+ err := writeHLL_do(&uniqueCounter)
+ if err != nil {
+ log.Println("could not cycle unique user counter:", err)
+
+ // Attempt to rescue the data into the log
+ var buf bytes.Buffer
+ bytes, err := uniqueCounter.Counter.GobEncode()
+ if err == nil {
+ enc := base64.NewEncoder(base64.StdEncoding, &buf)
+ enc.Write(bytes)
+ enc.Close()
+ log.Print("data for ", GetHLLFilename(uniqueCounter.Start), ":", buf.String())
+ }
+ }
+
+ uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now)
+ // errors are bad precisions, so we can ignore
+ uniqueCounter.Counter, _ = hyperloglog.NewPlus(CounterPrecision)
+
+ uniqueCtrWritingToken <- token
+}
diff --git a/socketserver/server/usercount_test.go b/socketserver/server/usercount_test.go
new file mode 100644
index 00000000..2cf608a2
--- /dev/null
+++ b/socketserver/server/usercount_test.go
@@ -0,0 +1,68 @@
+package server
+
+import (
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/satori/go.uuid"
+)
+
+func TestUniqueConnections(t *testing.T) {
+ const TestExpectedCount = 1000
+
+ testStart := time.Now().In(CounterLocation)
+
+ var server *httptest.Server
+ var backendExpected = NewTBackendRequestChecker(t,
+ TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil},
+ )
+ server, _, _ = TSetup(SetupWantSocketServer|SetupWantBackendServer, backendExpected)
+
+ defer server.CloseClientConnections()
+ defer unsubscribeAllClients()
+ defer backendExpected.Close()
+
+ dumpUniqueUsers()
+
+ for i := 0; i < TestExpectedCount; i++ {
+ uuid := uuid.NewV4()
+ uniqueUserChannel <- uuid
+ uniqueUserChannel <- uuid
+ }
+
+ TCheckHLLValue(t, TestExpectedCount, readCurrentHLL())
+
+ token := <-uniqueCtrWritingToken
+ uniqueCounter.End = time.Now().In(CounterLocation).Add(-1 * time.Second)
+ uniqueCtrWritingToken <- token
+
+ rolloverCounters_do()
+
+ for i := 0; i < TestExpectedCount; i++ {
+ uuid := uuid.NewV4()
+ uniqueUserChannel <- uuid
+ uniqueUserChannel <- uuid
+ }
+
+ TCheckHLLValue(t, TestExpectedCount, readCurrentHLL())
+
+ // Check: Merging the two days results in 2000
+ // note: rolloverCounters_do() wrote out a file, and loadHLL() is reading it back
+ // TODO need to rewrite some of the test to make this work
+ var loadDest PeriodUniqueUsers
+ loadHLL(testStart, &loadDest)
+
+ token = <-uniqueCtrWritingToken
+ loadDest.Counter.Merge(uniqueCounter.Counter)
+ uniqueCtrWritingToken <- token
+
+ TCheckHLLValue(t, TestExpectedCount*2, loadDest.Counter.Count())
+}
+
+func TestUniqueUsersCleanup(t *testing.T) {
+ // Not a test. Removes old files.
+ os.RemoveAll(uniqCountDir)
+}
diff --git a/socketserver/internal/server/utils.go b/socketserver/server/utils.go
similarity index 68%
rename from socketserver/internal/server/utils.go
rename to socketserver/server/utils.go
index 8dbff0f4..9552e7bb 100644
--- a/socketserver/internal/server/utils.go
+++ b/socketserver/server/utils.go
@@ -5,11 +5,11 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
- "golang.org/x/crypto/nacl/box"
- "log"
"net/url"
"strconv"
"strings"
+
+ "golang.org/x/crypto/nacl/box"
)
func FillCryptoRandom(buf []byte) error {
@@ -24,11 +24,11 @@ func FillCryptoRandom(buf []byte) error {
return nil
}
-func New4KByteBuffer() interface{} {
- return make([]byte, 0, 4096)
+func copyString(s string) string {
+ return string([]byte(s))
}
-func SealRequest(form url.Values) (url.Values, error) {
+func (backend *backendInfo) SealRequest(form url.Values) (url.Values, error) {
var nonce [24]byte
var err error
@@ -37,7 +37,7 @@ func SealRequest(form url.Values) (url.Values, error) {
return nil, err
}
- cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backendSharedKey)
+ cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backend.sharedKey)
bufMessage := new(bytes.Buffer)
enc := base64.NewEncoder(base64.URLEncoding, bufMessage)
@@ -54,7 +54,7 @@ func SealRequest(form url.Values) (url.Values, error) {
retval := url.Values{
"nonce": []string{nonceString},
"msg": []string{cipherString},
- "id": []string{strconv.Itoa(serverId)},
+ "id": []string{strconv.Itoa(Backend.serverID)},
}
return retval, nil
@@ -63,16 +63,18 @@ func SealRequest(form url.Values) (url.Values, error) {
var ErrorShortNonce = errors.New("Nonce too short.")
var ErrorInvalidSignature = errors.New("Invalid signature or contents")
-func UnsealRequest(form url.Values) (url.Values, error) {
+func (backend *backendInfo) UnsealRequest(form url.Values) (url.Values, error) {
var nonce [24]byte
nonceString := form.Get("nonce")
dec := base64.NewDecoder(base64.URLEncoding, strings.NewReader(nonceString))
count, err := dec.Read(nonce[:])
if err != nil {
+ Statistics.BackendVerifyFails++
return nil, err
}
if count != 24 {
+ Statistics.BackendVerifyFails++
return nil, ErrorShortNonce
}
@@ -81,15 +83,15 @@ func UnsealRequest(form url.Values) (url.Values, error) {
cipherBuffer := new(bytes.Buffer)
cipherBuffer.ReadFrom(dec)
- message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backendSharedKey)
+ message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backend.sharedKey)
if !ok {
+ Statistics.BackendVerifyFails++
return nil, ErrorInvalidSignature
}
retValues, err := url.ParseQuery(string(message))
if err != nil {
- // Assume that the signature was accidentally correct but the contents were garbage
- log.Print(err)
+ Statistics.BackendVerifyFails++
return nil, ErrorInvalidSignature
}
@@ -159,3 +161,49 @@ func RemoveFromSliceC(ary *[]chan<- ClientMessage, val chan<- ClientMessage) boo
*ary = slice
return true
}
+
+func AddToSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
+ slice := *ary
+ for _, v := range slice {
+ if v == val {
+ return false
+ }
+ }
+
+ slice = append(slice, val)
+ *ary = slice
+ return true
+}
+
+func RemoveFromSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
+ slice := *ary
+ var idx int = -1
+ for i, v := range slice {
+ if v == val {
+ idx = i
+ break
+ }
+ }
+ if idx == -1 {
+ return false
+ }
+
+ slice[idx] = slice[len(slice)-1]
+ slice = slice[:len(slice)-1]
+ *ary = slice
+ return true
+}
+
+func AddToSliceB(ary *[]bunchSubscriber, client *ClientInfo, mid int) bool {
+ newSub := bunchSubscriber{Client: client, MessageID: mid}
+ slice := *ary
+ for _, v := range slice {
+ if v == newSub {
+ return false
+ }
+ }
+
+ slice = append(slice, newSub)
+ *ary = slice
+ return true
+}