This commit is contained in:
Daniel 2022-11-13 19:06:53 -05:00
parent c618197c54
commit ae00c1534d
5 changed files with 40 additions and 7 deletions

View file

@ -24,7 +24,7 @@ type HAUser struct {
} }
func (hau *HAUser) UserData() provider.ProviderUser { func (hau *HAUser) UserData() provider.ProviderUser {
return &UserData{ return &UserData{ // strip secret
Username: hau.Username, Username: hau.Username,
} }
} }
@ -44,6 +44,7 @@ func (h *HAUser) ProviderUserData() interface{} { return h.UserData() }
type HomeAssistantProvider struct { type HomeAssistantProvider struct {
provider.AuthProviderBase `json:"-"` provider.AuthProviderBase `json:"-"`
Users []HAUser `json:"users"` Users []HAUser `json:"users"`
userMap map[string]*HAUser
} }
func NewHAProvider(s storage.Store) (provider.AuthProvider, error) { func NewHAProvider(s storage.Store) (provider.AuthProvider, error) {
@ -59,13 +60,25 @@ func NewHAProvider(s storage.Store) (provider.AuthProvider, error) {
return hap, err return hap, err
} }
for i := range hap.Users { hap.userMap = make(map[string]*HAUser)
for i, u := range hap.Users {
hap.Users[i].AuthProvider = hap hap.Users[i].AuthProvider = hap
hap.userMap[u.Username] = &hap.Users[i]
} }
return hap, nil return hap, nil
} }
func (hap *HomeAssistantProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser {
u, has := hap.userMap[pu.(*HAUser).Username]
if !has {
return nil
}
return u
}
func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) { func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) {
return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost) return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
} }
@ -110,7 +123,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]i
} }
func (hap *HomeAssistantProvider) NewCredData() interface{} { func (hap *HomeAssistantProvider) NewCredData() interface{} {
return &UserData{} return &HAUser{}
} }
func (hap *HomeAssistantProvider) FlowSchema() []provider.FlowSchemaItem { func (hap *HomeAssistantProvider) FlowSchema() []provider.FlowSchemaItem {

View file

@ -16,6 +16,7 @@ type AuthProvider interface { // TODO: this should include stepping
FlowSchema() []FlowSchemaItem FlowSchema() []FlowSchemaItem
NewCredData() interface{} NewCredData() interface{}
ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool) ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool)
Lookup(ProviderUser) ProviderUser
} }
func Register(providerName string, f func(storage.Store) (AuthProvider, error)) { func Register(providerName string, f func(storage.Store) (AuthProvider, error)) {
@ -23,6 +24,7 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
} }
type ProviderUser interface { type ProviderUser interface {
// TODO: make sure this is sane with all the ProviderUser and UserData type stuf
UserData() ProviderUser UserData() ProviderUser
} }

View file

@ -46,6 +46,10 @@ func New(s storage.Store) (provider.AuthProvider, error) {
return hap, nil return hap, nil
} }
func (tnp *TrustedNetworksProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser {
return pu
}
func (hap *TrustedNetworksProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) { func (hap *TrustedNetworksProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) {
/* /*
if req.RemoteAddr in allowed then do the thing if req.RemoteAddr in allowed then do the thing

View file

@ -124,9 +124,7 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request
return nil return nil
} }
cred := &Credential{ cred := a.store.Credential(user)
user: user,
}
return cred return cred
} }

View file

@ -14,6 +14,7 @@ const (
type AuthStore interface { type AuthStore interface {
User(UserID) *User User(UserID) *User
Credential(provider.ProviderUser) *Credential
} }
type authStore struct { type authStore struct {
@ -23,6 +24,16 @@ type authStore struct {
Refresh []RefreshToken `json:"refresh_tokens"` Refresh []RefreshToken `json:"refresh_tokens"`
userMap map[UserID]*User userMap map[UserID]*User
providerUsers map[provider.ProviderUser]*Credential
}
func (as *authStore) Credential(p provider.ProviderUser) *Credential {
c, have := as.providerUsers[p]
if !have {
return nil
}
return c
} }
func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) { func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) {
@ -30,6 +41,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
err = s.Get(AuthStoreKey, as) err = s.Get(AuthStoreKey, as)
as.userMap = make(map[UserID]*User) as.userMap = make(map[UserID]*User)
as.providerUsers = make(map[provider.ProviderUser]*Credential)
for _, u := range as.Users { for _, u := range as.Users {
as.userMap[u.ID] = &u as.userMap[u.ID] = &u
@ -49,7 +61,11 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
return nil, err return nil, err
} }
c.user = pd.(provider.ProviderUser) c.user = prov.Lookup(pd.(provider.ProviderUser))
if c.user == nil {
return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName())
}
as.providerUsers[c.user] = &c
} }
} }