2022-10-26 19:13:50 -04:00
|
|
|
package auth
|
|
|
|
|
|
|
|
import (
|
2022-11-12 15:56:17 -05:00
|
|
|
"encoding/json"
|
2022-12-19 02:42:01 -05:00
|
|
|
"fmt"
|
2022-10-26 19:13:50 -04:00
|
|
|
"net/http"
|
|
|
|
"time"
|
2022-10-26 19:43:51 -04:00
|
|
|
|
2022-12-19 02:42:01 -05:00
|
|
|
"github.com/golang-jwt/jwt"
|
2022-10-26 19:43:51 -04:00
|
|
|
"github.com/labstack/echo/v4"
|
2022-12-19 02:42:01 -05:00
|
|
|
"github.com/rs/zerolog/log"
|
2022-11-12 15:56:17 -05:00
|
|
|
|
2022-11-13 12:03:58 -05:00
|
|
|
"dynatron.me/x/blasphem/internal/common"
|
2022-11-13 11:55:10 -05:00
|
|
|
"dynatron.me/x/blasphem/internal/generate"
|
2022-11-12 15:56:17 -05:00
|
|
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
2022-10-26 19:13:50 -04:00
|
|
|
)
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
type authCodeStore struct {
|
|
|
|
s map[authCodeTuple]flowResult
|
2022-10-26 19:13:50 -04:00
|
|
|
lastCull time.Time
|
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
type authCodeTuple struct {
|
|
|
|
ClientID ClientID
|
2022-12-18 21:26:34 -05:00
|
|
|
Code AuthCode
|
2022-12-18 09:55:08 -05:00
|
|
|
}
|
2022-10-26 19:13:50 -04:00
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (t *authCodeTuple) IsValid() bool {
|
2022-11-11 18:02:52 -05:00
|
|
|
// TODO: more validation than this
|
2022-12-18 09:55:08 -05:00
|
|
|
return t.Code != ""
|
|
|
|
}
|
|
|
|
|
|
|
|
type flowResult struct {
|
|
|
|
Time time.Time
|
|
|
|
Cred *Credentials
|
|
|
|
// TODO: remove this comment below \/
|
|
|
|
//user provider.ProviderUser `json:"-"`
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
// OAuth 4.2.1 spec recommends 10 minutes
|
|
|
|
const authCodeExpire = 10 * time.Minute
|
2022-11-11 18:02:52 -05:00
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (f *flowResult) IsValid(now time.Time) bool {
|
|
|
|
if now.After(f.Time.Add(authCodeExpire)) {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return true
|
2022-10-26 19:13:50 -04:00
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (ss *authCodeStore) init() {
|
|
|
|
ss.s = make(map[authCodeTuple]flowResult)
|
2022-10-26 19:13:50 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
const cullInterval = 5 * time.Minute
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (ss *authCodeStore) cull() {
|
2022-10-26 19:13:50 -04:00
|
|
|
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
|
|
|
for k, v := range ss.s {
|
2022-12-18 09:55:08 -05:00
|
|
|
if !v.IsValid(now) {
|
2022-10-26 19:13:50 -04:00
|
|
|
delete(ss.s, k)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
|
2022-10-26 19:13:50 -04:00
|
|
|
ss.cull()
|
2022-12-18 09:55:08 -05:00
|
|
|
code := generate.UUID()
|
|
|
|
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
|
|
|
|
|
|
|
|
return code
|
2022-10-26 19:13:50 -04:00
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials, bool) {
|
|
|
|
key := authCodeTuple{tr.ClientID, tr.Code}
|
|
|
|
if t, hasCode := ss.s[key]; hasCode {
|
|
|
|
defer delete(ss.s, key)
|
|
|
|
if t.IsValid(time.Now()) {
|
|
|
|
return t.Cred, true
|
2022-10-26 19:43:51 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-11-11 18:02:52 -05:00
|
|
|
return nil, false
|
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
type Credentials struct {
|
2022-12-18 21:26:34 -05:00
|
|
|
ID CredID `json:"id"`
|
|
|
|
UserID UserID `json:"user_id"`
|
|
|
|
AuthProviderType string `json:"auth_provider_type"`
|
|
|
|
AuthProviderID *string `json:"auth_provider_id"`
|
|
|
|
DataRaw *json.RawMessage `json:"data,omitempty"`
|
|
|
|
|
|
|
|
User provider.ProviderUser `json:"-"`
|
2022-11-12 15:56:17 -05:00
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (cred *Credentials) MarshalJSON() ([]byte, error) {
|
|
|
|
type CredAlias Credentials // alias so ø method set and we don't recurse
|
2022-11-13 09:05:09 -05:00
|
|
|
nCd := (*CredAlias)(cred)
|
2022-11-12 15:56:17 -05:00
|
|
|
|
2022-12-18 21:26:34 -05:00
|
|
|
if cred.User != nil {
|
|
|
|
providerData := cred.User.UserData()
|
2022-11-13 11:55:10 -05:00
|
|
|
if providerData != nil {
|
|
|
|
b, err := json.Marshal(providerData)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2022-11-13 09:05:09 -05:00
|
|
|
|
2022-11-13 11:55:10 -05:00
|
|
|
dr := json.RawMessage(b)
|
|
|
|
nCd.DataRaw = &dr
|
|
|
|
}
|
2022-11-12 15:56:17 -05:00
|
|
|
}
|
|
|
|
|
2022-11-13 09:05:09 -05:00
|
|
|
return json.Marshal(nCd)
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
|
|
|
|
2022-12-19 02:42:01 -05:00
|
|
|
type (
|
|
|
|
TokenType string
|
|
|
|
RefreshTokenID string
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
TokenTypeSystem TokenType = "system"
|
|
|
|
TokenTypeNormal TokenType = "normal"
|
|
|
|
TokenTypeLongLived TokenType = "long_lived_access_token"
|
|
|
|
TokenTypeNone TokenType = ""
|
|
|
|
)
|
|
|
|
|
|
|
|
func (tt TokenType) IsValid() bool {
|
|
|
|
switch tt {
|
|
|
|
case TokenTypeSystem, TokenTypeNormal, TokenTypeLongLived:
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
type RefreshToken struct {
|
|
|
|
ID RefreshTokenID `json:"id"`
|
|
|
|
UserID UserID `json:"user_id"`
|
|
|
|
ClientID *ClientID `json:"client_id"`
|
|
|
|
ClientName *string `json:"client_name"`
|
|
|
|
ClientIcon *string `json:"client_icon"`
|
|
|
|
TokenType TokenType `json:"token_type"`
|
|
|
|
CreatedAt *common.PyTimestamp `json:"created_at"`
|
|
|
|
AccessTokenExpiration json.Number `json:"access_token_expiration"`
|
|
|
|
Token string `json:"token"`
|
|
|
|
JWTKey string `json:"jwt_key"`
|
|
|
|
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
|
|
|
|
LastUsedIP *string `json:"last_used_ip"`
|
|
|
|
CredentialID *CredID `json:"credential_id"`
|
|
|
|
Version *string `json:"version"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func (rt *RefreshToken) IsValid() bool {
|
|
|
|
return rt.JWTKey != ""
|
|
|
|
}
|
|
|
|
|
|
|
|
type RefreshOption func(*RefreshToken)
|
|
|
|
|
|
|
|
func WithClientID(cid ClientID) RefreshOption {
|
|
|
|
return func(rt *RefreshToken) {
|
|
|
|
rt.ClientID = &cid
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithClientName(n string) RefreshOption {
|
|
|
|
return func(rt *RefreshToken) {
|
|
|
|
rt.ClientName = &n
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithClientIcon(n string) RefreshOption {
|
|
|
|
return func(rt *RefreshToken) {
|
|
|
|
rt.ClientIcon = &n
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithTokenType(t TokenType) RefreshOption {
|
|
|
|
return func(rt *RefreshToken) {
|
|
|
|
rt.TokenType = t
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithCredential(c *Credentials) RefreshOption {
|
|
|
|
return func(rt *RefreshToken) {
|
|
|
|
rt.CredentialID = &c.ID
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const DefaultAccessExpiration = "1800"
|
|
|
|
|
|
|
|
func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) {
|
|
|
|
e := func(es string, a ...interface{}) (*RefreshToken, error) {
|
|
|
|
return nil, fmt.Errorf(es, a...)
|
|
|
|
}
|
|
|
|
|
|
|
|
now := common.PyTimestamp(time.Now())
|
|
|
|
|
|
|
|
r := &RefreshToken{
|
|
|
|
ID: RefreshTokenID(generate.UUID()),
|
|
|
|
UserID: user.ID,
|
|
|
|
Token: generate.Hex(64),
|
|
|
|
JWTKey: generate.Hex(64),
|
|
|
|
CreatedAt: &now,
|
|
|
|
AccessTokenExpiration: DefaultAccessExpiration,
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
opt(r)
|
|
|
|
}
|
|
|
|
|
|
|
|
if r.TokenType == TokenTypeNone {
|
|
|
|
if user.SystemGenerated {
|
|
|
|
r.TokenType = TokenTypeSystem
|
|
|
|
} else {
|
|
|
|
r.TokenType = TokenTypeNormal
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
switch {
|
|
|
|
case !r.TokenType.IsValid():
|
|
|
|
return e("invalid token type")
|
|
|
|
case !user.Active:
|
|
|
|
return e("user is not active")
|
|
|
|
case user.SystemGenerated && r.ClientID != nil:
|
|
|
|
return e("system generated users cannot have refresh tokens connected to a client")
|
|
|
|
case !r.TokenType.IsValid():
|
|
|
|
return e("invalid token type '%v'", r.TokenType)
|
|
|
|
case user.SystemGenerated != (r.TokenType == TokenTypeSystem):
|
|
|
|
return e("system generated user can only have system type refresh tokens")
|
|
|
|
case r.TokenType == TokenTypeNormal && r.ClientID == nil:
|
|
|
|
return e("client is required to generate a refresh token")
|
|
|
|
case r.TokenType == TokenTypeLongLived && r.ClientName == nil:
|
|
|
|
return e("client name is required for long-lived token")
|
|
|
|
}
|
|
|
|
|
|
|
|
if r.TokenType == TokenTypeLongLived {
|
|
|
|
for _, lv := range user.RefreshTokens {
|
|
|
|
if strPtrEq(lv.ClientName, r.ClientName) && lv.TokenType == TokenTypeLongLived {
|
|
|
|
return e("client name '%v' already exists", *r.ClientName)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return a.store.PutRefreshToken(r)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *RefreshToken) AccessToken(req *http.Request) (string, error) {
|
|
|
|
now := time.Now()
|
|
|
|
exp, err := r.AccessTokenExpiration.Int64()
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
pytnow := common.PyTimestamp(now)
|
|
|
|
r.LastUsedAt = &pytnow
|
|
|
|
r.LastUsedIP = &req.RemoteAddr
|
|
|
|
|
|
|
|
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
|
|
"iss": r.ID,
|
|
|
|
"iat": now,
|
|
|
|
"exp": now.Add(time.Duration(exp) * time.Second),
|
|
|
|
}).SignedString([]byte(r.JWTKey))
|
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials {
|
|
|
|
cred, success := a.authCodes.verify(tr, r)
|
2022-11-11 18:02:52 -05:00
|
|
|
if !success {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2022-11-13 09:05:09 -05:00
|
|
|
return cred
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
|
|
|
|
2022-11-13 11:55:10 -05:00
|
|
|
const defaultExpiration = 15 * time.Minute
|
2022-10-26 19:43:51 -04:00
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (a *Authenticator) NewAuthCode(clientID ClientID, cred *Credentials) string {
|
|
|
|
return a.authCodes.store(clientID, cred)
|
2022-10-26 19:13:50 -04:00
|
|
|
}
|
2022-10-26 19:43:51 -04:00
|
|
|
|
2022-11-11 18:02:52 -05:00
|
|
|
type GrantType string
|
|
|
|
|
|
|
|
const (
|
|
|
|
GTAuthorizationCode GrantType = "authorization_code"
|
2022-11-12 17:42:51 -05:00
|
|
|
GTRefreshToken GrantType = "refresh_token"
|
2022-11-11 18:02:52 -05:00
|
|
|
)
|
|
|
|
|
2022-11-20 08:49:24 -05:00
|
|
|
type ClientID common.ClientID
|
2022-11-11 18:02:52 -05:00
|
|
|
|
|
|
|
func (c *ClientID) IsValid() bool {
|
|
|
|
// TODO: || !indieauth.VerifyClientID(rq.ClientID)?
|
|
|
|
return *c != ""
|
|
|
|
}
|
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
type AuthCode string
|
2022-12-18 21:26:34 -05:00
|
|
|
|
2022-12-18 09:55:08 -05:00
|
|
|
func (ac *AuthCode) IsValid() bool {
|
|
|
|
return *ac != ""
|
|
|
|
}
|
|
|
|
|
2022-10-26 19:43:51 -04:00
|
|
|
type TokenRequest struct {
|
2022-12-18 21:26:34 -05:00
|
|
|
ClientID ClientID `form:"client_id"`
|
|
|
|
Code AuthCode `form:"code"`
|
|
|
|
GrantType GrantType `form:"grant_type"`
|
2022-10-26 19:43:51 -04:00
|
|
|
}
|
|
|
|
|
2022-12-19 02:42:01 -05:00
|
|
|
const AuthFailed = "authentication failure"
|
|
|
|
|
2022-10-26 19:43:51 -04:00
|
|
|
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
2022-11-11 18:02:52 -05:00
|
|
|
rq := new(TokenRequest)
|
|
|
|
err := c.Bind(rq)
|
2022-10-26 19:43:51 -04:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2022-11-11 18:02:52 -05:00
|
|
|
switch rq.GrantType {
|
|
|
|
case GTAuthorizationCode:
|
|
|
|
if !rq.ClientID.IsValid() {
|
|
|
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid client ID"})
|
|
|
|
}
|
|
|
|
|
|
|
|
if !rq.Code.IsValid() {
|
|
|
|
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
|
|
|
}
|
|
|
|
|
2022-11-13 09:05:09 -05:00
|
|
|
if cred := a.verifyAndGetCredential(rq, c.Request()); cred != nil {
|
2022-11-11 18:02:52 -05:00
|
|
|
// TODO: success
|
|
|
|
user, err := a.getOrCreateUser(cred)
|
|
|
|
if err != nil {
|
2022-12-19 02:42:01 -05:00
|
|
|
log.Error().Err(err).Msg("getOrCreateUser")
|
|
|
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
|
|
|
|
2022-12-19 02:42:01 -05:00
|
|
|
if err := user.allowedToAuth(c.Request()); err != nil {
|
|
|
|
log.Error().Err(err).Msg("allowedToAuth")
|
|
|
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: create a refresh token, return it and refreshtoken.AccessToken()
|
|
|
|
rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred))
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("NewRefreshToken")
|
|
|
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
2022-12-19 02:42:01 -05:00
|
|
|
|
|
|
|
at, err := rt.AccessToken(c.Request())
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("AccessToken")
|
|
|
|
return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed})
|
|
|
|
}
|
|
|
|
|
|
|
|
exp, _ := rt.AccessTokenExpiration.Int64()
|
|
|
|
|
|
|
|
successResp := struct {
|
|
|
|
AccessToken string `json:"access_token"`
|
|
|
|
TokenType string `json:"token_type"`
|
|
|
|
RefreshToken string `json:"refresh_token"`
|
|
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
|
|
HAAuthProvider string `json:"ha_auth_provider"`
|
|
|
|
}{
|
|
|
|
AccessToken: at,
|
|
|
|
TokenType: "Bearer",
|
|
|
|
RefreshToken: rt.Token,
|
|
|
|
ExpiresIn: exp,
|
|
|
|
HAAuthProvider: cred.AuthProviderType,
|
|
|
|
}
|
|
|
|
|
|
|
|
return c.JSON(http.StatusOK, &successResp)
|
2022-11-11 18:02:52 -05:00
|
|
|
}
|
|
|
|
case GTRefreshToken:
|
|
|
|
return c.String(http.StatusNotImplemented, "not implemented")
|
2022-10-26 19:43:51 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
return c.String(http.StatusUnauthorized, "token bad I guess")
|
|
|
|
}
|