Compare commits

..

2 commits

Author SHA1 Message Date
f272514c65 Improve flows in preparation for MFA 2022-12-27 15:09:20 -05:00
60301c9892 check rt 2022-12-21 19:34:26 -05:00
6 changed files with 124 additions and 97 deletions

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"strings" "strings"
@ -19,6 +20,7 @@ type AuthFlowManager struct {
type LoginFlow struct { type LoginFlow struct {
flow.FlowHandler flow.FlowHandler
prov provider.AuthProvider
ClientID common.ClientID ClientID common.ClientID
FlowContext FlowContext
} }
@ -52,9 +54,10 @@ func NewAuthFlowManager() *AuthFlowManager {
func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.AuthProvider) *LoginFlow { func (afm *AuthFlowManager) NewLoginFlow(req *LoginFlowRequest, prov provider.AuthProvider) *LoginFlow {
lf := &LoginFlow{ lf := &LoginFlow{
FlowHandler: flow.NewFlowHandlerBase(prov, prov.ProviderType()), FlowHandler: flow.NewFlowHandlerBase(prov.ProviderType()),
ClientID: req.ClientID, ClientID: req.ClientID,
FlowContext: req.FlowContext(), FlowContext: req.FlowContext(),
prov: prov,
} }
afm.Register(lf) afm.Register(lf)
@ -80,9 +83,9 @@ func (a *authenticator) NewFlow(r *LoginFlowRequest) *flow.Result {
return nil return nil
} }
flow := a.flows.NewLoginFlow(r, prov) lf := a.flows.NewLoginFlow(r, prov)
return flow.ShowForm(nil) return lf.ShowForm(lf.WithSchema(prov), lf.WithStep(flow.StepInit))
} }
func (f *LoginFlow) redirect(c echo.Context) { func (f *LoginFlow) redirect(c echo.Context) {
@ -99,15 +102,19 @@ func (f *LoginFlow) progress(a *authenticator, c echo.Context) error {
return c.String(http.StatusBadRequest, err.Error()) return c.String(http.StatusBadRequest, err.Error())
} }
err = f.Schema.CheckRequired(rm) err = f.prov.FlowSchema().CheckRequired(rm)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()})) return c.JSON(http.StatusBadRequest, f.ShowForm(f.WithErrors([]string{err.Error()})))
} }
user, clientID, err := a.Check(f, c.Request(), rm) user, clientID, err := a.Check(f, c.Request(), rm)
switch err { switch err {
case nil: case nil:
creds := a.store.GetCredential(user) creds := a.store.GetCredential(user)
if creds == nil {
return fmt.Errorf("flow progress: no such credential for %v", user.UserData())
}
finishedFlow := flow.Result{} finishedFlow := flow.Result{}
a.flows.Remove(f) a.flows.Remove(f)
copier.Copy(&finishedFlow, f) copier.Copy(&finishedFlow, f)
@ -124,14 +131,14 @@ func (f *LoginFlow) progress(a *authenticator, c echo.Context) error {
case ErrInvalidAuth: case ErrInvalidAuth:
fallthrough fallthrough
default: default:
return c.JSON(http.StatusOK, f.ShowForm(map[string]interface{}{ return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
"base": "invalid_auth", "base": "invalid_auth",
})) })))
} }
default: default:
return c.JSON(http.StatusOK, f.ShowForm(map[string]interface{}{ return c.JSON(http.StatusOK, f.ShowForm(f.WithErrors(map[string]interface{}{
"base": "unknown_flow_step", "base": "unknown_flow_step",
})) })))
} }
} }

View file

