From c91cd6efcad91cace237b173e572c6a4558eae9a Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 19 Dec 2022 13:09:01 -0500 Subject: [PATCH] refresh token --- internal/common/common.go | 7 ++ pkg/auth/authenticator.go | 2 + pkg/auth/flow.go | 9 ++ pkg/auth/session.go | 197 ++++++++++++++++++++++++-------------- pkg/auth/store.go | 21 +++- 5 files changed, 161 insertions(+), 75 deletions(-) diff --git a/internal/common/common.go b/internal/common/common.go index a8e2b20..c8a0768 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -2,6 +2,7 @@ package common import ( + "github.com/labstack/echo/v4" "github.com/spf13/cobra" ) @@ -39,3 +40,9 @@ func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error { return err } } + +func NoCache(c echo.Context) echo.Context { + c.Response().Header().Set("Cache-Control", "no-store") + c.Response().Header().Set("Pragma", "no-cache") + return c +} diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index 2f06622..8de219a 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -3,6 +3,7 @@ package auth import ( "errors" "net/http" + "sync" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" @@ -25,6 +26,7 @@ var ( ) type Authenticator struct { + sync.Mutex store AuthStore flows *AuthFlowManager authCodes authCodeStore diff --git a/pkg/auth/flow.go b/pkg/auth/flow.go index f4dee95..c0ef7c7 100644 --- a/pkg/auth/flow.go +++ b/pkg/auth/flow.go @@ -137,6 +137,9 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error { } func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error { + a.Lock() + defer a.Unlock() + flowID := flow.FlowID(c.Param("flow_id")) if flowID == "" { @@ -156,6 +159,9 @@ func setJSON(c echo.Context) { } func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error { + a.Lock() + defer a.Unlock() + setJSON(c) var flowReq LoginFlowRequest @@ -176,6 +182,9 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error { } func (a *Authenticator) LoginFlowHandler(c echo.Context) error { + a.Lock() + defer a.Unlock() + setJSON(c) flowID := c.Param("flow_id") diff --git a/pkg/auth/session.go b/pkg/auth/session.go index f64c632..d832e49 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -33,8 +33,6 @@ func (t *authCodeTuple) IsValid() bool { type flowResult struct { Time time.Time Cred *Credentials - // TODO: remove this comment below \/ - //user provider.ProviderUser `json:"-"` } // OAuth 4.2.1 spec recommends 10 minutes @@ -64,7 +62,7 @@ func (ss *authCodeStore) cull() { } } -func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { +func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string { ss.cull() code := generate.UUID() ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} @@ -72,7 +70,7 @@ func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { return code } -func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) { +func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) { key := authCodeTuple{tr.ClientID, tr.Code} if t, hasCode := ss.s[key]; hasCode { defer delete(ss.s, key) @@ -115,10 +113,15 @@ func (cred *Credentials) MarshalJSON() ([]byte, error) { } type ( - TokenType string - RefreshTokenID string + TokenType string + RefreshTokenID string + RefreshTokenToken string ) +func (rti RefreshTokenID) String() string { return string(rti) } + +func (rti RefreshTokenToken) IsValid() bool { return rti != "" } + const ( TokenTypeSystem TokenType = "system" TokenTypeNormal TokenType = "normal" @@ -144,7 +147,7 @@ type RefreshToken struct { TokenType TokenType `json:"token_type"` CreatedAt *common.PyTimestamp `json:"created_at"` AccessTokenExpiration json.Number `json:"access_token_expiration"` - Token string `json:"token"` + Token RefreshTokenToken `json:"token"` JWTKey string `json:"jwt_key"` LastUsedAt *common.PyTimestamp `json:"last_used_at"` LastUsedIP *string `json:"last_used_ip"` @@ -156,6 +159,15 @@ func (rt *RefreshToken) IsValid() bool { return rt.JWTKey != "" } +func (rt *RefreshToken) AccessExpiration() (exp int64) { + exp, err := rt.AccessTokenExpiration.Int64() + if err != nil { + panic(err) + } + + return +} + type RefreshOption func(*RefreshToken) func WithClientID(cid ClientID) RefreshOption { @@ -191,8 +203,8 @@ func WithCredential(c *Credentials) RefreshOption { const DefaultAccessExpiration = "1800" func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) { - e := func(es string, a ...interface{}) (*RefreshToken, error) { - return nil, fmt.Errorf(es, a...) + e := func(es string, arg ...interface{}) (*RefreshToken, error) { + return nil, fmt.Errorf(es, arg...) } now := common.PyTimestamp(time.Now()) @@ -200,7 +212,7 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref r := &RefreshToken{ ID: RefreshTokenID(generate.UUID()), UserID: user.ID, - Token: generate.Hex(64), + Token: RefreshTokenToken(generate.Hex(64)), JWTKey: generate.Hex(64), CreatedAt: &now, AccessTokenExpiration: DefaultAccessExpiration, @@ -248,24 +260,20 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { now := time.Now() - exp, err := r.AccessTokenExpiration.Int64() - if err != nil { - return "", err - } pytnow := common.PyTimestamp(now) r.LastUsedAt = &pytnow r.LastUsedIP = &req.RemoteAddr - return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": r.ID, - "iat": now, - "exp": now.Add(time.Duration(exp) * time.Second), + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ + Issuer: r.ID.String(), + IssuedAt: now.Unix(), + ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(), }).SignedString([]byte(r.JWTKey)) } -func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials { - cred, success := a.authCodes.verify(tr, r) +func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { + cred, success := a.authCodes.get(tr) if !success { return nil } @@ -276,14 +284,14 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request const defaultExpiration = 15 * time.Minute func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string { - return a.authCodes.store(clientID, cred) + return a.authCodes.put(clientID, cred) } type GrantType string const ( - GTAuthorizationCode GrantType = "authorization_code" - GTRefreshToken GrantType = "refresh_token" + GrantAuthCode GrantType = "authorization_code" + GrantRefreshToken GrantType = "refresh_token" ) type ClientID common.ClientID @@ -300,14 +308,18 @@ func (ac *AuthCode) IsValid() bool { } type TokenRequest struct { - ClientID ClientID `form:"client_id"` - Code AuthCode `form:"code"` - GrantType GrantType `form:"grant_type"` + ClientID ClientID `form:"client_id"` + Code AuthCode `form:"code"` + GrantType GrantType `form:"grant_type"` + RefreshToken RefreshTokenToken `form:"refresh_token"` } const AuthFailed = "authentication failure" func (a *Authenticator) TokenHandler(c echo.Context) error { + a.Lock() + defer a.Unlock() + rq := new(TokenRequest) err := c.Bind(rq) if err != nil { @@ -315,7 +327,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error { } switch rq.GrantType { - case GTAuthorizationCode: + case GrantAuthCode: if !rq.ClientID.IsValid() { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"}) } @@ -324,53 +336,92 @@ func (a *Authenticator) TokenHandler(c echo.Context) error { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"}) } - if cred := a.verifyAndGetCredential(rq, c.Request()); cred != nil { - // TODO: success - user, err := a.getOrCreateUser(cred) - if err != nil { - log.Error().Err(err).Msg("getOrCreateUser") - return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) - } - - if err := user.allowedToAuth(c.Request()); err != nil { - log.Error().Err(err).Msg("allowedToAuth") - return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) - } - - // TODO: create a refresh token, return it and refreshtoken.AccessToken() - rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred)) - if err != nil { - log.Error().Err(err).Msg("NewRefreshToken") - return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) - } - - at, err := rt.AccessToken(c.Request()) - if err != nil { - log.Error().Err(err).Msg("AccessToken") - return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) - } - - exp, _ := rt.AccessTokenExpiration.Int64() - - successResp := struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - HAAuthProvider string `json:"ha_auth_provider"` - }{ - AccessToken: at, - TokenType: "Bearer", - RefreshToken: rt.Token, - ExpiresIn: exp, - HAAuthProvider: cred.AuthProviderType, - } - - return c.JSON(http.StatusOK, &successResp) + cred := a.verifyAndGetCredential(rq) + if cred == nil { + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"}) } - case GTRefreshToken: - return c.String(http.StatusNotImplemented, "not implemented") + + user, err := a.getOrCreateUser(cred) + if err != nil { + log.Error().Err(err).Msg("getOrCreateUser") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + if err := user.allowedToAuth(c.Request()); err != nil { + log.Error().Err(err).Msg("allowedToAuth") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred)) + if err != nil { + log.Error().Err(err).Msg("NewRefreshToken") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + at, err := rt.AccessToken(c.Request()) + if err != nil { + log.Error().Err(err).Msg("AccessToken") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + return common.NoCache(c).JSON(http.StatusOK, &struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken RefreshTokenToken `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + HAAuthProvider string `json:"ha_auth_provider"` + }{ + AccessToken: at, + TokenType: "Bearer", + RefreshToken: rt.Token, + ExpiresIn: rt.AccessExpiration(), + HAAuthProvider: cred.AuthProviderType, + }) + case GrantRefreshToken: + log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token") + + if !rq.ClientID.IsValid() { + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"}) + } + + if !rq.RefreshToken.IsValid() { + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) + } + + rt := a.store.GetRefreshTokenByToken(rq.RefreshToken) + if rt == nil { + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"}) + } + + if rt.ClientID == nil || *rt.ClientID != rq.ClientID { + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) + } + + user := a.store.User(rt.UserID) + if user == nil { + log.Error().Str("userID", string(rt.UserID)).Msg("no such user") + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) + } + + if err := user.allowedToAuth(c.Request()); err != nil { + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()}) + } + + at, err := rt.AccessToken(c.Request()) + if err != nil { + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()}) + } + + return common.NoCache(c).JSON(http.StatusOK, &struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + }{ + AccessToken: at, + TokenType: "Bearer", + ExpiresIn: rt.AccessExpiration(), + }) } - return c.String(http.StatusUnauthorized, "token bad I guess") + return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) } diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 2b13df3..da62d10 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -1,6 +1,7 @@ package auth import ( + "crypto/subtle" "encoding/json" "fmt" @@ -19,6 +20,7 @@ type AuthStore interface { User(UserID) *User GetCredential(provider.ProviderUser) *Credentials PutRefreshToken(*RefreshToken) (*RefreshToken, error) + GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken } type authStore struct { @@ -46,17 +48,18 @@ func strPtrEq(n1, n2 *string) bool { } func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials { + var found *Credentials for _, cr := range as.Credentials { if p != nil && (p == cr.User || (p.Provider() != nil && strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && cr.AuthProviderType == p.Provider().ProviderType() && p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) { - return cr + found = cr } } - return nil + return found } func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) { @@ -76,6 +79,20 @@ func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) { return rt, nil } +func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken { + var found *RefreshToken + + for _, u := range as.Users { + for _, rt := range u.RefreshTokens { + if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 { + found = rt + } + } + } + + return found +} + func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { // XXX: probably broken prov := p.Provider()