diff --git a/socketserver/server/commands.go b/socketserver/server/commands.go index 665256c9..0e8004c9 100644 --- a/socketserver/server/commands.go +++ b/socketserver/server/commands.go @@ -103,6 +103,8 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg client.ClientID = uuid.NewV4() } + uniqueUserChannel <- client.ClientID + SubscribeGlobal(client) SubscribeDefaults(client) diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go index c61dfa32..451b9223 100644 --- a/socketserver/server/handlecore.go +++ b/socketserver/server/handlecore.go @@ -85,8 +85,9 @@ func SetupServerAndHandle(config *ConfigFile, serveMux *http.ServeMux) { BannerHTML = bannerBytes serveMux.HandleFunc("/", HTTPHandleRootURL) - serveMux.Handle("/.well-known/", http.FileServer(http.FileSystem(http.Dir("/tmp/letsencrypt/")))) + serveMux.Handle("/.well-known/", http.FileServer(http.Dir("/tmp/letsencrypt/"))) serveMux.HandleFunc("/stats", HTTPShowStatistics) + serveMux.HandleFunc("/hll/", HTTPShowHLL) serveMux.HandleFunc("/drop_backlog", HTTPBackendDropBacklog) serveMux.HandleFunc("/uncached_pub", HTTPBackendUncachedPublish) @@ -130,13 +131,22 @@ func startJanitors() { func shutdownHandler() { ch := make(chan os.Signal) signal.Notify(ch, syscall.SIGUSR1) + signal.Notify(ch, syscall.SIGTERM) <-ch log.Println("Shutting down...") + var wg sync.WaitGroup + wg.Add(1) + go func() { + writeAllHLLs() + wg.Done() + }() + StopAcceptingConnections = true close(StopAcceptingConnectionsCh) time.Sleep(1 * time.Second) + wg.Wait() os.Exit(0) } diff --git a/socketserver/server/stats.go b/socketserver/server/stats.go index 115cfefb..0275d4d7 100644 --- a/socketserver/server/stats.go +++ b/socketserver/server/stats.go @@ -165,6 +165,10 @@ func updateSysMem() { if err == nil { Statistics.SysMemTotalKB = memInfo.MemTotal Statistics.SysMemFreeKB = memInfo.MemAvailable + + if memInfo.MemAvailable > 0 && memInfo.MemAvailable < Configuration.MinMemoryKBytes { + writeAllHLLs() + } } } diff --git a/socketserver/server/subscriptions_test.go b/socketserver/server/subscriptions_test.go index 27ed6fe3..d2e561d2 100644 --- a/socketserver/server/subscriptions_test.go +++ b/socketserver/server/subscriptions_test.go @@ -228,6 +228,13 @@ func TestSubscriptionAndPublish(t *testing.T) { doneWg.Wait() server.Close() + + for _, period := range periods { + clientCount := readHLL(period) + if clientCount < 3 || clientCount > 5 { + t.Error("clientCount outside acceptable range: expected 4, got ", clientCount) + } + } } func TestRestrictedCommands(t *testing.T) { diff --git a/socketserver/server/testinfra_test.go b/socketserver/server/testinfra_test.go index 07fb063c..9c03fa78 100644 --- a/socketserver/server/testinfra_test.go +++ b/socketserver/server/testinfra_test.go @@ -64,6 +64,7 @@ func TSetup(flags int, backendChecker *TBackendRequestChecker) (socketserver *ht if flags&SetupWantSocketServer != 0 { serveMux := http.NewServeMux() SetupServerAndHandle(conf, serveMux) + dumpUniqueUsers() socketserver = httptest.NewServer(serveMux) } @@ -344,3 +345,11 @@ func TGetUrls(socketserver *httptest.Server, backend *httptest.Server) TURLs { SavePubMsg: fmt.Sprintf("http://%s/cached_pub", addr), } } + +func TCheckHLLValue(tb testing.TB, expected uint64, actual uint64) { + high := uint64(float64(expected) * 1.05) + low := uint64(float64(expected) * 0.95) + if actual < low || actual > high { + tb.Errorf("Count outside expected range. Expected %d, Got %d", expected, actual) + } +} diff --git a/socketserver/server/usercount.go b/socketserver/server/usercount.go new file mode 100644 index 00000000..9ac614fa --- /dev/null +++ b/socketserver/server/usercount.go @@ -0,0 +1,271 @@ +package server + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "encoding/gob" + "fmt" + "net/http" + "io/ioutil" + "log" + "os" + "time" + + "github.com/clarkduvall/hyperloglog" + "github.com/satori/go.uuid" +) + +// uuidHash implements a hash for uuid.UUID by XORing the random bits. +type uuidHash uuid.UUID + +func (u uuidHash) Sum64() uint64 { + var valLow, valHigh uint64 + valLow = binary.LittleEndian.Uint64(u[0:8]) + valHigh = binary.LittleEndian.Uint64(u[8:16]) + return valLow ^ valHigh +} + +type PeriodUniqueUsers struct { + Start time.Time + End time.Time + Counter *hyperloglog.HyperLogLogPlus +} + +type usageToken struct{} + +const ( + periodDaily = iota + periodWeekly + periodMonthly +) + +var periods [3]int = [3]int{periodDaily, periodWeekly, periodMonthly} + +const uniqCountDir = "./uniques" +const usersDailyFmt = "daily-%d-%d-%d.gob" // d-m-y +const usersWeeklyFmt = "weekly-%d-%d.gob" // w-y +const usersMonthlyFmt = "monthly-%d-%d.gob" // m-y +const counterPrecision uint8 = 12 + +var uniqueCounters [3]PeriodUniqueUsers +var uniqueUserChannel chan uuid.UUID +var uniqueCtrWritingToken chan usageToken + +var counterLocation *time.Location = time.FixedZone("UTC-5", int((time.Hour*-5)/time.Second)) + +// getCounterPeriod calculates the start and end timestamps for the HLL measurement period that includes the 'at' timestamp. +func getCounterPeriod(which int, at time.Time) (start time.Time, end time.Time) { + year, month, day := at.Date() + + switch which { + case periodDaily: + start = time.Date(year, month, day, 0, 0, 0, 0, counterLocation) + end = time.Date(year, month, day+1, 0, 0, 0, 0, counterLocation) + case periodWeekly: + dayOffset := at.Weekday() - time.Sunday + start = time.Date(year, month, day-int(dayOffset), 0, 0, 0, 0, counterLocation) + end = time.Date(year, month, day-int(dayOffset)+7, 0, 0, 0, 0, counterLocation) + case periodMonthly: + start = time.Date(year, month, 1, 0, 0, 0, 0, counterLocation) + end = time.Date(year, month+1, 1, 0, 0, 0, 0, counterLocation) + } + return start, end +} + +// getHLLFilename returns the filename for the saved HLL whose measurement period covers the given time. +func getHLLFilename(which int, at time.Time) string { + var filename string + switch which { + case periodDaily: + year, month, day := at.Date() + filename = fmt.Sprintf(usersDailyFmt, day, month, year) + case periodWeekly: + year, week := at.ISOWeek() + filename = fmt.Sprintf(usersWeeklyFmt, week, year) + case periodMonthly: + year, month, _ := at.Date() + filename = fmt.Sprintf(usersMonthlyFmt, month, year) + } + return fmt.Sprintf("%s/%s", uniqCountDir, filename) +} + +// loadHLL loads a HLL from disk and stores the result in dest.Counter. +// If dest.Counter is nil, it will be initialized. (This is a useful side-effect.) +// If dest is one of the uniqueCounters, the usageToken must be held. +func loadHLL(which int, at time.Time, dest *PeriodUniqueUsers) error { + fileBytes, err := ioutil.ReadFile(getHLLFilename(which, at)) + if err != nil { + return err + } + + if dest.Counter == nil { + dest.Counter, _ = hyperloglog.NewPlus(counterPrecision) + } + + dec := gob.NewDecoder(bytes.NewReader(fileBytes)) + err = dec.Decode(dest.Counter) + if err != nil { + log.Panicln(err) + return err + } + return nil +} + +// writeHLL writes the indicated HLL to disk. +// The function takes the usageToken. +func writeHLL(which int) error { + token := <-uniqueCtrWritingToken + result := writeHLL_do(which) + uniqueCtrWritingToken <- token + return result +} + +// writeHLL_do writes out the HLL indicated by `which` to disk. +// The usageToken must be held when calling this function. +func writeHLL_do(which int) error { + counter := uniqueCounters[which] + filename := getHLLFilename(which, counter.Start) + file, err := os.Create(filename) + if err != nil { + return err + } + enc := gob.NewEncoder(file) + enc.Encode(counter.Counter) + return file.Close() +} + +// readHLL reads the current value of the indicated HLL counter. +// The function takes the usageToken. +func readHLL(which int) uint64 { + token := <-uniqueCtrWritingToken + result := uniqueCounters[which].Counter.Count() + uniqueCtrWritingToken <- token + return result +} + +// writeAllHLLs writes out all in-memory HLLs to disk. +// The function takes the usageToken. +func writeAllHLLs() error { + var err, err2 error + token := <-uniqueCtrWritingToken + for _, period := range periods { + err2 = writeHLL_do(period) + if err == nil { + err = err2 + } + } + uniqueCtrWritingToken <- token + return err +} + +var hllFileServer = http.StripPrefix("/hll", http.FileServer(http.Dir(uniqCountDir))) +func HTTPShowHLL(w http.ResponseWriter, r *http.Request) { + hllFileServer.ServeHTTP(w, r) +} + +// loadUniqueUsers loads the previous HLLs into memory. +// is_init_func +func loadUniqueUsers() { + err := os.MkdirAll(uniqCountDir, 0755) + if err != nil { + log.Panicln("could not make unique users data dir:", err) + } + + now := time.Now().In(counterLocation) + for _, period := range periods { + uniqueCounters[period].Start, uniqueCounters[period].End = getCounterPeriod(period, now) + err := loadHLL(period, now, &uniqueCounters[period]) + if err != nil && os.IsNotExist(err) { + // errors are bad precisions + uniqueCounters[period].Counter, _ = hyperloglog.NewPlus(counterPrecision) + } else if err != nil && !os.IsNotExist(err) { + log.Panicln("failed to load unique users data:", err) + } + } + + uniqueUserChannel = make(chan uuid.UUID) + uniqueCtrWritingToken = make(chan usageToken) + go processNewUsers() + go rolloverCounters() + uniqueCtrWritingToken <- usageToken{} +} + +// dumpUniqueUsers dumps all the data in uniqueCounters. +func dumpUniqueUsers() { + token := <-uniqueCtrWritingToken + + for _, period := range periods { + uniqueCounters[period].Counter.Clear() + } + + uniqueCtrWritingToken <- token +} + +// processNewUsers reads uniqueUserChannel, and also dispatches the writing token. +// This function is the primary writer of uniqueCounters, so it makes sense for it to hold the token. +// is_init_func +func processNewUsers() { + token := <-uniqueCtrWritingToken + + for { + select { + case u := <-uniqueUserChannel: + hashed := uuidHash(u) + for _, period := range periods { + uniqueCounters[period].Counter.Add(hashed) + } + case uniqueCtrWritingToken <- token: + // relinquish token. important that there is only one of this going on + // otherwise we thrash + token = <-uniqueCtrWritingToken + } + } +} + +func getNextMidnight() time.Time { + now := time.Now().In(counterLocation) + year, month, day := now.Date() + return time.Date(year, month, day+1, 0, 0, 1, 0, counterLocation) +} + +// is_init_func +func rolloverCounters() { + for { + time.Sleep(getNextMidnight().Sub(time.Now())) + rolloverCounters_do() + } +} + +func rolloverCounters_do() { + var token usageToken + var now time.Time + + token = <-uniqueCtrWritingToken + now = time.Now().In(counterLocation) + for _, period := range periods { + if now.After(uniqueCounters[period].End) { + // Cycle for period + err := writeHLL_do(period) + if err != nil { + log.Println("could not cycle unique user counter:", err) + + // Attempt to rescue the data into the log + var buf bytes.Buffer + bytes, err := uniqueCounters[period].Counter.GobEncode() + if err == nil { + enc := base64.NewEncoder(base64.StdEncoding, &buf) + enc.Write(bytes) + enc.Close() + log.Print("data for ", getHLLFilename(period, now), ":", buf.String()) + } + } + + uniqueCounters[period].Start, uniqueCounters[period].End = getCounterPeriod(period, now) + // errors are bad precisions, so we can ignore + uniqueCounters[period].Counter, _ = hyperloglog.NewPlus(counterPrecision) + } + } + + uniqueCtrWritingToken <- token +} diff --git a/socketserver/server/usercount_test.go b/socketserver/server/usercount_test.go new file mode 100644 index 00000000..4b436d49 --- /dev/null +++ b/socketserver/server/usercount_test.go @@ -0,0 +1,70 @@ +package server + +import ( + "github.com/satori/go.uuid" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" +) + +func TestUniqueConnections(t *testing.T) { + const TestExpectedCount = 1000 + + testStart := time.Now().In(counterLocation) + + var server *httptest.Server + var backendExpected = NewTBackendRequestChecker(t, + TExpectedBackendRequest{200, bPathAnnounceStartup, &url.Values{"startup": []string{"1"}}, "", nil}, + ) + server, _, _ = TSetup(SetupWantSocketServer|SetupWantBackendServer, backendExpected) + + defer server.CloseClientConnections() + defer unsubscribeAllClients() + defer backendExpected.Close() + + dumpUniqueUsers() + + for i := 0; i < TestExpectedCount; i++ { + uuid := uuid.NewV4() + uniqueUserChannel <- uuid + uniqueUserChannel <- uuid + } + + TCheckHLLValue(t, TestExpectedCount, readHLL(periodDaily)) + TCheckHLLValue(t, TestExpectedCount, readHLL(periodWeekly)) + TCheckHLLValue(t, TestExpectedCount, readHLL(periodMonthly)) + + token := <-uniqueCtrWritingToken + uniqueCounters[periodDaily].End = time.Now().In(counterLocation).Add(-1 * time.Second) + uniqueCtrWritingToken <- token + + rolloverCounters_do() + + for i := 0; i < TestExpectedCount; i++ { + uuid := uuid.NewV4() + uniqueUserChannel <- uuid + uniqueUserChannel <- uuid + } + + TCheckHLLValue(t, TestExpectedCount, readHLL(periodDaily)) + TCheckHLLValue(t, TestExpectedCount*2, readHLL(periodWeekly)) + TCheckHLLValue(t, TestExpectedCount*2, readHLL(periodMonthly)) + + // Check: Merging the two days results in 2000 + // note: rolloverCounters_do() wrote out a file, and loadHLL() is reading it back + var loadDest PeriodUniqueUsers + loadHLL(periodDaily, testStart, &loadDest) + + token = <-uniqueCtrWritingToken + loadDest.Counter.Merge(uniqueCounters[periodDaily].Counter) + uniqueCtrWritingToken <- token + + TCheckHLLValue(t, TestExpectedCount*2, loadDest.Counter.Count()) +} + +func TestUniqueUsersCleanup(t *testing.T) { + // Not a test. Removes old files. + os.RemoveAll(uniqCountDir) +}