81 lines
1.5 KiB
Go
81 lines
1.5 KiB
Go
package wsapi
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
|
|
"dynatron.me/x/blasphem/pkg/auth"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
type authPhase struct {
|
|
*wsSession
|
|
}
|
|
|
|
func (ws *wsSession) sendAuthRequired() error {
|
|
authReq := &struct {
|
|
MsgBase
|
|
Version string `json:"ha_version"`
|
|
}{
|
|
MsgBase{Type: "auth_required"},
|
|
ws.b.Version(),
|
|
}
|
|
return ws.conn.WriteJSON(&authReq)
|
|
}
|
|
|
|
type authMsg struct {
|
|
MsgBase
|
|
AccessToken auth.AccessToken `json:"access_token"`
|
|
}
|
|
|
|
func (ap *authPhase) msgSchema() interface{} {
|
|
return &authMsg{}
|
|
}
|
|
|
|
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
|
|
ap.user = rt.User
|
|
ap.refreshToken = rt
|
|
}
|
|
|
|
func (ap *authPhase) sendAuthOK() error {
|
|
return ap.conn.WriteJSON(struct {
|
|
Type string `json:"type"`
|
|
Version string `json:"ha_version"`
|
|
}{Type: "auth_ok", Version: ap.Blas().Version()})
|
|
}
|
|
|
|
func (ap *authPhase) sendAuthInvalid() error {
|
|
return ap.conn.WriteJSON(struct {
|
|
Type string `json:"type"`
|
|
Message string `json:"message"`
|
|
}{Type: "auth_ok", Message: "invalid auth"})
|
|
}
|
|
|
|
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
|
|
var authMsg authMsg
|
|
err := json.NewDecoder(r).Decode(&authMsg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
|
|
if refreshToken != nil {
|
|
ap.finishAuth(refreshToken)
|
|
return ap.sendAuthOK()
|
|
}
|
|
|
|
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
|
|
|
|
err = ap.sendAuthInvalid()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return AuthInvalidErr
|
|
}
|