Compare commits

..

19 commits

Author SHA1 Message Date
f272514c65 Improve flows in preparation for MFA 2022-12-27 15:09:20 -05:00
60301c9892 check rt 2022-12-21 19:34:26 -05:00
14401c831a gofmt 2022-12-21 13:22:26 -05:00
c424b75cf6 concurrent ws 2022-12-21 13:22:18 -05:00
8bf7379170 don't use map structure if we're already a map 2022-12-20 21:27:07 -05:00
2ddac95715 reorder 2022-12-20 21:24:14 -05:00
5570e53d99 handle close 2022-12-20 21:22:00 -05:00
6bd36a59c5 gofmt 2022-12-20 20:34:53 -05:00
6443443c6b allow unsplit command registrations 2022-12-20 20:34:45 -05:00
76f95a422b newdata can switch on command 2022-12-20 20:30:58 -05:00
caa371eff1 commands working ish 2022-12-20 20:11:11 -05:00
f119f7086f wip: auth still broken 2022-12-20 19:31:46 -05:00
95d72d2912 auth progress 2022-12-20 19:05:45 -05:00
a1005ce6bf works 2022-12-20 16:26:04 -05:00
1b355d3cbf also compiles 2022-12-20 13:54:49 -05:00
0da222577b it compiles 2022-12-20 13:31:31 -05:00
0378151b9f compiles 2022-12-20 13:16:30 -05:00
a468f0629b WIP: pre dep injection 2022-12-20 11:34:25 -05:00
2de97a6936 WIP: websocket 2022-12-19 19:24:01 -05:00
28 changed files with 816 additions and 194 deletions

View file

@ -1,11 +1,14 @@
FE=pkg/frontend/frontend FE=pkg/frontend/frontend
VER=$(shell git describe --always --tags)
LDFLAGS=-ldflags='-X dynatron.me/x/blasphem/internal/common.Version=${VER}'
all: build all: build
build: build:
go build -o blas ./cmd/blas/ go build ${LDFLAGS} -o blas ./cmd/blas/
serve: serve:
go run ./cmd/blas/ serve ${BLAS_ARGS} go run ${LDFLAGS} ./cmd/blas/ serve ${BLAS_ARGS}
# pkg/frontend/frontend/hass_frontend: # pkg/frontend/frontend/hass_frontend:
frontend: frontend:

View file

