package auth import ( "crypto/subtle" "encoding/json" "fmt" "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/storage" "github.com/rs/zerolog/log" ) const ( AuthStoreKey = "auth" ) type AuthStore interface { User(UserID) *User GetCredential(provider.ProviderUser) *Credentials PutRefreshToken(*RefreshToken) (*RefreshToken, error) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken GetRefreshToken(RefreshTokenID) *RefreshToken } type authStore struct { storage.Item `json:"-"` Users []*User `json:"users"` Groups []*Group `json:"groups"` Credentials []*Credentials `json:"credentials"` Refresh []*RefreshToken `json:"refresh_tokens"` userMap map[UserID]*User providerUsers map[provider.ProviderUser]*Credentials store storage.Store } func (as *authStore) sync() { err := as.store.Flush(as.ItemKey()) if err != nil { log.Error().Err(err).Msg("sync authStore") } } func strPtrEq(n1, n2 *string) bool { return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2)) } func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials { var found *Credentials for _, cr := range as.Credentials { if p != nil && (p == cr.User || (p.Provider() != nil && strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && cr.AuthProviderType == p.Provider().ProviderType() && p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) { found = cr } } return found } func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) { e := func(es string, a ...interface{}) (*RefreshToken, error) { return nil, fmt.Errorf(es, a...) } u, hasUser := as.userMap[rt.UserID] if !hasUser { return e("no such user %v", rt.UserID) } as.Refresh = append(as.Refresh, rt) u.RefreshTokens = append(u.RefreshTokens, rt) as.sync() return rt, nil } func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken { var found *RefreshToken for _, u := range as.Users { for _, rt := range u.RefreshTokens { if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 { found = rt found.User = u } } } return found } func (as *authStore) GetRefreshToken(tid RefreshTokenID) *RefreshToken { var found *RefreshToken for _, u := range as.Users { for _, rt := range u.RefreshTokens { if subtle.ConstantTimeCompare([]byte(tid), []byte(rt.ID.String())) == 1 { found = rt found.User = u } } } return found } func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { // XXX: probably broken prov := p.Provider() id := generate.UUID() c := &Credentials{ ID: CredID(id), AuthProviderType: prov.ProviderBase().Type, AuthProviderID: prov.ProviderBase().ID, } return c } func (a *authenticator) newAuthStore(s storage.Store) (as *authStore, err error) { as = &authStore{ store: s, } as.Item, err = s.GetItem(AuthStoreKey, as) if err != nil { return } as.userMap = make(map[UserID]*User) as.providerUsers = make(map[provider.ProviderUser]*Credentials) for _, u := range as.Users { as.userMap[u.ID] = u } for _, c := range as.Credentials { prov := a.Provider(c.AuthProviderType) if prov == nil { return nil, fmt.Errorf("no such provider %s", c.AuthProviderType) } if c.DataRaw != nil { pd := prov.NewCredData() err := json.Unmarshal(*c.DataRaw, pd) if err != nil { return nil, err } c.User = prov.Lookup(pd.(provider.ProviderUser)) if c.User == nil { return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName()) } as.providerUsers[c.User] = c } u, hasUser := as.userMap[c.UserID] if !hasUser { log.Error().Str("userid", string(c.UserID)).Msg("creds no such userid in map") continue } u.Creds = append(u.Creds, c) } // remove invalid RefreshTokens i := 0 for _, rt := range as.Refresh { if rt.IsValid() { u, hasUser := as.userMap[rt.UserID] if !hasUser { log.Error().Str("userid", string(rt.UserID)).Msg("refreshtokens no such userid in map") continue } as.Refresh[i] = rt i++ u.RefreshTokens = append(u.RefreshTokens, rt) } } // don't leak memory for j := i; j < len(as.Refresh); j++ { as.Refresh[j] = nil } as.Refresh = as.Refresh[:i] return } func (s *authStore) User(uid UserID) *User { return s.userMap[uid] }