diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go index ee819650..db91c66a 100644 --- a/socketserver/internal/server/commands.go +++ b/socketserver/internal/server/commands.go @@ -85,22 +85,20 @@ func HandleReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (r client.MakePendingRequests = nil client.Mutex.Unlock() - if disconnectAt == 0 { - // backlog only - go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} - SendBacklogForNewClient(client) - }() - return ClientMessage{Command: AsyncResponseCommand}, nil - } else { - // backlog and timed - go func() { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} - SendBacklogForNewClient(client) + go func() { + client.MsgChannelKeepalive.RLock() + if client.MessageChannel == nil { + return + } + + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} + SendBacklogForNewClient(client) + if disconnectAt != 0 { SendTimedBacklogMessages(client, time.Unix(disconnectAt, 0)) - }() - return ClientMessage{Command: AsyncResponseCommand}, nil - } + } + client.MsgChannelKeepalive.RUnlock() + }() + return ClientMessage{Command: AsyncResponseCommand}, nil } func HandleSetUser(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { @@ -193,9 +191,13 @@ func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) { } // Deliver to client - for _, msg := range messages { - client.MessageChannel <- msg + client.MsgChannelKeepalive.RLock() + if client.MessageChannel != nil { + for _, msg := range messages { + client.MessageChannel <- msg + } } + client.MsgChannelKeepalive.RUnlock() } func HandleSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { @@ -318,11 +320,15 @@ func HandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMes go func(conn *websocket.Conn, msg ClientMessage, authInfo AuthInfo) { resp, err := RequestRemoteDataCached(string(msg.Command), msg.origArguments, authInfo) - if err != nil { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} - } else { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} + client.MsgChannelKeepalive.RLock() + if client.MessageChannel != nil { + if err != nil { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} + } else { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand, origArguments: resp} + } } + client.MsgChannelKeepalive.RUnlock() }(conn, msg, client.AuthInfo) return ClientMessage{Command: AsyncResponseCommand}, nil diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go index 1a61b0ef..c2d05ce8 100644 --- a/socketserver/internal/server/handlecore.go +++ b/socketserver/internal/server/handlecore.go @@ -247,6 +247,10 @@ RunLoop: // Stop getting messages... UnsubscribeAll(&client) + client.MsgChannelKeepalive.Lock() + client.MessageChannel = nil + client.MsgChannelKeepalive.Unlock() + // And finished. // Close the channel so the draining goroutine can finish, too. close(_serverMessageChan) diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go index d3d1224e..0fd79a39 100644 --- a/socketserver/internal/server/types.go +++ b/socketserver/internal/server/types.go @@ -85,9 +85,12 @@ type ClientInfo struct { MakePendingRequests *time.Timer // Server-initiated messages should be sent here - // Never nil. + // This field will be nil before it is closed. MessageChannel chan<- ClientMessage + // Take a read-lock on this before checking whether MessageChannel is nil. + MsgChannelKeepalive sync.RWMutex + // The number of pings sent without a response pingCount int }