@ -7,6 +7,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/pkg/blas"
"dynatron.me/x/blasphem/pkg/cmd/serve" "dynatron.me/x/blasphem/pkg/cmd/serve"
"dynatron.me/x/blasphem/pkg/config" "dynatron.me/x/blasphem/pkg/config"
@ -25,7 +26,12 @@ func main() {
log.Fatal().Err(err).Msg("Config read failed") log.Fatal().Err(err).Msg("Config read failed")
} }
rootCmd.AddCommand(serve.Command(config)) bl, err := blas.New(config)
if err != nil {
log.Fatal().Err(err).Msg("Core create failed")
}
rootCmd.AddCommand(serve.Command(bl))
err = rootCmd.Execute() err = rootCmd.Execute()
if err != nil { if err != nil {

1
go.mod
View file

@ -21,6 +21,7 @@ require (
github.com/labstack/gommon v0.3.1 // indirect github.com/labstack/gommon v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.1 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect

2
go.sum
View file

@ -36,6 +36,8 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

View file

@ -6,6 +6,9 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// this symbol overriden by linker args
var Version = "undefined"
const ( const (
// AppName is the name of the application. // AppName is the name of the application.
AppName = "blasphem" AppName = "blasphem"

View file

@ -9,7 +9,6 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/auth/provider"
"dynatron.me/x/blasphem/pkg/frontend"
"dynatron.me/x/blasphem/pkg/storage" "dynatron.me/x/blasphem/pkg/storage"
// providers // providers
@ -25,7 +24,7 @@ var (
ErrUserAuthRemote = errors.New("user cannot authenticate remotely") ErrUserAuthRemote = errors.New("user cannot authenticate remotely")
) )
type Authenticator struct { type authenticator struct {
sync.Mutex sync.Mutex
store AuthStore store AuthStore
flows *AuthFlowManager flows *AuthFlowManager
@ -33,14 +32,18 @@ type Authenticator struct {
providers map[string]provider.AuthProvider providers map[string]provider.AuthProvider
} }
type Authenticator interface {
ValidateAccessToken(token AccessToken) *RefreshToken
InstallRoutes(e *echo.Echo)
}
type AuthError struct { type AuthError struct {
Error string `json:"error"` Error string `json:"error"`
Description string `json:"error_description"` Description string `json:"error_description"`
} }
func (a *Authenticator) InstallRoutes(e *echo.Echo) { func (a *authenticator) InstallRoutes(e *echo.Echo) {
authG := e.Group("/auth") authG := e.Group("/auth")
authG.GET("/authorize", frontend.AliasHandler("authorize.html"))
authG.GET("/providers", a.ProvidersHandler) authG.GET("/providers", a.ProvidersHandler)
authG.POST("/token", a.TokenHandler) authG.POST("/token", a.TokenHandler)
@ -51,12 +54,15 @@ func (a *Authenticator) InstallRoutes(e *echo.Echo) {
loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler) loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler)
} }
func (a *Authenticator) InitAuth(s storage.Store) error { func New(s storage.Store) (Authenticator, error) {
a.providers = make(map[string]provider.AuthProvider) a := &authenticator{
providers: make(map[string]provider.AuthProvider),
}
for _, pI := range provider.Providers { for _, pI := range provider.Providers {
nProv, err := pI(s) nProv, err := pI(s)
if err != nil { if err != nil {
return err return nil, err
} }
a.providers[nProv.ProviderType()] = nProv a.providers[nProv.ProviderType()] = nProv
@ -69,13 +75,13 @@ func (a *Authenticator) InitAuth(s storage.Store) error {
var err error var err error
a.store, err = a.newAuthStore(s) a.store, err = a.newAuthStore(s)
if err != nil { if err != nil {
return err return nil, err
} }
return nil return a, nil
} }
func (a *Authenticator) Provider(name string) provider.AuthProvider { func (a *authenticator) Provider(name string) provider.AuthProvider {
p, ok := a.providers[name] p, ok := a.providers[name]
if !ok { if !ok {
return nil return nil
@ -87,7 +93,7 @@ func (a *Authenticator) Provider(name string) provider.AuthProvider {
var HomeAssistant = "homeassistant" var HomeAssistant = "homeassistant"
// TODO: make this configurable // TODO: make this configurable
func (a *Authenticator) ProvidersHandler(c echo.Context) error { func (a *authenticator) ProvidersHandler(c echo.Context) error {
providers := []provider.AuthProviderBase{ providers := []provider.AuthProviderBase{
a.Provider(HomeAssistant).ProviderBase(), a.Provider(HomeAssistant).ProviderBase(),
} }
@ -95,7 +101,7 @@ func (a *Authenticator) ProvidersHandler(c echo.Context) error {
return c.JSON(http.StatusOK, providers) return c.JSON(http.StatusOK, providers)
} }
func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) { func (a *authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) {
cID, hasCID := rm["client_id"] cID, hasCID := rm["client_id"]
clientID, cidIsStr := cID.(string) clientID, cidIsStr := cID.(string)
if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) { if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) {
@ -116,5 +122,3 @@ func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]int
return nil, clientID, ErrInvalidAuth return nil, clientID, ErrInvalidAuth
} }
//func (a *Authenticator) GetOrCreateCreds(

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"strings" "strings"
@ -19,6 +20,7 @@ type AuthFlowManager struct {
type LoginFlow struct { type LoginFlow struct {
flow.FlowHandler flow.FlowHandler
prov provider.AuthProvider
ClientID common.ClientID ClientID common.ClientID
FlowContext FlowContext
} }
@ -52,9 +54,10 @@ func NewAuthFlowManager() *AuthFlowManager {
func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.AuthProvider) *LoginFlow { func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.AuthProvider) *LoginFlow {
lf := &LoginFlow{ lf := &LoginFlow{
FlowHandler: flow.NewFlowHandlerBase(prov, prov.ProviderType()), FlowHandler: flow.NewFlowHandlerBase(prov.ProviderType()),
ClientID: req.ClientID, ClientID: req.ClientID,
FlowContext: req.FlowContext(), FlowContext: req.FlowContext(),
prov: prov,
} }
afm.Register(lf) afm.Register(lf)
@ -62,7 +65,7 @@ func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.Au
return lf return lf
} }
func (a *Authenticator) NewFlow(r *LoginFlowRequest) *flow.Result { func (a *authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
var prov provider.AuthProvider var prov provider.AuthProvider
for _, h := range r.Handler { for _, h := range r.Handler {
@ -80,16 +83,16 @@ func (a *Authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
return nil return nil
} }
flow := a.flows.NewLoginFlow(r, prov) lf := a.flows.NewLoginFlow(r, prov)
return flow.ShowForm(nil) return lf.ShowForm(lf.WithSchema(prov), lf.WithStep(flow.StepInit))
} }
func (f *LoginFlow) redirect(c echo.Context) { func (f *LoginFlow) redirect(c echo.Context) {
c.Request().Header.Set("Location", f.RedirectURI) c.Request().Header.Set("Location", f.RedirectURI)
} }
func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error { func (f *LoginFlow) progress(a *authenticator, c echo.Context) error {
switch f.Step() { switch f.Step() {
case flow.StepInit: case flow.StepInit:
rm := make(map[string]interface{}) rm := make(map[string]interface{})
@ -99,15 +102,19 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
return c.String(http.StatusBadRequest, err.Error()) return c.String(http.StatusBadRequest, err.Error())
} }
err = f.Schema.CheckRequired(rm) err = f.prov.FlowSchema().CheckRequired(rm)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()})) return c.JSON(http.StatusBadRequest, f.ShowForm(f.WithErrors([]string{err.Error()})))
} }
user, clientID, err := a.Check(f, c.Request(), rm) user, clientID, err := a.Check(f, c.Request(), rm)
switch err { switch err {
case nil: case nil:
creds := a.store.GetCredential(user) creds := a.store.GetCredential(user)
if creds == nil {
return fmt.Errorf("flow progress: no such credential for %v", user.UserData())
}
finishedFlow := flow.Result{} finishedFlow := flow.Result{}
a.flows.Remove(f) a.flows.Remove(f)
copier.Copy(&finishedFlow, f) copier.Copy(&finishedFlow, f)
@ -124,19 +131,19 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
case ErrInvalidAuth: case ErrInvalidAuth:
fallthrough fallthrough
default: default:
return c.JSON(http.StatusOK, f.ShowForm(map[string]interface{}{ return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
"base": "invalid_auth", "base": "invalid_auth",
})) })))
} }
default: default:
return c.JSON(http.StatusOK, f.ShowForm(map[string]interface{}{ return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
"base": "unknown_flow_step", "base": "unknown_flow_step",
})) })))
} }
} }
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error { func (a *authenticator) LoginFlowDeleteHandler(c echo.Context) error {
a.Lock() a.Lock()
defer a.Unlock() defer a.Unlock()
@ -158,7 +165,7 @@ func setJSON(c echo.Context) {
} }
} }
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error { func (a *authenticator) BeginLoginFlowHandler(c echo.Context) error {
a.Lock() a.Lock()
defer a.Unlock() defer a.Unlock()
@ -181,7 +188,7 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
return c.JSON(http.StatusOK, resp) return c.JSON(http.StatusOK, resp)
} }
func (a *Authenticator) LoginFlowHandler(c echo.Context) error { func (a *authenticator) LoginFlowHandler(c echo.Context) error {
a.Lock() a.Lock()
defer a.Unlock() defer a.Unlock()

View file

@ -162,18 +162,7 @@ func (hap *HomeAssistantProvider) NewCredData() interface{} {
} }
func (hap *HomeAssistantProvider) FlowSchema() flow.Schema { func (hap *HomeAssistantProvider) FlowSchema() flow.Schema {
return []flow.SchemaItem{ return flow.NewSchema(flow.RequiredString("username"), flow.RequiredString("password"))
{
Type: "string",
Name: "username",
Required: true,
},
{
Type: "string",
Name: "password",
Required: true,
},
}
} }
func init() { func init() {

View file

@ -79,18 +79,7 @@ func (hap *TrustedNetworksProvider) NewCredData() interface{} {
} }
func (hap *TrustedNetworksProvider) FlowSchema() flow.Schema { func (hap *TrustedNetworksProvider) FlowSchema() flow.Schema {
return []flow.SchemaItem{ return nil
{
Type: "string",
Name: "username",
Required: true,
},
{
Type: "string",
Name: "password",
Required: true,
},
}
} }
func init() { func init() {

View file

@ -153,6 +153,8 @@ type RefreshToken struct {
LastUsedIP *string `json:"last_used_ip"` LastUsedIP *string `json:"last_used_ip"`
CredentialID *CredID `json:"credential_id"` CredentialID *CredID `json:"credential_id"`
Version *string `json:"version"` Version *string `json:"version"`
User *User `json:"-"`
} }
func (rt *RefreshToken) IsValid() bool { func (rt *RefreshToken) IsValid() bool {
@ -200,9 +202,9 @@ func WithCredential(c *Credentials) RefreshOption {
} }
} }
const DefaultAccessExpiration = "1800" const DefaultAccessExpiration = "1800" // json 🤮
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) { func (a *authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
e := func(es string, arg ...interface{}) (*RefreshToken, error) { e := func(es string, arg ...interface{}) (*RefreshToken, error) {
return nil, fmt.Errorf(es, arg...) return nil, fmt.Errorf(es, arg...)
} }
@ -216,6 +218,7 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
JWTKey: generate.Hex(64), JWTKey: generate.Hex(64),
CreatedAt: &now, CreatedAt: &now,
AccessTokenExpiration: DefaultAccessExpiration, AccessTokenExpiration: DefaultAccessExpiration,
User: user,
} }
for _, opt := range opts { for _, opt := range opts {
@ -272,7 +275,28 @@ func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
}).SignedString([]byte(r.JWTKey)) }).SignedString([]byte(r.JWTKey))
} }
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
var unverifiedIssRT *RefreshToken
claims := &jwt.StandardClaims{}
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
iss := jt.Claims.(*jwt.StandardClaims).Issuer
unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
if unverifiedIssRT == nil {
return nil, fmt.Errorf("bad token")
}
return []byte(unverifiedIssRT.JWTKey), nil
})
if err != nil {
log.Error().Err(err).Msg("validateAccessToken")
return nil
}
return unverifiedIssRT
}
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
cred, success := a.authCodes.get(tr) cred, success := a.authCodes.get(tr)
if !success { if !success {
return nil return nil
@ -283,7 +307,7 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
const defaultExpiration = 15 * time.Minute const defaultExpiration = 15 * time.Minute
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string { func (a *authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
return a.authCodes.put(clientID, cred) return a.authCodes.put(clientID, cred)
} }
@ -316,7 +340,7 @@ type TokenRequest struct {
const AuthFailed = "authentication failure" const AuthFailed = "authentication failure"
func (a *Authenticator) TokenHandler(c echo.Context) error { func (a *authenticator) TokenHandler(c echo.Context) error {
a.Lock() a.Lock()
defer a.Unlock() defer a.Unlock()
@ -397,13 +421,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
} }
user := a.store.User(rt.UserID) if err := rt.User.allowedToAuth(c.Request()); err != nil {
if user == nil {
log.Error().Str("userID", string(rt.UserID)).Msg("no such user")
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
if err := user.allowedToAuth(c.Request()); err != nil {
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()}) return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
} }
@ -425,3 +443,5 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
} }
type AccessToken string

