1
0
Fork 0
mirror of https://github.com/FrankerFaceZ/FrankerFaceZ.git synced 2025-06-28 05:15:54 +00:00

Merge branch 'master' of github.com:FrankerFaceZ/FrankerFaceZ

# Conflicts:
#	socketserver/server/utils.go
This commit is contained in:
SirStendec 2017-09-26 17:02:35 -04:00
commit 7816a33e5d
14 changed files with 458 additions and 315 deletions

View file

@ -0,0 +1,72 @@
// Copyright 2016 Michael Stapelberg, BSD-3
//
// https://stackoverflow.com/a/40883377/1210278
package certreloader
import (
"crypto/tls"
"log"
"os"
"os/signal"
"sync"
)
type CertSource struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
}
// Create a CertSource
func New(certPath, keyPath string) (*CertSource, error) {
result := &CertSource{
certPath: certPath,
keyPath: keyPath,
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}
result.cert = &cert
return result, nil
}
// Automatically reload certificate on the provided signal
func (kpr *CertSource) AutoCheck(sig os.Signal) {
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, sig)
for range c {
log.Printf("Received %v, reloading TLS certificate and key from %q and %q", sig, kpr.certPath, kpr.keyPath)
if err := kpr.maybeReload(); err != nil {
log.Printf("Keeping old TLS certificate because the new one could not be loaded: %v", err)
}
}
}()
}
// Check() can be called manually to reload the certificate
func (kpr *CertSource) Check() error {
return kpr.maybeReload()
}
func (kpr *CertSource) maybeReload() error {
newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
if err != nil {
return err
}
kpr.certMu.Lock()
defer kpr.certMu.Unlock()
kpr.cert = &newCert
return nil
}
// Returns a tls.Config.GetCertificate function.
func (kpr *CertSource) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
kpr.certMu.RLock()
defer kpr.certMu.RUnlock()
return kpr.cert, nil
}
}

View file

