concurrent ws

This commit is contained in:
Daniel Ponte 2022-12-21 13:22:18 -05:00
parent 8bf7379170
commit c424b75cf6
6 changed files with 125 additions and 59 deletions

View file

@ -276,16 +276,16 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
} }
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken { func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
var uvIssRT *RefreshToken var unverifiedIssRT *RefreshToken
claims := &jwt.StandardClaims{} claims := &jwt.StandardClaims{}
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) { _, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
iss := jt.Claims.(*jwt.StandardClaims).Issuer iss := jt.Claims.(*jwt.StandardClaims).Issuer
uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss)) unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
if uvIssRT == nil { if unverifiedIssRT == nil {
return nil, fmt.Errorf("bad token") return nil, fmt.Errorf("bad token")
} }
return []byte(uvIssRT.JWTKey), nil return []byte(unverifiedIssRT.JWTKey), nil
}) })
if err != nil { if err != nil {
@ -293,7 +293,7 @@ func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
return nil return nil
} }
return uvIssRT return unverifiedIssRT
} }
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {

View file

@ -34,12 +34,12 @@ type WebSocketManager interface {
} }
type WebSocketSession interface { type WebSocketSession interface {
WSConn() *websocket.Conn Conn() *websocket.Conn
Go() error Go(context.Context) error
Blas() Blas Blas() Blas
} }
type Handler func(wss WebSocketSession, msgID int, cmd []string, msg interface{}) error type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
type NewData func(cmd []string) interface{} type NewData func(cmd []string) interface{}
type Shutdowner interface { type Shutdowner interface {

View file

@ -1,6 +1,7 @@
package frontend package frontend
import ( import (
"context"
"embed" "embed"
"io" "io"
"io/fs" "io/fs"
@ -56,7 +57,7 @@ func newData(_ []string) interface{} {
return map[string]interface{}{} return map[string]interface{}{}
} }
func wsHand(wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error { func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
return nil return nil
} }

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"dynatron.me/x/blasphem/pkg/wsapi" "dynatron.me/x/blasphem/pkg/wsapi"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -22,6 +23,8 @@ func (s *Server) wsHandler(c echo.Context) error {
defer conn.Close() defer conn.Close()
_ = log.Debug _ = log.Debug
ctx, cancel := context.WithCancel(c.Request().Context())
defer cancel()
return wsapi.NewSession(s, c, conn).Go() return wsapi.NewSession(s, c, conn).Go(ctx)
} }

View file

@ -1,9 +1,9 @@
package wsapi package wsapi
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"io"
"strings" "strings"
"dynatron.me/x/blasphem/pkg/auth" "dynatron.me/x/blasphem/pkg/auth"
@ -18,32 +18,26 @@ import (
var ( var (
NoSuchHandlerErr = errors.New("bad websocket command") NoSuchHandlerErr = errors.New("bad websocket command")
NoMessageIDErr = errors.New("no message ID") NoMessageIDErr = errors.New("no message ID")
AuthInvalidErr = errors.New("invalid auth")
) )
type Type string type Type string
type MsgBase struct { type MsgBase struct {
ID *int `json:"id,omitempty"`
Type Type `json:"type"` Type Type `json:"type"`
} }
type ( type (
wsSession struct { wsSession struct {
*websocket.Conn conn *websocket.Conn
b core.Blas b core.Blas
ec echo.Context ec echo.Context
h phaseHandler write chan<- interface{}
user *auth.User user *auth.User
refreshToken *auth.RefreshToken refreshToken *auth.RefreshToken
} }
phaseHandler interface {
handleMsg(io.Reader) error
}
cmdHandler struct {
*wsSession
}
wsEntry struct { wsEntry struct {
dataNew core.NewData dataNew core.NewData
hnd core.Handler hnd core.Handler
@ -83,47 +77,100 @@ func NewManager() core.WebSocketManager {
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession { func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
ws := &wsSession{ ws := &wsSession{
Conn: conn, conn: conn,
b: s, b: s,
ec: c, ec: c,
} }
ws.h = &authPhase{ws}
return ws return ws
} }
func (ws *wsSession) WSConn() *websocket.Conn { func (ws *wsSession) Conn() *websocket.Conn {
return ws.Conn return ws.conn
} }
func (ws *wsSession) Blas() core.Blas { return ws.b } func (ws *wsSession) Blas() core.Blas { return ws.b }
func (ws *wsSession) Go() error { func (ws *wsSession) Go(ctx context.Context) error {
authP := &authPhase{ws}
err := ws.sendAuthRequired() err := ws.sendAuthRequired()
if err != nil { if err != nil {
return err return err
} }
for { _, rdr, err := ws.conn.NextReader()
_, rdr, err := ws.NextReader() if err != nil {
switch { return err
case err == nil: }
case websocket.IsCloseError(err, websocket.CloseGoingAway):
return nil
case err != nil:
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
return err
}
err = ws.h.handleMsg(rdr) err = authP.handleMsg(ctx, rdr)
if err != nil { if err != nil {
log.Error().Err(err).Msg("handleMsg") return err
}
// command phase
msgChan := make(chan map[string]interface{})
write := make(chan interface{})
ws.write = write
defer close(write)
cCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func(ctx context.Context, ch chan<- map[string]interface{}) {
defer close(ch)
for {
if err := ctx.Err(); err != nil {
return
}
_, rdr, err := ws.conn.NextReader()
switch {
case err == nil:
case websocket.IsCloseError(err, websocket.CloseGoingAway):
return
case err != nil:
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
return
}
var msgMap map[string]interface{}
err = json.NewDecoder(rdr).Decode(&msgMap)
if err != nil {
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
}
ch <- msgMap
}
}(cCtx, msgChan)
for {
select {
case msg, ok := <-msgChan:
if !ok {
return nil
}
err = ws.handleMsg(ctx, msg)
if err != nil {
log.Error().Err(err).Msg("handleMsg")
}
case <-ctx.Done():
close(msgChan)
return nil
case m := <-write:
err := ws.conn.WriteJSON(m)
if err != nil {
log.Error().Err(err).Msg("writeMsg")
}
} }
} }
} }
type cmdMsg struct { func (ws *wsSession) Write(msg interface{}) {
ws.write <- msg
} }
type MsgType string type MsgType string
@ -144,8 +191,8 @@ type WSError struct {
Error Error `json:"error"` Error Error `json:"error"`
} }
func (ws *cmdHandler) writeError(id int, err Error) error { func (ws *wsSession) writeError(id int, err Error) error {
return ws.WriteJSON(WSError{ return ws.conn.WriteJSON(WSError{
ID: &id, ID: &id,
Type: ResultMsgType, Type: ResultMsgType,
Success: false, Success: false,
@ -153,11 +200,9 @@ func (ws *cmdHandler) writeError(id int, err Error) error {
}) })
} }
func (ws *cmdHandler) handleMsg(r io.Reader) error { func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
var msgMap map[string]interface{} if err := ctx.Err(); err != nil {
err := json.NewDecoder(r).Decode(&msgMap) return nil
if err != nil {
return err
} }
msgType, ok := msgMap["type"].(string) msgType, ok := msgMap["type"].(string)
@ -167,8 +212,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
idFl, ok := msgMap["id"].(float64) idFl, ok := msgMap["id"].(float64)
if !ok { if !ok {
return ws.WriteJSON( ws.Write(WSError{
WSError{
Type: ResultMsgType, Type: ResultMsgType,
Success: false, Success: false,
Error: Error{ Error: Error{
@ -176,6 +220,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
Message: "command has no ID", Message: "command has no ID",
}, },
}) })
return nil
} }
id := int(idFl) id := int(idFl)
@ -186,12 +231,15 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
switch err { switch err {
case nil: case nil:
case NoSuchHandlerErr: case NoSuchHandlerErr:
return ws.writeError(id, Error{ ws.writeError(id, Error{
Code: "invalid_type", Code: "invalid_type",
Message: "no such command", Message: "no such command",
}) })
return nil
default: default:
return err log.Error().Err(err).Msg("dispatch")
return nil
} }
nd := newData(cmd) nd := newData(cmd)
@ -207,5 +255,10 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
} }
} }
return hand(ws, id, cmd, nd) err = hand(ctx, ws, id, cmd, nd)
if err != nil {
log.Error().Err(err).Msg("dispatch")
}
return nil
} }

View file

@ -1,6 +1,7 @@
package wsapi package wsapi
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
@ -18,10 +19,10 @@ func (ws *wsSession) sendAuthRequired() error {
MsgBase MsgBase
Version string `json:"ha_version"` Version string `json:"ha_version"`
}{ }{
MsgBase{"auth_required"}, MsgBase{Type: "auth_required"},
ws.b.Version(), ws.b.Version(),
} }
return ws.WriteJSON(&authReq) return ws.conn.WriteJSON(&authReq)
} }
type authMsg struct { type authMsg struct {
@ -36,30 +37,33 @@ func (ap *authPhase) msgSchema() interface{} {
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) { func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
ap.user = rt.User ap.user = rt.User
ap.refreshToken = rt ap.refreshToken = rt
ap.h = &cmdHandler{ap.wsSession}
} }
func (ap *authPhase) sendAuthOK() error { func (ap *authPhase) sendAuthOK() error {
return ap.WriteJSON(struct { return ap.conn.WriteJSON(struct {
Type string `json:"type"` Type string `json:"type"`
Version string `json:"ha_version"` Version string `json:"ha_version"`
}{Type: "auth_ok", Version: ap.Blas().Version()}) }{Type: "auth_ok", Version: ap.Blas().Version()})
} }
func (ap *authPhase) sendAuthInvalid() error { func (ap *authPhase) sendAuthInvalid() error {
return ap.WriteJSON(struct { return ap.conn.WriteJSON(struct {
Type string `json:"type"` Type string `json:"type"`
Message string `json:"message"` Message string `json:"message"`
}{Type: "auth_ok", Message: "invalid auth"}) }{Type: "auth_ok", Message: "invalid auth"})
} }
func (ap *authPhase) handleMsg(r io.Reader) error { func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
var authMsg authMsg var authMsg authMsg
err := json.NewDecoder(r).Decode(&authMsg) err := json.NewDecoder(r).Decode(&authMsg)
if err != nil { if err != nil {
return err return err
} }
if err := ctx.Err(); err != nil {
return err
}
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken) refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
if refreshToken != nil { if refreshToken != nil {
ap.finishAuth(refreshToken) ap.finishAuth(refreshToken)
@ -68,5 +72,10 @@ func (ap *authPhase) handleMsg(r io.Reader) error {
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed") log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
return ap.sendAuthInvalid() err = ap.sendAuthInvalid()
if err != nil {
return err
}
return AuthInvalidErr
} }