View file

@ -21,6 +21,7 @@ type AuthStore interface {
GetCredential(provider.ProviderUser) *Credentials GetCredential(provider.ProviderUser) *Credentials
PutRefreshToken(*RefreshToken) (*RefreshToken, error) PutRefreshToken(*RefreshToken) (*RefreshToken, error)
GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken
GetRefreshToken(RefreshTokenID) *RefreshToken
} }
type authStore struct { type authStore struct {
@ -86,6 +87,7 @@ func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshTok
for _, rt := range u.RefreshTokens { for _, rt := range u.RefreshTokens {
if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 { if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 {
found = rt found = rt
found.User = u
} }
} }
} }
@ -93,6 +95,22 @@ func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshTok
return found return found
} }
func (as *authStore) GetRefreshToken(tid RefreshTokenID) *RefreshToken {
var found *RefreshToken
for _, u := range as.Users {
for _, rt := range u.RefreshTokens {
if subtle.ConstantTimeCompare([]byte(tid), []byte(rt.ID.String())) == 1 {
found = rt
found.User = u
}
}
}
return found
}
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
// XXX: probably broken // XXX: probably broken
prov := p.Provider() prov := p.Provider()
@ -106,7 +124,7 @@ func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
return c return c
} }
func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) { func (a *authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
as = &authStore{ as = &authStore{
store: s, store: s,
} }

View file

@ -53,7 +53,7 @@ func (u *User) allowedToAuth(r *http.Request) error {
return ErrUserAuthRemote return ErrUserAuthRemote
} }
func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) { func (a *authenticator) getOrCreateUser(c *Credentials) (*User, error) {
u := a.store.User(c.UserID) u := a.store.User(c.UserID)
if u == nil { if u == nil {
return nil, ErrInvalidAuth return nil, ErrInvalidAuth

View file

@ -7,20 +7,36 @@ import (
"strings" "strings"
"dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/blas/core"
"dynatron.me/x/blasphem/pkg/bus" "dynatron.me/x/blasphem/pkg/bus"
"dynatron.me/x/blasphem/pkg/components"
"dynatron.me/x/blasphem/pkg/config" "dynatron.me/x/blasphem/pkg/config"
"dynatron.me/x/blasphem/pkg/storage" "dynatron.me/x/blasphem/pkg/storage"
"dynatron.me/x/blasphem/pkg/wsapi"
"github.com/rs/zerolog/log"
) )
type Blas struct { type Blas struct {
*bus.Bus bus.Bus
storage.Store storage.Store
auth.Authenticator
Config *config.Config Config *config.Config
core.WebSocketManager
components components.ComponentStore
} }
func (b *Blas) Shutdown(ctx context.Context) error { func (b *Blas) Version() string {
b.Bus.Shutdown() return common.Version
b.Store.Shutdown() }
func (b *Blas) Conf() *config.Config { return b.Config }
func (b *Blas) ShutdownBlas(ctx context.Context) error {
b.Bus.ShutdownBus()
b.Store.ShutdownStore()
return ctx.Err() return ctx.Err()
} }
@ -51,12 +67,42 @@ func (b *Blas) openStore() error {
return err return err
} }
func New(cfg *config.Config) (*Blas, error) { func (b *Blas) Component(k components.ComponentKey) components.Component {
b := &Blas{ c, ok := b.components[k]
Bus: bus.New(), if !ok {
Config: cfg, return nil
}
return c
}
func (b *Blas) Components() components.ComponentStore { return b.components }
func New(cfg *config.Config) (b *Blas, err error) {
b = &Blas{
Bus: bus.New(),
Config: cfg,
components: make(components.ComponentStore),
WebSocketManager: wsapi.NewManager(),
}
err = b.openStore()
if err != nil {
return nil, err
}
b.Authenticator, err = auth.New(b.Store)
for k, v := range Registry {
log.Info().Msgf("Setting up component %s", k)
c, err := v(b)
if err != nil {
log.Error().Err(err).Msgf("Error setting up component %s", k)
continue
}
b.components[k] = c
} }
err := b.openStore()
return b, err return b, err
} }

21
pkg/blas/components.go Normal file
View file

@ -0,0 +1,21 @@
package blas
import (
"fmt"
"dynatron.me/x/blasphem/pkg/blas/core"
"dynatron.me/x/blasphem/pkg/components"
)
type Setup func(core.Blas) (components.Component, error)
var Registry = make(map[components.ComponentKey]Setup)
func Register(key components.ComponentKey, c Setup) {
_, already := Registry[key]
if already {
panic(fmt.Sprintf("component %s already exists", key))
}
Registry[key] = c
}

51
pkg/blas/core/core.go Normal file
View file

@ -0,0 +1,51 @@
package core
import (
"context"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/bus"
"dynatron.me/x/blasphem/pkg/components"
"dynatron.me/x/blasphem/pkg/config"
"dynatron.me/x/blasphem/pkg/storage"
"github.com/gorilla/websocket"
)
type Blas interface {
auth.Authenticator
bus.Bus
storage.Store
config.Configured
components.Componenter
WebSocketManager
Shutdowner
Versioner
}
type WebSocketManager interface {
// Register registers a websocket command.
// cmd is the first part, before first slash
// dataNew is a function to create a new message datatype
RegisterWSCommand(cmd string, hnd Handler, dataNew NewData)
WSCommandHandler(cmd string, splitCmd []string) (NewData, Handler, error)
}
type WebSocketSession interface {
Conn() *websocket.Conn
Go(context.Context) error
Blas() Blas
}
type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
type NewData func(cmd []string) interface{}
type Shutdowner interface {
ShutdownBlas(context.Context) error
}
type Versioner interface {
Version() string
}

View file

@ -12,22 +12,28 @@ type (
listeners []chan<- Event listeners []chan<- Event
Bus struct { bus struct {
sync.RWMutex sync.RWMutex
subs map[string]listeners subs map[string]listeners
} }
Bus interface {
Sub(topic string, ch chan<- Event)
Unsub(topic string, ch chan<- Event)
Pub(topic string, data interface{})
ShutdownBus()
}
) )
func New() *Bus { func New() Bus {
bus := &Bus{ bus := &bus{
subs: make(map[string]listeners), subs: make(map[string]listeners),
} }
return bus return bus
} }
func (b *Bus) Sub(topic string, ch chan<- Event) { func (b *bus) Sub(topic string, ch chan<- Event) {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
@ -38,7 +44,7 @@ func (b *Bus) Sub(topic string, ch chan<- Event) {
} }
} }
func (b *Bus) Unsub(topic string, ch chan<- Event) { func (b *bus) Unsub(topic string, ch chan<- Event) {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
@ -51,7 +57,7 @@ func (b *Bus) Unsub(topic string, ch chan<- Event) {
} }
} }
func (b *Bus) Pub(topic string, data interface{}) { func (b *bus) Pub(topic string, data interface{}) {
b.RLock() b.RLock()
defer b.RUnlock() defer b.RUnlock()
@ -65,7 +71,7 @@ func (b *Bus) Pub(topic string, data interface{}) {
} }
} }
func (b *Bus) Shutdown() { func (b *bus) ShutdownBus() {
for _, v := range b.subs { for _, v := range b.subs {
for _, c := range v { for _, c := range v {
close(c) close(c)

View file

@ -2,18 +2,18 @@ package serve
import ( import (
"dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/pkg/config" blas "dynatron.me/x/blasphem/pkg/blas/core"
"dynatron.me/x/blasphem/pkg/server" "dynatron.me/x/blasphem/pkg/server"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
type ServeOptions struct { type ServeOptions struct {
cfg *config.Config core blas.Blas
} }
func Command(cfg *config.Config) *cobra.Command { func Command(core blas.Blas) *cobra.Command {
opts := makeOptions(cfg) opts := makeOptions(core)
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",
Short: "starts the " + common.AppName + " server", Short: "starts the " + common.AppName + " server",
@ -23,9 +23,9 @@ func Command(cfg *config.Config) *cobra.Command {
return serveCmd return serveCmd
} }
func makeOptions(cfg *config.Config) *ServeOptions { func makeOptions(core blas.Blas) *ServeOptions {
return &ServeOptions{ return &ServeOptions{
cfg: cfg, core: core,
} }
} }
@ -34,7 +34,7 @@ func (o *ServeOptions) Options(_ *cobra.Command, args []string) error {
} }
func (o *ServeOptions) Execute() error { func (o *ServeOptions) Execute() error {
server, err := server.New(o.cfg) server, err := server.New(o.core)
if err != nil { if err != nil {
return err return err
} }

View file

@ -0,0 +1,17 @@
package components
import ()
type Componenter interface {
Component(ComponentKey) Component
Components() ComponentStore
}
type (
ComponentStore map[ComponentKey]Component
ComponentKey string
Component interface {
Shutdown()
}
)

View file

@ -11,6 +11,10 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Configured interface {
Conf() *Config
}
type Config struct { type Config struct {
DataDir *string `yaml:"data_dir,omitempty"` DataDir *string `yaml:"data_dir,omitempty"`
Server *server.Config `yaml:"server"` Server *server.Config `yaml:"server"`

View file

@ -8,22 +8,19 @@ import (
"dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/internal/generate"
) )
type ( type ResultType string
ResultType string type FlowID string
FlowID string type Step string
Step string type HandlerKey string
HandlerKey string type Errors interface{}
Errors interface{}
Context interface{} type FlowStore map[FlowID]Handler
FlowStore map[FlowID]Handler type FlowManager struct {
FlowManager struct {
flows FlowStore flows FlowStore
} }
Result struct { type Result struct {
Type ResultType `json:"type"` Type ResultType `json:"type"`
ID FlowID `json:"flow_id"` ID FlowID `json:"flow_id"`
Handler []*HandlerKey `json:"handler"` Handler []*HandlerKey `json:"handler"`
@ -43,36 +40,21 @@ type (
LastStep *string `json:"last_step"` LastStep *string `json:"last_step"`
Options map[string]interface{} `json:"options,omitempty"` Options map[string]interface{} `json:"options,omitempty"`
Version *int `json:"version,omitempty"` Version *int `json:"version,omitempty"`
} }
SchemaItem struct { type Handler interface {
Type string `json:"type"`
Name string `json:"name"`
Required bool `json:"required"`
}
Schema []SchemaItem
)
type (
Schemer interface {
FlowSchema() Schema
}
Handler interface {
BaseHandler() FlowHandler BaseHandler() FlowHandler
FlowID() FlowID FlowID() FlowID
flowCtime() time.Time flowCtime() time.Time
} }
)
const ( const (
StepInit Step = "init" StepInit Step = "init"
) )
func (fs *Schema) CheckRequired(rm map[string]interface{}) error { func (fs Schema) CheckRequired(rm map[string]interface{}) error {
for _, si := range *fs { for _, si := range fs {
if si.Required { if si.Required {
if _, ok := rm[si.Name]; !ok { if _, ok := rm[si.Name]; !ok {
return fmt.Errorf("missing required param %s", si.Name) return fmt.Errorf("missing required param %s", si.Name)
@ -94,8 +76,6 @@ func stepPtr(s Step) *Step { return &s }
type FlowHandler struct { type FlowHandler struct {
ID FlowID // ID is the FlowID ID FlowID // ID is the FlowID
Handler HandlerKey // Handler key Handler HandlerKey // Handler key
Context Context // flow Context
Schema Schema
// curStep is the current step set by the flow manager // curStep is the current step set by the flow manager
curStep Step curStep Step
@ -113,11 +93,10 @@ func (f *FlowHandler) FlowID() FlowID {
func (f *FlowHandler) flowCtime() time.Time { return f.ctime } func (f *FlowHandler) flowCtime() time.Time { return f.ctime }
func NewFlowHandlerBase(sch Schemer, hand string) FlowHandler { func NewFlowHandlerBase(hand string) FlowHandler {
return FlowHandler{ return FlowHandler{
ID: FlowID(generate.UUID()), ID: FlowID(generate.UUID()),
Handler: HandlerKey(hand), Handler: HandlerKey(hand),
Schema: sch.FlowSchema(),
curStep: StepInit, curStep: StepInit,
ctime: time.Now(), ctime: time.Now(),
@ -140,16 +119,39 @@ func resultErrs(e Errors) Errors {
return e return e
} }
func (fm *FlowHandler) ShowForm(errs Errors) *Result { type FormOption func(*Result)
func (*FlowHandler) WithErrors(e Errors) FormOption {
return func(r *Result) {
r.Errors = e
}
}
func (*FlowHandler) WithStep(s Step) FormOption {
return func(r *Result) {
r.StepID = stepPtr(s)
}
}
func (*FlowHandler) WithSchema(sch Schemer) FormOption {
return func(r *Result) {
r.Schema = sch.FlowSchema()
}
}
func (fm *FlowHandler) ShowForm(opts ...FormOption) *Result {
res := &Result{ res := &Result{
Type: TypeForm, Type: TypeForm,
ID: fm.ID, ID: fm.ID,
StepID: stepPtr(fm.curStep), StepID: stepPtr(fm.curStep),
Schema: fm.Schema,
Handler: fm.Handlers(), Handler: fm.Handlers(),
Errors: resultErrs(errs),
} }
for _, opt := range opts {
opt(res)
}
res.Errors = resultErrs(res.Errors)
return res return res
} }

40
pkg/flow/schema.go Normal file
View file

@ -0,0 +1,40 @@
package flow
type Type string
const (
TypeString Type = "string"
)
func (t Type) IsValid() bool {
switch t {
case TypeString:
return true
}
return false
}
type SchemaItem struct {
Type Type `json:"type"`
Name string `json:"name"`
Required bool `json:"required"`
}
type Schema []SchemaItem
type Schemer interface {
FlowSchema() Schema
}
func NewSchema(items ...SchemaItem) Schema {
return items
}
func RequiredString(name string) SchemaItem {
return SchemaItem{
Type: TypeString,
Name: name,
Required: true,
}
}

View file

@ -1,24 +1,42 @@
package frontend package frontend
import ( import (
"context"
"embed" "embed"
"io" "io"
"io/fs" "io/fs"
"net/http" "net/http"
"sync"
"dynatron.me/x/blasphem/pkg/blas"
"dynatron.me/x/blasphem/pkg/blas/core"
"dynatron.me/x/blasphem/pkg/components"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
const FrontendKey = "frontend"
//go:embed frontend/hass_frontend //go:embed frontend/hass_frontend
var root embed.FS var root embed.FS
var RootFS fs.FS type Frontend struct {
fsHandler echo.HandlerFunc
rootFS fs.FS
var FSHandler echo.HandlerFunc routeInstall sync.Once
}
func AliasHandler(toFile string) echo.HandlerFunc { func (fe *Frontend) InstallRoutes(e *echo.Echo) {
fe.routeInstall.Do(func() {
e.GET("/*", fe.fsHandler)
e.GET("/auth/authorize", fe.AliasHandler("authorize.html"))
})
}
func (fe *Frontend) AliasHandler(toFile string) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
file, err := RootFS.Open(toFile) file, err := fe.rootFS.Open(toFile)
if err != nil { if err != nil {
return err return err
} }
@ -33,13 +51,32 @@ func AliasHandler(toFile string) echo.HandlerFunc {
} }
} }
func init() { func (*Frontend) Shutdown() {}
func newData(_ []string) interface{} {
return map[string]interface{}{}
}
func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
return nil
}
func Setup(b core.Blas) (components.Component, error) {
fe := &Frontend{}
var err error var err error
RootFS, err = fs.Sub(root, "frontend/hass_frontend") fe.rootFS, err = fs.Sub(root, "frontend/hass_frontend")
if err != nil { if err != nil {
panic(err) return nil, err
} }
FSHandler = echo.StaticDirectoryHandler(RootFS, false) fe.fsHandler = echo.StaticDirectoryHandler(fe.rootFS, false)
b.RegisterWSCommand("frontend", wsHand, newData)
return fe, nil
}
func init() {
blas.Register(FrontendKey, Setup)
} }

View file

@ -12,40 +12,39 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/ziflex/lecho/v3" "github.com/ziflex/lecho/v3"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/blas" "dynatron.me/x/blasphem/pkg/blas"
"dynatron.me/x/blasphem/pkg/config" "dynatron.me/x/blasphem/pkg/blas/core"
"dynatron.me/x/blasphem/pkg/frontend" "dynatron.me/x/blasphem/pkg/frontend"
conf "dynatron.me/x/blasphem/pkg/server/config" conf "dynatron.me/x/blasphem/pkg/server/config"
) )
type Server struct { type Server struct {
*blas.Blas core.Blas
*echo.Echo *echo.Echo
auth.Authenticator
wg sync.WaitGroup wg sync.WaitGroup
} }
func (s *Server) installRoutes() { type RouteHaver interface {
s.GET("/*", frontend.FSHandler) InstallRoutes(e *echo.Echo)
s.GET("/api/websocket", s.wsHandler)
s.Authenticator.InstallRoutes(s.Echo)
} }
func New(cfg *config.Config) (s *Server, err error) { func (s *Server) installRoutes() {
b, err := blas.New(cfg) s.GET("/api/websocket", s.wsHandler)
if err != nil {
return nil, err
}
s = &Server{ s.Component(frontend.FrontendKey).(RouteHaver).InstallRoutes(s.Echo)
Blas: b, s.Blas.(*blas.Blas).Authenticator.InstallRoutes(s.Echo)
Echo: echo.New(),
for _, c := range s.Components() {
if rh, ok := c.(RouteHaver); ok {
rh.InstallRoutes(s.Echo)
} }
err = s.InitAuth(b.Store) }
if err != nil { }
return s, err
func New(core core.Blas) (s *Server, err error) {
s = &Server{
Blas: core,
Echo: echo.New(),
} }
s.Echo.Debug = true s.Echo.Debug = true
@ -53,6 +52,8 @@ func New(cfg *config.Config) (s *Server, err error) {
logger := lecho.From(log.Logger) logger := lecho.From(log.Logger)
s.Echo.Logger = logger s.Echo.Logger = logger
cfg := s.Conf()
if cfg.Server.LogRequestErrors { if cfg.Server.LogRequestErrors {
s.Echo.Use(lecho.Middleware(lecho.Config{ s.Echo.Use(lecho.Middleware(lecho.Config{
Logger: logger, Logger: logger,
@ -80,7 +81,7 @@ func New(cfg *config.Config) (s *Server, err error) {
} }
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
err := s.Blas.Shutdown(ctx) err := s.ShutdownBlas(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -91,8 +92,8 @@ func (s *Server) Shutdown(ctx context.Context) error {
func (s *Server) Go() error { func (s *Server) Go() error {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
log.Info().Str("bind", s.Config.Server.Bind).Msg("Server listening") log.Info().Str("version", s.Version()).Str("bind", s.Conf().Server.Bind).Msg("Server listening")
err := s.Start(s.Config.Server.Bind) err := s.Start(s.Conf().Server.Bind)
if err != nil && err != http.ErrServerClosed { if err != nil && err != http.ErrServerClosed {
s.Logger.Fatal(err) s.Logger.Fatal(err)
} }

View file

@ -1,11 +1,12 @@
package server package server
import ( import (
"errors" "context"
"log" "dynatron.me/x/blasphem/pkg/wsapi"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
) )
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
@ -14,7 +15,16 @@ var upgrader = websocket.Upgrader{
} }
func (s *Server) wsHandler(c echo.Context) error { func (s *Server) wsHandler(c echo.Context) error {
log.Println("WebSocket") conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
//conn, err := upgrader.Upgrade(w, req, nil) if err != nil {
return errors.New("not handled") return err
}
defer conn.Close()
_ = log.Debug
ctx, cancel := context.WithCancel(c.Request().Context())
defer cancel()
return wsapi.NewSession(s, c, conn).Go(ctx)
} }

View file

@ -129,7 +129,7 @@ func (s *fsStore) FlushAll() []error {
return errs return errs
} }
func (s *fsStore) Shutdown() { func (s *fsStore) ShutdownStore() {
errs := s.FlushAll() errs := s.FlushAll()
if errs != nil { if errs != nil {
log.Error().Errs("errors", errs).Msg("errors persisting store") log.Error().Errs("errors", errs).Msg("errors persisting store")

View file

@ -50,6 +50,6 @@ type Store interface {
// Flush flushes a single key to backing. // Flush flushes a single key to backing.
Flush(key string) error Flush(key string) error
// Shutdown is called to quiesce and shutdown the store. // ShutdownStore is called to quiesce and shutdown the store.
Shutdown() ShutdownStore()
} }

264
pkg/wsapi/api.go Normal file
View file

@ -0,0 +1,264 @@
package wsapi
import (
"context"
"encoding/json"
"errors"
"strings"
"dynatron.me/x/blasphem/pkg/auth"
"dynatron.me/x/blasphem/pkg/blas/core"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog/log"
)
var (
NoSuchHandlerErr = errors.New("bad websocket command")
NoMessageIDErr = errors.New("no message ID")
AuthInvalidErr = errors.New("invalid auth")
)
type Type string
type MsgBase struct {
ID *int `json:"id,omitempty"`
Type Type `json:"type"`
}
type (
wsSession struct {
conn *websocket.Conn
b core.Blas
ec echo.Context
write chan<- interface{}
user *auth.User
refreshToken *auth.RefreshToken
}
wsEntry struct {
dataNew core.NewData
hnd core.Handler
}
wsRegistry map[string]wsEntry
wsManager struct {
r wsRegistry
}
)
func (wsm *wsManager) RegisterWSCommand(cmd string, hnd core.Handler, dataNew core.NewData) {
wsm.r[cmd] = wsEntry{
dataNew: dataNew,
hnd: hnd,
}
}
func (wsm *wsManager) WSCommandHandler(cmd string, cmdSplit []string) (core.NewData, core.Handler, error) {
if wse, ok := wsm.r[cmd]; ok {
return wse.dataNew, wse.hnd, nil
}
if wse, ok := wsm.r[cmdSplit[0]]; ok {
return wse.dataNew, wse.hnd, nil
}
return nil, nil, NoSuchHandlerErr
}
func NewManager() core.WebSocketManager {
return &wsManager{
r: make(wsRegistry),
}
}
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
ws := &wsSession{
conn: conn,
b: s,
ec: c,
}
return ws
}
func (ws *wsSession) Conn() *websocket.Conn {
return ws.conn
}
func (ws *wsSession) Blas() core.Blas { return ws.b }
func (ws *wsSession) Go(ctx context.Context) error {
authP := &authPhase{ws}
err := ws.sendAuthRequired()
if err != nil {
return err
}
_, rdr, err := ws.conn.NextReader()
if err != nil {
return err
}
err = authP.handleMsg(ctx, rdr)
if err != nil || ws.refreshToken == nil {
return err
}
// command phase
msgChan := make(chan map[string]interface{})
write := make(chan interface{})
ws.write = write
defer close(write)
cCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func(ctx context.Context, ch chan<- map[string]interface{}) {
defer close(ch)
for {
if err := ctx.Err(); err != nil {
return
}
_, rdr, err := ws.conn.NextReader()
switch {
case err == nil:
case websocket.IsCloseError(err, websocket.CloseGoingAway):
return
case err != nil:
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
return
}
var msgMap map[string]interface{}
err = json.NewDecoder(rdr).Decode(&msgMap)
if err != nil {
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
}
ch <- msgMap
}
}(cCtx, msgChan)
for {
select {
case msg, ok := <-msgChan:
if !ok {
return nil
}
err = ws.handleMsg(ctx, msg)
if err != nil {
log.Error().Err(err).Msg("handleMsg")
}
case <-ctx.Done():
close(msgChan) // maybe remove this?
return nil
case m := <-write:
err := ws.conn.WriteJSON(m)
if err != nil {
log.Error().Err(err).Msg("writeMsg")
}
}
}
}
func (ws *wsSession) Write(msg interface{}) {
ws.write <- msg
}
type MsgType string
const (
ResultMsgType MsgType = "result"
)
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
}
type WSError struct {
ID *int `json:"id,omitempty"`
Type MsgType `json:"type"`
Success bool `json:"success"`
Error Error `json:"error"`
}
func (ws *wsSession) writeError(id int, err Error) error {
return ws.conn.WriteJSON(WSError{
ID: &id,
Type: ResultMsgType,
Success: false,
Error: err,
})
}
func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
if err := ctx.Err(); err != nil {
return nil
}
msgType, ok := msgMap["type"].(string)
if !ok {
return NoSuchHandlerErr
}
idFl, ok := msgMap["id"].(float64)
if !ok {
ws.Write(WSError{
Type: ResultMsgType,
Success: false,
Error: Error{
Code: "invalid_id",
Message: "command has no ID",
},
})
return nil
}
id := int(idFl)
cmd := strings.Split(msgType, "/")
newData, hand, err := ws.b.WSCommandHandler(cmd[0], cmd)
switch err {
case nil:
case NoSuchHandlerErr:
ws.writeError(id, Error{
Code: "invalid_type",
Message: "no such command",
})
return nil
default:
log.Error().Err(err).Msg("dispatch")
return nil
}
nd := newData(cmd)
if _, ok := nd.(map[string]interface{}); !ok {
err := mapstructure.Decode(&msgMap, &nd)
if err != nil {
ws.writeError(id, Error{
Code: "invalid_format",
Message: err.Error(),
})
return nil
}
}
err = hand(ctx, ws, id, cmd, nd)
if err != nil {
log.Error().Err(err).Msg("dispatch")
}
return nil
}

81
pkg/wsapi/auth.go Normal file
View file

@ -0,0 +1,81 @@
package wsapi
import (
"context"
"encoding/json"
"io"
"dynatron.me/x/blasphem/pkg/auth"
"github.com/rs/zerolog/log"
)
type authPhase struct {
*wsSession
}
func (ws *wsSession) sendAuthRequired() error {
authReq := &struct {
MsgBase
Version string `json:"ha_version"`
}{
MsgBase{Type: "auth_required"},
ws.b.Version(),
}
return ws.conn.WriteJSON(&authReq)
}
type authMsg struct {
MsgBase
AccessToken auth.AccessToken `json:"access_token"`
}
func (ap *authPhase) msgSchema() interface{} {
return &authMsg{}
}
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
ap.user = rt.User
ap.refreshToken = rt
}
func (ap *authPhase) sendAuthOK() error {
return ap.conn.WriteJSON(struct {
Type string `json:"type"`
Version string `json:"ha_version"`
}{Type: "auth_ok", Version: ap.Blas().Version()})
}
func (ap *authPhase) sendAuthInvalid() error {
return ap.conn.WriteJSON(struct {
Type string `json:"type"`
Message string `json:"message"`
}{Type: "auth_ok", Message: "invalid auth"})
}
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
var authMsg authMsg
err := json.NewDecoder(r).Decode(&authMsg)
if err != nil {
return err
}
if err := ctx.Err(); err != nil {
return err
}
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
if refreshToken != nil {
ap.finishAuth(refreshToken)
return ap.sendAuthOK()
}
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
err = ap.sendAuthInvalid()
if err != nil {
return err
}
return AuthInvalidErr
}