This commit is contained in:
Daniel Ponte 2022-12-18 21:26:34 -05:00
parent 9224b21db7
commit 6aa2c46717
8 changed files with 107 additions and 66 deletions

View file

@ -112,3 +112,5 @@ func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]int
return nil, clientID, ErrInvalidAuth return nil, clientID, ErrInvalidAuth
} }
//func (a *Authenticator) GetOrCreateCreds(

View file

@ -104,7 +104,6 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()})) return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()}))
} }
user, clientID, err := a.Check(f, c.Request(), rm) user, clientID, err := a.Check(f, c.Request(), rm)
switch err { switch err {
case nil: case nil:

View file

@ -27,11 +27,22 @@ type HAUser struct {
func (hau *HAUser) UserData() provider.ProviderUser { func (hau *HAUser) UserData() provider.ProviderUser {
return &UserData{ // strip secret return &UserData{ // strip secret
Username: hau.Username, Username: hau.Username,
AuthProvider: hau.AuthProvider,
} }
} }
func (hau *HAUser) Provider() provider.AuthProvider {
return hau.AuthProvider
}
func (hau *UserData) Provider() provider.AuthProvider {
return hau.AuthProvider
}
type UserData struct { type UserData struct {
Username string `json:"username"` Username string `json:"username"`
provider.AuthProvider `json:"-"`
} }
func (ud *UserData) UserData() provider.ProviderUser { func (ud *UserData) UserData() provider.ProviderUser {
@ -84,6 +95,28 @@ func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) {
return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost) return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
} }
func (hap *HomeAssistantProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
switch c1c := c1.(type) {
case *HAUser:
switch c2c := c2.(type) {
case *HAUser:
return c2c.Username == c1c.Username
case *UserData:
return c2c.Username == c1c.Username
}
case *UserData:
switch c2c := c2.(type) {
case *HAUser:
return c2c.Username == c1c.Username
case *UserData:
return c2c.Username == c1c.Username
}
}
return false
}
func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) { func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) {
usernameE, hasU := rm["username"] usernameE, hasU := rm["username"]
passwordE, hasP := rm["password"] passwordE, hasP := rm["password"]
@ -117,6 +150,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]i
err = bcrypt.CompareHashAndPassword(hash, []byte(password)) err = bcrypt.CompareHashAndPassword(hash, []byte(password))
if err == nil { if err == nil {
found.AuthProvider = hap
return found, true return found, true
} }

View file

