WIP: websocket

This commit is contained in:
Daniel Ponte 2022-12-19 19:24:01 -05:00
parent c91cd6efca
commit 2de97a6936
16 changed files with 261 additions and 56 deletions

View file

@ -1,11 +1,14 @@
FE=pkg/frontend/frontend
VER=$(shell git describe --always --tags)
LDFLAGS=-ldflags='-X dynatron.me/x/blasphem/internal/common.Version=${VER}'
all: build
build:
go build -o blas ./cmd/blas/
go build ${LDFLAGS} -o blas ./cmd/blas/
serve:
go run ./cmd/blas/ serve ${BLAS_ARGS}
go run ${LDFLAGS} ./cmd/blas/ serve ${BLAS_ARGS}
# pkg/frontend/frontend/hass_frontend:
frontend:

View file

@ -6,6 +6,9 @@ import (
"github.com/spf13/cobra"
)
// this symbol overriden by linker args
var Version = "undefined"
const (
// AppName is the name of the application.
AppName = "blasphem"

View file

@ -25,7 +25,7 @@ var (
ErrUserAuthRemote = errors.New("user cannot authenticate remotely")
)
type Authenticator struct {
type authenticator struct {
sync.Mutex
store AuthStore
flows *AuthFlowManager
@ -33,12 +33,16 @@ type Authenticator struct {
providers map[string]provider.AuthProvider
}
type Authenticator interface {
ValidateAccessToken(token AccessToken) *RefreshToken
}
type AuthError struct {
Error string `json:"error"`
Description string `json:"error_description"`
}
func (a *Authenticator) InstallRoutes(e *echo.Echo) {
func (a *authenticator) installRoutes(e *echo.Echo) {
authG := e.Group("/auth")
authG.GET("/authorize", frontend.AliasHandler("authorize.html"))
authG.GET("/providers", a.ProvidersHandler)
@ -51,12 +55,15 @@ func (a *Authenticator) InstallRoutes(e *echo.Echo) {
loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler)
}
func (a *Authenticator) InitAuth(s storage.Store) error {
a.providers = make(map[string]provider.AuthProvider)
func New(e *echo.Echo, s storage.Store) (Authenticator, error) {
a := &authenticator{
providers: make(map[string]provider.AuthProvider),
}
for _, pI := range provider.Providers {
nProv, err := pI(s)
if err != nil {
return err
return nil, err
}
a.providers[nProv.ProviderType()] = nProv
@ -69,13 +76,15 @@ func (a *Authenticator) InitAuth(s storage.Store) error {
var err error
a.store, err = a.newAuthStore(s)
if err != nil {
return err
return nil, err
}
return nil
a.installRoutes(e)
return a, nil
}
func (a *Authenticator) Provider(name string) provider.AuthProvider {
func (a *authenticator) Provider(name string) provider.AuthProvider {
p, ok := a.providers[name]
if !ok {
return nil
@ -87,7 +96,7 @@ func (a *Authenticator) Provider(name string) provider.AuthProvider {
var HomeAssistant = "homeassistant"
// TODO: make this configurable
func (a *Authenticator) ProvidersHandler(c echo.Context) error {
func (a *authenticator) ProvidersHandler(c echo.Context) error {
providers := []provider.AuthProviderBase{
a.Provider(HomeAssistant).ProviderBase(),
}
@ -95,7 +104,7 @@ func (a *Authenticator) ProvidersHandler(c echo.Context) error {
return c.JSON(http.StatusOK, providers)
}
func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) {
func (a *authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) {
cID, hasCID := rm["client_id"]
clientID, cidIsStr := cID.(string)
if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) {

View file

@ -62,7 +62,7 @@ func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.Au
return lf
}
func (a *Authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
func (a *authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
var prov provider.AuthProvider
for _, h := range r.Handler {
@ -89,7 +89,7 @@ func (f *LoginFlow) redirect(c echo.Context) {
c.Request().Header.Set("Location", f.RedirectURI)
}
func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
func (f *LoginFlow) progress(a *authenticator, c echo.Context) error {
switch f.Step() {
case flow.StepInit:
rm := make(map[string]interface{})
@ -136,7 +136,7 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
}
}
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
func (a *authenticator) LoginFlowDeleteHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
@ -158,7 +158,7 @@ func setJSON(c echo.Context) {
}
}
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
func (a *authenticator) BeginLoginFlowHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
@ -181,7 +181,7 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
return c.JSON(http.StatusOK, resp)
}
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
func (a *authenticator) LoginFlowHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()

View file

@ -153,6 +153,8 @@ type RefreshToken struct {
LastUsedIP *string `json:"last_used_ip"`
CredentialID *CredID `json:"credential_id"`
Version *string `json:"version"`
User *User `json:"-"`
}
func (rt *RefreshToken) IsValid() bool {
@ -200,9 +202,9 @@ func WithCredential(c *Credentials) RefreshOption {
}
}
const DefaultAccessExpiration = "1800"
const DefaultAccessExpiration = "1800" // json 🤮
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
func (a *authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
e := func(es string, arg ...interface{}) (*RefreshToken, error) {
return nil, fmt.Errorf(es, arg...)
}
@ -216,6 +218,7 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
JWTKey: generate.Hex(64),
CreatedAt: &now,
AccessTokenExpiration: DefaultAccessExpiration,
User: user,
}
for _, opt := range opts {
@ -272,7 +275,7 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
}).SignedString([]byte(r.JWTKey))
}
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
cred, success := a.authCodes.get(tr)
if !success {
return nil
@ -283,7 +286,7 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
const defaultExpiration = 15 * time.Minute
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
func (a *authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
return a.authCodes.put(clientID, cred)
}
@ -316,7 +319,7 @@ type TokenRequest struct {
const AuthFailed = "authentication failure"
func (a *Authenticator) TokenHandler(c echo.Context) error {
func (a *authenticator) TokenHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
@ -397,13 +400,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
user := a.store.User(rt.UserID)
if user == nil {
log.Error().Str("userID", string(rt.UserID)).Msg("no such user")
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
if err := user.allowedToAuth(c.Request()); err != nil {
if err := rt.User.allowedToAuth(c.Request()); err != nil {
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
}
@ -425,3 +422,5 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
type AccessToken string

View file

@ -86,6 +86,7 @@ func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshTok
for _, rt := range u.RefreshTokens {
if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 {
found = rt
found.User = u
}
}
}
@ -106,7 +107,7 @@ func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
return c
}
func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
func (a *authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
as = &authStore{
store: s,
}

View file

@ -53,7 +53,7 @@ func (u *User) allowedToAuth(r *http.Request) error {
return ErrUserAuthRemote
}
func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) {
func (a *authenticator) getOrCreateUser(c *Credentials) (*User, error) {
u := a.store.User(c.UserID)
if u == nil {
return nil, ErrInvalidAuth

View file

@ -7,20 +7,44 @@ import (
"strings"
"dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/bus"
"dynatron.me/x/blasphem/pkg/config"
"dynatron.me/x/blasphem/pkg/storage"
)
type Core interface {
auth.Authenticator
bus.Bus
storage.Store
config.Configured
Shutdowner
Versioner
}
type Shutdowner interface {
Shutdown(context.Context) error
}
type Versioner interface {
Version() string
}
type Blas struct {
*bus.Bus
bus.Bus
storage.Store
Config *config.Config
}
func (b *Blas) Shutdown(ctx context.Context) error {
b.Bus.Shutdown()
b.Store.Shutdown()
func (b *Blas) Version() string {
return common.Version
}
func (b *Blas) Conf() *config.Config { return b.Config }
func (b *Blas) ShutdownBlas(ctx context.Context) error {
b.Bus.ShutdownBus()
b.Store.ShutdownStore()
return ctx.Err()
}

View file

@ -12,22 +12,28 @@ type (
listeners []chan<- Event
Bus struct {
bus struct {
sync.RWMutex
subs map[string]listeners
}
Bus interface {
Sub(topic string, ch chan<- Event)
Unsub(topic string, ch chan<- Event)
Pub(topic string, data interface{})
ShutdownBus()
}
)
func New() *Bus {
bus := &Bus{
func New() Bus {
bus := &bus{
subs: make(map[string]listeners),
}
return bus
}
func (b *Bus) Sub(topic string, ch chan<- Event) {
func (b *bus) Sub(topic string, ch chan<- Event) {
b.Lock()
defer b.Unlock()
@ -38,7 +44,7 @@ func (b *Bus) Sub(topic string, ch chan<- Event) {
}
}
func (b *Bus) Unsub(topic string, ch chan<- Event) {
func (b *bus) Unsub(topic string, ch chan<- Event) {
b.Lock()
defer b.Unlock()
@ -51,7 +57,7 @@ func (b *Bus) Unsub(topic string, ch chan<- Event) {
}
}
func (b *Bus) Pub(topic string, data interface{}) {
func (b *bus) Pub(topic string, data interface{}) {
b.RLock()
defer b.RUnlock()
@ -65,7 +71,7 @@ func (b *Bus) Pub(topic string, data interface{}) {
}
}
func (b *Bus) Shutdown() {
func (b *bus) ShutdownBus() {
for _, v := range b.subs {
for _, c := range v {
close(c)

View file

@ -11,6 +11,10 @@ import (
"gopkg.in/yaml.v3"
)
type Configured interface {
Conf() *Config
}
type Config struct {
DataDir *string `yaml:"data_dir,omitempty"`
Server *server.Config `yaml:"server"`

View file

@ -29,8 +29,6 @@ type Server struct {
func (s *Server) installRoutes() {
s.GET("/*", frontend.FSHandler)
s.GET("/api/websocket", s.wsHandler)
s.Authenticator.InstallRoutes(s.Echo)
}
func New(cfg *config.Config) (s *Server, err error) {
@ -43,7 +41,8 @@ func New(cfg *config.Config) (s *Server, err error) {
Blas: b,
Echo: echo.New(),
}
err = s.InitAuth(b.Store)
s.Authenticator, err = auth.New(s.Echo, b.Store)
if err != nil {
return s, err
}
@ -80,7 +79,7 @@ func New(cfg *config.Config) (s *Server, err error) {
}
func (s *Server) Shutdown(ctx context.Context) error {
err := s.Blas.Shutdown(ctx)
err := s.Blas.ShutdownBlas(ctx)
if err != nil {
return err
}
@ -91,7 +90,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
func (s *Server) Go() error {
s.wg.Add(1)
go func() {
log.Info().Str("bind", s.Config.Server.Bind).Msg("Server listening")
log.Info().Str("version", s.Version()).Str("bind", s.Config.Server.Bind).Msg("Server listening")
err := s.Start(s.Config.Server.Bind)
if err != nil && err != http.ErrServerClosed {
s.Logger.Fatal(err)

View file

@ -1,11 +1,11 @@
package server
import (
"errors"
"log"
"dynatron.me/x/blasphem/pkg/wsapi"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
)
var upgrader = websocket.Upgrader{
@ -14,7 +14,14 @@ var upgrader = websocket.Upgrader{
}
func (s *Server) wsHandler(c echo.Context) error {
log.Println("WebSocket")
//conn, err := upgrader.Upgrade(w, req, nil)
return errors.New("not handled")
conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer conn.Close()
log.Debug().Str("remote", c.Request().RemoteAddr).Msg("WS")
return wsapi.New(s, c, conn).Handle()
}

View file

@ -129,7 +129,7 @@ func (s *fsStore) FlushAll() []error {
return errs
}
func (s *fsStore) Shutdown() {
func (s *fsStore) ShutdownStore() {
errs := s.FlushAll()
if errs != nil {
log.Error().Errs("errors", errs).Msg("errors persisting store")

View file

@ -50,6 +50,6 @@ type Store interface {
// Flush flushes a single key to backing.
Flush(key string) error
// Shutdown is called to quiesce and shutdown the store.
Shutdown()
// ShutdownStore is called to quiesce and shutdown the store.
ShutdownStore()
}

99
pkg/wsapi/api.go Normal file
View file

@ -0,0 +1,99 @@
package wsapi
import (
"encoding/json"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/blas"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
)
type Type string
type MsgBase struct {
Type Type `json:"type"`
}
type (
wsSession struct {
*websocket.Conn
b blas.Core
ec echo.Context
h phaseHandler
user *auth.User
refreshToken *auth.RefreshToken
}
WS interface {
Handle() error
}
phaseHandler interface {
handleMsg(msg interface{}) error
msgSchema() interface{}
}
cmdHandler struct {
*wsSession
}
)
func New(s blas.Core, c echo.Context, conn *websocket.Conn) WS {
ws := &wsSession{
Conn: conn,
b: s,
ec: c,
}
return ws
}
func (ws *wsSession) Handle() error {
err := ws.sendAuthRequired()
if err != nil {
return err
}
for {
msg := ws.h.msgSchema()
err := ws.ReadJSON(msg)
if err != nil {
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
}
err = ws.h.handleMsg(msg)
if err != nil {
return err
}
}
}
type cmdMsg struct {
}
type MsgType string
func (cm *cmdMsg) UnmarshalJSON(b []byte) error {
var t typeType struct {
Type MsgType `json:"type"`
}
err := json.Unmarshal(b, &t)
if err != nil {
return err
}
}
func (ws *cmdHandler) msgSchema() interface{} {
return &cmdMsg{}
}
func (ws *cmdHandler) handleMsg(msg interface{}) error {
return nil
}

51
pkg/wsapi/auth.go Normal file
View file

@ -0,0 +1,51 @@
package wsapi
import (
"dynatron.me/x/blasphem/pkg/auth"
"github.com/rs/zerolog/log"
)
type authPhase struct {
*wsSession
}
func (ws *wsSession) sendAuthRequired() error {
authReq := &struct{
MsgBase
Version string `json:"version"`
}{
MsgBase{"auth_required"},
ws.b.Version(),
}
return ws.WriteJSON(&authReq)
}
type authMsg struct {
MsgBase
AccessToken auth.AccessToken `json:"access_token"`
}
func (ap *authPhase) msgSchema() interface{} {
return &authMsg{}
}
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
ap.user = rt.User
ap.refreshToken = rt
ap.h = &cmdHandler{ap.wsSession}
}
func (ap *authPhase) handleMsg(msg interface{}) error {
authMsg := msg.(*authMsg)
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
if refreshToken != nil {
ap.finishAuth(refreshToken)
}
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
return auth.ErrInvalidAuth
}