diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index f0ac84f..b88d98a 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -25,7 +25,7 @@ var ( type Authenticator struct { store AuthStore flows *AuthFlowManager - sessions AccessSessionStore + authCodes authCodeStore providers map[string]provider.AuthProvider } @@ -60,7 +60,7 @@ func (a *Authenticator) InitAuth(s storage.Store) error { a.flows = NewAuthFlowManager() - a.sessions.init() + a.authCodes.init() var err error a.store, err = a.newAuthStore(s) @@ -91,24 +91,24 @@ func (a *Authenticator) ProvidersHandler(c echo.Context) error { return c.JSON(http.StatusOK, providers) } -func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (provider.ProviderUser, error) { +func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) { cID, hasCID := rm["client_id"] - cIDStr, cidIsStr := cID.(string) - if !hasCID || !cidIsStr || cIDStr == "" || cIDStr != string(f.ClientID) { - return nil, ErrInvalidAuth + clientID, cidIsStr := cID.(string) + if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) { + return nil, clientID, ErrInvalidAuth } p := a.Provider(f.Handler.String()) if p == nil { - return nil, ErrInvalidAuth + return nil, clientID, ErrInvalidAuth } user, success := p.ValidateCreds(req, rm) if success { log.Info().Interface("user", user.UserData()).Msg("Login success") - return user, nil + return user, clientID, nil } - return nil, ErrInvalidAuth + return nil, clientID, ErrInvalidAuth } diff --git a/pkg/auth/flow.go b/pkg/auth/flow.go index 69e5abd..f2f74db 100644 --- a/pkg/auth/flow.go +++ b/pkg/auth/flow.go @@ -103,17 +103,19 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error { if err != nil { return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()})) } + - user, err := a.Check(f, c.Request(), rm) + user, clientID, err := a.Check(f, c.Request(), rm) switch err { case nil: + creds := a.store.Credential(user) finishedFlow := flow.Result{} a.flows.Remove(f) copier.Copy(&finishedFlow, f) finishedFlow.Type = flow.TypeCreateEntry finishedFlow.Title = common.AppNamePtr() finishedFlow.Version = common.IntPtr(1) - finishedFlow.Result = a.NewAccessToken(c.Request(), user) + finishedFlow.Result = a.NewAuthCode(ClientID(clientID), creds) f.redirect(c) diff --git a/pkg/auth/provider/provider.go b/pkg/auth/provider/provider.go index f3aeb5b..cc40216 100644 --- a/pkg/auth/provider/provider.go +++ b/pkg/auth/provider/provider.go @@ -24,9 +24,39 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error)) Providers[providerName] = f } +type Credentials struct { + ID CredID `json:"id"` + UserID UserID `json:"user_id"` + AuthProviderType string `json:"auth_provider_type"` + AuthProviderID *string `json:"auth_provider_id"` + DataRaw *json.RawMessage `json:"data,omitempty"` + user ProviderUser `json:"-"` +} + +func (cred *Credentials) MarshalJSON() ([]byte, error) { + type CredAlias Credentials // alias so ø method set and we don't recurse + nCd := (*CredAlias)(cred) + + if cred.user != nil { + providerData := cred.user.UserData() + if providerData != nil { + b, err := json.Marshal(providerData) + if err != nil { + return nil, err + } + + dr := json.RawMessage(b) + nCd.DataRaw = &dr + } + } + + return json.Marshal(nCd) +} + type ProviderUser interface { // TODO: make sure this is sane with all the ProviderUser and UserData type stuff UserData() ProviderUser + Credentials() *Credentials } type AuthProviderBase struct { diff --git a/pkg/auth/session.go b/pkg/auth/session.go index 301b182..f7b9bb6 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -10,6 +10,8 @@ import ( "dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" + + "github.com/rs/zerolog/log" ) type ( @@ -38,62 +40,78 @@ func (rt *RefreshToken) IsValid() bool { return rt.JWTKey != "" } -type AccessSessionStore struct { - s map[AccessTokenID]*AccessToken +type authCodeStore struct { + s map[authCodeTuple]flowResult lastCull time.Time } -type AccessTokenID string +type authCodeTuple struct { + ClientID ClientID + Code AuthCode +} -func (t *AccessTokenID) IsValid() bool { +func (t *authCodeTuple) IsValid() bool { // TODO: more validation than this - return *t != "" + return t.Code != "" } -type AccessToken struct { // TODO: jwt bro - ID AccessTokenID - Ctime time.Time - Expires time.Time - Addr string - - user provider.ProviderUser `json:"-"` +type flowResult struct { + Time time.Time + Cred *Credentials + // TODO: remove this comment below \/ + //user provider.ProviderUser `json:"-"` } -func (ss *AccessSessionStore) init() { - ss.s = make(map[AccessTokenID]*AccessToken) +// OAuth 4.2.1 spec recommends 10 minutes +const authCodeExpire = 10 * time.Minute + +func (f *flowResult) IsValid(now time.Time) bool { + if now.After(f.Time.Add(authCodeExpire)) { + return false + } + + return true +} + +func (ss *authCodeStore) init() { + ss.s = make(map[authCodeTuple]flowResult) } const cullInterval = 5 * time.Minute -func (ss *AccessSessionStore) cull() { +func (ss *authCodeStore) cull() { if now := time.Now(); now.Sub(ss.lastCull) > cullInterval { for k, v := range ss.s { - if now.After(v.Expires) { + if !v.IsValid(now) { delete(ss.s, k) } } } } -func (ss *AccessSessionStore) register(t *AccessToken) { +func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { + log.Info().Msgf("store cred is %+v", cred) ss.cull() - ss.s[t.ID] = t + code := generate.UUID() + ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} + + return code } -func (ss *AccessSessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) { - if t, hasToken := ss.s[tr.Code]; hasToken { +func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) { + key := authCodeTuple{tr.ClientID, tr.Code} + if t, hasCode := ss.s[key]; hasCode { + defer delete(ss.s, key) // TODO: JWT - if t.Expires.After(time.Now()) { - return t.user, true - } else { - delete(ss.s, t.ID) + if t.IsValid(time.Now()) { + return t.Cred, true } } return nil, false } -type Credential struct { +type Credentials struct { ID CredID `json:"id"` UserID UserID `json:"user_id"` AuthProviderType string `json:"auth_provider_type"` @@ -102,8 +120,8 @@ type Credential struct { user provider.ProviderUser `json:"-"` } -func (cred *Credential) MarshalJSON() ([]byte, error) { - type CredAlias Credential // alias so ø method set and we don't recurse +func (cred *Credentials) MarshalJSON() ([]byte, error) { + type CredAlias Credentials // alias so ø method set and we don't recurse nCd := (*CredAlias)(cred) if cred.user != nil { @@ -122,35 +140,19 @@ func (cred *Credential) MarshalJSON() ([]byte, error) { return json.Marshal(nCd) } -func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credential { - user, success := a.sessions.verify(tr, r) +func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials { + cred, success := a.authCodes.verify(tr, r) if !success { return nil } - cred := a.store.Credential(user) - return cred } const defaultExpiration = 15 * time.Minute -func (a *Authenticator) NewAccessToken(r *http.Request, user provider.ProviderUser) AccessTokenID { - id := AccessTokenID(generate.UUID()) - now := time.Now() - - t := &AccessToken{ - ID: id, - Ctime: now, - Expires: now.Add(defaultExpiration), - Addr: r.RemoteAddr, - - user: user, - } - - a.sessions.register(t) - - return id +func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string { + return a.authCodes.store(clientID, cred) } type GrantType string @@ -167,9 +169,14 @@ func (c *ClientID) IsValid() bool { return *c != "" } +type AuthCode string +func (ac *AuthCode) IsValid() bool { + return *ac != "" +} + type TokenRequest struct { ClientID ClientID `form:"client_id"` - Code AccessTokenID `form:"code"` + Code AuthCode `form:"code"` GrantType GrantType `form:"grant_type"` } diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 2245072..1ed5ad8 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -6,6 +6,8 @@ import ( "dynatron.me/x/blasphem/pkg/auth/provider" "dynatron.me/x/blasphem/pkg/storage" + + "github.com/rs/zerolog/log" ) const ( @@ -14,20 +16,20 @@ const ( type AuthStore interface { User(UserID) *User - Credential(provider.ProviderUser) *Credential + Credential(provider.ProviderUser) *Credentials } type authStore struct { Users []User `json:"users"` Groups []Group `json:"groups"` - Credentials []Credential `json:"credentials"` + Credentials []Credentials `json:"credentials"` Refresh []RefreshToken `json:"refresh_tokens"` userMap map[UserID]*User - providerUsers map[provider.ProviderUser]*Credential + providerUsers map[provider.ProviderUser]*Credentials } -func (as *authStore) Credential(p provider.ProviderUser) *Credential { +func (as *authStore) Credential(p provider.ProviderUser) *Credentials { c, have := as.providerUsers[p] if !have { return nil @@ -41,7 +43,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) err = s.Get(AuthStoreKey, as) as.userMap = make(map[UserID]*User) - as.providerUsers = make(map[provider.ProviderUser]*Credential) + as.providerUsers = make(map[provider.ProviderUser]*Credentials) for _, u := range as.Users { as.userMap[u.ID] = &u @@ -67,6 +69,14 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) } as.providerUsers[c.user] = &c } + + u, hasUser := as.userMap[c.UserID] + if !hasUser { + log.Error().Str("userid", string(c.UserID)).Msg("no such userid in map") + continue + } + + u.Creds = append(u.Creds, c) } // remove invalid RefreshTokens diff --git a/pkg/auth/user.go b/pkg/auth/user.go index a721c22..d7d41b9 100644 --- a/pkg/auth/user.go +++ b/pkg/auth/user.go @@ -18,6 +18,8 @@ type User struct { GroupIDs []GroupID `json:"group_ids"` Data interface{} `json:"data,omitempty"` UserMetadata + + Creds []Credentials `json:"-"` } type UserMetadata struct { @@ -36,7 +38,7 @@ func (u *User) allowedToAuth() error { return nil } -func (a *Authenticator) getOrCreateUser(c *Credential) (*User, error) { +func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) { log.Debug().Interface("userdata", c).Msg("getOrCreateUser") u := a.store.User(c.UserID) if u == nil { diff --git a/pkg/flow/flow.go b/pkg/flow/flow.go index 4877fc9..0015ddc 100644 --- a/pkg/flow/flow.go +++ b/pkg/flow/flow.go @@ -181,7 +181,7 @@ func (fs *FlowManager) Remove(f Handler) { delete(fs.flows, f.FlowID()) } -const cullAge = time.Minute * 30 +const cullAge = time.Minute * 10 func (fs FlowStore) cull() { for k, v := range fs {