From c8c3023afa641ed12e7e12f2b93451b4b1beb18a Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sat, 18 Jan 2025 15:18:40 -0500 Subject: [PATCH] RBAC --- .mockery.yaml | 6 +- internal/forms/marshal_test.go | 1 - pkg/alerting/alerting.go | 3 + pkg/alerting/simulate.go | 4 +- pkg/auth/apikey.go | 13 +- pkg/auth/auth.go | 8 +- pkg/auth/jwt.go | 20 +- pkg/calls/callstore/store.go | 112 +++++++++- pkg/database/calls.sql.go | 20 ++ pkg/database/incidents.sql.go | 11 + pkg/database/mocks/Store.go | 163 +++++++++++++++ pkg/database/partman/partman.go | 2 + pkg/database/querier.go | 3 + pkg/incidents/incstore/store.go | 46 +++- pkg/nexus/nexus.go | 2 + pkg/rbac/mocks/RBAC.go | 113 ++++++++++ pkg/rbac/rbac.go | 143 ++++++++++++- pkg/rbac/rbac_test.go | 197 ++++++++++++++++++ pkg/rest/api.go | 12 ++ pkg/server/ingest.go | 3 +- pkg/server/server.go | 12 +- pkg/share/service.go | 3 + pkg/sinks/database.go | 57 +---- pkg/sources/http.go | 21 +- pkg/store/store.go | 50 ----- pkg/talkgroups/tgstore/store.go | 81 +++++-- .../xport/radioref/radioreference_test.go | 10 +- pkg/users/store.go | 26 ++- pkg/users/user.go | 4 + sql/postgres/queries/calls.sql | 6 + sql/postgres/queries/incidents.sql | 3 + 31 files changed, 973 insertions(+), 182 deletions(-) create mode 100644 pkg/rbac/mocks/RBAC.go create mode 100644 pkg/rbac/rbac_test.go delete mode 100644 pkg/store/store.go diff --git a/.mockery.yaml b/.mockery.yaml index a0b5082..3174fe4 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -1,4 +1,4 @@ -dir: '{{ replaceAll .InterfaceDirRelative "internal" "internal_" }}/mocks' +dir: '{{.InterfaceDir}}/mocks' mockname: "{{.InterfaceName}}" outpkg: "mocks" filename: "{{.InterfaceName}}.go" @@ -9,3 +9,7 @@ packages: interfaces: Store: DBTX: + dynatron.me/x/stillbox/pkg/rbac: + config: + interfaces: + RBAC: diff --git a/internal/forms/marshal_test.go b/internal/forms/marshal_test.go index eb5eaba..b762c30 100644 --- a/internal/forms/marshal_test.go +++ b/internal/forms/marshal_test.go @@ -15,7 +15,6 @@ import ( "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/users" diff --git a/pkg/alerting/alerting.go b/pkg/alerting/alerting.go index 259d767..d78a0cf 100644 --- a/pkg/alerting/alerting.go +++ b/pkg/alerting/alerting.go @@ -14,6 +14,7 @@ import ( "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" @@ -123,6 +124,8 @@ func New(cfg config.Alerting, tgCache tgstore.Store, opts ...AlertOption) Alerte // Go is the alerting loop. It does not start a goroutine. func (as *alerter) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "alerter"}) + err := as.startBackfill(ctx) if err != nil { log.Error().Err(err).Msg("backfill") diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 641ab28..1486460 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -12,6 +12,7 @@ import ( "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/trending" "dynatron.me/x/stillbox/pkg/config" + "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" @@ -59,8 +60,9 @@ func (s *Simulation) stepClock(t time.Time) { // Simulate begins the simulation using the DB handle from ctx. It returns final scores. func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.ID], error) { + db := database.FromCtx(ctx) now := time.Now() - tgc := tgstore.NewCache() + tgc := tgstore.NewCache(db) s.Enable = true s.alerter = New(s.Alerting, tgc, WithClock(&s.clock)).(*alerter) diff --git a/pkg/auth/apikey.go b/pkg/auth/apikey.go index 4bc86f3..f88c8f9 100644 --- a/pkg/auth/apikey.go +++ b/pkg/auth/apikey.go @@ -7,7 +7,7 @@ import ( "time" "dynatron.me/x/stillbox/pkg/database" - "dynatron.me/x/stillbox/pkg/users" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -16,20 +16,19 @@ import ( type apiKeyAuth interface { // CheckAPIKey validates the provided key and returns the API owner's users.UserID. // An error is returned if validation fails for any reason. - CheckAPIKey(ctx context.Context, key string) (*users.UserID, error) + CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error) } -func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*users.UserID, error) { +func (a *Auth) CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error) { keyUuid, err := uuid.Parse(key) if err != nil { log.Error().Str("apikey", key).Msg("cannot parse key") return nil, ErrBadRequest } - db := database.FromCtx(ctx) hash := sha256.Sum256([]byte(keyUuid.String())) b64hash := base64.StdEncoding.EncodeToString(hash[:]) - apik, err := db.GetAPIKey(ctx, b64hash) + apik, err := a.ust.GetAPIKey(ctx, b64hash) if err != nil { if database.IsNoRows(err) { log.Error().Str("apikey", keyUuid.String()).Msg("no such key") @@ -45,7 +44,5 @@ func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*users.UserID, erro return nil, ErrUnauthorized } - owner := users.UserID(apik.Owner) - - return &owner, nil + return a.ust.GetUser(ctx, apik.Username) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 67460c9..4bd073b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -8,6 +8,8 @@ import ( _ "embed" "dynatron.me/x/stillbox/pkg/config" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" @@ -22,14 +24,16 @@ type Authenticator interface { type Auth struct { rl *httprate.RateLimiter jwt *jwtauth.JWTAuth + ust users.Store cfg config.Auth } // NewAuthenticator creates a new Authenticator with the provided config. -func NewAuthenticator(cfg config.Auth) *Auth { +func NewAuthenticator(cfg config.Auth, ust users.Store) *Auth { a := &Auth{ rl: httprate.NewRateLimiter(5, time.Minute), cfg: cfg, + ust: ust, } a.initJWT() @@ -51,7 +55,7 @@ var ( // ErrorResponse writes the error and appropriate HTTP response code. func ErrorResponse(w http.ResponseWriter, err error) { switch err { - case ErrLoginFailed, ErrUnauthorized: + case ErrLoginFailed, ErrUnauthorized, rbac.ErrBadSubject: http.Error(w, err.Error(), http.StatusUnauthorized) case ErrBadRequest: http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 5b25084..97867b1 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "strconv" "strings" "time" @@ -153,12 +152,12 @@ func (a *Auth) Login(ctx context.Context, username, password string) (token stri } } - return a.newToken(found.ID), nil + return a.newToken(found.Username), nil } -func (a *Auth) newToken(uid int) string { +func (a *Auth) newToken(username string) string { claims := claims{ - "sub": strconv.Itoa(int(uid)), + "sub": username, } jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month _, tokenString, err := a.jwt.Encode(claims) @@ -190,19 +189,14 @@ func (a *Auth) routeRefresh(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid token", http.StatusBadRequest) return } - existingSubjectUID := jwToken.Subject() - if existingSubjectUID == "" { + + existingSubjectUsername := jwToken.Subject() + if existingSubjectUsername == "" { http.Error(w, "Invalid token", http.StatusBadRequest) return } - uid, err := strconv.Atoi(existingSubjectUID) - if err != nil { - log.Error().Str("sub", existingSubjectUID).Err(err).Msg("atoi uid for token refresh") - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } - tok := a.newToken(uid) + tok := a.newToken(existingSubjectUsername) cookie := &http.Cookie{ Name: CookieName, diff --git a/pkg/calls/callstore/store.go b/pkg/calls/callstore/store.go index 10c1c9d..218113d 100644 --- a/pkg/calls/callstore/store.go +++ b/pkg/calls/callstore/store.go @@ -9,6 +9,9 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -16,6 +19,12 @@ import ( ) type Store interface { + // AddCall adds a call to the database. + AddCall(ctx context.Context, call *calls.Call) error + + // DeleteCall deletes a call. + Delete(ctx context.Context, id uuid.UUID) error + // CallAudio returns a CallAudio struct CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) @@ -24,10 +33,13 @@ type Store interface { } type store struct { + db database.Store } -func NewStore() *store { - return new(store) +func NewStore(db database.Store) *store { + return &store{ + db: db, + } } type storeCtxKey string @@ -41,13 +53,77 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewStore() + panic("no call store in context") } return s } +func toAddCallParams(call *calls.Call) database.AddCallParams { + return database.AddCallParams{ + ID: call.ID, + Submitter: call.Submitter.Int32Ptr(), + System: call.System, + Talkgroup: call.Talkgroup, + CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true}, + AudioName: common.NilIfZero(call.AudioName), + AudioBlob: call.Audio, + AudioType: common.NilIfZero(call.AudioType), + Duration: call.Duration.MsInt32Ptr(), + Frequency: call.Frequency, + Frequencies: call.Frequencies, + Patches: call.Patches, + TGLabel: call.TalkgroupLabel, + TGAlphaTag: call.TGAlphaTag, + TGGroup: call.TalkgroupGroup, + Source: call.Source, + } +} + +func (s *store) AddCall(ctx context.Context, call *calls.Call) error { + _, err := rbac.Check(ctx, call, rbac.WithActions(rbac.ActionCreate)) + if err != nil { + return err + } + + params := toAddCallParams(call) + db := database.FromCtx(ctx) + tgs := tgstore.FromCtx(ctx) + + err = db.InTx(ctx, func(tx database.Store) error { + err := tx.AddCall(ctx, params) + if err != nil { + return fmt.Errorf("add call: %w", err) + } + + return nil + }, pgx.TxOptions{}) + + if err != nil && database.IsTGConstraintViolation(err) { + return db.InTx(ctx, func(tx database.Store) error { + _, err := tgs.LearnTG(ctx, call) + if err != nil { + return fmt.Errorf("learn tg: %w", err) + } + + err = tx.AddCall(ctx, params) + if err != nil { + return fmt.Errorf("learn tg retry: %w", err) + } + + return nil + }, pgx.TxOptions{}) + } + + return nil +} + func (s *store) CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) dbCall, err := db.GetCallAudioByID(ctx, id) @@ -76,6 +152,11 @@ type CallsParams struct { } func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListCallsPRow, totalCount int, err error) { + _, err = rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, 0, err + } + db := database.FromCtx(ctx) offset, perPage := p.Pagination.OffsetPerPage(100) @@ -127,3 +208,28 @@ func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListC return rows, int(count), err } + +func (s *store) Delete(ctx context.Context, id uuid.UUID) error { + callOwn, err := s.getCallOwner(ctx, id) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &callOwn, rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + + return database.FromCtx(ctx).DeleteCall(ctx, id) +} + +func (s *store) getCallOwner(ctx context.Context, id uuid.UUID) (calls.Call, error) { + subInt, err := database.FromCtx(ctx).GetCallSubmitter(ctx, id) + + var sub *users.UserID + + if subInt != nil { + sub = common.PtrTo(users.UserID(*subInt)) + } + return calls.Call{ID: id, Submitter: sub}, err +} diff --git a/pkg/database/calls.sql.go b/pkg/database/calls.sql.go index 22da55b..a26efe8 100644 --- a/pkg/database/calls.sql.go +++ b/pkg/database/calls.sql.go @@ -155,6 +155,15 @@ func (q *Queries) CleanupSweptCalls(ctx context.Context, rangeStart pgtype.Times return result.RowsAffected(), nil } +const deleteCall = `-- name: DeleteCall :exec +DELETE FROM calls WHERE id = $1 +` + +func (q *Queries) DeleteCall(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, deleteCall, id) + return err +} + const getCallAudioByID = `-- name: GetCallAudioByID :one SELECT c.call_date, @@ -192,6 +201,17 @@ func (q *Queries) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAu return i, err } +const getCallSubmitter = `-- name: GetCallSubmitter :one +SELECT submitter FROM calls WHERE id = $1 +` + +func (q *Queries) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) { + row := q.db.QueryRow(ctx, getCallSubmitter, id) + var submitter *int32 + err := row.Scan(&submitter) + return submitter, err +} + const getDatabaseSize = `-- name: GetDatabaseSize :one SELECT pg_size_pretty(pg_database_size(current_database())) ` diff --git a/pkg/database/incidents.sql.go b/pkg/database/incidents.sql.go index 18f6a9d..c39e6ca 100644 --- a/pkg/database/incidents.sql.go +++ b/pkg/database/incidents.sql.go @@ -244,6 +244,17 @@ func (q *Queries) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetInci return items, nil } +const getIncidentOwner = `-- name: GetIncidentOwner :one +SELECT owner FROM incidents WHERE id = $1 +` + +func (q *Queries) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) { + row := q.db.QueryRow(ctx, getIncidentOwner, id) + var owner int + err := row.Scan(&owner) + return owner, err +} + const listIncidentsCount = `-- name: ListIncidentsCount :one SELECT COUNT(*) FROM incidents i diff --git a/pkg/database/mocks/Store.go b/pkg/database/mocks/Store.go index 2206909..11f1ae2 100644 --- a/pkg/database/mocks/Store.go +++ b/pkg/database/mocks/Store.go @@ -795,6 +795,53 @@ func (_c *Store_DeleteAPIKey_Call) RunAndReturn(run func(context.Context, string return _c } +// DeleteCall provides a mock function with given fields: ctx, id +func (_m *Store) DeleteCall(ctx context.Context, id uuid.UUID) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteCall") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Store_DeleteCall_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCall' +type Store_DeleteCall_Call struct { + *mock.Call +} + +// DeleteCall is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) DeleteCall(ctx interface{}, id interface{}) *Store_DeleteCall_Call { + return &Store_DeleteCall_Call{Call: _e.mock.On("DeleteCall", ctx, id)} +} + +func (_c *Store_DeleteCall_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_DeleteCall_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_DeleteCall_Call) Return(_a0 error) *Store_DeleteCall_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Store_DeleteCall_Call) RunAndReturn(run func(context.Context, uuid.UUID) error) *Store_DeleteCall_Call { + _c.Call.Return(run) + return _c +} + // DeleteIncident provides a mock function with given fields: ctx, id func (_m *Store) DeleteIncident(ctx context.Context, id uuid.UUID) error { ret := _m.Called(ctx, id) @@ -1358,6 +1405,65 @@ func (_c *Store_GetCallAudioByID_Call) RunAndReturn(run func(context.Context, uu return _c } +// GetCallSubmitter provides a mock function with given fields: ctx, id +func (_m *Store) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetCallSubmitter") + } + + var r0 *int32 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*int32, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) *int32); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*int32) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetCallSubmitter_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCallSubmitter' +type Store_GetCallSubmitter_Call struct { + *mock.Call +} + +// GetCallSubmitter is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) GetCallSubmitter(ctx interface{}, id interface{}) *Store_GetCallSubmitter_Call { + return &Store_GetCallSubmitter_Call{Call: _e.mock.On("GetCallSubmitter", ctx, id)} +} + +func (_c *Store_GetCallSubmitter_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetCallSubmitter_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_GetCallSubmitter_Call) Return(_a0 *int32, _a1 error) *Store_GetCallSubmitter_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetCallSubmitter_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*int32, error)) *Store_GetCallSubmitter_Call { + _c.Call.Return(run) + return _c +} + // GetDatabaseSize provides a mock function with given fields: ctx func (_m *Store) GetDatabaseSize(ctx context.Context) (string, error) { ret := _m.Called(ctx) @@ -1530,6 +1636,63 @@ func (_c *Store_GetIncidentCalls_Call) RunAndReturn(run func(context.Context, uu return _c } +// GetIncidentOwner provides a mock function with given fields: ctx, id +func (_m *Store) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetIncidentOwner") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (int, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) int); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetIncidentOwner_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIncidentOwner' +type Store_GetIncidentOwner_Call struct { + *mock.Call +} + +// GetIncidentOwner is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) GetIncidentOwner(ctx interface{}, id interface{}) *Store_GetIncidentOwner_Call { + return &Store_GetIncidentOwner_Call{Call: _e.mock.On("GetIncidentOwner", ctx, id)} +} + +func (_c *Store_GetIncidentOwner_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetIncidentOwner_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_GetIncidentOwner_Call) Return(_a0 int, _a1 error) *Store_GetIncidentOwner_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetIncidentOwner_Call) RunAndReturn(run func(context.Context, uuid.UUID) (int, error)) *Store_GetIncidentOwner_Call { + _c.Call.Return(run) + return _c +} + // GetShare provides a mock function with given fields: ctx, id func (_m *Store) GetShare(ctx context.Context, id string) (database.Share, error) { ret := _m.Called(ctx, id) diff --git a/pkg/database/partman/partman.go b/pkg/database/partman/partman.go index db9f3ca..d19683b 100644 --- a/pkg/database/partman/partman.go +++ b/pkg/database/partman/partman.go @@ -13,6 +13,7 @@ import ( "dynatron.me/x/stillbox/internal/isoweek" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -134,6 +135,7 @@ func New(db database.Store, cfg config.Partition) (*partman, error) { var _ PartitionManager = (*partman)(nil) func (pm *partman) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "partman"}) tick := time.NewTicker(CheckInterval) select { diff --git a/pkg/database/querier.go b/pkg/database/querier.go index bf2194f..d0a8553 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -23,6 +23,7 @@ type Querier interface { CreateSystem(ctx context.Context, iD int, name string) error CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAPIKey(ctx context.Context, apiKey string) error + DeleteCall(ctx context.Context, id uuid.UUID) error DeleteIncident(ctx context.Context, id uuid.UUID) error DeleteShare(ctx context.Context, id string) error DeleteSystem(ctx context.Context, id int) error @@ -32,9 +33,11 @@ type Querier interface { GetAllTalkgroupTags(ctx context.Context) ([]string, error) GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error) + GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) GetDatabaseSize(ctx context.Context) (string, error) GetIncident(ctx context.Context, id uuid.UUID) (Incident, error) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetIncidentCallsRow, error) + GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) GetShare(ctx context.Context, id string) (Share, error) GetSystemName(ctx context.Context, systemID int) (string, error) GetTalkgroup(ctx context.Context, systemID int32, tGID int32) (GetTalkgroupRow, error) diff --git a/pkg/incidents/incstore/store.go b/pkg/incidents/incstore/store.go index c7c8bfa..1ba52fa 100644 --- a/pkg/incidents/incstore/store.go +++ b/pkg/incidents/incstore/store.go @@ -9,6 +9,7 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/incidents" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -72,7 +73,6 @@ func NewStore() Store { } func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) { - // TODO: replace this with a real RBAC check user, err := users.UserCheck(ctx, new(incidents.Incident), "create") if err != nil { return nil, err @@ -132,6 +132,16 @@ func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*in } func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID, addCallIDs []uuid.UUID, notes []byte, removeCallIDs []uuid.UUID) error { + inc, err := s.getIncidentOwner(ctx, incidentID) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionUpdate)) + if err != nil { + return err + } + return database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { if len(addCallIDs) > 0 { var noteAr [][]byte @@ -160,6 +170,10 @@ func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID } func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incident, totalCount int, err error) { + _, err = rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, 0, err + } db := database.FromCtx(ctx) offset, perPage := p.Pagination.OffsetPerPage(100) @@ -261,6 +275,11 @@ func fromDBCalls(d []database.GetIncidentCallsRow) []incidents.IncidentCall { } func (s *store) Incident(ctx context.Context, id uuid.UUID) (*incidents.Incident, error) { + _, err := rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + var r incidents.Incident txErr := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { inc, err := db.GetIncident(ctx, id) @@ -307,6 +326,16 @@ func (uip UpdateIncidentParams) toDBUIP(id uuid.UUID) database.UpdateIncidentPar } func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncidentParams) (*incidents.Incident, error) { + ckinc, err := s.getIncidentOwner(ctx, id) + if err != nil { + return nil, err + } + + _, err = rbac.Check(ctx, &ckinc, rbac.WithActions(rbac.ActionUpdate)) + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) dbInc, err := db.UpdateIncident(ctx, p.toDBUIP(id)) @@ -320,9 +349,24 @@ func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncide } func (s *store) DeleteIncident(ctx context.Context, id uuid.UUID) error { + inc, err := s.getIncidentOwner(ctx, id) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + return database.FromCtx(ctx).DeleteIncident(ctx, id) } func (s *store) UpdateNotes(ctx context.Context, incidentID uuid.UUID, callID uuid.UUID, notes []byte) error { return database.FromCtx(ctx).UpdateCallIncidentNotes(ctx, notes, incidentID, callID) } + +func (s *store) getIncidentOwner(ctx context.Context, id uuid.UUID) (incidents.Incident, error) { + owner, err := database.FromCtx(ctx).GetIncidentOwner(ctx, id) + return incidents.Incident{ID: id, Owner: users.UserID(owner)}, err +} diff --git a/pkg/nexus/nexus.go b/pkg/nexus/nexus.go index 14d52ad..fcfe056 100644 --- a/pkg/nexus/nexus.go +++ b/pkg/nexus/nexus.go @@ -6,6 +6,7 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/pb" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/rs/zerolog/log" ) @@ -38,6 +39,7 @@ func New() *Nexus { } func (n *Nexus) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "nexus"}) for { select { case call, ok := <-n.callCh: diff --git a/pkg/rbac/mocks/RBAC.go b/pkg/rbac/mocks/RBAC.go new file mode 100644 index 0000000..d7de98b --- /dev/null +++ b/pkg/rbac/mocks/RBAC.go @@ -0,0 +1,113 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + rbac "dynatron.me/x/stillbox/pkg/rbac" + mock "github.com/stretchr/testify/mock" + + restrict "github.com/el-mike/restrict/v2" +) + +// RBAC is an autogenerated mock type for the RBAC type +type RBAC struct { + mock.Mock +} + +type RBAC_Expecter struct { + mock *mock.Mock +} + +func (_m *RBAC) EXPECT() *RBAC_Expecter { + return &RBAC_Expecter{mock: &_m.Mock} +} + +// Check provides a mock function with given fields: ctx, res, opts +func (_m *RBAC) Check(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption) (rbac.Subject, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, res) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Check") + } + + var r0 rbac.Subject + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)); ok { + return rf(ctx, res, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) rbac.Subject); ok { + r0 = rf(ctx, res, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(rbac.Subject) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, restrict.Resource, ...rbac.CheckOption) error); ok { + r1 = rf(ctx, res, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RBAC_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check' +type RBAC_Check_Call struct { + *mock.Call +} + +// Check is a helper method to define mock.On call +// - ctx context.Context +// - res restrict.Resource +// - opts ...rbac.CheckOption +func (_e *RBAC_Expecter) Check(ctx interface{}, res interface{}, opts ...interface{}) *RBAC_Check_Call { + return &RBAC_Check_Call{Call: _e.mock.On("Check", + append([]interface{}{ctx, res}, opts...)...)} +} + +func (_c *RBAC_Check_Call) Run(run func(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption)) *RBAC_Check_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]rbac.CheckOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(rbac.CheckOption) + } + } + run(args[0].(context.Context), args[1].(restrict.Resource), variadicArgs...) + }) + return _c +} + +func (_c *RBAC_Check_Call) Return(_a0 rbac.Subject, _a1 error) *RBAC_Check_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RBAC_Check_Call) RunAndReturn(run func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)) *RBAC_Check_Call { + _c.Call.Return(run) + return _c +} + +// NewRBAC creates a new instance of RBAC. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRBAC(t interface { + mock.TestingT + Cleanup(func()) +}) *RBAC { + mock := &RBAC{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/rbac/rbac.go b/pkg/rbac/rbac.go index 0b5ede1..b626f34 100644 --- a/pkg/rbac/rbac.go +++ b/pkg/rbac/rbac.go @@ -3,6 +3,8 @@ package rbac import ( "context" "errors" + "fmt" + "reflect" "github.com/el-mike/restrict/v2" "github.com/el-mike/restrict/v2/adapters" @@ -12,6 +14,7 @@ const ( RoleUser = "User" RoleSubmitter = "Submitter" RoleAdmin = "Admin" + RoleSystem = "System" RolePublic = "Public" RoleShareGuest = "ShareGuest" @@ -20,6 +23,7 @@ const ( ResourceTalkgroup = "Talkgroup" ResourceAlert = "Alert" ResourceShare = "Share" + ResourceAPIKey = "APIKey" ActionRead = "read" ActionCreate = "create" @@ -29,6 +33,9 @@ const ( PresetUpdateOwn = "updateOwn" PresetDeleteOwn = "deleteOwn" PresetReadShared = "readShared" + + PresetUpdateSubmitter = "updateSubmitter" + PresetDeleteSubmitter = "deleteSubmitter" ) var ( @@ -43,6 +50,14 @@ func CtxWithSubject(ctx context.Context, sub Subject) context.Context { return context.WithValue(ctx, SubjectCtxKey, sub) } +func ErrAccessDenied(err error) *restrict.AccessDeniedError { + if accessErr, ok := err.(*restrict.AccessDeniedError); ok { + return accessErr + } + + return nil +} + func SubjectFrom(ctx context.Context) Subject { sub, ok := ctx.Value(SubjectCtxKey).(Subject) if ok { @@ -87,8 +102,8 @@ var policy = &restrict.PolicyDefinition{ ResourceCall: { &restrict.Permission{Action: ActionRead}, &restrict.Permission{Action: ActionCreate}, - &restrict.Permission{Preset: PresetUpdateOwn}, - &restrict.Permission{Preset: PresetDeleteOwn}, + &restrict.Permission{Preset: PresetUpdateSubmitter}, + &restrict.Permission{Preset: PresetDeleteSubmitter}, }, ResourceTalkgroup: { &restrict.Permission{Action: ActionRead}, @@ -107,6 +122,11 @@ var policy = &restrict.PolicyDefinition{ ResourceCall: { &restrict.Permission{Action: ActionCreate}, }, + ResourceTalkgroup: { + // for learning TGs + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Action: ActionUpdate}, + }, }, }, RoleShareGuest: { @@ -141,6 +161,9 @@ var policy = &restrict.PolicyDefinition{ }, }, }, + RoleSystem: { + Parents: []string{RoleSystem}, + }, RolePublic: { /* Grants: restrict.GrantsMap{ @@ -184,6 +207,38 @@ var policy = &restrict.PolicyDefinition{ }, }, }, + PresetUpdateSubmitter: &restrict.Permission{ + Action: ActionUpdate, + Conditions: restrict.Conditions{ + &SubmitterEqualCondition{ + ID: "isSubmitter", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Submitter", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetDeleteSubmitter: &restrict.Permission{ + Action: ActionDelete, + Conditions: restrict.Conditions{ + &SubmitterEqualCondition{ + ID: "isSubmitter", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Submitter", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, PresetReadShared: &restrict.Permission{ Action: ActionRead, Conditions: restrict.Conditions{ @@ -208,15 +263,15 @@ type checkOptions struct { context restrict.Context } -type checkOption func(*checkOptions) +type CheckOption func(*checkOptions) -func WithActions(actions ...string) checkOption { +func WithActions(actions ...string) CheckOption { return func(o *checkOptions) { o.actions = append(o.actions, actions...) } } -func WithContext(ctx restrict.Context) checkOption { +func WithContext(ctx restrict.Context) CheckOption { return func(o *checkOptions) { o.context = ctx } @@ -228,6 +283,7 @@ func UseResource(rsc string) restrict.Resource { type Subject interface { restrict.Subject + GetName() string } type Resource interface { @@ -235,7 +291,7 @@ type Resource interface { } type RBAC interface { - Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) + Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) } type rbac struct { @@ -257,7 +313,12 @@ func New() (*rbac, error) { }, nil } -func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) { +// Check is a convenience function to pull the RBAC instance out of ctx and Check. +func Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) { + return FromCtx(ctx).Check(ctx, res, opts...) +} + +func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) { sub := SubjectFrom(ctx) o := checkOptions{} @@ -279,6 +340,10 @@ type ShareLinkGuest struct { ShareID string } +func (s *ShareLinkGuest) GetName() string { + return "SHARE:" + s.ShareID +} + func (s *ShareLinkGuest) GetRoles() []string { return []string{RoleShareGuest} } @@ -287,6 +352,70 @@ type PublicSubject struct { RemoteAddr string } +func (s *PublicSubject) GetName() string { + return "PUBLIC:" + s.RemoteAddr +} + func (s *PublicSubject) GetRoles() []string { return []string{RolePublic} } + +type SystemServiceSubject struct { + Name string +} + +func (s *SystemServiceSubject) GetName() string { + return "SYSTEM:" + s.Name +} + +func (s *SystemServiceSubject) GetRoles() []string { + return []string{RoleSystem} +} + +const ( + SubmitterEqualConditionType = "SUBMITTER_EQUAL" +) + +type SubmitterEqualCondition struct { + ID string `json:"name,omitempty" yaml:"name,omitempty"` + Left *restrict.ValueDescriptor `json:"left" yaml:"left"` + Right *restrict.ValueDescriptor `json:"right" yaml:"right"` +} + +func (s *SubmitterEqualCondition) Type() string { + return SubmitterEqualConditionType +} + +func (c *SubmitterEqualCondition) Check(r *restrict.AccessRequest) error { + left, err := c.Left.GetValue(r) + if err != nil { + return err + } + + right, err := c.Right.GetValue(r) + if err != nil { + return err + } + + lVal := reflect.ValueOf(left) + rVal := reflect.ValueOf(right) + + // deref Left. this is the difference between us and EqualCondition + for lVal.Kind() == reflect.Pointer { + lVal = lVal.Elem() + } + + if !lVal.IsValid() || !reflect.DeepEqual(rVal.Interface(), lVal.Interface()) { + return restrict.NewConditionNotSatisfiedError(c, r, fmt.Errorf("values \"%v\" and \"%v\" are not equal", left, right)) + } + + return nil +} + +func SubmitterEqualConditionFactory() restrict.Condition { + return new(SubmitterEqualCondition) +} + +func init() { + restrict.RegisterConditionFactory(SubmitterEqualConditionType, SubmitterEqualConditionFactory) +} diff --git a/pkg/rbac/rbac_test.go b/pkg/rbac/rbac_test.go new file mode 100644 index 0000000..582732d --- /dev/null +++ b/pkg/rbac/rbac_test.go @@ -0,0 +1,197 @@ +package rbac_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/incidents" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/talkgroups" + "dynatron.me/x/stillbox/pkg/users" + "github.com/el-mike/restrict/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRBAC(t *testing.T) { + tests := []struct { + name string + subject rbac.Subject + resource rbac.Resource + action string + expectErr error + }{ + { + name: "admin update talkgroup", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &talkgroups.Talkgroup{}, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "admin update incident", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 4, + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update incident not owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 4, + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Incident"`), + }, + { + name: "user update incident owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 2, + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user delete incident not owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 6, + }, + action: rbac.ActionDelete, + expectErr: errors.New(`access denied for Action: "delete" on Resource: "Incident"`), + }, + { + name: "admin update call", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(4)), + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update call not owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(4)), + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`), + }, + { + name: "user update call owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(2)), + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update call nil submitter", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: nil, + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`), + }, + { + name: "user delete call not owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(6)), + }, + action: rbac.ActionDelete, + expectErr: errors.New(`access denied for Action: "delete" on Resource: "Call"`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := rbac.CtxWithSubject(context.Background(), tc.subject) + rb, err := rbac.New() + require.NoError(t, err) + sub, err := rb.Check(ctx, tc.resource, rbac.WithActions(tc.action)) + if tc.expectErr != nil { + assert.Equal(t, tc.expectErr.Error(), err.Error()) + } else { + if !assert.NoError(t, err) { + accErr(err) + } + } + assert.Equal(t, tc.subject, sub) + }) + } +} + +func accErr(err error) { + if accessError, ok := err.(*restrict.AccessDeniedError); ok { + // Error() implementation. Returns a message in a form: "access denied for Action/s: ... on Resource: ..." + fmt.Println(accessError) + // Returns an AccessRequest that failed. + fmt.Println(accessError.Request) + // Returns first reason for the denied access. + // Especially helpful in fail-early mode, where there will only be one Reason. + fmt.Println(accessError.FirstReason()) + + // Reasons property will hold all errors that caused the access to be denied. + for _, permissionErr := range accessError.Reasons { + fmt.Println(permissionErr) + fmt.Println(permissionErr.Action) + fmt.Println(permissionErr.RoleName) + fmt.Println(permissionErr.ResourceName) + + // Returns first ConditionNotSatisfied error for given PermissionError, if any was returned for given PermissionError. + // Especially helpful in fail-early mode, where there will only be one failed Condition. + fmt.Println(permissionErr.FirstConditionError()) + + // ConditionErrors property will hold all ConditionNotSatisfied errors. + for _, conditionErr := range permissionErr.ConditionErrors { + fmt.Println(conditionErr) + fmt.Println(conditionErr.Reason) + + // Every ConditionNotSatisfied contains an instance of Condition that returned it, + // so it can be tested using type assertion to get more details about failed Condition. + if emptyCondition, ok := conditionErr.Condition.(*restrict.EmptyCondition); ok { + fmt.Println(emptyCondition.ID) + } + } + } + } +} diff --git a/pkg/rest/api.go b/pkg/rest/api.go index 5920133..e94194b 100644 --- a/pkg/rest/api.go +++ b/pkg/rest/api.go @@ -84,6 +84,14 @@ func unauthErrText(err error) render.Renderer { } } +func forbiddenErrText(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusForbidden, + Error: "Forbidden: " + err.Error(), + } +} + func constraintErrText(err error) render.Renderer { return &errResponse{ Err: err, @@ -147,6 +155,10 @@ func autoError(err error) render.Renderer { } } + if rbac.ErrAccessDenied(err) != nil { + return forbiddenErrText(err) + } + return internalError(err) } diff --git a/pkg/server/ingest.go b/pkg/server/ingest.go index 8f42950..a8c766e 100644 --- a/pkg/server/ingest.go +++ b/pkg/server/ingest.go @@ -7,5 +7,6 @@ import ( ) func (s *Server) Ingest(ctx context.Context, call *calls.Call) error { - return s.sinks.EmitCall(context.Background(), call) + ctx = context.WithoutCancel(ctx) + return s.sinks.EmitCall(ctx, call) } diff --git a/pkg/server/server.go b/pkg/server/server.go index 15f11d4..6c51931 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -67,14 +67,16 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { r := chi.NewRouter() - authenticator := auth.NewAuthenticator(cfg.Auth) + ust := users.NewStore(db) + + authenticator := auth.NewAuthenticator(cfg.Auth, ust) notifier, err := notify.New(cfg.Notify) if err != nil { return nil, err } - tgCache := tgstore.NewCache() + tgCache := tgstore.NewCache(db) api := rest.New(cfg.BaseURL.URL()) rbacSvc, err := rbac.New() @@ -95,8 +97,8 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { sinks: sinks.NewSinkManager(), rest: api, share: share.NewService(), - users: users.NewStore(), - calls: callstore.NewStore(), + users: ust, + calls: callstore.NewStore(db), incidents: incstore.NewStore(), rbac: rbacSvc, } @@ -113,7 +115,7 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { } } - srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db, tgCache), true) + srv.sinks.Register("database", sinks.NewDatabaseSink(db, tgCache), true) srv.sinks.Register("nexus", sinks.NewNexusSink(srv.nex), false) if srv.alerter.Enabled() { diff --git a/pkg/share/service.go b/pkg/share/service.go index 63d3141..eea1edd 100644 --- a/pkg/share/service.go +++ b/pkg/share/service.go @@ -4,6 +4,7 @@ import ( "context" "time" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/rs/zerolog/log" ) @@ -26,6 +27,8 @@ func (s *service) ShareStore() Store { } func (s *service) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "share"}) + tick := time.NewTicker(PruneInterval) for { diff --git a/pkg/sinks/database.go b/pkg/sinks/database.go index 68cf70b..8018117 100644 --- a/pkg/sinks/database.go +++ b/pkg/sinks/database.go @@ -2,15 +2,12 @@ package sinks import ( "context" - "fmt" - "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/calls/callstore" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog/log" ) @@ -29,59 +26,9 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error { return nil } - params := s.toAddCallParams(call) - - err := s.db.InTx(ctx, func(tx database.Store) error { - err := tx.AddCall(ctx, params) - if err != nil { - return fmt.Errorf("add call: %w", err) - } - - log.Debug().Str("id", call.ID.String()).Int("system", call.System).Int("tgid", call.Talkgroup).Msg("stored") - - return nil - }, pgx.TxOptions{}) - - if err != nil && database.IsTGConstraintViolation(err) { - return s.db.InTx(ctx, func(tx database.Store) error { - _, err := s.tgs.LearnTG(ctx, call) - if err != nil { - return fmt.Errorf("learn tg: %w", err) - } - - err = tx.AddCall(ctx, params) - if err != nil { - return fmt.Errorf("learn tg retry: %w", err) - } - - return nil - }, pgx.TxOptions{}) - } - - return err + return callstore.FromCtx(ctx).AddCall(ctx, call) } func (s *DatabaseSink) SinkType() string { return "database" } - -func (s *DatabaseSink) toAddCallParams(call *calls.Call) database.AddCallParams { - return database.AddCallParams{ - ID: call.ID, - Submitter: call.Submitter.Int32Ptr(), - System: call.System, - Talkgroup: call.Talkgroup, - CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true}, - AudioName: common.NilIfZero(call.AudioName), - AudioBlob: call.Audio, - AudioType: common.NilIfZero(call.AudioType), - Duration: call.Duration.MsInt32Ptr(), - Frequency: call.Frequency, - Frequencies: call.Frequencies, - Patches: call.Patches, - TGLabel: call.TalkgroupLabel, - TGAlphaTag: call.TGAlphaTag, - TGGroup: call.TalkgroupGroup, - Source: call.Source, - } -} diff --git a/pkg/sources/http.go b/pkg/sources/http.go index 503881e..dfd8df5 100644 --- a/pkg/sources/http.go +++ b/pkg/sources/http.go @@ -9,6 +9,7 @@ import ( "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" @@ -99,7 +100,13 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - submitter, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) + submitterSub, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) + if err != nil { + auth.ErrorResponse(w, err) + return + } + + submitter, err := users.FromSubject(submitterSub) if err != nil { auth.ErrorResponse(w, err) return @@ -118,20 +125,22 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) { return } - call, err := cur.ToCall(*submitter) + call, err := cur.ToCall(submitter.ID) if err != nil { log.Error().Err(err).Msg("toCall failed") http.Error(w, err.Error(), http.StatusBadRequest) return } - err = h.ing.Ingest(ctx, call) + err = h.ing.Ingest(rbac.CtxWithSubject(ctx, submitterSub), call) if err != nil { - log.Error().Err(err).Msg("ingest failed") - http.Error(w, "Call ingest failed.", http.StatusInternalServerError) + if rbac.ErrAccessDenied(err) != nil { + log.Error().Err(err).Msg("ingest failed") + http.Error(w, "Call ingest failed.", http.StatusForbidden) + } return } - log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Msg("ingested") + log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Str("sub", submitter.Username).Msg("ingested") written, err := w.Write([]byte("Call imported successfully.")) if err != nil { diff --git a/pkg/store/store.go b/pkg/store/store.go deleted file mode 100644 index bb85c85..0000000 --- a/pkg/store/store.go +++ /dev/null @@ -1,50 +0,0 @@ -package store - -import ( - "context" - - "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" - "dynatron.me/x/stillbox/pkg/users" -) - -type Store interface { - TG() tgstore.Store - User() users.Store -} - -type store struct { - tg tgstore.Store - user users.Store -} - -func (s *store) TG() tgstore.Store { - return s.tg -} - -func (s *store) User() users.Store { - return s.user -} - -func New() Store { - return &store{ - tg: tgstore.NewCache(), - user: users.NewStore(), - } -} - -type storeCtxKey string - -const StoreCtxKey storeCtxKey = "store" - -func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, StoreCtxKey, s) -} - -func FromCtx(ctx context.Context) Store { - s, ok := ctx.Value(StoreCtxKey).(Store) - if !ok { - return New() - } - - return s -} diff --git a/pkg/talkgroups/tgstore/store.go b/pkg/talkgroups/tgstore/store.go index f92d3af..64a010e 100644 --- a/pkg/talkgroups/tgstore/store.go +++ b/pkg/talkgroups/tgstore/store.go @@ -11,6 +11,7 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" tgsp "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/users" @@ -176,7 +177,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewCache() + panic("no tg store in context") } return s @@ -201,19 +202,23 @@ type cache struct { sync.RWMutex tgs tgMap systems map[int]string + db database.Store } // NewCache returns a new cache Store. -func NewCache() *cache { +func NewCache(db database.Store) *cache { tgc := &cache{ tgs: make(tgMap), systems: make(map[int]string), + db: db, } return tgc } func (t *cache) Hint(ctx context.Context, tgs []tgsp.ID) error { + // since this doesn't actually return data, we can skip rbac checks. + // This is only called by system services anyway. if len(tgs) < 1 { return nil } @@ -322,11 +327,15 @@ func addToRowList[T rowType](t *cache, tgRecords []T) []*tgsp.Talkgroup { } func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + db := t.db r := make([]*tgsp.Talkgroup, 0, len(tgs)) opt := sOpt(opts) - var err error if tgs != nil { toGet := make(tgsp.IDs, 0, len(tgs)) for _, id := range tgs { @@ -394,7 +403,8 @@ func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp. } func (t *cache) Load(ctx context.Context, tgs database.TGTuples) error { - tgRecords, err := database.FromCtx(ctx).GetTalkgroupsWithLearnedBySysTGID(ctx, tgs) + // No need for RBAC checks since this merely primes the cache and returns nothing. + tgRecords, err := t.db.GetTalkgroupsWithLearnedBySysTGID(ctx, tgs) if err != nil { return err } @@ -420,9 +430,13 @@ func (t *cache) Weight(ctx context.Context, id tgsp.ID, tm time.Time) float64 { } func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([]*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + db := t.db opt := sOpt(opts) - var err error if opt.pagination != nil { sortDir, err := opt.pagination.SortDir() if err != nil { @@ -472,13 +486,18 @@ func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([] } func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + rec, has := t.get(tg) if has { return rec, nil } - record, err := database.FromCtx(ctx).GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup)) + record, err := t.db.GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup)) switch err { case nil: case pgx.ErrNoRows: @@ -494,12 +513,17 @@ func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) { } func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return "", false + } + t.RLock() n, has := t.systems[id] t.RUnlock() if !has { - sys, err := database.FromCtx(ctx).GetSystemName(ctx, id) + sys, err := t.db.GetSystemName(ctx, id) if err != nil { return "", false } @@ -525,7 +549,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara return nil, ErrNoSuchSystem } - db := database.FromCtx(ctx) + db := t.db var tg database.Talkgroup err = db.InTx(ctx, func(db database.Store) error { var oerr error @@ -563,12 +587,17 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara } func (t *cache) DeleteSystem(ctx context.Context, id int) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() t.invalidate() - err := database.FromCtx(ctx).DeleteSystem(ctx, id) + err = t.db.DeleteSystem(ctx, id) switch { case err == nil: return nil @@ -580,6 +609,11 @@ func (t *cache) DeleteSystem(ctx context.Context, id int) error { } func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() @@ -588,7 +622,7 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { return err } - err = database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { + err = t.db.InTx(ctx, func(db database.Store) error { err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr()) if err != nil { return err @@ -611,7 +645,12 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { } func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate, rbac.ActionUpdate)) + if err != nil { + return nil, err + } + + db := t.db sys, has := t.SystemName(ctx, c.System) if !has { @@ -649,7 +688,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse return nil, err } - db := database.FromCtx(ctx) + db := t.db sysName, hasSys := t.SystemName(ctx, system) if !hasSys { return nil, ErrNoSuchSystem @@ -725,14 +764,24 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse } func (t *cache) CreateSystem(ctx context.Context, id int, name string) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() t.addSysNoLock(id, name) - return database.FromCtx(ctx).CreateSystem(ctx, id, name) + return t.db.CreateSystem(ctx, id, name) } func (t *cache) Tags(ctx context.Context) ([]string, error) { - return database.FromCtx(ctx).GetAllTalkgroupTags(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + return t.db.GetAllTalkgroupTags(ctx) } diff --git a/pkg/talkgroups/xport/radioref/radioreference_test.go b/pkg/talkgroups/xport/radioref/radioreference_test.go index c453e2f..42fb340 100644 --- a/pkg/talkgroups/xport/radioref/radioreference_test.go +++ b/pkg/talkgroups/xport/radioref/radioreference_test.go @@ -14,9 +14,12 @@ import ( "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database/mocks" + "dynatron.me/x/stillbox/pkg/rbac" + rbacmocks "dynatron.me/x/stillbox/pkg/rbac/mocks" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/xport" + "dynatron.me/x/stillbox/pkg/users" ) func getFixture(fixture string) []byte { @@ -51,14 +54,19 @@ func TestRadioRef(t *testing.T) { }, } + subject := users.User{IsAdmin: true} + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { dbMock := mocks.NewStore(t) + rbacMock := rbacmocks.NewRBAC(t) + rbacMock.EXPECT().Check(mock.AnythingOfType("*context.valueCtx"), rbac.UseResource("Talkgroup"), mock.AnythingOfType("rbac.CheckOption")).Return(&subject, nil) if tc.expectErr == nil { dbMock.EXPECT().GetSystemName(mock.AnythingOfType("*context.valueCtx"), tc.sysID).Return(tc.sysName, nil) } ctx := database.CtxWithDB(context.Background(), dbMock) - ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache()) + ctx = rbac.CtxWithRBAC(ctx, rbacMock) + ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache(dbMock)) ij := &xport.ImportJob{ Type: xport.Format(tc.impType), SystemID: tc.sysID, diff --git a/pkg/users/store.go b/pkg/users/store.go index 161cd31..0129181 100644 --- a/pkg/users/store.go +++ b/pkg/users/store.go @@ -22,15 +22,20 @@ type Store interface { // UpdateUser updates a user's record UpdateUser(ctx context.Context, username string, user UserUpdate) error + + // GetUserByAPIKey gets a user by API key. + GetAPIKey(ctx context.Context, key string) (database.GetAPIKeyRow, error) } type postgresStore struct { cache.Cache[string, *User] + db database.Store } -func NewStore() *postgresStore { +func NewStore(db database.Store) *postgresStore { return &postgresStore{ Cache: cache.New[string, *User](), + db: db, } } @@ -45,7 +50,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewStore() + panic("no users store in context") } return s @@ -61,7 +66,7 @@ type UserUpdate struct { } func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error { - dbu, err := database.FromCtx(ctx).UpdateUser(ctx, username, user.Email, user.IsAdmin) + dbu, err := s.db.UpdateUser(ctx, username, user.Email, user.IsAdmin) if err != nil { return err } @@ -77,8 +82,7 @@ func (s *postgresStore) GetUser(ctx context.Context, username string) (*User, er return u, nil } - db := database.FromCtx(ctx) - dbu, err := db.GetUserByUsername(ctx, username) + dbu, err := s.db.GetUserByUsername(ctx, username) if err != nil { return nil, err } @@ -95,9 +99,7 @@ func (s *postgresStore) UserPrefs(ctx context.Context, username string, appName return nil, err } - db := database.FromCtx(ctx) - - prefs, err := db.GetAppPrefs(ctx, appName, int(u.ID)) + prefs, err := s.db.GetAppPrefs(ctx, appName, int(u.ID)) if err != nil { return nil, err } @@ -111,7 +113,9 @@ func (s *postgresStore) SetUserPrefs(ctx context.Context, username string, appNa return err } - db := database.FromCtx(ctx) - - return db.SetAppPrefs(ctx, appName, prefs, int(u.ID)) + return s.db.SetAppPrefs(ctx, appName, prefs, int(u.ID)) +} + +func (s *postgresStore) GetAPIKey(ctx context.Context, b64hash string) (database.GetAPIKeyRow, error) { + return s.db.GetAPIKey(ctx, b64hash) } diff --git a/pkg/users/user.go b/pkg/users/user.go index 9a6dd8a..d4904a7 100644 --- a/pkg/users/user.go +++ b/pkg/users/user.go @@ -66,6 +66,10 @@ type User struct { Prefs json.RawMessage } +func (u *User) GetName() string { + return u.Username +} + func (u *User) GetRoles() []string { r := make([]string, 1, 2) diff --git a/sql/postgres/queries/calls.sql b/sql/postgres/queries/calls.sql index 998566a..cceaea1 100644 --- a/sql/postgres/queries/calls.sql +++ b/sql/postgres/queries/calls.sql @@ -156,3 +156,9 @@ CASE WHEN sqlc.narg('tags_not')::TEXT[] IS NOT NULL THEN c.duration > @longer_than ) ELSE TRUE END) ; + +-- name: DeleteCall :exec +DELETE FROM calls WHERE id = @id; + +-- name: GetCallSubmitter :one +SELECT submitter FROM calls WHERE id = @id; diff --git a/sql/postgres/queries/incidents.sql b/sql/postgres/queries/incidents.sql index 314f64d..98b3fc8 100644 --- a/sql/postgres/queries/incidents.sql +++ b/sql/postgres/queries/incidents.sql @@ -175,3 +175,6 @@ RETURNING *; -- name: DeleteIncident :exec DELETE FROM incidents CASCADE WHERE id = @id; + +-- name: GetIncidentOwner :one +SELECT owner FROM incidents WHERE id = @id;