@ -162,18 +162,7 @@ func (hap *HomeAssistantProvider) NewCredData() interface{} {
} }
func (hap *HomeAssistantProvider) FlowSchema() flow.Schema { func (hap *HomeAssistantProvider) FlowSchema() flow.Schema {
return []flow.SchemaItem{ return flow.NewSchema(flow.RequiredString("username"), flow.RequiredString("password"))
{
Type: "string",
Name: "username",
Required: true,
},
{
Type: "string",
Name: "password",
Required: true,
},
}
} }
func init() { func init() {

View file

@ -79,18 +79,7 @@ func (hap *TrustedNetworksProvider) NewCredData() interface{} {
} }
func (hap *TrustedNetworksProvider) FlowSchema() flow.Schema { func (hap *TrustedNetworksProvider) FlowSchema() flow.Schema {
return []flow.SchemaItem{ return nil
{
Type: "string",
Name: "username",
Required: true,
},
{
Type: "string",
Name: "password",
Required: true,
},
}
} }
func init() { func init() {

View file

@ -8,22 +8,19 @@ import (
"dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/internal/generate"
) )
type ( type ResultType string
ResultType string type FlowID string
FlowID string type Step string
Step string type HandlerKey string
HandlerKey string type Errors interface{}
Errors interface{}
Context interface{} type FlowStore map[FlowID]Handler
FlowStore map[FlowID]Handler type FlowManager struct {
FlowManager struct {
flows FlowStore flows FlowStore
} }
Result struct { type Result struct {
Type ResultType `json:"type"` Type ResultType `json:"type"`
ID FlowID `json:"flow_id"` ID FlowID `json:"flow_id"`
Handler []*HandlerKey `json:"handler"` Handler []*HandlerKey `json:"handler"`
@ -43,36 +40,21 @@ type (
LastStep *string `json:"last_step"` LastStep *string `json:"last_step"`
Options map[string]interface{} `json:"options,omitempty"` Options map[string]interface{} `json:"options,omitempty"`
Version *int `json:"version,omitempty"` Version *int `json:"version,omitempty"`
} }
SchemaItem struct { type Handler interface {
Type string `json:"type"`
Name string `json:"name"`
Required bool `json:"required"`
}
Schema []SchemaItem
)
type (
Schemer interface {
FlowSchema() Schema
}
Handler interface {
BaseHandler() FlowHandler BaseHandler() FlowHandler
FlowID() FlowID FlowID() FlowID
flowCtime() time.Time flowCtime() time.Time
} }
)
const ( const (
StepInit Step = "init" StepInit Step = "init"
) )
func (fs *Schema) CheckRequired(rm map[string]interface{}) error { func (fs Schema) CheckRequired(rm map[string]interface{}) error {
for _, si := range *fs { for _, si := range fs {
if si.Required { if si.Required {
if _, ok := rm[si.Name]; !ok { if _, ok := rm[si.Name]; !ok {
return fmt.Errorf("missing required param %s", si.Name) return fmt.Errorf("missing required param %s", si.Name)
@ -94,8 +76,6 @@ func stepPtr(s Step) *Step { return &s }
type FlowHandler struct { type FlowHandler struct {
ID FlowID // ID is the FlowID ID FlowID // ID is the FlowID
Handler HandlerKey // Handler key Handler HandlerKey // Handler key
Context Context // flow Context
Schema Schema
// curStep is the current step set by the flow manager // curStep is the current step set by the flow manager
curStep Step curStep Step
@ -113,11 +93,10 @@ func (f *FlowHandler) FlowID() FlowID {
func (f *FlowHandler) flowCtime() time.Time { return f.ctime } func (f *FlowHandler) flowCtime() time.Time { return f.ctime }
func NewFlowHandlerBase(sch Schemer, hand string) FlowHandler { func NewFlowHandlerBase(hand string) FlowHandler {
return FlowHandler{ return FlowHandler{
ID: FlowID(generate.UUID()), ID: FlowID(generate.UUID()),
Handler: HandlerKey(hand), Handler: HandlerKey(hand),
Schema: sch.FlowSchema(),
curStep: StepInit, curStep: StepInit,
ctime: time.Now(), ctime: time.Now(),
@ -140,16 +119,39 @@ func resultErrs(e Errors) Errors {
return e return e
} }
func (fm *FlowHandler) ShowForm(errs Errors) *Result { type FormOption func(*Result)
func (*FlowHandler) WithErrors(e Errors) FormOption {
return func(r *Result) {
r.Errors = e
}
}
func (*FlowHandler) WithStep(s Step) FormOption {
return func(r *Result) {
r.StepID = stepPtr(s)
}
}
func (*FlowHandler) WithSchema(sch Schemer) FormOption {
return func(r *Result) {
r.Schema = sch.FlowSchema()
}
}
func (fm *FlowHandler) ShowForm(opts ...FormOption) *Result {
res := &Result{ res := &Result{
Type: TypeForm, Type: TypeForm,
ID: fm.ID, ID: fm.ID,
StepID: stepPtr(fm.curStep), StepID: stepPtr(fm.curStep),
Schema: fm.Schema,
Handler: fm.Handlers(), Handler: fm.Handlers(),
Errors: resultErrs(errs),
} }
for _, opt := range opts {
opt(res)
}
res.Errors = resultErrs(res.Errors)
return res return res
} }

40
pkg/flow/schema.go Normal file
View file

@ -0,0 +1,40 @@
package flow
type Type string
const (
TypeString Type = "string"
)
func (t Type) IsValid() bool {
switch t {
case TypeString:
return true
}
return false
}
type SchemaItem struct {
Type Type `json:"type"`
Name string `json:"name"`
Required bool `json:"required"`
}
type Schema []SchemaItem
type Schemer interface {
FlowSchema() Schema
}
func NewSchema(items ...SchemaItem) Schema {
return items
}
func RequiredString(name string) SchemaItem {
return SchemaItem{
Type: TypeString,
Name: name,
Required: true,
}
}

View file

@ -104,7 +104,7 @@ func (ws *wsSession) Go(ctx context.Context) error {
} }
err = authP.handleMsg(ctx, rdr) err = authP.handleMsg(ctx, rdr)
if err != nil { if err != nil || ws.refreshToken == nil {
return err return err
} }
@ -158,7 +158,7 @@ func (ws *wsSession) Go(ctx context.Context) error {
log.Error().Err(err).Msg("handleMsg") log.Error().Err(err).Msg("handleMsg")
} }
case <-ctx.Done(): case <-ctx.Done():
close(msgChan) close(msgChan) // maybe remove this?
return nil return nil
case m := <-write: case m := <-write:
err := ws.conn.WriteJSON(m) err := ws.conn.WriteJSON(m)