diff --git a/socketserver/internal/server/commands.go b/socketserver/internal/server/commands.go index 4fbb8688..2b13eb09 100644 --- a/socketserver/internal/server/commands.go +++ b/socketserver/internal/server/commands.go @@ -86,18 +86,14 @@ func HandleReady(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (r client.MakePendingRequests = nil client.Mutex.Unlock() + client.MsgChannelKeepalive.Add(1) go func() { - client.MsgChannelKeepalive.RLock() - if client.MessageChannel == nil { - return - } - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: SuccessCommand} SendBacklogForNewClient(client) if disconnectAt != 0 { SendTimedBacklogMessages(client, time.Unix(disconnectAt, 0)) } - client.MsgChannelKeepalive.RUnlock() + client.MsgChannelKeepalive.Done() }() return ClientMessage{Command: AsyncResponseCommand}, nil } @@ -192,13 +188,13 @@ func GetSubscriptionBacklog(conn *websocket.Conn, client *ClientInfo) { } // Deliver to client - client.MsgChannelKeepalive.RLock() + client.MsgChannelKeepalive.Add(1) if client.MessageChannel != nil { for _, msg := range messages { client.MessageChannel <- msg } } - client.MsgChannelKeepalive.RUnlock() + client.MsgChannelKeepalive.Done() } func HandleSurvey(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { @@ -401,8 +397,7 @@ func HandleBunchedRemotecommand(conn *websocket.Conn, client *ClientInfo, msg Cl CompletedBunchLock.RUnlock() } - // !!! unlocked on reply - client.MsgChannelKeepalive.RLock() + client.MsgChannelKeepalive.Add(1) PendingBunchLock.RLock() list, ok := PendingBunchedRequests[br] @@ -454,7 +449,7 @@ func HandleBunchedRemotecommand(conn *websocket.Conn, client *ClientInfo, msg Cl for _, member := range bsl.Members { msg.MessageID = member.MessageID member.Client.MessageChannel <- msg - member.Client.MsgChannelKeepalive.RUnlock() + member.Client.MsgChannelKeepalive.Done() } bsl.Unlock() @@ -469,20 +464,18 @@ func HandleBunchedRemotecommand(conn *websocket.Conn, client *ClientInfo, msg Cl } func HandleRemoteCommand(conn *websocket.Conn, client *ClientInfo, msg ClientMessage) (rmsg ClientMessage, err error) { + client.MsgChannelKeepalive.Add(1) go func(conn *websocket.Conn, msg ClientMessage, authInfo AuthInfo) { resp, err := RequestRemoteDataCached(string(msg.Command), msg.origArguments, authInfo) - client.MsgChannelKeepalive.RLock() - if client.MessageChannel != nil { - if err != nil { - client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} - } else { - msg := SuccessMessageFromString(resp) - msg.MessageID = msg.MessageID - client.MessageChannel <- msg - } + if err != nil { + client.MessageChannel <- ClientMessage{MessageID: msg.MessageID, Command: ErrorCommand, Arguments: err.Error()} + } else { + msg := SuccessMessageFromString(resp) + msg.MessageID = msg.MessageID + client.MessageChannel <- msg } - client.MsgChannelKeepalive.RUnlock() + client.MsgChannelKeepalive.Done() }(conn, msg, client.AuthInfo) return ClientMessage{Command: AsyncResponseCommand}, nil diff --git a/socketserver/internal/server/handlecore.go b/socketserver/internal/server/handlecore.go index 06ae7a93..35215577 100644 --- a/socketserver/internal/server/handlecore.go +++ b/socketserver/internal/server/handlecore.go @@ -257,11 +257,11 @@ RunLoop: // Stop getting messages... UnsubscribeAll(&client) - client.MsgChannelKeepalive.Lock() + // Wait for pending jobs to finish... + client.MsgChannelKeepalive.Wait() client.MessageChannel = nil - client.MsgChannelKeepalive.Unlock() - // And finished. + // And done. // Close the channel so the draining goroutine can finish, too. close(_serverMessageChan) diff --git a/socketserver/internal/server/types.go b/socketserver/internal/server/types.go index 1420401d..97dae35a 100644 --- a/socketserver/internal/server/types.go +++ b/socketserver/internal/server/types.go @@ -91,8 +91,8 @@ type ClientInfo struct { // This field will be nil before it is closed. MessageChannel chan<- ClientMessage - // Take a read-lock on this before checking whether MessageChannel is nil. - MsgChannelKeepalive sync.RWMutex + // Take out an Add() on this during a command if you need to use the MessageChannel later. + MsgChannelKeepalive sync.WaitGroup // The number of pings sent without a response pingCount int