diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go index ff3b786c..b76762e4 100644 --- a/socketserver/internal/server/handlecore.go +++ b/socketserver/internal/server/handlecore.go @@ -121,6 +121,7 @@ func ServeWebsocketOrCatbag(w http.ResponseWriter, r *http.Request) { var CloseGotBinaryMessage = websocket.CloseError{Code: websocket.CloseUnsupportedData, Text: "got binary packet"} var CloseGotMessageId0 = websocket.CloseError{Code: websocket.ClosePolicyViolation, Text: "got messageid 0"} +var CloseTimedOut = websocket.CloseError{Code: websocket.CloseNoStatusReceived, Text: "no ping replies for 5 minutes"} var CloseFirstMessageNotHello = websocket.CloseError{ Text: "Error - the first message sent must be a 'hello'", Code: websocket.ClosePolicyViolation, @@ -183,6 +184,12 @@ func HandleSocketConnection(conn *websocket.Conn) { // exit }(_errorChan, _clientChan) + conn.SetPongHandler(func(pongBody string) error { + fmt.Println("got pong") + client.pingCount = 0 + return nil + }) + var errorChan <-chan error = _errorChan var clientChan <-chan ClientMessage = _clientChan var serverMessageChan <-chan ClientMessage = _serverMessageChan @@ -215,8 +222,18 @@ RunLoop: } HandleCommand(conn, &client, msg) + case smsg := <-serverMessageChan: SendMessage(conn, smsg) + + case <- time.After(1 * time.Minute): + client.pingCount++ + if client.pingCount == 5 { + CloseConnection(conn, &CloseTimedOut) + break RunLoop + } else { + conn.WriteControl(websocket.PingMessage, []byte(strconv.FormatInt(time.Now().Unix(), 10)), getDeadline()) + } } } @@ -238,6 +255,10 @@ RunLoop: log.Println("End socket connection from", conn.RemoteAddr()) } +func getDeadline() time.Time { + return time.Now().Add(1 * time.Minute) +} + func CallHandler(handler CommandHandler, conn *websocket.Conn, client *ClientInfo, cmsg ClientMessage) (rmsg ClientMessage, err error) { defer func() { if r := recover(); r != nil { @@ -256,7 +277,7 @@ func CloseConnection(conn *websocket.Conn, closeMsg *websocket.CloseError) { if closeMsg != &CloseFirstMessageNotHello { log.Println("Terminating connection with", conn.RemoteAddr(), "-", closeMsg.Text) } - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(closeMsg.Code, closeMsg.Text), time.Now().Add(2*time.Minute)) + conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(closeMsg.Code, closeMsg.Text), getDeadline()) conn.Close() } @@ -265,6 +286,7 @@ func SendMessage(conn *websocket.Conn, msg ClientMessage) { if err != nil { panic(fmt.Sprintf("failed to marshal: %v %v", err, msg)) } + conn.SetWriteDeadline(getDeadline()) conn.WriteMessage(messageType, packet) } diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go index 76123277..d3d1224e 100644 --- a/socketserver/internal/server/types.go +++ b/socketserver/internal/server/types.go @@ -87,6 +87,9 @@ type ClientInfo struct { // Server-initiated messages should be sent here // Never nil. MessageChannel chan<- ClientMessage + + // The number of pings sent without a response + pingCount int } type tgmarray []TimestampedGlobalMessage