This commit is contained in:
Daniel Ponte 2025-01-18 15:18:40 -05:00
parent aec89b5569
commit c8c3023afa
31 changed files with 973 additions and 182 deletions

View file

@ -1,4 +1,4 @@
dir: '{{ replaceAll .InterfaceDirRelative "internal" "internal_" }}/mocks' dir: '{{.InterfaceDir}}/mocks'
mockname: "{{.InterfaceName}}" mockname: "{{.InterfaceName}}"
outpkg: "mocks" outpkg: "mocks"
filename: "{{.InterfaceName}}.go" filename: "{{.InterfaceName}}.go"
@ -9,3 +9,7 @@ packages:
interfaces: interfaces:
Store: Store:
DBTX: DBTX:
dynatron.me/x/stillbox/pkg/rbac:
config:
interfaces:
RBAC:

View file

@ -15,7 +15,6 @@ import (
"dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/forms"
"dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/sources"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"

View file

@ -14,6 +14,7 @@ import (
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/notify" "dynatron.me/x/stillbox/pkg/notify"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sinks"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
@ -123,6 +124,8 @@ func New(cfg config.Alerting, tgCache tgstore.Store, opts ...AlertOption) Alerte
// Go is the alerting loop. It does not start a goroutine. // Go is the alerting loop. It does not start a goroutine.
func (as *alerter) Go(ctx context.Context) { func (as *alerter) Go(ctx context.Context) {
ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "alerter"})
err := as.startBackfill(ctx) err := as.startBackfill(ctx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("backfill") log.Error().Err(err).Msg("backfill")

View file

@ -12,6 +12,7 @@ import (
"dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/jsontypes"
"dynatron.me/x/stillbox/internal/trending" "dynatron.me/x/stillbox/internal/trending"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
@ -59,8 +60,9 @@ func (s *Simulation) stepClock(t time.Time) {
// Simulate begins the simulation using the DB handle from ctx. It returns final scores. // Simulate begins the simulation using the DB handle from ctx. It returns final scores.
func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.ID], error) { func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.ID], error) {
db := database.FromCtx(ctx)
now := time.Now() now := time.Now()
tgc := tgstore.NewCache() tgc := tgstore.NewCache(db)
s.Enable = true s.Enable = true
s.alerter = New(s.Alerting, tgc, WithClock(&s.clock)).(*alerter) s.alerter = New(s.Alerting, tgc, WithClock(&s.clock)).(*alerter)

View file

@ -7,7 +7,7 @@ import (
"time" "time"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/rbac"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -16,20 +16,19 @@ import (
type apiKeyAuth interface { type apiKeyAuth interface {
// CheckAPIKey validates the provided key and returns the API owner's users.UserID. // CheckAPIKey validates the provided key and returns the API owner's users.UserID.
// An error is returned if validation fails for any reason. // An error is returned if validation fails for any reason.
CheckAPIKey(ctx context.Context, key string) (*users.UserID, error) CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error)
} }
func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*users.UserID, error) { func (a *Auth) CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error) {
keyUuid, err := uuid.Parse(key) keyUuid, err := uuid.Parse(key)
if err != nil { if err != nil {
log.Error().Str("apikey", key).Msg("cannot parse key") log.Error().Str("apikey", key).Msg("cannot parse key")
return nil, ErrBadRequest return nil, ErrBadRequest
} }
db := database.FromCtx(ctx)
hash := sha256.Sum256([]byte(keyUuid.String())) hash := sha256.Sum256([]byte(keyUuid.String()))
b64hash := base64.StdEncoding.EncodeToString(hash[:]) b64hash := base64.StdEncoding.EncodeToString(hash[:])
apik, err := db.GetAPIKey(ctx, b64hash) apik, err := a.ust.GetAPIKey(ctx, b64hash)
if err != nil { if err != nil {
if database.IsNoRows(err) { if database.IsNoRows(err) {
log.Error().Str("apikey", keyUuid.String()).Msg("no such key") log.Error().Str("apikey", keyUuid.String()).Msg("no such key")
@ -45,7 +44,5 @@ func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*users.UserID, erro
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
owner := users.UserID(apik.Owner) return a.ust.GetUser(ctx, apik.Username)
return &owner, nil
} }

View file

@ -8,6 +8,8 @@ import (
_ "embed" _ "embed"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/httprate" "github.com/go-chi/httprate"
"github.com/go-chi/jwtauth/v5" "github.com/go-chi/jwtauth/v5"
@ -22,14 +24,16 @@ type Authenticator interface {
type Auth struct { type Auth struct {
rl *httprate.RateLimiter rl *httprate.RateLimiter
jwt *jwtauth.JWTAuth jwt *jwtauth.JWTAuth
ust users.Store
cfg config.Auth cfg config.Auth
} }
// NewAuthenticator creates a new Authenticator with the provided config. // NewAuthenticator creates a new Authenticator with the provided config.
func NewAuthenticator(cfg config.Auth) *Auth { func NewAuthenticator(cfg config.Auth, ust users.Store) *Auth {
a := &Auth{ a := &Auth{
rl: httprate.NewRateLimiter(5, time.Minute), rl: httprate.NewRateLimiter(5, time.Minute),
cfg: cfg, cfg: cfg,
ust: ust,
} }
a.initJWT() a.initJWT()
@ -51,7 +55,7 @@ var (
// ErrorResponse writes the error and appropriate HTTP response code. // ErrorResponse writes the error and appropriate HTTP response code.
func ErrorResponse(w http.ResponseWriter, err error) { func ErrorResponse(w http.ResponseWriter, err error) {
switch err { switch err {
case ErrLoginFailed, ErrUnauthorized: case ErrLoginFailed, ErrUnauthorized, rbac.ErrBadSubject:
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
case ErrBadRequest: case ErrBadRequest:
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@ -153,12 +152,12 @@ func (a *Auth) Login(ctx context.Context, username, password string) (token stri
} }
} }
return a.newToken(found.ID), nil return a.newToken(found.Username), nil
} }
func (a *Auth) newToken(uid int) string { func (a *Auth) newToken(username string) string {
claims := claims{ claims := claims{
"sub": strconv.Itoa(int(uid)), "sub": username,
} }
jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month
_, tokenString, err := a.jwt.Encode(claims) _, tokenString, err := a.jwt.Encode(claims)
@ -190,19 +189,14 @@ func (a *Auth) routeRefresh(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid token", http.StatusBadRequest) http.Error(w, "Invalid token", http.StatusBadRequest)
return return
} }
existingSubjectUID := jwToken.Subject()
if existingSubjectUID == "" { existingSubjectUsername := jwToken.Subject()
if existingSubjectUsername == "" {
http.Error(w, "Invalid token", http.StatusBadRequest) http.Error(w, "Invalid token", http.StatusBadRequest)
return return
} }
uid, err := strconv.Atoi(existingSubjectUID)
if err != nil {
log.Error().Str("sub", existingSubjectUID).Err(err).Msg("atoi uid for token refresh")
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
tok := a.newToken(uid) tok := a.newToken(existingSubjectUsername)
cookie := &http.Cookie{ cookie := &http.Cookie{
Name: CookieName, Name: CookieName,

View file

@ -9,6 +9,9 @@ import (
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
"dynatron.me/x/stillbox/pkg/users"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@ -16,6 +19,12 @@ import (
) )
type Store interface { type Store interface {
// AddCall adds a call to the database.
AddCall(ctx context.Context, call *calls.Call) error
// DeleteCall deletes a call.
Delete(ctx context.Context, id uuid.UUID) error
// CallAudio returns a CallAudio struct // CallAudio returns a CallAudio struct
CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error)
@ -24,10 +33,13 @@ type Store interface {
} }
type store struct { type store struct {
db database.Store
} }
func NewStore() *store { func NewStore(db database.Store) *store {
return new(store) return &store{
db: db,
}
} }
type storeCtxKey string type storeCtxKey string
@ -41,13 +53,77 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
func FromCtx(ctx context.Context) Store { func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store) s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok { if !ok {
return NewStore() panic("no call store in context")
} }
return s return s
} }
func toAddCallParams(call *calls.Call) database.AddCallParams {
return database.AddCallParams{
ID: call.ID,
Submitter: call.Submitter.Int32Ptr(),
System: call.System,
Talkgroup: call.Talkgroup,
CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true},
AudioName: common.NilIfZero(call.AudioName),
AudioBlob: call.Audio,
AudioType: common.NilIfZero(call.AudioType),
Duration: call.Duration.MsInt32Ptr(),
Frequency: call.Frequency,
Frequencies: call.Frequencies,
Patches: call.Patches,
TGLabel: call.TalkgroupLabel,
TGAlphaTag: call.TGAlphaTag,
TGGroup: call.TalkgroupGroup,
Source: call.Source,
}
}
func (s *store) AddCall(ctx context.Context, call *calls.Call) error {
_, err := rbac.Check(ctx, call, rbac.WithActions(rbac.ActionCreate))
if err != nil {
return err
}
params := toAddCallParams(call)
db := database.FromCtx(ctx)
tgs := tgstore.FromCtx(ctx)
err = db.InTx(ctx, func(tx database.Store) error {
err := tx.AddCall(ctx, params)
if err != nil {
return fmt.Errorf("add call: %w", err)
}
return nil
}, pgx.TxOptions{})
if err != nil && database.IsTGConstraintViolation(err) {
return db.InTx(ctx, func(tx database.Store) error {
_, err := tgs.LearnTG(ctx, call)
if err != nil {
return fmt.Errorf("learn tg: %w", err)
}
err = tx.AddCall(ctx, params)
if err != nil {
return fmt.Errorf("learn tg retry: %w", err)
}
return nil
}, pgx.TxOptions{})
}
return nil
}
func (s *store) CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) { func (s *store) CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
dbCall, err := db.GetCallAudioByID(ctx, id) dbCall, err := db.GetCallAudioByID(ctx, id)
@ -76,6 +152,11 @@ type CallsParams struct {
} }
func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListCallsPRow, totalCount int, err error) { func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListCallsPRow, totalCount int, err error) {
_, err = rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, 0, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
offset, perPage := p.Pagination.OffsetPerPage(100) offset, perPage := p.Pagination.OffsetPerPage(100)
@ -127,3 +208,28 @@ func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListC
return rows, int(count), err return rows, int(count), err
} }
func (s *store) Delete(ctx context.Context, id uuid.UUID) error {
callOwn, err := s.getCallOwner(ctx, id)
if err != nil {
return err
}
_, err = rbac.Check(ctx, &callOwn, rbac.WithActions(rbac.ActionDelete))
if err != nil {
return err
}
return database.FromCtx(ctx).DeleteCall(ctx, id)
}
func (s *store) getCallOwner(ctx context.Context, id uuid.UUID) (calls.Call, error) {
subInt, err := database.FromCtx(ctx).GetCallSubmitter(ctx, id)
var sub *users.UserID
if subInt != nil {
sub = common.PtrTo(users.UserID(*subInt))
}
return calls.Call{ID: id, Submitter: sub}, err
}

View file

@ -155,6 +155,15 @@ func (q *Queries) CleanupSweptCalls(ctx context.Context, rangeStart pgtype.Times
return result.RowsAffected(), nil return result.RowsAffected(), nil
} }
const deleteCall = `-- name: DeleteCall :exec
DELETE FROM calls WHERE id = $1
`
func (q *Queries) DeleteCall(ctx context.Context, id uuid.UUID) error {
_, err := q.db.Exec(ctx, deleteCall, id)
return err
}
const getCallAudioByID = `-- name: GetCallAudioByID :one const getCallAudioByID = `-- name: GetCallAudioByID :one
SELECT SELECT
c.call_date, c.call_date,
@ -192,6 +201,17 @@ func (q *Queries) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAu
return i, err return i, err
} }
const getCallSubmitter = `-- name: GetCallSubmitter :one
SELECT submitter FROM calls WHERE id = $1
`
func (q *Queries) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) {
row := q.db.QueryRow(ctx, getCallSubmitter, id)
var submitter *int32
err := row.Scan(&submitter)
return submitter, err
}
const getDatabaseSize = `-- name: GetDatabaseSize :one const getDatabaseSize = `-- name: GetDatabaseSize :one
SELECT pg_size_pretty(pg_database_size(current_database())) SELECT pg_size_pretty(pg_database_size(current_database()))
` `

View file

@ -244,6 +244,17 @@ func (q *Queries) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetInci
return items, nil return items, nil
} }
const getIncidentOwner = `-- name: GetIncidentOwner :one
SELECT owner FROM incidents WHERE id = $1
`
func (q *Queries) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) {
row := q.db.QueryRow(ctx, getIncidentOwner, id)
var owner int
err := row.Scan(&owner)
return owner, err
}
const listIncidentsCount = `-- name: ListIncidentsCount :one const listIncidentsCount = `-- name: ListIncidentsCount :one
SELECT COUNT(*) SELECT COUNT(*)
FROM incidents i FROM incidents i