@ -17,6 +17,7 @@ type AuthProvider interface { // TODO: this should include stepping
FlowSchema() flow.Schema FlowSchema() flow.Schema
NewCredData() interface{} NewCredData() interface{}
ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool) ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool)
EqualCreds(c1, c2 ProviderUser) bool
Lookup(ProviderUser) ProviderUser Lookup(ProviderUser) ProviderUser
} }
@ -24,39 +25,10 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
Providers[providerName] = f Providers[providerName] = f
} }
type Credentials struct {
ID CredID `json:"id"`
UserID UserID `json:"user_id"`
AuthProviderType string `json:"auth_provider_type"`
AuthProviderID *string `json:"auth_provider_id"`
DataRaw *json.RawMessage `json:"data,omitempty"`
user ProviderUser `json:"-"`
}
func (cred *Credentials) MarshalJSON() ([]byte, error) {
type CredAlias Credentials // alias so ø method set and we don't recurse
nCd := (*CredAlias)(cred)
if cred.user != nil {
providerData := cred.user.UserData()
if providerData != nil {
b, err := json.Marshal(providerData)
if err != nil {
return nil, err
}
dr := json.RawMessage(b)
nCd.DataRaw = &dr
}
}
return json.Marshal(nCd)
}
type ProviderUser interface { type ProviderUser interface {
// TODO: make sure this is sane with all the ProviderUser and UserData type stuff // TODO: make sure this is sane with all the ProviderUser and UserData type stuff
UserData() ProviderUser UserData() ProviderUser
Credentials() *Credentials Provider() AuthProvider
} }
type AuthProviderBase struct { type AuthProviderBase struct {

View file

@ -19,11 +19,22 @@ type User struct {
func (hau *User) UserData() provider.ProviderUser { func (hau *User) UserData() provider.ProviderUser {
return &UserData{ return &UserData{
UserID: hau.UserID, UserID: hau.UserID,
AuthProvider: hau.AuthProvider,
} }
} }
func (hau *UserData) Provider() provider.AuthProvider {
return hau.AuthProvider
}
func (hau *User) Provider() provider.AuthProvider {
return hau.AuthProvider
}
type UserData struct { type UserData struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
provider.AuthProvider `json:"-"`
} }
func (ud *UserData) UserData() provider.ProviderUser { func (ud *UserData) UserData() provider.ProviderUser {
@ -36,6 +47,11 @@ type TrustedNetworksProvider struct {
provider.AuthProviderBase `json:"-"` provider.AuthProviderBase `json:"-"`
} }
func (hap *TrustedNetworksProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
panic("not implemented")
return false
}
func New(s storage.Store) (provider.AuthProvider, error) { func New(s storage.Store) (provider.AuthProvider, error) {
hap := &TrustedNetworksProvider{ hap := &TrustedNetworksProvider{
AuthProviderBase: provider.AuthProviderBase{ AuthProviderBase: provider.AuthProviderBase{

View file

@ -10,8 +10,6 @@ import (
"dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/internal/generate"
"dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/auth/provider"
"github.com/rs/zerolog/log"
) )
type ( type (
@ -90,7 +88,6 @@ func (ss *authCodeStore) cull() {
} }
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
log.Info().Msgf("store cred is %+v", cred)
ss.cull() ss.cull()
code := generate.UUID() code := generate.UUID()
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
@ -117,15 +114,16 @@ type Credentials struct {
AuthProviderType string `json:"auth_provider_type"` AuthProviderType string `json:"auth_provider_type"`
AuthProviderID *string `json:"auth_provider_id"` AuthProviderID *string `json:"auth_provider_id"`
DataRaw *json.RawMessage `json:"data,omitempty"` DataRaw *json.RawMessage `json:"data,omitempty"`
user provider.ProviderUser `json:"-"`
User provider.ProviderUser `json:"-"`
} }
func (cred *Credentials) MarshalJSON() ([]byte, error) { func (cred *Credentials) MarshalJSON() ([]byte, error) {
type CredAlias Credentials // alias so ø method set and we don't recurse type CredAlias Credentials // alias so ø method set and we don't recurse
nCd := (*CredAlias)(cred) nCd := (*CredAlias)(cred)
if cred.user != nil { if cred.User != nil {
providerData := cred.user.UserData() providerData := cred.User.UserData()
if providerData != nil { if providerData != nil {
b, err := json.Marshal(providerData) b, err := json.Marshal(providerData)
if err != nil { if err != nil {
@ -170,6 +168,7 @@ func (c *ClientID) IsValid() bool {
} }
type AuthCode string type AuthCode string
func (ac *AuthCode) IsValid() bool { func (ac *AuthCode) IsValid() bool {
return *ac != "" return *ac != ""
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"dynatron.me/x/blasphem/internal/generate"
"dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/auth/provider"
"dynatron.me/x/blasphem/pkg/storage" "dynatron.me/x/blasphem/pkg/storage"
@ -20,19 +21,40 @@ type AuthStore interface {
} }
type authStore struct { type authStore struct {
Users []User `json:"users"` Users []*User `json:"users"`
Groups []Group `json:"groups"` Groups []*Group `json:"groups"`
Credentials []Credentials `json:"credentials"` Credentials []*Credentials `json:"credentials"`
Refresh []RefreshToken `json:"refresh_tokens"` Refresh []*RefreshToken `json:"refresh_tokens"`
userMap map[UserID]*User userMap map[UserID]*User
providerUsers map[provider.ProviderUser]*Credentials providerUsers map[provider.ProviderUser]*Credentials
} }
func strPtrEq(n1, n2 *string) bool {
return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2))
}
func (as *authStore) Credential(p provider.ProviderUser) *Credentials { func (as *authStore) Credential(p provider.ProviderUser) *Credentials {
c, have := as.providerUsers[p] for _, cr := range as.Credentials {
if !have { if p.Provider() != nil &&
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
cr.AuthProviderType == p.Provider().ProviderType() &&
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()) {
return cr
}
}
return nil return nil
}
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
// XXX: probably broken
prov := p.Provider()
id := generate.UUID()
c := &Credentials{
ID: CredID(id),
AuthProviderType: prov.ProviderBase().Type,
AuthProviderID: prov.ProviderBase().ID,
} }
return c return c
@ -46,7 +68,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
as.providerUsers = make(map[provider.ProviderUser]*Credentials) as.providerUsers = make(map[provider.ProviderUser]*Credentials)
for _, u := range as.Users { for _, u := range as.Users {
as.userMap[u.ID] = &u as.userMap[u.ID] = u
} }
for _, c := range as.Credentials { for _, c := range as.Credentials {
@ -63,11 +85,11 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
return nil, err return nil, err
} }
c.user = prov.Lookup(pd.(provider.ProviderUser)) c.User = prov.Lookup(pd.(provider.ProviderUser))
if c.user == nil { if c.User == nil {
return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName()) return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName())
} }
as.providerUsers[c.user] = &c as.providerUsers[c.User] = c
} }
u, hasUser := as.userMap[c.UserID] u, hasUser := as.userMap[c.UserID]
@ -90,7 +112,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
// don't leak memory // don't leak memory
for j := i; j < len(as.Refresh); j++ { for j := i; j < len(as.Refresh); j++ {
as.Refresh[j] = RefreshToken{} as.Refresh[j] = nil
} }
as.Refresh = as.Refresh[:i] as.Refresh = as.Refresh[:i]

View file

@ -1,8 +1,6 @@
package auth package auth
import ( import ()
"github.com/rs/zerolog/log"
)
type UserID string type UserID string
type GroupID string type GroupID string
@ -19,7 +17,7 @@ type User struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
UserMetadata UserMetadata
Creds []Credentials `json:"-"` Creds []*Credentials `json:"-"`
} }
type UserMetadata struct { type UserMetadata struct {
@ -39,7 +37,6 @@ func (u *User) allowedToAuth() error {
} }
func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) { func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) {
log.Debug().Interface("userdata", c).Msg("getOrCreateUser")
u := a.store.User(c.UserID) u := a.store.User(c.UserID)
if u == nil { if u == nil {
return nil, ErrInvalidAuth return nil, ErrInvalidAuth