From c424b75cf680a56aa4acf07efe5101ed3ed4f81c Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Wed, 21 Dec 2022 13:22:18 -0500 Subject: [PATCH] concurrent ws --- pkg/auth/session.go | 10 +-- pkg/blas/core/core.go | 6 +- pkg/frontend/frontend.go | 3 +- pkg/server/websocket.go | 5 +- pkg/wsapi/api.go | 137 +++++++++++++++++++++++++++------------ pkg/wsapi/auth.go | 23 +++++-- 6 files changed, 125 insertions(+), 59 deletions(-) diff --git a/pkg/auth/session.go b/pkg/auth/session.go index 9a38c36..2bd8a7a 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -276,16 +276,16 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { } func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken { - var uvIssRT *RefreshToken + var unverifiedIssRT *RefreshToken claims := &jwt.StandardClaims{} _, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) { iss := jt.Claims.(*jwt.StandardClaims).Issuer - uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss)) - if uvIssRT == nil { + unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss)) + if unverifiedIssRT == nil { return nil, fmt.Errorf("bad token") } - return []byte(uvIssRT.JWTKey), nil + return []byte(unverifiedIssRT.JWTKey), nil }) if err != nil { @@ -293,7 +293,7 @@ func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken { return nil } - return uvIssRT + return unverifiedIssRT } func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { diff --git a/pkg/blas/core/core.go b/pkg/blas/core/core.go index d5ef63c..4e12586 100644 --- a/pkg/blas/core/core.go +++ b/pkg/blas/core/core.go @@ -34,12 +34,12 @@ type WebSocketManager interface { } type WebSocketSession interface { - WSConn() *websocket.Conn - Go() error + Conn() *websocket.Conn + Go(context.Context) error Blas() Blas } -type Handler func(wss WebSocketSession, msgID int, cmd []string, msg interface{}) error +type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error type NewData func(cmd []string) interface{} type Shutdowner interface { diff --git a/pkg/frontend/frontend.go b/pkg/frontend/frontend.go index 47cc2ac..56f79ca 100644 --- a/pkg/frontend/frontend.go +++ b/pkg/frontend/frontend.go @@ -1,6 +1,7 @@ package frontend import ( + "context" "embed" "io" "io/fs" @@ -56,7 +57,7 @@ func newData(_ []string) interface{} { return map[string]interface{}{} } -func wsHand(wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error { +func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error { return nil } diff --git a/pkg/server/websocket.go b/pkg/server/websocket.go index 3bb05bf..ee9969b 100644 --- a/pkg/server/websocket.go +++ b/pkg/server/websocket.go @@ -1,6 +1,7 @@ package server import ( + "context" "dynatron.me/x/blasphem/pkg/wsapi" "github.com/gorilla/websocket" @@ -22,6 +23,8 @@ func (s *Server) wsHandler(c echo.Context) error { defer conn.Close() _ = log.Debug + ctx, cancel := context.WithCancel(c.Request().Context()) + defer cancel() - return wsapi.NewSession(s, c, conn).Go() + return wsapi.NewSession(s, c, conn).Go(ctx) } diff --git a/pkg/wsapi/api.go b/pkg/wsapi/api.go index 2277257..a15d6a4 100644 --- a/pkg/wsapi/api.go +++ b/pkg/wsapi/api.go @@ -1,9 +1,9 @@ package wsapi import ( + "context" "encoding/json" "errors" - "io" "strings" "dynatron.me/x/blasphem/pkg/auth" @@ -18,32 +18,26 @@ import ( var ( NoSuchHandlerErr = errors.New("bad websocket command") NoMessageIDErr = errors.New("no message ID") + AuthInvalidErr = errors.New("invalid auth") ) type Type string type MsgBase struct { + ID *int `json:"id,omitempty"` Type Type `json:"type"` } type ( wsSession struct { - *websocket.Conn + conn *websocket.Conn b core.Blas ec echo.Context - h phaseHandler + write chan<- interface{} user *auth.User refreshToken *auth.RefreshToken } - phaseHandler interface { - handleMsg(io.Reader) error - } - - cmdHandler struct { - *wsSession - } - wsEntry struct { dataNew core.NewData hnd core.Handler @@ -83,47 +77,100 @@ func NewManager() core.WebSocketManager { func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession { ws := &wsSession{ - Conn: conn, + conn: conn, b: s, ec: c, } - ws.h = &authPhase{ws} - return ws } -func (ws *wsSession) WSConn() *websocket.Conn { - return ws.Conn +func (ws *wsSession) Conn() *websocket.Conn { + return ws.conn } func (ws *wsSession) Blas() core.Blas { return ws.b } -func (ws *wsSession) Go() error { +func (ws *wsSession) Go(ctx context.Context) error { + authP := &authPhase{ws} err := ws.sendAuthRequired() if err != nil { return err } + + _, rdr, err := ws.conn.NextReader() + if err != nil { + return err + } + + err = authP.handleMsg(ctx, rdr) + if err != nil { + return err + } + + // command phase + + msgChan := make(chan map[string]interface{}) + write := make(chan interface{}) + ws.write = write + defer close(write) + + cCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func(ctx context.Context, ch chan<- map[string]interface{}) { + defer close(ch) + for { + if err := ctx.Err(); err != nil { + return + } + + _, rdr, err := ws.conn.NextReader() + switch { + case err == nil: + case websocket.IsCloseError(err, websocket.CloseGoingAway): + return + case err != nil: + log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail") + return + } + + var msgMap map[string]interface{} + err = json.NewDecoder(rdr).Decode(&msgMap) + if err != nil { + ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()}) + } + + ch <- msgMap + + } + }(cCtx, msgChan) for { - _, rdr, err := ws.NextReader() - switch { - case err == nil: - case websocket.IsCloseError(err, websocket.CloseGoingAway): - return nil - case err != nil: - log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail") - return err - } + select { + case msg, ok := <-msgChan: + if !ok { + return nil + } - err = ws.h.handleMsg(rdr) - if err != nil { - log.Error().Err(err).Msg("handleMsg") + err = ws.handleMsg(ctx, msg) + if err != nil { + log.Error().Err(err).Msg("handleMsg") + } + case <-ctx.Done(): + close(msgChan) + return nil + case m := <-write: + err := ws.conn.WriteJSON(m) + if err != nil { + log.Error().Err(err).Msg("writeMsg") + } } } } -type cmdMsg struct { +func (ws *wsSession) Write(msg interface{}) { + ws.write <- msg } type MsgType string @@ -144,8 +191,8 @@ type WSError struct { Error Error `json:"error"` } -func (ws *cmdHandler) writeError(id int, err Error) error { - return ws.WriteJSON(WSError{ +func (ws *wsSession) writeError(id int, err Error) error { + return ws.conn.WriteJSON(WSError{ ID: &id, Type: ResultMsgType, Success: false, @@ -153,11 +200,9 @@ func (ws *cmdHandler) writeError(id int, err Error) error { }) } -func (ws *cmdHandler) handleMsg(r io.Reader) error { - var msgMap map[string]interface{} - err := json.NewDecoder(r).Decode(&msgMap) - if err != nil { - return err +func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error { + if err := ctx.Err(); err != nil { + return nil } msgType, ok := msgMap["type"].(string) @@ -167,8 +212,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error { idFl, ok := msgMap["id"].(float64) if !ok { - return ws.WriteJSON( - WSError{ + ws.Write(WSError{ Type: ResultMsgType, Success: false, Error: Error{ @@ -176,6 +220,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error { Message: "command has no ID", }, }) + return nil } id := int(idFl) @@ -186,12 +231,15 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error { switch err { case nil: case NoSuchHandlerErr: - return ws.writeError(id, Error{ + ws.writeError(id, Error{ Code: "invalid_type", Message: "no such command", }) + + return nil default: - return err + log.Error().Err(err).Msg("dispatch") + return nil } nd := newData(cmd) @@ -207,5 +255,10 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error { } } - return hand(ws, id, cmd, nd) + err = hand(ctx, ws, id, cmd, nd) + if err != nil { + log.Error().Err(err).Msg("dispatch") + } + + return nil } diff --git a/pkg/wsapi/auth.go b/pkg/wsapi/auth.go index 5ecec80..412d690 100644 --- a/pkg/wsapi/auth.go +++ b/pkg/wsapi/auth.go @@ -1,6 +1,7 @@ package wsapi import ( + "context" "encoding/json" "io" @@ -18,10 +19,10 @@ func (ws *wsSession) sendAuthRequired() error { MsgBase Version string `json:"ha_version"` }{ - MsgBase{"auth_required"}, + MsgBase{Type: "auth_required"}, ws.b.Version(), } - return ws.WriteJSON(&authReq) + return ws.conn.WriteJSON(&authReq) } type authMsg struct { @@ -36,30 +37,33 @@ func (ap *authPhase) msgSchema() interface{} { func (ap *authPhase) finishAuth(rt *auth.RefreshToken) { ap.user = rt.User ap.refreshToken = rt - ap.h = &cmdHandler{ap.wsSession} } func (ap *authPhase) sendAuthOK() error { - return ap.WriteJSON(struct { + return ap.conn.WriteJSON(struct { Type string `json:"type"` Version string `json:"ha_version"` }{Type: "auth_ok", Version: ap.Blas().Version()}) } func (ap *authPhase) sendAuthInvalid() error { - return ap.WriteJSON(struct { + return ap.conn.WriteJSON(struct { Type string `json:"type"` Message string `json:"message"` }{Type: "auth_ok", Message: "invalid auth"}) } -func (ap *authPhase) handleMsg(r io.Reader) error { +func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error { var authMsg authMsg err := json.NewDecoder(r).Decode(&authMsg) if err != nil { return err } + if err := ctx.Err(); err != nil { + return err + } + refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken) if refreshToken != nil { ap.finishAuth(refreshToken) @@ -68,5 +72,10 @@ func (ap *authPhase) handleMsg(r io.Reader) error { log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed") - return ap.sendAuthInvalid() + err = ap.sendAuthInvalid() + if err != nil { + return err + } + + return AuthInvalidErr }