diff --git a/pkg/auth/provider/hass/provider.go b/pkg/auth/provider/hass/provider.go index 04b3d3e..a5cc06e 100644 --- a/pkg/auth/provider/hass/provider.go +++ b/pkg/auth/provider/hass/provider.go @@ -24,7 +24,7 @@ type HAUser struct { } func (hau *HAUser) UserData() provider.ProviderUser { - return &UserData{ + return &UserData{ // strip secret Username: hau.Username, } } @@ -44,6 +44,7 @@ func (h *HAUser) ProviderUserData() interface{} { return h.UserData() } type HomeAssistantProvider struct { provider.AuthProviderBase `json:"-"` Users []HAUser `json:"users"` + userMap map[string]*HAUser } func NewHAProvider(s storage.Store) (provider.AuthProvider, error) { @@ -59,13 +60,25 @@ func NewHAProvider(s storage.Store) (provider.AuthProvider, error) { return hap, err } - for i := range hap.Users { + hap.userMap = make(map[string]*HAUser) + + for i, u := range hap.Users { hap.Users[i].AuthProvider = hap + hap.userMap[u.Username] = &hap.Users[i] } return hap, nil } +func (hap *HomeAssistantProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser { + u, has := hap.userMap[pu.(*HAUser).Username] + if !has { + return nil + } + + return u +} + func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) { return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost) } @@ -110,7 +123,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]i } func (hap *HomeAssistantProvider) NewCredData() interface{} { - return &UserData{} + return &HAUser{} } func (hap *HomeAssistantProvider) FlowSchema() []provider.FlowSchemaItem { diff --git a/pkg/auth/provider/provider.go b/pkg/auth/provider/provider.go index b8f3bde..a34bed0 100644 --- a/pkg/auth/provider/provider.go +++ b/pkg/auth/provider/provider.go @@ -16,6 +16,7 @@ type AuthProvider interface { // TODO: this should include stepping FlowSchema() []FlowSchemaItem NewCredData() interface{} ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool) + Lookup(ProviderUser) ProviderUser } func Register(providerName string, f func(storage.Store) (AuthProvider, error)) { @@ -23,6 +24,7 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error)) } type ProviderUser interface { + // TODO: make sure this is sane with all the ProviderUser and UserData type stuf UserData() ProviderUser } diff --git a/pkg/auth/provider/trustednets/trustednets.go b/pkg/auth/provider/trustednets/trustednets.go index d3ae9f3..5c3dcb8 100644 --- a/pkg/auth/provider/trustednets/trustednets.go +++ b/pkg/auth/provider/trustednets/trustednets.go @@ -46,6 +46,10 @@ func New(s storage.Store) (provider.AuthProvider, error) { return hap, nil } +func (tnp *TrustedNetworksProvider) Lookup(pu provider.ProviderUser) provider.ProviderUser { + return pu +} + func (hap *TrustedNetworksProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) { /* if req.RemoteAddr in allowed then do the thing diff --git a/pkg/auth/session.go b/pkg/auth/session.go index 05ee3c3..7c2fb73 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -124,9 +124,7 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request return nil } - cred := &Credential{ - user: user, - } + cred := a.store.Credential(user) return cred } diff --git a/pkg/auth/store.go b/pkg/auth/store.go index ece416a..aa3cc40 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -14,6 +14,7 @@ const ( type AuthStore interface { User(UserID) *User + Credential(provider.ProviderUser) *Credential } type authStore struct { @@ -23,6 +24,16 @@ type authStore struct { Refresh []RefreshToken `json:"refresh_tokens"` userMap map[UserID]*User + providerUsers map[provider.ProviderUser]*Credential +} + +func (as *authStore) Credential(p provider.ProviderUser) *Credential { + c, have := as.providerUsers[p] + if !have { + return nil + } + + return c } func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) { @@ -30,6 +41,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) err = s.Get(AuthStoreKey, as) as.userMap = make(map[UserID]*User) + as.providerUsers = make(map[provider.ProviderUser]*Credential) for _, u := range as.Users { as.userMap[u.ID] = &u @@ -49,7 +61,11 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) return nil, err } - c.user = pd.(provider.ProviderUser) + c.user = prov.Lookup(pd.(provider.ProviderUser)) + if c.user == nil { + return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName()) + } + as.providerUsers[c.user] = &c } }