From 6aa2c46717c2afabd707bf7705e58b441163676d Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 18 Dec 2022 21:26:34 -0500 Subject: [PATCH] prejwt --- pkg/auth/authenticator.go | 4 +- pkg/auth/flow.go | 1 - pkg/auth/provider/hass/provider.go | 36 ++++++++++++++- pkg/auth/provider/provider.go | 32 +------------- pkg/auth/provider/trustednets/trustednets.go | 18 +++++++- pkg/auth/session.go | 29 ++++++------ pkg/auth/store.go | 46 +++++++++++++++----- pkg/auth/user.go | 7 +-- 8 files changed, 107 insertions(+), 66 deletions(-) diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index b88d98a..991a9a0 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -25,7 +25,7 @@ var ( type Authenticator struct { store AuthStore flows *AuthFlowManager - authCodes authCodeStore + authCodes authCodeStore providers map[string]provider.AuthProvider } @@ -112,3 +112,5 @@ func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]int return nil, clientID, ErrInvalidAuth } + +//func (a *Authenticator) GetOrCreateCreds( diff --git a/pkg/auth/flow.go b/pkg/auth/flow.go index f2f74db..8bd8681 100644 --- a/pkg/auth/flow.go +++ b/pkg/auth/flow.go @@ -103,7 +103,6 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error { if err != nil { return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()})) } - user, clientID, err := a.Check(f, c.Request(), rm) switch err { diff --git a/pkg/auth/provider/hass/provider.go b/pkg/auth/provider/hass/provider.go index e6c2624..ad718ea 100644 --- a/pkg/auth/provider/hass/provider.go +++ b/pkg/auth/provider/hass/provider.go @@ -26,12 +26,23 @@ type HAUser struct { func (hau *HAUser) UserData() provider.ProviderUser { return &UserData{ // strip secret - Username: hau.Username, + Username: hau.Username, + AuthProvider: hau.AuthProvider, } } +func (hau *HAUser) Provider() provider.AuthProvider { + return hau.AuthProvider +} + +func (hau *UserData) Provider() provider.AuthProvider { + return hau.AuthProvider +} + type UserData struct { Username string `json:"username"` + + provider.AuthProvider `json:"-"` } func (ud *UserData) UserData() provider.ProviderUser { @@ -84,6 +95,28 @@ func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) { return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost) } +func (hap *HomeAssistantProvider) EqualCreds(c1, c2 provider.ProviderUser) bool { + switch c1c := c1.(type) { + case *HAUser: + switch c2c := c2.(type) { + case *HAUser: + return c2c.Username == c1c.Username + case *UserData: + return c2c.Username == c1c.Username + } + case *UserData: + switch c2c := c2.(type) { + case *HAUser: + return c2c.Username == c1c.Username + case *UserData: + return c2c.Username == c1c.Username + } + + } + + return false +} + func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) { usernameE, hasU := rm["username"] passwordE, hasP := rm["password"] @@ -117,6 +150,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]i err = bcrypt.CompareHashAndPassword(hash, []byte(password)) if err == nil { + found.AuthProvider = hap return found, true } diff --git a/pkg/auth/provider/provider.go b/pkg/auth/provider/provider.go index cc40216..9a1287e 100644 --- a/pkg/auth/provider/provider.go +++ b/pkg/auth/provider/provider.go @@ -17,6 +17,7 @@ type AuthProvider interface { // TODO: this should include stepping FlowSchema() flow.Schema NewCredData() interface{} ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool) + EqualCreds(c1, c2 ProviderUser) bool Lookup(ProviderUser) ProviderUser } @@ -24,39 +25,10 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error)) Providers[providerName] = f } -type Credentials struct { - ID CredID `json:"id"` - UserID UserID `json:"user_id"` - AuthProviderType string `json:"auth_provider_type"` - AuthProviderID *string `json:"auth_provider_id"` - DataRaw *json.RawMessage `json:"data,omitempty"` - user ProviderUser `json:"-"` -} - -func (cred *Credentials) MarshalJSON() ([]byte, error) { - type CredAlias Credentials // alias so ø method set and we don't recurse - nCd := (*CredAlias)(cred) - - if cred.user != nil { - providerData := cred.user.UserData() - if providerData != nil { - b, err := json.Marshal(providerData) - if err != nil { - return nil, err - } - - dr := json.RawMessage(b) - nCd.DataRaw = &dr - } - } - - return json.Marshal(nCd) -} - type ProviderUser interface { // TODO: make sure this is sane with all the ProviderUser and UserData type stuff UserData() ProviderUser - Credentials() *Credentials + Provider() AuthProvider } type AuthProviderBase struct { diff --git a/pkg/auth/provider/trustednets/trustednets.go b/pkg/auth/provider/trustednets/trustednets.go index 93dfe05..e14c7fe 100644 --- a/pkg/auth/provider/trustednets/trustednets.go +++ b/pkg/auth/provider/trustednets/trustednets.go @@ -18,12 +18,23 @@ type User struct { func (hau *User) UserData() provider.ProviderUser { return &UserData{ - UserID: hau.UserID, + UserID: hau.UserID, + AuthProvider: hau.AuthProvider, } } +func (hau *UserData) Provider() provider.AuthProvider { + return hau.AuthProvider +} + +func (hau *User) Provider() provider.AuthProvider { + return hau.AuthProvider +} + type UserData struct { UserID string `json:"user_id"` + + provider.AuthProvider `json:"-"` } func (ud *UserData) UserData() provider.ProviderUser { @@ -36,6 +47,11 @@ type TrustedNetworksProvider struct { provider.AuthProviderBase `json:"-"` } +func (hap *TrustedNetworksProvider) EqualCreds(c1, c2 provider.ProviderUser) bool { + panic("not implemented") + return false +} + func New(s storage.Store) (provider.AuthProvider, error) { hap := &TrustedNetworksProvider{ AuthProviderBase: provider.AuthProviderBase{ diff --git a/pkg/auth/session.go b/pkg/auth/session.go index f7b9bb6..17a4fd3 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -10,8 +10,6 @@ import ( "dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" - - "github.com/rs/zerolog/log" ) type ( @@ -47,7 +45,7 @@ type authCodeStore struct { type authCodeTuple struct { ClientID ClientID - Code AuthCode + Code AuthCode } func (t *authCodeTuple) IsValid() bool { @@ -90,7 +88,6 @@ func (ss *authCodeStore) cull() { } func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { - log.Info().Msgf("store cred is %+v", cred) ss.cull() code := generate.UUID() ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} @@ -112,20 +109,21 @@ func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials } type Credentials struct { - ID CredID `json:"id"` - UserID UserID `json:"user_id"` - AuthProviderType string `json:"auth_provider_type"` - AuthProviderID *string `json:"auth_provider_id"` - DataRaw *json.RawMessage `json:"data,omitempty"` - user provider.ProviderUser `json:"-"` + ID CredID `json:"id"` + UserID UserID `json:"user_id"` + AuthProviderType string `json:"auth_provider_type"` + AuthProviderID *string `json:"auth_provider_id"` + DataRaw *json.RawMessage `json:"data,omitempty"` + + User provider.ProviderUser `json:"-"` } func (cred *Credentials) MarshalJSON() ([]byte, error) { type CredAlias Credentials // alias so ø method set and we don't recurse nCd := (*CredAlias)(cred) - if cred.user != nil { - providerData := cred.user.UserData() + if cred.User != nil { + providerData := cred.User.UserData() if providerData != nil { b, err := json.Marshal(providerData) if err != nil { @@ -170,14 +168,15 @@ func (c *ClientID) IsValid() bool { } type AuthCode string + func (ac *AuthCode) IsValid() bool { return *ac != "" } type TokenRequest struct { - ClientID ClientID `form:"client_id"` - Code AuthCode `form:"code"` - GrantType GrantType `form:"grant_type"` + ClientID ClientID `form:"client_id"` + Code AuthCode `form:"code"` + GrantType GrantType `form:"grant_type"` } func (a *Authenticator) TokenHandler(c echo.Context) error { diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 1ed5ad8..e8d854f 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/storage" @@ -20,19 +21,40 @@ type AuthStore interface { } type authStore struct { - Users []User `json:"users"` - Groups []Group `json:"groups"` - Credentials []Credentials `json:"credentials"` - Refresh []RefreshToken `json:"refresh_tokens"` + Users []*User `json:"users"` + Groups []*Group `json:"groups"` + Credentials []*Credentials `json:"credentials"` + Refresh []*RefreshToken `json:"refresh_tokens"` userMap map[UserID]*User providerUsers map[provider.ProviderUser]*Credentials } +func strPtrEq(n1, n2 *string) bool { + return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2)) +} + func (as *authStore) Credential(p provider.ProviderUser) *Credentials { - c, have := as.providerUsers[p] - if !have { - return nil + for _, cr := range as.Credentials { + if p.Provider() != nil && + strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && + cr.AuthProviderType == p.Provider().ProviderType() && + p.Provider().EqualCreds(cr.User.UserData(), p.UserData()) { + return cr + } + } + + return nil +} + +func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { + // XXX: probably broken + prov := p.Provider() + id := generate.UUID() + c := &Credentials{ + ID: CredID(id), + AuthProviderType: prov.ProviderBase().Type, + AuthProviderID: prov.ProviderBase().ID, } return c @@ -46,7 +68,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) as.providerUsers = make(map[provider.ProviderUser]*Credentials) for _, u := range as.Users { - as.userMap[u.ID] = &u + as.userMap[u.ID] = u } for _, c := range as.Credentials { @@ -63,11 +85,11 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) return nil, err } - c.user = prov.Lookup(pd.(provider.ProviderUser)) - if c.user == nil { + c.User = prov.Lookup(pd.(provider.ProviderUser)) + if c.User == nil { return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName()) } - as.providerUsers[c.user] = &c + as.providerUsers[c.User] = c } u, hasUser := as.userMap[c.UserID] @@ -90,7 +112,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) // don't leak memory for j := i; j < len(as.Refresh); j++ { - as.Refresh[j] = RefreshToken{} + as.Refresh[j] = nil } as.Refresh = as.Refresh[:i] diff --git a/pkg/auth/user.go b/pkg/auth/user.go index d7d41b9..e65b789 100644 --- a/pkg/auth/user.go +++ b/pkg/auth/user.go @@ -1,8 +1,6 @@ package auth -import ( - "github.com/rs/zerolog/log" -) +import () type UserID string type GroupID string @@ -19,7 +17,7 @@ type User struct { Data interface{} `json:"data,omitempty"` UserMetadata - Creds []Credentials `json:"-"` + Creds []*Credentials `json:"-"` } type UserMetadata struct { @@ -39,7 +37,6 @@ func (u *User) allowedToAuth() error { } func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) { - log.Debug().Interface("userdata", c).Msg("getOrCreateUser") u := a.store.User(c.UserID) if u == nil { return nil, ErrInvalidAuth