@ -1,6 +1,10 @@
package main // import "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/cmd/ffzsocketserver" package main // import "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/cmd/ffzsocketserver"
import _ "net/http/pprof"
import ( import (
"context"
"crypto/tls"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
@ -8,12 +12,15 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/certreloader"
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server" "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server"
) )
import _ "net/http/pprof"
var configFilename = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.") var configFilename = flag.String("config", "config.json", "Configuration file, including the keypairs for the NaCl crypto library, for communicating with the backend.")
var flagGenerateKeys = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.") var flagGenerateKeys = flag.Bool("genkeys", false, "Generate NaCl keys instead of serving requests.\nArguments: [int serverId] [base64 backendPublic]\nThe backend public key can either be specified in base64 on the command line, or put in the json file later.")
@ -56,17 +63,72 @@ func main() {
go commandLineConsole() go commandLineConsole()
var server1, server2 *http.Server
stopSig := make(chan os.Signal, 3)
signal.Notify(stopSig, os.Interrupt)
signal.Notify(stopSig, syscall.SIGUSR1)
signal.Notify(stopSig, syscall.SIGTERM)
if conf.UseSSL { if conf.UseSSL {
reloader, err := certreloader.New(conf.SSLCertificateFile, conf.SSLKeyFile)
if err != nil {
log.Fatalln("Could not load TLS certificate:", err)
}
reloader.AutoCheck(syscall.SIGHUP)
server1 = &http.Server{
Addr: conf.SSLListenAddr,
Handler: http.DefaultServeMux,
TLSConfig: &tls.Config{
GetCertificate: reloader.GetCertificateFunc(),
},
}
go func() { go func() {
if err := http.ListenAndServeTLS(conf.SSLListenAddr, conf.SSLCertificateFile, conf.SSLKeyFile, http.DefaultServeMux); err != nil { if err := server1.ListenAndServeTLS("", ""); err != nil {
log.Fatal("ListenAndServeTLS: ", err) log.Println("ListenAndServeTLS:", err)
stopSig <- os.Interrupt
} }
}() }()
} }
if err = http.ListenAndServe(conf.ListenAddr, http.DefaultServeMux); err != nil { if true {
log.Fatal("ListenAndServe: ", err) server2 = &http.Server{
Addr: conf.ListenAddr,
Handler: http.DefaultServeMux,
}
go func() {
if err := server2.ListenAndServe(); err != nil {
log.Println("ListenAndServe: ", err)
stopSig <- os.Interrupt
}
}()
} }
<-stopSig
log.Println("Shutting down...")
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
if conf.UseSSL {
server1.Shutdown(ctx)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
server2.Shutdown(ctx)
}()
server.Shutdown(&wg)
time.Sleep(1 * time.Second)
wg.Wait()
} }
func generateKeys(outputFile string) { func generateKeys(outputFile string) {

View file

@ -15,8 +15,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/naclform"
cache "github.com/patrickmn/go-cache" cache "github.com/patrickmn/go-cache"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"golang.org/x/sync/singleflight"
) )
const bPathAnnounceStartup = "/startup" const bPathAnnounceStartup = "/startup"
@ -28,13 +30,13 @@ type backendInfo struct {
HTTPClient http.Client HTTPClient http.Client
baseURL string baseURL string
responseCache *cache.Cache responseCache *cache.Cache
reloadGroup singleflight.Group
postStatisticsURL string postStatisticsURL string
addTopicURL string addTopicURL string
announceStartupURL string announceStartupURL string
sharedKey [32]byte secureForm naclform.ServerInfo
serverID int
lastSuccess map[string]time.Time lastSuccess map[string]time.Time
lastSuccessLock sync.Mutex lastSuccessLock sync.Mutex
@ -45,11 +47,12 @@ var Backend *backendInfo
func setupBackend(config *ConfigFile) *backendInfo { func setupBackend(config *ConfigFile) *backendInfo {
b := new(backendInfo) b := new(backendInfo)
Backend = b Backend = b
b.serverID = config.ServerID b.secureForm.ServerID = config.ServerID
b.HTTPClient.Timeout = 60 * time.Second b.HTTPClient.Timeout = 60 * time.Second
b.baseURL = config.BackendURL b.baseURL = config.BackendURL
b.responseCache = cache.New(60*time.Second, 120*time.Second) // size in bytes of string payload
b.responseCache = cache.New(60*time.Second, 10*time.Minute)
b.announceStartupURL = fmt.Sprintf("%s%s", b.baseURL, bPathAnnounceStartup) b.announceStartupURL = fmt.Sprintf("%s%s", b.baseURL, bPathAnnounceStartup)
b.addTopicURL = fmt.Sprintf("%s%s", b.baseURL, bPathAddTopic) b.addTopicURL = fmt.Sprintf("%s%s", b.baseURL, bPathAddTopic)
@ -68,7 +71,7 @@ func setupBackend(config *ConfigFile) *backendInfo {
copy(theirPublic[:], config.BackendPublicKey) copy(theirPublic[:], config.BackendPublicKey)
copy(ourPrivate[:], config.OurPrivateKey) copy(ourPrivate[:], config.OurPrivateKey)
box.Precompute(&b.sharedKey, &theirPublic, &ourPrivate) box.Precompute(&b.secureForm.SharedKey, &theirPublic, &ourPrivate)
return b return b
} }
@ -88,12 +91,18 @@ func (bfe ErrForwardedFromBackend) Error() string {
} }
// ErrAuthorizationNeeded is emitted when the backend replies with HTTP 401. // ErrAuthorizationNeeded is emitted when the backend replies with HTTP 401.
//
// Indicates that an attempt to validate `ClientInfo.TwitchUsername` should be attempted. // Indicates that an attempt to validate `ClientInfo.TwitchUsername` should be attempted.
var ErrAuthorizationNeeded = errors.New("Must authenticate Twitch username to use this command") var ErrAuthorizationNeeded = errors.New("Must authenticate Twitch username to use this command")
// SendRemoteCommandCached performs a RPC call on the backend, but caches responses. // SendRemoteCommandCached performs a RPC call on the backend, checking for a
// cached response first.
//
// If a cached, but expired, response is found, the existing value is returned
// and the cache is updated in the background.
func (backend *backendInfo) SendRemoteCommandCached(remoteCommand, data string, auth AuthInfo) (string, error) { func (backend *backendInfo) SendRemoteCommandCached(remoteCommand, data string, auth AuthInfo) (string, error) {
cached, ok := backend.responseCache.Get(getCacheKey(remoteCommand, data)) cacheKey := getCacheKey(remoteCommand, data)
cached, ok := backend.responseCache.Get(cacheKey)
if ok { if ok {
return cached.(string), nil return cached.(string), nil
} }
@ -101,9 +110,21 @@ func (backend *backendInfo) SendRemoteCommandCached(remoteCommand, data string,
} }
// SendRemoteCommand performs a RPC call on the backend by POSTing to `/cmd/$remoteCommand`. // SendRemoteCommand performs a RPC call on the backend by POSTing to `/cmd/$remoteCommand`.
//
// The form data is as follows: `clientData` is the JSON in the `data` parameter // The form data is as follows: `clientData` is the JSON in the `data` parameter
// (should be retrieved from ClientMessage.Arguments), and either `username` or // (should be retrieved from ClientMessage.Arguments), `username` is AuthInfo.TwitchUsername,
// `usernameClaimed` depending on whether AuthInfo.UsernameValidates is true is AuthInfo.TwitchUsername. // and `authenticated` is 1 or 0 depending on AuthInfo.UsernameValidated.
//
// 401 responses return an ErrAuthorizationNeeded.
//
// Non-2xx responses return the response body as an error to the client (application/json
// responses are sent as-is, non-json are sent as a JSON string).
//
// If a 2xx response has the FFZ-Cache header, its value is used as a minimum number of
// seconds to cache the response for. (Responses may be cached for longer, see
// SendRemoteCommandCached and the cache implementation.)
//
// A successful response updates the Statistics.Health.Backend map.
func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) { func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth AuthInfo) (responseStr string, err error) {
destURL := fmt.Sprintf("%s/cmd/%s", backend.baseURL, remoteCommand) destURL := fmt.Sprintf("%s/cmd/%s", backend.baseURL, remoteCommand)
healthBucket := fmt.Sprintf("/cmd/%s", remoteCommand) healthBucket := fmt.Sprintf("/cmd/%s", remoteCommand)
@ -119,7 +140,7 @@ func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth A
formData.Set("authenticated", "0") formData.Set("authenticated", "0")
} }
sealedForm, err := backend.SealRequest(formData) sealedForm, err := backend.secureForm.Seal(formData)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -141,9 +162,11 @@ func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth A
return "", ErrAuthorizationNeeded return "", ErrAuthorizationNeeded
} else if resp.StatusCode < 200 || resp.StatusCode > 299 { // any non-2xx } else if resp.StatusCode < 200 || resp.StatusCode > 299 { // any non-2xx
// If the Content-Type header includes a charset, ignore it. // If the Content-Type header includes a charset, ignore it.
// typeStr, _, _ = mime.ParseMediaType(resp.Header.Get("Content-Type"))
// inline the part of the function we care about
typeStr := resp.Header.Get("Content-Type") typeStr := resp.Header.Get("Content-Type")
splitIdx := strings.IndexRune(typeStr, ';') splitIdx := strings.IndexRune(typeStr, ';')
if ( splitIdx != -1 ) { if splitIdx != -1 {
typeStr = strings.TrimSpace(typeStr[0:splitIdx]) typeStr = strings.TrimSpace(typeStr[0:splitIdx])
} }
@ -164,7 +187,11 @@ func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth A
return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err) return "", fmt.Errorf("The RPC server returned a non-integer cache duration: %v", err)
} }
duration := time.Duration(durSecs) * time.Second duration := time.Duration(durSecs) * time.Second
backend.responseCache.Set(getCacheKey(remoteCommand, data), responseStr, duration) backend.responseCache.Set(
getCacheKey(remoteCommand, data),
responseStr,
duration,
)
} }
now := time.Now().UTC() now := time.Now().UTC()
@ -178,7 +205,7 @@ func (backend *backendInfo) SendRemoteCommand(remoteCommand, data string, auth A
// SendAggregatedData sends aggregated emote usage and following data to the backend server. // SendAggregatedData sends aggregated emote usage and following data to the backend server.
func (backend *backendInfo) SendAggregatedData(form url.Values) error { func (backend *backendInfo) SendAggregatedData(form url.Values) error {
sealedForm, err := backend.SealRequest(form) sealedForm, err := backend.secureForm.Seal(form)
if err != nil { if err != nil {
return err return err
} }
@ -235,7 +262,7 @@ func (backend *backendInfo) sendTopicNotice(topic string, added bool) error {
formData.Set("added", "f") formData.Set("added", "f")
} }
sealedForm, err := backend.SealRequest(formData) sealedForm, err := backend.secureForm.Seal(formData)
if err != nil { if err != nil {
return err return err
} }

View file

@ -18,14 +18,14 @@ func TestSealRequest(t *testing.T) {
"QuickBrownFox": []string{"LazyDog"}, "QuickBrownFox": []string{"LazyDog"},
} }
sealedValues, err := b.SealRequest(values) sealedValues, err := b.secureForm.Seal(values)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// sealedValues.Encode() // sealedValues.Encode()
// id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV // id=0&msg=KKtbng49dOLLyjeuX5AnXiEe6P0uZwgeP_7mMB5vhP-wMAAPZw%3D%3D&nonce=-wRbUnifscisWUvhm3gBEXHN5QzrfzgV
unsealedValues, err := b.UnsealRequest(sealedValues) unsealedValues, err := b.secureForm.Unseal(sealedValues)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -2,11 +2,9 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log" "log"
"net/url" "net/url"
"strconv"
"sync" "sync"
"time" "time"
@ -34,14 +32,17 @@ var commandHandlers = map[Command]CommandHandler{
"track_follow": C2STrackFollow, "track_follow": C2STrackFollow,
"emoticon_uses": C2SEmoticonUses, "emoticon_uses": C2SEmoticonUses,
"survey": C2SSurvey, "survey": C2SSurvey,
}
"twitch_emote": C2SHandleBunchedCommand, var bunchedCommands = []Command{
"get_link": C2SHandleBunchedCommand, "get_display_name",
"get_display_name": C2SHandleBunchedCommand, "get_emote",
"get_emote": C2SHandleBunchedCommand, "get_emote_set",
"get_emote_set": C2SHandleBunchedCommand, "get_link",
"has_logs": C2SHandleBunchedCommand, "get_itad_plain",
"update_follow_buttons": C2SHandleRemoteCommand, "get_itad_prices",
"get_name_history",
"has_logs",
} }
func setupInterning() { func setupInterning() {
@ -73,6 +74,12 @@ func DispatchC2SCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMess
handler, ok := commandHandlers[msg.Command] handler, ok := commandHandlers[msg.Command]
if !ok { if !ok {
handler = C2SHandleRemoteCommand handler = C2SHandleRemoteCommand
for _, v := range bunchedCommands {
if msg.Command == v {
handler = C2SHandleBunchedCommand
}
}
} }
CommandCounter <- msg.Command CommandCounter <- msg.Command
@ -96,7 +103,7 @@ func DispatchC2SCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMess
} }
} }
func callHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) { func callHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (_ ClientMessage, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
var ok bool var ok bool
@ -112,7 +119,7 @@ func callHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInf
// C2SHello implements the `hello` C2S Command. // C2SHello implements the `hello` C2S Command.
// It calls SubscribeGlobal() and SubscribeDefaults() with the client, and fills out ClientInfo.Version and ClientInfo.ClientID. // It calls SubscribeGlobal() and SubscribeDefaults() with the client, and fills out ClientInfo.Version and ClientInfo.ClientID.
func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SHello(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (_ ClientMessage, err error) {
ary, ok := msg.Arguments.([]interface{}) ary, ok := msg.Arguments.([]interface{})
if !ok { if !ok {
err = ErrExpectedTwoStrings err = ErrExpectedTwoStrings
@ -163,16 +170,16 @@ func C2SHello(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg
}, nil }, nil
} }
func C2SPing(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SPing(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) {
return ClientMessage{ return ClientMessage{
Arguments: float64(time.Now().UnixNano()/1000) / 1000, Arguments: float64(time.Now().UnixNano()/1000) / 1000,
}, nil }, nil
} }
func C2SSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SSetUser(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
username, err := msg.ArgumentsAsString() username, err := msg.ArgumentsAsString()
if err != nil { if err != nil {
return return ClientMessage{}, err
} }
username = copyString(username) username = copyString(username)
@ -192,29 +199,24 @@ func C2SSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rm
return ResponseSuccess, nil return ResponseSuccess, nil
} }
func C2SReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SReady(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
// disconnectAt, err := msg.ArgumentsAsInt()
// if err != nil {
// return
// }
client.Mutex.Lock() client.Mutex.Lock()
client.ReadyComplete = true client.ReadyComplete = true
client.Mutex.Unlock() client.Mutex.Unlock()
client.MsgChannelKeepalive.Add(1) client.MsgChannelKeepalive.Add(1)
go func() { go func() {
client.Send(ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand}) client.Send(msg.Reply(SuccessCommand, nil))
SendBacklogForNewClient(client) SendBacklogForNewClient(client)
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
}() }()
return ClientMessage{Command: AsyncResponseCommand}, nil return ClientMessage{Command: AsyncResponseCommand}, nil
} }
func C2SSubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SSubscribe(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
channel, err := msg.ArgumentsAsString() channel, err := msg.ArgumentsAsString()
if err != nil { if err != nil {
return return ClientMessage{}, err
} }
channel = PubSubChannelPool.Intern(channel) channel = PubSubChannelPool.Intern(channel)
@ -238,10 +240,10 @@ func C2SSubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (
// C2SUnsubscribe implements the `unsub` C2S Command. // C2SUnsubscribe implements the `unsub` C2S Command.
// It removes the channel from ClientInfo.CurrentChannels and calls UnsubscribeSingleChat. // It removes the channel from ClientInfo.CurrentChannels and calls UnsubscribeSingleChat.
func C2SUnsubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SUnsubscribe(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
channel, err := msg.ArgumentsAsString() channel, err := msg.ArgumentsAsString()
if err != nil { if err != nil {
return return ClientMessage{}, err
} }
channel = PubSubChannelPool.Intern(channel) channel = PubSubChannelPool.Intern(channel)
@ -256,9 +258,8 @@ func C2SUnsubscribe(conn *websocket.Conn, client *ClientInfo, msg ClientMessage)
} }
// C2SSurvey implements the survey C2S Command. // C2SSurvey implements the survey C2S Command.
// Surveys are discarded.s func C2SSurvey(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) {
func C2SSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { // Surveys are not collected.
// Discard
return ResponseSuccess, nil return ResponseSuccess, nil
} }
@ -276,7 +277,7 @@ var followEventsLock sync.Mutex
// C2STrackFollow implements the `track_follow` C2S Command. // C2STrackFollow implements the `track_follow` C2S Command.
// It adds the record to `followEvents`, which is submitted to the backend on a timer. // It adds the record to `followEvents`, which is submitted to the backend on a timer.
func C2STrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2STrackFollow(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (_ ClientMessage, err error) {
channel, following, err := msg.ArgumentsAsStringAndBool() channel, following, err := msg.ArgumentsAsStringAndBool()
if err != nil { if err != nil {
return return
@ -293,68 +294,18 @@ func C2STrackFollow(conn *websocket.Conn, client *ClientInfo, msg ClientMessage)
} }
// AggregateEmoteUsage is a map from emoteID to a map from chatroom name to usage count. // AggregateEmoteUsage is a map from emoteID to a map from chatroom name to usage count.
var aggregateEmoteUsage = make(map[int]map[string]int) //var aggregateEmoteUsage = make(map[int]map[string]int)
// AggregateEmoteUsageLock is the lock for AggregateEmoteUsage. // AggregateEmoteUsageLock is the lock for AggregateEmoteUsage.
var aggregateEmoteUsageLock sync.Mutex //var aggregateEmoteUsageLock sync.Mutex
// ErrNegativeEmoteUsage is emitted when the submitted emote usage is negative. // ErrNegativeEmoteUsage is emitted when the submitted emote usage is negative.
var ErrNegativeEmoteUsage = errors.New("Emote usage count cannot be negative") //var ErrNegativeEmoteUsage = errors.New("Emote usage count cannot be negative")
// C2SEmoticonUses implements the `emoticon_uses` C2S Command. // C2SEmoticonUses implements the `emoticon_uses` C2S Command.
// msg.Arguments are in the JSON format of [1]map[emoteID]map[ChatroomName]float64. // msg.Arguments are in the JSON format of [1]map[emoteID]map[ChatroomName]float64.
func C2SEmoticonUses(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SEmoticonUses(*websocket.Conn, *ClientInfo, ClientMessage) (ClientMessage, error) {
// if this panics, will be caught by callHandler // We do not collect emote usage data
mapRoot := msg.Arguments.([]interface{})[0].(map[string]interface{})
// Validate: male suire
for strEmote, val1 := range mapRoot {
_, err = strconv.Atoi(strEmote)
if err != nil {
return
}
mapInner := val1.(map[string]interface{})
for _, val2 := range mapInner {
var count = int(val2.(float64))
if count <= 0 {
err = ErrNegativeEmoteUsage
return
}
}
}
aggregateEmoteUsageLock.Lock()
defer aggregateEmoteUsageLock.Unlock()
var total int
for strEmote, val1 := range mapRoot {
var emoteID int
emoteID, err = strconv.Atoi(strEmote)
if err != nil {
return
}
destMapInner, ok := aggregateEmoteUsage[emoteID]
if !ok {
destMapInner = make(map[string]int)
aggregateEmoteUsage[emoteID] = destMapInner
}
mapInner := val1.(map[string]interface{})
for roomName, val2 := range mapInner {
var count = int(val2.(float64))
if count > 200 {
count = 200
}
roomName = TwitchChannelPool.Intern(roomName)
destMapInner[roomName] += count
total += count
}
}
Statistics.EmotesReportedTotal += uint64(total)
return ResponseSuccess, nil return ResponseSuccess, nil
} }
@ -371,10 +322,10 @@ func aggregateDataSender_do() {
follows := followEvents follows := followEvents
followEvents = nil followEvents = nil
followEventsLock.Unlock() followEventsLock.Unlock()
aggregateEmoteUsageLock.Lock() //aggregateEmoteUsageLock.Lock()
emoteUsage := aggregateEmoteUsage //emoteUsage := aggregateEmoteUsage
aggregateEmoteUsage = make(map[int]map[string]int) //aggregateEmoteUsage = make(map[int]map[string]int)
aggregateEmoteUsageLock.Unlock() //aggregateEmoteUsageLock.Unlock()
reportForm := url.Values{} reportForm := url.Values{}
@ -386,10 +337,10 @@ func aggregateDataSender_do() {
} }
strEmoteUsage := make(map[string]map[string]int) strEmoteUsage := make(map[string]map[string]int)
for emoteID, usageByChannel := range emoteUsage { //for emoteID, usageByChannel := range emoteUsage {
strEmoteID := strconv.Itoa(emoteID) // strEmoteID := strconv.Itoa(emoteID)
strEmoteUsage[strEmoteID] = usageByChannel // strEmoteUsage[strEmoteID] = usageByChannel
} //}
emoteJSON, err := json.Marshal(strEmoteUsage) emoteJSON, err := json.Marshal(strEmoteUsage)
if err != nil { if err != nil {
log.Println("error reporting aggregate data:", err) log.Println("error reporting aggregate data:", err)
@ -429,7 +380,7 @@ var bunchGroup singleflight.Group
// C2SHandleBunchedCommand handles C2S Commands such as `get_link`. // C2SHandleBunchedCommand handles C2S Commands such as `get_link`.
// It makes a request to the backend server for the data, but any other requests coming in while the first is pending also get the responses from the first one. // It makes a request to the backend server for the data, but any other requests coming in while the first is pending also get the responses from the first one.
func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) { func C2SHandleBunchedCommand(_ *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
key := fmt.Sprintf("%s:%s", msg.Command, msg.origArguments) key := fmt.Sprintf("%s:%s", msg.Command, msg.origArguments)
resultCh := bunchGroup.DoChan(key, func() (interface{}, error) { resultCh := bunchGroup.DoChan(key, func() (interface{}, error) {
@ -439,29 +390,21 @@ func C2SHandleBunchedCommand(conn *websocket.Conn, client *ClientInfo, msg Clien
client.MsgChannelKeepalive.Add(1) client.MsgChannelKeepalive.Add(1)
go func() { go func() {
result := <-resultCh result := <-resultCh
var reply ClientMessage if efb, ok := result.Err.(ErrForwardedFromBackend); ok {
reply.MessageID = msg.MessageID client.Send(msg.Reply(ErrorCommand, efb.JSONError))
if result.Err != nil { } else if result.Err != nil {
reply.Command = ErrorCommand client.Send(msg.Reply(ErrorCommand, result.Err.Error()))
if efb, ok := result.Err.(ErrForwardedFromBackend); ok {
reply.Arguments = efb.JSONError
} else {
reply.Arguments = result.Err.Error()
}
} else { } else {
reply.Command = SuccessCommand client.Send(msg.ReplyJSON(SuccessCommand, result.Val.(string)))
reply.origArguments = result.Val.(string)
reply.parseOrigArguments()
} }
client.Send(reply)
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
}() }()
return ClientMessage{Command: AsyncResponseCommand}, nil return ClientMessage{Command: AsyncResponseCommand}, nil
} }
func C2SHandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { func C2SHandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (ClientMessage, error) {
client.MsgChannelKeepalive.Add(1) client.MsgChannelKeepalive.Add(1)
go doRemoteCommand(conn, msg, client) go doRemoteCommand(conn, msg, client)
@ -477,7 +420,7 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if err == ErrAuthorizationNeeded { if err == ErrAuthorizationNeeded {
if client.TwitchUsername == "" { if client.TwitchUsername == "" {
// Not logged in // Not logged in
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationNeededError}) client.Send(msg.Reply(ErrorCommand, AuthorizationNeededError))
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
return return
} }
@ -485,19 +428,17 @@ func doRemoteCommand(conn *websocket.Conn, msg ClientMessage, client *ClientInfo
if success { if success {
doRemoteCommand(conn, msg, client) doRemoteCommand(conn, msg, client)
} else { } else {
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: AuthorizationFailedErrorString}) client.Send(msg.Reply(ErrorCommand, AuthorizationFailedErrorString))
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }
}) })
return // without keepalive.Done() return // without keepalive.Done()
} else if bfe, ok := err.(ErrForwardedFromBackend); ok { } else if bfe, ok := err.(ErrForwardedFromBackend); ok {
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: bfe.JSONError}) client.Send(msg.Reply(ErrorCommand, bfe.JSONError))
} else if err != nil { } else if err != nil {
client.Send(ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()}) client.Send(msg.Reply(ErrorCommand, err.Error()))
} else { } else {
msg := ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} client.Send(msg.ReplyJSON(SuccessCommand, resp))
msg.parseOrigArguments()
client.Send(msg)
} }
client.MsgChannelKeepalive.Done() client.MsgChannelKeepalive.Done()
} }