View file

@ -795,6 +795,53 @@ func (_c *Store_DeleteAPIKey_Call) RunAndReturn(run func(context.Context, string
return _c return _c
} }
// DeleteCall provides a mock function with given fields: ctx, id
func (_m *Store) DeleteCall(ctx context.Context, id uuid.UUID) error {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for DeleteCall")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) error); ok {
r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
return r0
}
// Store_DeleteCall_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCall'
type Store_DeleteCall_Call struct {
*mock.Call
}
// DeleteCall is a helper method to define mock.On call
// - ctx context.Context
// - id uuid.UUID
func (_e *Store_Expecter) DeleteCall(ctx interface{}, id interface{}) *Store_DeleteCall_Call {
return &Store_DeleteCall_Call{Call: _e.mock.On("DeleteCall", ctx, id)}
}
func (_c *Store_DeleteCall_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_DeleteCall_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(uuid.UUID))
})
return _c
}
func (_c *Store_DeleteCall_Call) Return(_a0 error) *Store_DeleteCall_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Store_DeleteCall_Call) RunAndReturn(run func(context.Context, uuid.UUID) error) *Store_DeleteCall_Call {
_c.Call.Return(run)
return _c
}
// DeleteIncident provides a mock function with given fields: ctx, id // DeleteIncident provides a mock function with given fields: ctx, id
func (_m *Store) DeleteIncident(ctx context.Context, id uuid.UUID) error { func (_m *Store) DeleteIncident(ctx context.Context, id uuid.UUID) error {
ret := _m.Called(ctx, id) ret := _m.Called(ctx, id)
@ -1358,6 +1405,65 @@ func (_c *Store_GetCallAudioByID_Call) RunAndReturn(run func(context.Context, uu
return _c return _c
} }
// GetCallSubmitter provides a mock function with given fields: ctx, id
func (_m *Store) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for GetCallSubmitter")
}
var r0 *int32
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*int32, error)); ok {
return rf(ctx, id)
}
if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) *int32); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*int32)
}
}
if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_GetCallSubmitter_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCallSubmitter'
type Store_GetCallSubmitter_Call struct {
*mock.Call
}
// GetCallSubmitter is a helper method to define mock.On call
// - ctx context.Context
// - id uuid.UUID
func (_e *Store_Expecter) GetCallSubmitter(ctx interface{}, id interface{}) *Store_GetCallSubmitter_Call {
return &Store_GetCallSubmitter_Call{Call: _e.mock.On("GetCallSubmitter", ctx, id)}
}
func (_c *Store_GetCallSubmitter_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetCallSubmitter_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(uuid.UUID))
})
return _c
}
func (_c *Store_GetCallSubmitter_Call) Return(_a0 *int32, _a1 error) *Store_GetCallSubmitter_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_GetCallSubmitter_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*int32, error)) *Store_GetCallSubmitter_Call {
_c.Call.Return(run)
return _c
}
// GetDatabaseSize provides a mock function with given fields: ctx // GetDatabaseSize provides a mock function with given fields: ctx
func (_m *Store) GetDatabaseSize(ctx context.Context) (string, error) { func (_m *Store) GetDatabaseSize(ctx context.Context) (string, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
@ -1530,6 +1636,63 @@ func (_c *Store_GetIncidentCalls_Call) RunAndReturn(run func(context.Context, uu
return _c return _c
} }
// GetIncidentOwner provides a mock function with given fields: ctx, id
func (_m *Store) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for GetIncidentOwner")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (int, error)); ok {
return rf(ctx, id)
}
if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) int); ok {
r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_GetIncidentOwner_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIncidentOwner'
type Store_GetIncidentOwner_Call struct {
*mock.Call
}
// GetIncidentOwner is a helper method to define mock.On call
// - ctx context.Context
// - id uuid.UUID
func (_e *Store_Expecter) GetIncidentOwner(ctx interface{}, id interface{}) *Store_GetIncidentOwner_Call {
return &Store_GetIncidentOwner_Call{Call: _e.mock.On("GetIncidentOwner", ctx, id)}
}
func (_c *Store_GetIncidentOwner_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetIncidentOwner_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(uuid.UUID))
})
return _c
}
func (_c *Store_GetIncidentOwner_Call) Return(_a0 int, _a1 error) *Store_GetIncidentOwner_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_GetIncidentOwner_Call) RunAndReturn(run func(context.Context, uuid.UUID) (int, error)) *Store_GetIncidentOwner_Call {
_c.Call.Return(run)
return _c
}
// GetShare provides a mock function with given fields: ctx, id // GetShare provides a mock function with given fields: ctx, id
func (_m *Store) GetShare(ctx context.Context, id string) (database.Share, error) { func (_m *Store) GetShare(ctx context.Context, id string) (database.Share, error) {
ret := _m.Called(ctx, id) ret := _m.Called(ctx, id)

View file

@ -13,6 +13,7 @@ import (
"dynatron.me/x/stillbox/internal/isoweek" "dynatron.me/x/stillbox/internal/isoweek"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
@ -134,6 +135,7 @@ func New(db database.Store, cfg config.Partition) (*partman, error) {
var _ PartitionManager = (*partman)(nil) var _ PartitionManager = (*partman)(nil)
func (pm *partman) Go(ctx context.Context) { func (pm *partman) Go(ctx context.Context) {
ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "partman"})
tick := time.NewTicker(CheckInterval) tick := time.NewTicker(CheckInterval)
select { select {

View file

@ -23,6 +23,7 @@ type Querier interface {
CreateSystem(ctx context.Context, iD int, name string) error CreateSystem(ctx context.Context, iD int, name string) error
CreateUser(ctx context.Context, arg CreateUserParams) (User, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
DeleteAPIKey(ctx context.Context, apiKey string) error DeleteAPIKey(ctx context.Context, apiKey string) error
DeleteCall(ctx context.Context, id uuid.UUID) error
DeleteIncident(ctx context.Context, id uuid.UUID) error DeleteIncident(ctx context.Context, id uuid.UUID) error
DeleteShare(ctx context.Context, id string) error DeleteShare(ctx context.Context, id string) error
DeleteSystem(ctx context.Context, id int) error DeleteSystem(ctx context.Context, id int) error
@ -32,9 +33,11 @@ type Querier interface {
GetAllTalkgroupTags(ctx context.Context) ([]string, error) GetAllTalkgroupTags(ctx context.Context) ([]string, error)
GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error) GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error)
GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error)
GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error)
GetDatabaseSize(ctx context.Context) (string, error) GetDatabaseSize(ctx context.Context) (string, error)
GetIncident(ctx context.Context, id uuid.UUID) (Incident, error) GetIncident(ctx context.Context, id uuid.UUID) (Incident, error)
GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetIncidentCallsRow, error) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetIncidentCallsRow, error)
GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error)
GetShare(ctx context.Context, id string) (Share, error) GetShare(ctx context.Context, id string) (Share, error)
GetSystemName(ctx context.Context, systemID int) (string, error) GetSystemName(ctx context.Context, systemID int) (string, error)
GetTalkgroup(ctx context.Context, systemID int32, tGID int32) (GetTalkgroupRow, error) GetTalkgroup(ctx context.Context, systemID int32, tGID int32) (GetTalkgroupRow, error)

View file

@ -9,6 +9,7 @@ import (
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/incidents" "dynatron.me/x/stillbox/pkg/incidents"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@ -72,7 +73,6 @@ func NewStore() Store {
} }
func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) { func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) {
// TODO: replace this with a real RBAC check
user, err := users.UserCheck(ctx, new(incidents.Incident), "create") user, err := users.UserCheck(ctx, new(incidents.Incident), "create")
if err != nil { if err != nil {
return nil, err return nil, err
@ -132,6 +132,16 @@ func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*in
} }
func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID, addCallIDs []uuid.UUID, notes []byte, removeCallIDs []uuid.UUID) error { func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID, addCallIDs []uuid.UUID, notes []byte, removeCallIDs []uuid.UUID) error {
inc, err := s.getIncidentOwner(ctx, incidentID)
if err != nil {
return err
}
_, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionUpdate))
if err != nil {
return err
}
return database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { return database.FromCtx(ctx).InTx(ctx, func(db database.Store) error {
if len(addCallIDs) > 0 { if len(addCallIDs) > 0 {
var noteAr [][]byte var noteAr [][]byte
@ -160,6 +170,10 @@ func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID
} }
func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incident, totalCount int, err error) { func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incident, totalCount int, err error) {
_, err = rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, 0, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
offset, perPage := p.Pagination.OffsetPerPage(100) offset, perPage := p.Pagination.OffsetPerPage(100)
@ -261,6 +275,11 @@ func fromDBCalls(d []database.GetIncidentCallsRow) []incidents.IncidentCall {
} }
func (s *store) Incident(ctx context.Context, id uuid.UUID) (*incidents.Incident, error) { func (s *store) Incident(ctx context.Context, id uuid.UUID) (*incidents.Incident, error) {
_, err := rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
var r incidents.Incident var r incidents.Incident
txErr := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { txErr := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error {
inc, err := db.GetIncident(ctx, id) inc, err := db.GetIncident(ctx, id)
@ -307,6 +326,16 @@ func (uip UpdateIncidentParams) toDBUIP(id uuid.UUID) database.UpdateIncidentPar
} }
func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncidentParams) (*incidents.Incident, error) { func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncidentParams) (*incidents.Incident, error) {
ckinc, err := s.getIncidentOwner(ctx, id)
if err != nil {
return nil, err
}
_, err = rbac.Check(ctx, &ckinc, rbac.WithActions(rbac.ActionUpdate))
if err != nil {
return nil, err
}
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
dbInc, err := db.UpdateIncident(ctx, p.toDBUIP(id)) dbInc, err := db.UpdateIncident(ctx, p.toDBUIP(id))
@ -320,9 +349,24 @@ func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncide
} }
func (s *store) DeleteIncident(ctx context.Context, id uuid.UUID) error { func (s *store) DeleteIncident(ctx context.Context, id uuid.UUID) error {
inc, err := s.getIncidentOwner(ctx, id)
if err != nil {
return err
}
_, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionDelete))
if err != nil {
return err
}
return database.FromCtx(ctx).DeleteIncident(ctx, id) return database.FromCtx(ctx).DeleteIncident(ctx, id)
} }
func (s *store) UpdateNotes(ctx context.Context, incidentID uuid.UUID, callID uuid.UUID, notes []byte) error { func (s *store) UpdateNotes(ctx context.Context, incidentID uuid.UUID, callID uuid.UUID, notes []byte) error {
return database.FromCtx(ctx).UpdateCallIncidentNotes(ctx, notes, incidentID, callID) return database.FromCtx(ctx).UpdateCallIncidentNotes(ctx, notes, incidentID, callID)
} }
func (s *store) getIncidentOwner(ctx context.Context, id uuid.UUID) (incidents.Incident, error) {
owner, err := database.FromCtx(ctx).GetIncidentOwner(ctx, id)
return incidents.Incident{ID: id, Owner: users.UserID(owner)}, err
}

