Compare commits
2 commits
3981025fa4
...
9b38bdbca9
Author | SHA1 | Date | |
---|---|---|---|
9b38bdbca9 | |||
43682fab05 |
11 changed files with 255 additions and 114 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
// common contains common functionality for blasphem.
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -5,6 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// AppName is the name of the application.
|
||||||
AppName = "blasphem"
|
AppName = "blasphem"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +15,7 @@ type cmdOptions interface {
|
||||||
Execute() error
|
Execute() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunE is a convenience function for use with cobra.
|
||||||
func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
||||||
return func(cmd *cobra.Command, args []string) error {
|
return func(cmd *cobra.Command, args []string) error {
|
||||||
err := c.Options(cmd, args)
|
err := c.Options(cmd, args)
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
@ -28,7 +25,7 @@ var (
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
store AuthStore
|
store AuthStore
|
||||||
flows FlowStore
|
flows FlowStore
|
||||||
sessions SessionStore
|
sessions AccessSessionStore
|
||||||
providers map[string]provider.AuthProvider
|
providers map[string]provider.AuthProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,26 +111,10 @@ func (a *Authenticator) Check(f *Flow, req *http.Request, rm map[string]interfac
|
||||||
user, success := p.ValidateCreds(req, rm)
|
user, success := p.ValidateCreds(req, rm)
|
||||||
|
|
||||||
if success {
|
if success {
|
||||||
log.Info().Interface("user", user.ProviderUserData()).Msg("Login success")
|
log.Info().Interface("user", user.UserData()).Msg("Login success")
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidAuth
|
return nil, ErrInvalidAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func genUUID() string {
|
|
||||||
// must be addressable
|
|
||||||
u := uuid.New()
|
|
||||||
|
|
||||||
return hex.EncodeToString(u[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func genHex(l int) string {
|
|
||||||
b := make([]byte, l)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
"dynatron.me/x/blasphem/internal/common"
|
"dynatron.me/x/blasphem/internal/common"
|
||||||
|
"dynatron.me/x/blasphem/internal/generate"
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,7 +101,7 @@ func (a *Authenticator) NewFlow(r *FlowRequest) *Flow {
|
||||||
|
|
||||||
flow := &Flow{
|
flow := &Flow{
|
||||||
Type: TypeForm,
|
Type: TypeForm,
|
||||||
ID: FlowID(genUUID()),
|
ID: FlowID(generate.UUID()),
|
||||||
StepID: stepPtr(StepInit),
|
StepID: stepPtr(StepInit),
|
||||||
Schema: sch,
|
Schema: sch,
|
||||||
Handler: r.Handler,
|
Handler: r.Handler,
|
||||||
|
@ -146,12 +147,12 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
var finishedFlow struct {
|
var finishedFlow struct {
|
||||||
ID FlowID `json:"flow_id"`
|
ID FlowID `json:"flow_id"`
|
||||||
Handler []*string `json:"handler"`
|
Handler []*string `json:"handler"`
|
||||||
Result TokenID `json:"result"`
|
Result AccessTokenID `json:"result"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Type FlowType `json:"type"`
|
Type FlowType `json:"type"`
|
||||||
Version int `json:"version"`
|
Version int `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
a.flows.Remove(f)
|
a.flows.Remove(f)
|
||||||
|
@ -159,7 +160,7 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
||||||
finishedFlow.Type = TypeCreateEntry
|
finishedFlow.Type = TypeCreateEntry
|
||||||
finishedFlow.Title = common.AppName
|
finishedFlow.Title = common.AppName
|
||||||
finishedFlow.Version = 1
|
finishedFlow.Version = 1
|
||||||
finishedFlow.Result = a.NewToken(c.Request(), user, f)
|
finishedFlow.Result = a.NewAccessToken(c.Request(), user, f)
|
||||||
|
|
||||||
f.redirect(c)
|
f.redirect(c)
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,8 @@ type HAUser struct {
|
||||||
provider.AuthProvider `json:"-"`
|
provider.AuthProvider `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hau *HAUser) UserData() interface{} {
|
func (hau *HAUser) UserData() provider.ProviderUser {
|
||||||
return UserData{
|
return &UserData{
|
||||||
Username: hau.Username,
|
Username: hau.Username,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,10 @@ type UserData struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ud *UserData) UserData() provider.ProviderUser {
|
||||||
|
return ud
|
||||||
|
}
|
||||||
|
|
||||||
const HomeAssistant = "homeassistant"
|
const HomeAssistant = "homeassistant"
|
||||||
|
|
||||||
func (h *HAUser) ProviderUserData() interface{} { return h.UserData() }
|
func (h *HAUser) ProviderUserData() interface{} { return h.UserData() }
|
||||||
|
|
|
@ -23,8 +23,7 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderUser interface {
|
type ProviderUser interface {
|
||||||
AuthProviderMetadata
|
UserData() ProviderUser
|
||||||
ProviderUserData() interface{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthProviderBase struct {
|
type AuthProviderBase struct {
|
||||||
|
|
|
@ -15,8 +15,8 @@ type User struct {
|
||||||
provider.AuthProvider `json:"-"`
|
provider.AuthProvider `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hau *User) UserData() interface{} {
|
func (hau *User) UserData() provider.ProviderUser {
|
||||||
return UserData{
|
return &UserData{
|
||||||
UserID: hau.UserID,
|
UserID: hau.UserID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,9 +25,11 @@ type UserData struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const TrustedNetworks = "trusted_networks"
|
func (ud *UserData) UserData() provider.ProviderUser {
|
||||||
|
return ud
|
||||||
|
}
|
||||||
|
|
||||||
func (h *User) ProviderUserData() interface{} { return h.UserData() }
|
const TrustedNetworks = "trusted_networks"
|
||||||
|
|
||||||
type TrustedNetworksProvider struct {
|
type TrustedNetworksProvider struct {
|
||||||
provider.AuthProviderBase `json:"-"`
|
provider.AuthProviderBase `json:"-"`
|
||||||
|
|
|
@ -2,28 +2,97 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/internal/generate"
|
||||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SessionStore struct {
|
/*
|
||||||
s map[TokenID]*Token
|
{
|
||||||
|
"id": "18912b07b66e48558a44c058bf90f1d4",
|
||||||
|
"user_id": "1bd642b0ed9f410280f622d1d358102b",
|
||||||
|
"client_id": "https://oauth-redirect.googleusercontent.com/",
|
||||||
|
"client_name": null,
|
||||||
|
"client_icon": null,
|
||||||
|
"token_type": "normal",
|
||||||
|
"created_at": "2021-12-13T04:34:47.169033+00:00",
|
||||||
|
"access_token_expiration": 1800.0,
|
||||||
|
"token": "f93f87557ca616508c675f05b85921d07fdf764efc34c74a81daeebd71ab899a6e8cd5dec94d0ae36499ff281f2efcf715c763ad73eabadd2f586b827057043d",
|
||||||
|
"jwt_key": "f4507c01fe19b1d99c4b628cadc8208f7838dbdda0a9a051fd029cefdf6f619b0728d6b0a4d3d96ee68614b035054952faa14c48ae36bd212e28d602864f6d1c",
|
||||||
|
"last_used_at": "2021-12-13T04:34:47.169738+00:00",
|
||||||
|
"last_used_ip": "108.177.68.93",
|
||||||
|
"credential_id": "5daeb186d2a943328cbb984f135974cb",
|
||||||
|
"version": "2021.12.0"
|
||||||
|
},
|
||||||
|
*/
|
||||||
|
|
||||||
|
type (
|
||||||
|
TokenType string
|
||||||
|
TokenTimestamp time.Time
|
||||||
|
RefreshTokenID string
|
||||||
|
ExpSeconds float64
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f ExpSeconds) MarshalJSON() ([]byte, error) {
|
||||||
|
if float64(f) == float64(int(f)) {
|
||||||
|
return []byte(strconv.FormatFloat(float64(f), 'f', 1, 32)), nil
|
||||||
|
}
|
||||||
|
return []byte(strconv.FormatFloat(float64(f), 'f', -1, 32)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const PytTimeFormat = "2006-01-02T15:04:05.999999-07:00"
|
||||||
|
|
||||||
|
func (t *TokenTimestamp) MarshalJSON() ([]byte, error) {
|
||||||
|
rv := fmt.Sprintf("%q", time.Time(*t).Format(PytTimeFormat))
|
||||||
|
return []byte(rv), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TokenTimestamp) UnmarshalJSON(b []byte) error {
|
||||||
|
s := strings.Trim(string(b), `"`)
|
||||||
|
tm, err := time.Parse(PytTimeFormat, s)
|
||||||
|
*t = TokenTimestamp(tm)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshToken struct {
|
||||||
|
ID RefreshTokenID `json:"id"`
|
||||||
|
UserID UserID `json:"user_id"`
|
||||||
|
ClientID *ClientID `json:"client_id"`
|
||||||
|
ClientName *string `json:"client_name"`
|
||||||
|
ClientIcon *string `json:"client_icon"`
|
||||||
|
TokenType TokenType `json:"token_type"`
|
||||||
|
CreatedAt *TokenTimestamp `json:"created_at"`
|
||||||
|
AccessTokenExpiration ExpSeconds `json:"access_token_expiration"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
JWTKey string `json:"jwt_key"`
|
||||||
|
LastUsedAt *TokenTimestamp `json:"last_used_at"`
|
||||||
|
LastUsedIP *string `json:"last_used_ip"`
|
||||||
|
CredentialID *CredID `json:"credential_id"`
|
||||||
|
Version *string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessSessionStore struct {
|
||||||
|
s map[AccessTokenID]*AccessToken
|
||||||
lastCull time.Time
|
lastCull time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenID string
|
type AccessTokenID string
|
||||||
|
|
||||||
func (t *TokenID) IsValid() bool {
|
func (t *AccessTokenID) IsValid() bool {
|
||||||
// TODO: more validation than this
|
// TODO: more validation than this
|
||||||
return *t != ""
|
return *t != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token struct { // TODO: jwt bro
|
type AccessToken struct { // TODO: jwt bro
|
||||||
ID TokenID
|
ID AccessTokenID
|
||||||
Ctime time.Time
|
Ctime time.Time
|
||||||
Expires time.Time
|
Expires time.Time
|
||||||
Addr string
|
Addr string
|
||||||
|
@ -31,13 +100,13 @@ type Token struct { // TODO: jwt bro
|
||||||
user provider.ProviderUser `json:"-"`
|
user provider.ProviderUser `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) init() {
|
func (ss *AccessSessionStore) init() {
|
||||||
ss.s = make(map[TokenID]*Token)
|
ss.s = make(map[AccessTokenID]*AccessToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
const cullInterval = 5 * time.Minute
|
const cullInterval = 5 * time.Minute
|
||||||
|
|
||||||
func (ss *SessionStore) cull() {
|
func (ss *AccessSessionStore) cull() {
|
||||||
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
||||||
for k, v := range ss.s {
|
for k, v := range ss.s {
|
||||||
if now.After(v.Expires) {
|
if now.After(v.Expires) {
|
||||||
|
@ -47,12 +116,12 @@ func (ss *SessionStore) cull() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) register(t *Token) {
|
func (ss *AccessSessionStore) register(t *AccessToken) {
|
||||||
ss.cull()
|
ss.cull()
|
||||||
ss.s[t.ID] = t
|
ss.s[t.ID] = t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *SessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
func (ss *AccessSessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
||||||
if t, hasToken := ss.s[tr.Code]; hasToken {
|
if t, hasToken := ss.s[tr.Code]; hasToken {
|
||||||
// TODO: JWT
|
// TODO: JWT
|
||||||
if t.Expires.After(time.Now()) {
|
if t.Expires.After(time.Now()) {
|
||||||
|
@ -75,18 +144,20 @@ type Credential struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cred *Credential) MarshalJSON() ([]byte, error) {
|
func (cred *Credential) MarshalJSON() ([]byte, error) {
|
||||||
type CredAlias Credential
|
type CredAlias Credential // alias so ø method set and we don't recurse
|
||||||
nCd := (*CredAlias)(cred)
|
nCd := (*CredAlias)(cred)
|
||||||
|
|
||||||
providerData := cred.user.ProviderUserData()
|
if cred.user != nil {
|
||||||
if providerData != nil {
|
providerData := cred.user.UserData()
|
||||||
b, err := json.Marshal(providerData)
|
if providerData != nil {
|
||||||
if err != nil {
|
b, err := json.Marshal(providerData)
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
dr := json.RawMessage(b)
|
dr := json.RawMessage(b)
|
||||||
nCd.DataRaw = &dr
|
nCd.DataRaw = &dr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.Marshal(nCd)
|
return json.Marshal(nCd)
|
||||||
|
@ -105,13 +176,13 @@ func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request
|
||||||
return cred
|
return cred
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultExpiration = 2 * time.Hour
|
const defaultExpiration = 15 * time.Minute
|
||||||
|
|
||||||
func (a *Authenticator) NewToken(r *http.Request, user provider.ProviderUser, f *Flow) TokenID {
|
func (a *Authenticator) NewAccessToken(r *http.Request, user provider.ProviderUser, f *Flow) AccessTokenID {
|
||||||
id := TokenID(genUUID())
|
id := AccessTokenID(generate.UUID())
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
t := &Token{
|
t := &AccessToken{
|
||||||
ID: id,
|
ID: id,
|
||||||
Ctime: now,
|
Ctime: now,
|
||||||
Expires: now.Add(defaultExpiration),
|
Expires: now.Add(defaultExpiration),
|
||||||
|
@ -140,9 +211,9 @@ func (c *ClientID) IsValid() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
ClientID ClientID `form:"client_id"`
|
ClientID ClientID `form:"client_id"`
|
||||||
Code TokenID `form:"code"`
|
Code AccessTokenID `form:"code"`
|
||||||
GrantType GrantType `form:"grant_type"`
|
GrantType GrantType `form:"grant_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||||
"dynatron.me/x/blasphem/pkg/storage"
|
"dynatron.me/x/blasphem/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,9 +19,10 @@ type AuthStore interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type authStore struct {
|
type authStore struct {
|
||||||
Users []User `json:"users"`
|
Users []User `json:"users"`
|
||||||
Groups interface{} `json:"groups"`
|
Groups []Group `json:"groups"`
|
||||||
Credentials []Credential `json:"credentials"`
|
Credentials []Credential `json:"credentials"`
|
||||||
|
Refresh []RefreshToken `json:"refresh_tokens"`
|
||||||
|
|
||||||
userMap map[UserID]*User
|
userMap map[UserID]*User
|
||||||
}
|
}
|
||||||
|
@ -42,13 +44,15 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
|
||||||
return nil, fmt.Errorf("no such provider %s", c.AuthProviderType)
|
return nil, fmt.Errorf("no such provider %s", c.AuthProviderType)
|
||||||
}
|
}
|
||||||
|
|
||||||
pd := prov.NewCredData()
|
|
||||||
|
|
||||||
if c.DataRaw != nil {
|
if c.DataRaw != nil {
|
||||||
|
pd := prov.NewCredData()
|
||||||
|
|
||||||
err := json.Unmarshal(*c.DataRaw, pd)
|
err := json.Unmarshal(*c.DataRaw, pd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.user = pd.(provider.ProviderUser)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,11 @@ type UserID string
|
||||||
type GroupID string
|
type GroupID string
|
||||||
type CredID string
|
type CredID string
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
ID GroupID `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID UserID `json:"id"`
|
ID UserID `json:"id"`
|
||||||
GroupIDs []GroupID `json:"group_ids"`
|
GroupIDs []GroupID `json:"group_ids"`
|
||||||
|
@ -16,11 +21,11 @@ type User struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserMetadata struct {
|
type UserMetadata struct {
|
||||||
Active bool `json:"is_active"`
|
|
||||||
Owner bool `json:"is_owner"`
|
Owner bool `json:"is_owner"`
|
||||||
LocalOnly bool `json:"local_only"`
|
Active bool `json:"is_active"`
|
||||||
SystemGenerated bool `json:"system_generated"`
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
SystemGenerated bool `json:"system_generated"`
|
||||||
|
LocalOnly bool `json:"local_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) allowedToAuth() error {
|
func (u *User) allowedToAuth() error {
|
||||||
|
|
55
pkg/storage/fs.go
Normal file
55
pkg/storage/fs.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoSuchKey = errors.New("no such key in store")
|
||||||
|
ErrKeyExists = errors.New("key already exists")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Item is an item in a datastore.
|
||||||
|
type Item interface {
|
||||||
|
// Item is lockable if updating data item directly.
|
||||||
|
sync.Locker
|
||||||
|
|
||||||
|
// Dirty sets the dirty flag for the item so it will be flushed.
|
||||||
|
Dirty()
|
||||||
|
|
||||||
|
// IsDirty gets the dirty flag for the item.
|
||||||
|
IsDirty() bool
|
||||||
|
|
||||||
|
// GetData gets the data for the item.
|
||||||
|
GetData() interface{}
|
||||||
|
|
||||||
|
// GetData sets the data for the item.
|
||||||
|
SetData(interface{})
|
||||||
|
|
||||||
|
// ItemKey gets the key of the item.
|
||||||
|
ItemKey() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store represents a datastore.
|
||||||
|
type Store interface {
|
||||||
|
// GetItem loads the specified key from the store into data and returns the Item.
|
||||||
|
// If err is ErrKeyExists, Item will be the existing item.
|
||||||
|
GetItem(key string, data interface{}) (Item, error)
|
||||||
|
|
||||||
|
// Get is the same as GetItem, but only returns error.
|
||||||
|
Get(key string, data interface{}) error
|
||||||
|
|
||||||
|
// Put puts the specified key into the store. If the key already exists, it clobbers.
|
||||||
|
// Note that any existing items will then dangle.
|
||||||
|
Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error)
|
||||||
|
|
||||||
|
// FlushAll flushes the store to backing.
|
||||||
|
FlushAll() []error
|
||||||
|
|
||||||
|
// Flush flushes a single key to backing.
|
||||||
|
Flush(key string) error
|
||||||
|
|
||||||
|
// Shutdown is called to quiesce and shutdown the store.
|
||||||
|
Shutdown()
|
||||||
|
}
|
|
@ -2,55 +2,43 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
IndentStr = strings.Repeat(" ", 4)
|
IndentStr = strings.Repeat(" ", 2)
|
||||||
|
|
||||||
ErrNoSuchKey = errors.New("no such key in store")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SecretMode os.FileMode = 0600
|
SecretMode fs.FileMode = 0600
|
||||||
DefaultMode os.FileMode = 0644
|
DefaultMode fs.FileMode = 0644
|
||||||
)
|
)
|
||||||
|
|
||||||
type Data interface {
|
|
||||||
}
|
|
||||||
|
|
||||||
type item struct {
|
type item struct {
|
||||||
|
sync.Mutex `json:"-"`
|
||||||
Version int `json:"version"`
|
Version int `json:"version"`
|
||||||
MinorVersion *int `json:"minor_version,omitempty"`
|
MinorVersion *int `json:"minor_version,omitempty"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
|
|
||||||
fmode os.FileMode
|
fmode fs.FileMode
|
||||||
dirty bool
|
dirty bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Item interface {
|
func (i *item) Dirty() { i.Lock(); defer i.Unlock(); i.dirty = true }
|
||||||
Dirty()
|
func (i *item) IsDirty() bool { i.Lock(); defer i.Unlock(); return i.dirty }
|
||||||
IsDirty() bool
|
func (i *item) GetData() interface{} { i.Lock(); defer i.Unlock(); return i.Data }
|
||||||
GetData() interface{}
|
func (i *item) SetData(d interface{}) { i.Lock(); defer i.Unlock(); i.Data = d; i.dirty = true }
|
||||||
SetData(interface{})
|
func (i *item) ItemKey() string { return i.Key /* key is immutable */ }
|
||||||
ItemKey() string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *item) Dirty() { i.dirty = true }
|
func (it *item) mode() fs.FileMode {
|
||||||
func (i *item) IsDirty() bool { return i.dirty }
|
|
||||||
func (i *item) GetData() interface{} { return i.Data }
|
|
||||||
func (i *item) SetData(d interface{}) { i.Data = d; i.Dirty() }
|
|
||||||
func (i *item) ItemKey() string { return i.Key }
|
|
||||||
|
|
||||||
func (it *item) mode() os.FileMode {
|
|
||||||
if it.fmode != 0 {
|
if it.fmode != 0 {
|
||||||
return it.fmode
|
return it.fmode
|
||||||
}
|
}
|
||||||
|
@ -59,22 +47,36 @@ func (it *item) mode() os.FileMode {
|
||||||
}
|
}
|
||||||
|
|
||||||
type fsStore struct {
|
type fsStore struct {
|
||||||
fs.FS
|
sync.RWMutex
|
||||||
|
fs fs.FS
|
||||||
storeRoot string
|
storeRoot string
|
||||||
s map[string]*item
|
s map[string]*item
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store interface {
|
func (s *fsStore) get(key string) *item {
|
||||||
GetItem(key string, data interface{}) (Item, error)
|
s.RLock()
|
||||||
Get(key string, data interface{}) error
|
defer s.RUnlock()
|
||||||
Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error)
|
|
||||||
FlushAll() []error
|
i, ok := s.s[key]
|
||||||
Flush(key string) error
|
if !ok {
|
||||||
Shutdown()
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fsStore) put(key string, it *item) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
s.s[key] = it
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fsStore) persist(it *item) error {
|
func (s *fsStore) persist(it *item) error {
|
||||||
f, err := os.OpenFile(path.Join(s.storeRoot, it.Key), os.O_WRONLY|os.O_CREATE, it.mode())
|
it.Lock()
|
||||||
|
defer it.Unlock()
|
||||||
|
|
||||||
|
f, err := os.OpenFile(path.Join(s.storeRoot, it.Key), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, it.mode())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -93,19 +95,19 @@ func (s *fsStore) persist(it *item) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fsStore) Dirty(key string) error {
|
func (s *fsStore) Dirty(key string) error {
|
||||||
it, has := s.s[key]
|
it := s.get(key)
|
||||||
if !has {
|
if it == nil {
|
||||||
return ErrNoSuchKey
|
return ErrNoSuchKey
|
||||||
}
|
}
|
||||||
|
|
||||||
it.dirty = true
|
it.Dirty()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fsStore) Flush(key string) error {
|
func (s *fsStore) Flush(key string) error {
|
||||||
it, exists := s.s[key]
|
it := s.get(key)
|
||||||
if !exists {
|
if it == nil {
|
||||||
return ErrNoSuchKey
|
return ErrNoSuchKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +115,9 @@ func (s *fsStore) Flush(key string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fsStore) FlushAll() []error {
|
func (s *fsStore) FlushAll() []error {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
for _, it := range s.s {
|
for _, it := range s.s {
|
||||||
err := s.persist(it)
|
err := s.persist(it)
|
||||||
|
@ -167,15 +172,26 @@ func (s *fsStore) Get(key string, data interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fsStore) GetItem(key string, data interface{}) (Item, error) {
|
func (s *fsStore) GetItem(key string, data interface{}) (Item, error) {
|
||||||
f, err := s.Open(key)
|
exists := s.get(key)
|
||||||
|
if exists != nil {
|
||||||
|
return exists, ErrKeyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := s.fs.Open(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
item := &item{
|
item := &item{
|
||||||
Data: data,
|
Data: data,
|
||||||
|
fmode: fi.Mode(),
|
||||||
}
|
}
|
||||||
d := json.NewDecoder(f)
|
d := json.NewDecoder(f)
|
||||||
err = d.Decode(item)
|
err = d.Decode(item)
|
||||||
|
@ -187,7 +203,7 @@ func (s *fsStore) GetItem(key string, data interface{}) (Item, error) {
|
||||||
return nil, fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
return nil, fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.s[key] = item
|
s.put(key, item)
|
||||||
|
|
||||||
return item, nil
|
return item, nil
|
||||||
}
|
}
|
||||||
|
@ -197,7 +213,7 @@ func OpenFileStore(configRoot string) (*fsStore, error) {
|
||||||
stor := os.DirFS(storeRoot)
|
stor := os.DirFS(storeRoot)
|
||||||
|
|
||||||
return &fsStore{
|
return &fsStore{
|
||||||
FS: stor,
|
fs: stor,
|
||||||
storeRoot: storeRoot,
|
storeRoot: storeRoot,
|
||||||
s: make(map[string]*item),
|
s: make(map[string]*item),
|
||||||
}, nil
|
}, nil
|
||||||
|
|
Loading…
Reference in a new issue