diff --git a/socketserver/cmd/ffzsocketserver/socketserver.go b/socketserver/cmd/ffzsocketserver/socketserver.go index 1f7f2db0..cea1bb07 100644 --- a/socketserver/cmd/ffzsocketserver/socketserver.go +++ b/socketserver/cmd/ffzsocketserver/socketserver.go @@ -81,7 +81,8 @@ func main() { Addr: conf.SSLListenAddr, Handler: http.DefaultServeMux, TLSConfig: &tls.Config{ - GetCertificate: reloader.GetCertificateFunc(), + GetCertificate: reloader.GetCertificateFunc(), + GetConfigForClient: server.TLSEarlyReject, }, } go func() { diff --git a/socketserver/server/handlecore.go b/socketserver/server/handlecore.go index 525101aa..a1383a91 100644 --- a/socketserver/server/handlecore.go +++ b/socketserver/server/handlecore.go @@ -1,6 +1,8 @@ package server // import "github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server" import ( + "bytes" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -238,6 +240,29 @@ var BannerHTML []byte // StopAcceptingConnectionsCh is closed while the server is shutting down. var StopAcceptingConnectionsCh = make(chan struct{}) +func shouldRejectConnection() bool { + memFreeKB := atomic.LoadUint64(&Statistics.SysMemFreeKB) + if memFreeKB > 0 && memFreeKB < Configuration.MinMemoryKBytes { + return true + } + + curClients := atomic.LoadUint64(&Statistics.CurrentClientCount) + if Configuration.MaxClientCount != 0 && curClients >= Configuration.MaxClientCount { + return true + } + + return false +} + +var errEarlyTLSReject = errors.New("over capacity") + +func TLSEarlyReject(*tls.ClientHelloInfo) (*tls.Config, error) { + if shouldRejectConnection() { + return nil, errEarlyTLSReject + } + return nil, nil +} + // HTTPHandleRootURL is the http.HandleFunc for requests on `/`. // It either uses the SocketUpgrader or writes out the BannerHTML. func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) { @@ -250,21 +275,12 @@ func HTTPHandleRootURL(w http.ResponseWriter, r *http.Request) { if strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") { updateSysMem() - if Statistics.SysMemFreeKB > 0 && Statistics.SysMemFreeKB < Configuration.MinMemoryKBytes { + if shouldRejectConnection() { w.WriteHeader(503) - fmt.Fprint(w, "error: low memory") + fmt.Fprint(w, "connection rejected: over capacity") return } - if Configuration.MaxClientCount != 0 { - curClients := atomic.LoadUint64(&Statistics.CurrentClientCount) - if curClients >= Configuration.MaxClientCount { - w.WriteHeader(503) - fmt.Fprint(w, "error: client limit reached") - return - } - } - conn, err := SocketUpgrader.Upgrade(w, r, nil) if err != nil { fmt.Fprintf(w, "error: %v", err) @@ -365,7 +381,7 @@ func RunSocketConnection(conn *websocket.Conn) { stoppedChan := make(chan struct{}) var client ClientInfo - client.MessageChannel = _serverMessageChan + client.messageChannel = _serverMessageChan client.RemoteAddr = conn.RemoteAddr() client.MsgChannelIsDone = stoppedChan @@ -374,11 +390,7 @@ func RunSocketConnection(conn *websocket.Conn) { // report.RemoteAddr = client.RemoteAddr conn.SetPongHandler(func(pongBody string) error { - client.Mutex.Lock() - if client.HelloOK { // do not accept PONGs until hello sent - client.pingCount = 0 - } - client.Mutex.Unlock() + _clientChan <- ClientMessage{Command: "__ping"} return nil }) @@ -428,7 +440,24 @@ func runSocketReader(conn *websocket.Conn, client *ClientInfo, errorChan chan<- defer close(errorChan) defer close(clientChan) - for ; err == nil; messageType, packet, err = conn.ReadMessage() { + for { + conn.SetReadDeadline(time.Now().Add(1 * time.Minute)) + messageType, packet, err = conn.ReadMessage() + // handle ReadDeadline by sending a ping + // writer loop handles repeated ping timeouts + if tmErr, ok := err.(interface { + Timeout() bool + }); ok && tmErr.Timeout() { + select { + case <-stoppedChan: + return + case clientChan <- ClientMessage{Command: "__readTimeout"}: + } + continue // re-set deadline and wait for pong packet + } + if err != nil { + break + } if messageType == websocket.BinaryMessage { err = &CloseGotBinaryMessage break @@ -451,21 +480,26 @@ func runSocketReader(conn *websocket.Conn, client *ClientInfo, errorChan chan<- } else if msg.MessageID == 0 { continue } + select { - case clientChan <- msg: case <-stoppedChan: return + case clientChan <- msg: } } select { - case errorChan <- err: case <-stoppedChan: + case errorChan <- err: } // exit goroutine } +var pingPayload = []byte("PING") + func runSocketWriter(conn *websocket.Conn, client *ClientInfo, errorChan <-chan error, clientChan <-chan ClientMessage, serverMessageChan <-chan ClientMessage) websocket.CloseError { + lastPacket := time.Now() + for { select { case err := <-errorChan: @@ -484,9 +518,30 @@ func runSocketWriter(conn *websocket.Conn, client *ClientInfo, errorChan <-chan } case msg := <-clientChan: + if msg.Command == "__readTimeout" { + // generated on 60 seconds without a message + now := time.Now() + if lastPacket.Add(5 * time.Minute).Before(now) { + return CloseTimedOut + } + conn.WriteControl( + websocket.PingMessage, + pingPayload, + getDeadline(), + ) + continue + } + if !client.HelloOK && msg.Command != HelloCommand { return CloseFirstMessageNotHello } + lastPacket = time.Now() + + if msg.Command == "__ping" { + // generated for PONG packets + // want to branch AFTER lastPacket is set + continue + } for _, char := range msg.Command { if char == utf8.RuneError { @@ -505,17 +560,6 @@ func runSocketWriter(conn *websocket.Conn, client *ClientInfo, errorChan <-chan } SendMessage(conn, msg) - case <-time.After(1 * time.Minute): - client.Mutex.Lock() - client.pingCount++ - tooManyPings := client.pingCount == 5 - client.Mutex.Unlock() - if tooManyPings { - return CloseTimedOut - } else { - conn.WriteControl(websocket.PingMessage, []byte(strconv.FormatInt(time.Now().Unix(), 10)), getDeadline()) - } - case <-StopAcceptingConnectionsCh: return CloseGoingAway } @@ -619,7 +663,6 @@ func MarshalClientMessage(clientMessage interface{}) (int, []byte, error) { } msg = *pMsg } - var dataStr string if msg.Command == "" && msg.MessageID == 0 { panic("MarshalClientMessage: attempt to send an empty ClientMessage") @@ -632,20 +675,25 @@ func MarshalClientMessage(clientMessage interface{}) (int, []byte, error) { msg.MessageID = -1 } + // optimized from fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, ...) + var buf bytes.Buffer + fmt.Fprint(&buf, msg.MessageID) + buf.WriteByte(' ') + buf.WriteString(string(msg.Command)) + if msg.origArguments != "" { - dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, msg.origArguments) + buf.WriteByte(' ') + buf.WriteString(msg.origArguments) } else if msg.Arguments != nil { argBytes, err := json.Marshal(msg.Arguments) if err != nil { return 0, nil, err } - - dataStr = fmt.Sprintf("%d %s %s", msg.MessageID, msg.Command, string(argBytes)) - } else { - dataStr = fmt.Sprintf("%d %s", msg.MessageID, msg.Command) + buf.WriteByte(' ') + buf.Write(argBytes) } - return websocket.TextMessage, []byte(dataStr), nil + return websocket.TextMessage, buf.Bytes(), nil } // ArgumentsAsString parses the arguments of the ClientMessage as a single string. diff --git a/socketserver/server/subscriptions.go b/socketserver/server/subscriptions.go index 136fa603..7486876b 100644 --- a/socketserver/server/subscriptions.go +++ b/socketserver/server/subscriptions.go @@ -18,9 +18,11 @@ var ChatSubscriptionLock sync.RWMutex var GlobalSubscriptionInfo []*ClientInfo var GlobalSubscriptionLock sync.RWMutex +// Send a message to the client. +// Drops if buffer is full. func (client *ClientInfo) Send(msg ClientMessage) bool { select { - case client.MessageChannel <- msg: + case client.messageChannel <- msg: return true case <-client.MsgChannelIsDone: return false diff --git a/socketserver/server/types.go b/socketserver/server/types.go index 3face4c4..b4fae2af 100644 --- a/socketserver/server/types.go +++ b/socketserver/server/types.go @@ -47,13 +47,11 @@ type ConfigFile struct { ProxyRoutes []ProxyRoute } - type ProxyRoute struct { - Route string - Server string + Route string + Server string } - type ClientMessage struct { // Message ID. Increments by 1 for each message sent from the client. // When replying to a command, the message ID must be echoed. @@ -94,12 +92,6 @@ type AuthInfo struct { UsernameValidated bool } -type ClientVersion struct { - Major int - Minor int - Revision int -} - type ClientInfo struct { // The client ID. // This must be written once by the owning goroutine before the struct is passed off to any other goroutines. @@ -134,17 +126,19 @@ type ClientInfo struct { ReadyComplete bool // Server-initiated messages should be sent via the Send() method. - MessageChannel chan<- ClientMessage + messageChannel chan<- ClientMessage // Closed when the client is shutting down. MsgChannelIsDone <-chan struct{} - // Take out an Add() on this during a command if you need to use the MessageChannel later. + // Take out an Add() on this during a command if you need to call Send() later. MsgChannelKeepalive sync.WaitGroup +} - // The number of pings sent without a response. - // Protected by Mutex - pingCount int +type ClientVersion struct { + Major int + Minor int + Revision int } func VersionFromString(v string) ClientVersion {