RBAC #102

Merged
amigan merged 6 commits from rbac into trunk 2025-01-18 17:22:09 -05:00
20 changed files with 702 additions and 173 deletions
Showing only changes of commit aec89b5569 - Show all commits

View file

@ -1,7 +1,6 @@
package cache package cache
import ( import "sync"
)
type Cache[K comparable, V any] interface { type Cache[K comparable, V any] interface {
Get(K) (V, bool) Get(K) (V, bool)
@ -10,25 +9,38 @@ type Cache[K comparable, V any] interface {
Clear() Clear()
} }
type inMem[K comparable, V any] map[K]V type inMem[K comparable, V any] struct {
sync.RWMutex
func New[K comparable, V any]() inMem[K, V] { m map[K]V
return make(inMem[K, V])
} }
func (c inMem[K, V]) Get(key K) (V, bool) { func New[K comparable, V any]() *inMem[K, V] {
v, ok := c[key] return &inMem[K, V]{
m: make(map[K]V),
}
}
func (c *inMem[K, V]) Get(key K) (V, bool) {
c.RLock()
defer c.RUnlock()
v, ok := c.m[key]
return v, ok return v, ok
} }
func (c inMem[K, V]) Set(key K, val V) { func (c *inMem[K, V]) Set(key K, val V) {
c[key] = val c.Lock()
defer c.Unlock()
c.m[key] = val
} }
func (c inMem[K, V]) Delete(key K) { func (c *inMem[K, V]) Delete(key K) {
delete(c, key) c.Lock()
defer c.Unlock()
delete(c.m, key)
} }
func (c inMem[K, V]) Clear() { func (c *inMem[K, V]) Clear() {
clear(c) c.Lock()
defer c.Unlock()
clear(c.m)
} }

View file

@ -11,11 +11,13 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5" "github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render" "github.com/go-chi/render"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -45,22 +47,16 @@ type jwtAuth interface {
type claims map[string]interface{} type claims map[string]interface{}
// TODO: change this to UserFrom() *users.User // UsernameFrom gets the username (just the subject from token) from ctx.
func UIDFrom(ctx context.Context) *users.UserID { func UsernameFrom(ctx context.Context) *string {
tok, _, err := jwtauth.FromContext(ctx) tok, _, err := jwtauth.FromContext(ctx)
if err != nil { if err != nil {
return nil return nil
} }
uidStr := tok.Subject() username := tok.Subject()
uidInt, err := strconv.Atoi(uidStr)
if err != nil {
return nil
}
uid := users.UserID(int32(uidInt)) return &username
return &uid
} }
func (a *Auth) Authenticated(r *http.Request) (claims, bool) { func (a *Auth) Authenticated(r *http.Request) (claims, bool) {
@ -90,7 +86,38 @@ func TokenFromCookie(r *http.Request) string {
} }
func (a *Auth) AuthMiddleware() func(http.Handler) http.Handler { func (a *Auth) AuthMiddleware() func(http.Handler) http.Handler {
return jwtauth.Authenticator(a.jwt) return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if token != nil && jwt.Validate(token, a.jwt.ValidateOptions()...) == nil {
ctx := r.Context()
username := token.Subject()
sub, err := users.FromCtx(ctx).GetUser(ctx, username)
if err != nil {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
ctx = rbac.CtxWithSubject(ctx, sub)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Token is authenticated, pass it through
next.ServeHTTP(w, r)
}
return http.HandlerFunc(hfn)
}
} }
func (a *Auth) initJWT() { func (a *Auth) initJWT() {

View file

@ -8,6 +8,7 @@ import (
"dynatron.me/x/stillbox/internal/audio" "dynatron.me/x/stillbox/internal/audio"
"dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/jsontypes"
"dynatron.me/x/stillbox/pkg/pb" "dynatron.me/x/stillbox/pkg/pb"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
@ -73,6 +74,10 @@ type Call struct {
shouldStore bool `json:"-"` shouldStore bool `json:"-"`
} }
func (c *Call) GetResourceName() string {
return rbac.ResourceCall
}
func (c *Call) String() string { func (c *Call) String() string {
return fmt.Sprintf("%s to %d from %d", c.AudioName, c.Talkgroup, c.Source) return fmt.Sprintf("%s to %d from %d", c.AudioName, c.Talkgroup, c.Source)
} }

View file

@ -1127,22 +1127,22 @@ func (_c *Store_DropPartition_Call) RunAndReturn(run func(context.Context, strin
} }
// GetAPIKey provides a mock function with given fields: ctx, apiKey // GetAPIKey provides a mock function with given fields: ctx, apiKey
func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.ApiKey, error) { func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.GetAPIKeyRow, error) {
ret := _m.Called(ctx, apiKey) ret := _m.Called(ctx, apiKey)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for GetAPIKey") panic("no return value specified for GetAPIKey")
} }
var r0 database.ApiKey var r0 database.GetAPIKeyRow
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (database.ApiKey, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, string) (database.GetAPIKeyRow, error)); ok {
return rf(ctx, apiKey) return rf(ctx, apiKey)
} }
if rf, ok := ret.Get(0).(func(context.Context, string) database.ApiKey); ok { if rf, ok := ret.Get(0).(func(context.Context, string) database.GetAPIKeyRow); ok {
r0 = rf(ctx, apiKey) r0 = rf(ctx, apiKey)
} else { } else {
r0 = ret.Get(0).(database.ApiKey) r0 = ret.Get(0).(database.GetAPIKeyRow)
} }
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
@ -1173,12 +1173,12 @@ func (_c *Store_GetAPIKey_Call) Run(run func(ctx context.Context, apiKey string)
return _c return _c
} }
func (_c *Store_GetAPIKey_Call) Return(_a0 database.ApiKey, _a1 error) *Store_GetAPIKey_Call { func (_c *Store_GetAPIKey_Call) Return(_a0 database.GetAPIKeyRow, _a1 error) *Store_GetAPIKey_Call {
_c.Call.Return(_a0, _a1) _c.Call.Return(_a0, _a1)
return _c return _c
} }
func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.ApiKey, error)) *Store_GetAPIKey_Call { func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.GetAPIKeyRow, error)) *Store_GetAPIKey_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@ -2584,63 +2584,6 @@ func (_c *Store_GetUserByID_Call) RunAndReturn(run func(context.Context, int) (d
return _c return _c
} }
// GetUserByUID provides a mock function with given fields: ctx, id
func (_m *Store) GetUserByUID(ctx context.Context, id int) (database.User, error) {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for GetUserByUID")
}
var r0 database.User
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int) (database.User, error)); ok {
return rf(ctx, id)
}
if rf, ok := ret.Get(0).(func(context.Context, int) database.User); ok {
r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(database.User)
}
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_GetUserByUID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserByUID'
type Store_GetUserByUID_Call struct {
*mock.Call
}
// GetUserByUID is a helper method to define mock.On call
// - ctx context.Context
// - id int
func (_e *Store_Expecter) GetUserByUID(ctx interface{}, id interface{}) *Store_GetUserByUID_Call {
return &Store_GetUserByUID_Call{Call: _e.mock.On("GetUserByUID", ctx, id)}
}
func (_c *Store_GetUserByUID_Call) Run(run func(ctx context.Context, id int)) *Store_GetUserByUID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int))
})
return _c
}
func (_c *Store_GetUserByUID_Call) Return(_a0 database.User, _a1 error) *Store_GetUserByUID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_GetUserByUID_Call) RunAndReturn(run func(context.Context, int) (database.User, error)) *Store_GetUserByUID_Call {
_c.Call.Return(run)
return _c
}
// GetUserByUsername provides a mock function with given fields: ctx, username // GetUserByUsername provides a mock function with given fields: ctx, username
func (_m *Store) GetUserByUsername(ctx context.Context, username string) (database.User, error) { func (_m *Store) GetUserByUsername(ctx context.Context, username string) (database.User, error) {
ret := _m.Called(ctx, username) ret := _m.Called(ctx, username)
@ -3702,6 +3645,65 @@ func (_c *Store_UpdateTalkgroup_Call) RunAndReturn(run func(context.Context, dat
return _c return _c
} }
// UpdateUser provides a mock function with given fields: ctx, username, email, isAdmin
func (_m *Store) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (database.User, error) {
ret := _m.Called(ctx, username, email, isAdmin)
if len(ret) == 0 {
panic("no return value specified for UpdateUser")
}
var r0 database.User
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) (database.User, error)); ok {
return rf(ctx, username, email, isAdmin)
}
if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) database.User); ok {
r0 = rf(ctx, username, email, isAdmin)
} else {
r0 = ret.Get(0).(database.User)
}
if rf, ok := ret.Get(1).(func(context.Context, string, *string, *bool) error); ok {
r1 = rf(ctx, username, email, isAdmin)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_UpdateUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUser'
type Store_UpdateUser_Call struct {
*mock.Call
}
// UpdateUser is a helper method to define mock.On call
// - ctx context.Context
// - username string
// - email *string
// - isAdmin *bool
func (_e *Store_Expecter) UpdateUser(ctx interface{}, username interface{}, email interface{}, isAdmin interface{}) *Store_UpdateUser_Call {
return &Store_UpdateUser_Call{Call: _e.mock.On("UpdateUser", ctx, username, email, isAdmin)}
}
func (_c *Store_UpdateUser_Call) Run(run func(ctx context.Context, username string, email *string, isAdmin *bool)) *Store_UpdateUser_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*string), args[3].(*bool))
})
return _c
}
func (_c *Store_UpdateUser_Call) Return(_a0 database.User, _a1 error) *Store_UpdateUser_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_UpdateUser_Call) RunAndReturn(run func(context.Context, string, *string, *bool) (database.User, error)) *Store_UpdateUser_Call {
_c.Call.Return(run)
return _c
}
// UpsertTalkgroup provides a mock function with given fields: ctx, arg // UpsertTalkgroup provides a mock function with given fields: ctx, arg
func (_m *Store) UpsertTalkgroup(ctx context.Context, arg []database.UpsertTalkgroupParams) *database.UpsertTalkgroupBatchResults { func (_m *Store) UpsertTalkgroup(ctx context.Context, arg []database.UpsertTalkgroupParams) *database.UpsertTalkgroupBatchResults {
ret := _m.Called(ctx, arg) ret := _m.Called(ctx, arg)

View file

@ -28,7 +28,7 @@ type Querier interface {
DeleteSystem(ctx context.Context, id int) error DeleteSystem(ctx context.Context, id int) error
DeleteTalkgroup(ctx context.Context, systemID int32, tGID int32) error DeleteTalkgroup(ctx context.Context, systemID int32, tGID int32) error
DeleteUser(ctx context.Context, username string) error DeleteUser(ctx context.Context, username string) error
GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error)
GetAllTalkgroupTags(ctx context.Context) ([]string, error) GetAllTalkgroupTags(ctx context.Context) ([]string, error)
GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error) GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error)
GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error)
@ -50,7 +50,6 @@ type Querier interface {
GetTalkgroupsWithLearnedCount(ctx context.Context, filter *string) (int64, error) GetTalkgroupsWithLearnedCount(ctx context.Context, filter *string) (int64, error)
GetTalkgroupsWithLearnedP(ctx context.Context, arg GetTalkgroupsWithLearnedPParams) ([]GetTalkgroupsWithLearnedPRow, error) GetTalkgroupsWithLearnedP(ctx context.Context, arg GetTalkgroupsWithLearnedPParams) ([]GetTalkgroupsWithLearnedPRow, error)
GetUserByID(ctx context.Context, id int) (User, error) GetUserByID(ctx context.Context, id int) (User, error)
GetUserByUID(ctx context.Context, id int) (User, error)
GetUserByUsername(ctx context.Context, username string) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error)
GetUsers(ctx context.Context) ([]User, error) GetUsers(ctx context.Context) ([]User, error)
ListCallsCount(ctx context.Context, arg ListCallsCountParams) (int64, error) ListCallsCount(ctx context.Context, arg ListCallsCountParams) (int64, error)
@ -71,6 +70,7 @@ type Querier interface {
UpdateIncident(ctx context.Context, arg UpdateIncidentParams) (Incident, error) UpdateIncident(ctx context.Context, arg UpdateIncidentParams) (Incident, error)
UpdatePassword(ctx context.Context, username string, password string) error UpdatePassword(ctx context.Context, username string, password string) error
UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error)
UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error)
UpsertTalkgroup(ctx context.Context, arg []UpsertTalkgroupParams) *UpsertTalkgroupBatchResults UpsertTalkgroup(ctx context.Context, arg []UpsertTalkgroupParams) *UpsertTalkgroupBatchResults
} }

View file

@ -7,6 +7,7 @@ package database
import ( import (
"context" "context"
"time"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
@ -91,12 +92,32 @@ func (q *Queries) DeleteUser(ctx context.Context, username string) error {
} }
const getAPIKey = `-- name: GetAPIKey :one const getAPIKey = `-- name: GetAPIKey :one
SELECT id, owner, created_at, expires, disabled, api_key FROM api_keys WHERE api_key = $1 SELECT
a.id,
a.owner,
a.created_at,
a.expires,
a.disabled,
a.api_key,
u.username
FROM api_keys a
JOIN users u ON (a.owner = u.id)
WHERE api_key = $1
` `
func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) { type GetAPIKeyRow struct {
ID int `json:"id"`
Owner int `json:"owner"`
CreatedAt time.Time `json:"created_at"`
Expires pgtype.Timestamp `json:"expires"`
Disabled *bool `json:"disabled"`
ApiKey string `json:"api_key"`
Username string `json:"username"`
}
func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error) {
row := q.db.QueryRow(ctx, getAPIKey, apiKey) row := q.db.QueryRow(ctx, getAPIKey, apiKey)
var i ApiKey var i GetAPIKeyRow
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Owner, &i.Owner,
@ -104,6 +125,7 @@ func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error)
&i.Expires, &i.Expires,
&i.Disabled, &i.Disabled,
&i.ApiKey, &i.ApiKey,
&i.Username,
) )
return i, err return i, err
} }
@ -121,7 +143,7 @@ func (q *Queries) GetAppPrefs(ctx context.Context, appName string, uid int) ([]b
const getUserByID = `-- name: GetUserByID :one const getUserByID = `-- name: GetUserByID :one
SELECT id, username, password, email, is_admin, prefs FROM users SELECT id, username, password, email, is_admin, prefs FROM users
WHERE id = $1 LIMIT 1 WHERE id = $1
` `
func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) { func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) {
@ -138,28 +160,9 @@ func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) {
return i, err return i, err
} }
const getUserByUID = `-- name: GetUserByUID :one
SELECT id, username, password, email, is_admin, prefs FROM users
WHERE id = $1 LIMIT 1
`
func (q *Queries) GetUserByUID(ctx context.Context, id int) (User, error) {
row := q.db.QueryRow(ctx, getUserByUID, id)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.Password,
&i.Email,
&i.IsAdmin,
&i.Prefs,
)
return i, err
}
const getUserByUsername = `-- name: GetUserByUsername :one const getUserByUsername = `-- name: GetUserByUsername :one
SELECT id, username, password, email, is_admin, prefs FROM users SELECT id, username, password, email, is_admin, prefs FROM users
WHERE username = $1 LIMIT 1 WHERE username = $1
` `
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
@ -224,3 +227,26 @@ func (q *Queries) UpdatePassword(ctx context.Context, username string, password
_, err := q.db.Exec(ctx, updatePassword, username, password) _, err := q.db.Exec(ctx, updatePassword, username, password)
return err return err
} }
const updateUser = `-- name: UpdateUser :one
UPDATE users SET
email = COALESCE($2, email),
is_admin = COALESCE($3, is_admin)
WHERE
username = $1
RETURNING id, username, password, email, is_admin, prefs
`
func (q *Queries) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error) {
row := q.db.QueryRow(ctx, updateUser, username, email, isAdmin)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.Password,
&i.Email,
&i.IsAdmin,
&i.Prefs,
)
return i, err
}

View file

@ -5,11 +5,14 @@ import (
"dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/jsontypes"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users"
"github.com/google/uuid" "github.com/google/uuid"
) )
type Incident struct { type Incident struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Owner users.UserID `json:"owner"`
Name string `json:"name"` Name string `json:"name"`
Description *string `json:"description"` Description *string `json:"description"`
StartTime *jsontypes.Time `json:"startTime"` StartTime *jsontypes.Time `json:"startTime"`
@ -19,6 +22,10 @@ type Incident struct {
Calls []IncidentCall `json:"calls"` Calls []IncidentCall `json:"calls"`
} }
func (inc *Incident) GetResourceName() string {
return rbac.ResourceIncident
}
type IncidentCall struct { type IncidentCall struct {
calls.Call calls.Call
Notes json.RawMessage `json:"notes"` Notes json.RawMessage `json:"notes"`

View file

@ -6,11 +6,9 @@ import (
"dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/jsontypes"
"dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/incidents" "dynatron.me/x/stillbox/pkg/incidents"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@ -74,22 +72,22 @@ func NewStore() Store {
} }
func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) { func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) {
// TODO: replace this with a real RBAC check
user, err := users.UserCheck(ctx, new(incidents.Incident), "create")
if err != nil {
return nil, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
var dbInc database.Incident var dbInc database.Incident
// TODO: replace this with a real RBAC check
owner := auth.UIDFrom(ctx)
if owner == nil {
return nil, rbac.ErrNotAuthorized
}
id := uuid.New() id := uuid.New()
txErr := db.InTx(ctx, func(db database.Store) error { txErr := db.InTx(ctx, func(db database.Store) error {
var err error var err error
dbInc, err = db.CreateIncident(ctx, database.CreateIncidentParams{ dbInc, err = db.CreateIncident(ctx, database.CreateIncidentParams{
ID: id, ID: id,
Owner: owner.Int(), Owner: user.ID.Int(),
Name: inc.Name, Name: inc.Name,
Description: inc.Description, Description: inc.Description,
StartTime: inc.StartTime.PGTypeTSTZ(), StartTime: inc.StartTime.PGTypeTSTZ(),
@ -205,6 +203,7 @@ func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incide
func fromDBIncident(id uuid.UUID, d database.Incident) incidents.Incident { func fromDBIncident(id uuid.UUID, d database.Incident) incidents.Incident {
return incidents.Incident{ return incidents.Incident{
ID: id, ID: id,
Owner: users.UserID(d.Owner),
Name: d.Name, Name: d.Name,
Description: d.Description, Description: d.Description,
StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime),
@ -223,6 +222,7 @@ func fromDBListInPRow(id uuid.UUID, d database.ListIncidentsPRow) Incident {
return Incident{ return Incident{
Incident: incidents.Incident{ Incident: incidents.Incident{
ID: id, ID: id,
Owner: users.UserID(d.Owner),
Name: d.Name, Name: d.Name,
Description: d.Description, Description: d.Description,
StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime),

View file

@ -1,26 +1,292 @@
package rbac package rbac
import ( import (
"context"
"errors" "errors"
"github.com/el-mike/restrict/v2" "github.com/el-mike/restrict/v2"
"github.com/el-mike/restrict/v2/adapters"
) )
const (
RoleUser = "User"
RoleSubmitter = "Submitter"
RoleAdmin = "Admin"
RolePublic = "Public"
RoleShareGuest = "ShareGuest"
ResourceCall = "Call"
ResourceIncident = "Incident"
ResourceTalkgroup = "Talkgroup"
ResourceAlert = "Alert"
ResourceShare = "Share"
ActionRead = "read"
ActionCreate = "create"
ActionUpdate = "update"
ActionDelete = "delete"
PresetUpdateOwn = "updateOwn"
PresetDeleteOwn = "deleteOwn"
PresetReadShared = "readShared"
)
var (
ErrBadSubject = errors.New("bad subject in token")
)
type subjectContextKey string
const SubjectCtxKey subjectContextKey = "sub"
func CtxWithSubject(ctx context.Context, sub Subject) context.Context {
return context.WithValue(ctx, SubjectCtxKey, sub)
}
func SubjectFrom(ctx context.Context) Subject {
sub, ok := ctx.Value(SubjectCtxKey).(Subject)
if ok {
return sub
}
return new(PublicSubject)
}
type rbacCtxKey string
const RBACCtxKey rbacCtxKey = "rbac"
func FromCtx(ctx context.Context) RBAC {
rbac, ok := ctx.Value(RBACCtxKey).(RBAC)
if !ok {
panic("no RBAC in context")
}
return rbac
}
func CtxWithRBAC(ctx context.Context, rbac RBAC) context.Context {
return context.WithValue(ctx, RBACCtxKey, rbac)
}
var ( var (
ErrNotAuthorized = errors.New("not authorized") ErrNotAuthorized = errors.New("not authorized")
) )
var policy = &restrict.PolicyDefinition{ var policy = &restrict.PolicyDefinition{
Roles: restrict.Roles{ Roles: restrict.Roles{
"User": { RoleUser: {
Description: "An authenticated user",
Grants: restrict.GrantsMap{ Grants: restrict.GrantsMap{
"Conversation": { ResourceIncident: {
&restrict.Permission{Action: "read"}, &restrict.Permission{Action: ActionRead},
&restrict.Permission{Action: "create"}, &restrict.Permission{Action: ActionCreate},
&restrict.Permission{Preset: PresetUpdateOwn},
&restrict.Permission{Preset: PresetDeleteOwn},
},
ResourceCall: {
&restrict.Permission{Action: ActionRead},
&restrict.Permission{Action: ActionCreate},
&restrict.Permission{Preset: PresetUpdateOwn},
&restrict.Permission{Preset: PresetDeleteOwn},
},
ResourceTalkgroup: {
&restrict.Permission{Action: ActionRead},
},
ResourceShare: {
&restrict.Permission{Action: ActionRead},
&restrict.Permission{Action: ActionCreate},
&restrict.Permission{Preset: PresetUpdateOwn},
&restrict.Permission{Preset: PresetDeleteOwn},
},
},
},
RoleSubmitter: {
Description: "A role that can submit calls",
Grants: restrict.GrantsMap{
ResourceCall: {
&restrict.Permission{Action: ActionCreate},
},
},
},
RoleShareGuest: {
Description: "Someone who has a valid share link",
Grants: restrict.GrantsMap{
ResourceCall: {
&restrict.Permission{Preset: PresetReadShared},
},
ResourceIncident: {
&restrict.Permission{Preset: PresetReadShared},
},
ResourceTalkgroup: {
&restrict.Permission{Action: ActionRead},
},
},
},
RoleAdmin: {
Parents: []string{RoleUser},
Grants: restrict.GrantsMap{
ResourceIncident: {
&restrict.Permission{Action: ActionUpdate},
&restrict.Permission{Action: ActionDelete},
},
ResourceCall: {
&restrict.Permission{Action: ActionUpdate},
&restrict.Permission{Action: ActionDelete},
},
ResourceTalkgroup: {
&restrict.Permission{Action: ActionUpdate},
&restrict.Permission{Action: ActionCreate},
&restrict.Permission{Action: ActionDelete},
},
},
},
RolePublic: {
/*
Grants: restrict.GrantsMap{
ResourceShare: {
&restrict.Permission{Action: ActionRead},
},
},
*/
},
},
PermissionPresets: restrict.PermissionPresets{
PresetUpdateOwn: &restrict.Permission{
Action: ActionUpdate,
Conditions: restrict.Conditions{
&restrict.EqualCondition{
ID: "isOwner",
Left: &restrict.ValueDescriptor{
Source: restrict.ResourceField,
Field: "Owner",
},
Right: &restrict.ValueDescriptor{
Source: restrict.SubjectField,
Field: "ID",
},
},
},
},
PresetDeleteOwn: &restrict.Permission{
Action: ActionDelete,
Conditions: restrict.Conditions{
&restrict.EqualCondition{
ID: "isOwner",
Left: &restrict.ValueDescriptor{
Source: restrict.ResourceField,
Field: "Owner",
},
Right: &restrict.ValueDescriptor{
Source: restrict.SubjectField,
Field: "ID",
},
},
},
},
PresetReadShared: &restrict.Permission{
Action: ActionRead,
Conditions: restrict.Conditions{
&restrict.EqualCondition{
ID: "isOwner",
Left: &restrict.ValueDescriptor{
Source: restrict.ContextField,
Field: "Owner",
},
Right: &restrict.ValueDescriptor{
Source: restrict.SubjectField,
Field: "ID",
},
}, },
}, },
}, },
"Guest": {},
"Admin": {},
}, },
} }
type checkOptions struct {
actions []string
context restrict.Context
}
type checkOption func(*checkOptions)
func WithActions(actions ...string) checkOption {
return func(o *checkOptions) {
o.actions = append(o.actions, actions...)
}
}
func WithContext(ctx restrict.Context) checkOption {
return func(o *checkOptions) {
o.context = ctx
}
}
func UseResource(rsc string) restrict.Resource {
return restrict.UseResource(rsc)
}
type Subject interface {
restrict.Subject
}
type Resource interface {
restrict.Resource
}
type RBAC interface {
Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error)
}
type rbac struct {
policy *restrict.PolicyManager
access *restrict.AccessManager
}
func New() (*rbac, error) {
adapter := adapters.NewInMemoryAdapter(policy)
polMan, err := restrict.NewPolicyManager(adapter, true)
if err != nil {
return nil, err
}
accMan := restrict.NewAccessManager(polMan)
return &rbac{
policy: polMan,
access: accMan,
}, nil
}
func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) {
sub := SubjectFrom(ctx)
o := checkOptions{}
for _, opt := range opts {
opt(&o)
}
req := &restrict.AccessRequest{
Subject: sub,
Resource: res,
Actions: o.actions,
Context: o.context,
}
return sub, r.access.Authorize(req)
}
type ShareLinkGuest struct {
ShareID string
}
func (s *ShareLinkGuest) GetRoles() []string {
return []string{RoleShareGuest}
}
type PublicSubject struct {
RemoteAddr string
}
func (s *PublicSubject) GetRoles() []string {
return []string{RolePublic}
}

View file

@ -129,7 +129,7 @@ var statusMapping = map[error]errResponder{
ErrTGIDMismatch: badRequestErrText, ErrTGIDMismatch: badRequestErrText,
ErrSysMismatch: badRequestErrText, ErrSysMismatch: badRequestErrText,
tgstore.ErrReference: constraintErrText, tgstore.ErrReference: constraintErrText,
ErrBadUID: unauthErrText, rbac.ErrBadSubject: unauthErrText,
ErrBadAppName: unauthErrText, ErrBadAppName: unauthErrText,
common.ErrPageOutOfRange: badRequestErrText, common.ErrPageOutOfRange: badRequestErrText,
rbac.ErrNotAuthorized: unauthErrText, rbac.ErrNotAuthorized: unauthErrText,

View file

@ -7,13 +7,13 @@ import (
"strings" "strings"
"dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )
var ( var (
ErrBadUID = errors.New("bad UID in token")
ErrBadAppName = errors.New("bad app name") ErrBadAppName = errors.New("bad app name")
) )
@ -32,10 +32,10 @@ func (ua *usersAPI) Subrouter() http.Handler {
func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
uid := auth.UIDFrom(ctx) username := auth.UsernameFrom(ctx)
if uid == nil { if username == nil {
wErr(w, r, autoError(ErrBadUID)) wErr(w, r, autoError(rbac.ErrBadSubject))
return return
} }
@ -55,7 +55,7 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) {
} }
us := users.FromCtx(ctx) us := users.FromCtx(ctx)
prefs, err := us.UserPrefs(ctx, *uid, *p.AppName) prefs, err := us.UserPrefs(ctx, *username, *p.AppName)
if err != nil { if err != nil {
wErr(w, r, autoError(err)) wErr(w, r, autoError(err))
return return
@ -67,10 +67,10 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) {
func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) { func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
uid := auth.UIDFrom(ctx) username := auth.UsernameFrom(ctx)
if uid == nil { if username == nil {
wErr(w, r, autoError(ErrBadUID)) wErr(w, r, autoError(rbac.ErrBadSubject))
return return
} }
@ -102,7 +102,7 @@ func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) {
} }
us := users.FromCtx(ctx) us := users.FromCtx(ctx)
err = us.SetUserPrefs(ctx, *uid, *p.AppName, prefs) err = us.SetUserPrefs(ctx, *username, *p.AppName, prefs)
if err != nil { if err != nil {
wErr(w, r, autoError(err)) wErr(w, r, autoError(err))
return return

View file

@ -68,7 +68,7 @@ func (s *Server) setupRoutes() {
func (s *Server) WithCtxStores() func(next http.Handler) http.Handler { func (s *Server) WithCtxStores() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(s.addStoresTo(r.Context())) r = r.WithContext(s.fillCtx(r.Context()))
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
return http.HandlerFunc(fn) return http.HandlerFunc(fn)

View file

@ -15,6 +15,7 @@ import (
"dynatron.me/x/stillbox/pkg/incidents/incstore" "dynatron.me/x/stillbox/pkg/incidents/incstore"
"dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/nexus"
"dynatron.me/x/stillbox/pkg/notify" "dynatron.me/x/stillbox/pkg/notify"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/rest" "dynatron.me/x/stillbox/pkg/rest"
"dynatron.me/x/stillbox/pkg/share" "dynatron.me/x/stillbox/pkg/share"
"dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sinks"
@ -50,6 +51,7 @@ type Server struct {
calls callstore.Store calls callstore.Store
incidents incstore.Store incidents incstore.Store
share share.Service share share.Service
rbac rbac.RBAC
} }
func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
@ -75,6 +77,11 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
tgCache := tgstore.NewCache() tgCache := tgstore.NewCache()
api := rest.New(cfg.BaseURL.URL()) api := rest.New(cfg.BaseURL.URL())
rbacSvc, err := rbac.New()
if err != nil {
return nil, err
}
srv := &Server{ srv := &Server{
auth: authenticator, auth: authenticator,
conf: cfg, conf: cfg,
@ -91,6 +98,7 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
users: users.NewStore(), users: users.NewStore(),
calls: callstore.NewStore(), calls: callstore.NewStore(),
incidents: incstore.NewStore(), incidents: incstore.NewStore(),
rbac: rbacSvc,
} }
if cfg.DB.Partition.Enabled { if cfg.DB.Partition.Enabled {
@ -138,13 +146,14 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
return srv, nil return srv, nil
} }
func (s *Server) addStoresTo(ctx context.Context) context.Context { func (s *Server) fillCtx(ctx context.Context) context.Context {
ctx = database.CtxWithDB(ctx, s.db) ctx = database.CtxWithDB(ctx, s.db)
ctx = tgstore.CtxWithStore(ctx, s.tgs) ctx = tgstore.CtxWithStore(ctx, s.tgs)
ctx = users.CtxWithStore(ctx, s.users) ctx = users.CtxWithStore(ctx, s.users)
ctx = callstore.CtxWithStore(ctx, s.calls) ctx = callstore.CtxWithStore(ctx, s.calls)
ctx = incstore.CtxWithStore(ctx, s.incidents) ctx = incstore.CtxWithStore(ctx, s.incidents)
ctx = share.CtxWithStore(ctx, s.share.Store()) ctx = share.CtxWithStore(ctx, s.share.ShareStore())
ctx = rbac.CtxWithRBAC(ctx, s.rbac)
return ctx return ctx
} }
@ -154,7 +163,7 @@ func (s *Server) Go(ctx context.Context) error {
s.installHupHandler() s.installHupHandler()
ctx = s.addStoresTo(ctx) ctx = s.fillCtx(ctx)
httpSrv := &http.Server{ httpSrv := &http.Server{
Addr: s.conf.Listen, Addr: s.conf.Listen,

View file

@ -12,17 +12,17 @@ const (
) )
type Service interface { type Service interface {
Store() Store ShareStore() Store
Go(ctx context.Context) Go(ctx context.Context)
} }
type service struct { type service struct {
store Store Store
} }
func (s *service) Store() Store { func (s *service) ShareStore() Store {
return s.store return s.Store
} }
func (s *service) Go(ctx context.Context) { func (s *service) Go(ctx context.Context) {
@ -31,7 +31,7 @@ func (s *service) Go(ctx context.Context) {
for { for {
select { select {
case <-tick.C: case <-tick.C:
err := s.store.Prune(ctx) err := s.Prune(ctx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("share prune failed") log.Error().Err(err).Msg("share prune failed")
} }
@ -44,6 +44,6 @@ func (s *service) Go(ctx context.Context) {
func NewService() *service { func NewService() *service {
return &service{ return &service{
store: NewStore(), Store: NewStore(),
} }
} }

View file

@ -9,6 +9,7 @@ import (
"strings" "strings"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
) )
type Talkgroup struct { type Talkgroup struct {
@ -17,6 +18,10 @@ type Talkgroup struct {
Learned bool `json:"learned"` Learned bool `json:"learned"`
} }
func (t *Talkgroup) GetResourceName() string {
return rbac.ResourceTalkgroup
}
func (t Talkgroup) String() string { func (t Talkgroup) String() string {
if t.System.Name == "" { if t.System.Name == "" {
t.System.Name = strconv.Itoa(int(t.Talkgroup.TGID)) t.System.Name = strconv.Itoa(int(t.Talkgroup.TGID))

View file

@ -8,11 +8,11 @@ import (
"time" "time"
"dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
tgsp "dynatron.me/x/stillbox/pkg/talkgroups" tgsp "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/users"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -515,20 +515,26 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool)
} }
func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*tgsp.Talkgroup, error) { func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*tgsp.Talkgroup, error) {
user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update")
if err != nil {
return nil, err
}
sysName, has := t.SystemName(ctx, int(*input.SystemID)) sysName, has := t.SystemName(ctx, int(*input.SystemID))
if !has { if !has {
return nil, ErrNoSuchSystem return nil, ErrNoSuchSystem
} }
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
var tg database.Talkgroup var tg database.Talkgroup
err := db.InTx(ctx, func(db database.Store) error { err = db.InTx(ctx, func(db database.Store) error {
var oerr error var oerr error
tg, oerr = db.UpdateTalkgroup(ctx, input) tg, oerr = db.UpdateTalkgroup(ctx, input)
if oerr != nil { if oerr != nil {
return oerr return oerr
} }
versionBatch := db.StoreTGVersion(ctx, []database.StoreTGVersionParams{{ versionBatch := db.StoreTGVersion(ctx, []database.StoreTGVersionParams{{
Submitter: auth.UIDFrom(ctx).Int32Ptr(), Submitter: user.ID.Int32Ptr(),
TGID: *input.TGID, TGID: *input.TGID,
}}) }})
defer versionBatch.Close() defer versionBatch.Close()
@ -577,8 +583,13 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
err := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update")
err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), auth.UIDFrom(ctx).Int32Ptr()) if err != nil {
return err
}
err = database.FromCtx(ctx).InTx(ctx, func(db database.Store) error {
err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr())
if err != nil { if err != nil {
return err return err
} }
@ -633,6 +644,11 @@ func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, er
} }
func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.UpsertTalkgroupParams) ([]*tgsp.Talkgroup, error) { func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.UpsertTalkgroupParams) ([]*tgsp.Talkgroup, error) {
user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "create+update")
if err != nil {
return nil, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
sysName, hasSys := t.SystemName(ctx, system) sysName, hasSys := t.SystemName(ctx, system)
if !hasSys { if !hasSys {
@ -645,7 +661,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse
tgs := make([]*tgsp.Talkgroup, 0, len(input)) tgs := make([]*tgsp.Talkgroup, 0, len(input))
err := db.InTx(ctx, func(db database.Store) error { err = db.InTx(ctx, func(db database.Store) error {
versionParams := make([]database.StoreTGVersionParams, 0, len(input)) versionParams := make([]database.StoreTGVersionParams, 0, len(input))
for i := range input { for i := range input {
// normalize tags // normalize tags
@ -670,7 +686,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse
versionParams = append(versionParams, database.StoreTGVersionParams{ versionParams = append(versionParams, database.StoreTGVersionParams{
SystemID: int32(system), SystemID: int32(system),
TGID: r.TGID, TGID: r.TGID,
Submitter: auth.UIDFrom(ctx).Int32Ptr(), Submitter: user.ID.Int32Ptr(),
}) })
tgs = append(tgs, &tgsp.Talkgroup{ tgs = append(tgs, &tgsp.Talkgroup{
Talkgroup: r, Talkgroup: r,

21
pkg/users/guest.go Normal file
View file

@ -0,0 +1,21 @@
package users
import (
"dynatron.me/x/stillbox/pkg/rbac"
)
type ShareLinkGuest struct {
ShareID string
}
func (s *ShareLinkGuest) GetRoles() []string {
return []string{rbac.RoleShareGuest}
}
type Public struct {
RemoteAddr string
}
func (s *Public) GetRoles() []string {
return []string{rbac.RolePublic}
}

View file

@ -3,22 +3,35 @@ package users
import ( import (
"context" "context"
"dynatron.me/x/stillbox/internal/cache"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
) )
type Store interface { type Store interface {
// GetUser gets a user by UID.
GetUser(ctx context.Context, username string) (*User, error)
// UserPrefs gets the preferences for the specified user and app name. // UserPrefs gets the preferences for the specified user and app name.
UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) UserPrefs(ctx context.Context, username string, appName string) ([]byte, error)
// SetUserPrefs sets the preferences for the specified user and app name. // SetUserPrefs sets the preferences for the specified user and app name.
SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error
// Invalidate clears the user cache.
Invalidate()
// UpdateUser updates a user's record
UpdateUser(ctx context.Context, username string, user UserUpdate) error
} }
type postgresStore struct { type postgresStore struct {
cache.Cache[string, *User]
} }
func NewStore() *postgresStore { func NewStore() *postgresStore {
return new(postgresStore) return &postgresStore{
Cache: cache.New[string, *User](),
}
} }
type storeCtxKey string type storeCtxKey string
@ -38,10 +51,53 @@ func FromCtx(ctx context.Context) Store {
return s return s
} }
func (s *postgresStore) UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) { func (s *postgresStore) Invalidate() {
s.Clear()
}
type UserUpdate struct {
Email *string `json:"email"`
IsAdmin *bool `json:"isAdmin"`
}
func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error {
dbu, err := database.FromCtx(ctx).UpdateUser(ctx, username, user.Email, user.IsAdmin)
if err != nil {
return err
}
s.Set(username, fromDBUser(dbu))
return nil
}
func (s *postgresStore) GetUser(ctx context.Context, username string) (*User, error) {
u, has := s.Get(username)
if has {
return u, nil
}
db := database.FromCtx(ctx)
dbu, err := db.GetUserByUsername(ctx, username)
if err != nil {
return nil, err
}
u = fromDBUser(dbu)
s.Set(username, u)
return u, nil
}
func (s *postgresStore) UserPrefs(ctx context.Context, username string, appName string) ([]byte, error) {
u, err := s.GetUser(ctx, username)
if err != nil {
return nil, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
prefs, err := db.GetAppPrefs(ctx, appName, int(uid)) prefs, err := db.GetAppPrefs(ctx, appName, int(u.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -49,10 +105,13 @@ func (s *postgresStore) UserPrefs(ctx context.Context, uid int32, appName string
return []byte(prefs), err return []byte(prefs), err
} }
func (s *postgresStore) SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error { func (s *postgresStore) SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error {
u, err := s.GetUser(ctx, username)
if err != nil {
return err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
return db.SetAppPrefs(ctx, appName, prefs, int(uid)) return db.SetAppPrefs(ctx, appName, prefs, int(u.ID))
} }
//func (s *postgresStore)

View file

@ -1,7 +1,12 @@
package users package users
import ( import (
"context"
"encoding/json" "encoding/json"
"strings"
"dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
) )
type UserID int type UserID int
@ -20,6 +25,38 @@ func (u UserID) Int() int {
return int(u) return int(u)
} }
func (u UserID) IsValid() bool {
return u > 0
}
func From(ctx context.Context) (*User, error) {
sub := rbac.SubjectFrom(ctx)
return FromSubject(sub)
}
func UserCheck(ctx context.Context, rsc rbac.Resource, actions string) (*User, error) {
acts := strings.Split(actions, "+")
subj, err := rbac.FromCtx(ctx).Check(ctx, rsc, rbac.WithActions(acts...))
if err != nil {
return nil, err
}
return FromSubject(subj)
}
func FromSubject(sub rbac.Subject) (*User, error) {
if sub == nil {
return nil, rbac.ErrBadSubject
}
user, isUser := sub.(*User)
if !isUser || user == nil || !user.ID.IsValid() {
return nil, rbac.ErrBadSubject
}
return user, nil
}
type User struct { type User struct {
ID UserID ID UserID
Username string Username string
@ -28,3 +65,26 @@ type User struct {
IsAdmin bool IsAdmin bool
Prefs json.RawMessage Prefs json.RawMessage
} }
func (u *User) GetRoles() []string {
r := make([]string, 1, 2)
r[0] = rbac.RoleUser
if u.IsAdmin {
r = append(r, rbac.RoleAdmin)
}
return r
}
func fromDBUser(dbu database.User) *User {
return &User{
ID: UserID(dbu.ID),
Username: dbu.Username,
Password: dbu.Password,
Email: dbu.Email,
IsAdmin: dbu.IsAdmin,
Prefs: dbu.Prefs,
}
}

View file

@ -1,14 +1,10 @@
-- name: GetUserByID :one -- name: GetUserByID :one
SELECT * FROM users SELECT * FROM users
WHERE id = $1 LIMIT 1; WHERE id = $1;
-- name: GetUserByUsername :one -- name: GetUserByUsername :one
SELECT * FROM users SELECT * FROM users
WHERE username = $1 LIMIT 1; WHERE username = $1;
-- name: GetUserByUID :one
SELECT * FROM users
WHERE id = $1 LIMIT 1;
-- name: GetUsers :many -- name: GetUsers :many
SELECT * FROM users; SELECT * FROM users;
@ -28,6 +24,14 @@ DELETE FROM users WHERE username = $1;
-- name: UpdatePassword :exec -- name: UpdatePassword :exec
UPDATE users SET password = $2 WHERE username = $1; UPDATE users SET password = $2 WHERE username = $1;
-- name: UpdateUser :one
UPDATE users SET
email = COALESCE(sqlc.narg('email'), email),
is_admin = COALESCE(sqlc.narg('is_admin'), is_admin)
WHERE
username = $1
RETURNING *;
-- name: CreateAPIKey :one -- name: CreateAPIKey :one
INSERT INTO api_keys( INSERT INTO api_keys(
owner, owner,
@ -42,7 +46,17 @@ RETURNING *;
DELETE FROM api_keys WHERE api_key = $1; DELETE FROM api_keys WHERE api_key = $1;
-- name: GetAPIKey :one -- name: GetAPIKey :one
SELECT * FROM api_keys WHERE api_key = $1; SELECT
a.id,
a.owner,
a.created_at,
a.expires,
a.disabled,
a.api_key,
u.username
FROM api_keys a
JOIN users u ON (a.owner = u.id)
WHERE api_key = $1;
-- name: GetAppPrefs :one -- name: GetAppPrefs :one
SELECT (prefs->>(@app_name::TEXT))::JSONB FROM users WHERE id = @uid; SELECT (prefs->>(@app_name::TEXT))::JSONB FROM users WHERE id = @uid;