concurrent ws

This commit is contained in:
Daniel Ponte 2022-12-21 13:22:18 -05:00
parent 8bf7379170
commit c424b75cf6
6 changed files with 125 additions and 59 deletions

View file

@ -276,16 +276,16 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
}
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
var uvIssRT *RefreshToken
var unverifiedIssRT *RefreshToken
claims := &jwt.StandardClaims{}
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
iss := jt.Claims.(*jwt.StandardClaims).Issuer
uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
if uvIssRT == nil {
unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
if unverifiedIssRT == nil {
return nil, fmt.Errorf("bad token")
}
return []byte(uvIssRT.JWTKey), nil
return []byte(unverifiedIssRT.JWTKey), nil
})
if err != nil {
@ -293,7 +293,7 @@ func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
return nil
}
return uvIssRT
return unverifiedIssRT
}
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {

View file

@ -34,12 +34,12 @@ type WebSocketManager interface {
}
type WebSocketSession interface {
WSConn() *websocket.Conn
Go() error
Conn() *websocket.Conn
Go(context.Context) error
Blas() Blas
}
type Handler func(wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
type NewData func(cmd []string) interface{}
type Shutdowner interface {

View file

@ -1,6 +1,7 @@
package frontend
import (
"context"
"embed"
"io"
"io/fs"
@ -56,7 +57,7 @@ func newData(_ []string) interface{} {
return map[string]interface{}{}
}
func wsHand(wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
return nil
}

View file

@ -1,6 +1,7 @@
package server
import (
"context"
"dynatron.me/x/blasphem/pkg/wsapi"
"github.com/gorilla/websocket"
@ -22,6 +23,8 @@ func (s *Server) wsHandler(c echo.Context) error {
defer conn.Close()
_ = log.Debug
ctx, cancel := context.WithCancel(c.Request().Context())
defer cancel()
return wsapi.NewSession(s, c, conn).Go()
return wsapi.NewSession(s, c, conn).Go(ctx)
}

View file

@ -1,9 +1,9 @@
package wsapi
import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"dynatron.me/x/blasphem/pkg/auth"
@ -18,32 +18,26 @@ import (
var (
NoSuchHandlerErr = errors.New("bad websocket command")
NoMessageIDErr = errors.New("no message ID")
AuthInvalidErr = errors.New("invalid auth")
)
type Type string
type MsgBase struct {
ID *int `json:"id,omitempty"`
Type Type `json:"type"`
}
type (
wsSession struct {
*websocket.Conn
conn *websocket.Conn
b core.Blas
ec echo.Context
h phaseHandler
write chan<- interface{}
user *auth.User
refreshToken *auth.RefreshToken
}
phaseHandler interface {
handleMsg(io.Reader) error
}
cmdHandler struct {
*wsSession
}
wsEntry struct {
dataNew core.NewData
hnd core.Handler
@ -83,47 +77,100 @@ func NewManager() core.WebSocketManager {
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
ws := &wsSession{
Conn: conn,
conn: conn,
b: s,
ec: c,
}
ws.h = &authPhase{ws}
return ws
}
func (ws *wsSession) WSConn() *websocket.Conn {
return ws.Conn
func (ws *wsSession) Conn() *websocket.Conn {
return ws.conn
}
func (ws *wsSession) Blas() core.Blas { return ws.b }
func (ws *wsSession) Go() error {
func (ws *wsSession) Go(ctx context.Context) error {
authP := &authPhase{ws}
err := ws.sendAuthRequired()
if err != nil {
return err
}
_, rdr, err := ws.conn.NextReader()
if err != nil {
return err
}
err = authP.handleMsg(ctx, rdr)
if err != nil {
return err
}
// command phase
msgChan := make(chan map[string]interface{})
write := make(chan interface{})
ws.write = write
defer close(write)
cCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func(ctx context.Context, ch chan<- map[string]interface{}) {
defer close(ch)
for {
if err := ctx.Err(); err != nil {
return
}
_, rdr, err := ws.conn.NextReader()
switch {
case err == nil:
case websocket.IsCloseError(err, websocket.CloseGoingAway):
return
case err != nil:
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
return
}
var msgMap map[string]interface{}
err = json.NewDecoder(rdr).Decode(&msgMap)
if err != nil {
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
}
ch <- msgMap
}
}(cCtx, msgChan)
for {
_, rdr, err := ws.NextReader()
switch {
case err == nil:
case websocket.IsCloseError(err, websocket.CloseGoingAway):
return nil
case err != nil:
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
return err
}
select {
case msg, ok := <-msgChan:
if !ok {
return nil
}
err = ws.h.handleMsg(rdr)
if err != nil {
log.Error().Err(err).Msg("handleMsg")
err = ws.handleMsg(ctx, msg)
if err != nil {
log.Error().Err(err).Msg("handleMsg")
}
case <-ctx.Done():
close(msgChan)
return nil
case m := <-write:
err := ws.conn.WriteJSON(m)
if err != nil {
log.Error().Err(err).Msg("writeMsg")
}
}
}
}
type cmdMsg struct {
func (ws *wsSession) Write(msg interface{}) {
ws.write <- msg
}
type MsgType string
@ -144,8 +191,8 @@ type WSError struct {
Error Error `json:"error"`
}
func (ws *cmdHandler) writeError(id int, err Error) error {
return ws.WriteJSON(WSError{
func (ws *wsSession) writeError(id int, err Error) error {
return ws.conn.WriteJSON(WSError{
ID: &id,
Type: ResultMsgType,
Success: false,
@ -153,11 +200,9 @@ func (ws *cmdHandler) writeError(id int, err Error) error {
})
}
func (ws *cmdHandler) handleMsg(r io.Reader) error {
var msgMap map[string]interface{}
err := json.NewDecoder(r).Decode(&msgMap)
if err != nil {
return err
func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
if err := ctx.Err(); err != nil {
return nil
}
msgType, ok := msgMap["type"].(string)
@ -167,8 +212,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
idFl, ok := msgMap["id"].(float64)
if !ok {
return ws.WriteJSON(
WSError{
ws.Write(WSError{
Type: ResultMsgType,
Success: false,
Error: Error{
@ -176,6 +220,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
Message: "command has no ID",
},
})
return nil
}
id := int(idFl)
@ -186,12 +231,15 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
switch err {
case nil:
case NoSuchHandlerErr:
return ws.writeError(id, Error{
ws.writeError(id, Error{
Code: "invalid_type",
Message: "no such command",
})
return nil
default:
return err
log.Error().Err(err).Msg("dispatch")
return nil
}
nd := newData(cmd)
@ -207,5 +255,10 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
}
}
return hand(ws, id, cmd, nd)
err = hand(ctx, ws, id, cmd, nd)
if err != nil {
log.Error().Err(err).Msg("dispatch")
}
return nil
}

View file

@ -1,6 +1,7 @@
package wsapi
import (
"context"
"encoding/json"
"io"
@ -18,10 +19,10 @@ func (ws *wsSession) sendAuthRequired() error {
MsgBase
Version string `json:"ha_version"`
}{
MsgBase{"auth_required"},
MsgBase{Type: "auth_required"},
ws.b.Version(),
}
return ws.WriteJSON(&authReq)
return ws.conn.WriteJSON(&authReq)
}
type authMsg struct {
@ -36,30 +37,33 @@ func (ap *authPhase) msgSchema() interface{} {
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
ap.user = rt.User
ap.refreshToken = rt
ap.h = &cmdHandler{ap.wsSession}
}
func (ap *authPhase) sendAuthOK() error {
return ap.WriteJSON(struct {
return ap.conn.WriteJSON(struct {
Type string `json:"type"`
Version string `json:"ha_version"`
}{Type: "auth_ok", Version: ap.Blas().Version()})
}
func (ap *authPhase) sendAuthInvalid() error {
return ap.WriteJSON(struct {
return ap.conn.WriteJSON(struct {
Type string `json:"type"`
Message string `json:"message"`
}{Type: "auth_ok", Message: "invalid auth"})
}
func (ap *authPhase) handleMsg(r io.Reader) error {
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
var authMsg authMsg
err := json.NewDecoder(r).Decode(&authMsg)
if err != nil {
return err
}
if err := ctx.Err(); err != nil {
return err
}
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
if refreshToken != nil {
ap.finishAuth(refreshToken)
@ -68,5 +72,10 @@ func (ap *authPhase) handleMsg(r io.Reader) error {
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
return ap.sendAuthInvalid()
err = ap.sendAuthInvalid()
if err != nil {
return err
}
return AuthInvalidErr
}