refresh token
This commit is contained in:
parent
824e54894e
commit
c91cd6efca
5 changed files with 161 additions and 75 deletions
|
@ -2,6 +2,7 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,3 +40,9 @@ func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NoCache(c echo.Context) echo.Context {
|
||||||
|
c.Response().Header().Set("Cache-Control", "no-store")
|
||||||
|
c.Response().Header().Set("Pragma", "no-cache")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -25,6 +26,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
|
sync.Mutex
|
||||||
store AuthStore
|
store AuthStore
|
||||||
flows *AuthFlowManager
|
flows *AuthFlowManager
|
||||||
authCodes authCodeStore
|
authCodes authCodeStore
|
||||||
|
|
|
@ -137,6 +137,9 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
|
func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
flowID := flow.FlowID(c.Param("flow_id"))
|
flowID := flow.FlowID(c.Param("flow_id"))
|
||||||
|
|
||||||
if flowID == "" {
|
if flowID == "" {
|
||||||
|
@ -156,6 +159,9 @@ func setJSON(c echo.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
setJSON(c)
|
setJSON(c)
|
||||||
|
|
||||||
var flowReq LoginFlowRequest
|
var flowReq LoginFlowRequest
|
||||||
|
@ -176,6 +182,9 @@ func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
setJSON(c)
|
setJSON(c)
|
||||||
|
|
||||||
flowID := c.Param("flow_id")
|
flowID := c.Param("flow_id")
|
||||||
|
|
|
@ -33,8 +33,6 @@ func (t *authCodeTuple) IsValid() bool {
|
||||||
type flowResult struct {
|
type flowResult struct {
|
||||||
Time time.Time
|
Time time.Time
|
||||||
Cred *Credentials
|
Cred *Credentials
|
||||||
// TODO: remove this comment below \/
|
|
||||||
//user provider.ProviderUser `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth 4.2.1 spec recommends 10 minutes
|
// OAuth 4.2.1 spec recommends 10 minutes
|
||||||
|
@ -64,7 +62,7 @@ func (ss *authCodeStore) cull() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
|
func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string {
|
||||||
ss.cull()
|
ss.cull()
|
||||||
code := generate.UUID()
|
code := generate.UUID()
|
||||||
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
|
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
|
||||||
|
@ -72,7 +70,7 @@ func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
|
||||||
return code
|
return code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) {
|
func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) {
|
||||||
key := authCodeTuple{tr.ClientID, tr.Code}
|
key := authCodeTuple{tr.ClientID, tr.Code}
|
||||||
if t, hasCode := ss.s[key]; hasCode {
|
if t, hasCode := ss.s[key]; hasCode {
|
||||||
defer delete(ss.s, key)
|
defer delete(ss.s, key)
|
||||||
|
@ -115,10 +113,15 @@ func (cred *Credentials) MarshalJSON() ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
TokenType string
|
TokenType string
|
||||||
RefreshTokenID string
|
RefreshTokenID string
|
||||||
|
RefreshTokenToken string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (rti RefreshTokenID) String() string { return string(rti) }
|
||||||
|
|
||||||
|
func (rti RefreshTokenToken) IsValid() bool { return rti != "" }
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TokenTypeSystem TokenType = "system"
|
TokenTypeSystem TokenType = "system"
|
||||||
TokenTypeNormal TokenType = "normal"
|
TokenTypeNormal TokenType = "normal"
|
||||||
|
@ -144,7 +147,7 @@ type RefreshToken struct {
|
||||||
TokenType TokenType `json:"token_type"`
|
TokenType TokenType `json:"token_type"`
|
||||||
CreatedAt *common.PyTimestamp `json:"created_at"`
|
CreatedAt *common.PyTimestamp `json:"created_at"`
|
||||||
AccessTokenExpiration json.Number `json:"access_token_expiration"`
|
AccessTokenExpiration json.Number `json:"access_token_expiration"`
|
||||||
Token string `json:"token"`
|
Token RefreshTokenToken `json:"token"`
|
||||||
JWTKey string `json:"jwt_key"`
|
JWTKey string `json:"jwt_key"`
|
||||||
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
|
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
|
||||||
LastUsedIP *string `json:"last_used_ip"`
|
LastUsedIP *string `json:"last_used_ip"`
|
||||||
|
@ -156,6 +159,15 @@ func (rt *RefreshToken) IsValid() bool {
|
||||||
return rt.JWTKey != ""
|
return rt.JWTKey != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rt *RefreshToken) AccessExpiration() (exp int64) {
|
||||||
|
exp, err := rt.AccessTokenExpiration.Int64()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type RefreshOption func(*RefreshToken)
|
type RefreshOption func(*RefreshToken)
|
||||||
|
|
||||||
func WithClientID(cid ClientID) RefreshOption {
|
func WithClientID(cid ClientID) RefreshOption {
|
||||||
|
@ -191,8 +203,8 @@ func WithCredential(c *Credentials) RefreshOption {
|
||||||
const DefaultAccessExpiration = "1800"
|
const DefaultAccessExpiration = "1800"
|
||||||
|
|
||||||
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
|
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
|
||||||
e := func(es string, a ...interface{}) (*RefreshToken, error) {
|
e := func(es string, arg ...interface{}) (*RefreshToken, error) {
|
||||||
return nil, fmt.Errorf(es, a...)
|
return nil, fmt.Errorf(es, arg...)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := common.PyTimestamp(time.Now())
|
now := common.PyTimestamp(time.Now())
|
||||||
|
@ -200,7 +212,7 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
|
||||||
r := &RefreshToken{
|
r := &RefreshToken{
|
||||||
ID: RefreshTokenID(generate.UUID()),
|
ID: RefreshTokenID(generate.UUID()),
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Token: generate.Hex(64),
|
Token: RefreshTokenToken(generate.Hex(64)),
|
||||||
JWTKey: generate.Hex(64),
|
JWTKey: generate.Hex(64),
|
||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
AccessTokenExpiration: DefaultAccessExpiration,
|
AccessTokenExpiration: DefaultAccessExpiration,
|
||||||
|
@ -248,24 +260,20 @@ func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*Ref
|
||||||
|
|
||||||
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
exp, err := r.AccessTokenExpiration.Int64()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
pytnow := common.PyTimestamp(now)
|
pytnow := common.PyTimestamp(now)
|
||||||
r.LastUsedAt = &pytnow
|
r.LastUsedAt = &pytnow
|
||||||
r.LastUsedIP = &req.RemoteAddr
|
r.LastUsedIP = &req.RemoteAddr
|
||||||
|
|
||||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
|
||||||
"iss": r.ID,
|
Issuer: r.ID.String(),
|
||||||
"iat": now,
|
IssuedAt: now.Unix(),
|
||||||
"exp": now.Add(time.Duration(exp) * time.Second),
|
ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(),
|
||||||
}).SignedString([]byte(r.JWTKey))
|
}).SignedString([]byte(r.JWTKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials {
|
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
|
||||||
cred, success := a.authCodes.verify(tr, r)
|
cred, success := a.authCodes.get(tr)
|
||||||
if !success {
|
if !success {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -276,14 +284,14 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request
|
||||||
const defaultExpiration = 15 * time.Minute
|
const defaultExpiration = 15 * time.Minute
|
||||||
|
|
||||||
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
|
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
|
||||||
return a.authCodes.store(clientID, cred)
|
return a.authCodes.put(clientID, cred)
|
||||||
}
|
}
|
||||||
|
|
||||||
type GrantType string
|
type GrantType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GTAuthorizationCode GrantType = "authorization_code"
|
GrantAuthCode GrantType = "authorization_code"
|
||||||
GTRefreshToken GrantType = "refresh_token"
|
GrantRefreshToken GrantType = "refresh_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientID common.ClientID
|
type ClientID common.ClientID
|
||||||
|
@ -300,14 +308,18 @@ func (ac *AuthCode) IsValid() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
ClientID ClientID `form:"client_id"`
|
ClientID ClientID `form:"client_id"`
|
||||||
Code AuthCode `form:"code"`
|
Code AuthCode `form:"code"`
|
||||||
GrantType GrantType `form:"grant_type"`
|
GrantType GrantType `form:"grant_type"`
|
||||||
|
RefreshToken RefreshTokenToken `form:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const AuthFailed = "authentication failure"
|
const AuthFailed = "authentication failure"
|
||||||
|
|
||||||
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
|
||||||
rq := new(TokenRequest)
|
rq := new(TokenRequest)
|
||||||
err := c.Bind(rq)
|
err := c.Bind(rq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -315,7 +327,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch rq.GrantType {
|
switch rq.GrantType {
|
||||||
case GTAuthorizationCode:
|
case GrantAuthCode:
|
||||||
if !rq.ClientID.IsValid() {
|
if !rq.ClientID.IsValid() {
|
||||||
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
||||||
}
|
}
|
||||||
|
@ -324,53 +336,92 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if cred := a.verifyAndGetCredential(rq, c.Request()); cred != nil {
|
cred := a.verifyAndGetCredential(rq)
|
||||||
// TODO: success
|
if cred == nil {
|
||||||
user, err := a.getOrCreateUser(cred)
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("getOrCreateUser")
|
|
||||||
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := user.allowedToAuth(c.Request()); err != nil {
|
|
||||||
log.Error().Err(err).Msg("allowedToAuth")
|
|
||||||
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: create a refresh token, return it and refreshtoken.AccessToken()
|
|
||||||
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("NewRefreshToken")
|
|
||||||
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
||||||
}
|
|
||||||
|
|
||||||
at, err := rt.AccessToken(c.Request())
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("AccessToken")
|
|
||||||
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
||||||
}
|
|
||||||
|
|
||||||
exp, _ := rt.AccessTokenExpiration.Int64()
|
|
||||||
|
|
||||||
successResp := struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
|
||||||
HAAuthProvider string `json:"ha_auth_provider"`
|
|
||||||
}{
|
|
||||||
AccessToken: at,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
RefreshToken: rt.Token,
|
|
||||||
ExpiresIn: exp,
|
|
||||||
HAAuthProvider: cred.AuthProviderType,
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, &successResp)
|
|
||||||
}
|
}
|
||||||
case GTRefreshToken:
|
|
||||||
return c.String(http.StatusNotImplemented, "not implemented")
|
user, err := a.getOrCreateUser(cred)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("getOrCreateUser")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.allowedToAuth(c.Request()); err != nil {
|
||||||
|
log.Error().Err(err).Msg("allowedToAuth")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("NewRefreshToken")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
at, err := rt.AccessToken(c.Request())
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("AccessToken")
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.NoCache(c).JSON(http.StatusOK, &struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken RefreshTokenToken `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
HAAuthProvider string `json:"ha_auth_provider"`
|
||||||
|
}{
|
||||||
|
AccessToken: at,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
RefreshToken: rt.Token,
|
||||||
|
ExpiresIn: rt.AccessExpiration(),
|
||||||
|
HAAuthProvider: cred.AuthProviderType,
|
||||||
|
})
|
||||||
|
case GrantRefreshToken:
|
||||||
|
log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token")
|
||||||
|
|
||||||
|
if !rq.ClientID.IsValid() {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rq.RefreshToken.IsValid() {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := a.store.GetRefreshTokenByToken(rq.RefreshToken)
|
||||||
|
if rt == nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if rt.ClientID == nil || *rt.ClientID != rq.ClientID {
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
|
}
|
||||||
|
|
||||||
|
user := a.store.User(rt.UserID)
|
||||||
|
if user == nil {
|
||||||
|
log.Error().Str("userID", string(rt.UserID)).Msg("no such user")
|
||||||
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.allowedToAuth(c.Request()); err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
at, err := rt.AccessToken(c.Request())
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.NoCache(c).JSON(http.StatusOK, &struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}{
|
||||||
|
AccessToken: at,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: rt.AccessExpiration(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.String(http.StatusUnauthorized, "token bad I guess")
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ type AuthStore interface {
|
||||||
User(UserID) *User
|
User(UserID) *User
|
||||||
GetCredential(provider.ProviderUser) *Credentials
|
GetCredential(provider.ProviderUser) *Credentials
|
||||||
PutRefreshToken(*RefreshToken) (*RefreshToken, error)
|
PutRefreshToken(*RefreshToken) (*RefreshToken, error)
|
||||||
|
GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
type authStore struct {
|
type authStore struct {
|
||||||
|
@ -46,17 +48,18 @@ func strPtrEq(n1, n2 *string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials {
|
func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials {
|
||||||
|
var found *Credentials
|
||||||
for _, cr := range as.Credentials {
|
for _, cr := range as.Credentials {
|
||||||
if p != nil && (p == cr.User ||
|
if p != nil && (p == cr.User ||
|
||||||
(p.Provider() != nil &&
|
(p.Provider() != nil &&
|
||||||
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
|
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
|
||||||
cr.AuthProviderType == p.Provider().ProviderType() &&
|
cr.AuthProviderType == p.Provider().ProviderType() &&
|
||||||
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) {
|
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) {
|
||||||
return cr
|
found = cr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
|
func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
|
||||||
|
@ -76,6 +79,20 @@ func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) {
|
||||||
return rt, nil
|
return rt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (as *authStore) GetRefreshTokenByToken(token RefreshTokenToken) *RefreshToken {
|
||||||
|
var found *RefreshToken
|
||||||
|
|
||||||
|
for _, u := range as.Users {
|
||||||
|
for _, rt := range u.RefreshTokens {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(token), []byte(rt.Token)) == 1 {
|
||||||
|
found = rt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
|
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
|
||||||
// XXX: probably broken
|
// XXX: probably broken
|
||||||
prov := p.Provider()
|
prov := p.Provider()
|
||||||
|
|
Loading…
Reference in a new issue