diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index 991a9a0..2f06622 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -20,6 +20,8 @@ var ( ErrDisabled = errors.New("user disabled") ErrInvalidAuth = errors.New("invalid auth") ErrInvalidHandler = errors.New("no such handler") + ErrInvalidIP = errors.New("invalid IP") + ErrUserAuthRemote = errors.New("user cannot authenticate remotely") ) type Authenticator struct { diff --git a/pkg/auth/flow.go b/pkg/auth/flow.go index 8bd8681..f4dee95 100644 --- a/pkg/auth/flow.go +++ b/pkg/auth/flow.go @@ -107,7 +107,7 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error { user, clientID, err := a.Check(f, c.Request(), rm) switch err { case nil: - creds := a.store.Credential(user) + creds := a.store.GetCredential(user) finishedFlow := flow.Result{} a.flows.Remove(f) copier.Copy(&finishedFlow, f) diff --git a/pkg/auth/session.go b/pkg/auth/session.go index 17a4fd3..f64c632 100644 --- a/pkg/auth/session.go +++ b/pkg/auth/session.go @@ -2,42 +2,19 @@ package auth import ( "encoding/json" + "fmt" "net/http" "time" + "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" "dynatron.me/x/blasphem/internal/common" "dynatron.me/x/blasphem/internal/generate" "dynatron.me/x/blasphem/pkg/auth/provider" ) -type ( - TokenType string - RefreshTokenID string -) - -type RefreshToken struct { - ID RefreshTokenID `json:"id"` - UserID UserID `json:"user_id"` - ClientID *ClientID `json:"client_id"` - ClientName *string `json:"client_name"` - ClientIcon *string `json:"client_icon"` - TokenType TokenType `json:"token_type"` - CreatedAt *common.PyTimestamp `json:"created_at"` - AccessTokenExpiration json.Number `json:"access_token_expiration"` - Token string `json:"token"` - JWTKey string `json:"jwt_key"` - LastUsedAt *common.PyTimestamp `json:"last_used_at"` - LastUsedIP *string `json:"last_used_ip"` - CredentialID *CredID `json:"credential_id"` - Version *string `json:"version"` -} - -func (rt *RefreshToken) IsValid() bool { - return rt.JWTKey != "" -} - type authCodeStore struct { s map[authCodeTuple]flowResult lastCull time.Time @@ -99,7 +76,6 @@ func (ss *authCodeStore) verify(tr *TokenRequest, r *http.Request) (*Credentials key := authCodeTuple{tr.ClientID, tr.Code} if t, hasCode := ss.s[key]; hasCode { defer delete(ss.s, key) - // TODO: JWT if t.IsValid(time.Now()) { return t.Cred, true } @@ -138,6 +114,156 @@ func (cred *Credentials) MarshalJSON() ([]byte, error) { return json.Marshal(nCd) } +type ( + TokenType string + RefreshTokenID string +) + +const ( + TokenTypeSystem TokenType = "system" + TokenTypeNormal TokenType = "normal" + TokenTypeLongLived TokenType = "long_lived_access_token" + TokenTypeNone TokenType = "" +) + +func (tt TokenType) IsValid() bool { + switch tt { + case TokenTypeSystem, TokenTypeNormal, TokenTypeLongLived: + return true + } + + return false +} + +type RefreshToken struct { + ID RefreshTokenID `json:"id"` + UserID UserID `json:"user_id"` + ClientID *ClientID `json:"client_id"` + ClientName *string `json:"client_name"` + ClientIcon *string `json:"client_icon"` + TokenType TokenType `json:"token_type"` + CreatedAt *common.PyTimestamp `json:"created_at"` + AccessTokenExpiration json.Number `json:"access_token_expiration"` + Token string `json:"token"` + JWTKey string `json:"jwt_key"` + LastUsedAt *common.PyTimestamp `json:"last_used_at"` + LastUsedIP *string `json:"last_used_ip"` + CredentialID *CredID `json:"credential_id"` + Version *string `json:"version"` +} + +func (rt *RefreshToken) IsValid() bool { + return rt.JWTKey != "" +} + +type RefreshOption func(*RefreshToken) + +func WithClientID(cid ClientID) RefreshOption { + return func(rt *RefreshToken) { + rt.ClientID = &cid + } +} + +func WithClientName(n string) RefreshOption { + return func(rt *RefreshToken) { + rt.ClientName = &n + } +} + +func WithClientIcon(n string) RefreshOption { + return func(rt *RefreshToken) { + rt.ClientIcon = &n + } +} + +func WithTokenType(t TokenType) RefreshOption { + return func(rt *RefreshToken) { + rt.TokenType = t + } +} + +func WithCredential(c *Credentials) RefreshOption { + return func(rt *RefreshToken) { + rt.CredentialID = &c.ID + } +} + +const DefaultAccessExpiration = "1800" + +func (a *Authenticator) NewRefreshToken(user *User, opts ...RefreshOption) (*RefreshToken, error) { + e := func(es string, a ...interface{}) (*RefreshToken, error) { + return nil, fmt.Errorf(es, a...) + } + + now := common.PyTimestamp(time.Now()) + + r := &RefreshToken{ + ID: RefreshTokenID(generate.UUID()), + UserID: user.ID, + Token: generate.Hex(64), + JWTKey: generate.Hex(64), + CreatedAt: &now, + AccessTokenExpiration: DefaultAccessExpiration, + } + + for _, opt := range opts { + opt(r) + } + + if r.TokenType == TokenTypeNone { + if user.SystemGenerated { + r.TokenType = TokenTypeSystem + } else { + r.TokenType = TokenTypeNormal + } + } + + switch { + case !r.TokenType.IsValid(): + return e("invalid token type") + case !user.Active: + return e("user is not active") + case user.SystemGenerated && r.ClientID != nil: + return e("system generated users cannot have refresh tokens connected to a client") + case !r.TokenType.IsValid(): + return e("invalid token type '%v'", r.TokenType) + case user.SystemGenerated != (r.TokenType == TokenTypeSystem): + return e("system generated user can only have system type refresh tokens") + case r.TokenType == TokenTypeNormal && r.ClientID == nil: + return e("client is required to generate a refresh token") + case r.TokenType == TokenTypeLongLived && r.ClientName == nil: + return e("client name is required for long-lived token") + } + + if r.TokenType == TokenTypeLongLived { + for _, lv := range user.RefreshTokens { + if strPtrEq(lv.ClientName, r.ClientName) && lv.TokenType == TokenTypeLongLived { + return e("client name '%v' already exists", *r.ClientName) + } + } + } + + return a.store.PutRefreshToken(r) +} + +func (r *RefreshToken) AccessToken(req *http.Request) (string, error) { + now := time.Now() + exp, err := r.AccessTokenExpiration.Int64() + if err != nil { + return "", err + } + + pytnow := common.PyTimestamp(now) + r.LastUsedAt = &pytnow + r.LastUsedIP = &req.RemoteAddr + + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iss": r.ID, + "iat": now, + "exp": now.Add(time.Duration(exp) * time.Second), + }).SignedString([]byte(r.JWTKey)) +} + func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credentials { cred, success := a.authCodes.verify(tr, r) if !success { @@ -179,6 +305,8 @@ type TokenRequest struct { GrantType GrantType `form:"grant_type"` } +const AuthFailed = "authentication failure" + func (a *Authenticator) TokenHandler(c echo.Context) error { rq := new(TokenRequest) err := c.Bind(rq) @@ -200,13 +328,45 @@ func (a *Authenticator) TokenHandler(c echo.Context) error { // TODO: success user, err := a.getOrCreateUser(cred) if err != nil { - return c.JSON(http.StatusUnauthorized, AuthError{Error: "access_denied", Description: err.Error()}) + log.Error().Err(err).Msg("getOrCreateUser") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } - if err := user.allowedToAuth(); err != nil { - return c.JSON(http.StatusUnauthorized, AuthError{Error: "access_denied", Description: err.Error()}) + if err := user.allowedToAuth(c.Request()); err != nil { + log.Error().Err(err).Msg("allowedToAuth") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) } - return c.String(http.StatusOK, "token good I guess") + + // TODO: create a refresh token, return it and refreshtoken.AccessToken() + rt, err := a.NewRefreshToken(user, WithClientID(rq.ClientID), WithCredential(cred)) + if err != nil { + log.Error().Err(err).Msg("NewRefreshToken") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + at, err := rt.AccessToken(c.Request()) + if err != nil { + log.Error().Err(err).Msg("AccessToken") + return c.JSON(http.StatusForbidden, AuthError{Error: "access_denied", Description: AuthFailed}) + } + + exp, _ := rt.AccessTokenExpiration.Int64() + + successResp := struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + HAAuthProvider string `json:"ha_auth_provider"` + }{ + AccessToken: at, + TokenType: "Bearer", + RefreshToken: rt.Token, + ExpiresIn: exp, + HAAuthProvider: cred.AuthProviderType, + } + + return c.JSON(http.StatusOK, &successResp) } case GTRefreshToken: return c.String(http.StatusNotImplemented, "not implemented") diff --git a/pkg/auth/store.go b/pkg/auth/store.go index e8d854f..2b13df3 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -17,10 +17,13 @@ const ( type AuthStore interface { User(UserID) *User - Credential(provider.ProviderUser) *Credentials + GetCredential(provider.ProviderUser) *Credentials + PutRefreshToken(*RefreshToken) (*RefreshToken, error) } type authStore struct { + storage.Item `json:"-"` + Users []*User `json:"users"` Groups []*Group `json:"groups"` Credentials []*Credentials `json:"credentials"` @@ -28,18 +31,27 @@ type authStore struct { userMap map[UserID]*User providerUsers map[provider.ProviderUser]*Credentials + store storage.Store +} + +func (as *authStore) sync() { + err := as.store.Flush(as.ItemKey()) + if err != nil { + log.Error().Err(err).Msg("sync authStore") + } } func strPtrEq(n1, n2 *string) bool { return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2)) } -func (as *authStore) Credential(p provider.ProviderUser) *Credentials { +func (as *authStore) GetCredential(p provider.ProviderUser) *Credentials { for _, cr := range as.Credentials { - if p.Provider() != nil && - strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && - cr.AuthProviderType == p.Provider().ProviderType() && - p.Provider().EqualCreds(cr.User.UserData(), p.UserData()) { + if p != nil && (p == cr.User || + (p.Provider() != nil && + strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) && + cr.AuthProviderType == p.Provider().ProviderType() && + p.Provider().EqualCreds(cr.User.UserData(), p.UserData()))) { return cr } } @@ -47,6 +59,23 @@ func (as *authStore) Credential(p provider.ProviderUser) *Credentials { return nil } +func (as *authStore) PutRefreshToken(rt *RefreshToken) (*RefreshToken, error) { + e := func(es string, a ...interface{}) (*RefreshToken, error) { + return nil, fmt.Errorf(es, a...) + } + + u, hasUser := as.userMap[rt.UserID] + if !hasUser { + return e("no such user %v", rt.UserID) + } + + as.Refresh = append(as.Refresh, rt) + u.RefreshTokens = append(u.RefreshTokens, rt) + + as.sync() + return rt, nil +} + func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { // XXX: probably broken prov := p.Provider() @@ -61,8 +90,14 @@ func (as *authStore) newCredential(p provider.ProviderUser) *Credentials { } func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) { - as = &authStore{} - err = s.Get(AuthStoreKey, as) + as = &authStore{ + store: s, + } + + as.Item, err = s.GetItem(AuthStoreKey, as) + if err != nil { + return + } as.userMap = make(map[UserID]*User) as.providerUsers = make(map[provider.ProviderUser]*Credentials) @@ -94,7 +129,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) u, hasUser := as.userMap[c.UserID] if !hasUser { - log.Error().Str("userid", string(c.UserID)).Msg("no such userid in map") + log.Error().Str("userid", string(c.UserID)).Msg("creds no such userid in map") continue } @@ -105,8 +140,15 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error) i := 0 for _, rt := range as.Refresh { if rt.IsValid() { + u, hasUser := as.userMap[rt.UserID] + if !hasUser { + log.Error().Str("userid", string(rt.UserID)).Msg("refreshtokens no such userid in map") + continue + } + as.Refresh[i] = rt i++ + u.RefreshTokens = append(u.RefreshTokens, rt) } } diff --git a/pkg/auth/user.go b/pkg/auth/user.go index e65b789..19c1180 100644 --- a/pkg/auth/user.go +++ b/pkg/auth/user.go @@ -1,6 +1,9 @@ package auth -import () +import ( + "net" + "net/http" +) type UserID string type GroupID string @@ -17,7 +20,8 @@ type User struct { Data interface{} `json:"data,omitempty"` UserMetadata - Creds []*Credentials `json:"-"` + Creds []*Credentials `json:"-"` + RefreshTokens []*RefreshToken `json:"-"` } type UserMetadata struct { @@ -28,12 +32,25 @@ type UserMetadata struct { LocalOnly bool `json:"local_only"` } -func (u *User) allowedToAuth() error { +func (u *User) allowedToAuth(r *http.Request) error { if !u.Active { return ErrDisabled } - return nil + if !u.LocalOnly { + return nil + } + + ip := net.ParseIP(r.RemoteAddr) + if ip == nil { + return ErrInvalidIP + } + + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() { + return nil + } + + return ErrUserAuthRemote } func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) {