View file

@ -6,6 +6,7 @@ import (
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/pb" "dynatron.me/x/stillbox/pkg/pb"
"dynatron.me/x/stillbox/pkg/rbac"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -38,6 +39,7 @@ func New() *Nexus {
} }
func (n *Nexus) Go(ctx context.Context) { func (n *Nexus) Go(ctx context.Context) {
ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "nexus"})
for { for {
select { select {
case call, ok := <-n.callCh: case call, ok := <-n.callCh:

113
pkg/rbac/mocks/RBAC.go Normal file
View file

@ -0,0 +1,113 @@
// Code generated by mockery v2.47.0. DO NOT EDIT.
package mocks
import (
context "context"
rbac "dynatron.me/x/stillbox/pkg/rbac"
mock "github.com/stretchr/testify/mock"
restrict "github.com/el-mike/restrict/v2"
)
// RBAC is an autogenerated mock type for the RBAC type
type RBAC struct {
mock.Mock
}
type RBAC_Expecter struct {
mock *mock.Mock
}
func (_m *RBAC) EXPECT() *RBAC_Expecter {
return &RBAC_Expecter{mock: &_m.Mock}
}
// Check provides a mock function with given fields: ctx, res, opts
func (_m *RBAC) Check(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption) (rbac.Subject, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, res)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for Check")
}
var r0 rbac.Subject
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)); ok {
return rf(ctx, res, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) rbac.Subject); ok {
r0 = rf(ctx, res, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(rbac.Subject)
}
}
if rf, ok := ret.Get(1).(func(context.Context, restrict.Resource, ...rbac.CheckOption) error); ok {
r1 = rf(ctx, res, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RBAC_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check'
type RBAC_Check_Call struct {
*mock.Call
}
// Check is a helper method to define mock.On call
// - ctx context.Context
// - res restrict.Resource
// - opts ...rbac.CheckOption
func (_e *RBAC_Expecter) Check(ctx interface{}, res interface{}, opts ...interface{}) *RBAC_Check_Call {
return &RBAC_Check_Call{Call: _e.mock.On("Check",
append([]interface{}{ctx, res}, opts...)...)}
}
func (_c *RBAC_Check_Call) Run(run func(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption)) *RBAC_Check_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]rbac.CheckOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(rbac.CheckOption)
}
}
run(args[0].(context.Context), args[1].(restrict.Resource), variadicArgs...)
})
return _c
}
func (_c *RBAC_Check_Call) Return(_a0 rbac.Subject, _a1 error) *RBAC_Check_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *RBAC_Check_Call) RunAndReturn(run func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)) *RBAC_Check_Call {
_c.Call.Return(run)
return _c
}
// NewRBAC creates a new instance of RBAC. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewRBAC(t interface {
mock.TestingT
Cleanup(func())
}) *RBAC {
mock := &RBAC{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View file

@ -3,6 +3,8 @@ package rbac
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"reflect"
"github.com/el-mike/restrict/v2" "github.com/el-mike/restrict/v2"
"github.com/el-mike/restrict/v2/adapters" "github.com/el-mike/restrict/v2/adapters"
@ -12,6 +14,7 @@ const (
RoleUser = "User" RoleUser = "User"
RoleSubmitter = "Submitter" RoleSubmitter = "Submitter"
RoleAdmin = "Admin" RoleAdmin = "Admin"
RoleSystem = "System"
RolePublic = "Public" RolePublic = "Public"
RoleShareGuest = "ShareGuest" RoleShareGuest = "ShareGuest"
@ -20,6 +23,7 @@ const (
ResourceTalkgroup = "Talkgroup" ResourceTalkgroup = "Talkgroup"
ResourceAlert = "Alert" ResourceAlert = "Alert"
ResourceShare = "Share" ResourceShare = "Share"
ResourceAPIKey = "APIKey"
ActionRead = "read" ActionRead = "read"
ActionCreate = "create" ActionCreate = "create"
@ -29,6 +33,9 @@ const (
PresetUpdateOwn = "updateOwn" PresetUpdateOwn = "updateOwn"
PresetDeleteOwn = "deleteOwn" PresetDeleteOwn = "deleteOwn"
PresetReadShared = "readShared" PresetReadShared = "readShared"
PresetUpdateSubmitter = "updateSubmitter"
PresetDeleteSubmitter = "deleteSubmitter"
) )
var ( var (
@ -43,6 +50,14 @@ func CtxWithSubject(ctx context.Context, sub Subject) context.Context {
return context.WithValue(ctx, SubjectCtxKey, sub) return context.WithValue(ctx, SubjectCtxKey, sub)
} }
func ErrAccessDenied(err error) *restrict.AccessDeniedError {
if accessErr, ok := err.(*restrict.AccessDeniedError); ok {
return accessErr
}
return nil
}
func SubjectFrom(ctx context.Context) Subject { func SubjectFrom(ctx context.Context) Subject {
sub, ok := ctx.Value(SubjectCtxKey).(Subject) sub, ok := ctx.Value(SubjectCtxKey).(Subject)
if ok { if ok {
@ -87,8 +102,8 @@ var policy = &restrict.PolicyDefinition{
ResourceCall: { ResourceCall: {
&restrict.Permission{Action: ActionRead}, &restrict.Permission{Action: ActionRead},
&restrict.Permission{Action: ActionCreate}, &restrict.Permission{Action: ActionCreate},
&restrict.Permission{Preset: PresetUpdateOwn}, &restrict.Permission{Preset: PresetUpdateSubmitter},
&restrict.Permission{Preset: PresetDeleteOwn}, &restrict.Permission{Preset: PresetDeleteSubmitter},
}, },
ResourceTalkgroup: { ResourceTalkgroup: {
&restrict.Permission{Action: ActionRead}, &restrict.Permission{Action: ActionRead},
@ -107,6 +122,11 @@ var policy = &restrict.PolicyDefinition{
ResourceCall: { ResourceCall: {
&restrict.Permission{Action: ActionCreate}, &restrict.Permission{Action: ActionCreate},
}, },
ResourceTalkgroup: {
// for learning TGs
&restrict.Permission{Action: ActionCreate},
&restrict.Permission{Action: ActionUpdate},
},
}, },
}, },
RoleShareGuest: { RoleShareGuest: {
@ -141,6 +161,9 @@ var policy = &restrict.PolicyDefinition{
}, },
}, },
}, },
RoleSystem: {
Parents: []string{RoleSystem},
},
RolePublic: { RolePublic: {
/* /*
Grants: restrict.GrantsMap{ Grants: restrict.GrantsMap{
@ -184,6 +207,38 @@ var policy = &restrict.PolicyDefinition{
}, },
}, },
}, },
PresetUpdateSubmitter: &restrict.Permission{
Action: ActionUpdate,
Conditions: restrict.Conditions{
&SubmitterEqualCondition{
ID: "isSubmitter",
Left: &restrict.ValueDescriptor{
Source: restrict.ResourceField,
Field: "Submitter",
},
Right: &restrict.ValueDescriptor{
Source: restrict.SubjectField,
Field: "ID",
},
},
},
},
PresetDeleteSubmitter: &restrict.Permission{
Action: ActionDelete,
Conditions: restrict.Conditions{
&SubmitterEqualCondition{
ID: "isSubmitter",
Left: &restrict.ValueDescriptor{
Source: restrict.ResourceField,
Field: "Submitter",
},
Right: &restrict.ValueDescriptor{
Source: restrict.SubjectField,
Field: "ID",
},
},
},
},
PresetReadShared: &restrict.Permission{ PresetReadShared: &restrict.Permission{
Action: ActionRead, Action: ActionRead,
Conditions: restrict.Conditions{ Conditions: restrict.Conditions{
@ -208,15 +263,15 @@ type checkOptions struct {
context restrict.Context context restrict.Context
} }
type checkOption func(*checkOptions) type CheckOption func(*checkOptions)
func WithActions(actions ...string) checkOption { func WithActions(actions ...string) CheckOption {
return func(o *checkOptions) { return func(o *checkOptions) {
o.actions = append(o.actions, actions...) o.actions = append(o.actions, actions...)
} }
} }
func WithContext(ctx restrict.Context) checkOption { func WithContext(ctx restrict.Context) CheckOption {
return func(o *checkOptions) { return func(o *checkOptions) {
o.context = ctx o.context = ctx
} }
@ -228,6 +283,7 @@ func UseResource(rsc string) restrict.Resource {
type Subject interface { type Subject interface {
restrict.Subject restrict.Subject
GetName() string
} }
type Resource interface { type Resource interface {
@ -235,7 +291,7 @@ type Resource interface {
} }
type RBAC interface { type RBAC interface {
Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error)
} }
type rbac struct { type rbac struct {
@ -257,7 +313,12 @@ func New() (*rbac, error) {
}, nil }, nil
} }
func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...checkOption) (Subject, error) { // Check is a convenience function to pull the RBAC instance out of ctx and Check.
func Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) {
return FromCtx(ctx).Check(ctx, res, opts...)
}
func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) {
sub := SubjectFrom(ctx) sub := SubjectFrom(ctx)
o := checkOptions{} o := checkOptions{}
@ -279,6 +340,10 @@ type ShareLinkGuest struct {
ShareID string ShareID string
} }
func (s *ShareLinkGuest) GetName() string {
return "SHARE:" + s.ShareID
}
func (s *ShareLinkGuest) GetRoles() []string { func (s *ShareLinkGuest) GetRoles() []string {
return []string{RoleShareGuest} return []string{RoleShareGuest}
} }
@ -287,6 +352,70 @@ type PublicSubject struct {
RemoteAddr string RemoteAddr string
} }
func (s *PublicSubject) GetName() string {
return "PUBLIC:" + s.RemoteAddr
}
func (s *PublicSubject) GetRoles() []string { func (s *PublicSubject) GetRoles() []string {
return []string{RolePublic} return []string{RolePublic}
} }
type SystemServiceSubject struct {
Name string
}
func (s *SystemServiceSubject) GetName() string {
return "SYSTEM:" + s.Name
}
func (s *SystemServiceSubject) GetRoles() []string {
return []string{RoleSystem}
}
const (
SubmitterEqualConditionType = "SUBMITTER_EQUAL"
)
type SubmitterEqualCondition struct {
ID string `json:"name,omitempty" yaml:"name,omitempty"`
Left *restrict.ValueDescriptor `json:"left" yaml:"left"`
Right *restrict.ValueDescriptor `json:"right" yaml:"right"`
}
func (s *SubmitterEqualCondition) Type() string {
return SubmitterEqualConditionType
}
func (c *SubmitterEqualCondition) Check(r *restrict.AccessRequest) error {
left, err := c.Left.GetValue(r)
if err != nil {
return err
}
right, err := c.Right.GetValue(r)
if err != nil {
return err
}
lVal := reflect.ValueOf(left)
rVal := reflect.ValueOf(right)
// deref Left. this is the difference between us and EqualCondition
for lVal.Kind() == reflect.Pointer {
lVal = lVal.Elem()
}
if !lVal.IsValid() || !reflect.DeepEqual(rVal.Interface(), lVal.Interface()) {
return restrict.NewConditionNotSatisfiedError(c, r, fmt.Errorf("values \"%v\" and \"%v\" are not equal", left, right))
}
return nil
}
func SubmitterEqualConditionFactory() restrict.Condition {
return new(SubmitterEqualCondition)
}
func init() {
restrict.RegisterConditionFactory(SubmitterEqualConditionType, SubmitterEqualConditionFactory)
}

197
pkg/rbac/rbac_test.go Normal file
View file

@ -0,0 +1,197 @@
package rbac_test
import (
"context"
"errors"
"fmt"
"testing"
"dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/incidents"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/users"
"github.com/el-mike/restrict/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRBAC(t *testing.T) {
tests := []struct {
name string
subject rbac.Subject
resource rbac.Resource
action string
expectErr error
}{
{
name: "admin update talkgroup",
subject: &users.User{
ID: 2,
IsAdmin: true,
},
resource: &talkgroups.Talkgroup{},
action: rbac.ActionUpdate,
expectErr: nil,
},
{
name: "admin update incident",
subject: &users.User{
ID: 2,
IsAdmin: true,
},
resource: &incidents.Incident{
Name: "test incident",
Owner: 4,
},
action: rbac.ActionUpdate,
expectErr: nil,
},
{
name: "user update incident not owner",
subject: &users.User{
ID: 2,
},
resource: &incidents.Incident{
Name: "test incident",
Owner: 4,
},
action: rbac.ActionUpdate,
expectErr: errors.New(`access denied for Action: "update" on Resource: "Incident"`),
},
{
name: "user update incident owner",
subject: &users.User{
ID: 2,
},
resource: &incidents.Incident{
Name: "test incident",
Owner: 2,
},
action: rbac.ActionUpdate,
expectErr: nil,
},
{
name: "user delete incident not owner",
subject: &users.User{
ID: 2,
},
resource: &incidents.Incident{
Name: "test incident",
Owner: 6,
},
action: rbac.ActionDelete,
expectErr: errors.New(`access denied for Action: "delete" on Resource: "Incident"`),
},
{
name: "admin update call",
subject: &users.User{
ID: 2,
IsAdmin: true,
},
resource: &calls.Call{
Submitter: common.PtrTo(users.UserID(4)),
},
action: rbac.ActionUpdate,
expectErr: nil,
},
{
name: "user update call not owner",
subject: &users.User{
ID: 2,
},
resource: &calls.Call{
Submitter: common.PtrTo(users.UserID(4)),
},
action: rbac.ActionUpdate,
expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`),
},
{
name: "user update call owner",
subject: &users.User{
ID: 2,
},
resource: &calls.Call{
Submitter: common.PtrTo(users.UserID(2)),
},
action: rbac.ActionUpdate,
expectErr: nil,
},
{
name: "user update call nil submitter",
subject: &users.User{
ID: 2,
},
resource: &calls.Call{
Submitter: nil,
},
action: rbac.ActionUpdate,
expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`),
},
{
name: "user delete call not owner",
subject: &users.User{
ID: 2,
},
resource: &calls.Call{
Submitter: common.PtrTo(users.UserID(6)),
},
action: rbac.ActionDelete,
expectErr: errors.New(`access denied for Action: "delete" on Resource: "Call"`),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := rbac.CtxWithSubject(context.Background(), tc.subject)
rb, err := rbac.New()
require.NoError(t, err)
sub, err := rb.Check(ctx, tc.resource, rbac.WithActions(tc.action))
if tc.expectErr != nil {
assert.Equal(t, tc.expectErr.Error(), err.Error())
} else {
if !assert.NoError(t, err) {
accErr(err)
}
}
assert.Equal(t, tc.subject, sub)
})
}
}
func accErr(err error) {
if accessError, ok := err.(*restrict.AccessDeniedError); ok {
// Error() implementation. Returns a message in a form: "access denied for Action/s: ... on Resource: ..."
fmt.Println(accessError)
// Returns an AccessRequest that failed.
fmt.Println(accessError.Request)
// Returns first reason for the denied access.
// Especially helpful in fail-early mode, where there will only be one Reason.
fmt.Println(accessError.FirstReason())
// Reasons property will hold all errors that caused the access to be denied.
for _, permissionErr := range accessError.Reasons {
fmt.Println(permissionErr)
fmt.Println(permissionErr.Action)
fmt.Println(permissionErr.RoleName)
fmt.Println(permissionErr.ResourceName)
// Returns first ConditionNotSatisfied error for given PermissionError, if any was returned for given PermissionError.
// Especially helpful in fail-early mode, where there will only be one failed Condition.
fmt.Println(permissionErr.FirstConditionError())
// ConditionErrors property will hold all ConditionNotSatisfied errors.
for _, conditionErr := range permissionErr.ConditionErrors {
fmt.Println(conditionErr)
fmt.Println(conditionErr.Reason)
// Every ConditionNotSatisfied contains an instance of Condition that returned it,
// so it can be tested using type assertion to get more details about failed Condition.
if emptyCondition, ok := conditionErr.Condition.(*restrict.EmptyCondition); ok {
fmt.Println(emptyCondition.ID)
}
}
}
}
}

View file

@ -84,6 +84,14 @@ func unauthErrText(err error) render.Renderer {
} }
} }
func forbiddenErrText(err error) render.Renderer {
return &errResponse{
Err: err,
Code: http.StatusForbidden,
Error: "Forbidden: " + err.Error(),
}
}
func constraintErrText(err error) render.Renderer { func constraintErrText(err error) render.Renderer {
return &errResponse{ return &errResponse{
Err: err, Err: err,
@ -147,6 +155,10 @@ func autoError(err error) render.Renderer {
} }
} }
if rbac.ErrAccessDenied(err) != nil {
return forbiddenErrText(err)
}
return internalError(err) return internalError(err)
} }

View file

@ -7,5 +7,6 @@ import (
) )
func (s *Server) Ingest(ctx context.Context, call *calls.Call) error { func (s *Server) Ingest(ctx context.Context, call *calls.Call) error {
return s.sinks.EmitCall(context.Background(), call) ctx = context.WithoutCancel(ctx)
return s.sinks.EmitCall(ctx, call)
} }

View file

@ -67,14 +67,16 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
r := chi.NewRouter() r := chi.NewRouter()
authenticator := auth.NewAuthenticator(cfg.Auth) ust := users.NewStore(db)
authenticator := auth.NewAuthenticator(cfg.Auth, ust)
notifier, err := notify.New(cfg.Notify) notifier, err := notify.New(cfg.Notify)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tgCache := tgstore.NewCache() tgCache := tgstore.NewCache(db)
api := rest.New(cfg.BaseURL.URL()) api := rest.New(cfg.BaseURL.URL())
rbacSvc, err := rbac.New() rbacSvc, err := rbac.New()
@ -95,8 +97,8 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
sinks: sinks.NewSinkManager(), sinks: sinks.NewSinkManager(),
rest: api, rest: api,
share: share.NewService(), share: share.NewService(),
users: users.NewStore(), users: ust,
calls: callstore.NewStore(), calls: callstore.NewStore(db),
incidents: incstore.NewStore(), incidents: incstore.NewStore(),
rbac: rbacSvc, rbac: rbacSvc,
} }
@ -113,7 +115,7 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) {
} }
} }
srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db, tgCache), true) srv.sinks.Register("database", sinks.NewDatabaseSink(db, tgCache), true)
srv.sinks.Register("nexus", sinks.NewNexusSink(srv.nex), false) srv.sinks.Register("nexus", sinks.NewNexusSink(srv.nex), false)
if srv.alerter.Enabled() { if srv.alerter.Enabled() {

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"time" "time"
"dynatron.me/x/stillbox/pkg/rbac"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -26,6 +27,8 @@ func (s *service) ShareStore() Store {
} }
func (s *service) Go(ctx context.Context) { func (s *service) Go(ctx context.Context) {
ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "share"})
tick := time.NewTicker(PruneInterval) tick := time.NewTicker(PruneInterval)
for { for {

View file

@ -2,15 +2,12 @@ package sinks
import ( import (
"context" "context"
"fmt"
"dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/calls/callstore"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -29,59 +26,9 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error {
return nil return nil
} }
params := s.toAddCallParams(call) return callstore.FromCtx(ctx).AddCall(ctx, call)
err := s.db.InTx(ctx, func(tx database.Store) error {
err := tx.AddCall(ctx, params)
if err != nil {
return fmt.Errorf("add call: %w", err)
}
log.Debug().Str("id", call.ID.String()).Int("system", call.System).Int("tgid", call.Talkgroup).Msg("stored")
return nil
}, pgx.TxOptions{})
if err != nil && database.IsTGConstraintViolation(err) {
return s.db.InTx(ctx, func(tx database.Store) error {
_, err := s.tgs.LearnTG(ctx, call)
if err != nil {
return fmt.Errorf("learn tg: %w", err)
}
err = tx.AddCall(ctx, params)
if err != nil {
return fmt.Errorf("learn tg retry: %w", err)
}
return nil
}, pgx.TxOptions{})
}
return err
} }
func (s *DatabaseSink) SinkType() string { func (s *DatabaseSink) SinkType() string {
return "database" return "database"
} }
func (s *DatabaseSink) toAddCallParams(call *calls.Call) database.AddCallParams {
return database.AddCallParams{
ID: call.ID,
Submitter: call.Submitter.Int32Ptr(),
System: call.System,
Talkgroup: call.Talkgroup,
CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true},
AudioName: common.NilIfZero(call.AudioName),
AudioBlob: call.Audio,
AudioType: common.NilIfZero(call.AudioType),
Duration: call.Duration.MsInt32Ptr(),
Frequency: call.Frequency,
Frequencies: call.Frequencies,
Patches: call.Patches,
TGLabel: call.TalkgroupLabel,
TGAlphaTag: call.TGAlphaTag,
TGGroup: call.TalkgroupGroup,
Source: call.Source,
}
}

View file

@ -9,6 +9,7 @@ import (
"dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/forms"
"dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/rbac"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -99,7 +100,13 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
submitter, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) submitterSub, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key"))
if err != nil {
auth.ErrorResponse(w, err)
return
}
submitter, err := users.FromSubject(submitterSub)
if err != nil { if err != nil {
auth.ErrorResponse(w, err) auth.ErrorResponse(w, err)
return return
@ -118,20 +125,22 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) {
return return
} }
call, err := cur.ToCall(*submitter) call, err := cur.ToCall(submitter.ID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("toCall failed") log.Error().Err(err).Msg("toCall failed")
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
err = h.ing.Ingest(ctx, call) err = h.ing.Ingest(rbac.CtxWithSubject(ctx, submitterSub), call)
if err != nil { if err != nil {
log.Error().Err(err).Msg("ingest failed") if rbac.ErrAccessDenied(err) != nil {
http.Error(w, "Call ingest failed.", http.StatusInternalServerError) log.Error().Err(err).Msg("ingest failed")
http.Error(w, "Call ingest failed.", http.StatusForbidden)
}
return return
} }
log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Msg("ingested") log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Str("sub", submitter.Username).Msg("ingested")
written, err := w.Write([]byte("Call imported successfully.")) written, err := w.Write([]byte("Call imported successfully."))
if err != nil { if err != nil {

View file

@ -1,50 +0,0 @@
package store
import (
"context"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
"dynatron.me/x/stillbox/pkg/users"
)
type Store interface {
TG() tgstore.Store
User() users.Store
}
type store struct {
tg tgstore.Store
user users.Store
}
func (s *store) TG() tgstore.Store {
return s.tg
}
func (s *store) User() users.Store {
return s.user
}
func New() Store {
return &store{
tg: tgstore.NewCache(),
user: users.NewStore(),
}
}
type storeCtxKey string
const StoreCtxKey storeCtxKey = "store"
func CtxWithStore(ctx context.Context, s Store) context.Context {
return context.WithValue(ctx, StoreCtxKey, s)
}
func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok {
return New()
}
return s
}

View file

@ -11,6 +11,7 @@ import (
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/rbac"
tgsp "dynatron.me/x/stillbox/pkg/talkgroups" tgsp "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/users" "dynatron.me/x/stillbox/pkg/users"
@ -176,7 +177,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
func FromCtx(ctx context.Context) Store { func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store) s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok { if !ok {
return NewCache() panic("no tg store in context")
} }
return s return s
@ -201,19 +202,23 @@ type cache struct {
sync.RWMutex sync.RWMutex
tgs tgMap tgs tgMap
systems map[int]string systems map[int]string
db database.Store
} }
// NewCache returns a new cache Store. // NewCache returns a new cache Store.
func NewCache() *cache { func NewCache(db database.Store) *cache {
tgc := &cache{ tgc := &cache{
tgs: make(tgMap), tgs: make(tgMap),
systems: make(map[int]string), systems: make(map[int]string),
db: db,
} }
return tgc return tgc
} }
func (t *cache) Hint(ctx context.Context, tgs []tgsp.ID) error { func (t *cache) Hint(ctx context.Context, tgs []tgsp.ID) error {
// since this doesn't actually return data, we can skip rbac checks.
// This is only called by system services anyway.
if len(tgs) < 1 { if len(tgs) < 1 {
return nil return nil
} }
@ -322,11 +327,15 @@ func addToRowList[T rowType](t *cache, tgRecords []T) []*tgsp.Talkgroup {
} }
func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp.Talkgroup, error) { func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp.Talkgroup, error) {
db := database.FromCtx(ctx) _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
db := t.db
r := make([]*tgsp.Talkgroup, 0, len(tgs)) r := make([]*tgsp.Talkgroup, 0, len(tgs))
opt := sOpt(opts) opt := sOpt(opts)
var err error
if tgs != nil { if tgs != nil {
toGet := make(tgsp.IDs, 0, len(tgs)) toGet := make(tgsp.IDs, 0, len(tgs))
for _, id := range tgs { for _, id := range tgs {
@ -394,7 +403,8 @@ func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp.
} }
func (t *cache) Load(ctx context.Context, tgs database.TGTuples) error { func (t *cache) Load(ctx context.Context, tgs database.TGTuples) error {
tgRecords, err := database.FromCtx(ctx).GetTalkgroupsWithLearnedBySysTGID(ctx, tgs) // No need for RBAC checks since this merely primes the cache and returns nothing.
tgRecords, err := t.db.GetTalkgroupsWithLearnedBySysTGID(ctx, tgs)
if err != nil { if err != nil {
return err return err
} }
@ -420,9 +430,13 @@ func (t *cache) Weight(ctx context.Context, id tgsp.ID, tm time.Time) float64 {
} }
func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([]*tgsp.Talkgroup, error) { func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([]*tgsp.Talkgroup, error) {
db := database.FromCtx(ctx) _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
db := t.db
opt := sOpt(opts) opt := sOpt(opts)
var err error
if opt.pagination != nil { if opt.pagination != nil {
sortDir, err := opt.pagination.SortDir() sortDir, err := opt.pagination.SortDir()
if err != nil { if err != nil {
@ -472,13 +486,18 @@ func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([]
} }
func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) { func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
rec, has := t.get(tg) rec, has := t.get(tg)
if has { if has {
return rec, nil return rec, nil
} }
record, err := database.FromCtx(ctx).GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup)) record, err := t.db.GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup))
switch err { switch err {
case nil: case nil:
case pgx.ErrNoRows: case pgx.ErrNoRows:
@ -494,12 +513,17 @@ func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) {
} }
func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) { func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead))
if err != nil {
return "", false
}
t.RLock() t.RLock()
n, has := t.systems[id] n, has := t.systems[id]
t.RUnlock() t.RUnlock()
if !has { if !has {
sys, err := database.FromCtx(ctx).GetSystemName(ctx, id) sys, err := t.db.GetSystemName(ctx, id)
if err != nil { if err != nil {
return "", false return "", false
} }
@ -525,7 +549,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara
return nil, ErrNoSuchSystem return nil, ErrNoSuchSystem
} }
db := database.FromCtx(ctx) db := t.db
var tg database.Talkgroup var tg database.Talkgroup
err = db.InTx(ctx, func(db database.Store) error { err = db.InTx(ctx, func(db database.Store) error {
var oerr error var oerr error
@ -563,12 +587,17 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara
} }
func (t *cache) DeleteSystem(ctx context.Context, id int) error { func (t *cache) DeleteSystem(ctx context.Context, id int) error {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete))
if err != nil {
return err
}
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.invalidate() t.invalidate()
err := database.FromCtx(ctx).DeleteSystem(ctx, id) err = t.db.DeleteSystem(ctx, id)
switch { switch {
case err == nil: case err == nil:
return nil return nil
@ -580,6 +609,11 @@ func (t *cache) DeleteSystem(ctx context.Context, id int) error {
} }
func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete))
if err != nil {
return err
}
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
@ -588,7 +622,7 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error {
return err return err
} }
err = database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { err = t.db.InTx(ctx, func(db database.Store) error {
err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr()) err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr())
if err != nil { if err != nil {
return err return err
@ -611,7 +645,12 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error {
} }
func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) { func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) {
db := database.FromCtx(ctx) _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate, rbac.ActionUpdate))
if err != nil {
return nil, err
}
db := t.db
sys, has := t.SystemName(ctx, c.System) sys, has := t.SystemName(ctx, c.System)
if !has { if !has {
@ -649,7 +688,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse
return nil, err return nil, err
} }
db := database.FromCtx(ctx) db := t.db
sysName, hasSys := t.SystemName(ctx, system) sysName, hasSys := t.SystemName(ctx, system)
if !hasSys { if !hasSys {
return nil, ErrNoSuchSystem return nil, ErrNoSuchSystem
@ -725,14 +764,24 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse
} }
func (t *cache) CreateSystem(ctx context.Context, id int, name string) error { func (t *cache) CreateSystem(ctx context.Context, id int, name string) error {
_, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate))
if err != nil {
return err
}
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.addSysNoLock(id, name) t.addSysNoLock(id, name)
return database.FromCtx(ctx).CreateSystem(ctx, id, name) return t.db.CreateSystem(ctx, id, name)
} }
func (t *cache) Tags(ctx context.Context) ([]string, error) { func (t *cache) Tags(ctx context.Context) ([]string, error) {
return database.FromCtx(ctx).GetAllTalkgroupTags(ctx) _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead))
if err != nil {
return nil, err
}
return t.db.GetAllTalkgroupTags(ctx)
} }

View file

@ -14,9 +14,12 @@ import (
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/database/mocks" "dynatron.me/x/stillbox/pkg/database/mocks"
"dynatron.me/x/stillbox/pkg/rbac"
rbacmocks "dynatron.me/x/stillbox/pkg/rbac/mocks"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
"dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore"
"dynatron.me/x/stillbox/pkg/talkgroups/xport" "dynatron.me/x/stillbox/pkg/talkgroups/xport"
"dynatron.me/x/stillbox/pkg/users"
) )
func getFixture(fixture string) []byte { func getFixture(fixture string) []byte {
@ -51,14 +54,19 @@ func TestRadioRef(t *testing.T) {
}, },
} }
subject := users.User{IsAdmin: true}
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
dbMock := mocks.NewStore(t) dbMock := mocks.NewStore(t)
rbacMock := rbacmocks.NewRBAC(t)
rbacMock.EXPECT().Check(mock.AnythingOfType("*context.valueCtx"), rbac.UseResource("Talkgroup"), mock.AnythingOfType("rbac.CheckOption")).Return(&subject, nil)
if tc.expectErr == nil { if tc.expectErr == nil {
dbMock.EXPECT().GetSystemName(mock.AnythingOfType("*context.valueCtx"), tc.sysID).Return(tc.sysName, nil) dbMock.EXPECT().GetSystemName(mock.AnythingOfType("*context.valueCtx"), tc.sysID).Return(tc.sysName, nil)
} }
ctx := database.CtxWithDB(context.Background(), dbMock) ctx := database.CtxWithDB(context.Background(), dbMock)
ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache()) ctx = rbac.CtxWithRBAC(ctx, rbacMock)
ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache(dbMock))
ij := &xport.ImportJob{ ij := &xport.ImportJob{
Type: xport.Format(tc.impType), Type: xport.Format(tc.impType),
SystemID: tc.sysID, SystemID: tc.sysID,

View file

@ -22,15 +22,20 @@ type Store interface {
// UpdateUser updates a user's record // UpdateUser updates a user's record
UpdateUser(ctx context.Context, username string, user UserUpdate) error UpdateUser(ctx context.Context, username string, user UserUpdate) error
// GetUserByAPIKey gets a user by API key.
GetAPIKey(ctx context.Context, key string) (database.GetAPIKeyRow, error)
} }
type postgresStore struct { type postgresStore struct {
cache.Cache[string, *User] cache.Cache[string, *User]
db database.Store
} }
func NewStore() *postgresStore { func NewStore(db database.Store) *postgresStore {
return &postgresStore{ return &postgresStore{
Cache: cache.New[string, *User](), Cache: cache.New[string, *User](),
db: db,
} }
} }
@ -45,7 +50,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
func FromCtx(ctx context.Context) Store { func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store) s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok { if !ok {
return NewStore() panic("no users store in context")
} }
return s return s
@ -61,7 +66,7 @@ type UserUpdate struct {
} }
func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error { func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error {
dbu, err := database.FromCtx(ctx).UpdateUser(ctx, username, user.Email, user.IsAdmin) dbu, err := s.db.UpdateUser(ctx, username, user.Email, user.IsAdmin)
if err != nil { if err != nil {
return err return err
} }
@ -77,8 +82,7 @@ func (s *postgresStore) GetUser(ctx context.Context, username string) (*User, er
return u, nil return u, nil
} }
db := database.FromCtx(ctx) dbu, err := s.db.GetUserByUsername(ctx, username)
dbu, err := db.GetUserByUsername(ctx, username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,9 +99,7 @@ func (s *postgresStore) UserPrefs(ctx context.Context, username string, appName
return nil, err return nil, err
} }
db := database.FromCtx(ctx) prefs, err := s.db.GetAppPrefs(ctx, appName, int(u.ID))
prefs, err := db.GetAppPrefs(ctx, appName, int(u.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,7 +113,9 @@ func (s *postgresStore) SetUserPrefs(ctx context.Context, username string, appNa
return err return err
} }
db := database.FromCtx(ctx) return s.db.SetAppPrefs(ctx, appName, prefs, int(u.ID))
}
return db.SetAppPrefs(ctx, appName, prefs, int(u.ID))
func (s *postgresStore) GetAPIKey(ctx context.Context, b64hash string) (database.GetAPIKeyRow, error) {
return s.db.GetAPIKey(ctx, b64hash)
} }

View file

@ -66,6 +66,10 @@ type User struct {
Prefs json.RawMessage Prefs json.RawMessage
} }
func (u *User) GetName() string {
return u.Username
}
func (u *User) GetRoles() []string { func (u *User) GetRoles() []string {
r := make([]string, 1, 2) r := make([]string, 1, 2)

View file

@ -156,3 +156,9 @@ CASE WHEN sqlc.narg('tags_not')::TEXT[] IS NOT NULL THEN
c.duration > @longer_than c.duration > @longer_than
) ELSE TRUE END) ) ELSE TRUE END)
; ;
-- name: DeleteCall :exec
DELETE FROM calls WHERE id = @id;
-- name: GetCallSubmitter :one
SELECT submitter FROM calls WHERE id = @id;

View file

@ -175,3 +175,6 @@ RETURNING *;
-- name: DeleteIncident :exec -- name: DeleteIncident :exec
DELETE FROM incidents CASCADE WHERE id = @id; DELETE FROM incidents CASCADE WHERE id = @id;
-- name: GetIncidentOwner :one
SELECT owner FROM incidents WHERE id = @id;