pre-model
This commit is contained in:
parent
3038750206
commit
9224b21db7
7 changed files with 117 additions and 66 deletions
|
@ -25,7 +25,7 @@ var (
|
|||
type Authenticator struct {
|
||||
store AuthStore
|
||||
flows *AuthFlowManager
|
||||
sessions AccessSessionStore
|
||||
authCodes authCodeStore
|
||||
providers map[string]provider.AuthProvider
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ func (a *Authenticator) InitAuth(s storage.Store) error {
|
|||
|
||||
a.flows = NewAuthFlowManager()
|
||||
|
||||
a.sessions.init()
|
||||
a.authCodes.init()
|
||||
|
||||
var err error
|
||||
a.store, err = a.newAuthStore(s)
|
||||
|
@ -91,24 +91,24 @@ func (a *Authenticator) ProvidersHandler(c echo.Context) error {
|
|||
return c.JSON(http.StatusOK, providers)
|
||||
}
|
||||
|
||||
func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (provider.ProviderUser, error) {
|
||||
func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]interface{}) (user provider.ProviderUser, clientID string, err error) {
|
||||
cID, hasCID := rm["client_id"]
|
||||
cIDStr, cidIsStr := cID.(string)
|
||||
if !hasCID || !cidIsStr || cIDStr == "" || cIDStr != string(f.ClientID) {
|
||||
return nil, ErrInvalidAuth
|
||||
clientID, cidIsStr := cID.(string)
|
||||
if !hasCID || !cidIsStr || clientID == "" || clientID != string(f.ClientID) {
|
||||
return nil, clientID, ErrInvalidAuth
|
||||
}
|
||||
|
||||
p := a.Provider(f.Handler.String())
|
||||
if p == nil {
|
||||
return nil, ErrInvalidAuth
|
||||
return nil, clientID, ErrInvalidAuth
|
||||
}
|
||||
|
||||
user, success := p.ValidateCreds(req, rm)
|
||||
|
||||
if success {
|
||||
log.Info().Interface("user", user.UserData()).Msg("Login success")
|
||||
return user, nil
|
||||
return user, clientID, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidAuth
|
||||
return nil, clientID, ErrInvalidAuth
|
||||
}
|
||||
|
|
|
@ -104,16 +104,18 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
|
|||
return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()}))
|
||||
}
|
||||
|
||||
user, err := a.Check(f, c.Request(), rm)
|
||||
|
||||
user, clientID, err := a.Check(f, c.Request(), rm)
|
||||
switch err {
|
||||
case nil:
|
||||
creds := a.store.Credential(user)
|
||||
finishedFlow := flow.Result{}
|
||||
a.flows.Remove(f)
|
||||
copier.Copy(&finishedFlow, f)
|
||||
finishedFlow.Type = flow.TypeCreateEntry
|
||||
finishedFlow.Title = common.AppNamePtr()
|
||||
finishedFlow.Version = common.IntPtr(1)
|
||||
finishedFlow.Result = a.NewAccessToken(c.Request(), user)
|
||||
finishedFlow.Result = a.NewAuthCode(ClientID(clientID), creds)
|
||||
|
||||
f.redirect(c)
|
||||
|
||||
|
|
|
@ -24,9 +24,39 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
|
|||
Providers[providerName] = f
|
||||
}
|
||||
|
||||
type Credentials struct {
|
||||
ID CredID `json:"id"`
|
||||
UserID UserID `json:"user_id"`
|
||||
AuthProviderType string `json:"auth_provider_type"`
|
||||
AuthProviderID *string `json:"auth_provider_id"`
|
||||
DataRaw *json.RawMessage `json:"data,omitempty"`
|
||||
user ProviderUser `json:"-"`
|
||||
}
|
||||
|
||||
func (cred *Credentials) MarshalJSON() ([]byte, error) {
|
||||
type CredAlias Credentials // alias so ø method set and we don't recurse
|
||||
nCd := (*CredAlias)(cred)
|
||||
|
||||
if cred.user != nil {
|
||||
providerData := cred.user.UserData()
|
||||
if providerData != nil {
|
||||
b, err := json.Marshal(providerData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dr := json.RawMessage(b)
|
||||
nCd.DataRaw = &dr
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(nCd)
|
||||
}
|
||||
|
||||
type ProviderUser interface {
|
||||
// TODO: make sure this is sane with all the ProviderUser and UserData type stuff
|
||||
UserData() ProviderUser
|
||||
Credentials() *Credentials
|
||||
}
|
||||
|
||||
type AuthProviderBase struct {
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"dynatron.me/x/blasphem/internal/common"
|
||||
"dynatron.me/x/blasphem/internal/generate"
|
||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -38,62 +40,78 @@ func (rt *RefreshToken) IsValid() bool {
|
|||
return rt.JWTKey != ""
|
||||
}
|
||||
|
||||
type AccessSessionStore struct {
|
||||
s map[AccessTokenID]*AccessToken
|
||||
type authCodeStore struct {
|
||||
s map[authCodeTuple]flowResult
|
||||
lastCull time.Time
|
||||
}
|
||||
|
||||
type AccessTokenID string
|
||||
type authCodeTuple struct {
|
||||
ClientID ClientID
|
||||
Code AuthCode
|
||||
}
|
||||
|
||||
func (t *AccessTokenID) IsValid() bool {
|
||||
func (t *authCodeTuple) IsValid() bool {
|
||||
// TODO: more validation than this
|
||||
return *t != ""
|
||||
return t.Code != ""
|
||||
}
|
||||
|
||||
type AccessToken struct { // TODO: jwt bro
|
||||
ID AccessTokenID
|
||||
Ctime time.Time
|
||||
Expires time.Time
|
||||
Addr string
|
||||
|
||||
user provider.ProviderUser `json:"-"`
|
||||
type flowResult struct {
|
||||
Time time.Time
|
||||
Cred *Credentials
|
||||
// TODO: remove this comment below \/
|
||||
//user provider.ProviderUser `json:"-"`
|
||||
}
|
||||
|
||||
func (ss *AccessSessionStore) init() {
|
||||
ss.s = make(map[AccessTokenID]*AccessToken)
|
||||
// OAuth 4.2.1 spec recommends 10 minutes
|
||||
const authCodeExpire = 10 * time.Minute
|
||||
|
||||
func (f *flowResult) IsValid(now time.Time) bool {
|
||||
if now.After(f.Time.Add(authCodeExpire)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (ss *authCodeStore) init() {
|
||||
ss.s = make(map[authCodeTuple]flowResult)
|
||||
}
|
||||
|
||||
const cullInterval = 5 * time.Minute
|
||||
|
||||
func (ss *AccessSessionStore) cull() {
|
||||
func (ss *authCodeStore) cull() {
|
||||
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
||||
for k, v := range ss.s {
|
||||
if now.After(v.Expires) {
|
||||
if !v.IsValid(now) {
|
||||
delete(ss.s, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *AccessSessionStore) register(t *AccessToken) {
|
||||
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
|
||||
log.Info().Msgf("store cred is %+v", cred)
|
||||
ss.cull()
|
||||
ss.s[t.ID] = t
|
||||
code := generate.UUID()
|
||||
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
func (ss *AccessSessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
||||
if t, hasToken := ss.s[tr.Code]; hasToken {
|
||||
func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) {
|
||||
key := authCodeTuple{tr.ClientID, tr.Code}
|
||||
if t, hasCode := ss.s[key]; hasCode {
|
||||
defer delete(ss.s, key)
|
||||
// TODO: JWT
|
||||
if t.Expires.After(time.Now()) {
|
||||
return t.user, true
|
||||
} else {
|
||||
delete(ss.s, t.ID)
|
||||
if t.IsValid(time.Now()) {
|
||||
return t.Cred, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type Credential struct {
|
||||
type Credentials struct {
|
||||
ID CredID `json:"id"`
|
||||
UserID UserID `json:"user_id"`
|
||||
AuthProviderType string `json:"auth_provider_type"`
|
||||
|
@ -102,8 +120,8 @@ type Credential struct {
|
|||
user provider.ProviderUser `json:"-"`
|
||||
}
|
||||
|
||||
func (cred *Credential) MarshalJSON() ([]byte, error) {
|
||||
type CredAlias Credential // alias so ø method set and we don't recurse
|
||||
func (cred *Credentials) MarshalJSON() ([]byte, error) {
|
||||
type CredAlias Credentials // alias so ø method set and we don't recurse
|
||||
nCd := (*CredAlias)(cred)
|
||||
|
||||
if cred.user != nil {
|
||||
|
@ -122,35 +140,19 @@ func (cred *Credential) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(nCd)
|
||||
}
|
||||
|
||||
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credential {
|
||||
user, success := a.sessions.verify(tr, r)
|
||||
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials {
|
||||
cred, success := a.authCodes.verify(tr, r)
|
||||
if !success {
|
||||
return nil
|
||||
}
|
||||
|
||||
cred := a.store.Credential(user)
|
||||
|
||||
return cred
|
||||
}
|
||||
|
||||
const defaultExpiration = 15 * time.Minute
|
||||
|
||||
func (a *Authenticator) NewAccessToken(r *http.Request, user provider.ProviderUser) AccessTokenID {
|
||||
id := AccessTokenID(generate.UUID())
|
||||
now := time.Now()
|
||||
|
||||
t := &AccessToken{
|
||||
ID: id,
|
||||
Ctime: now,
|
||||
Expires: now.Add(defaultExpiration),
|
||||
Addr: r.RemoteAddr,
|
||||
|
||||
user: user,
|
||||
}
|
||||
|
||||
a.sessions.register(t)
|
||||
|
||||
return id
|
||||
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
|
||||
return a.authCodes.store(clientID, cred)
|
||||
}
|
||||
|
||||
type GrantType string
|
||||
|
@ -167,9 +169,14 @@ func (c *ClientID) IsValid() bool {
|
|||
return *c != ""
|
||||
}
|
||||
|
||||
type AuthCode string
|
||||
func (ac *AuthCode) IsValid() bool {
|
||||
return *ac != ""
|
||||
}
|
||||
|
||||
type TokenRequest struct {
|
||||
ClientID ClientID `form:"client_id"`
|
||||
Code AccessTokenID `form:"code"`
|
||||
Code AuthCode `form:"code"`
|
||||
GrantType GrantType `form:"grant_type"`
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
|
||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||
"dynatron.me/x/blasphem/pkg/storage"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -14,20 +16,20 @@ const (
|
|||
|
||||
type AuthStore interface {
|
||||
User(UserID) *User
|
||||
Credential(provider.ProviderUser) *Credential
|
||||
Credential(provider.ProviderUser) *Credentials
|
||||
}
|
||||
|
||||
type authStore struct {
|
||||
Users []User `json:"users"`
|
||||
Groups []Group `json:"groups"`
|
||||
Credentials []Credential `json:"credentials"`
|
||||
Credentials []Credentials `json:"credentials"`
|
||||
Refresh []RefreshToken `json:"refresh_tokens"`
|
||||
|
||||
userMap map[UserID]*User
|
||||
providerUsers map[provider.ProviderUser]*Credential
|
||||
providerUsers map[provider.ProviderUser]*Credentials
|
||||
}
|
||||
|
||||
func (as *authStore) Credential(p provider.ProviderUser) *Credential {
|
||||
func (as *authStore) Credential(p provider.ProviderUser) *Credentials {
|
||||
c, have := as.providerUsers[p]
|
||||
if !have {
|
||||
return nil
|
||||
|
@ -41,7 +43,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
|
|||
err = s.Get(AuthStoreKey, as)
|
||||
|
||||
as.userMap = make(map[UserID]*User)
|
||||
as.providerUsers = make(map[provider.ProviderUser]*Credential)
|
||||
as.providerUsers = make(map[provider.ProviderUser]*Credentials)
|
||||
|
||||
for _, u := range as.Users {
|
||||
as.userMap[u.ID] = &u
|
||||
|
@ -67,6 +69,14 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
|
|||
}
|
||||
as.providerUsers[c.user] = &c
|
||||
}
|
||||
|
||||
u, hasUser := as.userMap[c.UserID]
|
||||
if !hasUser {
|
||||
log.Error().Str("userid", string(c.UserID)).Msg("no such userid in map")
|
||||
continue
|
||||
}
|
||||
|
||||
u.Creds = append(u.Creds, c)
|
||||
}
|
||||
|
||||
// remove invalid RefreshTokens
|
||||
|
|
|
@ -18,6 +18,8 @@ type User struct {
|
|||
GroupIDs []GroupID `json:"group_ids"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
UserMetadata
|
||||
|
||||
Creds []Credentials `json:"-"`
|
||||
}
|
||||
|
||||
type UserMetadata struct {
|
||||
|
@ -36,7 +38,7 @@ func (u *User) allowedToAuth() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) getOrCreateUser(c *Credential) (*User, error) {
|
||||
func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) {
|
||||
log.Debug().Interface("userdata", c).Msg("getOrCreateUser")
|
||||
u := a.store.User(c.UserID)
|
||||
if u == nil {
|
||||
|
|
|
@ -181,7 +181,7 @@ func (fs *FlowManager) Remove(f Handler) {
|
|||
delete(fs.flows, f.FlowID())
|
||||
}
|
||||
|
||||
const cullAge = time.Minute * 30
|
||||
const cullAge = time.Minute * 10
|
||||
|
||||
func (fs FlowStore) cull() {
|
||||
for k, v := range fs {
|
||||
|
|
Loading…
Reference in a new issue