Compare commits
13 commits
54185be835
...
c618197c54
Author | SHA1 | Date | |
---|---|---|---|
c618197c54 | |||
794f2d8448 | |||
9aef6e5143 | |||
7f499012a6 | |||
1ea8a24224 | |||
25ed921736 | |||
9db15a648b | |||
4de1344512 | |||
68b971f65a | |||
9b38bdbca9 | |||
43682fab05 | |||
3981025fa4 | |||
3ab5b5b78a |
15 changed files with 451 additions and 148 deletions
|
@ -1,3 +1,4 @@
|
|||
// common contains common functionality for blasphem.
|
||||
package common
|
||||
|
||||
import (
|
||||
|
@ -5,6 +6,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// AppName is the name of the application.
|
||||
AppName = "blasphem"
|
||||
)
|
||||
|
||||
|
@ -13,6 +15,7 @@ type cmdOptions interface {
|
|||
Execute() error
|
||||
}
|
||||
|
||||
// RunE is a convenience function for use with cobra.
|
||||
func RunE(c cmdOptions) func(cmd *cobra.Command, args []string) error {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
err := c.Options(cmd, args)
|
||||
|
|
29
internal/common/types.go
Normal file
29
internal/common/types.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package common
|
||||
|
||||
// Convenience types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// PyTimeStamp is a timestamp that marshals to python-style timestamp strings (long nano).
|
||||
PyTimestamp time.Time
|
||||
)
|
||||
|
||||
const PytTimeFormat = "2006-01-02T15:04:05.999999-07:00"
|
||||
|
||||
func (t *PyTimestamp) MarshalJSON() ([]byte, error) {
|
||||
rv := fmt.Sprintf("%q", time.Time(*t).Format(PytTimeFormat))
|
||||
return []byte(rv), nil
|
||||
}
|
||||
|
||||
func (t *PyTimestamp) UnmarshalJSON(b []byte) error {
|
||||
s := strings.Trim(string(b), `"`)
|
||||
tm, err := time.Parse(PytTimeFormat, s)
|
||||
*t = PyTimestamp(tm)
|
||||
|
||||
return err
|
||||
}
|
24
internal/generate/unique.go
Normal file
24
internal/generate/unique.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package generate
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func UUID() string {
|
||||
// must be addressable
|
||||
u := uuid.New()
|
||||
|
||||
return hex.EncodeToString(u[:])
|
||||
}
|
||||
|
||||
func Hex(l int) string {
|
||||
b := make([]byte, l)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(b)
|
||||
}
|
|
@ -1,12 +1,9 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
|
@ -28,7 +25,7 @@ var (
|
|||
type Authenticator struct {
|
||||
store AuthStore
|
||||
flows FlowStore
|
||||
sessions SessionStore
|
||||
sessions AccessSessionStore
|
||||
providers map[string]provider.AuthProvider
|
||||
}
|
||||
|
||||
|
@ -114,26 +111,10 @@ func (a *Authenticator) Check(f *Flow, req *http.Request, rm map[string]interfac
|
|||
user, success := p.ValidateCreds(req, rm)
|
||||
|
||||
if success {
|
||||
log.Info().Interface("user", user.ProviderUserData()).Msg("Login success")
|
||||
log.Info().Interface("user", user.UserData()).Msg("Login success")
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrInvalidAuth
|
||||
}
|
||||
|
||||
func genUUID() string {
|
||||
// must be addressable
|
||||
u := uuid.New()
|
||||
|
||||
return hex.EncodeToString(u[:])
|
||||
}
|
||||
|
||||
func genHex(l int) string {
|
||||
b := make([]byte, l)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/labstack/echo/v4"
|
||||
|
||||
"dynatron.me/x/blasphem/internal/common"
|
||||
"dynatron.me/x/blasphem/internal/generate"
|
||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||
)
|
||||
|
||||
|
@ -100,7 +101,7 @@ func (a *Authenticator) NewFlow(r *FlowRequest) *Flow {
|
|||
|
||||
flow := &Flow{
|
||||
Type: TypeForm,
|
||||
ID: FlowID(genUUID()),
|
||||
ID: FlowID(generate.UUID()),
|
||||
StepID: stepPtr(StepInit),
|
||||
Schema: sch,
|
||||
Handler: r.Handler,
|
||||
|
@ -146,12 +147,12 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
|||
switch err {
|
||||
case nil:
|
||||
var finishedFlow struct {
|
||||
ID FlowID `json:"flow_id"`
|
||||
Handler []*string `json:"handler"`
|
||||
Result TokenID `json:"result"`
|
||||
Title string `json:"title"`
|
||||
Type FlowType `json:"type"`
|
||||
Version int `json:"version"`
|
||||
ID FlowID `json:"flow_id"`
|
||||
Handler []*string `json:"handler"`
|
||||
Result AccessTokenID `json:"result"`
|
||||
Title string `json:"title"`
|
||||
Type FlowType `json:"type"`
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
a.flows.Remove(f)
|
||||
|
@ -159,7 +160,7 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
|||
finishedFlow.Type = TypeCreateEntry
|
||||
finishedFlow.Title = common.AppName
|
||||
finishedFlow.Version = 1
|
||||
finishedFlow.Result = a.NewToken(c.Request(), user, f)
|
||||
finishedFlow.Result = a.NewAccessToken(c.Request(), user, f)
|
||||
|
||||
f.redirect(c)
|
||||
|
||||
|
|
|
@ -23,8 +23,8 @@ type HAUser struct {
|
|||
provider.AuthProvider `json:"-"`
|
||||
}
|
||||
|
||||
func (hau *HAUser) UserData() interface{} {
|
||||
return UserData{
|
||||
func (hau *HAUser) UserData() provider.ProviderUser {
|
||||
return &UserData{
|
||||
Username: hau.Username,
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,10 @@ type UserData struct {
|
|||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func (ud *UserData) UserData() provider.ProviderUser {
|
||||
return ud
|
||||
}
|
||||
|
||||
const HomeAssistant = "homeassistant"
|
||||
|
||||
func (h *HAUser) ProviderUserData() interface{} { return h.UserData() }
|
||||
|
|
|
@ -23,8 +23,7 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
|
|||
}
|
||||
|
||||
type ProviderUser interface {
|
||||
AuthProviderMetadata
|
||||
ProviderUserData() interface{}
|
||||
UserData() ProviderUser
|
||||
}
|
||||
|
||||
type AuthProviderBase struct {
|
||||
|
|
|
@ -15,8 +15,8 @@ type User struct {
|
|||
provider.AuthProvider `json:"-"`
|
||||
}
|
||||
|
||||
func (hau *User) UserData() interface{} {
|
||||
return UserData{
|
||||
func (hau *User) UserData() provider.ProviderUser {
|
||||
return &UserData{
|
||||
UserID: hau.UserID,
|
||||
}
|
||||
}
|
||||
|
@ -25,9 +25,11 @@ type UserData struct {
|
|||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
const TrustedNetworks = "trusted_networks"
|
||||
func (ud *UserData) UserData() provider.ProviderUser {
|
||||
return ud
|
||||
}
|
||||
|
||||
func (h *User) ProviderUserData() interface{} { return h.UserData() }
|
||||
const TrustedNetworks = "trusted_networks"
|
||||
|
||||
type TrustedNetworksProvider struct {
|
||||
provider.AuthProviderBase `json:"-"`
|
||||
|
|
|
@ -7,23 +7,47 @@ import (
|
|||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"dynatron.me/x/blasphem/internal/common"
|
||||
"dynatron.me/x/blasphem/internal/generate"
|
||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||
)
|
||||
|
||||
type SessionStore struct {
|
||||
s map[TokenID]*Token
|
||||
type (
|
||||
TokenType string
|
||||
RefreshTokenID string
|
||||
)
|
||||
|
||||
type RefreshToken struct {
|
||||
ID RefreshTokenID `json:"id"`
|
||||
UserID UserID `json:"user_id"`
|
||||
ClientID *ClientID `json:"client_id"`
|
||||
ClientName *string `json:"client_name"`
|
||||
ClientIcon *string `json:"client_icon"`
|
||||
TokenType TokenType `json:"token_type"`
|
||||
CreatedAt *common.PyTimestamp `json:"created_at"`
|
||||
AccessTokenExpiration json.Number `json:"access_token_expiration"`
|
||||
Token string `json:"token"`
|
||||
JWTKey string `json:"jwt_key"`
|
||||
LastUsedAt *common.PyTimestamp `json:"last_used_at"`
|
||||
LastUsedIP *string `json:"last_used_ip"`
|
||||
CredentialID *CredID `json:"credential_id"`
|
||||
Version *string `json:"version"`
|
||||
}
|
||||
|
||||
type AccessSessionStore struct {
|
||||
s map[AccessTokenID]*AccessToken
|
||||
lastCull time.Time
|
||||
}
|
||||
|
||||
type TokenID string
|
||||
type AccessTokenID string
|
||||
|
||||
func (t *TokenID) IsValid() bool {
|
||||
func (t *AccessTokenID) IsValid() bool {
|
||||
// TODO: more validation than this
|
||||
return *t != ""
|
||||
}
|
||||
|
||||
type Token struct { // TODO: jwt bro
|
||||
ID TokenID
|
||||
type AccessToken struct { // TODO: jwt bro
|
||||
ID AccessTokenID
|
||||
Ctime time.Time
|
||||
Expires time.Time
|
||||
Addr string
|
||||
|
@ -31,13 +55,13 @@ type Token struct { // TODO: jwt bro
|
|||
user provider.ProviderUser `json:"-"`
|
||||
}
|
||||
|
||||
func (ss *SessionStore) init() {
|
||||
ss.s = make(map[TokenID]*Token)
|
||||
func (ss *AccessSessionStore) init() {
|
||||
ss.s = make(map[AccessTokenID]*AccessToken)
|
||||
}
|
||||
|
||||
const cullInterval = 5 * time.Minute
|
||||
|
||||
func (ss *SessionStore) cull() {
|
||||
func (ss *AccessSessionStore) cull() {
|
||||
if now := time.Now(); now.Sub(ss.lastCull) > cullInterval {
|
||||
for k, v := range ss.s {
|
||||
if now.After(v.Expires) {
|
||||
|
@ -47,12 +71,12 @@ func (ss *SessionStore) cull() {
|
|||
}
|
||||
}
|
||||
|
||||
func (ss *SessionStore) register(t *Token) {
|
||||
func (ss *AccessSessionStore) register(t *AccessToken) {
|
||||
ss.cull()
|
||||
ss.s[t.ID] = t
|
||||
}
|
||||
|
||||
func (ss *SessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
||||
func (ss *AccessSessionStore) verify(tr *TokenRequest, r *http.Request) (provider.ProviderUser, bool) {
|
||||
if t, hasToken := ss.s[tr.Code]; hasToken {
|
||||
// TODO: JWT
|
||||
if t.Expires.After(time.Now()) {
|
||||
|
@ -66,46 +90,54 @@ func (ss *SessionStore) verify(tr *TokenRequest, r *http.Request) (provider.Prov
|
|||
}
|
||||
|
||||
type Credential struct {
|
||||
ID CredID `json:"id"`
|
||||
UserID UserID `json:"user_id"`
|
||||
AuthProviderType string `json:"auth_provider_type"`
|
||||
AuthProviderID *string `json:"auth_provider_id"`
|
||||
DataRaw json.RawMessage `json:"data,omitempty"`
|
||||
user provider.ProviderUser
|
||||
ID CredID `json:"id"`
|
||||
UserID UserID `json:"user_id"`
|
||||
AuthProviderType string `json:"auth_provider_type"`
|
||||
AuthProviderID *string `json:"auth_provider_id"`
|
||||
DataRaw *json.RawMessage `json:"data,omitempty"`
|
||||
user provider.ProviderUser `json:"-"`
|
||||
}
|
||||
|
||||
func (cred *Credential) MarshalJSON() ([]byte, error) {
|
||||
rm := map[string]interface{}{
|
||||
"id": cred.ID,
|
||||
"user_id": cred.UserID,
|
||||
"auth_provider_type": cred.user.ProviderType(),
|
||||
"auth_provider_id": cred.user.ProviderID(),
|
||||
type CredAlias Credential // alias so ø method set and we don't recurse
|
||||
nCd := (*CredAlias)(cred)
|
||||
|
||||
if cred.user != nil {
|
||||
providerData := cred.user.UserData()
|
||||
if providerData != nil {
|
||||
b, err := json.Marshal(providerData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dr := json.RawMessage(b)
|
||||
nCd.DataRaw = &dr
|
||||
}
|
||||
}
|
||||
|
||||
providerData := cred.user.ProviderUserData()
|
||||
if providerData != nil {
|
||||
rm["data"] = providerData
|
||||
}
|
||||
|
||||
return json.Marshal(rm)
|
||||
return json.Marshal(nCd)
|
||||
}
|
||||
|
||||
func (ss *SessionStore) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credential {
|
||||
user, success := ss.verify(tr, r)
|
||||
func (a *Authenticator) verifyAndGetCredential(tr *TokenRequest, r *http.Request) *Credential {
|
||||
user, success := a.sessions.verify(tr, r)
|
||||
if !success {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Credential{user: user}
|
||||
cred := &Credential{
|
||||
user: user,
|
||||
}
|
||||
|
||||
return cred
|
||||
}
|
||||
|
||||
const defaultExpiration = 2 * time.Hour
|
||||
const defaultExpiration = 15 * time.Minute
|
||||
|
||||
func (a *Authenticator) NewToken(r *http.Request, user provider.ProviderUser, f *Flow) TokenID {
|
||||
id := TokenID(genUUID())
|
||||
func (a *Authenticator) NewAccessToken(r *http.Request, user provider.ProviderUser, f *Flow) AccessTokenID {
|
||||
id := AccessTokenID(generate.UUID())
|
||||
now := time.Now()
|
||||
|
||||
t := &Token{
|
||||
t := &AccessToken{
|
||||
ID: id,
|
||||
Ctime: now,
|
||||
Expires: now.Add(defaultExpiration),
|
||||
|
@ -134,9 +166,9 @@ func (c *ClientID) IsValid() bool {
|
|||
}
|
||||
|
||||
type TokenRequest struct {
|
||||
ClientID ClientID `form:"client_id"`
|
||||
Code TokenID `form:"code"`
|
||||
GrantType GrantType `form:"grant_type"`
|
||||
ClientID ClientID `form:"client_id"`
|
||||
Code AccessTokenID `form:"code"`
|
||||
GrantType GrantType `form:"grant_type"`
|
||||
}
|
||||
|
||||
func (a *Authenticator) TokenHandler(c echo.Context) error {
|
||||
|
@ -156,7 +188,7 @@ func (a *Authenticator) TokenHandler(c echo.Context) error {
|
|||
return c.JSON(http.StatusBadRequest, AuthError{Error: "invalid_request", Description: "invalid code"})
|
||||
}
|
||||
|
||||
if cred := a.sessions.verifyAndGetCredential(rq, c.Request()); cred != nil {
|
||||
if cred := a.verifyAndGetCredential(rq, c.Request()); cred != nil {
|
||||
// TODO: success
|
||||
user, err := a.getOrCreateUser(cred)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"dynatron.me/x/blasphem/pkg/auth/provider"
|
||||
"dynatron.me/x/blasphem/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -16,9 +17,10 @@ type AuthStore interface {
|
|||
}
|
||||
|
||||
type authStore struct {
|
||||
Users []User `json:"users"`
|
||||
Groups interface{} `json:"groups"`
|
||||
Credentials []Credential `json:"credentials"`
|
||||
Users []User `json:"users"`
|
||||
Groups []Group `json:"groups"`
|
||||
Credentials []Credential `json:"credentials"`
|
||||
Refresh []RefreshToken `json:"refresh_tokens"`
|
||||
|
||||
userMap map[UserID]*User
|
||||
}
|
||||
|
@ -39,11 +41,15 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
|
|||
return nil, fmt.Errorf("no such provider %s", c.AuthProviderType)
|
||||
}
|
||||
|
||||
pd := prov.NewCredData()
|
||||
if c.DataRaw != nil {
|
||||
pd := prov.NewCredData()
|
||||
|
||||
err := json.Unmarshal(c.DataRaw, pd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
err := json.Unmarshal(*c.DataRaw, pd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.user = pd.(provider.ProviderUser)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -10,6 +8,11 @@ type UserID string
|
|||
type GroupID string
|
||||
type CredID string
|
||||
|
||||
type Group struct {
|
||||
ID GroupID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID UserID `json:"id"`
|
||||
GroupIDs []GroupID `json:"group_ids"`
|
||||
|
@ -18,11 +21,11 @@ type User struct {
|
|||
}
|
||||
|
||||
type UserMetadata struct {
|
||||
Active bool `json:"is_active"`
|
||||
Owner bool `json:"is_owner"`
|
||||
LocalOnly bool `json:"local_only"`
|
||||
SystemGenerated bool `json:"system_generated"`
|
||||
Active bool `json:"is_active"`
|
||||
Name string `json:"name"`
|
||||
SystemGenerated bool `json:"system_generated"`
|
||||
LocalOnly bool `json:"local_only"`
|
||||
}
|
||||
|
||||
func (u *User) allowedToAuth() error {
|
||||
|
@ -34,7 +37,7 @@ func (u *User) allowedToAuth() error {
|
|||
}
|
||||
|
||||
func (a *Authenticator) getOrCreateUser(c *Credential) (*User, error) {
|
||||
log.Debug().Interface("userdata", c.user.ProviderUserData()).Msg("getOrCreateUser")
|
||||
log.Debug().Interface("userdata", c).Msg("getOrCreateUser")
|
||||
u := a.store.User(c.UserID)
|
||||
if u == nil {
|
||||
return nil, ErrInvalidAuth
|
||||
|
|
|
@ -20,6 +20,7 @@ type Blas struct {
|
|||
|
||||
func (b *Blas) Shutdown(ctx context.Context) error {
|
||||
b.Bus.Shutdown()
|
||||
b.Store.Shutdown()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
|
@ -44,7 +45,8 @@ func (b *Blas) ConfigDir() (cd string) {
|
|||
}
|
||||
|
||||
func (b *Blas) openStore() error {
|
||||
stor, err := storage.Open(os.DirFS(b.ConfigDir()))
|
||||
// TODO: based on config, open filestore or db store
|
||||
stor, err := storage.OpenFileStore(b.ConfigDir())
|
||||
b.Store = stor
|
||||
return err
|
||||
}
|
||||
|
|
220
pkg/storage/filesystem.go
Normal file
220
pkg/storage/filesystem.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
IndentStr = strings.Repeat(" ", 2)
|
||||
)
|
||||
|
||||
const (
|
||||
SecretMode fs.FileMode = 0600
|
||||
DefaultMode fs.FileMode = 0644
|
||||
)
|
||||
|
||||
type item struct {
|
||||
sync.Mutex `json:"-"`
|
||||
Version int `json:"version"`
|
||||
MinorVersion *int `json:"minor_version,omitempty"`
|
||||
Key string `json:"key"`
|
||||
Data interface{} `json:"data"`
|
||||
|
||||
fmode fs.FileMode
|
||||
dirty bool
|
||||
}
|
||||
|
||||
func (i *item) Dirty() { i.Lock(); defer i.Unlock(); i.dirty = true }
|
||||
func (i *item) IsDirty() bool { i.Lock(); defer i.Unlock(); return i.dirty }
|
||||
func (i *item) GetData() interface{} { i.Lock(); defer i.Unlock(); return i.Data }
|
||||
func (i *item) SetData(d interface{}) { i.Lock(); defer i.Unlock(); i.Data = d; i.dirty = true }
|
||||
func (i *item) ItemKey() string { return i.Key /* key is immutable */ }
|
||||
|
||||
func (it *item) mode() fs.FileMode {
|
||||
if it.fmode != 0 {
|
||||
return it.fmode
|
||||
}
|
||||
|
||||
return SecretMode
|
||||
}
|
||||
|
||||
type fsStore struct {
|
||||
sync.RWMutex
|
||||
fs fs.FS
|
||||
storeRoot string
|
||||
s map[string]*item
|
||||
}
|
||||
|
||||
func (s *fsStore) get(key string) *item {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
i, ok := s.s[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
func (s *fsStore) put(key string, it *item) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
s.s[key] = it
|
||||
}
|
||||
|
||||
func (s *fsStore) persist(it *item) error {
|
||||
it.Lock()
|
||||
defer it.Unlock()
|
||||
|
||||
f, err := os.OpenFile(path.Join(s.storeRoot, it.Key), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, it.mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", IndentStr)
|
||||
|
||||
err = enc.Encode(it)
|
||||
if err == nil {
|
||||
it.dirty = false
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *fsStore) Dirty(key string) error {
|
||||
it := s.get(key)
|
||||
if it == nil {
|
||||
return ErrNoSuchKey
|
||||
}
|
||||
|
||||
it.Dirty()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fsStore) Flush(key string) error {
|
||||
it := s.get(key)
|
||||
if it == nil {
|
||||
return ErrNoSuchKey
|
||||
}
|
||||
|
||||
return s.persist(it)
|
||||
}
|
||||
|
||||
func (s *fsStore) FlushAll() []error {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
var errs []error
|
||||
for _, it := range s.s {
|
||||
err := s.persist(it)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("store key %s: %w", it.Key, err))
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func (s *fsStore) Shutdown() {
|
||||
errs := s.FlushAll()
|
||||
if errs != nil {
|
||||
log.Error().Errs("errors", errs).Msg("errors persisting store")
|
||||
}
|
||||
}
|
||||
|
||||
// Put puts an item into the store.
|
||||
// NB: Any user of a previous item with this key will now have a dangling reference that will not be persisted.
|
||||
// It is up to consumers to coordinate against this case!
|
||||
func (s *fsStore) Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error) {
|
||||
var mv *int
|
||||
if minorVersion != 0 {
|
||||
mv = &minorVersion
|
||||
}
|
||||
|
||||
mode := DefaultMode
|
||||
|
||||
if secretMode {
|
||||
mode = SecretMode
|
||||
}
|
||||
|
||||
it := &item{
|
||||
Version: version,
|
||||
MinorVersion: mv,
|
||||
Key: key,
|
||||
Data: data,
|
||||
|
||||
fmode: mode,
|
||||
dirty: true,
|
||||
}
|
||||
|
||||
s.s[key] = it
|
||||
return it, s.persist(it)
|
||||
}
|
||||
|
||||
func (s *fsStore) Get(key string, data interface{}) error {
|
||||
_, err := s.GetItem(key, data)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *fsStore) GetItem(key string, data interface{}) (Item, error) {
|
||||
exists := s.get(key)
|
||||
if exists != nil {
|
||||
return exists, ErrKeyExists
|
||||
}
|
||||
|
||||
f, err := s.fs.Open(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
item := &item{
|
||||
Data: data,
|
||||
fmode: fi.Mode(),
|
||||
}
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if item.Key != key {
|
||||
return nil, fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
||||
}
|
||||
|
||||
s.put(key, item)
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func OpenFileStore(configRoot string) (*fsStore, error) {
|
||||
storeRoot := path.Join(configRoot, ".storage")
|
||||
stor := os.DirFS(storeRoot)
|
||||
|
||||
return &fsStore{
|
||||
fs: stor,
|
||||
storeRoot: storeRoot,
|
||||
s: make(map[string]*item),
|
||||
}, nil
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
type Data interface {
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Version int `json:"version"`
|
||||
MinorVersion *int `json:"minor_version,omitempty"`
|
||||
Key string `json:"key"`
|
||||
Data Data `json:"data"`
|
||||
}
|
||||
|
||||
type store struct {
|
||||
fs.FS
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
Get(key string, data interface{}) error
|
||||
}
|
||||
|
||||
func (s *store) Get(key string, data interface{}) error {
|
||||
f, err := s.Open(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
item := Item{
|
||||
Data: data,
|
||||
}
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(&item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if item.Key != key {
|
||||
return fmt.Errorf("key mismatch '%s' != '%s'", item.Key, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Open(dir fs.FS) (*store, error) {
|
||||
stor, err := fs.Sub(dir, ".storage")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &store{stor}, nil
|
||||
}
|
55
pkg/storage/store.go
Normal file
55
pkg/storage/store.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoSuchKey = errors.New("no such key in store")
|
||||
ErrKeyExists = errors.New("key already exists")
|
||||
)
|
||||
|
||||
// Item is an item in a datastore.
|
||||
type Item interface {
|
||||
// Item is lockable if updating data item directly.
|
||||
sync.Locker
|
||||
|
||||
// Dirty sets the dirty flag for the item so it will be flushed.
|
||||
Dirty()
|
||||
|
||||
// IsDirty gets the dirty flag for the item.
|
||||
IsDirty() bool
|
||||
|
||||
// GetData gets the data for the item.
|
||||
GetData() interface{}
|
||||
|
||||
// GetData sets the data for the item.
|
||||
SetData(interface{})
|
||||
|
||||
// ItemKey gets the key of the item.
|
||||
ItemKey() string
|
||||
}
|
||||
|
||||
// Store represents a datastore.
|
||||
type Store interface {
|
||||
// GetItem loads the specified key from the store into data and returns the Item.
|
||||
// If err is ErrKeyExists, Item will be the existing item.
|
||||
GetItem(key string, data interface{}) (Item, error)
|
||||
|
||||
// Get is the same as GetItem, but only returns error.
|
||||
Get(key string, data interface{}) error
|
||||
|
||||
// Put puts the specified key into the store. If the key already exists, it clobbers.
|
||||
// Note that any existing items will then dangle.
|
||||
Put(key string, version, minorVersion int, secretMode bool, data interface{}) (Item, error)
|
||||
|
||||
// FlushAll flushes the store to backing.
|
||||
FlushAll() []error
|
||||
|
||||
// Flush flushes a single key to backing.
|
||||
Flush(key string) error
|
||||
|
||||
// Shutdown is called to quiesce and shutdown the store.
|
||||
Shutdown()
|
||||
}
|
Loading…
Reference in a new issue