refresh token

This commit is contained in:
Daniel Ponte 2022-12-19 13:09:01 -05:00
parent 824e54894e
commit c91cd6efca
5 changed files with 161 additions and 75 deletions

View file

@ -2,6 +2,7 @@
package common package common
import ( import (
"github.com/labstack/echo/v4"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -39,3 +40,9 @@ func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
return err return err
} }
} }
func NoCache(c echo.Context) echo.Context {
c.Response().Header().Set("Cache-Control", "no-store")
c.Response().Header().Set("Pragma", "no-cache")
return c
}

View file

@ -3,6 +3,7 @@ package auth
import ( import (
"errors" "errors"
"net/http" "net/http"
"sync"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -25,6 +26,7 @@ var (
) )
type Authenticator struct { type Authenticator struct {
sync.Mutex
store AuthStore store AuthStore
flows *AuthFlowManager flows *AuthFlowManager
authCodes authCodeStore authCodes authCodeStore

View file

@ -137,6 +137,9 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
} }
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error { func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
flowID := flow.FlowID(c.Param("flow_id")) flowID := flow.FlowID(c.Param("flow_id"))
if flowID == "" { if flowID == "" {
@ -156,6 +159,9 @@ func setJSON(c echo.Context) {
} }
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error { func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
setJSON(c) setJSON(c)
var flowReq LoginFlowRequest var flowReq LoginFlowRequest
@ -176,6 +182,9 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
} }
func (a *Authenticator) LoginFlowHandler(c echo.Context) error { func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
setJSON(c) setJSON(c)
flowID := c.Param("flow_id") flowID := c.Param("flow_id")

View file

@ -33,8 +33,6 @@ func (t *authCodeTuple) IsValid() bool {
type flowResult struct { type flowResult struct {
Time time.Time Time time.Time
Cred *Credentials Cred *Credentials
// TODO: remove this comment below \/
//user provider.ProviderUser `json:"-"`
} }
// OAuth 4.2.1 spec recommends 10 minutes // OAuth 4.2.1 spec recommends 10 minutes
@ -64,7 +62,7 @@ func (ss *authCodeStore) cull() {
} }
} }
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string { func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string {
ss.cull() ss.cull()
code := generate.UUID() code := generate.UUID()
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
@ -72,7 +70,7 @@ func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
return code return code
} }
func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) { func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) {
key := authCodeTuple{tr.ClientID, tr.Code} key := authCodeTuple{tr.ClientID, tr.Code}
if t, hasCode := ss.s[key]; hasCode { if t, hasCode := ss.s[key]; hasCode {
defer delete(ss.s, key) defer delete(ss.s, key)
@ -115,10 +113,15 @@ func (cred *Credentials) MarshalJSON() ([]byte, error) {
} }
type ( type (
TokenType string TokenType string
RefreshTokenID string RefreshTokenID string
RefreshTokenToken string
) )
func (rti RefreshTokenID) String() string { return string(rti) }
func (rti RefreshTokenToken) IsValid() bool { return rti != "" }
const ( const (
TokenTypeSystem TokenType = "system" TokenTypeSystem TokenType = "system"
TokenTypeNormal TokenType = "normal" TokenTypeNormal TokenType = "normal"
@ -144,7 +147,7 @@ type RefreshToken struct {
TokenType TokenType `json:"token_type"` TokenType TokenType `json:"token_type"`
CreatedAt *common.PyTimestamp `json:"created_at"` CreatedAt *common.PyTimestamp `json:"created_at"`
AccessTokenExpiration json.Number `json:"access_token_expiration"` AccessTokenExpiration json.Number `json:"access_token_expiration"`
Token string `json:"token"` Token RefreshTokenToken `json:"token"`
JWTKey string `json:"jwt_key"` JWTKey string `json:"jwt_key"`
LastUsedAt *common.PyTimestamp `json:"last_used_at"` LastUsedAt *common.PyTimestamp `json:"last_used_at"`
LastUsedIP *string `json:"last_used_ip"` LastUsedIP *string `json:"last_used_ip"`
@ -156,6 +159,15 @@ func (rt *RefreshToken) IsValid() bool {
return rt.JWTKey != "" return rt.JWTKey != ""
} }
func (rt *RefreshToken) AccessExpiration() (exp int64) {
exp, err := rt.AccessTokenExpiration.Int64()
if err != nil {
panic(err)
}
return
}
type RefreshOption func(*RefreshToken) type RefreshOption func(*RefreshToken)
func WithClientID(cid ClientID) RefreshOption { func WithClientID(cid ClientID) RefreshOption {
@ -191,8 +203,8 @@ func WithCredential(c *Credentials) RefreshOption {
const DefaultAccessExpiration = "1800" const DefaultAccessExpiration = "1800"
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) { func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
e := func(es string, a ...interface{}) (*RefreshToken, error) { e := func(es string, arg ...interface{}) (*RefreshToken, error) {
return nil, fmt.Errorf(es, a...) return nil, fmt.Errorf(es, arg...)
} }
now := common.PyTimestamp(time.Now()) now := common.PyTimestamp(time.Now())
@ -200,7 +212,7 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
r := &RefreshToken{ r := &RefreshToken{
ID: RefreshTokenID(generate.UUID()), ID: RefreshTokenID(generate.UUID()),
UserID: user.ID, UserID: user.ID,
Token: generate.Hex(64), Token: RefreshTokenToken(generate.Hex(64)),
JWTKey: generate.Hex(64), JWTKey: generate.Hex(64),
CreatedAt: &now, CreatedAt: &now,
AccessTokenExpiration: DefaultAccessExpiration, AccessTokenExpiration: DefaultAccessExpiration,
@ -248,24 +260,20 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
now := time.Now() now := time.Now()
exp, err := r.AccessTokenExpiration.Int64()
if err != nil {
return "", err
}
pytnow := common.PyTimestamp(now) pytnow := common.PyTimestamp(now)
r.LastUsedAt = &pytnow r.LastUsedAt = &pytnow
r.LastUsedIP = &req.RemoteAddr r.LastUsedIP = &req.RemoteAddr
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
"iss": r.ID, Issuer: r.ID.String(),
"iat": now, IssuedAt: now.Unix(),
"exp": now.Add(time.Duration(exp) * time.Second), ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(),
}).SignedString([]byte(r.JWTKey)) }).SignedString([]byte(r.JWTKey))
} }
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials { func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
cred, success := a.authCodes.verify(tr, r) cred, success := a.authCodes.get(tr)
if !success { if !success {
return nil return nil
} }
@ -276,14 +284,14 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request
const defaultExpiration = 15 * time.Minute const defaultExpiration = 15 * time.Minute
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string { func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
return a.authCodes.store(clientID, cred) return a.authCodes.put(clientID, cred)
} }
type GrantType string type GrantType string
const ( const (
GTAuthorizationCode GrantType = "authorization_code" GrantAuthCode GrantType = "authorization_code"
GTRefreshToken GrantType = "refresh_token" GrantRefreshToken GrantType = "refresh_token"
) )
type ClientID common.ClientID type ClientID common.ClientID
@ -300,14 +308,18 @@ func (ac *AuthCode) IsValid() bool {
} }
type TokenRequest struct { type TokenRequest struct {
ClientID ClientID `form:"client_id"` ClientID ClientID `form:"client_id"`
Code AuthCode `form:"code"` Code AuthCode `form:"code"`
GrantType GrantType `form:"grant_type"` GrantType GrantType `form:"grant_type"`
RefreshToken RefreshTokenToken `form:"refresh_token"`
} }
const AuthFailed = "authentication failure" const AuthFailed = "authentication failure"
func (a *Authenticator) TokenHandler(c echo.Context) error { func (a *Authenticator) TokenHandler(c echo.Context) error {
a.Lock()
defer a.Unlock()
rq := new(TokenRequest) rq := new(TokenRequest)
err := c.Bind(rq) err := c.Bind(rq)
if err != nil { if err != nil {
@ -315,7 +327,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
} }
switch rq.GrantType { switch rq.GrantType {
case GTAuthorizationCode: case GrantAuthCode:
if !rq.ClientID.IsValid() { if !rq.ClientID.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"}) return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
} }
@ -324,53 +336,92 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"}) return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
} }
if cred := a.verifyAndGetCredential(rq, c.Request()); cred != nil { cred := a.verifyAndGetCredential(rq)
// TODO: success if cred == nil {
user, err := a.getOrCreateUser(cred) return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
if err != nil {
log.Error().Err(err).Msg("getOrCreateUser")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
if err := user.allowedToAuth(c.Request()); err != nil {
log.Error().Err(err).Msg("allowedToAuth")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
// TODO: create a refresh token, return it and refreshtoken.AccessToken()
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
if err != nil {
log.Error().Err(err).Msg("NewRefreshToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
at, err := rt.AccessToken(c.Request())
if err != nil {
log.Error().Err(err).Msg("AccessToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
exp, _ := rt.AccessTokenExpiration.Int64()
successResp := struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
HAAuthProvider string `json:"ha_auth_provider"`
}{
AccessToken: at,
TokenType: "Bearer",
RefreshToken: rt.Token,
ExpiresIn: exp,
HAAuthProvider: cred.AuthProviderType,
}
return c.JSON(http.StatusOK, &successResp)
} }
case GTRefreshToken:
return c.String(http.StatusNotImplemented, "not implemented") user, err := a.getOrCreateUser(cred)
if err != nil {
log.Error().Err(err).Msg("getOrCreateUser")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
if err := user.allowedToAuth(c.Request()); err != nil {
log.Error().Err(err).Msg("allowedToAuth")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
if err != nil {
log.Error().Err(err).Msg("NewRefreshToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
at, err := rt.AccessToken(c.Request())
if err != nil {
log.Error().Err(err).Msg("AccessToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
return common.NoCache(c).JSON(http.StatusOK, &struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken RefreshTokenToken `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
HAAuthProvider string `json:"ha_auth_provider"`
}{
AccessToken: at,
TokenType: "Bearer",
RefreshToken: rt.Token,
ExpiresIn: rt.AccessExpiration(),
HAAuthProvider: cred.AuthProviderType,
})
case GrantRefreshToken:
log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token")
if !rq.ClientID.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
}
if !rq.RefreshToken.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
rt := a.store.GetRefreshTokenByToken(rq.RefreshToken)
if rt == nil {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"})
}
if rt.ClientID == nil || *rt.ClientID != rq.ClientID {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
user := a.store.User(rt.UserID)
if user == nil {
log.Error().Str("userID", string(rt.UserID)).Msg("no such user")
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
if err := user.allowedToAuth(c.Request()); err != nil {
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
}
at, err := rt.AccessToken(c.Request())
if err != nil {
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
}
return common.NoCache(c).JSON(http.StatusOK, &struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}{
AccessToken: at,
TokenType: "Bearer",
ExpiresIn: rt.AccessExpiration(),
})
} }
return c.String(http.StatusUnauthorized, "token bad I guess") return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
} }

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"crypto/subtle"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -19,6 +20,7 @@ type AuthStore interface {
User(UserID) *User User(UserID) *User
GetCredential(provider.ProviderUser) *Credentials GetCredential(provider.ProviderUser) *Credentials
PutRefreshToken(*RefreshToken) (*RefreshToken, error) PutRefreshToken(*RefreshToken) (*RefreshToken, error)
GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken
} }
type authStore struct { type authStore struct {
@ -46,17 +48,18 @@ func strPtrEq(n1, n2 *string) bool {
} }
func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials { func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials {
var found *Credentials
for _, cr := range as.Credentials { for _, cr := range as.Credentials {
if p != nil && (p == cr.User || if p != nil && (p == cr.User ||
(p.Provider() != nil && (p.Provider() != nil &&
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
cr.AuthProviderType == p.Provider().ProviderType() && cr.AuthProviderType == p.Provider().ProviderType() &&
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) { p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) {
return cr found = cr
} }
} }
return nil return found
} }
func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) { func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
@ -76,6 +79,20 @@ func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
return rt, nil return rt, nil
} }
func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken {
var found *RefreshToken
for _, u := range as.Users {
for _, rt := range u.RefreshTokens {
if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 {
found = rt
}
}
}
return found
}
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
// XXX: probably broken // XXX: probably broken
prov := p.Provider() prov := p.Provider()