View file

@ -103,8 +103,9 @@ func SetupServerAndHandle(config *ConfigFile, serveMux *http.ServeMux) {
serveMux.HandleFunc("/uncached_pub", HTTPBackendUncachedPublish) serveMux.HandleFunc("/uncached_pub", HTTPBackendUncachedPublish)
serveMux.HandleFunc("/cached_pub", HTTPBackendCachedPublish) serveMux.HandleFunc("/cached_pub", HTTPBackendCachedPublish)
serveMux.HandleFunc("/get_sub_count", HTTPGetSubscriberCount) serveMux.HandleFunc("/get_sub_count", HTTPGetSubscriberCount)
serveMux.HandleFunc("/all_topics", HTTPListAllTopics)
announceForm, err := Backend.SealRequest(url.Values{ announceForm, err := Backend.secureForm.Seal(url.Values{
"startup": []string{"1"}, "startup": []string{"1"},
}) })
if err != nil { if err != nil {
@ -138,30 +139,21 @@ func startJanitors() {
go pubsubJanitor() go pubsubJanitor()
go ircConnection() go ircConnection()
go shutdownHandler()
} }
// is_init_func // Shutdown disconnects all clients.
func shutdownHandler() { func Shutdown(wg *sync.WaitGroup) {
ch := make(chan os.Signal)
signal.Notify(ch, syscall.SIGUSR1)
signal.Notify(ch, syscall.SIGTERM)
<-ch
log.Println("Shutting down...")
var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done()
writeHLL() writeHLL()
wg.Done()
}() }()
wg.Add(1)
StopAcceptingConnections = true go func() {
close(StopAcceptingConnectionsCh) defer wg.Done()
close(StopAcceptingConnectionsCh)
time.Sleep(1 * time.Second) time.Sleep(2 * time.Second)
wg.Wait() }()
os.Exit(0)
} }
// is_init_func +test // is_init_func +test
@ -199,7 +191,6 @@ var BannerHTML []byte
// StopAcceptingConnectionsCh is closed while the server is shutting down. // StopAcceptingConnectionsCh is closed while the server is shutting down.
var StopAcceptingConnectionsCh = make(chan struct{}) var StopAcceptingConnectionsCh = make(chan struct{})
var StopAcceptingConnections = false
// HTTPHandleRootURL is the http.HandleFunc for requests on `/`. // HTTPHandleRootURL is the http.HandleFunc for requests on `/`.
// It either uses the SocketUpgrader or writes out the BannerHTML. // It either uses the SocketUpgrader or writes out the BannerHTML.
@ -210,18 +201,10 @@ func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) {
return return
} }
// racy, but should be ok?
if StopAcceptingConnections {
w.WriteHeader(503)
fmt.Fprint(w, "server is shutting down")
return
}
if strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") { if strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {
updateSysMem() updateSysMem()
if Statistics.SysMemFreeKB > 0 && Statistics.SysMemFreeKB < Configuration.MinMemoryKBytes { if Statistics.SysMemFreeKB > 0 && Statistics.SysMemFreeKB < Configuration.MinMemoryKBytes {
atomic.AddUint64(&Statistics.LowMemDroppedConnections, 1)
w.WriteHeader(503) w.WriteHeader(503)
fmt.Fprint(w, "error: low memory") fmt.Fprint(w, "error: low memory")
return return
@ -249,11 +232,21 @@ func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) {
} }
} }
type fatalDecodeError string
func (e fatalDecodeError) Error() string {
return string(e)
}
func (e fatalDecodeError) IsFatal() bool {
return true
}
// ErrProtocolGeneric is sent in a ErrorCommand Reply. // ErrProtocolGeneric is sent in a ErrorCommand Reply.
var ErrProtocolGeneric error = errors.New("FFZ Socket protocol error.") var ErrProtocolGeneric error = fatalDecodeError("FFZ Socket protocol error.")
// ErrProtocolNegativeMsgID is sent in a ErrorCommand Reply when a negative MessageID is received. // ErrProtocolNegativeMsgID is sent in a ErrorCommand Reply when a negative MessageID is received.
var ErrProtocolNegativeMsgID error = errors.New("FFZ Socket protocol error: negative or zero message ID.") var ErrProtocolNegativeMsgID error = fatalDecodeError("FFZ Socket protocol error: negative or zero message ID.")
// ErrExpectedSingleString is sent in a ErrorCommand Reply when the Arguments are of the wrong type. // ErrExpectedSingleString is sent in a ErrorCommand Reply when the Arguments are of the wrong type.
var ErrExpectedSingleString = errors.New("Error: Expected single string as arguments.") var ErrExpectedSingleString = errors.New("Error: Expected single string as arguments.")
@ -344,7 +337,7 @@ func RunSocketConnection(conn *websocket.Conn) {
}) })
// All set up, now enter the work loop // All set up, now enter the work loop
go runSocketReader(conn, _errorChan, _clientChan, stoppedChan) go runSocketReader(conn, &client, _errorChan, _clientChan)
closeReason := runSocketWriter(conn, &client, _errorChan, _clientChan, _serverMessageChan) closeReason := runSocketWriter(conn, &client, _errorChan, _clientChan, _serverMessageChan)
// Exit // Exit
@ -365,8 +358,10 @@ func RunSocketConnection(conn *websocket.Conn) {
// And done. // And done.
if !StopAcceptingConnections { select {
case <-StopAcceptingConnectionsCh:
// Don't perform high contention operations when server is closing // Don't perform high contention operations when server is closing
default:
atomic.AddUint64(&Statistics.CurrentClientCount, NegativeOne) atomic.AddUint64(&Statistics.CurrentClientCount, NegativeOne)
atomic.AddUint64(&Statistics.ClientDisconnectsTotal, 1) atomic.AddUint64(&Statistics.ClientDisconnectsTotal, 1)
@ -376,12 +371,14 @@ func RunSocketConnection(conn *websocket.Conn) {
} }
} }
func runSocketReader(conn *websocket.Conn, errorChan chan<- error, clientChan chan<- ClientMessage, stoppedChan <-chan struct{}) { func runSocketReader(conn *websocket.Conn, client *ClientInfo, errorChan chan<- error, clientChan chan<- ClientMessage) {
var msg ClientMessage var msg ClientMessage
var messageType int var messageType int
var packet []byte var packet []byte
var err error var err error
stoppedChan := client.MsgChannelIsDone
defer close(errorChan) defer close(errorChan)
defer close(clientChan) defer close(clientChan)
@ -395,8 +392,17 @@ func runSocketReader(conn *websocket.Conn, errorChan chan<- error, clientChan ch
break break
} }
UnmarshalClientMessage(packet, messageType, &msg) msg = ClientMessage{}
if msg.MessageID == 0 { msgErr := UnmarshalClientMessage(packet, messageType, &msg)
if _, ok := msgErr.(interface {
IsFatal() bool
}); ok {
errorChan <- msgErr
continue
} else if msgErr != nil {
client.Send(msg.Reply(ErrorCommand, msgErr.Error()))
continue
} else if msg.MessageID == 0 {
continue continue
} }
select { select {
@ -505,20 +511,24 @@ func SendMessage(conn *websocket.Conn, msg ClientMessage) {
} }
// UnmarshalClientMessage unpacks websocket TextMessage into a ClientMessage provided in the `v` parameter. // UnmarshalClientMessage unpacks websocket TextMessage into a ClientMessage provided in the `v` parameter.
func UnmarshalClientMessage(data []byte, payloadType int, v interface{}) (err error) { func UnmarshalClientMessage(data []byte, _ int, v interface{}) (err error) {
var spaceIdx int var spaceIdx int
out := v.(*ClientMessage) out := v.(*ClientMessage)
dataStr := string(data) dataStr := string(data)
if len(dataStr) == 0 {
out.MessageID = 0
return nil // test: ignore empty frames
}
// Message ID // Message ID
spaceIdx = strings.IndexRune(dataStr, ' ') spaceIdx = strings.IndexRune(dataStr, ' ')
if spaceIdx == -1 { if spaceIdx == -1 {
return ErrProtocolGeneric return ErrProtocolGeneric // fatal error
} }
messageID, err := strconv.Atoi(dataStr[:spaceIdx]) messageID, err := strconv.Atoi(dataStr[:spaceIdx])
if messageID < -1 || messageID == 0 { if messageID < -1 || messageID == 0 {
return ErrProtocolNegativeMsgID return ErrProtocolNegativeMsgID // fatal error
} }
out.MessageID = messageID out.MessageID = messageID
@ -551,7 +561,8 @@ func (cm *ClientMessage) parseOrigArguments() error {
return nil return nil
} }
func MarshalClientMessage(clientMessage interface{}) (payloadType int, data []byte, err error) { // returns payloadType, data, err
func MarshalClientMessage(clientMessage interface{}) (int, []byte, error) {
var msg ClientMessage var msg ClientMessage
var ok bool var ok bool
msg, ok = clientMessage.(ClientMessage) msg, ok = clientMessage.(ClientMessage)

View file

@ -0,0 +1,96 @@
package naclform
import (
"bytes"
"crypto/rand"
"encoding/base64"
"errors"
"net/url"
"strconv"
"strings"
"golang.org/x/crypto/nacl/box"
)
var ErrorShortNonce = errors.New("Nonce too short.")
var ErrorInvalidSignature = errors.New("Invalid signature or contents")
type ServerInfo struct {
SharedKey [32]byte
ServerID int
}
func fillCryptoRandom(buf []byte) error {
remaining := len(buf)
for remaining > 0 {
count, err := rand.Read(buf)
if err != nil {
return err
}
remaining -= count
}
return nil
}
func (i *ServerInfo) Seal(form url.Values) (url.Values, error) {
var nonce [24]byte
var err error
err = fillCryptoRandom(nonce[:])
if err != nil {
return nil, err
}
cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &i.SharedKey)
bufMessage := new(bytes.Buffer)
enc := base64.NewEncoder(base64.URLEncoding, bufMessage)
enc.Write(cipherMsg)
enc.Close()
cipherString := bufMessage.String()
bufNonce := new(bytes.Buffer)
enc = base64.NewEncoder(base64.URLEncoding, bufNonce)
enc.Write(nonce[:])
enc.Close()
nonceString := bufNonce.String()
retval := url.Values{
"nonce": []string{nonceString},
"msg": []string{cipherString},
"id": []string{strconv.Itoa(i.ServerID)},
}
return retval, nil
}
func (i *ServerInfo) Unseal(form url.Values) (url.Values, error) {
var nonce [24]byte
nonceString := form.Get("nonce")
dec := base64.NewDecoder(base64.URLEncoding, strings.NewReader(nonceString))
count, err := dec.Read(nonce[:])
if err != nil {
return nil, err
}
if count != 24 {
return nil, ErrorShortNonce
}
cipherString := form.Get("msg")
dec = base64.NewDecoder(base64.URLEncoding, strings.NewReader(cipherString))
cipherBuffer := new(bytes.Buffer)
cipherBuffer.ReadFrom(dec)
message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &i.SharedKey)
if !ok {
return nil, ErrorInvalidSignature
}
retValues, err := url.ParseQuery(string(message))
if err != nil {
return nil, ErrorInvalidSignature
}
return retValues, nil
}

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
@ -10,6 +11,7 @@ import (
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate" "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/sync/singleflight"
) )
// LastSavedMessage contains a reply to a command along with an expiration time. // LastSavedMessage contains a reply to a command along with an expiration time.
@ -25,6 +27,8 @@ type LastSavedMessage struct {
var CachedLastMessages = make(map[Command]map[string]LastSavedMessage) var CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
var CachedLSMLock sync.RWMutex var CachedLSMLock sync.RWMutex
var singleFlighter singleflight.Group
func cachedMessageJanitor() { func cachedMessageJanitor() {
for { for {
time.Sleep(1 * time.Hour) time.Sleep(1 * time.Hour)
@ -123,7 +127,7 @@ func saveLastMessage(cmd Command, channel string, expires time.Time, data string
func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) { func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
formData, err := Backend.UnsealRequest(r.Form) formData, err := Backend.secureForm.Unseal(r.Form)
if err != nil { if err != nil {
w.WriteHeader(403) w.WriteHeader(403)
fmt.Fprintf(w, "Error: %v", err) fmt.Fprintf(w, "Error: %v", err)
@ -160,7 +164,7 @@ func rateLimitFromRequest(r *http.Request) (rate.Limiter, error) {
// If the 'expires' parameter is not specified, the message will not expire (though it is only kept in-memory). // If the 'expires' parameter is not specified, the message will not expire (though it is only kept in-memory).
func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) { func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
formData, err := Backend.UnsealRequest(r.Form) formData, err := Backend.secureForm.Unseal(r.Form)
if err != nil { if err != nil {
w.WriteHeader(403) w.WriteHeader(403)
fmt.Fprintf(w, "Error: %v", err) fmt.Fprintf(w, "Error: %v", err)
@ -227,7 +231,7 @@ func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
// If "scope" is "global", then "channel" is not used. // If "scope" is "global", then "channel" is not used.
func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) { func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
formData, err := Backend.UnsealRequest(r.Form) formData, err := Backend.secureForm.Unseal(r.Form)
if err != nil { if err != nil {
w.WriteHeader(403) w.WriteHeader(403)
fmt.Fprintf(w, "Error: %v", err) fmt.Fprintf(w, "Error: %v", err)
@ -292,7 +296,7 @@ func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
// A "global" option is not available, use fetch(/stats).CurrentClientCount instead. // A "global" option is not available, use fetch(/stats).CurrentClientCount instead.
func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) { func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
formData, err := Backend.UnsealRequest(r.Form) formData, err := Backend.secureForm.Unseal(r.Form)
if err != nil { if err != nil {
w.WriteHeader(403) w.WriteHeader(403)
fmt.Fprintf(w, "Error: %v", err) fmt.Fprintf(w, "Error: %v", err)
@ -303,3 +307,19 @@ func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ","))) fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ",")))
} }
func HTTPListAllTopics(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_, err := Backend.secureForm.Unseal(r.Form)
if err != nil {
//w.WriteHeader(403)
//fmt.Fprintf(w, "Error: %v", err)
//return
}
topicList, _, _ := singleFlighter.Do("/all_topics", func() (interface{}, error) {
return GetAllTopics(), nil
})
w.WriteHeader(200)
json.NewEncoder(w).Encode(topicList)
}

View file

@ -38,9 +38,8 @@ type StatsData struct {
MemoryInUseKB uint64 MemoryInUseKB uint64
MemoryRSSKB uint64 MemoryRSSKB uint64
LowMemDroppedConnections uint64 ResponseCacheItems int
MemPerClientBytes uint64
MemPerClientBytes uint64
CpuUsagePct float64 CpuUsagePct float64
@ -84,7 +83,7 @@ func commandCounter() {
} }
// StatsDataVersion is the version of the StatsData struct. // StatsDataVersion is the version of the StatsData struct.
const StatsDataVersion = 7 const StatsDataVersion = 8
const pageSize = 4096 const pageSize = 4096
var cpuUsage struct { var cpuUsage struct {
@ -170,6 +169,7 @@ func updatePeriodicStats() {
{ {
Statistics.Uptime = nowUpdate.Sub(Statistics.StartTime).String() Statistics.Uptime = nowUpdate.Sub(Statistics.StartTime).String()
Statistics.ResponseCacheItems = Backend.responseCache.ItemCount()
} }
{ {
@ -209,7 +209,7 @@ func updateSysMem() {
} }
// HTTPShowStatistics handles the /stats endpoint. It writes out the Statistics object as indented JSON. // HTTPShowStatistics handles the /stats endpoint. It writes out the Statistics object as indented JSON.
func HTTPShowStatistics(w http.ResponseWriter, r *http.Request) { func HTTPShowStatistics(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
updateStatsIfNeeded() updateStatsIfNeeded()

View file

@ -47,6 +47,20 @@ func CountSubscriptions(channels []string) int {
return count return count
} }
func GetAllTopics() []string {
ChatSubscriptionLock.RLock()
defer ChatSubscriptionLock.RUnlock()
count := len(ChatSubscriptionInfo)
list := make([]string, count)
i := 0
for topicName := range ChatSubscriptionInfo {
list[i] = topicName
i++
}
return list
}
func SubscribeChannel(client *ClientInfo, channelName string) { func SubscribeChannel(client *ClientInfo, channelName string) {
ChatSubscriptionLock.RLock() ChatSubscriptionLock.RLock()
_subscribeWhileRlocked(channelName, client) _subscribeWhileRlocked(channelName, client)
@ -142,8 +156,11 @@ func UnsubscribeSingleChat(client *ClientInfo, channelName string) {
// - write lock to SubscriptionInfos // - write lock to SubscriptionInfos
// - write lock to ClientInfo // - write lock to ClientInfo
func UnsubscribeAll(client *ClientInfo) { func UnsubscribeAll(client *ClientInfo) {
if StopAcceptingConnections { select {
return // no need to remove from a high-contention list when the server is closing case <-StopAcceptingConnectionsCh:
// Skip high-contention client removal operations while server shutting down
return
default:
} }
GlobalSubscriptionLock.Lock() GlobalSubscriptionLock.Lock()

View file

@ -97,7 +97,7 @@ func (er *TExpectedBackendRequest) String() string {
if MethodIsPost == "" { if MethodIsPost == "" {
return er.Path return er.Path
} }
return fmt.Sprint("%s %s: %s", MethodIsPost, er.Path, er.PostForm.Encode()) return fmt.Sprintf("%s %s: %s", MethodIsPost, er.Path, er.PostForm.Encode())
} }
type TBackendRequestChecker struct { type TBackendRequestChecker struct {
@ -123,7 +123,7 @@ func (backend *TBackendRequestChecker) ServeHTTP(w http.ResponseWriter, r *http.
r.ParseForm() r.ParseForm()
unsealedForm, err := Backend.UnsealRequest(r.PostForm) unsealedForm, err := Backend.secureForm.Unseal(r.PostForm)
if err != nil { if err != nil {
backend.tb.Errorf("Failed to unseal backend request: %v", err) backend.tb.Errorf("Failed to unseal backend request: %v", err)
} }
@ -276,7 +276,7 @@ func TSealForSavePubMsg(tb testing.TB, cmd Command, channel string, arguments in
} }
form.Set("time", strconv.FormatInt(time.Now().Unix(), 10)) form.Set("time", strconv.FormatInt(time.Now().Unix(), 10))
sealed, err := Backend.SealRequest(form) sealed, err := Backend.secureForm.Seal(form)
if err != nil { if err != nil {
tb.Error(err) tb.Error(err)
return nil, err return nil, err
@ -300,7 +300,7 @@ func TSealForUncachedPubMsg(tb testing.TB, cmd Command, channel string, argument
form.Set("time", time.Now().Format(time.UnixDate)) form.Set("time", time.Now().Format(time.UnixDate))
form.Set("scope", scope) form.Set("scope", scope)
sealed, err := Backend.SealRequest(form) sealed, err := Backend.secureForm.Seal(form)
if err != nil { if err != nil {
tb.Error(err) tb.Error(err)
return nil, err return nil, err

View file

@ -35,11 +35,6 @@ type ConfigFile struct {
// Path to key file. // Path to key file.
SSLKeyFile string SSLKeyFile string
UseESLogStashing bool
ESServer string
ESIndexPrefix string
ESHostName string
// Nacl keys // Nacl keys
OurPrivateKey []byte OurPrivateKey []byte
OurPublicKey []byte OurPublicKey []byte
@ -64,6 +59,24 @@ type ClientMessage struct {
origArguments string origArguments string
} }
func (cm ClientMessage) Reply(cmd Command, args interface{}) ClientMessage {
return ClientMessage{
MessageID: cm.MessageID,
Command: cmd,
Arguments: args,
}
}
func (cm ClientMessage) ReplyJSON(cmd Command, argsJSON string) ClientMessage {
n := ClientMessage{
MessageID: cm.MessageID,
Command: cmd,
origArguments: argsJSON,
}
n.parseOrigArguments()
return n
}
type AuthInfo struct { type AuthInfo struct {
// The client's claimed username on Twitch. // The client's claimed username on Twitch.
TwitchUsername string TwitchUsername string

View file

@ -46,11 +46,6 @@ var uniqueCtrWritingToken chan usageToken
var CounterLocation *time.Location = time.FixedZone("UTC-5", int((time.Hour*-5)/time.Second)) var CounterLocation *time.Location = time.FixedZone("UTC-5", int((time.Hour*-5)/time.Second))
func TruncateToMidnight(at time.Time) time.Time {
year, month, day := at.Date()
return time.Date(year, month, day, 0, 0, 0, 0, CounterLocation)
}
// GetCounterPeriod calculates the start and end timestamps for the HLL measurement period that includes the 'at' timestamp. // GetCounterPeriod calculates the start and end timestamps for the HLL measurement period that includes the 'at' timestamp.
func GetCounterPeriod(at time.Time) (start time.Time, end time.Time) { func GetCounterPeriod(at time.Time) (start time.Time, end time.Time) {
year, month, day := at.Date() year, month, day := at.Date()
@ -132,7 +127,7 @@ func HTTPShowHLL(w http.ResponseWriter, r *http.Request) {
hllFileServer.ServeHTTP(w, r) hllFileServer.ServeHTTP(w, r)
} }
func HTTPWriteHLL(w http.ResponseWriter, r *http.Request) { func HTTPWriteHLL(w http.ResponseWriter, _ *http.Request) {
writeHLL() writeHLL()
w.WriteHeader(200) w.WriteHeader(200)
w.Write([]byte("ok")) w.Write([]byte("ok"))
@ -150,9 +145,7 @@ func loadUniqueUsers() {
now := time.Now().In(CounterLocation) now := time.Now().In(CounterLocation)
uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now) uniqueCounter.Start, uniqueCounter.End = GetCounterPeriod(now)
err = loadHLL(now, &uniqueCounter) err = loadHLL(now, &uniqueCounter)
isIgnorableError := err != nil && (false || isIgnorableError := err != nil && (os.IsNotExist(err) || err == io.EOF)
(os.IsNotExist(err)) ||
(err == io.EOF))
if isIgnorableError { if isIgnorableError {
// file didn't finish writing // file didn't finish writing
@ -227,10 +220,10 @@ func rolloverCounters_do() {
// Attempt to rescue the data into the log // Attempt to rescue the data into the log
var buf bytes.Buffer var buf bytes.Buffer
bytes, err := uniqueCounter.Counter.GobEncode() by, err := uniqueCounter.Counter.GobEncode()
if err == nil { if err == nil {
enc := base64.NewEncoder(base64.StdEncoding, &buf) enc := base64.NewEncoder(base64.StdEncoding, &buf)
enc.Write(bytes) enc.Write(by)
enc.Close() enc.Close()
log.Print("data for ", GetHLLFilename(uniqueCounter.Start), ":", buf.String()) log.Print("data for ", GetHLLFilename(uniqueCounter.Start), ":", buf.String())
} }

View file

@ -1,104 +1,9 @@
package server package server
import (
"bytes"
"crypto/rand"
"encoding/base64"
"errors"
"net/url"
"strconv"
"strings"
"golang.org/x/crypto/nacl/box"
)
func FillCryptoRandom(buf []byte) error {
remaining := len(buf)
for remaining > 0 {
count, err := rand.Read(buf)
if err != nil {
return err
}
remaining -= count
}
return nil
}
func copyString(s string) string { func copyString(s string) string {
return string([]byte(s)) return string([]byte(s))
} }
func (backend *backendInfo) SealRequest(form url.Values) (url.Values, error) {
var nonce [24]byte
var err error
err = FillCryptoRandom(nonce[:])
if err != nil {
return nil, err
}
cipherMsg := box.SealAfterPrecomputation(nil, []byte(form.Encode()), &nonce, &backend.sharedKey)
bufMessage := new(bytes.Buffer)
enc := base64.NewEncoder(base64.URLEncoding, bufMessage)
enc.Write(cipherMsg)
enc.Close()
cipherString := bufMessage.String()
bufNonce := new(bytes.Buffer)
enc = base64.NewEncoder(base64.URLEncoding, bufNonce)
enc.Write(nonce[:])
enc.Close()
nonceString := bufNonce.String()
retval := url.Values{
"nonce": []string{nonceString},
"msg": []string{cipherString},
"id": []string{strconv.Itoa(Backend.serverID)},
}
return retval, nil
}
var ErrorShortNonce = errors.New("Nonce too short.")
var ErrorInvalidSignature = errors.New("Invalid signature or contents")
func (backend *backendInfo) UnsealRequest(form url.Values) (url.Values, error) {
var nonce [24]byte
nonceString := form.Get("nonce")
dec := base64.NewDecoder(base64.URLEncoding, strings.NewReader(nonceString))
count, err := dec.Read(nonce[:])
if err != nil {
fmt.Println("Error reading nonce");
Statistics.BackendVerifyFails++
return nil, err
}
if count != 24 {
Statistics.BackendVerifyFails++
return nil, ErrorShortNonce
}
cipherString := form.Get("msg")
dec = base64.NewDecoder(base64.URLEncoding, strings.NewReader(cipherString))
cipherBuffer := new(bytes.Buffer)
cipherBuffer.ReadFrom(dec)
message, ok := box.OpenAfterPrecomputation(nil, cipherBuffer.Bytes(), &nonce, &backend.sharedKey)
if !ok {
Statistics.BackendVerifyFails++
return nil, ErrorInvalidSignature
}
retValues, err := url.ParseQuery(string(message))
if err != nil {
Statistics.BackendVerifyFails++
return nil, ErrorInvalidSignature
}
return retValues, nil
}
func AddToSliceS(ary *[]string, val string) bool { func AddToSliceS(ary *[]string, val string) bool {
slice := *ary slice := *ary
for _, v := range slice { for _, v := range slice {
@ -162,17 +67,3 @@ func RemoveFromSliceCl(ary *[]*ClientInfo, val *ClientInfo) bool {
*ary = slice *ary = slice
return true return true
} }
func AddToSliceB(ary *[]bunchSubscriber, client *ClientInfo, mid int) bool {
newSub := bunchSubscriber{Client: client, MessageID: mid}
slice := *ary
for _, v := range slice {
if v == newSub {
return false
}
}
slice = append(slice, newSub)
*ary = slice
return true
}