package auth import ( "encoding/json" "fmt" "net/http" "time" "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" "dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" ) type authCodeStore struct { s map[authCodeTuple]flowResult lastCull time.Time } type authCodeTuple struct { ClientID ClientID Code AuthCode } func (t *authCodeTuple) IsValid() bool { // TODO: more validation than this return t.Code != "" } type flowResult struct { Time time.Time Cred *Credentials } // OAuth 4.2.1 spec recommends 10 minutes const authCodeExpire = 10 * time.Minute func (f *flowResult) IsValid(now time.Time) bool { if now.After(f.Time.Add(authCodeExpire)) { return false } return true } func (ss *authCodeStore) init() { ss.s = make(map[authCodeTuple]flowResult) } const cullInterval = 5 * time.Minute func (ss *authCodeStore) cull() { if now := time.Now(); now.Sub(ss.lastCull) > cullInterval { for k, v := range ss.s { if !v.IsValid(now) { delete(ss.s, k) } } } } func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string { ss.cull() code := generate.UUID() ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred} return code } func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) { key := authCodeTuple{tr.ClientID, tr.Code} if t, hasCode := ss.s[key]; hasCode { defer delete(ss.s, key) if t.IsValid(time.Now()) { return t.Cred, true } } return nil, false } type Credentials struct { ID CredID `json:"id"` UserID UserID `json:"user_id"` AuthProviderType string `json:"auth_provider_type"` AuthProviderID *string `json:"auth_provider_id"` DataRaw *json.RawMessage `json:"data,omitempty"` User provider.ProviderUser `json:"-"` } func (cred *Credentials) MarshalJSON() ([]byte, error) { type CredAlias Credentials // alias so ø method set and we don't recurse nCd := (*CredAlias)(cred) if cred.User != nil { providerData := cred.User.UserData() if providerData != nil { b, err := json.Marshal(providerData) if err != nil { return nil, err } dr := json.RawMessage(b) nCd.DataRaw = &dr } } return json.Marshal(nCd) } type ( TokenType string RefreshTokenID string RefreshTokenToken string ) func (rti RefreshTokenID) String() string { return string(rti) } func (rti RefreshTokenToken) IsValid() bool { return rti != "" } const ( TokenTypeSystem TokenType = "system" TokenTypeNormal TokenType = "normal" TokenTypeLongLived TokenType = "long_lived_access_token" TokenTypeNone TokenType = "" ) func (tt TokenType) IsValid() bool { switch tt { case TokenTypeSystem, TokenTypeNormal, TokenTypeLongLived: return true } return false } type RefreshToken struct { ID RefreshTokenID `json:"id"` UserID UserID `json:"user_id"` ClientID *ClientID `json:"client_id"` ClientName *string `json:"client_name"` ClientIcon *string `json:"client_icon"` TokenType TokenType `json:"token_type"` CreatedAt *common.PyTimestamp `json:"created_at"` AccessTokenExpiration json.Number `json:"access_token_expiration"` Token RefreshTokenToken `json:"token"` JWTKey string `json:"jwt_key"` LastUsedAt *common.PyTimestamp `json:"last_used_at"` LastUsedIP *string `json:"last_used_ip"` CredentialID *CredID `json:"credential_id"` Version *string `json:"version"` User *User `json:"-"` } func (rt *RefreshToken) IsValid() bool { return rt.JWTKey != "" } func (rt *RefreshToken) AccessExpiration() (exp int64) { exp, err := rt.AccessTokenExpiration.Int64() if err != nil { panic(err) } return } type RefreshOption func(*RefreshToken) func WithClientID(cid ClientID) RefreshOption { return func(rt *RefreshToken) { rt.ClientID = &cid } } func WithClientName(n string) RefreshOption { return func(rt *RefreshToken) { rt.ClientName = &n } } func WithClientIcon(n string) RefreshOption { return func(rt *RefreshToken) { rt.ClientIcon = &n } } func WithTokenType(t TokenType) RefreshOption { return func(rt *RefreshToken) { rt.TokenType = t } } func WithCredential(c *Credentials) RefreshOption { return func(rt *RefreshToken) { rt.CredentialID = &c.ID } } const DefaultAccessExpiration = "1800" // json 🤮 func (a *authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) { e := func(es string, arg ...interface{}) (*RefreshToken, error) { return nil, fmt.Errorf(es, arg...) } now := common.PyTimestamp(time.Now()) r := &RefreshToken{ ID: RefreshTokenID(generate.UUID()), UserID: user.ID, Token: RefreshTokenToken(generate.Hex(64)), JWTKey: generate.Hex(64), CreatedAt: &now, AccessTokenExpiration: DefaultAccessExpiration, User: user, } for _, opt := range opts { opt(r) } if r.TokenType == TokenTypeNone { if user.SystemGenerated { r.TokenType = TokenTypeSystem } else { r.TokenType = TokenTypeNormal } } switch { case !r.TokenType.IsValid(): return e("invalid token type") case !user.Active: return e("user is not active") case user.SystemGenerated && r.ClientID != nil: return e("system generated users cannot have refresh tokens connected to a client") case !r.TokenType.IsValid(): return e("invalid token type '%v'", r.TokenType) case user.SystemGenerated != (r.TokenType == TokenTypeSystem): return e("system generated user can only have system type refresh tokens") case r.TokenType == TokenTypeNormal && r.ClientID == nil: return e("client is required to generate a refresh token") case r.TokenType == TokenTypeLongLived && r.ClientName == nil: return e("client name is required for long-lived token") } if r.TokenType == TokenTypeLongLived { for _, lv := range user.RefreshTokens { if strPtrEq(lv.ClientName, r.ClientName) && lv.TokenType == TokenTypeLongLived { return e("client name '%v' already exists", *r.ClientName) } } } return a.store.PutRefreshToken(r) } func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { now := time.Now() pytnow := common.PyTimestamp(now) r.LastUsedAt = &pytnow r.LastUsedIP = &req.RemoteAddr return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ Issuer: r.ID.String(), IssuedAt: now.Unix(), ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(), }).SignedString([]byte(r.JWTKey)) } func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials { cred, success := a.authCodes.get(tr) if !success { return nil } return cred } const defaultExpiration = 15 * time.Minute func (a *authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string { return a.authCodes.put(clientID, cred) } type GrantType string const ( GrantAuthCode GrantType = "authorization_code" GrantRefreshToken GrantType = "refresh_token" ) type ClientID common.ClientID func (c *ClientID) IsValid() bool { // TODO: || !indieauth.VerifyClientID(rq.ClientID)? return *c != "" } type AuthCode string func (ac *AuthCode) IsValid() bool { return *ac != "" } type TokenRequest struct { ClientID ClientID `form:"client_id"` Code AuthCode `form:"code"` GrantType GrantType `form:"grant_type"` RefreshToken RefreshTokenToken `form:"refresh_token"` } const AuthFailed = "authentication failure" func (a *authenticator) TokenHandler(c echo.Context) error { a.Lock() defer a.Unlock() rq := new(TokenRequest) err := c.Bind(rq) if err != nil { return err } switch rq.GrantType { case GrantAuthCode: if !rq.ClientID.IsValid() { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"}) } if !rq.Code.IsValid() { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"}) } cred := a.verifyAndGetCredential(rq) if cred == nil { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"}) } user, err := a.getOrCreateUser(cred) if err != nil { log.Error().Err(err).Msg("getOrCreateUser") return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } if err := user.allowedToAuth(c.Request()); err != nil { log.Error().Err(err).Msg("allowedToAuth") return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred)) if err != nil { log.Error().Err(err).Msg("NewRefreshToken") return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } at, err := rt.AccessToken(c.Request()) if err != nil { log.Error().Err(err).Msg("AccessToken") return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } return common.NoCache(c).JSON(http.StatusOK, &struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` RefreshToken RefreshTokenToken `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` HAAuthProvider string `json:"ha_auth_provider"` }{ AccessToken: at, TokenType: "Bearer", RefreshToken: rt.Token, ExpiresIn: rt.AccessExpiration(), HAAuthProvider: cred.AuthProviderType, }) case GrantRefreshToken: log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token") if !rq.ClientID.IsValid() { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"}) } if !rq.RefreshToken.IsValid() { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) } rt := a.store.GetRefreshTokenByToken(rq.RefreshToken) if rt == nil { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"}) } if rt.ClientID == nil || *rt.ClientID != rq.ClientID { return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) } if err := rt.User.allowedToAuth(c.Request()); err != nil { return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()}) } at, err := rt.AccessToken(c.Request()) if err != nil { return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()}) } return common.NoCache(c).JSON(http.StatusOK, &struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` }{ AccessToken: at, TokenType: "Bearer", ExpiresIn: rt.AccessExpiration(), }) } return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"}) } type AccessToken string