blasphem/pkg/auth/session.go

432 lines
11 KiB
Go
Raw Normal View History

2022-10-26 19:13:50 -04:00
package auth
import (
2022-11-12 15:56:17 -05:00
"encoding/json"
2022-12-19 02:42:01 -05:00
"fmt"
2022-10-26 19:13:50 -04:00
"net/http"
"time"
2022-10-26 19:43:51 -04:00
2022-12-19 02:42:01 -05:00
"github.com/golang-jwt/jwt"
2022-10-26 19:43:51 -04:00
"github.com/labstack/echo/v4"
2022-12-19 02:42:01 -05:00
"github.com/rs/zerolog/log"
2022-11-12 15:56:17 -05:00
2022-11-13 12:03:58 -05:00
"dynatron.me/x/blasphem/internal/common"
2022-11-13 11:55:10 -05:00
"dynatron.me/x/blasphem/internal/generate"
2022-11-12 15:56:17 -05:00
"dynatron.me/x/blasphem/pkg/auth/provider"
2022-10-26 19:13:50 -04:00
)
2022-12-18 09:55:08 -05:00
type authCodeStore struct {
s map[authCodeTuple]flowResult
2022-10-26 19:13:50 -04:00
lastCull time.Time
}
2022-12-18 09:55:08 -05:00
type authCodeTuple struct {
ClientID ClientID
2022-12-18 21:26:34 -05:00
Code AuthCode
2022-12-18 09:55:08 -05:00
}
2022-10-26 19:13:50 -04:00
2022-12-18 09:55:08 -05:00
func (t *authCodeTuple) IsValid() bool {
2022-11-11 18:02:52 -05:00
// TODO: more validation than this
2022-12-18 09:55:08 -05:00
return t.Code != ""
}
type flowResult struct {
Time time.Time
Cred *Credentials
2022-11-11 18:02:52 -05:00
}
2022-12-18 09:55:08 -05:00
// OAuth 4.2.1 spec recommends 10 minutes
const authCodeExpire = 10 * time.Minute
2022-11-11 18:02:52 -05:00
2022-12-18 09:55:08 -05:00
func (f *flowResult) IsValid(now time.Time) bool {
if now.After(f.Time.Add(authCodeExpire)) {
return false
}
return true
2022-10-26 19:13:50 -04:00
}
2022-12-18 09:55:08 -05:00
func (ss *authCodeStore) init() {
ss.s = make(map[authCodeTuple]flowResult)
2022-10-26 19:13:50 -04:00
}
const cullInterval = 5 * time.Minute
2022-12-18 09:55:08 -05:00
func (ss *authCodeStore) cull() {
2022-10-26 19:13:50 -04:00
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
for k, v := range ss.s {
2022-12-18 09:55:08 -05:00
if !v.IsValid(now) {
2022-10-26 19:13:50 -04:00
delete(ss.s, k)
}
}
}
}
2022-12-19 13:09:01 -05:00
func (ss *authCodeStore) put(clientID ClientID, cred *Credentials) string {
2022-10-26 19:13:50 -04:00
ss.cull()
2022-12-18 09:55:08 -05:00
code := generate.UUID()
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
return code
2022-10-26 19:13:50 -04:00
}
2022-12-19 13:09:01 -05:00
func (ss *authCodeStore) get(tr *TokenRequest) (*Credentials, bool) {
2022-12-18 09:55:08 -05:00
key := authCodeTuple{tr.ClientID, tr.Code}
if t, hasCode := ss.s[key]; hasCode {
defer delete(ss.s, key)
if t.IsValid(time.Now()) {
return t.Cred, true
2022-10-26 19:43:51 -04:00
}
}
2022-11-11 18:02:52 -05:00
return nil, false
}
2022-12-18 09:55:08 -05:00
type Credentials struct {
2022-12-18 21:26:34 -05:00
ID CredID `json:"id"`
UserID UserID `json:"user_id"`
AuthProviderType string `json:"auth_provider_type"`
AuthProviderID *string `json:"auth_provider_id"`
DataRaw *json.RawMessage `json:"data,omitempty"`
User provider.ProviderUser `json:"-"`
2022-11-12 15:56:17 -05:00
}
2022-12-18 09:55:08 -05:00
func (cred *Credentials) MarshalJSON() ([]byte, error) {
type CredAlias Credentials // alias so ø method set and we don't recurse
2022-11-13 09:05:09 -05:00
nCd := (*CredAlias)(cred)
2022-11-12 15:56:17 -05:00
2022-12-18 21:26:34 -05:00
if cred.User != nil {
providerData := cred.User.UserData()
2022-11-13 11:55:10 -05:00
if providerData != nil {
b, err := json.Marshal(providerData)
if err != nil {
return nil, err
}
2022-11-13 09:05:09 -05:00
2022-11-13 11:55:10 -05:00
dr := json.RawMessage(b)
nCd.DataRaw = &dr
}
2022-11-12 15:56:17 -05:00
}
2022-11-13 09:05:09 -05:00
return json.Marshal(nCd)
2022-11-11 18:02:52 -05:00
}
2022-12-19 02:42:01 -05:00
type (
2022-12-19 13:09:01 -05:00
TokenType string
RefreshTokenID string
RefreshTokenToken string
2022-12-19 02:42:01 -05:00
)
2022-12-19 13:09:01 -05:00
func (rti RefreshTokenID) String() string { return string(rti) }
func (rti RefreshTokenToken) IsValid() bool { return rti != "" }
2022-12-19 02:42:01 -05:00
const (
TokenTypeSystem TokenType = "system"
TokenTypeNormal TokenType = "normal"
TokenTypeLongLived TokenType = "long_lived_access_token"
TokenTypeNone TokenType = ""
)
func (tt TokenType) IsValid() bool {
switch tt {
case TokenTypeSystem, TokenTypeNormal, TokenTypeLongLived:
return true
}
return false
}
type RefreshToken struct {
ID RefreshTokenID `json:"id"`
UserID UserID `json:"user_id"`
ClientID *ClientID `json:"client_id"`
ClientName *string `json:"client_name"`
ClientIcon *string `json:"client_icon"`
TokenType TokenType `json:"token_type"`
CreatedAt *common.PyTimestamp `json:"created_at"`
AccessTokenExpiration json.Number `json:"access_token_expiration"`
2022-12-19 13:09:01 -05:00
Token RefreshTokenToken `json:"token"`
2022-12-19 02:42:01 -05:00
JWTKey string `json:"jwt_key"`
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
LastUsedIP *string `json:"last_used_ip"`
CredentialID *CredID `json:"credential_id"`
Version *string `json:"version"`
2022-12-19 19:24:01 -05:00
User *User `json:"-"`
2022-12-19 02:42:01 -05:00
}
func (rt *RefreshToken) IsValid() bool {
return rt.JWTKey != ""
}
2022-12-19 13:09:01 -05:00
func (rt *RefreshToken) AccessExpiration() (exp int64) {
exp, err := rt.AccessTokenExpiration.Int64()
if err != nil {
panic(err)
}
return
}
2022-12-19 02:42:01 -05:00
type RefreshOption func(*RefreshToken)
func WithClientID(cid ClientID) RefreshOption {
return func(rt *RefreshToken) {
rt.ClientID = &cid
}
}
func WithClientName(n string) RefreshOption {
return func(rt *RefreshToken) {
rt.ClientName = &n
}
}
func WithClientIcon(n string) RefreshOption {
return func(rt *RefreshToken) {
rt.ClientIcon = &n
}
}
func WithTokenType(t TokenType) RefreshOption {
return func(rt *RefreshToken) {
rt.TokenType = t
}
}
func WithCredential(c *Credentials) RefreshOption {
return func(rt *RefreshToken) {
rt.CredentialID = &c.ID
}
}
2022-12-19 19:24:01 -05:00
const DefaultAccessExpiration = "1800" // json 🤮
2022-12-19 02:42:01 -05:00
2022-12-19 19:24:01 -05:00
func (a *authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
2022-12-19 13:09:01 -05:00
e := func(es string, arg ...interface{}) (*RefreshToken, error) {
return nil, fmt.Errorf(es, arg...)
2022-12-19 02:42:01 -05:00
}
now := common.PyTimestamp(time.Now())
r := &RefreshToken{
ID: RefreshTokenID(generate.UUID()),
UserID: user.ID,
2022-12-19 13:09:01 -05:00
Token: RefreshTokenToken(generate.Hex(64)),
2022-12-19 02:42:01 -05:00
JWTKey: generate.Hex(64),
CreatedAt: &now,
AccessTokenExpiration: DefaultAccessExpiration,
2022-12-20 13:16:30 -05:00
User: user,
2022-12-19 02:42:01 -05:00
}
for _, opt := range opts {
opt(r)
}
if r.TokenType == TokenTypeNone {
if user.SystemGenerated {
r.TokenType = TokenTypeSystem
} else {
r.TokenType = TokenTypeNormal
}
}
switch {
case !r.TokenType.IsValid():
return e("invalid token type")
case !user.Active:
return e("user is not active")
case user.SystemGenerated && r.ClientID != nil:
return e("system generated users cannot have refresh tokens connected to a client")
case !r.TokenType.IsValid():
return e("invalid token type '%v'", r.TokenType)
case user.SystemGenerated != (r.TokenType == TokenTypeSystem):
return e("system generated user can only have system type refresh tokens")
case r.TokenType == TokenTypeNormal && r.ClientID == nil:
return e("client is required to generate a refresh token")
case r.TokenType == TokenTypeLongLived && r.ClientName == nil:
return e("client name is required for long-lived token")
}
if r.TokenType == TokenTypeLongLived {
for _, lv := range user.RefreshTokens {
if strPtrEq(lv.ClientName, r.ClientName) && lv.TokenType == TokenTypeLongLived {
return e("client name '%v' already exists", *r.ClientName)
}
}
}
return a.store.PutRefreshToken(r)
}
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
now := time.Now()
pytnow := common.PyTimestamp(now)
r.LastUsedAt = &pytnow
r.LastUsedIP = &req.RemoteAddr
2022-12-19 13:09:01 -05:00
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
Issuer: r.ID.String(),
IssuedAt: now.Unix(),
ExpiresAt: now.Add(time.Duration(r.AccessExpiration()) * time.Second).Unix(),
2022-12-19 02:42:01 -05:00
}).SignedString([]byte(r.JWTKey))
}
2022-12-20 13:16:30 -05:00
func (a *authenticator) ValidateAccessToken(token AccessToken) *RefreshToken {
panic("not implemented")
return nil
}
2022-12-19 19:24:01 -05:00
func (a *authenticator) verifyAndGetCredential(tr *TokenRequest) *Credentials {
2022-12-19 13:09:01 -05:00
cred, success := a.authCodes.get(tr)
2022-11-11 18:02:52 -05:00
if !success {
return nil
}
2022-11-13 09:05:09 -05:00
return cred
2022-11-11 18:02:52 -05:00
}
2022-11-13 11:55:10 -05:00
const defaultExpiration = 15 * time.Minute
2022-10-26 19:43:51 -04:00
2022-12-19 19:24:01 -05:00
func (a *authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
2022-12-19 13:09:01 -05:00
return a.authCodes.put(clientID, cred)
2022-10-26 19:13:50 -04:00
}
2022-10-26 19:43:51 -04:00
2022-11-11 18:02:52 -05:00
type GrantType string
const (
2022-12-19 13:09:01 -05:00
GrantAuthCode GrantType = "authorization_code"
GrantRefreshToken GrantType = "refresh_token"
2022-11-11 18:02:52 -05:00
)
2022-11-20 08:49:24 -05:00
type ClientID common.ClientID
2022-11-11 18:02:52 -05:00
func (c *ClientID) IsValid() bool {
// TODO: || !indieauth.VerifyClientID(rq.ClientID)?
return *c != ""
}
2022-12-18 09:55:08 -05:00
type AuthCode string
2022-12-18 21:26:34 -05:00
2022-12-18 09:55:08 -05:00
func (ac *AuthCode) IsValid() bool {
return *ac != ""
}
2022-10-26 19:43:51 -04:00
type TokenRequest struct {
2022-12-19 13:09:01 -05:00
ClientID ClientID `form:"client_id"`
Code AuthCode `form:"code"`
GrantType GrantType `form:"grant_type"`
RefreshToken RefreshTokenToken `form:"refresh_token"`
2022-10-26 19:43:51 -04:00
}
2022-12-19 02:42:01 -05:00
const AuthFailed = "authentication failure"
2022-12-19 19:24:01 -05:00
func (a *authenticator) TokenHandler(c echo.Context) error {
2022-12-19 13:09:01 -05:00
a.Lock()
defer a.Unlock()
2022-11-11 18:02:52 -05:00
rq := new(TokenRequest)
err := c.Bind(rq)
2022-10-26 19:43:51 -04:00
if err != nil {
return err
}
2022-11-11 18:02:52 -05:00
switch rq.GrantType {
2022-12-19 13:09:01 -05:00
case GrantAuthCode:
2022-11-11 18:02:52 -05:00
if !rq.ClientID.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
}
if !rq.Code.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
}
2022-12-19 13:09:01 -05:00
cred := a.verifyAndGetCredential(rq)
if cred == nil {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
}
2022-11-11 18:02:52 -05:00
2022-12-19 13:09:01 -05:00
user, err := a.getOrCreateUser(cred)
if err != nil {
log.Error().Err(err).Msg("getOrCreateUser")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
2022-12-19 02:42:01 -05:00
2022-12-19 13:09:01 -05:00
if err := user.allowedToAuth(c.Request()); err != nil {
log.Error().Err(err).Msg("allowedToAuth")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
2022-12-19 02:42:01 -05:00
2022-12-19 13:09:01 -05:00
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
if err != nil {
log.Error().Err(err).Msg("NewRefreshToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
2022-12-19 02:42:01 -05:00
2022-12-19 13:09:01 -05:00
at, err := rt.AccessToken(c.Request())
if err != nil {
log.Error().Err(err).Msg("AccessToken")
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
}
return common.NoCache(c).JSON(http.StatusOK, &struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken RefreshTokenToken `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
HAAuthProvider string `json:"ha_auth_provider"`
}{
AccessToken: at,
TokenType: "Bearer",
RefreshToken: rt.Token,
ExpiresIn: rt.AccessExpiration(),
HAAuthProvider: cred.AuthProviderType,
})
case GrantRefreshToken:
log.Debug().Interface("request", c.Request()).Interface("tokenRequest", rq).Msg("grant_type=refresh_token")
if !rq.ClientID.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
}
if !rq.RefreshToken.IsValid() {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
rt := a.store.GetRefreshTokenByToken(rq.RefreshToken)
if rt == nil {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_grant"})
}
if rt.ClientID == nil || *rt.ClientID != rq.ClientID {
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
}
2022-12-19 02:42:01 -05:00
2022-12-19 19:24:01 -05:00
if err := rt.User.allowedToAuth(c.Request()); err != nil {
2022-12-19 13:09:01 -05:00
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
}
at, err := rt.AccessToken(c.Request())
if err != nil {
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: err.Error()})
}
return common.NoCache(c).JSON(http.StatusOK, &struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}{
AccessToken: at,
TokenType: "Bearer",
ExpiresIn: rt.AccessExpiration(),
})
2022-10-26 19:43:51 -04:00
}
2022-12-19 13:09:01 -05:00
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request"})
2022-10-26 19:43:51 -04:00
}
2022-12-19 19:24:01 -05:00
type AccessToken string