diff --git a/pkg/auth/session.go b/pkg/auth/session.go index b2e37e3..bc925c1 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -276,24 +276,25 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { } func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken { + var uvIssRT *RefreshToken claims := &jwt.StandardClaims{} - tok, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) { + _, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) { iss := jt.Claims.(*jwt.StandardClaims).Issuer - rt := a.store.GetRefreshToken(RefreshTokenID(iss)) - if rt == nil { + uvIssRT = a.store.GetRefreshToken(RefreshTokenID(iss)) + if uvIssRT == nil { return nil, fmt.Errorf("bad token") } - return rt.JWTKey, nil + return []byte(uvIssRT.JWTKey), nil }) if err != nil { + log.Error().Err(err).Msg("validateAccessToken") return nil } - iss := tok.Claims.(*jwt.StandardClaims).Issuer - return a.store.GetRefreshToken(RefreshTokenID(iss)) + return uvIssRT } func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 9777e91..97ce28b 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -100,7 +100,7 @@ func (as *authStore) GetRefreshToken(tid RefreshTokenID) *RefreshToken { for _, u := range as.Users { for _, rt := range u.RefreshTokens { - if subtle.ConstantTimeCompare([]byte(tid), []byte(rt.ID)) == 1 { + if subtle.ConstantTimeCompare([]byte(tid), []byte(rt.ID.String())) == 1 { found = rt found.User = u } diff --git a/pkg/blas/blas.go b/pkg/blas/blas.go index 78b4823..297187d 100644 --- a/pkg/blas/blas.go +++ b/pkg/blas/blas.go @@ -13,6 +13,7 @@ import ( "dynatron.me/x/blasphem/pkg/components" "dynatron.me/x/blasphem/pkg/config" "dynatron.me/x/blasphem/pkg/storage" + "dynatron.me/x/blasphem/pkg/wsapi" "github.com/rs/zerolog/log" ) @@ -82,6 +83,7 @@ func New(cfg *config.Config) (b *Blas, err error) { Bus: bus.New(), Config: cfg, components: make(components.ComponentStore), + WebSocketManager: wsapi.NewManager(), } err = b.openStore() diff --git a/pkg/blas/core/core.go b/pkg/blas/core/core.go index fbafac7..416f7dd 100644 --- a/pkg/blas/core/core.go +++ b/pkg/blas/core/core.go @@ -36,7 +36,7 @@ type WebSocketSession interface { Blas() Blas } -type Handler func(wss WebSocketSession, msg interface{}) error +type Handler func(wss WebSocketSession, msgID int, msg interface{}) error type NewData func() interface{} type Shutdowner interface { diff --git a/pkg/wsapi/api.go b/pkg/wsapi/api.go index 5b6d4fd..61e5f75 100644 --- a/pkg/wsapi/api.go +++ b/pkg/wsapi/api.go @@ -15,6 +15,7 @@ import ( var ( NoSuchHandlerErr = errors.New("bad websocket command") + NoMessageIDErr = errors.New("no message ID") ) type Type string @@ -113,6 +114,30 @@ type cmdMsg struct { } type MsgType string +const ( + ResultMsgType MsgType = "result" +) + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type WSError struct { + ID *int `json:"id,omitempty"` + Type MsgType `json:"type"` + Success bool `json:"success"` + Error Error `json:"error"` +} + +func (ws *cmdHandler) writeError(id int, err Error) error { + return ws.WriteJSON(WSError{ + ID: &id, + Type: ResultMsgType, + Success: false, + Error: err, + }) +} func (ws *cmdHandler) handleMsg(r io.Reader) error { var msgMap map[string]interface{} @@ -126,11 +151,33 @@ func (ws *cmdHandler) handleMsg(r io.Reader) error { return NoSuchHandlerErr } - newData, hand, err := ws.b.WSCommandHandler(msgType) + idFl, ok := msgMap["id"].(float64) if !ok { + return ws.WriteJSON( + WSError{ + Type: ResultMsgType, + Success: false, + Error: Error{ + Code: "invalid_id", + Message: "command has no ID", + }, + }) + } + + id := int(idFl) + + newData, hand, err := ws.b.WSCommandHandler(msgType) + switch err { + case nil: + case NoSuchHandlerErr: + return ws.writeError(id, Error{ + Code: "invalid_type", + Message: "no such command", + }) + default: return err } nd := newData() - return hand(ws, nd) + return hand(ws, id, nd) } diff --git a/pkg/wsapi/auth.go b/pkg/wsapi/auth.go index 4254ecd..5ecec80 100644 --- a/pkg/wsapi/auth.go +++ b/pkg/wsapi/auth.go @@ -16,7 +16,7 @@ type authPhase struct { func (ws *wsSession) sendAuthRequired() error { authReq := &struct { MsgBase - Version string `json:"version"` + Version string `json:"ha_version"` }{ MsgBase{"auth_required"}, ws.b.Version(), @@ -37,13 +37,12 @@ func (ap *authPhase) finishAuth(rt *auth.RefreshToken) { ap.user = rt.User ap.refreshToken = rt ap.h = &cmdHandler{ap.wsSession} - ap.sendAuthOK() } func (ap *authPhase) sendAuthOK() error { return ap.WriteJSON(struct { Type string `json:"type"` - Version string `json:"version"` + Version string `json:"ha_version"` }{Type: "auth_ok", Version: ap.Blas().Version()}) }