concurrent ws
This commit is contained in:
parent
8bf7379170
commit
c424b75cf6
6 changed files with 125 additions and 59 deletions
|
@ -276,16 +276,16 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
|||
}
|
||||
|
||||
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
||||
var uvIssRT *RefreshToken
|
||||
var unverifiedIssRT *RefreshToken
|
||||
claims := &jwt.StandardClaims{}
|
||||
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
|
||||
iss := jt.Claims.(*jwt.StandardClaims).Issuer
|
||||
uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
|
||||
if uvIssRT == nil {
|
||||
unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
|
||||
if unverifiedIssRT == nil {
|
||||
return nil, fmt.Errorf("bad token")
|
||||
}
|
||||
|
||||
return []byte(uvIssRT.JWTKey), nil
|
||||
return []byte(unverifiedIssRT.JWTKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -293,7 +293,7 @@ func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
|||
return nil
|
||||
}
|
||||
|
||||
return uvIssRT
|
||||
return unverifiedIssRT
|
||||
}
|
||||
|
||||
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
|
||||
|
|
|
@ -34,12 +34,12 @@ type WebSocketManager interface {
|
|||
}
|
||||
|
||||
type WebSocketSession interface {
|
||||
WSConn() *websocket.Conn
|
||||
Go() error
|
||||
Conn() *websocket.Conn
|
||||
Go(context.Context) error
|
||||
Blas() Blas
|
||||
}
|
||||
|
||||
type Handler func(wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
|
||||
type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
|
||||
type NewData func(cmd []string) interface{}
|
||||
|
||||
type Shutdowner interface {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package frontend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
@ -56,7 +57,7 @@ func newData(_ []string) interface{} {
|
|||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func wsHand(wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
|
||||
func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"dynatron.me/x/blasphem/pkg/wsapi"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -22,6 +23,8 @@ func (s *Server) wsHandler(c echo.Context) error {
|
|||
defer conn.Close()
|
||||
|
||||
_ = log.Debug
|
||||
ctx, cancel := context.WithCancel(c.Request().Context())
|
||||
defer cancel()
|
||||
|
||||
return wsapi.NewSession(s, c, conn).Go()
|
||||
return wsapi.NewSession(s, c, conn).Go(ctx)
|
||||
}
|
||||
|
|
139
pkg/wsapi/api.go
139
pkg/wsapi/api.go
|
@ -1,9 +1,9 @@
|
|||
package wsapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"dynatron.me/x/blasphem/pkg/auth"
|
||||
|
@ -18,32 +18,26 @@ import (
|
|||
var (
|
||||
NoSuchHandlerErr = errors.New("bad websocket command")
|
||||
NoMessageIDErr = errors.New("no message ID")
|
||||
AuthInvalidErr = errors.New("invalid auth")
|
||||
)
|
||||
|
||||
type Type string
|
||||
type MsgBase struct {
|
||||
ID *int `json:"id,omitempty"`
|
||||
Type Type `json:"type"`
|
||||
}
|
||||
|
||||
type (
|
||||
wsSession struct {
|
||||
*websocket.Conn
|
||||
conn *websocket.Conn
|
||||
b core.Blas
|
||||
ec echo.Context
|
||||
h phaseHandler
|
||||
write chan<- interface{}
|
||||
|
||||
user *auth.User
|
||||
refreshToken *auth.RefreshToken
|
||||
}
|
||||
|
||||
phaseHandler interface {
|
||||
handleMsg(io.Reader) error
|
||||
}
|
||||
|
||||
cmdHandler struct {
|
||||
*wsSession
|
||||
}
|
||||
|
||||
wsEntry struct {
|
||||
dataNew core.NewData
|
||||
hnd core.Handler
|
||||
|
@ -83,47 +77,100 @@ func NewManager() core.WebSocketManager {
|
|||
|
||||
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
|
||||
ws := &wsSession{
|
||||
Conn: conn,
|
||||
conn: conn,
|
||||
b: s,
|
||||
ec: c,
|
||||
}
|
||||
|
||||
ws.h = &authPhase{ws}
|
||||
|
||||
return ws
|
||||
}
|
||||
|
||||
func (ws *wsSession) WSConn() *websocket.Conn {
|
||||
return ws.Conn
|
||||
func (ws *wsSession) Conn() *websocket.Conn {
|
||||
return ws.conn
|
||||
}
|
||||
|
||||
func (ws *wsSession) Blas() core.Blas { return ws.b }
|
||||
|
||||
func (ws *wsSession) Go() error {
|
||||
func (ws *wsSession) Go(ctx context.Context) error {
|
||||
authP := &authPhase{ws}
|
||||
err := ws.sendAuthRequired()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
_, rdr, err := ws.NextReader()
|
||||
switch {
|
||||
case err == nil:
|
||||
case websocket.IsCloseError(err, websocket.CloseGoingAway):
|
||||
return nil
|
||||
case err != nil:
|
||||
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
|
||||
return err
|
||||
}
|
||||
_, rdr, err := ws.conn.NextReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ws.h.handleMsg(rdr)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("handleMsg")
|
||||
err = authP.handleMsg(ctx, rdr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// command phase
|
||||
|
||||
msgChan := make(chan map[string]interface{})
|
||||
write := make(chan interface{})
|
||||
ws.write = write
|
||||
defer close(write)
|
||||
|
||||
cCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func(ctx context.Context, ch chan<- map[string]interface{}) {
|
||||
defer close(ch)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, rdr, err := ws.conn.NextReader()
|
||||
switch {
|
||||
case err == nil:
|
||||
case websocket.IsCloseError(err, websocket.CloseGoingAway):
|
||||
return
|
||||
case err != nil:
|
||||
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
|
||||
return
|
||||
}
|
||||
|
||||
var msgMap map[string]interface{}
|
||||
err = json.NewDecoder(rdr).Decode(&msgMap)
|
||||
if err != nil {
|
||||
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
|
||||
}
|
||||
|
||||
ch <- msgMap
|
||||
|
||||
}
|
||||
}(cCtx, msgChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-msgChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ws.handleMsg(ctx, msg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("handleMsg")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
close(msgChan)
|
||||
return nil
|
||||
case m := <-write:
|
||||
err := ws.conn.WriteJSON(m)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("writeMsg")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type cmdMsg struct {
|
||||
func (ws *wsSession) Write(msg interface{}) {
|
||||
ws.write <- msg
|
||||
}
|
||||
|
||||
type MsgType string
|
||||
|
@ -144,8 +191,8 @@ type WSError struct {
|
|||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
func (ws *cmdHandler) writeError(id int, err Error) error {
|
||||
return ws.WriteJSON(WSError{
|
||||
func (ws *wsSession) writeError(id int, err Error) error {
|
||||
return ws.conn.WriteJSON(WSError{
|
||||
ID: &id,
|
||||
Type: ResultMsgType,
|
||||
Success: false,
|
||||
|
@ -153,11 +200,9 @@ func (ws *cmdHandler) writeError(id int, err Error) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
||||
var msgMap map[string]interface{}
|
||||
err := json.NewDecoder(r).Decode(&msgMap)
|
||||
if err != nil {
|
||||
return err
|
||||
func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msgType, ok := msgMap["type"].(string)
|
||||
|
@ -167,8 +212,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
|||
|
||||
idFl, ok := msgMap["id"].(float64)
|
||||
if !ok {
|
||||
return ws.WriteJSON(
|
||||
WSError{
|
||||
ws.Write(WSError{
|
||||
Type: ResultMsgType,
|
||||
Success: false,
|
||||
Error: Error{
|
||||
|
@ -176,6 +220,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
|||
Message: "command has no ID",
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
id := int(idFl)
|
||||
|
@ -186,12 +231,15 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
|||
switch err {
|
||||
case nil:
|
||||
case NoSuchHandlerErr:
|
||||
return ws.writeError(id, Error{
|
||||
ws.writeError(id, Error{
|
||||
Code: "invalid_type",
|
||||
Message: "no such command",
|
||||
})
|
||||
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
log.Error().Err(err).Msg("dispatch")
|
||||
return nil
|
||||
}
|
||||
|
||||
nd := newData(cmd)
|
||||
|
@ -207,5 +255,10 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
|||
}
|
||||
}
|
||||
|
||||
return hand(ws, id, cmd, nd)
|
||||
err = hand(ctx, ws, id, cmd, nd)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("dispatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package wsapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
|
@ -18,10 +19,10 @@ func (ws *wsSession) sendAuthRequired() error {
|
|||
MsgBase
|
||||
Version string `json:"ha_version"`
|
||||
}{
|
||||
MsgBase{"auth_required"},
|
||||
MsgBase{Type: "auth_required"},
|
||||
ws.b.Version(),
|
||||
}
|
||||
return ws.WriteJSON(&authReq)
|
||||
return ws.conn.WriteJSON(&authReq)
|
||||
}
|
||||
|
||||
type authMsg struct {
|
||||
|
@ -36,30 +37,33 @@ func (ap *authPhase) msgSchema() interface{} {
|
|||
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
|
||||
ap.user = rt.User
|
||||
ap.refreshToken = rt
|
||||
ap.h = &cmdHandler{ap.wsSession}
|
||||
}
|
||||
|
||||
func (ap *authPhase) sendAuthOK() error {
|
||||
return ap.WriteJSON(struct {
|
||||
return ap.conn.WriteJSON(struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"ha_version"`
|
||||
}{Type: "auth_ok", Version: ap.Blas().Version()})
|
||||
}
|
||||
|
||||
func (ap *authPhase) sendAuthInvalid() error {
|
||||
return ap.WriteJSON(struct {
|
||||
return ap.conn.WriteJSON(struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}{Type: "auth_ok", Message: "invalid auth"})
|
||||
}
|
||||
|
||||
func (ap *authPhase) handleMsg(r io.Reader) error {
|
||||
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
|
||||
var authMsg authMsg
|
||||
err := json.NewDecoder(r).Decode(&authMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
|
||||
if refreshToken != nil {
|
||||
ap.finishAuth(refreshToken)
|
||||
|
@ -68,5 +72,10 @@ func (ap *authPhase) handleMsg(r io.Reader) error {
|
|||
|
||||
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
|
||||
|
||||
return ap.sendAuthInvalid()
|
||||
err = ap.sendAuthInvalid()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return AuthInvalidErr
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue