diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 0ce70d8..a364406 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -1,7 +1,6 @@ package cache -import ( -) +import "sync" type Cache[K comparable, V any] interface { Get(K) (V, bool) @@ -10,25 +9,38 @@ type Cache[K comparable, V any] interface { Clear() } -type inMem[K comparable, V any] map[K]V - -func New[K comparable, V any]() inMem[K, V] { - return make(inMem[K, V]) +type inMem[K comparable, V any] struct { + sync.RWMutex + m map[K]V } -func (c inMem[K, V]) Get(key K) (V, bool) { - v, ok := c[key] +func New[K comparable, V any]() *inMem[K, V] { + return &inMem[K, V]{ + m: make(map[K]V), + } +} + +func (c *inMem[K, V]) Get(key K) (V, bool) { + c.RLock() + defer c.RUnlock() + v, ok := c.m[key] return v, ok } -func (c inMem[K, V]) Set(key K, val V) { - c[key] = val +func (c *inMem[K, V]) Set(key K, val V) { + c.Lock() + defer c.Unlock() + c.m[key] = val } -func (c inMem[K, V]) Delete(key K) { - delete(c, key) +func (c *inMem[K, V]) Delete(key K) { + c.Lock() + defer c.Unlock() + delete(c.m, key) } -func (c inMem[K, V]) Clear() { - clear(c) +func (c *inMem[K, V]) Clear() { + c.Lock() + defer c.Unlock() + clear(c.m) } diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 289120e..5b25084 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -11,11 +11,13 @@ import ( "golang.org/x/crypto/bcrypt" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/rs/zerolog/log" ) @@ -45,22 +47,16 @@ type jwtAuth interface { type claims map[string]interface{} -// TODO: change this to UserFrom() *users.User -func UIDFrom(ctx context.Context) *users.UserID { +// UsernameFrom gets the username (just the subject from token) from ctx. +func UsernameFrom(ctx context.Context) *string { tok, _, err := jwtauth.FromContext(ctx) if err != nil { return nil } - uidStr := tok.Subject() - uidInt, err := strconv.Atoi(uidStr) - if err != nil { - return nil - } + username := tok.Subject() - uid := users.UserID(int32(uidInt)) - - return &uid + return &username } func (a *Auth) Authenticated(r *http.Request) (claims, bool) { @@ -90,7 +86,38 @@ func TokenFromCookie(r *http.Request) string { } func (a *Auth) AuthMiddleware() func(http.Handler) http.Handler { - return jwtauth.Authenticator(a.jwt) + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + token, _, err := jwtauth.FromContext(r.Context()) + + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + if token != nil && jwt.Validate(token, a.jwt.ValidateOptions()...) == nil { + ctx := r.Context() + username := token.Subject() + + sub, err := users.FromCtx(ctx).GetUser(ctx, username) + if err != nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + ctx = rbac.CtxWithSubject(ctx, sub) + + next.ServeHTTP(w, r.WithContext(ctx)) + + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + } + return http.HandlerFunc(hfn) + } + } func (a *Auth) initJWT() { diff --git a/pkg/calls/call.go b/pkg/calls/call.go index 3d46fa6..090fa4f 100644 --- a/pkg/calls/call.go +++ b/pkg/calls/call.go @@ -8,6 +8,7 @@ import ( "dynatron.me/x/stillbox/internal/audio" "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/pkg/pb" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/users" @@ -73,6 +74,10 @@ type Call struct { shouldStore bool `json:"-"` } +func (c *Call) GetResourceName() string { + return rbac.ResourceCall +} + func (c *Call) String() string { return fmt.Sprintf("%s to %d from %d", c.AudioName, c.Talkgroup, c.Source) } diff --git a/pkg/database/mocks/Store.go b/pkg/database/mocks/Store.go index 86f8601..2206909 100644 --- a/pkg/database/mocks/Store.go +++ b/pkg/database/mocks/Store.go @@ -1127,22 +1127,22 @@ func (_c *Store_DropPartition_Call) RunAndReturn(run func(context.Context, strin } // GetAPIKey provides a mock function with given fields: ctx, apiKey -func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.ApiKey, error) { +func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.GetAPIKeyRow, error) { ret := _m.Called(ctx, apiKey) if len(ret) == 0 { panic("no return value specified for GetAPIKey") } - var r0 database.ApiKey + var r0 database.GetAPIKeyRow var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (database.ApiKey, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (database.GetAPIKeyRow, error)); ok { return rf(ctx, apiKey) } - if rf, ok := ret.Get(0).(func(context.Context, string) database.ApiKey); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) database.GetAPIKeyRow); ok { r0 = rf(ctx, apiKey) } else { - r0 = ret.Get(0).(database.ApiKey) + r0 = ret.Get(0).(database.GetAPIKeyRow) } if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -1173,12 +1173,12 @@ func (_c *Store_GetAPIKey_Call) Run(run func(ctx context.Context, apiKey string) return _c } -func (_c *Store_GetAPIKey_Call) Return(_a0 database.ApiKey, _a1 error) *Store_GetAPIKey_Call { +func (_c *Store_GetAPIKey_Call) Return(_a0 database.GetAPIKeyRow, _a1 error) *Store_GetAPIKey_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.ApiKey, error)) *Store_GetAPIKey_Call { +func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.GetAPIKeyRow, error)) *Store_GetAPIKey_Call { _c.Call.Return(run) return _c } @@ -2584,63 +2584,6 @@ func (_c *Store_GetUserByID_Call) RunAndReturn(run func(context.Context, int) (d return _c } -// GetUserByUID provides a mock function with given fields: ctx, id -func (_m *Store) GetUserByUID(ctx context.Context, id int) (database.User, error) { - ret := _m.Called(ctx, id) - - if len(ret) == 0 { - panic("no return value specified for GetUserByUID") - } - - var r0 database.User - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int) (database.User, error)); ok { - return rf(ctx, id) - } - if rf, ok := ret.Get(0).(func(context.Context, int) database.User); ok { - r0 = rf(ctx, id) - } else { - r0 = ret.Get(0).(database.User) - } - - if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { - r1 = rf(ctx, id) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Store_GetUserByUID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserByUID' -type Store_GetUserByUID_Call struct { - *mock.Call -} - -// GetUserByUID is a helper method to define mock.On call -// - ctx context.Context -// - id int -func (_e *Store_Expecter) GetUserByUID(ctx interface{}, id interface{}) *Store_GetUserByUID_Call { - return &Store_GetUserByUID_Call{Call: _e.mock.On("GetUserByUID", ctx, id)} -} - -func (_c *Store_GetUserByUID_Call) Run(run func(ctx context.Context, id int)) *Store_GetUserByUID_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int)) - }) - return _c -} - -func (_c *Store_GetUserByUID_Call) Return(_a0 database.User, _a1 error) *Store_GetUserByUID_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Store_GetUserByUID_Call) RunAndReturn(run func(context.Context, int) (database.User, error)) *Store_GetUserByUID_Call { - _c.Call.Return(run) - return _c -} - // GetUserByUsername provides a mock function with given fields: ctx, username func (_m *Store) GetUserByUsername(ctx context.Context, username string) (database.User, error) { ret := _m.Called(ctx, username) @@ -3702,6 +3645,65 @@ func (_c *Store_UpdateTalkgroup_Call) RunAndReturn(run func(context.Context, dat return _c } +// UpdateUser provides a mock function with given fields: ctx, username, email, isAdmin +func (_m *Store) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (database.User, error) { + ret := _m.Called(ctx, username, email, isAdmin) + + if len(ret) == 0 { + panic("no return value specified for UpdateUser") + } + + var r0 database.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) (database.User, error)); ok { + return rf(ctx, username, email, isAdmin) + } + if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) database.User); ok { + r0 = rf(ctx, username, email, isAdmin) + } else { + r0 = ret.Get(0).(database.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, *string, *bool) error); ok { + r1 = rf(ctx, username, email, isAdmin) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_UpdateUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUser' +type Store_UpdateUser_Call struct { + *mock.Call +} + +// UpdateUser is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - email *string +// - isAdmin *bool +func (_e *Store_Expecter) UpdateUser(ctx interface{}, username interface{}, email interface{}, isAdmin interface{}) *Store_UpdateUser_Call { + return &Store_UpdateUser_Call{Call: _e.mock.On("UpdateUser", ctx, username, email, isAdmin)} +} + +func (_c *Store_UpdateUser_Call) Run(run func(ctx context.Context, username string, email *string, isAdmin *bool)) *Store_UpdateUser_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(*string), args[3].(*bool)) + }) + return _c +} + +func (_c *Store_UpdateUser_Call) Return(_a0 database.User, _a1 error) *Store_UpdateUser_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_UpdateUser_Call) RunAndReturn(run func(context.Context, string, *string, *bool) (database.User, error)) *Store_UpdateUser_Call { + _c.Call.Return(run) + return _c +} + // UpsertTalkgroup provides a mock function with given fields: ctx, arg func (_m *Store) UpsertTalkgroup(ctx context.Context, arg []database.UpsertTalkgroupParams) *database.UpsertTalkgroupBatchResults { ret := _m.Called(ctx, arg) diff --git a/pkg/database/querier.go b/pkg/database/querier.go index eac7252..bf2194f 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -28,7 +28,7 @@ type Querier interface { DeleteSystem(ctx context.Context, id int) error DeleteTalkgroup(ctx context.Context, systemID int32, tGID int32) error DeleteUser(ctx context.Context, username string) error - GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) + GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error) GetAllTalkgroupTags(ctx context.Context) ([]string, error) GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error) @@ -50,7 +50,6 @@ type Querier interface { GetTalkgroupsWithLearnedCount(ctx context.Context, filter *string) (int64, error) GetTalkgroupsWithLearnedP(ctx context.Context, arg GetTalkgroupsWithLearnedPParams) ([]GetTalkgroupsWithLearnedPRow, error) GetUserByID(ctx context.Context, id int) (User, error) - GetUserByUID(ctx context.Context, id int) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) GetUsers(ctx context.Context) ([]User, error) ListCallsCount(ctx context.Context, arg ListCallsCountParams) (int64, error) @@ -71,6 +70,7 @@ type Querier interface { UpdateIncident(ctx context.Context, arg UpdateIncidentParams) (Incident, error) UpdatePassword(ctx context.Context, username string, password string) error UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) + UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error) UpsertTalkgroup(ctx context.Context, arg []UpsertTalkgroupParams) *UpsertTalkgroupBatchResults } diff --git a/pkg/database/users.sql.go b/pkg/database/users.sql.go index 94d144f..74f2b43 100644 --- a/pkg/database/users.sql.go +++ b/pkg/database/users.sql.go @@ -7,6 +7,7 @@ package database import ( "context" + "time" "github.com/jackc/pgx/v5/pgtype" ) @@ -91,12 +92,32 @@ func (q *Queries) DeleteUser(ctx context.Context, username string) error { } const getAPIKey = `-- name: GetAPIKey :one -SELECT id, owner, created_at, expires, disabled, api_key FROM api_keys WHERE api_key = $1 +SELECT + a.id, + a.owner, + a.created_at, + a.expires, + a.disabled, + a.api_key, + u.username +FROM api_keys a +JOIN users u ON (a.owner = u.id) +WHERE api_key = $1 ` -func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) { +type GetAPIKeyRow struct { + ID int `json:"id"` + Owner int `json:"owner"` + CreatedAt time.Time `json:"created_at"` + Expires pgtype.Timestamp `json:"expires"` + Disabled *bool `json:"disabled"` + ApiKey string `json:"api_key"` + Username string `json:"username"` +} + +func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error) { row := q.db.QueryRow(ctx, getAPIKey, apiKey) - var i ApiKey + var i GetAPIKeyRow err := row.Scan( &i.ID, &i.Owner, @@ -104,6 +125,7 @@ func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) &i.Expires, &i.Disabled, &i.ApiKey, + &i.Username, ) return i, err } @@ -121,7 +143,7 @@ func (q *Queries) GetAppPrefs(ctx context.Context, appName string, uid int) ([]b const getUserByID = `-- name: GetUserByID :one SELECT id, username, password, email, is_admin, prefs FROM users -WHERE id = $1 LIMIT 1 +WHERE id = $1 ` func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) { @@ -138,28 +160,9 @@ func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) { return i, err } -const getUserByUID = `-- name: GetUserByUID :one -SELECT id, username, password, email, is_admin, prefs FROM users -WHERE id = $1 LIMIT 1 -` - -func (q *Queries) GetUserByUID(ctx context.Context, id int) (User, error) { - row := q.db.QueryRow(ctx, getUserByUID, id) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Password, - &i.Email, - &i.IsAdmin, - &i.Prefs, - ) - return i, err -} - const getUserByUsername = `-- name: GetUserByUsername :one SELECT id, username, password, email, is_admin, prefs FROM users -WHERE username = $1 LIMIT 1 +WHERE username = $1 ` func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { @@ -224,3 +227,26 @@ func (q *Queries) UpdatePassword(ctx context.Context, username string, password _, err := q.db.Exec(ctx, updatePassword, username, password) return err } + +const updateUser = `-- name: UpdateUser :one +UPDATE users SET + email = COALESCE($2, email), + is_admin = COALESCE($3, is_admin) +WHERE + username = $1 +RETURNING id, username, password, email, is_admin, prefs +` + +func (q *Queries) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error) { + row := q.db.QueryRow(ctx, updateUser, username, email, isAdmin) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Password, + &i.Email, + &i.IsAdmin, + &i.Prefs, + ) + return i, err +} diff --git a/pkg/incidents/incident.go b/pkg/incidents/incident.go index c2ee068..b48f152 100644 --- a/pkg/incidents/incident.go +++ b/pkg/incidents/incident.go @@ -5,11 +5,14 @@ import ( "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" ) type Incident struct { ID uuid.UUID `json:"id"` + Owner users.UserID `json:"owner"` Name string `json:"name"` Description *string `json:"description"` StartTime *jsontypes.Time `json:"startTime"` @@ -19,6 +22,10 @@ type Incident struct { Calls []IncidentCall `json:"calls"` } +func (inc *Incident) GetResourceName() string { + return rbac.ResourceIncident +} + type IncidentCall struct { calls.Call Notes json.RawMessage `json:"notes"` diff --git a/pkg/incidents/incstore/store.go b/pkg/incidents/incstore/store.go index 019f59c..c7c8bfa 100644 --- a/pkg/incidents/incstore/store.go +++ b/pkg/incidents/incstore/store.go @@ -6,11 +6,9 @@ import ( "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/jsontypes" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/incidents" - "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -74,22 +72,22 @@ func NewStore() Store { } func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) { + // TODO: replace this with a real RBAC check + user, err := users.UserCheck(ctx, new(incidents.Incident), "create") + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) var dbInc database.Incident - // TODO: replace this with a real RBAC check - owner := auth.UIDFrom(ctx) - if owner == nil { - return nil, rbac.ErrNotAuthorized - } - id := uuid.New() txErr := db.InTx(ctx, func(db database.Store) error { var err error dbInc, err = db.CreateIncident(ctx, database.CreateIncidentParams{ ID: id, - Owner: owner.Int(), + Owner: user.ID.Int(), Name: inc.Name, Description: inc.Description, StartTime: inc.StartTime.PGTypeTSTZ(), @@ -205,6 +203,7 @@ func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incide func fromDBIncident(id uuid.UUID, d database.Incident) incidents.Incident { return incidents.Incident{ ID: id, + Owner: users.UserID(d.Owner), Name: d.Name, Description: d.Description, StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), @@ -223,6 +222,7 @@ func fromDBListInPRow(id uuid.UUID, d database.ListIncidentsPRow) Incident { return Incident{ Incident: incidents.Incident{ ID: id, + Owner: users.UserID(d.Owner), Name: d.Name, Description: d.Description, StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), diff --git a/pkg/rbac/rbac.go b/pkg/rbac/rbac.go index 5c75598..0b5ede1 100644 --- a/pkg/rbac/rbac.go +++ b/pkg/rbac/rbac.go @@ -1,26 +1,292 @@ package rbac import ( + "context" "errors" "github.com/el-mike/restrict/v2" + "github.com/el-mike/restrict/v2/adapters" ) +const ( + RoleUser = "User" + RoleSubmitter = "Submitter" + RoleAdmin = "Admin" + RolePublic = "Public" + RoleShareGuest = "ShareGuest" + + ResourceCall = "Call" + ResourceIncident = "Incident" + ResourceTalkgroup = "Talkgroup" + ResourceAlert = "Alert" + ResourceShare = "Share" + + ActionRead = "read" + ActionCreate = "create" + ActionUpdate = "update" + ActionDelete = "delete" + + PresetUpdateOwn = "updateOwn" + PresetDeleteOwn = "deleteOwn" + PresetReadShared = "readShared" +) + +var ( + ErrBadSubject = errors.New("bad subject in token") +) + +type subjectContextKey string + +const SubjectCtxKey subjectContextKey = "sub" + +func CtxWithSubject(ctx context.Context, sub Subject) context.Context { + return context.WithValue(ctx, SubjectCtxKey, sub) +} + +func SubjectFrom(ctx context.Context) Subject { + sub, ok := ctx.Value(SubjectCtxKey).(Subject) + if ok { + return sub + } + + return new(PublicSubject) +} + +type rbacCtxKey string + +const RBACCtxKey rbacCtxKey = "rbac" + +func FromCtx(ctx context.Context) RBAC { + rbac, ok := ctx.Value(RBACCtxKey).(RBAC) + if !ok { + panic("no RBAC in context") + } + + return rbac +} + +func CtxWithRBAC(ctx context.Context, rbac RBAC) context.Context { + return context.WithValue(ctx, RBACCtxKey, rbac) +} + var ( ErrNotAuthorized = errors.New("not authorized") ) var policy = &restrict.PolicyDefinition{ Roles: restrict.Roles{ - "User": { + RoleUser: { + Description: "An authenticated user", Grants: restrict.GrantsMap{ - "Conversation": { - &restrict.Permission{Action: "read"}, - &restrict.Permission{Action: "create"}, + ResourceIncident: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateOwn}, + &restrict.Permission{Preset: PresetDeleteOwn}, + }, + ResourceCall: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateOwn}, + &restrict.Permission{Preset: PresetDeleteOwn}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionRead}, + }, + ResourceShare: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateOwn}, + &restrict.Permission{Preset: PresetDeleteOwn}, + }, + }, + }, + RoleSubmitter: { + Description: "A role that can submit calls", + Grants: restrict.GrantsMap{ + ResourceCall: { + &restrict.Permission{Action: ActionCreate}, + }, + }, + }, + RoleShareGuest: { + Description: "Someone who has a valid share link", + Grants: restrict.GrantsMap{ + ResourceCall: { + &restrict.Permission{Preset: PresetReadShared}, + }, + ResourceIncident: { + &restrict.Permission{Preset: PresetReadShared}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionRead}, + }, + }, + }, + RoleAdmin: { + Parents: []string{RoleUser}, + Grants: restrict.GrantsMap{ + ResourceIncident: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionDelete}, + }, + ResourceCall: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionDelete}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Action: ActionDelete}, + }, + }, + }, + RolePublic: { + /* + Grants: restrict.GrantsMap{ + ResourceShare: { + &restrict.Permission{Action: ActionRead}, + }, + }, + */ + }, + }, + PermissionPresets: restrict.PermissionPresets{ + PresetUpdateOwn: &restrict.Permission{ + Action: ActionUpdate, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetDeleteOwn: &restrict.Permission{ + Action: ActionDelete, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetReadShared: &restrict.Permission{ + Action: ActionRead, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ContextField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, }, }, }, - "Guest": {}, - "Admin": {}, }, } + +type checkOptions struct { + actions []string + context restrict.Context +} + +type checkOption func(*checkOptions) + +func WithActions(actions ...string) checkOption { + return func(o *checkOptions) { + o.actions = append(o.actions, actions...) + } +} + +func WithContext(ctx restrict.Context) checkOption { + return func(o *checkOptions) { + o.context = ctx + } +} + +func UseResource(rsc string) restrict.Resource { + return restrict.UseResource(rsc) +} + +type Subject interface { + restrict.Subject +} + +type Resource interface { + restrict.Resource +} + +type RBAC interface { + Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) +} + +type rbac struct { + policy *restrict.PolicyManager + access *restrict.AccessManager +} + +func New() (*rbac, error) { + adapter := adapters.NewInMemoryAdapter(policy) + polMan, err := restrict.NewPolicyManager(adapter, true) + if err != nil { + return nil, err + } + + accMan := restrict.NewAccessManager(polMan) + return &rbac{ + policy: polMan, + access: accMan, + }, nil +} + +func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) { + sub := SubjectFrom(ctx) + o := checkOptions{} + + for _, opt := range opts { + opt(&o) + } + + req := &restrict.AccessRequest{ + Subject: sub, + Resource: res, + Actions: o.actions, + Context: o.context, + } + + return sub, r.access.Authorize(req) +} + +type ShareLinkGuest struct { + ShareID string +} + +func (s *ShareLinkGuest) GetRoles() []string { + return []string{RoleShareGuest} +} + +type PublicSubject struct { + RemoteAddr string +} + +func (s *PublicSubject) GetRoles() []string { + return []string{RolePublic} +} diff --git a/pkg/rest/api.go b/pkg/rest/api.go index fbd2bae..5920133 100644 --- a/pkg/rest/api.go +++ b/pkg/rest/api.go @@ -129,7 +129,7 @@ var statusMapping = map[error]errResponder{ ErrTGIDMismatch: badRequestErrText, ErrSysMismatch: badRequestErrText, tgstore.ErrReference: constraintErrText, - ErrBadUID: unauthErrText, + rbac.ErrBadSubject: unauthErrText, ErrBadAppName: unauthErrText, common.ErrPageOutOfRange: badRequestErrText, rbac.ErrNotAuthorized: unauthErrText, diff --git a/pkg/rest/users.go b/pkg/rest/users.go index 704088b..1fa6703 100644 --- a/pkg/rest/users.go +++ b/pkg/rest/users.go @@ -7,13 +7,13 @@ import ( "strings" "dynatron.me/x/stillbox/pkg/auth" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" ) var ( - ErrBadUID = errors.New("bad UID in token") ErrBadAppName = errors.New("bad app name") ) @@ -32,10 +32,10 @@ func (ua *usersAPI) Subrouter() http.Handler { func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - uid := auth.UIDFrom(ctx) + username := auth.UsernameFrom(ctx) - if uid == nil { - wErr(w, r, autoError(ErrBadUID)) + if username == nil { + wErr(w, r, autoError(rbac.ErrBadSubject)) return } @@ -55,7 +55,7 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { } us := users.FromCtx(ctx) - prefs, err := us.UserPrefs(ctx, *uid, *p.AppName) + prefs, err := us.UserPrefs(ctx, *username, *p.AppName) if err != nil { wErr(w, r, autoError(err)) return @@ -67,10 +67,10 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - uid := auth.UIDFrom(ctx) + username := auth.UsernameFrom(ctx) - if uid == nil { - wErr(w, r, autoError(ErrBadUID)) + if username == nil { + wErr(w, r, autoError(rbac.ErrBadSubject)) return } @@ -102,7 +102,7 @@ func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) { } us := users.FromCtx(ctx) - err = us.SetUserPrefs(ctx, *uid, *p.AppName, prefs) + err = us.SetUserPrefs(ctx, *username, *p.AppName, prefs) if err != nil { wErr(w, r, autoError(err)) return diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 2300c04..cb5f26e 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -68,7 +68,7 @@ func (s *Server) setupRoutes() { func (s *Server) WithCtxStores() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(s.addStoresTo(r.Context())) + r = r.WithContext(s.fillCtx(r.Context())) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) diff --git a/pkg/server/server.go b/pkg/server/server.go index 963151e..15f11d4 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -15,6 +15,7 @@ import ( "dynatron.me/x/stillbox/pkg/incidents/incstore" "dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/rest" "dynatron.me/x/stillbox/pkg/share" "dynatron.me/x/stillbox/pkg/sinks" @@ -50,6 +51,7 @@ type Server struct { calls callstore.Store incidents incstore.Store share share.Service + rbac rbac.RBAC } func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { @@ -75,6 +77,11 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { tgCache := tgstore.NewCache() api := rest.New(cfg.BaseURL.URL()) + rbacSvc, err := rbac.New() + if err != nil { + return nil, err + } + srv := &Server{ auth: authenticator, conf: cfg, @@ -91,6 +98,7 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { users: users.NewStore(), calls: callstore.NewStore(), incidents: incstore.NewStore(), + rbac: rbacSvc, } if cfg.DB.Partition.Enabled { @@ -138,13 +146,14 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { return srv, nil } -func (s *Server) addStoresTo(ctx context.Context) context.Context { +func (s *Server) fillCtx(ctx context.Context) context.Context { ctx = database.CtxWithDB(ctx, s.db) ctx = tgstore.CtxWithStore(ctx, s.tgs) ctx = users.CtxWithStore(ctx, s.users) ctx = callstore.CtxWithStore(ctx, s.calls) ctx = incstore.CtxWithStore(ctx, s.incidents) - ctx = share.CtxWithStore(ctx, s.share.Store()) + ctx = share.CtxWithStore(ctx, s.share.ShareStore()) + ctx = rbac.CtxWithRBAC(ctx, s.rbac) return ctx } @@ -154,7 +163,7 @@ func (s *Server) Go(ctx context.Context) error { s.installHupHandler() - ctx = s.addStoresTo(ctx) + ctx = s.fillCtx(ctx) httpSrv := &http.Server{ Addr: s.conf.Listen, diff --git a/pkg/share/service.go b/pkg/share/service.go index b4a8c87..63d3141 100644 --- a/pkg/share/service.go +++ b/pkg/share/service.go @@ -12,17 +12,17 @@ const ( ) type Service interface { - Store() Store + ShareStore() Store Go(ctx context.Context) } type service struct { - store Store + Store } -func (s *service) Store() Store { - return s.store +func (s *service) ShareStore() Store { + return s.Store } func (s *service) Go(ctx context.Context) { @@ -31,7 +31,7 @@ func (s *service) Go(ctx context.Context) { for { select { case <-tick.C: - err := s.store.Prune(ctx) + err := s.Prune(ctx) if err != nil { log.Error().Err(err).Msg("share prune failed") } @@ -44,6 +44,6 @@ func (s *service) Go(ctx context.Context) { func NewService() *service { return &service{ - store: NewStore(), + Store: NewStore(), } } diff --git a/pkg/talkgroups/talkgroup.go b/pkg/talkgroups/talkgroup.go index ded46fc..7965e98 100644 --- a/pkg/talkgroups/talkgroup.go +++ b/pkg/talkgroups/talkgroup.go @@ -9,6 +9,7 @@ import ( "strings" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" ) type Talkgroup struct { @@ -17,6 +18,10 @@ type Talkgroup struct { Learned bool `json:"learned"` } +func (t *Talkgroup) GetResourceName() string { + return rbac.ResourceTalkgroup +} + func (t Talkgroup) String() string { if t.System.Name == "" { t.System.Name = strconv.Itoa(int(t.Talkgroup.TGID)) diff --git a/pkg/talkgroups/tgstore/store.go b/pkg/talkgroups/tgstore/store.go index 896cb3f..f92d3af 100644 --- a/pkg/talkgroups/tgstore/store.go +++ b/pkg/talkgroups/tgstore/store.go @@ -8,11 +8,11 @@ import ( "time" "dynatron.me/x/stillbox/internal/common" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" tgsp "dynatron.me/x/stillbox/pkg/talkgroups" + "dynatron.me/x/stillbox/pkg/users" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" @@ -515,20 +515,26 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) } func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*tgsp.Talkgroup, error) { + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update") + if err != nil { + return nil, err + } + sysName, has := t.SystemName(ctx, int(*input.SystemID)) if !has { return nil, ErrNoSuchSystem } + db := database.FromCtx(ctx) var tg database.Talkgroup - err := db.InTx(ctx, func(db database.Store) error { + err = db.InTx(ctx, func(db database.Store) error { var oerr error tg, oerr = db.UpdateTalkgroup(ctx, input) if oerr != nil { return oerr } versionBatch := db.StoreTGVersion(ctx, []database.StoreTGVersionParams{{ - Submitter: auth.UIDFrom(ctx).Int32Ptr(), + Submitter: user.ID.Int32Ptr(), TGID: *input.TGID, }}) defer versionBatch.Close() @@ -577,8 +583,13 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { t.Lock() defer t.Unlock() - err := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { - err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), auth.UIDFrom(ctx).Int32Ptr()) + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update") + if err != nil { + return err + } + + err = database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { + err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr()) if err != nil { return err } @@ -633,6 +644,11 @@ func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, er } func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.UpsertTalkgroupParams) ([]*tgsp.Talkgroup, error) { + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "create+update") + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) sysName, hasSys := t.SystemName(ctx, system) if !hasSys { @@ -645,7 +661,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse tgs := make([]*tgsp.Talkgroup, 0, len(input)) - err := db.InTx(ctx, func(db database.Store) error { + err = db.InTx(ctx, func(db database.Store) error { versionParams := make([]database.StoreTGVersionParams, 0, len(input)) for i := range input { // normalize tags @@ -670,7 +686,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse versionParams = append(versionParams, database.StoreTGVersionParams{ SystemID: int32(system), TGID: r.TGID, - Submitter: auth.UIDFrom(ctx).Int32Ptr(), + Submitter: user.ID.Int32Ptr(), }) tgs = append(tgs, &tgsp.Talkgroup{ Talkgroup: r, diff --git a/pkg/users/guest.go b/pkg/users/guest.go new file mode 100644 index 0000000..c38d2dc --- /dev/null +++ b/pkg/users/guest.go @@ -0,0 +1,21 @@ +package users + +import ( + "dynatron.me/x/stillbox/pkg/rbac" +) + +type ShareLinkGuest struct { + ShareID string +} + +func (s *ShareLinkGuest) GetRoles() []string { + return []string{rbac.RoleShareGuest} +} + +type Public struct { + RemoteAddr string +} + +func (s *Public) GetRoles() []string { + return []string{rbac.RolePublic} +} diff --git a/pkg/users/store.go b/pkg/users/store.go index bfdfde1..161cd31 100644 --- a/pkg/users/store.go +++ b/pkg/users/store.go @@ -3,22 +3,35 @@ package users import ( "context" + "dynatron.me/x/stillbox/internal/cache" "dynatron.me/x/stillbox/pkg/database" ) type Store interface { + // GetUser gets a user by UID. + GetUser(ctx context.Context, username string) (*User, error) + // UserPrefs gets the preferences for the specified user and app name. - UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) + UserPrefs(ctx context.Context, username string, appName string) ([]byte, error) // SetUserPrefs sets the preferences for the specified user and app name. - SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error + SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error + + // Invalidate clears the user cache. + Invalidate() + + // UpdateUser updates a user's record + UpdateUser(ctx context.Context, username string, user UserUpdate) error } type postgresStore struct { + cache.Cache[string, *User] } func NewStore() *postgresStore { - return new(postgresStore) + return &postgresStore{ + Cache: cache.New[string, *User](), + } } type storeCtxKey string @@ -38,10 +51,53 @@ func FromCtx(ctx context.Context) Store { return s } -func (s *postgresStore) UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) { +func (s *postgresStore) Invalidate() { + s.Clear() +} + +type UserUpdate struct { + Email *string `json:"email"` + IsAdmin *bool `json:"isAdmin"` +} + +func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error { + dbu, err := database.FromCtx(ctx).UpdateUser(ctx, username, user.Email, user.IsAdmin) + if err != nil { + return err + } + + s.Set(username, fromDBUser(dbu)) + + return nil +} + +func (s *postgresStore) GetUser(ctx context.Context, username string) (*User, error) { + u, has := s.Get(username) + if has { + return u, nil + } + + db := database.FromCtx(ctx) + dbu, err := db.GetUserByUsername(ctx, username) + if err != nil { + return nil, err + } + + u = fromDBUser(dbu) + s.Set(username, u) + + return u, nil +} + +func (s *postgresStore) UserPrefs(ctx context.Context, username string, appName string) ([]byte, error) { + u, err := s.GetUser(ctx, username) + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) - prefs, err := db.GetAppPrefs(ctx, appName, int(uid)) + prefs, err := db.GetAppPrefs(ctx, appName, int(u.ID)) if err != nil { return nil, err } @@ -49,10 +105,13 @@ func (s *postgresStore) UserPrefs(ctx context.Context, uid int32, appName string return []byte(prefs), err } -func (s *postgresStore) SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error { +func (s *postgresStore) SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error { + u, err := s.GetUser(ctx, username) + if err != nil { + return err + } + db := database.FromCtx(ctx) - return db.SetAppPrefs(ctx, appName, prefs, int(uid)) + return db.SetAppPrefs(ctx, appName, prefs, int(u.ID)) } - -//func (s *postgresStore) diff --git a/pkg/users/user.go b/pkg/users/user.go index 82c1fb3..9a6dd8a 100644 --- a/pkg/users/user.go +++ b/pkg/users/user.go @@ -1,7 +1,12 @@ package users import ( + "context" "encoding/json" + "strings" + + "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" ) type UserID int @@ -20,6 +25,38 @@ func (u UserID) Int() int { return int(u) } +func (u UserID) IsValid() bool { + return u > 0 +} + +func From(ctx context.Context) (*User, error) { + sub := rbac.SubjectFrom(ctx) + return FromSubject(sub) +} + +func UserCheck(ctx context.Context, rsc rbac.Resource, actions string) (*User, error) { + acts := strings.Split(actions, "+") + subj, err := rbac.FromCtx(ctx).Check(ctx, rsc, rbac.WithActions(acts...)) + if err != nil { + return nil, err + } + + return FromSubject(subj) +} + +func FromSubject(sub rbac.Subject) (*User, error) { + if sub == nil { + return nil, rbac.ErrBadSubject + } + + user, isUser := sub.(*User) + if !isUser || user == nil || !user.ID.IsValid() { + return nil, rbac.ErrBadSubject + } + + return user, nil +} + type User struct { ID UserID Username string @@ -28,3 +65,26 @@ type User struct { IsAdmin bool Prefs json.RawMessage } + +func (u *User) GetRoles() []string { + r := make([]string, 1, 2) + + r[0] = rbac.RoleUser + + if u.IsAdmin { + r = append(r, rbac.RoleAdmin) + } + + return r +} + +func fromDBUser(dbu database.User) *User { + return &User{ + ID: UserID(dbu.ID), + Username: dbu.Username, + Password: dbu.Password, + Email: dbu.Email, + IsAdmin: dbu.IsAdmin, + Prefs: dbu.Prefs, + } +} diff --git a/sql/postgres/queries/users.sql b/sql/postgres/queries/users.sql index 36924b9..dfce324 100644 --- a/sql/postgres/queries/users.sql +++ b/sql/postgres/queries/users.sql @@ -1,14 +1,10 @@ -- name: GetUserByID :one SELECT * FROM users -WHERE id = $1 LIMIT 1; +WHERE id = $1; -- name: GetUserByUsername :one SELECT * FROM users -WHERE username = $1 LIMIT 1; - --- name: GetUserByUID :one -SELECT * FROM users -WHERE id = $1 LIMIT 1; +WHERE username = $1; -- name: GetUsers :many SELECT * FROM users; @@ -28,6 +24,14 @@ DELETE FROM users WHERE username = $1; -- name: UpdatePassword :exec UPDATE users SET password = $2 WHERE username = $1; +-- name: UpdateUser :one +UPDATE users SET + email = COALESCE(sqlc.narg('email'), email), + is_admin = COALESCE(sqlc.narg('is_admin'), is_admin) +WHERE + username = $1 +RETURNING *; + -- name: CreateAPIKey :one INSERT INTO api_keys( owner, @@ -42,7 +46,17 @@ RETURNING *; DELETE FROM api_keys WHERE api_key = $1; -- name: GetAPIKey :one -SELECT * FROM api_keys WHERE api_key = $1; +SELECT + a.id, + a.owner, + a.created_at, + a.expires, + a.disabled, + a.api_key, + u.username +FROM api_keys a +JOIN users u ON (a.owner = u.id) +WHERE api_key = $1; -- name: GetAppPrefs :one SELECT (prefs->>(@app_name::TEXT))::JSONB FROM users WHERE id = @uid;