concurrent ws
This commit is contained in:
parent
8bf7379170
commit
c424b75cf6
6 changed files with 125 additions and 59 deletions
|
@ -276,16 +276,16 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
||||||
var uvIssRT *RefreshToken
|
var unverifiedIssRT *RefreshToken
|
||||||
claims := &jwt.StandardClaims{}
|
claims := &jwt.StandardClaims{}
|
||||||
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
|
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
|
||||||
iss := jt.Claims.(*jwt.StandardClaims).Issuer
|
iss := jt.Claims.(*jwt.StandardClaims).Issuer
|
||||||
uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
|
unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
|
||||||
if uvIssRT == nil {
|
if unverifiedIssRT == nil {
|
||||||
return nil, fmt.Errorf("bad token")
|
return nil, fmt.Errorf("bad token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(uvIssRT.JWTKey), nil
|
return []byte(unverifiedIssRT.JWTKey), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -293,7 +293,7 @@ func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return uvIssRT
|
return unverifiedIssRT
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
|
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
|
||||||
|
|
|
@ -34,12 +34,12 @@ type WebSocketManager interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type WebSocketSession interface {
|
type WebSocketSession interface {
|
||||||
WSConn() *websocket.Conn
|
Conn() *websocket.Conn
|
||||||
Go() error
|
Go(context.Context) error
|
||||||
Blas() Blas
|
Blas() Blas
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handler func(wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
|
type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
|
||||||
type NewData func(cmd []string) interface{}
|
type NewData func(cmd []string) interface{}
|
||||||
|
|
||||||
type Shutdowner interface {
|
type Shutdowner interface {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package frontend
|
package frontend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
@ -56,7 +57,7 @@ func newData(_ []string) interface{} {
|
||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsHand(wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
|
func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"dynatron.me/x/blasphem/pkg/wsapi"
|
"dynatron.me/x/blasphem/pkg/wsapi"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -22,6 +23,8 @@ func (s *Server) wsHandler(c echo.Context) error {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
_ = log.Debug
|
_ = log.Debug
|
||||||
|
ctx, cancel := context.WithCancel(c.Request().Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
return wsapi.NewSession(s, c, conn).Go()
|
return wsapi.NewSession(s, c, conn).Go(ctx)
|
||||||
}
|
}
|
||||||
|
|
137
pkg/wsapi/api.go
137
pkg/wsapi/api.go
|
@ -1,9 +1,9 @@
|
||||||
package wsapi
|
package wsapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/pkg/auth"
|
"dynatron.me/x/blasphem/pkg/auth"
|
||||||
|
@ -18,32 +18,26 @@ import (
|
||||||
var (
|
var (
|
||||||
NoSuchHandlerErr = errors.New("bad websocket command")
|
NoSuchHandlerErr = errors.New("bad websocket command")
|
||||||
NoMessageIDErr = errors.New("no message ID")
|
NoMessageIDErr = errors.New("no message ID")
|
||||||
|
AuthInvalidErr = errors.New("invalid auth")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Type string
|
type Type string
|
||||||
type MsgBase struct {
|
type MsgBase struct {
|
||||||
|
ID *int `json:"id,omitempty"`
|
||||||
Type Type `json:"type"`
|
Type Type `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
wsSession struct {
|
wsSession struct {
|
||||||
*websocket.Conn
|
conn *websocket.Conn
|
||||||
b core.Blas
|
b core.Blas
|
||||||
ec echo.Context
|
ec echo.Context
|
||||||
h phaseHandler
|
write chan<- interface{}
|
||||||
|
|
||||||
user *auth.User
|
user *auth.User
|
||||||
refreshToken *auth.RefreshToken
|
refreshToken *auth.RefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
phaseHandler interface {
|
|
||||||
handleMsg(io.Reader) error
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdHandler struct {
|
|
||||||
*wsSession
|
|
||||||
}
|
|
||||||
|
|
||||||
wsEntry struct {
|
wsEntry struct {
|
||||||
dataNew core.NewData
|
dataNew core.NewData
|
||||||
hnd core.Handler
|
hnd core.Handler
|
||||||
|
@ -83,47 +77,100 @@ func NewManager() core.WebSocketManager {
|
||||||
|
|
||||||
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
|
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
|
||||||
ws := &wsSession{
|
ws := &wsSession{
|
||||||
Conn: conn,
|
conn: conn,
|
||||||
b: s,
|
b: s,
|
||||||
ec: c,
|
ec: c,
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.h = &authPhase{ws}
|
|
||||||
|
|
||||||
return ws
|
return ws
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *wsSession) WSConn() *websocket.Conn {
|
func (ws *wsSession) Conn() *websocket.Conn {
|
||||||
return ws.Conn
|
return ws.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *wsSession) Blas() core.Blas { return ws.b }
|
func (ws *wsSession) Blas() core.Blas { return ws.b }
|
||||||
|
|
||||||
func (ws *wsSession) Go() error {
|
func (ws *wsSession) Go(ctx context.Context) error {
|
||||||
|
authP := &authPhase{ws}
|
||||||
err := ws.sendAuthRequired()
|
err := ws.sendAuthRequired()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, rdr, err := ws.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = authP.handleMsg(ctx, rdr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// command phase
|
||||||
|
|
||||||
|
msgChan := make(chan map[string]interface{})
|
||||||
|
write := make(chan interface{})
|
||||||
|
ws.write = write
|
||||||
|
defer close(write)
|
||||||
|
|
||||||
|
cCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func(ctx context.Context, ch chan<- map[string]interface{}) {
|
||||||
|
defer close(ch)
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, rdr, err := ws.conn.NextReader()
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
case websocket.IsCloseError(err, websocket.CloseGoingAway):
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgMap map[string]interface{}
|
||||||
|
err = json.NewDecoder(rdr).Decode(&msgMap)
|
||||||
|
if err != nil {
|
||||||
|
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- msgMap
|
||||||
|
|
||||||
|
}
|
||||||
|
}(cCtx, msgChan)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
_, rdr, err := ws.NextReader()
|
select {
|
||||||
switch {
|
case msg, ok := <-msgChan:
|
||||||
case err == nil:
|
if !ok {
|
||||||
case websocket.IsCloseError(err, websocket.CloseGoingAway):
|
return nil
|
||||||
return nil
|
}
|
||||||
case err != nil:
|
|
||||||
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ws.h.handleMsg(rdr)
|
err = ws.handleMsg(ctx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("handleMsg")
|
log.Error().Err(err).Msg("handleMsg")
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(msgChan)
|
||||||
|
return nil
|
||||||
|
case m := <-write:
|
||||||
|
err := ws.conn.WriteJSON(m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("writeMsg")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type cmdMsg struct {
|
func (ws *wsSession) Write(msg interface{}) {
|
||||||
|
ws.write <- msg
|
||||||
}
|
}
|
||||||
|
|
||||||
type MsgType string
|
type MsgType string
|
||||||
|
@ -144,8 +191,8 @@ type WSError struct {
|
||||||
Error Error `json:"error"`
|
Error Error `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *cmdHandler) writeError(id int, err Error) error {
|
func (ws *wsSession) writeError(id int, err Error) error {
|
||||||
return ws.WriteJSON(WSError{
|
return ws.conn.WriteJSON(WSError{
|
||||||
ID: &id,
|
ID: &id,
|
||||||
Type: ResultMsgType,
|
Type: ResultMsgType,
|
||||||
Success: false,
|
Success: false,
|
||||||
|
@ -153,11 +200,9 @@ func (ws *cmdHandler) writeError(id int, err Error) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
|
||||||
var msgMap map[string]interface{}
|
if err := ctx.Err(); err != nil {
|
||||||
err := json.NewDecoder(r).Decode(&msgMap)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, ok := msgMap["type"].(string)
|
msgType, ok := msgMap["type"].(string)
|
||||||
|
@ -167,8 +212,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
||||||
|
|
||||||
idFl, ok := msgMap["id"].(float64)
|
idFl, ok := msgMap["id"].(float64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ws.WriteJSON(
|
ws.Write(WSError{
|
||||||
WSError{
|
|
||||||
Type: ResultMsgType,
|
Type: ResultMsgType,
|
||||||
Success: false,
|
Success: false,
|
||||||
Error: Error{
|
Error: Error{
|
||||||
|
@ -176,6 +220,7 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
||||||
Message: "command has no ID",
|
Message: "command has no ID",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
id := int(idFl)
|
id := int(idFl)
|
||||||
|
@ -186,12 +231,15 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
case NoSuchHandlerErr:
|
case NoSuchHandlerErr:
|
||||||
return ws.writeError(id, Error{
|
ws.writeError(id, Error{
|
||||||
Code: "invalid_type",
|
Code: "invalid_type",
|
||||||
Message: "no such command",
|
Message: "no such command",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
return err
|
log.Error().Err(err).Msg("dispatch")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nd := newData(cmd)
|
nd := newData(cmd)
|
||||||
|
@ -207,5 +255,10 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hand(ws, id, cmd, nd)
|
err = hand(ctx, ws, id, cmd, nd)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("dispatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package wsapi
|
package wsapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
@ -18,10 +19,10 @@ func (ws *wsSession) sendAuthRequired() error {
|
||||||
MsgBase
|
MsgBase
|
||||||
Version string `json:"ha_version"`
|
Version string `json:"ha_version"`
|
||||||
}{
|
}{
|
||||||
MsgBase{"auth_required"},
|
MsgBase{Type: "auth_required"},
|
||||||
ws.b.Version(),
|
ws.b.Version(),
|
||||||
}
|
}
|
||||||
return ws.WriteJSON(&authReq)
|
return ws.conn.WriteJSON(&authReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authMsg struct {
|
type authMsg struct {
|
||||||
|
@ -36,30 +37,33 @@ func (ap *authPhase) msgSchema() interface{} {
|
||||||
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
|
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
|
||||||
ap.user = rt.User
|
ap.user = rt.User
|
||||||
ap.refreshToken = rt
|
ap.refreshToken = rt
|
||||||
ap.h = &cmdHandler{ap.wsSession}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ap *authPhase) sendAuthOK() error {
|
func (ap *authPhase) sendAuthOK() error {
|
||||||
return ap.WriteJSON(struct {
|
return ap.conn.WriteJSON(struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Version string `json:"ha_version"`
|
Version string `json:"ha_version"`
|
||||||
}{Type: "auth_ok", Version: ap.Blas().Version()})
|
}{Type: "auth_ok", Version: ap.Blas().Version()})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ap *authPhase) sendAuthInvalid() error {
|
func (ap *authPhase) sendAuthInvalid() error {
|
||||||
return ap.WriteJSON(struct {
|
return ap.conn.WriteJSON(struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}{Type: "auth_ok", Message: "invalid auth"})
|
}{Type: "auth_ok", Message: "invalid auth"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ap *authPhase) handleMsg(r io.Reader) error {
|
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
|
||||||
var authMsg authMsg
|
var authMsg authMsg
|
||||||
err := json.NewDecoder(r).Decode(&authMsg)
|
err := json.NewDecoder(r).Decode(&authMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
|
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
|
||||||
if refreshToken != nil {
|
if refreshToken != nil {
|
||||||
ap.finishAuth(refreshToken)
|
ap.finishAuth(refreshToken)
|
||||||
|
@ -68,5 +72,10 @@ func (ap *authPhase) handleMsg(r io.Reader) error {
|
||||||
|
|
||||||
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
|
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
|
||||||
|
|
||||||
return ap.sendAuthInvalid()
|
err = ap.sendAuthInvalid()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return AuthInvalidErr
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue