Compare commits
46 commits
Author | SHA1 | Date | |
---|---|---|---|
f272514c65 | |||
60301c9892 | |||
14401c831a | |||
c424b75cf6 | |||
8bf7379170 | |||
2ddac95715 | |||
5570e53d99 | |||
6bd36a59c5 | |||
6443443c6b | |||
76f95a422b | |||
caa371eff1 | |||
f119f7086f | |||
95d72d2912 | |||
a1005ce6bf | |||
1b355d3cbf | |||
0da222577b | |||
0378151b9f | |||
a468f0629b | |||
2de97a6936 | |||
c91cd6efca | |||
824e54894e | |||
6aa2c46717 | |||
9224b21db7 | |||
3038750206 | |||
86abb0b618 | |||
d34814f050 | |||
f3e17e149f | |||
ae00c1534d | |||
c618197c54 | |||
794f2d8448 | |||
9aef6e5143 | |||
7f499012a6 | |||
1ea8a24224 | |||
25ed921736 | |||
9db15a648b | |||
4de1344512 | |||
68b971f65a | |||
9b38bdbca9 | |||
43682fab05 | |||
3981025fa4 | |||
3ab5b5b78a | |||
54185be835 | |||
9f39049dd6 | |||
a540f94108 | |||
25c3e9421f | |||
1aa1296732 |
32 changed files with 2009 additions and 469 deletions
7
Makefile
7
Makefile
|
@ -1,11 +1,14 @@
|
||||||
FE=pkg/frontend/frontend
|
FE=pkg/frontend/frontend
|
||||||
|
VER=$(shell git describe --always --tags)
|
||||||
|
|
||||||
|
LDFLAGS=-ldflags='-X dynatron.me/x/blasphem/internal/common.Version=${VER}'
|
||||||
|
|
||||||
all: build
|
all: build
|
||||||
build:
|
build:
|
||||||
go build -o blas ./cmd/blas/
|
go build ${LDFLAGS} -o blas ./cmd/blas/
|
||||||
|
|
||||||
serve:
|
serve:
|
||||||
go run ./cmd/blas/ serve ${BLAS_ARGS}
|
go run ${LDFLAGS} ./cmd/blas/ serve ${BLAS_ARGS}
|
||||||
|
|
||||||
# pkg/frontend/frontend/hass_frontend:
|
# pkg/frontend/frontend/hass_frontend:
|
||||||
frontend:
|
frontend:
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/internal/common"
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas"
|
||||||
"dynatron.me/x/blasphem/pkg/cmd/serve"
|
"dynatron.me/x/blasphem/pkg/cmd/serve"
|
||||||
"dynatron.me/x/blasphem/pkg/config"
|
"dynatron.me/x/blasphem/pkg/config"
|
||||||
|
|
||||||
|
@ -25,7 +26,12 @@ func main() {
|
||||||
log.Fatal().Err(err).Msg("Config read failed")
|
log.Fatal().Err(err).Msg("Config read failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCmd.AddCommand(serve.Command(config))
|
bl, err := blas.New(config)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Core create failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.AddCommand(serve.Command(bl))
|
||||||
|
|
||||||
err = rootCmd.Execute()
|
err = rootCmd.Execute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -21,6 +21,7 @@ require (
|
||||||
github.com/labstack/gommon v0.3.1 // indirect
|
github.com/labstack/gommon v0.3.1 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/fasttemplate v1.2.1 // indirect
|
github.com/valyala/fasttemplate v1.2.1 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -36,6 +36,8 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y
|
||||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
|
// common contains common functionality for blasphem.
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// this symbol overriden by linker args
|
||||||
|
var Version = "undefined"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// AppName is the name of the application.
|
||||||
AppName = "blasphem"
|
AppName = "blasphem"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +19,14 @@ type cmdOptions interface {
|
||||||
Execute() error
|
Execute() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AppNamePtr() *string {
|
||||||
|
s := AppName
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntPtr(i int) *int { return &i }
|
||||||
|
|
||||||
|
// RunE is a convenience function for use with cobra.
|
||||||
func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
||||||
return func(cmd *cobra.Command, args []string) error {
|
return func(cmd *cobra.Command, args []string) error {
|
||||||
err := c.Options(cmd, args)
|
err := c.Options(cmd, args)
|
||||||
|
@ -29,3 +43,9 @@ func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NoCache(c echo.Context) echo.Context {
|
||||||
|
c.Response().Header().Set("Cache-Control", "no-store")
|
||||||
|
c.Response().Header().Set("Pragma", "no-cache")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
30
internal/common/types.go
Normal file
30
internal/common/types.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
// Convenience types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// PyTimeStamp is a timestamp that marshals to python-style timestamp strings (long nano).
|
||||||
|
PyTimestamp time.Time
|
||||||
|
ClientID string
|
||||||
|
)
|
||||||
|
|
||||||
|
const PytTimeFormat = "2006-01-02T15:04:05.999999-07:00"
|
||||||
|
|
||||||
|
func (t *PyTimestamp) MarshalJSON() ([]byte, error) {
|
||||||
|
rv := fmt.Sprintf("%q", time.Time(*t).Format(PytTimeFormat))
|
||||||
|
return []byte(rv), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PyTimestamp) UnmarshalJSON(b []byte) error {
|
||||||
|
s := strings.Trim(string(b), `"`)
|
||||||
|
tm, err := time.Parse(PytTimeFormat, s)
|
||||||
|
*t = PyTimestamp(tm)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
24
internal/generate/unique.go
Normal file
24
internal/generate/unique.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package generate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UUID() string {
|
||||||
|
// must be addressable
|
||||||
|
u := uuid.New()
|
||||||
|
|
||||||
|
return hex.EncodeToString(u[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func Hex(l int) string {
|
||||||
|
b := make([]byte, l)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
|
@ -1,17 +1,14 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
"dynatron.me/x/blasphem/pkg/frontend"
|
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
|
|
||||||
// providers
|
// providers
|
||||||
|
@ -20,25 +17,33 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrDisabled = errors.New("user disabled")
|
||||||
ErrInvalidAuth = errors.New("invalid auth")
|
ErrInvalidAuth = errors.New("invalid auth")
|
||||||
ErrInvalidHandler = errors.New("no such handler")
|
ErrInvalidHandler = errors.New("no such handler")
|
||||||
|
ErrInvalidIP = errors.New("invalid IP")
|
||||||
|
ErrUserAuthRemote = errors.New("user cannot authenticate remotely")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Authenticator struct {
|
type authenticator struct {
|
||||||
|
sync.Mutex
|
||||||
store AuthStore
|
store AuthStore
|
||||||
flows FlowStore
|
flows *AuthFlowManager
|
||||||
sessions SessionStore
|
authCodes authCodeStore
|
||||||
providers map[string]provider.AuthProvider
|
providers map[string]provider.AuthProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Authenticator interface {
|
||||||
|
ValidateAccessToken(token AccessToken) *RefreshToken
|
||||||
|
InstallRoutes(e *echo.Echo)
|
||||||
|
}
|
||||||
|
|
||||||
type AuthError struct {
|
type AuthError struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
Description string `json:"error_description"`
|
Description string `json:"error_description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) InstallRoutes(e *echo.Echo) {
|
func (a *authenticator) InstallRoutes(e *echo.Echo) {
|
||||||
authG := e.Group("/auth")
|
authG := e.Group("/auth")
|
||||||
authG.GET("/authorize", frontend.AliasHandler("authorize.html"))
|
|
||||||
authG.GET("/providers", a.ProvidersHandler)
|
authG.GET("/providers", a.ProvidersHandler)
|
||||||
authG.POST("/token", a.TokenHandler)
|
authG.POST("/token", a.TokenHandler)
|
||||||
|
|
||||||
|
@ -49,31 +54,34 @@ func (a *Authenticator) InstallRoutes(e *echo.Echo) {
|
||||||
loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler)
|
loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) InitAuth(s storage.Store) error {
|
func New(s storage.Store) (Authenticator, error) {
|
||||||
a.providers = make(map[string]provider.AuthProvider)
|
a := &authenticator{
|
||||||
|
providers: make(map[string]provider.AuthProvider),
|
||||||
|
}
|
||||||
|
|
||||||
for _, pI := range provider.Providers {
|
for _, pI := range provider.Providers {
|
||||||
nProv, err := pI(s)
|
nProv, err := pI(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.providers[nProv.ProviderType()] = nProv
|
a.providers[nProv.ProviderType()] = nProv
|
||||||
}
|
}
|
||||||
|
|
||||||
a.flows = make(FlowStore)
|
a.flows = NewAuthFlowManager()
|
||||||
|
|
||||||
a.sessions.init()
|
a.authCodes.init()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
a.store, err = a.newAuthStore(s)
|
a.store, err = a.newAuthStore(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) Provider(name string) provider.AuthProvider {
|
func (a *authenticator) Provider(name string) provider.AuthProvider {
|
||||||
p, ok := a.providers[name]
|
p, ok := a.providers[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
|
@ -85,7 +93,7 @@ func (a *Authenticator) Provider(name string) provider.AuthProvider {
|
||||||
var HomeAssistant = "homeassistant"
|
var HomeAssistant = "homeassistant"
|
||||||
|
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
func (a *Authenticator) ProvidersHandler(c echo.Context) error {
|
func (a *authenticator) ProvidersHandler(c echo.Context) error {
|
||||||
providers := []provider.AuthProviderBase{
|
providers := []provider.AuthProviderBase{
|
||||||
a.Provider(HomeAssistant).ProviderBase(),
|
a.Provider(HomeAssistant).ProviderBase(),
|
||||||
}
|
}
|
||||||
|
@ -93,46 +101,24 @@ func (a *Authenticator) ProvidersHandler(c echo.Context) error {
|
||||||
return c.JSON(http.StatusOK, providers)
|
return c.JSON(http.StatusOK, providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) Check(f *Flow, rm map[string]interface{}) (provider.ProviderUser, error) {
|
func (a *authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) {
|
||||||
cID, hasCID := rm["client_id"]
|
cID, hasCID := rm["client_id"]
|
||||||
cIDStr, cidIsStr := cID.(string)
|
clientID, cidIsStr := cID.(string)
|
||||||
if !hasCID || !cidIsStr || cIDStr == "" || cIDStr != string(f.request.ClientID) {
|
if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) {
|
||||||
return nil, ErrInvalidAuth
|
return nil, clientID, ErrInvalidAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, h := range f.Handler {
|
p := a.Provider(f.Handler.String())
|
||||||
if h == nil {
|
|
||||||
return nil, ErrInvalidHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
p := a.Provider(*h)
|
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return nil, ErrInvalidAuth
|
return nil, clientID, ErrInvalidAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
user, success := p.ValidateCreds(rm)
|
user, success := p.ValidateCreds(req, rm)
|
||||||
|
|
||||||
if success {
|
if success {
|
||||||
log.Info().Interface("user", user.ProviderUserData()).Msg("Login success")
|
log.Info().Interface("user", user.UserData()).Msg("Login success")
|
||||||
return user, nil
|
return user, clientID, nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidAuth
|
return nil, clientID, ErrInvalidAuth
|
||||||
}
|
|
||||||
|
|
||||||
func genUUID() string {
|
|
||||||
// must be addressable
|
|
||||||
u := uuid.New()
|
|
||||||
|
|
||||||
return hex.EncodeToString(u[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func genHex(l int) string {
|
|
||||||
b := make([]byte, l)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
}
|
||||||
|
|
225
pkg/auth/flow.go
225
pkg/auth/flow.go
|
@ -1,133 +1,100 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jinzhu/copier"
|
"github.com/jinzhu/copier"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/internal/common"
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
|
"dynatron.me/x/blasphem/pkg/flow"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FlowStore map[FlowID]*Flow
|
type AuthFlowManager struct {
|
||||||
|
*flow.FlowManager
|
||||||
|
}
|
||||||
|
|
||||||
type FlowRequest struct {
|
type LoginFlow struct {
|
||||||
ClientID ClientID `json:"client_id"`
|
flow.FlowHandler
|
||||||
|
|
||||||
|
prov provider.AuthProvider
|
||||||
|
ClientID common.ClientID
|
||||||
|
FlowContext
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowContext struct {
|
||||||
|
IPAddr string
|
||||||
|
CredentialOnly bool
|
||||||
|
RedirectURI string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginFlowRequest struct {
|
||||||
|
ClientID common.ClientID `json:"client_id"`
|
||||||
Handler []*string `json:"handler"`
|
Handler []*string `json:"handler"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
Type *string `json:"type"`
|
||||||
|
|
||||||
|
ip string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FlowType string
|
func (r *LoginFlowRequest) FlowContext() FlowContext {
|
||||||
|
return FlowContext{
|
||||||
const (
|
IPAddr: r.ip,
|
||||||
TypeForm FlowType = "form"
|
RedirectURI: r.RedirectURI,
|
||||||
TypeCreateEntry FlowType = "create_entry"
|
CredentialOnly: r.Type != nil && *r.Type == "link_user",
|
||||||
)
|
|
||||||
|
|
||||||
type FlowID string
|
|
||||||
type Step string
|
|
||||||
|
|
||||||
const (
|
|
||||||
StepInit Step = "init"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Flow struct {
|
|
||||||
Type FlowType `json:"type"`
|
|
||||||
ID FlowID `json:"flow_id"`
|
|
||||||
Handler []*string `json:"handler"`
|
|
||||||
StepID *Step `json:"step_id,omitempty"`
|
|
||||||
Schema []provider.FlowSchemaItem `json:"data_schema"`
|
|
||||||
Errors interface{} `json:"errors"`
|
|
||||||
DescPlace *string `json:"description_placeholders"`
|
|
||||||
LastStep *string `json:"last_step"`
|
|
||||||
|
|
||||||
request *FlowRequest
|
|
||||||
ctime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Flow) touch() {
|
|
||||||
f.ctime = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs FlowStore) register(f *Flow) {
|
|
||||||
fs.cull()
|
|
||||||
fs[f.ID] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs FlowStore) Remove(f *Flow) {
|
|
||||||
delete(fs, f.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
const cullAge = time.Minute * 30
|
|
||||||
|
|
||||||
func (fs FlowStore) cull() {
|
|
||||||
for k, v := range fs {
|
|
||||||
if time.Now().Sub(v.ctime) > cullAge {
|
|
||||||
delete(fs, k)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs FlowStore) Get(id FlowID) *Flow {
|
func NewAuthFlowManager() *AuthFlowManager {
|
||||||
f, ok := fs[id]
|
return &AuthFlowManager{FlowManager: flow.NewFlowManager()}
|
||||||
if ok {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) NewFlow(r *FlowRequest) *Flow {
|
func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.AuthProvider) *LoginFlow {
|
||||||
var sch []provider.FlowSchemaItem
|
lf := &LoginFlow{
|
||||||
|
FlowHandler: flow.NewFlowHandlerBase(prov.ProviderType()),
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
FlowContext: req.FlowContext(),
|
||||||
|
prov: prov,
|
||||||
|
}
|
||||||
|
|
||||||
|
afm.Register(lf)
|
||||||
|
|
||||||
|
return lf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
|
||||||
|
var prov provider.AuthProvider
|
||||||
|
|
||||||
for _, h := range r.Handler {
|
for _, h := range r.Handler {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if hand := a.Provider(*h); hand != nil {
|
prov = a.Provider(*h)
|
||||||
sch = hand.FlowSchema()
|
if prov != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sch == nil {
|
if prov == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flow := &Flow{
|
lf := a.flows.NewLoginFlow(r, prov)
|
||||||
Type: TypeForm,
|
|
||||||
ID: FlowID(genUUID()),
|
|
||||||
StepID: stepPtr(StepInit),
|
|
||||||
Schema: sch,
|
|
||||||
Handler: r.Handler,
|
|
||||||
Errors: []string{},
|
|
||||||
request: r,
|
|
||||||
}
|
|
||||||
flow.touch()
|
|
||||||
|
|
||||||
a.flows.register(flow)
|
return lf.ShowForm(lf.WithSchema(prov), lf.WithStep(flow.StepInit))
|
||||||
|
|
||||||
return flow
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func stepPtr(s Step) *Step { return &s }
|
func (f *LoginFlow) redirect(c echo.Context) {
|
||||||
|
c.Request().Header.Set("Location", f.RedirectURI)
|
||||||
func (f *Flow) redirect(c echo.Context) {
|
|
||||||
c.Request().Header.Set("Location", f.request.RedirectURI)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
func (f *LoginFlow) progress(a *authenticator, c echo.Context) error {
|
||||||
if f.StepID == nil {
|
switch f.Step() {
|
||||||
c.Logger().Error("stepID is nil")
|
case flow.StepInit:
|
||||||
return c.String(http.StatusInternalServerError, "No Step ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch *f.StepID {
|
|
||||||
case StepInit:
|
|
||||||
rm := make(map[string]interface{})
|
rm := make(map[string]interface{})
|
||||||
|
|
||||||
err := c.Bind(&rm)
|
err := c.Bind(&rm)
|
||||||
|
@ -135,31 +102,26 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
||||||
return c.String(http.StatusBadRequest, err.Error())
|
return c.String(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, si := range f.Schema {
|
err = f.prov.FlowSchema().CheckRequired(rm)
|
||||||
if si.Required {
|
if err != nil {
|
||||||
if _, ok := rm[si.Name]; !ok {
|
return c.JSON(http.StatusBadRequest, f.ShowForm(f.WithErrors([]string{err.Error()})))
|
||||||
return c.String(http.StatusBadRequest, "missing required param "+si.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
user, err := a.Check(f, rm)
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
var finishedFlow struct {
|
|
||||||
ID FlowID `json:"flow_id"`
|
|
||||||
Handler []*string `json:"handler"`
|
|
||||||
Result TokenID `json:"result"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Type FlowType `json:"type"`
|
|
||||||
Version int `json:"version"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, clientID, err := a.Check(f, c.Request(), rm)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
creds := a.store.GetCredential(user)
|
||||||
|
if creds == nil {
|
||||||
|
return fmt.Errorf("flow progress: no such credential for %v", user.UserData())
|
||||||
|
}
|
||||||
|
|
||||||
|
finishedFlow := flow.Result{}
|
||||||
a.flows.Remove(f)
|
a.flows.Remove(f)
|
||||||
copier.Copy(&finishedFlow, f)
|
copier.Copy(&finishedFlow, f)
|
||||||
finishedFlow.Type = TypeCreateEntry
|
finishedFlow.Type = flow.TypeCreateEntry
|
||||||
finishedFlow.Title = common.AppName
|
finishedFlow.Title = common.AppNamePtr()
|
||||||
finishedFlow.Version = 1
|
finishedFlow.Version = common.IntPtr(1)
|
||||||
finishedFlow.Result = a.NewToken(c.Request(), user, f)
|
finishedFlow.Result = a.NewAuthCode(ClientID(clientID), creds)
|
||||||
|
|
||||||
f.redirect(c)
|
f.redirect(c)
|
||||||
|
|
||||||
|
@ -169,24 +131,29 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
||||||
case ErrInvalidAuth:
|
case ErrInvalidAuth:
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
f.Errors = map[string]interface{}{
|
return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
|
||||||
"base": "invalid_auth",
|
"base": "invalid_auth",
|
||||||
}
|
})))
|
||||||
return c.JSON(http.StatusOK, f)
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return c.String(http.StatusBadRequest, "unknown flow step")
|
return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
|
||||||
|
"base": "unknown_flow_step",
|
||||||
|
})))
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
|
func (a *authenticator) LoginFlowDeleteHandler(c echo.Context) error {
|
||||||
flowID := c.Param("flow_id")
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
|
flowID := flow.FlowID(c.Param("flow_id"))
|
||||||
|
|
||||||
if flowID == "" {
|
if flowID == "" {
|
||||||
return c.String(http.StatusBadRequest, "empty flow ID")
|
return c.String(http.StatusBadRequest, "empty flow ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(a.flows, FlowID(flowID))
|
a.flows.Delete(flowID)
|
||||||
|
|
||||||
return c.String(http.StatusOK, "deleted")
|
return c.String(http.StatusOK, "deleted")
|
||||||
}
|
}
|
||||||
|
@ -198,15 +165,20 @@ func setJSON(c echo.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
func (a *authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
setJSON(c)
|
setJSON(c)
|
||||||
|
|
||||||
var flowReq FlowRequest
|
var flowReq LoginFlowRequest
|
||||||
err := c.Bind(&flowReq)
|
err := c.Bind(&flowReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.String(http.StatusBadRequest, err.Error())
|
return c.String(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowReq.ip = c.Request().RemoteAddr
|
||||||
|
|
||||||
resp := a.NewFlow(&flowReq)
|
resp := a.NewFlow(&flowReq)
|
||||||
|
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
|
@ -216,21 +188,18 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
||||||
return c.JSON(http.StatusOK, resp)
|
return c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
func (a *authenticator) LoginFlowHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
setJSON(c)
|
setJSON(c)
|
||||||
|
|
||||||
flowID := c.Param("flow_id")
|
flowID := c.Param("flow_id")
|
||||||
|
|
||||||
flow := a.flows.Get(FlowID(flowID))
|
flow := a.flows.Get(flow.FlowID(flowID))
|
||||||
if flow == nil {
|
if flow == nil {
|
||||||
return c.String(http.StatusNotFound, "no such flow")
|
return c.String(http.StatusNotFound, "no such flow")
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().Sub(flow.ctime) > cullAge {
|
return flow.(*LoginFlow).progress(a, c)
|
||||||
a.flows.Remove(flow)
|
|
||||||
|
|
||||||
return c.String(http.StatusGone, "flow timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
return flow.progress(a, c)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
package hass
|
package hass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
|
"dynatron.me/x/blasphem/pkg/flow"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,14 +24,29 @@ type HAUser struct {
|
||||||
provider.AuthProvider `json:"-"`
|
provider.AuthProvider `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hau *HAUser) UserData() interface{} {
|
func (hau *HAUser) UserData() provider.ProviderUser {
|
||||||
return UserData{
|
return &UserData{ // strip secret
|
||||||
Username: hau.Username,
|
Username: hau.Username,
|
||||||
|
AuthProvider: hau.AuthProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hau *HAUser) Provider() provider.AuthProvider {
|
||||||
|
return hau.AuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hau *UserData) Provider() provider.AuthProvider {
|
||||||
|
return hau.AuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
type UserData struct {
|
type UserData struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
|
||||||
|
provider.AuthProvider `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ud *UserData) UserData() provider.ProviderUser {
|
||||||
|
return ud
|
||||||
}
|
}
|
||||||
|
|
||||||
const HomeAssistant = "homeassistant"
|
const HomeAssistant = "homeassistant"
|
||||||
|
@ -38,6 +56,7 @@ func (h *HAUser) ProviderUserData() interface{} { return h.UserData() }
|
||||||
type HomeAssistantProvider struct {
|
type HomeAssistantProvider struct {
|
||||||
provider.AuthProviderBase `json:"-"`
|
provider.AuthProviderBase `json:"-"`
|
||||||
Users []HAUser `json:"users"`
|
Users []HAUser `json:"users"`
|
||||||
|
userMap map[string]*HAUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHAProvider(s storage.Store) (provider.AuthProvider, error) {
|
func NewHAProvider(s storage.Store) (provider.AuthProvider, error) {
|
||||||
|
@ -53,18 +72,52 @@ func NewHAProvider(s storage.Store) (provider.AuthProvider, error) {
|
||||||
return hap, err
|
return hap, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range hap.Users {
|
hap.userMap = make(map[string]*HAUser)
|
||||||
|
|
||||||
|
for i, u := range hap.Users {
|
||||||
hap.Users[i].AuthProvider = hap
|
hap.Users[i].AuthProvider = hap
|
||||||
|
hap.userMap[u.Username] = &hap.Users[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
return hap, nil
|
return hap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hap *HomeAssistantProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser {
|
||||||
|
u, has := hap.userMap[pu.(*HAUser).Username]
|
||||||
|
if !has {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) {
|
func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) {
|
||||||
return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
|
return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hap *HomeAssistantProvider) ValidateCreds(rm map[string]interface{}) (provider.ProviderUser, bool) {
|
func (hap *HomeAssistantProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
|
||||||
|
switch c1c := c1.(type) {
|
||||||
|
case *HAUser:
|
||||||
|
switch c2c := c2.(type) {
|
||||||
|
case *HAUser:
|
||||||
|
return c2c.Username == c1c.Username
|
||||||
|
case *UserData:
|
||||||
|
return c2c.Username == c1c.Username
|
||||||
|
}
|
||||||
|
case *UserData:
|
||||||
|
switch c2c := c2.(type) {
|
||||||
|
case *HAUser:
|
||||||
|
return c2c.Username == c1c.Username
|
||||||
|
case *UserData:
|
||||||
|
return c2c.Username == c1c.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) {
|
||||||
usernameE, hasU := rm["username"]
|
usernameE, hasU := rm["username"]
|
||||||
passwordE, hasP := rm["password"]
|
passwordE, hasP := rm["password"]
|
||||||
username, unStr := usernameE.(string)
|
username, unStr := usernameE.(string)
|
||||||
|
@ -97,6 +150,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(rm map[string]interface{}) (prov
|
||||||
|
|
||||||
err = bcrypt.CompareHashAndPassword(hash, []byte(password))
|
err = bcrypt.CompareHashAndPassword(hash, []byte(password))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
found.AuthProvider = hap
|
||||||
return found, true
|
return found, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,25 +158,13 @@ func (hap *HomeAssistantProvider) ValidateCreds(rm map[string]interface{}) (prov
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hap *HomeAssistantProvider) NewCredData() interface{} {
|
func (hap *HomeAssistantProvider) NewCredData() interface{} {
|
||||||
return &UserData{}
|
return &HAUser{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hap *HomeAssistantProvider) FlowSchema() []provider.FlowSchemaItem {
|
func (hap *HomeAssistantProvider) FlowSchema() flow.Schema {
|
||||||
return []provider.FlowSchemaItem{
|
return flow.NewSchema(flow.RequiredString("username"), flow.RequiredString("password"))
|
||||||
{
|
|
||||||
Type: "string",
|
|
||||||
Name: "username",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: "string",
|
|
||||||
Name: "password",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
provider.Register(HomeAssistant, NewHAProvider)
|
provider.Register(HomeAssistant, NewHAProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/flow"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,9 +14,11 @@ var Providers = make(map[string]Constructor)
|
||||||
type AuthProvider interface { // TODO: this should include stepping
|
type AuthProvider interface { // TODO: this should include stepping
|
||||||
AuthProviderMetadata
|
AuthProviderMetadata
|
||||||
ProviderBase() AuthProviderBase
|
ProviderBase() AuthProviderBase
|
||||||
FlowSchema() []FlowSchemaItem
|
FlowSchema() flow.Schema
|
||||||
NewCredData() interface{}
|
NewCredData() interface{}
|
||||||
ValidateCreds(reqMap map[string]interface{}) (user ProviderUser, success bool)
|
ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool)
|
||||||
|
EqualCreds(c1, c2 ProviderUser) bool
|
||||||
|
Lookup(ProviderUser) ProviderUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func Register(providerName string, f func(storage.Store) (AuthProvider, error)) {
|
func Register(providerName string, f func(storage.Store) (AuthProvider, error)) {
|
||||||
|
@ -21,8 +26,9 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderUser interface {
|
type ProviderUser interface {
|
||||||
AuthProviderMetadata
|
// TODO: make sure this is sane with all the ProviderUser and UserData type stuff
|
||||||
ProviderUserData() interface{}
|
UserData() ProviderUser
|
||||||
|
Provider() AuthProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthProviderBase struct {
|
type AuthProviderBase struct {
|
||||||
|
@ -41,11 +47,3 @@ func (bp *AuthProviderBase) ProviderName() string { return bp.Name }
|
||||||
func (bp *AuthProviderBase) ProviderID() *string { return bp.ID }
|
func (bp *AuthProviderBase) ProviderID() *string { return bp.ID }
|
||||||
func (bp *AuthProviderBase) ProviderType() string { return bp.Type }
|
func (bp *AuthProviderBase) ProviderType() string { return bp.Type }
|
||||||
func (bp *AuthProviderBase) ProviderBase() AuthProviderBase { return *bp }
|
func (bp *AuthProviderBase) ProviderBase() AuthProviderBase { return *bp }
|
||||||
|
|
||||||
type FlowSchemaItem struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Required bool `json:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,10 @@ package trustednets
|
||||||
// TODO: This doesn't work at all
|
// TODO: This doesn't work at all
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
|
"dynatron.me/x/blasphem/pkg/flow"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,24 +16,42 @@ type User struct {
|
||||||
provider.AuthProvider `json:"-"`
|
provider.AuthProvider `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hau *User) UserData() interface{} {
|
func (hau *User) UserData() provider.ProviderUser {
|
||||||
return UserData{
|
return &UserData{
|
||||||
UserID: hau.UserID,
|
UserID: hau.UserID,
|
||||||
|
AuthProvider: hau.AuthProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hau *UserData) Provider() provider.AuthProvider {
|
||||||
|
return hau.AuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hau *User) Provider() provider.AuthProvider {
|
||||||
|
return hau.AuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
type UserData struct {
|
type UserData struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
|
||||||
|
provider.AuthProvider `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ud *UserData) UserData() provider.ProviderUser {
|
||||||
|
return ud
|
||||||
}
|
}
|
||||||
|
|
||||||
const TrustedNetworks = "trusted_networks"
|
const TrustedNetworks = "trusted_networks"
|
||||||
|
|
||||||
func (h *User) ProviderUserData() interface{} { return h.UserData() }
|
|
||||||
|
|
||||||
type TrustedNetworksProvider struct {
|
type TrustedNetworksProvider struct {
|
||||||
provider.AuthProviderBase `json:"-"`
|
provider.AuthProviderBase `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hap *TrustedNetworksProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
|
||||||
|
panic("not implemented")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func New(s storage.Store) (provider.AuthProvider, error) {
|
func New(s storage.Store) (provider.AuthProvider, error) {
|
||||||
hap := &TrustedNetworksProvider{
|
hap := &TrustedNetworksProvider{
|
||||||
AuthProviderBase: provider.AuthProviderBase{
|
AuthProviderBase: provider.AuthProviderBase{
|
||||||
|
@ -42,8 +63,14 @@ func New(s storage.Store) (provider.AuthProvider, error) {
|
||||||
return hap, nil
|
return hap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: To implement this, ValidateCreds needs to be changed to accept an http.Request, or the echo context.
|
func (tnp *TrustedNetworksProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser {
|
||||||
func (hap *TrustedNetworksProvider) ValidateCreds(rm map[string]interface{}) (provider.ProviderUser, bool) {
|
return pu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hap *TrustedNetworksProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) {
|
||||||
|
/*
|
||||||
|
if req.RemoteAddr in allowed then do the thing
|
||||||
|
*/
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,22 +78,10 @@ func (hap *TrustedNetworksProvider) NewCredData() interface{} {
|
||||||
return &UserData{}
|
return &UserData{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hap *TrustedNetworksProvider) FlowSchema() []provider.FlowSchemaItem {
|
func (hap *TrustedNetworksProvider) FlowSchema() flow.Schema {
|
||||||
return []provider.FlowSchemaItem{
|
return nil
|
||||||
{
|
|
||||||
Type: "string",
|
|
||||||
Name: "username",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: "string",
|
|
||||||
Name: "password",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
provider.Register(TrustedNetworks, New)
|
provider.Register(TrustedNetworks, New)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,143 +2,348 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
|
"dynatron.me/x/blasphem/internal/generate"
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SessionStore struct {
|
type authCodeStore struct {
|
||||||
s map[TokenID]*Token
|
s map[authCodeTuple]flowResult
|
||||||
lastCull time.Time
|
lastCull time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenID string
|
type authCodeTuple struct {
|
||||||
|
ClientID ClientID
|
||||||
|
Code AuthCode
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TokenID) IsValid() bool {
|
func (t *authCodeTuple) IsValid() bool {
|
||||||
// TODO: more validation than this
|
// TODO: more validation than this
|
||||||
return *t != ""
|
return t.Code != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token struct { // TODO: jwt bro
|
type flowResult struct {
|
||||||
ID TokenID
|
Time time.Time
|
||||||
Ctime time.Time
|
Cred *Credentials
|
||||||
Expires time.Time
|
|
||||||
Addr string
|
|
||||||
|
|
||||||
user provider.ProviderUser `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) init() {
|
// OAuth 4.2.1 spec recommends 10 minutes
|
||||||
ss.s = make(map[TokenID]*Token)
|
const authCodeExpire = 10 * time.Minute
|
||||||
|
|
||||||
|
func (f *flowResult) IsValid(now time.Time) bool {
|
||||||
|
if now.After(f.Time.Add(authCodeExpire)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *authCodeStore) init() {
|
||||||
|
ss.s = make(map[authCodeTuple]flowResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
const cullInterval = 5 * time.Minute
|
const cullInterval = 5 * time.Minute
|
||||||
|
|
||||||
func (ss *SessionStore) cull() {
|
func (ss *authCodeStore) cull() {
|
||||||
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
||||||
for k, v := range ss.s {
|
for k, v := range ss.s {
|
||||||
if now.After(v.Expires) {
|
if !v.IsValid(now) {
|
||||||
delete(ss.s, k)
|
delete(ss.s, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) register(t *Token) {
|
func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string {
|
||||||
ss.cull()
|
ss.cull()
|
||||||
ss.s[t.ID] = t
|
code := generate.UUID()
|
||||||
|
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
|
||||||
|
|
||||||
|
return code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) {
|
||||||
if t, hasToken := ss.s[tr.Code]; hasToken {
|
key := authCodeTuple{tr.ClientID, tr.Code}
|
||||||
// TODO: JWT
|
if t, hasCode := ss.s[key]; hasCode {
|
||||||
if t.Expires.After(time.Now()) {
|
defer delete(ss.s, key)
|
||||||
return t.user, true
|
if t.IsValid(time.Now()) {
|
||||||
} else {
|
return t.Cred, true
|
||||||
delete(ss.s, t.ID)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
type Credential struct {
|
type Credentials struct {
|
||||||
ID CredID `json:"id"`
|
ID CredID `json:"id"`
|
||||||
UserID UserID `json:"user_id"`
|
UserID UserID `json:"user_id"`
|
||||||
AuthProviderType string `json:"auth_provider_type"`
|
AuthProviderType string `json:"auth_provider_type"`
|
||||||
AuthProviderID *string `json:"auth_provider_id"`
|
AuthProviderID *string `json:"auth_provider_id"`
|
||||||
DataRaw json.RawMessage `json:"data,omitempty"`
|
DataRaw *json.RawMessage `json:"data,omitempty"`
|
||||||
user provider.ProviderUser
|
|
||||||
|
User provider.ProviderUser `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cred *Credential) MarshalJSON() ([]byte, error) {
|
func (cred *Credentials) MarshalJSON() ([]byte, error) {
|
||||||
rm := map[string]interface{}{
|
type CredAlias Credentials // alias so ø method set and we don't recurse
|
||||||
"id": cred.ID,
|
nCd := (*CredAlias)(cred)
|
||||||
"user_id": cred.UserID,
|
|
||||||
"auth_provider_type": cred.user.ProviderType(),
|
|
||||||
"auth_provider_id": cred.user.ProviderID(),
|
|
||||||
}
|
|
||||||
|
|
||||||
providerData := cred.user.ProviderUserData()
|
if cred.User != nil {
|
||||||
|
providerData := cred.User.UserData()
|
||||||
if providerData != nil {
|
if providerData != nil {
|
||||||
rm["data"] = providerData
|
b, err := json.Marshal(providerData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.Marshal(rm)
|
dr := json.RawMessage(b)
|
||||||
|
nCd.DataRaw = &dr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(nCd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credential {
|
type (
|
||||||
user, success := ss.verify(tr, r)
|
TokenType string
|
||||||
|
RefreshTokenID string
|
||||||
|
RefreshTokenToken string
|
||||||
|
)
|
||||||
|
|
||||||
|
func (rti RefreshTokenID) String() string { return string(rti) }
|
||||||
|
|
||||||
|
func (rti RefreshTokenToken) IsValid() bool { return rti != "" }
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenTypeSystem TokenType = "system"
|
||||||
|
TokenTypeNormal TokenType = "normal"
|
||||||
|
TokenTypeLongLived TokenType = "long_lived_access_token"
|
||||||
|
TokenTypeNone TokenType = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tt TokenType) IsValid() bool {
|
||||||
|
switch tt {
|
||||||
|
case TokenTypeSystem, TokenTypeNormal, TokenTypeLongLived:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshToken struct {
|
||||||
|
ID RefreshTokenID `json:"id"`
|
||||||
|
UserID UserID `json:"user_id"`
|
||||||
|
ClientID *ClientID `json:"client_id"`
|
||||||
|
ClientName *string `json:"client_name"`
|
||||||
|
ClientIcon *string `json:"client_icon"`
|
||||||
|
TokenType TokenType `json:"token_type"`
|
||||||
|
CreatedAt *common.PyTimestamp `json:"created_at"`
|
||||||
|
AccessTokenExpiration json.Number `json:"access_token_expiration"`
|
||||||
|
Token RefreshTokenToken `json:"token"`
|
||||||
|
JWTKey string `json:"jwt_key"`
|
||||||
|
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
|
||||||
|
LastUsedIP *string `json:"last_used_ip"`
|
||||||
|
CredentialID *CredID `json:"credential_id"`
|
||||||
|
Version *string `json:"version"`
|
||||||
|
|
||||||
|
User *User `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *RefreshToken) IsValid() bool {
|
||||||
|
return rt.JWTKey != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *RefreshToken) AccessExpiration() (exp int64) {
|
||||||
|
exp, err := rt.AccessTokenExpiration.Int64()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshOption func(*RefreshToken)
|
||||||
|
|
||||||
|
func WithClientID(cid ClientID) RefreshOption {
|
||||||
|
return func(rt *RefreshToken) {
|
||||||
|
rt.ClientID = &cid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithClientName(n string) RefreshOption {
|
||||||
|
return func(rt *RefreshToken) {
|
||||||
|
rt.ClientName = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithClientIcon(n string) RefreshOption {
|
||||||
|
return func(rt *RefreshToken) {
|
||||||
|
rt.ClientIcon = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTokenType(t TokenType) RefreshOption {
|
||||||
|
return func(rt *RefreshToken) {
|
||||||
|
rt.TokenType = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCredential(c *Credentials) RefreshOption {
|
||||||
|
return func(rt *RefreshToken) {
|
||||||
|
rt.CredentialID = &c.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DefaultAccessExpiration = "1800" // json 🤮
|
||||||
|
|
||||||
|
func (a *authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
|
||||||
|
e := func(es string, arg ...interface{}) (*RefreshToken, error) {
|
||||||
|
return nil, fmt.Errorf(es, arg...)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := common.PyTimestamp(time.Now())
|
||||||
|
|
||||||
|
r := &RefreshToken{
|
||||||
|
ID: RefreshTokenID(generate.UUID()),
|
||||||
|
UserID: user.ID,
|
||||||
|
Token: RefreshTokenToken(generate.Hex(64)),
|
||||||
|
JWTKey: generate.Hex(64),
|
||||||
|
CreatedAt: &now,
|
||||||
|
AccessTokenExpiration: DefaultAccessExpiration,
|
||||||
|
User: user,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TokenType == TokenTypeNone {
|
||||||
|
if user.SystemGenerated {
|
||||||
|
r.TokenType = TokenTypeSystem
|
||||||
|
} else {
|
||||||
|
r.TokenType = TokenTypeNormal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case !r.TokenType.IsValid():
|
||||||
|
return e("invalid token type")
|
||||||
|
case !user.Active:
|
||||||
|
return e("user is not active")
|
||||||
|
case user.SystemGenerated && r.ClientID != nil:
|
||||||
|
return e("system generated users cannot have refresh tokens connected to a client")
|
||||||
|
case !r.TokenType.IsValid():
|
||||||
|
return e("invalid token type '%v'", r.TokenType)
|
||||||
|
case user.SystemGenerated != (r.TokenType == TokenTypeSystem):
|
||||||
|
return e("system generated user can only have system type refresh tokens")
|
||||||
|
case r.TokenType == TokenTypeNormal && r.ClientID == nil:
|
||||||
|
return e("client is required to generate a refresh token")
|
||||||
|
case r.TokenType == TokenTypeLongLived && r.ClientName == nil:
|
||||||
|
return e("client name is required for long-lived token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TokenType == TokenTypeLongLived {
|
||||||
|
for _, lv := range user.RefreshTokens {
|
||||||
|
if strPtrEq(lv.ClientName, r.ClientName) && lv.TokenType == TokenTypeLongLived {
|
||||||
|
return e("client name '%v' already exists", *r.ClientName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.store.PutRefreshToken(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
pytnow := common.PyTimestamp(now)
|
||||||
|
r.LastUsedAt = &pytnow
|
||||||
|
r.LastUsedIP = &req.RemoteAddr
|
||||||
|
|
||||||
|
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
|
||||||
|
Issuer: r.ID.String(),
|
||||||
|
IssuedAt: now.Unix(),
|
||||||
|
ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(),
|
||||||
|
}).SignedString([]byte(r.JWTKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
|
||||||
|
var unverifiedIssRT *RefreshToken
|
||||||
|
claims := &jwt.StandardClaims{}
|
||||||
|
_, err := jwt.ParseWithClaims(string(token), claims, func(jt *jwt.Token) (interface{}, error) {
|
||||||
|
iss := jt.Claims.(*jwt.StandardClaims).Issuer
|
||||||
|
unverifiedIssRT = a.store.GetRefreshToken(RefreshTokenID(iss))
|
||||||
|
if unverifiedIssRT == nil {
|
||||||
|
return nil, fmt.Errorf("bad token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(unverifiedIssRT.JWTKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("validateAccessToken")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return unverifiedIssRT
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
|
||||||
|
cred, success := a.authCodes.get(tr)
|
||||||
if !success {
|
if !success {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Credential{user: user}
|
return cred
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultExpiration = 2 * time.Hour
|
const defaultExpiration = 15 * time.Minute
|
||||||
|
|
||||||
func (a *Authenticator) NewToken(r *http.Request, user provider.ProviderUser, f *Flow) TokenID {
|
func (a *authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
|
||||||
id := TokenID(genUUID())
|
return a.authCodes.put(clientID, cred)
|
||||||
|
|
||||||
t := &Token{
|
|
||||||
ID: id,
|
|
||||||
Ctime: time.Now(),
|
|
||||||
Expires: time.Now().Add(defaultExpiration),
|
|
||||||
Addr: r.RemoteAddr,
|
|
||||||
|
|
||||||
user: user,
|
|
||||||
}
|
|
||||||
|
|
||||||
a.sessions.register(t)
|
|
||||||
|
|
||||||
return id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GrantType string
|
type GrantType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GTAuthorizationCode GrantType = "authorization_code"
|
GrantAuthCode GrantType = "authorization_code"
|
||||||
GTRefreshToken GrantType = "refresh_token"
|
GrantRefreshToken GrantType = "refresh_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientID string
|
type ClientID common.ClientID
|
||||||
|
|
||||||
func (c *ClientID) IsValid() bool {
|
func (c *ClientID) IsValid() bool {
|
||||||
// TODO: || !indieauth.VerifyClientID(rq.ClientID)?
|
// TODO: || !indieauth.VerifyClientID(rq.ClientID)?
|
||||||
return *c != ""
|
return *c != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type AuthCode string
|
||||||
ClientID ClientID `form:"client_id"`
|
|
||||||
Code TokenID `form:"code"`
|
func (ac *AuthCode) IsValid() bool {
|
||||||
GrantType GrantType `form:"grant_type"`
|
return *ac != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
type TokenRequest struct {
|
||||||
|
ClientID ClientID `form:"client_id"`
|
||||||
|
Code AuthCode `form:"code"`
|
||||||
|
GrantType GrantType `form:"grant_type"`
|
||||||
|
RefreshToken RefreshTokenToken `form:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const AuthFailed = "authentication failure"
|
||||||
|
|
||||||
|
func (a *authenticator) TokenHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
rq := new(TokenRequest)
|
rq := new(TokenRequest)
|
||||||
err := c.Bind(rq)
|
err := c.Bind(rq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -146,7 +351,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch rq.GrantType {
|
switch rq.GrantType {
|
||||||
case GTAuthorizationCode:
|
case GrantAuthCode:
|
||||||
if !rq.ClientID.IsValid() {
|
if !rq.ClientID.IsValid() {
|
||||||
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
||||||
}
|
}
|
||||||
|
@ -155,21 +360,88 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if cred := a.sessions.verifyAndGetCredential(rq, c.Request()); cred != nil {
|
cred := a.verifyAndGetCredential(rq)
|
||||||
// TODO: success
|
if cred == nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
||||||
|
}
|
||||||
|
|
||||||
user, err := a.getOrCreateUser(cred)
|
user, err := a.getOrCreateUser(cred)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusUnauthorized, AuthError{Error: "access_denied", Description: "bad user"})
|
log.Error().Err(err).Msg("getOrCreateUser")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.allowedToAuth(); err != nil {
|
if err := user.allowedToAuth(c.Request()); err != nil {
|
||||||
return c.JSON(http.StatusUnauthorized, AuthError{Error: "access_denied", Description: err.Error()})
|
log.Error().Err(err).Msg("allowedToAuth")
|
||||||
}
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
return c.String(http.StatusOK, "token good I guess")
|
|
||||||
}
|
|
||||||
case GTRefreshToken:
|
|
||||||
return c.String(http.StatusNotImplemented, "not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.String(http.StatusUnauthorized, "token bad I guess")
|
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("NewRefreshToken")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
at, err := rt.AccessToken(c.Request())
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("AccessToken")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.NoCache(c).JSON(http.StatusOK, &struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken RefreshTokenToken `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
HAAuthProvider string `json:"ha_auth_provider"`
|
||||||
|
}{
|
||||||
|
AccessToken: at,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
RefreshToken: rt.Token,
|
||||||
|
ExpiresIn: rt.AccessExpiration(),
|
||||||
|
HAAuthProvider: cred.AuthProviderType,
|
||||||
|
})
|
||||||
|
case GrantRefreshToken:
|
||||||
|
log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token")
|
||||||
|
|
||||||
|
if !rq.ClientID.IsValid() {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rq.RefreshToken.IsValid() {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := a.store.GetRefreshTokenByToken(rq.RefreshToken)
|
||||||
|
if rt == nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if rt.ClientID == nil || *rt.ClientID != rq.ClientID {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rt.User.allowedToAuth(c.Request()); err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
at, err := rt.AccessToken(c.Request())
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.NoCache(c).JSON(http.StatusOK, &struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}{
|
||||||
|
AccessToken: at,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: rt.AccessExpiration(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccessToken string
|
||||||
|
|
|
@ -1,36 +1,144 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"crypto/subtle"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/internal/generate"
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AuthStoreKey = "auth"
|
AuthStoreKey = "auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type AuthStore interface {
|
type AuthStore interface {
|
||||||
|
User(UserID) *User
|
||||||
|
GetCredential(provider.ProviderUser) *Credentials
|
||||||
|
PutRefreshToken(*RefreshToken) (*RefreshToken, error)
|
||||||
|
GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken
|
||||||
|
GetRefreshToken(RefreshTokenID) *RefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
type authStore struct {
|
type authStore struct {
|
||||||
Users []User `json:"users"`
|
storage.Item `json:"-"`
|
||||||
Groups interface {} `json:"groups"`
|
|
||||||
Credentials []Credential `json:"credentials"`
|
Users []*User `json:"users"`
|
||||||
|
Groups []*Group `json:"groups"`
|
||||||
|
Credentials []*Credentials `json:"credentials"`
|
||||||
|
Refresh []*RefreshToken `json:"refresh_tokens"`
|
||||||
|
|
||||||
userMap map[UserID]*User
|
userMap map[UserID]*User
|
||||||
|
providerUsers map[provider.ProviderUser]*Credentials
|
||||||
|
store storage.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
|
func (as *authStore) sync() {
|
||||||
as = &authStore{}
|
err := as.store.Flush(as.ItemKey())
|
||||||
err = s.Get(AuthStoreKey, as)
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("sync authStore")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
as.userMap = make(map[UserID]*User)
|
func strPtrEq(n1, n2 *string) bool {
|
||||||
|
return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials {
|
||||||
|
var found *Credentials
|
||||||
|
for _, cr := range as.Credentials {
|
||||||
|
if p != nil && (p == cr.User ||
|
||||||
|
(p.Provider() != nil &&
|
||||||
|
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
|
||||||
|
cr.AuthProviderType == p.Provider().ProviderType() &&
|
||||||
|
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) {
|
||||||
|
found = cr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
|
||||||
|
e := func(es string, a ...interface{}) (*RefreshToken, error) {
|
||||||
|
return nil, fmt.Errorf(es, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, hasUser := as.userMap[rt.UserID]
|
||||||
|
if !hasUser {
|
||||||
|
return e("no such user %v", rt.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
as.Refresh = append(as.Refresh, rt)
|
||||||
|
u.RefreshTokens = append(u.RefreshTokens, rt)
|
||||||
|
|
||||||
|
as.sync()
|
||||||
|
return rt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken {
|
||||||
|
var found *RefreshToken
|
||||||
|
|
||||||
for _, u := range as.Users {
|
for _, u := range as.Users {
|
||||||
as.userMap[u.ID] = &u
|
for _, rt := range u.RefreshTokens {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 {
|
||||||
|
found = rt
|
||||||
|
found.User = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (as *authStore) GetRefreshToken(tid RefreshTokenID) *RefreshToken {
|
||||||
|
var found *RefreshToken
|
||||||
|
|
||||||
|
for _, u := range as.Users {
|
||||||
|
for _, rt := range u.RefreshTokens {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(tid), []byte(rt.ID.String())) == 1 {
|
||||||
|
found = rt
|
||||||
|
found.User = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
|
||||||
|
// XXX: probably broken
|
||||||
|
prov := p.Provider()
|
||||||
|
id := generate.UUID()
|
||||||
|
c := &Credentials{
|
||||||
|
ID: CredID(id),
|
||||||
|
AuthProviderType: prov.ProviderBase().Type,
|
||||||
|
AuthProviderID: prov.ProviderBase().ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
|
||||||
|
as = &authStore{
|
||||||
|
store: s,
|
||||||
|
}
|
||||||
|
|
||||||
|
as.Item, err = s.GetItem(AuthStoreKey, as)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
as.userMap = make(map[UserID]*User)
|
||||||
|
as.providerUsers = make(map[provider.ProviderUser]*Credentials)
|
||||||
|
|
||||||
|
for _, u := range as.Users {
|
||||||
|
as.userMap[u.ID] = u
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range as.Credentials {
|
for _, c := range as.Credentials {
|
||||||
|
@ -39,13 +147,56 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
|
||||||
return nil, fmt.Errorf("no such provider %s", c.AuthProviderType)
|
return nil, fmt.Errorf("no such provider %s", c.AuthProviderType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.DataRaw != nil {
|
||||||
pd := prov.NewCredData()
|
pd := prov.NewCredData()
|
||||||
|
|
||||||
err := json.Unmarshal(c.DataRaw, pd)
|
err := json.Unmarshal(*c.DataRaw, pd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.User = prov.Lookup(pd.(provider.ProviderUser))
|
||||||
|
if c.User == nil {
|
||||||
|
return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName())
|
||||||
}
|
}
|
||||||
|
as.providerUsers[c.User] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
u, hasUser := as.userMap[c.UserID]
|
||||||
|
if !hasUser {
|
||||||
|
log.Error().Str("userid", string(c.UserID)).Msg("creds no such userid in map")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Creds = append(u.Creds, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove invalid RefreshTokens
|
||||||
|
i := 0
|
||||||
|
for _, rt := range as.Refresh {
|
||||||
|
if rt.IsValid() {
|
||||||
|
u, hasUser := as.userMap[rt.UserID]
|
||||||
|
if !hasUser {
|
||||||
|
log.Error().Str("userid", string(rt.UserID)).Msg("refreshtokens no such userid in map")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
as.Refresh[i] = rt
|
||||||
|
i++
|
||||||
|
u.RefreshTokens = append(u.RefreshTokens, rt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't leak memory
|
||||||
|
for j := i; j < len(as.Refresh); j++ {
|
||||||
|
as.Refresh[j] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
as.Refresh = as.Refresh[:i]
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *authStore) User(uid UserID) *User {
|
||||||
|
return s.userMap[uid]
|
||||||
|
}
|
||||||
|
|
|
@ -1,42 +1,63 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"net"
|
||||||
|
"net/http"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserID string
|
type UserID string
|
||||||
type GroupID string
|
type GroupID string
|
||||||
type CredID string
|
type CredID string
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
ID GroupID `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID UserID `json:"id"`
|
ID UserID `json:"id"`
|
||||||
GroupIDs []GroupID `json:"group_ids"`
|
GroupIDs []GroupID `json:"group_ids"`
|
||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
UserMetadata
|
UserMetadata
|
||||||
|
|
||||||
|
Creds []*Credentials `json:"-"`
|
||||||
|
RefreshTokens []*RefreshToken `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserMetadata struct {
|
type UserMetadata struct {
|
||||||
Active bool `json:"is_active"`
|
|
||||||
Owner bool `json:"is_owner"`
|
Owner bool `json:"is_owner"`
|
||||||
LocalOnly bool `json:"local_only"`
|
Active bool `json:"is_active"`
|
||||||
SystemGenerated bool `json:"system_generated"`
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
SystemGenerated bool `json:"system_generated"`
|
||||||
|
LocalOnly bool `json:"local_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) allowedToAuth() error {
|
func (u *User) allowedToAuth(r *http.Request) error {
|
||||||
if !u.Active {
|
if !u.Active {
|
||||||
return errors.New("user disabled")
|
return ErrDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !u.LocalOnly {
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(r.RemoteAddr)
|
||||||
|
if ip == nil {
|
||||||
|
return ErrInvalidIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrUserAuthRemote
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) getOrCreateUser(c *Credential) (*User, error) {
|
func (a *authenticator) getOrCreateUser(c *Credentials) (*User, error) {
|
||||||
log.Debug().Interface("userdata", c.user.ProviderUserData()).Msg("getOrCreateUser")
|
u := a.store.User(c.UserID)
|
||||||
panic("not implemented")
|
if u == nil {
|
||||||
return &User{}, nil
|
return nil, ErrInvalidAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,19 +7,36 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/internal/common"
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth"
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
"dynatron.me/x/blasphem/pkg/bus"
|
"dynatron.me/x/blasphem/pkg/bus"
|
||||||
|
"dynatron.me/x/blasphem/pkg/components"
|
||||||
"dynatron.me/x/blasphem/pkg/config"
|
"dynatron.me/x/blasphem/pkg/config"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
|
"dynatron.me/x/blasphem/pkg/wsapi"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Blas struct {
|
type Blas struct {
|
||||||
*bus.Bus
|
bus.Bus
|
||||||
storage.Store
|
storage.Store
|
||||||
|
auth.Authenticator
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
|
core.WebSocketManager
|
||||||
|
|
||||||
|
components components.ComponentStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Blas) Shutdown(ctx context.Context) error {
|
func (b *Blas) Version() string {
|
||||||
b.Bus.Shutdown()
|
return common.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Blas) Conf() *config.Config { return b.Config }
|
||||||
|
|
||||||
|
func (b *Blas) ShutdownBlas(ctx context.Context) error {
|
||||||
|
b.Bus.ShutdownBus()
|
||||||
|
b.Store.ShutdownStore()
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,17 +61,48 @@ func (b *Blas) ConfigDir() (cd string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Blas) openStore() error {
|
func (b *Blas) openStore() error {
|
||||||
stor, err := storage.Open(os.DirFS(b.ConfigDir()))
|
// TODO: based on config, open filestore or db store
|
||||||
|
stor, err := storage.OpenFileStore(b.ConfigDir())
|
||||||
b.Store = stor
|
b.Store = stor
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg *config.Config) (*Blas, error) {
|
func (b *Blas) Component(k components.ComponentKey) components.Component {
|
||||||
b := &Blas{
|
c, ok := b.components[k]
|
||||||
Bus: bus.New(),
|
if !ok {
|
||||||
Config: cfg,
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Blas) Components() components.ComponentStore { return b.components }
|
||||||
|
|
||||||
|
func New(cfg *config.Config) (b *Blas, err error) {
|
||||||
|
b = &Blas{
|
||||||
|
Bus: bus.New(),
|
||||||
|
Config: cfg,
|
||||||
|
components: make(components.ComponentStore),
|
||||||
|
WebSocketManager: wsapi.NewManager(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = b.openStore()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Authenticator, err = auth.New(b.Store)
|
||||||
|
|
||||||
|
for k, v := range Registry {
|
||||||
|
log.Info().Msgf("Setting up component %s", k)
|
||||||
|
c, err := v(b)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msgf("Error setting up component %s", k)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b.components[k] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
err := b.openStore()
|
|
||||||
return b, err
|
return b, err
|
||||||
}
|
}
|
||||||
|
|
21
pkg/blas/components.go
Normal file
21
pkg/blas/components.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package blas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
|
"dynatron.me/x/blasphem/pkg/components"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Setup func(core.Blas) (components.Component, error)
|
||||||
|
|
||||||
|
var Registry = make(map[components.ComponentKey]Setup)
|
||||||
|
|
||||||
|
func Register(key components.ComponentKey, c Setup) {
|
||||||
|
_, already := Registry[key]
|
||||||
|
if already {
|
||||||
|
panic(fmt.Sprintf("component %s already exists", key))
|
||||||
|
}
|
||||||
|
|
||||||
|
Registry[key] = c
|
||||||
|
}
|
51
pkg/blas/core/core.go
Normal file
51
pkg/blas/core/core.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth"
|
||||||
|
"dynatron.me/x/blasphem/pkg/bus"
|
||||||
|
"dynatron.me/x/blasphem/pkg/components"
|
||||||
|
"dynatron.me/x/blasphem/pkg/config"
|
||||||
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Blas interface {
|
||||||
|
auth.Authenticator
|
||||||
|
bus.Bus
|
||||||
|
storage.Store
|
||||||
|
config.Configured
|
||||||
|
components.Componenter
|
||||||
|
|
||||||
|
WebSocketManager
|
||||||
|
|
||||||
|
Shutdowner
|
||||||
|
Versioner
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebSocketManager interface {
|
||||||
|
// Register registers a websocket command.
|
||||||
|
// cmd is the first part, before first slash
|
||||||
|
// dataNew is a function to create a new message datatype
|
||||||
|
RegisterWSCommand(cmd string, hnd Handler, dataNew NewData)
|
||||||
|
WSCommandHandler(cmd string, splitCmd []string) (NewData, Handler, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebSocketSession interface {
|
||||||
|
Conn() *websocket.Conn
|
||||||
|
Go(context.Context) error
|
||||||
|
Blas() Blas
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler func(ctx context.Context, wss WebSocketSession, msgID int, cmd []string, msg interface{}) error
|
||||||
|
type NewData func(cmd []string) interface{}
|
||||||
|
|
||||||
|
type Shutdowner interface {
|
||||||
|
ShutdownBlas(context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Versioner interface {
|
||||||
|
Version() string
|
||||||
|
}
|
|
@ -12,22 +12,28 @@ type (
|
||||||
|
|
||||||
listeners []chan<- Event
|
listeners []chan<- Event
|
||||||
|
|
||||||
Bus struct {
|
bus struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
subs map[string]listeners
|
subs map[string]listeners
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Bus interface {
|
||||||
|
Sub(topic string, ch chan<- Event)
|
||||||
|
Unsub(topic string, ch chan<- Event)
|
||||||
|
Pub(topic string, data interface{})
|
||||||
|
ShutdownBus()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func New() *Bus {
|
func New() Bus {
|
||||||
bus := &Bus{
|
bus := &bus{
|
||||||
subs: make(map[string]listeners),
|
subs: make(map[string]listeners),
|
||||||
}
|
}
|
||||||
|
|
||||||
return bus
|
return bus
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bus) Sub(topic string, ch chan<- Event) {
|
func (b *bus) Sub(topic string, ch chan<- Event) {
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
|
@ -38,7 +44,7 @@ func (b *Bus) Sub(topic string, ch chan<- Event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bus) Unsub(topic string, ch chan<- Event) {
|
func (b *bus) Unsub(topic string, ch chan<- Event) {
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
|
@ -51,7 +57,7 @@ func (b *Bus) Unsub(topic string, ch chan<- Event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bus) Pub(topic string, data interface{}) {
|
func (b *bus) Pub(topic string, data interface{}) {
|
||||||
b.RLock()
|
b.RLock()
|
||||||
defer b.RUnlock()
|
defer b.RUnlock()
|
||||||
|
|
||||||
|
@ -65,7 +71,7 @@ func (b *Bus) Pub(topic string, data interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bus) Shutdown() {
|
func (b *bus) ShutdownBus() {
|
||||||
for _, v := range b.subs {
|
for _, v := range b.subs {
|
||||||
for _, c := range v {
|
for _, c := range v {
|
||||||
close(c)
|
close(c)
|
||||||
|
|
|
@ -2,18 +2,18 @@ package serve
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"dynatron.me/x/blasphem/internal/common"
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
"dynatron.me/x/blasphem/pkg/config"
|
blas "dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
"dynatron.me/x/blasphem/pkg/server"
|
"dynatron.me/x/blasphem/pkg/server"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServeOptions struct {
|
type ServeOptions struct {
|
||||||
cfg *config.Config
|
core blas.Blas
|
||||||
}
|
}
|
||||||
|
|
||||||
func Command(cfg *config.Config) *cobra.Command {
|
func Command(core blas.Blas) *cobra.Command {
|
||||||
opts := makeOptions(cfg)
|
opts := makeOptions(core)
|
||||||
serveCmd := &cobra.Command{
|
serveCmd := &cobra.Command{
|
||||||
Use: "serve",
|
Use: "serve",
|
||||||
Short: "starts the " + common.AppName + " server",
|
Short: "starts the " + common.AppName + " server",
|
||||||
|
@ -23,9 +23,9 @@ func Command(cfg *config.Config) *cobra.Command {
|
||||||
return serveCmd
|
return serveCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeOptions(cfg *config.Config) *ServeOptions {
|
func makeOptions(core blas.Blas) *ServeOptions {
|
||||||
return &ServeOptions{
|
return &ServeOptions{
|
||||||
cfg: cfg,
|
core: core,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func (o *ServeOptions) Options(_ *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *ServeOptions) Execute() error {
|
func (o *ServeOptions) Execute() error {
|
||||||
server, err := server.New(o.cfg)
|
server, err := server.New(o.core)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
17
pkg/components/components.go
Normal file
17
pkg/components/components.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package components
|
||||||
|
|
||||||
|
import ()
|
||||||
|
|
||||||
|
type Componenter interface {
|
||||||
|
Component(ComponentKey) Component
|
||||||
|
Components() ComponentStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
ComponentStore map[ComponentKey]Component
|
||||||
|
ComponentKey string
|
||||||
|
|
||||||
|
Component interface {
|
||||||
|
Shutdown()
|
||||||
|
}
|
||||||
|
)
|
|
@ -11,6 +11,10 @@ import (
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Configured interface {
|
||||||
|
Conf() *Config
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
DataDir *string `yaml:"data_dir,omitempty"`
|
DataDir *string `yaml:"data_dir,omitempty"`
|
||||||
Server *server.Config `yaml:"server"`
|
Server *server.Config `yaml:"server"`
|
||||||
|
|
203
pkg/flow/flow.go
Normal file
203
pkg/flow/flow.go
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
// flow is the data entry flow.
|
||||||
|
package flow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/internal/generate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResultType string
|
||||||
|
type FlowID string
|
||||||
|
type Step string
|
||||||
|
type HandlerKey string
|
||||||
|
type Errors interface{}
|
||||||
|
|
||||||
|
type FlowStore map[FlowID]Handler
|
||||||
|
|
||||||
|
type FlowManager struct {
|
||||||
|
flows FlowStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Type ResultType `json:"type"`
|
||||||
|
ID FlowID `json:"flow_id"`
|
||||||
|
Handler []*HandlerKey `json:"handler"`
|
||||||
|
Title *string `json:"title,omitempty"`
|
||||||
|
Data map[string]interface{} `json:"data,omitempty"`
|
||||||
|
StepID *Step `json:"step_id,omitempty"`
|
||||||
|
Schema []SchemaItem `json:"data_schema"`
|
||||||
|
Extra *string `json:"extra,omitempty"`
|
||||||
|
Required *bool `json:"required,omitempty"`
|
||||||
|
Errors interface{} `json:"errors"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
DescPlace *string `json:"description_placeholders"`
|
||||||
|
URL *string `json:"url,omitempty"`
|
||||||
|
Reason *string `json:"reason,omitempty"`
|
||||||
|
Context *string `json:"context,omitempty"`
|
||||||
|
Result interface{} `json:"result,omitempty"`
|
||||||
|
LastStep *string `json:"last_step"`
|
||||||
|
Options map[string]interface{} `json:"options,omitempty"`
|
||||||
|
Version *int `json:"version,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler interface {
|
||||||
|
BaseHandler() FlowHandler
|
||||||
|
FlowID() FlowID
|
||||||
|
|
||||||
|
flowCtime() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
StepInit Step = "init"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (fs Schema) CheckRequired(rm map[string]interface{}) error {
|
||||||
|
for _, si := range fs {
|
||||||
|
if si.Required {
|
||||||
|
if _, ok := rm[si.Name]; !ok {
|
||||||
|
return fmt.Errorf("missing required param %s", si.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlowManager() *FlowManager {
|
||||||
|
return &FlowManager{
|
||||||
|
flows: make(FlowStore),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stepPtr(s Step) *Step { return &s }
|
||||||
|
|
||||||
|
type FlowHandler struct {
|
||||||
|
ID FlowID // ID is the FlowID
|
||||||
|
Handler HandlerKey // Handler key
|
||||||
|
|
||||||
|
// curStep is the current step set by the flow manager
|
||||||
|
curStep Step
|
||||||
|
|
||||||
|
ctime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FlowHandler) Step() Step { return f.curStep }
|
||||||
|
|
||||||
|
func (f *FlowHandler) BaseHandler() FlowHandler { return *f }
|
||||||
|
|
||||||
|
func (f *FlowHandler) FlowID() FlowID {
|
||||||
|
return f.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FlowHandler) flowCtime() time.Time { return f.ctime }
|
||||||
|
|
||||||
|
func NewFlowHandlerBase(hand string) FlowHandler {
|
||||||
|
return FlowHandler{
|
||||||
|
ID: FlowID(generate.UUID()),
|
||||||
|
Handler: HandlerKey(hand),
|
||||||
|
|
||||||
|
curStep: StepInit,
|
||||||
|
ctime: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hk *HandlerKey) String() string {
|
||||||
|
return string(*hk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fm *FlowHandler) Handlers() []*HandlerKey {
|
||||||
|
return []*HandlerKey{&fm.Handler, nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resultErrs(e Errors) Errors {
|
||||||
|
if e == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormOption func(*Result)
|
||||||
|
|
||||||
|
func (*FlowHandler) WithErrors(e Errors) FormOption {
|
||||||
|
return func(r *Result) {
|
||||||
|
r.Errors = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowHandler) WithStep(s Step) FormOption {
|
||||||
|
return func(r *Result) {
|
||||||
|
r.StepID = stepPtr(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowHandler) WithSchema(sch Schemer) FormOption {
|
||||||
|
return func(r *Result) {
|
||||||
|
r.Schema = sch.FlowSchema()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fm *FlowHandler) ShowForm(opts ...FormOption) *Result {
|
||||||
|
res := &Result{
|
||||||
|
Type: TypeForm,
|
||||||
|
ID: fm.ID,
|
||||||
|
StepID: stepPtr(fm.curStep),
|
||||||
|
Handler: fm.Handlers(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Errors = resultErrs(res.Errors)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fm *FlowManager) Delete(id FlowID) {
|
||||||
|
delete(fm.flows, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeForm ResultType = "form"
|
||||||
|
TypeCreateEntry ResultType = "create_entry"
|
||||||
|
TypeAbort ResultType = "abort"
|
||||||
|
TypeExternalStep ResultType = "external"
|
||||||
|
TypeExternalStepDone ResultType = "external_done"
|
||||||
|
TypeShowProgress ResultType = "progress"
|
||||||
|
TypeShowProgressDone ResultType = "progress_done"
|
||||||
|
TypeMenu ResultType = "menu"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f *FlowHandler) touch() {
|
||||||
|
f.ctime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fm *FlowManager) Register(f Handler) {
|
||||||
|
fm.flows.cull()
|
||||||
|
fm.flows[f.FlowID()] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FlowManager) Remove(f Handler) {
|
||||||
|
delete(fs.flows, f.FlowID())
|
||||||
|
}
|
||||||
|
|
||||||
|
const cullAge = time.Minute * 10
|
||||||
|
|
||||||
|
func (fs FlowStore) cull() {
|
||||||
|
for k, v := range fs {
|
||||||
|
if time.Now().Sub(v.flowCtime()) > cullAge {
|
||||||
|
delete(fs, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FlowManager) Get(id FlowID) Handler {
|
||||||
|
f, ok := fs.flows[id]
|
||||||
|
if ok {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
40
pkg/flow/schema.go
Normal file
40
pkg/flow/schema.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package flow
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeString Type = "string"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t Type) IsValid() bool {
|
||||||
|
switch t {
|
||||||
|
case TypeString:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchemaItem struct {
|
||||||
|
Type Type `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Required bool `json:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Schema []SchemaItem
|
||||||
|
|
||||||
|
type Schemer interface {
|
||||||
|
FlowSchema() Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSchema(items ...SchemaItem) Schema {
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequiredString(name string) SchemaItem {
|
||||||
|
return SchemaItem{
|
||||||
|
Type: TypeString,
|
||||||
|
Name: name,
|
||||||
|
Required: true,
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,24 +1,42 @@
|
||||||
package frontend
|
package frontend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas"
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
|
"dynatron.me/x/blasphem/pkg/components"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const FrontendKey = "frontend"
|
||||||
|
|
||||||
//go:embed frontend/hass_frontend
|
//go:embed frontend/hass_frontend
|
||||||
var root embed.FS
|
var root embed.FS
|
||||||
|
|
||||||
var RootFS fs.FS
|
type Frontend struct {
|
||||||
|
fsHandler echo.HandlerFunc
|
||||||
|
rootFS fs.FS
|
||||||
|
|
||||||
var FSHandler echo.HandlerFunc
|
routeInstall sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
func AliasHandler(toFile string) echo.HandlerFunc {
|
func (fe *Frontend) InstallRoutes(e *echo.Echo) {
|
||||||
|
fe.routeInstall.Do(func() {
|
||||||
|
e.GET("/*", fe.fsHandler)
|
||||||
|
e.GET("/auth/authorize", fe.AliasHandler("authorize.html"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fe *Frontend) AliasHandler(toFile string) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
file, err := RootFS.Open(toFile)
|
file, err := fe.rootFS.Open(toFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -33,13 +51,32 @@ func AliasHandler(toFile string) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func (*Frontend) Shutdown() {}
|
||||||
|
|
||||||
|
func newData(_ []string) interface{} {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsHand(ctx context.Context, wss core.WebSocketSession, msgID int, cmd []string, msg interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Setup(b core.Blas) (components.Component, error) {
|
||||||
|
fe := &Frontend{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
RootFS, err = fs.Sub(root, "frontend/hass_frontend")
|
fe.rootFS, err = fs.Sub(root, "frontend/hass_frontend")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
FSHandler = echo.StaticDirectoryHandler(RootFS, false)
|
fe.fsHandler = echo.StaticDirectoryHandler(fe.rootFS, false)
|
||||||
|
|
||||||
|
b.RegisterWSCommand("frontend", wsHand, newData)
|
||||||
|
|
||||||
|
return fe, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
blas.Register(FrontendKey, Setup)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,40 +12,39 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/ziflex/lecho/v3"
|
"github.com/ziflex/lecho/v3"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/pkg/auth"
|
|
||||||
"dynatron.me/x/blasphem/pkg/blas"
|
"dynatron.me/x/blasphem/pkg/blas"
|
||||||
"dynatron.me/x/blasphem/pkg/config"
|
"dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
"dynatron.me/x/blasphem/pkg/frontend"
|
"dynatron.me/x/blasphem/pkg/frontend"
|
||||||
conf "dynatron.me/x/blasphem/pkg/server/config"
|
conf "dynatron.me/x/blasphem/pkg/server/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
*blas.Blas
|
core.Blas
|
||||||
*echo.Echo
|
*echo.Echo
|
||||||
auth.Authenticator
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) installRoutes() {
|
type RouteHaver interface {
|
||||||
s.GET("/*", frontend.FSHandler)
|
InstallRoutes(e *echo.Echo)
|
||||||
s.GET("/api/websocket", s.wsHandler)
|
|
||||||
|
|
||||||
s.Authenticator.InstallRoutes(s.Echo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg *config.Config) (s *Server, err error) {
|
func (s *Server) installRoutes() {
|
||||||
b, err := blas.New(cfg)
|
s.GET("/api/websocket", s.wsHandler)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s = &Server{
|
s.Component(frontend.FrontendKey).(RouteHaver).InstallRoutes(s.Echo)
|
||||||
Blas: b,
|
s.Blas.(*blas.Blas).Authenticator.InstallRoutes(s.Echo)
|
||||||
Echo: echo.New(),
|
|
||||||
|
for _, c := range s.Components() {
|
||||||
|
if rh, ok := c.(RouteHaver); ok {
|
||||||
|
rh.InstallRoutes(s.Echo)
|
||||||
}
|
}
|
||||||
err = s.InitAuth(b.Store)
|
}
|
||||||
if err != nil {
|
}
|
||||||
return s, err
|
|
||||||
|
func New(core core.Blas) (s *Server, err error) {
|
||||||
|
s = &Server{
|
||||||
|
Blas: core,
|
||||||
|
Echo: echo.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Echo.Debug = true
|
s.Echo.Debug = true
|
||||||
|
@ -53,6 +52,8 @@ func New(cfg *config.Config) (s *Server, err error) {
|
||||||
logger := lecho.From(log.Logger)
|
logger := lecho.From(log.Logger)
|
||||||
s.Echo.Logger = logger
|
s.Echo.Logger = logger
|
||||||
|
|
||||||
|
cfg := s.Conf()
|
||||||
|
|
||||||
if cfg.Server.LogRequestErrors {
|
if cfg.Server.LogRequestErrors {
|
||||||
s.Echo.Use(lecho.Middleware(lecho.Config{
|
s.Echo.Use(lecho.Middleware(lecho.Config{
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
|
@ -80,7 +81,7 @@ func New(cfg *config.Config) (s *Server, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Shutdown(ctx context.Context) error {
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
err := s.Blas.Shutdown(ctx)
|
err := s.ShutdownBlas(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -91,8 +92,8 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
func (s *Server) Go() error {
|
func (s *Server) Go() error {
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
log.Info().Str("bind", s.Config.Server.Bind).Msg("Server listening")
|
log.Info().Str("version", s.Version()).Str("bind", s.Conf().Server.Bind).Msg("Server listening")
|
||||||
err := s.Start(s.Config.Server.Bind)
|
err := s.Start(s.Conf().Server.Bind)
|
||||||
if err != nil && err != http.ErrServerClosed {
|
if err != nil && err != http.ErrServerClosed {
|
||||||
s.Logger.Fatal(err)
|
s.Logger.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"context"
|
||||||
"log"
|
"dynatron.me/x/blasphem/pkg/wsapi"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
var upgrader = websocket.Upgrader{
|
||||||
|
@ -14,7 +15,16 @@ var upgrader = websocket.Upgrader{
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) wsHandler(c echo.Context) error {
|
func (s *Server) wsHandler(c echo.Context) error {
|
||||||
log.Println("WebSocket")
|
conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||||
//conn, err := upgrader.Upgrade(w, req, nil)
|
if err != nil {
|
||||||
return errors.New("not handled")
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_ = log.Debug
|
||||||
|
ctx, cancel := context.WithCancel(c.Request().Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return wsapi.NewSession(s, c, conn).Go(ctx)
|
||||||
}
|
}
|
||||||
|
|
220
pkg/storage/filesystem.go
Normal file
220
pkg/storage/filesystem.go
Normal file
|
@ -0,0 +1,220 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
IndentStr = strings.Repeat(" ", 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SecretMode fs.FileMode = 0600
|
||||||
|
DefaultMode fs.FileMode = 0644
|
||||||
|
)
|
||||||
|
|
||||||
|
type item struct {
|
||||||
|
sync.Mutex `json:"-"`
|
||||||
|
Version int `json:"version"`
|
||||||
|
MinorVersion *int `json:"minor_version,omitempty"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
|
||||||
|
fmode fs.FileMode
|
||||||
|
dirty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *item) Dirty() { i.Lock(); defer i.Unlock(); i.dirty = true }
|
||||||
|
func (i *item) IsDirty() bool { i.Lock(); defer i.Unlock(); return i.dirty }
|
||||||
|
func (i *item) GetData() interface{} { i.Lock(); defer i.Unlock(); return i.Data }
|
||||||
|
func (i *item) SetData(d interface{}) { i.Lock(); defer i.Unlock(); i.Data = d; i.dirty = true }
|
||||||
|
func (i *item) ItemKey() string { return i.Key /* key is immutable */ }
|
||||||
|
|
||||||
|
func (it *item) mode() fs.FileMode {
|
||||||
|
if it.fmode != 0 {
|
||||||
|
return it.fmode
|
||||||
|
}
|
||||||
|
|
||||||
|
return SecretMode
|
||||||
|
}
|
||||||
|
|
||||||
|
type fsStore struct {
|
||||||
|
sync.RWMutex
|
||||||
|
fs fs.FS
|
||||||
|
storeRoot string
|
||||||
|
s map[string]*item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) get(key string) *item {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
|
||||||
|
i, ok := s.s[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) put(key string, it *item) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
s.s[key] = it
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) persist(it *item) error {
|
||||||
|
it.Lock()
|
||||||
|
defer it.Unlock()
|
||||||
|
|
||||||
|
f, err := os.OpenFile(path.Join(s.storeRoot, it.Key), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, it.mode())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", IndentStr)
|
||||||
|
|
||||||
|
err = enc.Encode(it)
|
||||||
|
if err == nil {
|
||||||
|
it.dirty = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) Dirty(key string) error {
|
||||||
|
it := s.get(key)
|
||||||
|
if it == nil {
|
||||||
|
return ErrNoSuchKey
|
||||||
|
}
|
||||||
|
|
||||||
|
it.Dirty()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) Flush(key string) error {
|
||||||
|
it := s.get(key)
|
||||||
|
if it == nil {
|
||||||
|
return ErrNoSuchKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.persist(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) FlushAll() []error {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, it := range s.s {
|
||||||
|
err := s.persist(it)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("store key %s: %w", it.Key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) ShutdownStore() {
|
||||||
|
errs := s.FlushAll()
|
||||||
|
if errs != nil {
|
||||||
|
log.Error().Errs("errors", errs).Msg("errors persisting store")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put puts an item into the store.
|
||||||
|
// NB: Any user of a previous item with this key will now have a dangling reference that will not be persisted.
|
||||||
|
// It is up to consumers to coordinate against this case!
|
||||||
|
func (s *fsStore) Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error) {
|
||||||
|
var mv *int
|
||||||
|
if minorVersion != 0 {
|
||||||
|
mv = &minorVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := DefaultMode
|
||||||
|
|
||||||
|
if secretMode {
|
||||||
|
mode = SecretMode
|
||||||
|
}
|
||||||
|
|
||||||
|
it := &item{
|
||||||
|
Version: version,
|
||||||
|
MinorVersion: mv,
|
||||||
|
Key: key,
|
||||||
|
Data: data,
|
||||||
|
|
||||||
|
fmode: mode,
|
||||||
|
dirty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.s[key] = it
|
||||||
|
return it, s.persist(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) Get(key string, data interface{}) error {
|
||||||
|
_, err := s.GetItem(key, data)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) GetItem(key string, data interface{}) (Item, error) {
|
||||||
|
exists := s.get(key)
|
||||||
|
if exists != nil {
|
||||||
|
return exists, ErrKeyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := s.fs.Open(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &item{
|
||||||
|
Data: data,
|
||||||
|
fmode: fi.Mode(),
|
||||||
|
}
|
||||||
|
d := json.NewDecoder(f)
|
||||||
|
err = d.Decode(item)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.Key != key {
|
||||||
|
return nil, fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.put(key, item)
|
||||||
|
|
||||||
|
return item, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenFileStore(configRoot string) (*fsStore, error) {
|
||||||
|
storeRoot := path.Join(configRoot, ".storage")
|
||||||
|
stor := os.DirFS(storeRoot)
|
||||||
|
|
||||||
|
return &fsStore{
|
||||||
|
fs: stor,
|
||||||
|
storeRoot: storeRoot,
|
||||||
|
s: make(map[string]*item),
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -1,58 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Data interface {
|
|
||||||
}
|
|
||||||
|
|
||||||
type Item struct {
|
|
||||||
Version int `json:"version"`
|
|
||||||
MinorVersion *int `json:"minor_version,omitempty"`
|
|
||||||
Key string `json:"key"`
|
|
||||||
Data Data `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type store struct {
|
|
||||||
fs.FS
|
|
||||||
}
|
|
||||||
|
|
||||||
type Store interface {
|
|
||||||
Get(key string, data interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Get(key string, data interface{}) error {
|
|
||||||
f, err := s.Open(key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
item := Item{
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
d := json.NewDecoder(f)
|
|
||||||
err = d.Decode(&item)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.Key != key {
|
|
||||||
return fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Open(dir fs.FS) (*store, error) {
|
|
||||||
stor, err := fs.Sub(dir, ".storage")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &store{stor}, nil
|
|
||||||
}
|
|
55
pkg/storage/store.go
Normal file
55
pkg/storage/store.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoSuchKey = errors.New("no such key in store")
|
||||||
|
ErrKeyExists = errors.New("key already exists")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Item is an item in a datastore.
|
||||||
|
type Item interface {
|
||||||
|
// Item is lockable if updating data item directly.
|
||||||
|
sync.Locker
|
||||||
|
|
||||||
|
// Dirty sets the dirty flag for the item so it will be flushed.
|
||||||
|
Dirty()
|
||||||
|
|
||||||
|
// IsDirty gets the dirty flag for the item.
|
||||||
|
IsDirty() bool
|
||||||
|
|
||||||
|
// GetData gets the data for the item.
|
||||||
|
GetData() interface{}
|
||||||
|
|
||||||
|
// GetData sets the data for the item.
|
||||||
|
SetData(interface{})
|
||||||
|
|
||||||
|
// ItemKey gets the key of the item.
|
||||||
|
ItemKey() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store represents a datastore.
|
||||||
|
type Store interface {
|
||||||
|
// GetItem loads the specified key from the store into data and returns the Item.
|
||||||
|
// If err is ErrKeyExists, Item will be the existing item.
|
||||||
|
GetItem(key string, data interface{}) (Item, error)
|
||||||
|
|
||||||
|
// Get is the same as GetItem, but only returns error.
|
||||||
|
Get(key string, data interface{}) error
|
||||||
|
|
||||||
|
// Put puts the specified key into the store. If the key already exists, it clobbers.
|
||||||
|
// Note that any existing items will then dangle.
|
||||||
|
Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error)
|
||||||
|
|
||||||
|
// FlushAll flushes the store to backing.
|
||||||
|
FlushAll() []error
|
||||||
|
|
||||||
|
// Flush flushes a single key to backing.
|
||||||
|
Flush(key string) error
|
||||||
|
|
||||||
|
// ShutdownStore is called to quiesce and shutdown the store.
|
||||||
|
ShutdownStore()
|
||||||
|
}
|
264
pkg/wsapi/api.go
Normal file
264
pkg/wsapi/api.go
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
package wsapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth"
|
||||||
|
"dynatron.me/x/blasphem/pkg/blas/core"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
NoSuchHandlerErr = errors.New("bad websocket command")
|
||||||
|
NoMessageIDErr = errors.New("no message ID")
|
||||||
|
AuthInvalidErr = errors.New("invalid auth")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
type MsgBase struct {
|
||||||
|
ID *int `json:"id,omitempty"`
|
||||||
|
Type Type `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
wsSession struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
b core.Blas
|
||||||
|
ec echo.Context
|
||||||
|
write chan<- interface{}
|
||||||
|
|
||||||
|
user *auth.User
|
||||||
|
refreshToken *auth.RefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
wsEntry struct {
|
||||||
|
dataNew core.NewData
|
||||||
|
hnd core.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
wsRegistry map[string]wsEntry
|
||||||
|
|
||||||
|
wsManager struct {
|
||||||
|
r wsRegistry
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (wsm *wsManager) RegisterWSCommand(cmd string, hnd core.Handler, dataNew core.NewData) {
|
||||||
|
wsm.r[cmd] = wsEntry{
|
||||||
|
dataNew: dataNew,
|
||||||
|
hnd: hnd,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsm *wsManager) WSCommandHandler(cmd string, cmdSplit []string) (core.NewData, core.Handler, error) {
|
||||||
|
if wse, ok := wsm.r[cmd]; ok {
|
||||||
|
return wse.dataNew, wse.hnd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if wse, ok := wsm.r[cmdSplit[0]]; ok {
|
||||||
|
return wse.dataNew, wse.hnd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, NoSuchHandlerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager() core.WebSocketManager {
|
||||||
|
return &wsManager{
|
||||||
|
r: make(wsRegistry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession {
|
||||||
|
ws := &wsSession{
|
||||||
|
conn: conn,
|
||||||
|
b: s,
|
||||||
|
ec: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ws
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) Conn() *websocket.Conn {
|
||||||
|
return ws.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) Blas() core.Blas { return ws.b }
|
||||||
|
|
||||||
|
func (ws *wsSession) Go(ctx context.Context) error {
|
||||||
|
authP := &authPhase{ws}
|
||||||
|
err := ws.sendAuthRequired()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, rdr, err := ws.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = authP.handleMsg(ctx, rdr)
|
||||||
|
if err != nil || ws.refreshToken == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// command phase
|
||||||
|
|
||||||
|
msgChan := make(chan map[string]interface{})
|
||||||
|
write := make(chan interface{})
|
||||||
|
ws.write = write
|
||||||
|
defer close(write)
|
||||||
|
|
||||||
|
cCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func(ctx context.Context, ch chan<- map[string]interface{}) {
|
||||||
|
defer close(ch)
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, rdr, err := ws.conn.NextReader()
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
case websocket.IsCloseError(err, websocket.CloseGoingAway):
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgMap map[string]interface{}
|
||||||
|
err = json.NewDecoder(rdr).Decode(&msgMap)
|
||||||
|
if err != nil {
|
||||||
|
ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- msgMap
|
||||||
|
|
||||||
|
}
|
||||||
|
}(cCtx, msgChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg, ok := <-msgChan:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ws.handleMsg(ctx, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("handleMsg")
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(msgChan) // maybe remove this?
|
||||||
|
return nil
|
||||||
|
case m := <-write:
|
||||||
|
err := ws.conn.WriteJSON(m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("writeMsg")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) Write(msg interface{}) {
|
||||||
|
ws.write <- msg
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResultMsgType MsgType = "result"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WSError struct {
|
||||||
|
ID *int `json:"id,omitempty"`
|
||||||
|
Type MsgType `json:"type"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Error Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) writeError(id int, err Error) error {
|
||||||
|
return ws.conn.WriteJSON(WSError{
|
||||||
|
ID: &id,
|
||||||
|
Type: ResultMsgType,
|
||||||
|
Success: false,
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, ok := msgMap["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return NoSuchHandlerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
idFl, ok := msgMap["id"].(float64)
|
||||||
|
if !ok {
|
||||||
|
ws.Write(WSError{
|
||||||
|
Type: ResultMsgType,
|
||||||
|
Success: false,
|
||||||
|
Error: Error{
|
||||||
|
Code: "invalid_id",
|
||||||
|
Message: "command has no ID",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id := int(idFl)
|
||||||
|
|
||||||
|
cmd := strings.Split(msgType, "/")
|
||||||
|
|
||||||
|
newData, hand, err := ws.b.WSCommandHandler(cmd[0], cmd)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
case NoSuchHandlerErr:
|
||||||
|
ws.writeError(id, Error{
|
||||||
|
Code: "invalid_type",
|
||||||
|
Message: "no such command",
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
log.Error().Err(err).Msg("dispatch")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nd := newData(cmd)
|
||||||
|
if _, ok := nd.(map[string]interface{}); !ok {
|
||||||
|
err := mapstructure.Decode(&msgMap, &nd)
|
||||||
|
if err != nil {
|
||||||
|
ws.writeError(id, Error{
|
||||||
|
Code: "invalid_format",
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hand(ctx, ws, id, cmd, nd)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("dispatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
81
pkg/wsapi/auth.go
Normal file
81
pkg/wsapi/auth.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package wsapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authPhase struct {
|
||||||
|
*wsSession
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsSession) sendAuthRequired() error {
|
||||||
|
authReq := &struct {
|
||||||
|
MsgBase
|
||||||
|
Version string `json:"ha_version"`
|
||||||
|
}{
|
||||||
|
MsgBase{Type: "auth_required"},
|
||||||
|
ws.b.Version(),
|
||||||
|
}
|
||||||
|
return ws.conn.WriteJSON(&authReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
type authMsg struct {
|
||||||
|
MsgBase
|
||||||
|
AccessToken auth.AccessToken `json:"access_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *authPhase) msgSchema() interface{} {
|
||||||
|
return &authMsg{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *authPhase) finishAuth(rt *auth.RefreshToken) {
|
||||||
|
ap.user = rt.User
|
||||||
|
ap.refreshToken = rt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *authPhase) sendAuthOK() error {
|
||||||
|
return ap.conn.WriteJSON(struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Version string `json:"ha_version"`
|
||||||
|
}{Type: "auth_ok", Version: ap.Blas().Version()})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *authPhase) sendAuthInvalid() error {
|
||||||
|
return ap.conn.WriteJSON(struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}{Type: "auth_ok", Message: "invalid auth"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *authPhase) handleMsg(ctx context.Context, r io.Reader) error {
|
||||||
|
var authMsg authMsg
|
||||||
|
err := json.NewDecoder(r).Decode(&authMsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := ap.b.ValidateAccessToken(authMsg.AccessToken)
|
||||||
|
if refreshToken != nil {
|
||||||
|
ap.finishAuth(refreshToken)
|
||||||
|
return ap.sendAuthOK()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error().Str("remote", ap.ec.Request().RemoteAddr).Msg("websocket auth failed")
|
||||||
|
|
||||||
|
err = ap.sendAuthInvalid()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return AuthInvalidErr
|
||||||
|
}
|
Loading…
Reference in a new issue