From 4ee88993e77fc5b143f434998d6f0877b51a6f7f Mon Sep 17 00:00:00 2001 From: Marco Peereboom Date: Tue, 27 Feb 2024 17:34:09 +0000 Subject: [PATCH] Fix session handling and kill a race --- service/bfg/bfg.go | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/service/bfg/bfg.go b/service/bfg/bfg.go index c2e058c3..447e62ec 100644 --- a/service/bfg/bfg.go +++ b/service/bfg/bfg.go @@ -704,9 +704,6 @@ func (s *Server) handleWebsocketPrivateRead(ctx context.Context, bws *bfgWs) { if err != nil { log.Errorf("handleWebsocketRead %v %v %v: %v", bws.addr, cmd, id, err) - // XXX this needs to be handled by the caller - bws.conn.CloseStatus(websocket.StatusProtocolError, - err.Error()) return } @@ -781,9 +778,6 @@ func (s *Server) handleWebsocketPublicRead(ctx context.Context, bws *bfgWs) { if err != nil { log.Errorf("handleWebsocketRead %v %v %v: %v", bws.addr, cmd, id, err) - // XXX this needs to be handled by the caller - bws.conn.CloseStatus(websocket.StatusProtocolError, - err.Error()) return } @@ -815,21 +809,19 @@ func (s *Server) newSession(bws *bfgWs) (string, error) { } } -func (s *Server) killSession(id string, why websocket.StatusCode) { +func (s *Server) deleteSession(id string) { + log.Tracef("deleteSession") + defer log.Tracef("deleteSession exit") + s.mtx.Lock() - bws, ok := s.sessions[id] + _, ok := s.sessions[id] if ok { delete(s.sessions, id) } s.mtx.Unlock() if !ok { - log.Errorf("killSession: id not found in sessions %s", id) - } else { - if err := bws.conn.CloseStatus(why, ""); err != nil { - // XXX this is too noisy. - log.Debugf("session close %v: %v", id, err) - } + log.Errorf("deleteSession: id not found in sessions %s", id) } } @@ -848,6 +840,7 @@ func (s *Server) handleWebsocketPrivate(w http.ResponseWriter, r *http.Request) r.RemoteAddr, err) return } + defer conn.Close(websocket.StatusProtocolError, "") bws := &bfgWs{ addr: r.RemoteAddr, @@ -865,7 +858,7 @@ func (s *Server) handleWebsocketPrivate(w http.ResponseWriter, r *http.Request) } defer func() { - s.killSession(bws.sessionId, websocket.StatusNormalClosure) + s.deleteSession(bws.sessionId) }() bws.wg.Add(1) @@ -956,7 +949,7 @@ func (s *Server) handleWebsocketPublic(w http.ResponseWriter, r *http.Request) { return } defer func() { - s.killSession(bws.sessionId, websocket.StatusNormalClosure) + s.deleteSession(bws.sessionId) }() // Always ping, required by protocol. @@ -1300,8 +1293,6 @@ func (s *Server) handleAccessPublicKeys(table string, action string, payload, pa return } - // XXX this is racing with killSession but protected. We should - // create a killSessions that takes an encoded PublicKey s.mtx.Lock() for _, v := range s.sessions { // if public key does not exist on session, it's not an authenticated @@ -1314,8 +1305,7 @@ func (s *Server) handleAccessPublicKeys(table string, action string, payload, pa // encoding, ensure that the session string does for an equal comparison sessionPublicKeyEncoded := fmt.Sprintf("\\x%s", hex.EncodeToString(v.publicKey)) if sessionPublicKeyEncoded == accessPublicKey.PublicKeyEncoded { - sessionId := v.sessionId - go s.killSession(sessionId, protocol.StatusHandshakeErr) + v.conn.CloseStatus(websocket.StatusProtocolError, "killed") } } s.mtx.Unlock()