diff --git a/.mockery.yaml b/.mockery.yaml index a0b5082..3174fe4 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -1,4 +1,4 @@ -dir: '{{ replaceAll .InterfaceDirRelative "internal" "internal_" }}/mocks' +dir: '{{.InterfaceDir}}/mocks' mockname: "{{.InterfaceName}}" outpkg: "mocks" filename: "{{.InterfaceName}}.go" @@ -9,3 +9,7 @@ packages: interfaces: Store: DBTX: + dynatron.me/x/stillbox/pkg/rbac: + config: + interfaces: + RBAC: diff --git a/go.mod b/go.mod index 5429915..a766b05 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/el-mike/restrict/v2 v2.0.0 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/go-audio/audio v1.0.0 // indirect github.com/go-audio/riff v1.0.0 // indirect @@ -56,6 +57,7 @@ require ( github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/jwx/v2 v2.1.3 // indirect github.com/lestrrat-go/option v1.0.1 // indirect + github.com/matoous/go-nanoid v1.5.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect diff --git a/go.sum b/go.sum index 1efbc87..26398c9 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,12 @@ github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw= +github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= +github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/casbin/casbin/v2 v2.103.0 h1:dHElatNXNrr8XcseUov0ZSiWjauwmZZE6YMV3eU1yic= +github.com/casbin/casbin/v2 v2.103.0/go.mod h1:Ee33aqGrmES+GNL17L0h9X28wXuo829wnNUnS0edAco= +github.com/casbin/govaluate v1.3.0 h1:VA0eSY0M2lA86dYd5kPPuNZMUD9QkWnOCnavGrw9myc= +github.com/casbin/govaluate v1.3.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -28,6 +34,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/el-mike/restrict/v2 v2.0.0 h1:OuVBseAejSHyfHMUr15c4Gz3WRCEKuuD8IOR/mOIV/o= +github.com/el-mike/restrict/v2 v2.0.0/go.mod h1:ClycXfCKWIZRU1qi2CJIOpHEuonBOj/2GKc+w1lZtrQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= @@ -61,6 +69,7 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y= github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -116,6 +125,8 @@ github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNB github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/matoous/go-nanoid v1.5.1 h1:aCjdvTyO9LLnTIi0fgdXhOPPvOHjpXN6Ik9DaNjIct4= +github.com/matoous/go-nanoid v1.5.1/go.mod h1:zyD2a71IubI24efhpvkJz+ZwfwagzgSO6UNiFsZKN7U= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -155,12 +166,16 @@ github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7 github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -181,6 +196,7 @@ go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HY go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -192,8 +208,11 @@ golang.org/x/image v0.22.0/go.mod h1:9hPFhljd4zZ1GNSIZJ49sqbp45GKK9t6w+iXvGqZUz4 golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mobile v0.0.0-20241108191957-fa514ef75a0f h1:23H/YlmTHfmmvpZ+ajKZL0qLz0+IwFOIqQA0mQbmLeM= golang.org/x/mobile v0.0.0-20241108191957-fa514ef75a0f/go.mod h1:UbSUP4uu/C9hw9R2CkojhXlAxvayHjBdU9aRvE+c1To= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -206,6 +225,7 @@ golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a364406 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,46 @@ +package cache + +import "sync" + +type Cache[K comparable, V any] interface { + Get(K) (V, bool) + Set(K, V) + Delete(K) + Clear() +} + +type inMem[K comparable, V any] struct { + sync.RWMutex + m map[K]V +} + +func New[K comparable, V any]() *inMem[K, V] { + return &inMem[K, V]{ + m: make(map[K]V), + } +} + +func (c *inMem[K, V]) Get(key K) (V, bool) { + c.RLock() + defer c.RUnlock() + v, ok := c.m[key] + return v, ok +} + +func (c *inMem[K, V]) Set(key K, val V) { + c.Lock() + defer c.Unlock() + c.m[key] = val +} + +func (c *inMem[K, V]) Delete(key K) { + c.Lock() + defer c.Unlock() + delete(c.m, key) +} + +func (c *inMem[K, V]) Clear() { + c.Lock() + defer c.Unlock() + clear(c.m) +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..23079f1 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,33 @@ +package cache_test + +import ( + "testing" + + "dynatron.me/x/stillbox/internal/cache" + + "github.com/stretchr/testify/assert" +) + +func TestCache(t *testing.T) { + c := cache.New[int, string]() + c.Set(4, "asd") + g, ok := c.Get(4) + assert.Equal(t, "asd", g) + assert.True(t, ok) + + _, ok = c.Get(8) + assert.False(t, ok) + + c.Set(7, "fg") + + c.Delete(4) + + g, ok = c.Get(4) + assert.False(t, ok) + assert.NotEqual(t, "asd", g) + + c.Clear() + g, ok = c.Get(7) + assert.False(t, ok) + assert.NotEqual(t, "fg", g) +} diff --git a/internal/forms/marshal_test.go b/internal/forms/marshal_test.go index ae624c2..b762c30 100644 --- a/internal/forms/marshal_test.go +++ b/internal/forms/marshal_test.go @@ -15,9 +15,9 @@ import ( "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/sources" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" ) @@ -62,16 +62,16 @@ func TestMarshal(t *testing.T) { tests := []struct { name string - submitter auth.UserID + submitter users.UserID apiKey string call calls.Call }{ { name: "base", - submitter: auth.UserID(1), + submitter: users.UserID(1), call: calls.Call{ ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}), - Submitter: common.PtrTo(auth.UserID(1)), + Submitter: common.PtrTo(users.UserID(1)), System: 197, Talkgroup: 10101, DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local), diff --git a/pkg/alerting/alerting.go b/pkg/alerting/alerting.go index 259d767..d78a0cf 100644 --- a/pkg/alerting/alerting.go +++ b/pkg/alerting/alerting.go @@ -14,6 +14,7 @@ import ( "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" @@ -123,6 +124,8 @@ func New(cfg config.Alerting, tgCache tgstore.Store, opts ...AlertOption) Alerte // Go is the alerting loop. It does not start a goroutine. func (as *alerter) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "alerter"}) + err := as.startBackfill(ctx) if err != nil { log.Error().Err(err).Msg("backfill") diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 641ab28..1486460 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -12,6 +12,7 @@ import ( "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/internal/trending" "dynatron.me/x/stillbox/pkg/config" + "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" @@ -59,8 +60,9 @@ func (s *Simulation) stepClock(t time.Time) { // Simulate begins the simulation using the DB handle from ctx. It returns final scores. func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.ID], error) { + db := database.FromCtx(ctx) now := time.Now() - tgc := tgstore.NewCache() + tgc := tgstore.NewCache(db) s.Enable = true s.alerter = New(s.Alerting, tgc, WithClock(&s.clock)).(*alerter) diff --git a/pkg/auth/apikey.go b/pkg/auth/apikey.go index d18a303..f88c8f9 100644 --- a/pkg/auth/apikey.go +++ b/pkg/auth/apikey.go @@ -7,28 +7,28 @@ import ( "time" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/google/uuid" "github.com/rs/zerolog/log" ) type apiKeyAuth interface { - // CheckAPIKey validates the provided key and returns the API owner's UserID. + // CheckAPIKey validates the provided key and returns the API owner's users.UserID. // An error is returned if validation fails for any reason. - CheckAPIKey(ctx context.Context, key string) (*UserID, error) + CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error) } -func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*UserID, error) { +func (a *Auth) CheckAPIKey(ctx context.Context, key string) (rbac.Subject, error) { keyUuid, err := uuid.Parse(key) if err != nil { log.Error().Str("apikey", key).Msg("cannot parse key") return nil, ErrBadRequest } - db := database.FromCtx(ctx) hash := sha256.Sum256([]byte(keyUuid.String())) b64hash := base64.StdEncoding.EncodeToString(hash[:]) - apik, err := db.GetAPIKey(ctx, b64hash) + apik, err := a.ust.GetAPIKey(ctx, b64hash) if err != nil { if database.IsNoRows(err) { log.Error().Str("apikey", keyUuid.String()).Msg("no such key") @@ -44,7 +44,5 @@ func (a *Auth) CheckAPIKey(ctx context.Context, key string) (*UserID, error) { return nil, ErrUnauthorized } - owner := UserID(apik.Owner) - - return &owner, nil + return a.ust.GetUser(ctx, apik.Username) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index e7bd43c..4bd073b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -8,23 +8,13 @@ import ( _ "embed" "dynatron.me/x/stillbox/pkg/config" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" ) -type UserID int - -func (u *UserID) Int32Ptr() *int32 { - if u == nil { - return nil - } - - i := int32(*u) - - return &i -} - // Authenticator performs API key and user JWT authentication. type Authenticator interface { jwtAuth @@ -34,14 +24,16 @@ type Authenticator interface { type Auth struct { rl *httprate.RateLimiter jwt *jwtauth.JWTAuth + ust users.Store cfg config.Auth } // NewAuthenticator creates a new Authenticator with the provided config. -func NewAuthenticator(cfg config.Auth) *Auth { +func NewAuthenticator(cfg config.Auth, ust users.Store) *Auth { a := &Auth{ rl: httprate.NewRateLimiter(5, time.Minute), cfg: cfg, + ust: ust, } a.initJWT() @@ -63,7 +55,7 @@ var ( // ErrorResponse writes the error and appropriate HTTP response code. func ErrorResponse(w http.ResponseWriter, err error) { switch err { - case ErrLoginFailed, ErrUnauthorized: + case ErrLoginFailed, ErrUnauthorized, rbac.ErrBadSubject: http.Error(w, err.Error(), http.StatusUnauthorized) case ErrBadRequest: http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 99c2ba7..97867b1 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -4,17 +4,19 @@ import ( "context" "encoding/json" "net/http" - "strconv" "strings" "time" "golang.org/x/crypto/bcrypt" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/rs/zerolog/log" ) @@ -44,21 +46,16 @@ type jwtAuth interface { type claims map[string]interface{} -func UIDFrom(ctx context.Context) *int32 { +// UsernameFrom gets the username (just the subject from token) from ctx. +func UsernameFrom(ctx context.Context) *string { tok, _, err := jwtauth.FromContext(ctx) if err != nil { return nil } - uidStr := tok.Subject() - uidInt, err := strconv.Atoi(uidStr) - if err != nil { - return nil - } + username := tok.Subject() - uid := int32(uidInt) - - return &uid + return &username } func (a *Auth) Authenticated(r *http.Request) (claims, bool) { @@ -88,7 +85,38 @@ func TokenFromCookie(r *http.Request) string { } func (a *Auth) AuthMiddleware() func(http.Handler) http.Handler { - return jwtauth.Authenticator(a.jwt) + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + token, _, err := jwtauth.FromContext(r.Context()) + + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + if token != nil && jwt.Validate(token, a.jwt.ValidateOptions()...) == nil { + ctx := r.Context() + username := token.Subject() + + sub, err := users.FromCtx(ctx).GetUser(ctx, username) + if err != nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + ctx = rbac.CtxWithSubject(ctx, sub) + + next.ServeHTTP(w, r.WithContext(ctx)) + + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + } + return http.HandlerFunc(hfn) + } + } func (a *Auth) initJWT() { @@ -124,12 +152,12 @@ func (a *Auth) Login(ctx context.Context, username, password string) (token stri } } - return a.newToken(found.ID), nil + return a.newToken(found.Username), nil } -func (a *Auth) newToken(uid int) string { +func (a *Auth) newToken(username string) string { claims := claims{ - "sub": strconv.Itoa(int(uid)), + "sub": username, } jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month _, tokenString, err := a.jwt.Encode(claims) @@ -161,19 +189,14 @@ func (a *Auth) routeRefresh(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid token", http.StatusBadRequest) return } - existingSubjectUID := jwToken.Subject() - if existingSubjectUID == "" { + + existingSubjectUsername := jwToken.Subject() + if existingSubjectUsername == "" { http.Error(w, "Invalid token", http.StatusBadRequest) return } - uid, err := strconv.Atoi(existingSubjectUID) - if err != nil { - log.Error().Str("sub", existingSubjectUID).Err(err).Msg("atoi uid for token refresh") - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } - tok := a.newToken(uid) + tok := a.newToken(existingSubjectUsername) cookie := &http.Cookie{ Name: CookieName, diff --git a/pkg/calls/call.go b/pkg/calls/call.go index effdbce..090fa4f 100644 --- a/pkg/calls/call.go +++ b/pkg/calls/call.go @@ -7,9 +7,10 @@ import ( "dynatron.me/x/stillbox/internal/audio" "dynatron.me/x/stillbox/internal/jsontypes" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/pb" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/talkgroups" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "google.golang.org/protobuf/types/known/timestamppb" @@ -52,27 +53,31 @@ type CallAudio struct { // further transformation. relayOut exists for compatibility with http // source CallUploadRequest as used in the relay sink. type Call struct { - ID uuid.UUID `json:"id" relayOut:"id"` - Audio []byte `json:"audio,omitempty" relayOut:"audio,omitempty" filenameField:"AudioName"` - AudioName string `json:"audioName,omitempty" relayOut:"audioName,omitempty"` - AudioType string `json:"audioType,omitempty" relayOut:"audioType,omitempty"` - Duration CallDuration `json:"duration,omitempty" relayOut:"duration,omitempty"` - DateTime time.Time `json:"call_date,omitempty" relayOut:"dateTime,omitempty"` - Frequencies []int `json:"frequencies,omitempty" relayOut:"frequencies,omitempty"` - Frequency int `json:"frequency,omitempty" relayOut:"frequency,omitempty"` - Patches []int `json:"patches,omitempty" relayOut:"patches,omitempty"` - Source int `json:"source,omitempty" relayOut:"source,omitempty"` - System int `json:"system_id,omitempty" relayOut:"system,omitempty"` - Submitter *auth.UserID `json:"submitter,omitempty" relayOut:"submitter,omitempty"` - SystemLabel string `json:"system_name,omitempty" relayOut:"systemLabel,omitempty"` - Talkgroup int `json:"tgid,omitempty" relayOut:"talkgroup,omitempty"` - TalkgroupGroup *string `json:"talkgroupGroup,omitempty" relayOut:"talkgroupGroup,omitempty"` - TalkgroupLabel *string `json:"talkgroupLabel,omitempty" relayOut:"talkgroupLabel,omitempty"` - TGAlphaTag *string `json:"tg_name,omitempty" relayOut:"talkgroupTag,omitempty"` + ID uuid.UUID `json:"id" relayOut:"id"` + Audio []byte `json:"audio,omitempty" relayOut:"audio,omitempty" filenameField:"AudioName"` + AudioName string `json:"audioName,omitempty" relayOut:"audioName,omitempty"` + AudioType string `json:"audioType,omitempty" relayOut:"audioType,omitempty"` + Duration CallDuration `json:"duration,omitempty" relayOut:"duration,omitempty"` + DateTime time.Time `json:"call_date,omitempty" relayOut:"dateTime,omitempty"` + Frequencies []int `json:"frequencies,omitempty" relayOut:"frequencies,omitempty"` + Frequency int `json:"frequency,omitempty" relayOut:"frequency,omitempty"` + Patches []int `json:"patches,omitempty" relayOut:"patches,omitempty"` + Source int `json:"source,omitempty" relayOut:"source,omitempty"` + System int `json:"system_id,omitempty" relayOut:"system,omitempty"` + Submitter *users.UserID `json:"submitter,omitempty" relayOut:"submitter,omitempty"` + SystemLabel string `json:"system_name,omitempty" relayOut:"systemLabel,omitempty"` + Talkgroup int `json:"tgid,omitempty" relayOut:"talkgroup,omitempty"` + TalkgroupGroup *string `json:"talkgroupGroup,omitempty" relayOut:"talkgroupGroup,omitempty"` + TalkgroupLabel *string `json:"talkgroupLabel,omitempty" relayOut:"talkgroupLabel,omitempty"` + TGAlphaTag *string `json:"tg_name,omitempty" relayOut:"talkgroupTag,omitempty"` shouldStore bool `json:"-"` } +func (c *Call) GetResourceName() string { + return rbac.ResourceCall +} + func (c *Call) String() string { return fmt.Sprintf("%s to %d from %d", c.AudioName, c.Talkgroup, c.Source) } diff --git a/pkg/calls/callstore/store.go b/pkg/calls/callstore/store.go index 10c1c9d..218113d 100644 --- a/pkg/calls/callstore/store.go +++ b/pkg/calls/callstore/store.go @@ -9,6 +9,9 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -16,6 +19,12 @@ import ( ) type Store interface { + // AddCall adds a call to the database. + AddCall(ctx context.Context, call *calls.Call) error + + // DeleteCall deletes a call. + Delete(ctx context.Context, id uuid.UUID) error + // CallAudio returns a CallAudio struct CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) @@ -24,10 +33,13 @@ type Store interface { } type store struct { + db database.Store } -func NewStore() *store { - return new(store) +func NewStore(db database.Store) *store { + return &store{ + db: db, + } } type storeCtxKey string @@ -41,13 +53,77 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewStore() + panic("no call store in context") } return s } +func toAddCallParams(call *calls.Call) database.AddCallParams { + return database.AddCallParams{ + ID: call.ID, + Submitter: call.Submitter.Int32Ptr(), + System: call.System, + Talkgroup: call.Talkgroup, + CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true}, + AudioName: common.NilIfZero(call.AudioName), + AudioBlob: call.Audio, + AudioType: common.NilIfZero(call.AudioType), + Duration: call.Duration.MsInt32Ptr(), + Frequency: call.Frequency, + Frequencies: call.Frequencies, + Patches: call.Patches, + TGLabel: call.TalkgroupLabel, + TGAlphaTag: call.TGAlphaTag, + TGGroup: call.TalkgroupGroup, + Source: call.Source, + } +} + +func (s *store) AddCall(ctx context.Context, call *calls.Call) error { + _, err := rbac.Check(ctx, call, rbac.WithActions(rbac.ActionCreate)) + if err != nil { + return err + } + + params := toAddCallParams(call) + db := database.FromCtx(ctx) + tgs := tgstore.FromCtx(ctx) + + err = db.InTx(ctx, func(tx database.Store) error { + err := tx.AddCall(ctx, params) + if err != nil { + return fmt.Errorf("add call: %w", err) + } + + return nil + }, pgx.TxOptions{}) + + if err != nil && database.IsTGConstraintViolation(err) { + return db.InTx(ctx, func(tx database.Store) error { + _, err := tgs.LearnTG(ctx, call) + if err != nil { + return fmt.Errorf("learn tg: %w", err) + } + + err = tx.AddCall(ctx, params) + if err != nil { + return fmt.Errorf("learn tg retry: %w", err) + } + + return nil + }, pgx.TxOptions{}) + } + + return nil +} + func (s *store) CallAudio(ctx context.Context, id uuid.UUID) (*calls.CallAudio, error) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) dbCall, err := db.GetCallAudioByID(ctx, id) @@ -76,6 +152,11 @@ type CallsParams struct { } func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListCallsPRow, totalCount int, err error) { + _, err = rbac.Check(ctx, rbac.UseResource(rbac.ResourceCall), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, 0, err + } + db := database.FromCtx(ctx) offset, perPage := p.Pagination.OffsetPerPage(100) @@ -127,3 +208,28 @@ func (s *store) Calls(ctx context.Context, p CallsParams) (rows []database.ListC return rows, int(count), err } + +func (s *store) Delete(ctx context.Context, id uuid.UUID) error { + callOwn, err := s.getCallOwner(ctx, id) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &callOwn, rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + + return database.FromCtx(ctx).DeleteCall(ctx, id) +} + +func (s *store) getCallOwner(ctx context.Context, id uuid.UUID) (calls.Call, error) { + subInt, err := database.FromCtx(ctx).GetCallSubmitter(ctx, id) + + var sub *users.UserID + + if subInt != nil { + sub = common.PtrTo(users.UserID(*subInt)) + } + return calls.Call{ID: id, Submitter: sub}, err +} diff --git a/pkg/database/calls.sql.go b/pkg/database/calls.sql.go index 22da55b..a26efe8 100644 --- a/pkg/database/calls.sql.go +++ b/pkg/database/calls.sql.go @@ -155,6 +155,15 @@ func (q *Queries) CleanupSweptCalls(ctx context.Context, rangeStart pgtype.Times return result.RowsAffected(), nil } +const deleteCall = `-- name: DeleteCall :exec +DELETE FROM calls WHERE id = $1 +` + +func (q *Queries) DeleteCall(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, deleteCall, id) + return err +} + const getCallAudioByID = `-- name: GetCallAudioByID :one SELECT c.call_date, @@ -192,6 +201,17 @@ func (q *Queries) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAu return i, err } +const getCallSubmitter = `-- name: GetCallSubmitter :one +SELECT submitter FROM calls WHERE id = $1 +` + +func (q *Queries) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) { + row := q.db.QueryRow(ctx, getCallSubmitter, id) + var submitter *int32 + err := row.Scan(&submitter) + return submitter, err +} + const getDatabaseSize = `-- name: GetDatabaseSize :one SELECT pg_size_pretty(pg_database_size(current_database())) ` diff --git a/pkg/database/incidents.sql.go b/pkg/database/incidents.sql.go index d422db0..c39e6ca 100644 --- a/pkg/database/incidents.sql.go +++ b/pkg/database/incidents.sql.go @@ -44,6 +44,7 @@ const createIncident = `-- name: CreateIncident :one INSERT INTO incidents ( id, name, + owner, description, start_time, end_time, @@ -56,14 +57,16 @@ INSERT INTO incidents ( $4, $5, $6, - $7 + $7, + $8 ) -RETURNING id, name, description, start_time, end_time, location, metadata +RETURNING id, name, owner, description, start_time, end_time, location, metadata ` type CreateIncidentParams struct { ID uuid.UUID `json:"id"` Name string `json:"name"` + Owner int `json:"owner"` Description *string `json:"description"` StartTime pgtype.Timestamptz `json:"start_time"` EndTime pgtype.Timestamptz `json:"end_time"` @@ -75,6 +78,7 @@ func (q *Queries) CreateIncident(ctx context.Context, arg CreateIncidentParams) row := q.db.QueryRow(ctx, createIncident, arg.ID, arg.Name, + arg.Owner, arg.Description, arg.StartTime, arg.EndTime, @@ -85,6 +89,7 @@ func (q *Queries) CreateIncident(ctx context.Context, arg CreateIncidentParams) err := row.Scan( &i.ID, &i.Name, + &i.Owner, &i.Description, &i.StartTime, &i.EndTime, @@ -107,6 +112,7 @@ const getIncident = `-- name: GetIncident :one SELECT i.id, i.name, + i.owner, i.description, i.start_time, i.end_time, @@ -122,6 +128,7 @@ func (q *Queries) GetIncident(ctx context.Context, id uuid.UUID) (Incident, erro err := row.Scan( &i.ID, &i.Name, + &i.Owner, &i.Description, &i.StartTime, &i.EndTime, @@ -237,6 +244,17 @@ func (q *Queries) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetInci return items, nil } +const getIncidentOwner = `-- name: GetIncidentOwner :one +SELECT owner FROM incidents WHERE id = $1 +` + +func (q *Queries) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) { + row := q.db.QueryRow(ctx, getIncidentOwner, id) + var owner int + err := row.Scan(&owner) + return owner, err +} + const listIncidentsCount = `-- name: ListIncidentsCount :one SELECT COUNT(*) FROM incidents i @@ -262,6 +280,7 @@ const listIncidentsP = `-- name: ListIncidentsP :many SELECT i.id, i.name, + i.owner, i.description, i.start_time, i.end_time, @@ -299,6 +318,7 @@ type ListIncidentsPParams struct { type ListIncidentsPRow struct { ID uuid.UUID `json:"id"` Name string `json:"name"` + Owner int `json:"owner"` Description *string `json:"description"` StartTime pgtype.Timestamptz `json:"start_time"` EndTime pgtype.Timestamptz `json:"end_time"` @@ -326,6 +346,7 @@ func (q *Queries) ListIncidentsP(ctx context.Context, arg ListIncidentsPParams) if err := rows.Scan( &i.ID, &i.Name, + &i.Owner, &i.Description, &i.StartTime, &i.EndTime, @@ -375,7 +396,7 @@ SET metadata = COALESCE($6, metadata) WHERE id = $7 -RETURNING id, name, description, start_time, end_time, location, metadata +RETURNING id, name, owner, description, start_time, end_time, location, metadata ` type UpdateIncidentParams struct { @@ -402,6 +423,7 @@ func (q *Queries) UpdateIncident(ctx context.Context, arg UpdateIncidentParams) err := row.Scan( &i.ID, &i.Name, + &i.Owner, &i.Description, &i.StartTime, &i.EndTime, diff --git a/pkg/database/mocks/Store.go b/pkg/database/mocks/Store.go index b308e41..11f1ae2 100644 --- a/pkg/database/mocks/Store.go +++ b/pkg/database/mocks/Store.go @@ -502,6 +502,53 @@ func (_c *Store_CreatePartition_Call) RunAndReturn(run func(context.Context, str return _c } +// CreateShare provides a mock function with given fields: ctx, arg +func (_m *Store) CreateShare(ctx context.Context, arg database.CreateShareParams) error { + ret := _m.Called(ctx, arg) + + if len(ret) == 0 { + panic("no return value specified for CreateShare") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, database.CreateShareParams) error); ok { + r0 = rf(ctx, arg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Store_CreateShare_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateShare' +type Store_CreateShare_Call struct { + *mock.Call +} + +// CreateShare is a helper method to define mock.On call +// - ctx context.Context +// - arg database.CreateShareParams +func (_e *Store_Expecter) CreateShare(ctx interface{}, arg interface{}) *Store_CreateShare_Call { + return &Store_CreateShare_Call{Call: _e.mock.On("CreateShare", ctx, arg)} +} + +func (_c *Store_CreateShare_Call) Run(run func(ctx context.Context, arg database.CreateShareParams)) *Store_CreateShare_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.CreateShareParams)) + }) + return _c +} + +func (_c *Store_CreateShare_Call) Return(_a0 error) *Store_CreateShare_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Store_CreateShare_Call) RunAndReturn(run func(context.Context, database.CreateShareParams) error) *Store_CreateShare_Call { + _c.Call.Return(run) + return _c +} + // CreateSystem provides a mock function with given fields: ctx, iD, name func (_m *Store) CreateSystem(ctx context.Context, iD int, name string) error { ret := _m.Called(ctx, iD, name) @@ -748,6 +795,53 @@ func (_c *Store_DeleteAPIKey_Call) RunAndReturn(run func(context.Context, string return _c } +// DeleteCall provides a mock function with given fields: ctx, id +func (_m *Store) DeleteCall(ctx context.Context, id uuid.UUID) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteCall") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Store_DeleteCall_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCall' +type Store_DeleteCall_Call struct { + *mock.Call +} + +// DeleteCall is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) DeleteCall(ctx interface{}, id interface{}) *Store_DeleteCall_Call { + return &Store_DeleteCall_Call{Call: _e.mock.On("DeleteCall", ctx, id)} +} + +func (_c *Store_DeleteCall_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_DeleteCall_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_DeleteCall_Call) Return(_a0 error) *Store_DeleteCall_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Store_DeleteCall_Call) RunAndReturn(run func(context.Context, uuid.UUID) error) *Store_DeleteCall_Call { + _c.Call.Return(run) + return _c +} + // DeleteIncident provides a mock function with given fields: ctx, id func (_m *Store) DeleteIncident(ctx context.Context, id uuid.UUID) error { ret := _m.Called(ctx, id) @@ -795,6 +889,53 @@ func (_c *Store_DeleteIncident_Call) RunAndReturn(run func(context.Context, uuid return _c } +// DeleteShare provides a mock function with given fields: ctx, id +func (_m *Store) DeleteShare(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteShare") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Store_DeleteShare_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteShare' +type Store_DeleteShare_Call struct { + *mock.Call +} + +// DeleteShare is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Store_Expecter) DeleteShare(ctx interface{}, id interface{}) *Store_DeleteShare_Call { + return &Store_DeleteShare_Call{Call: _e.mock.On("DeleteShare", ctx, id)} +} + +func (_c *Store_DeleteShare_Call) Run(run func(ctx context.Context, id string)) *Store_DeleteShare_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Store_DeleteShare_Call) Return(_a0 error) *Store_DeleteShare_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Store_DeleteShare_Call) RunAndReturn(run func(context.Context, string) error) *Store_DeleteShare_Call { + _c.Call.Return(run) + return _c +} + // DeleteSystem provides a mock function with given fields: ctx, id func (_m *Store) DeleteSystem(ctx context.Context, id int) error { ret := _m.Called(ctx, id) @@ -1033,22 +1174,22 @@ func (_c *Store_DropPartition_Call) RunAndReturn(run func(context.Context, strin } // GetAPIKey provides a mock function with given fields: ctx, apiKey -func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.ApiKey, error) { +func (_m *Store) GetAPIKey(ctx context.Context, apiKey string) (database.GetAPIKeyRow, error) { ret := _m.Called(ctx, apiKey) if len(ret) == 0 { panic("no return value specified for GetAPIKey") } - var r0 database.ApiKey + var r0 database.GetAPIKeyRow var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (database.ApiKey, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (database.GetAPIKeyRow, error)); ok { return rf(ctx, apiKey) } - if rf, ok := ret.Get(0).(func(context.Context, string) database.ApiKey); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) database.GetAPIKeyRow); ok { r0 = rf(ctx, apiKey) } else { - r0 = ret.Get(0).(database.ApiKey) + r0 = ret.Get(0).(database.GetAPIKeyRow) } if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -1079,12 +1220,12 @@ func (_c *Store_GetAPIKey_Call) Run(run func(ctx context.Context, apiKey string) return _c } -func (_c *Store_GetAPIKey_Call) Return(_a0 database.ApiKey, _a1 error) *Store_GetAPIKey_Call { +func (_c *Store_GetAPIKey_Call) Return(_a0 database.GetAPIKeyRow, _a1 error) *Store_GetAPIKey_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.ApiKey, error)) *Store_GetAPIKey_Call { +func (_c *Store_GetAPIKey_Call) RunAndReturn(run func(context.Context, string) (database.GetAPIKeyRow, error)) *Store_GetAPIKey_Call { _c.Call.Return(run) return _c } @@ -1264,6 +1405,65 @@ func (_c *Store_GetCallAudioByID_Call) RunAndReturn(run func(context.Context, uu return _c } +// GetCallSubmitter provides a mock function with given fields: ctx, id +func (_m *Store) GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetCallSubmitter") + } + + var r0 *int32 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*int32, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) *int32); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*int32) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetCallSubmitter_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCallSubmitter' +type Store_GetCallSubmitter_Call struct { + *mock.Call +} + +// GetCallSubmitter is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) GetCallSubmitter(ctx interface{}, id interface{}) *Store_GetCallSubmitter_Call { + return &Store_GetCallSubmitter_Call{Call: _e.mock.On("GetCallSubmitter", ctx, id)} +} + +func (_c *Store_GetCallSubmitter_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetCallSubmitter_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_GetCallSubmitter_Call) Return(_a0 *int32, _a1 error) *Store_GetCallSubmitter_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetCallSubmitter_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*int32, error)) *Store_GetCallSubmitter_Call { + _c.Call.Return(run) + return _c +} + // GetDatabaseSize provides a mock function with given fields: ctx func (_m *Store) GetDatabaseSize(ctx context.Context) (string, error) { ret := _m.Called(ctx) @@ -1436,6 +1636,120 @@ func (_c *Store_GetIncidentCalls_Call) RunAndReturn(run func(context.Context, uu return _c } +// GetIncidentOwner provides a mock function with given fields: ctx, id +func (_m *Store) GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetIncidentOwner") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (int, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) int); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetIncidentOwner_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIncidentOwner' +type Store_GetIncidentOwner_Call struct { + *mock.Call +} + +// GetIncidentOwner is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Store_Expecter) GetIncidentOwner(ctx interface{}, id interface{}) *Store_GetIncidentOwner_Call { + return &Store_GetIncidentOwner_Call{Call: _e.mock.On("GetIncidentOwner", ctx, id)} +} + +func (_c *Store_GetIncidentOwner_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Store_GetIncidentOwner_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Store_GetIncidentOwner_Call) Return(_a0 int, _a1 error) *Store_GetIncidentOwner_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetIncidentOwner_Call) RunAndReturn(run func(context.Context, uuid.UUID) (int, error)) *Store_GetIncidentOwner_Call { + _c.Call.Return(run) + return _c +} + +// GetShare provides a mock function with given fields: ctx, id +func (_m *Store) GetShare(ctx context.Context, id string) (database.Share, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetShare") + } + + var r0 database.Share + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (database.Share, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) database.Share); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(database.Share) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetShare_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShare' +type Store_GetShare_Call struct { + *mock.Call +} + +// GetShare is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Store_Expecter) GetShare(ctx interface{}, id interface{}) *Store_GetShare_Call { + return &Store_GetShare_Call{Call: _e.mock.On("GetShare", ctx, id)} +} + +func (_c *Store_GetShare_Call) Run(run func(ctx context.Context, id string)) *Store_GetShare_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Store_GetShare_Call) Return(_a0 database.Share, _a1 error) *Store_GetShare_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetShare_Call) RunAndReturn(run func(context.Context, string) (database.Share, error)) *Store_GetShare_Call { + _c.Call.Return(run) + return _c +} + // GetSystemName provides a mock function with given fields: ctx, systemID func (_m *Store) GetSystemName(ctx context.Context, systemID int) (string, error) { ret := _m.Called(ctx, systemID) @@ -2433,63 +2747,6 @@ func (_c *Store_GetUserByID_Call) RunAndReturn(run func(context.Context, int) (d return _c } -// GetUserByUID provides a mock function with given fields: ctx, id -func (_m *Store) GetUserByUID(ctx context.Context, id int) (database.User, error) { - ret := _m.Called(ctx, id) - - if len(ret) == 0 { - panic("no return value specified for GetUserByUID") - } - - var r0 database.User - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int) (database.User, error)); ok { - return rf(ctx, id) - } - if rf, ok := ret.Get(0).(func(context.Context, int) database.User); ok { - r0 = rf(ctx, id) - } else { - r0 = ret.Get(0).(database.User) - } - - if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { - r1 = rf(ctx, id) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Store_GetUserByUID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserByUID' -type Store_GetUserByUID_Call struct { - *mock.Call -} - -// GetUserByUID is a helper method to define mock.On call -// - ctx context.Context -// - id int -func (_e *Store_Expecter) GetUserByUID(ctx interface{}, id interface{}) *Store_GetUserByUID_Call { - return &Store_GetUserByUID_Call{Call: _e.mock.On("GetUserByUID", ctx, id)} -} - -func (_c *Store_GetUserByUID_Call) Run(run func(ctx context.Context, id int)) *Store_GetUserByUID_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int)) - }) - return _c -} - -func (_c *Store_GetUserByUID_Call) Return(_a0 database.User, _a1 error) *Store_GetUserByUID_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Store_GetUserByUID_Call) RunAndReturn(run func(context.Context, int) (database.User, error)) *Store_GetUserByUID_Call { - _c.Call.Return(run) - return _c -} - // GetUserByUsername provides a mock function with given fields: ctx, username func (_m *Store) GetUserByUsername(ctx context.Context, username string) (database.User, error) { ret := _m.Called(ctx, username) @@ -2887,6 +3144,52 @@ func (_c *Store_ListIncidentsP_Call) RunAndReturn(run func(context.Context, data return _c } +// PruneShares provides a mock function with given fields: ctx +func (_m *Store) PruneShares(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for PruneShares") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Store_PruneShares_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PruneShares' +type Store_PruneShares_Call struct { + *mock.Call +} + +// PruneShares is a helper method to define mock.On call +// - ctx context.Context +func (_e *Store_Expecter) PruneShares(ctx interface{}) *Store_PruneShares_Call { + return &Store_PruneShares_Call{Call: _e.mock.On("PruneShares", ctx)} +} + +func (_c *Store_PruneShares_Call) Run(run func(ctx context.Context)) *Store_PruneShares_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Store_PruneShares_Call) Return(_a0 error) *Store_PruneShares_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Store_PruneShares_Call) RunAndReturn(run func(context.Context) error) *Store_PruneShares_Call { + _c.Call.Return(run) + return _c +} + // RemoveFromIncident provides a mock function with given fields: ctx, iD, callIds func (_m *Store) RemoveFromIncident(ctx context.Context, iD uuid.UUID, callIds []uuid.UUID) error { ret := _m.Called(ctx, iD, callIds) @@ -3505,6 +3808,65 @@ func (_c *Store_UpdateTalkgroup_Call) RunAndReturn(run func(context.Context, dat return _c } +// UpdateUser provides a mock function with given fields: ctx, username, email, isAdmin +func (_m *Store) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (database.User, error) { + ret := _m.Called(ctx, username, email, isAdmin) + + if len(ret) == 0 { + panic("no return value specified for UpdateUser") + } + + var r0 database.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) (database.User, error)); ok { + return rf(ctx, username, email, isAdmin) + } + if rf, ok := ret.Get(0).(func(context.Context, string, *string, *bool) database.User); ok { + r0 = rf(ctx, username, email, isAdmin) + } else { + r0 = ret.Get(0).(database.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, *string, *bool) error); ok { + r1 = rf(ctx, username, email, isAdmin) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_UpdateUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUser' +type Store_UpdateUser_Call struct { + *mock.Call +} + +// UpdateUser is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - email *string +// - isAdmin *bool +func (_e *Store_Expecter) UpdateUser(ctx interface{}, username interface{}, email interface{}, isAdmin interface{}) *Store_UpdateUser_Call { + return &Store_UpdateUser_Call{Call: _e.mock.On("UpdateUser", ctx, username, email, isAdmin)} +} + +func (_c *Store_UpdateUser_Call) Run(run func(ctx context.Context, username string, email *string, isAdmin *bool)) *Store_UpdateUser_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(*string), args[3].(*bool)) + }) + return _c +} + +func (_c *Store_UpdateUser_Call) Return(_a0 database.User, _a1 error) *Store_UpdateUser_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_UpdateUser_Call) RunAndReturn(run func(context.Context, string, *string, *bool) (database.User, error)) *Store_UpdateUser_Call { + _c.Call.Return(run) + return _c +} + // UpsertTalkgroup provides a mock function with given fields: ctx, arg func (_m *Store) UpsertTalkgroup(ctx context.Context, arg []database.UpsertTalkgroupParams) *database.UpsertTalkgroupBatchResults { ret := _m.Called(ctx, arg) diff --git a/pkg/database/models.go b/pkg/database/models.go index 66ff0ed..5080c95 100644 --- a/pkg/database/models.go +++ b/pkg/database/models.go @@ -58,6 +58,7 @@ type Call struct { type Incident struct { ID uuid.UUID `json:"id,omitempty"` Name string `json:"name,omitempty"` + Owner int `json:"owner,omitempty"` Description *string `json:"description,omitempty"` StartTime pgtype.Timestamptz `json:"start_time,omitempty"` EndTime pgtype.Timestamptz `json:"end_time,omitempty"` @@ -80,6 +81,14 @@ type Setting struct { Value []byte `json:"value,omitempty"` } +type Share struct { + ID string `json:"id,omitempty"` + EntityType string `json:"entity_type,omitempty"` + EntityID uuid.UUID `json:"entity_id,omitempty"` + Owner int `json:"owner,omitempty"` + Expiration pgtype.Timestamptz `json:"expiration,omitempty"` +} + type SweptCall struct { ID uuid.UUID `json:"id,omitempty"` Submitter *int32 `json:"submitter,omitempty"` diff --git a/pkg/database/partman/partman.go b/pkg/database/partman/partman.go index db9f3ca..d19683b 100644 --- a/pkg/database/partman/partman.go +++ b/pkg/database/partman/partman.go @@ -13,6 +13,7 @@ import ( "dynatron.me/x/stillbox/internal/isoweek" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -134,6 +135,7 @@ func New(db database.Store, cfg config.Partition) (*partman, error) { var _ PartitionManager = (*partman)(nil) func (pm *partman) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "partman"}) tick := time.NewTicker(CheckInterval) select { diff --git a/pkg/database/querier.go b/pkg/database/querier.go index 4ef8f5e..d0a8553 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -19,20 +19,26 @@ type Querier interface { CleanupSweptCalls(ctx context.Context, rangeStart pgtype.Timestamptz, rangeEnd pgtype.Timestamptz) (int64, error) CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error) CreateIncident(ctx context.Context, arg CreateIncidentParams) (Incident, error) + CreateShare(ctx context.Context, arg CreateShareParams) error CreateSystem(ctx context.Context, iD int, name string) error CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAPIKey(ctx context.Context, apiKey string) error + DeleteCall(ctx context.Context, id uuid.UUID) error DeleteIncident(ctx context.Context, id uuid.UUID) error + DeleteShare(ctx context.Context, id string) error DeleteSystem(ctx context.Context, id int) error DeleteTalkgroup(ctx context.Context, systemID int32, tGID int32) error DeleteUser(ctx context.Context, username string) error - GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) + GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error) GetAllTalkgroupTags(ctx context.Context) ([]string, error) GetAppPrefs(ctx context.Context, appName string, uid int) ([]byte, error) GetCallAudioByID(ctx context.Context, id uuid.UUID) (GetCallAudioByIDRow, error) + GetCallSubmitter(ctx context.Context, id uuid.UUID) (*int32, error) GetDatabaseSize(ctx context.Context) (string, error) GetIncident(ctx context.Context, id uuid.UUID) (Incident, error) GetIncidentCalls(ctx context.Context, id uuid.UUID) ([]GetIncidentCallsRow, error) + GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) + GetShare(ctx context.Context, id string) (Share, error) GetSystemName(ctx context.Context, systemID int) (string, error) GetTalkgroup(ctx context.Context, systemID int32, tGID int32) (GetTalkgroupRow, error) GetTalkgroupIDsByTags(ctx context.Context, anyTags []string, allTags []string, notTags []string) ([]GetTalkgroupIDsByTagsRow, error) @@ -47,13 +53,13 @@ type Querier interface { GetTalkgroupsWithLearnedCount(ctx context.Context, filter *string) (int64, error) GetTalkgroupsWithLearnedP(ctx context.Context, arg GetTalkgroupsWithLearnedPParams) ([]GetTalkgroupsWithLearnedPRow, error) GetUserByID(ctx context.Context, id int) (User, error) - GetUserByUID(ctx context.Context, id int) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) GetUsers(ctx context.Context) ([]User, error) ListCallsCount(ctx context.Context, arg ListCallsCountParams) (int64, error) ListCallsP(ctx context.Context, arg ListCallsPParams) ([]ListCallsPRow, error) ListIncidentsCount(ctx context.Context, start pgtype.Timestamptz, end pgtype.Timestamptz, filter *string) (int64, error) ListIncidentsP(ctx context.Context, arg ListIncidentsPParams) ([]ListIncidentsPRow, error) + PruneShares(ctx context.Context) error RemoveFromIncident(ctx context.Context, iD uuid.UUID, callIds []uuid.UUID) error RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) SetAppPrefs(ctx context.Context, appName string, prefs []byte, uid int) error @@ -67,6 +73,7 @@ type Querier interface { UpdateIncident(ctx context.Context, arg UpdateIncidentParams) (Incident, error) UpdatePassword(ctx context.Context, username string, password string) error UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) + UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error) UpsertTalkgroup(ctx context.Context, arg []UpsertTalkgroupParams) *UpsertTalkgroupBatchResults } diff --git a/pkg/database/share.sql.go b/pkg/database/share.sql.go new file mode 100644 index 0000000..b7b76a3 --- /dev/null +++ b/pkg/database/share.sql.go @@ -0,0 +1,84 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: share.sql + +package database + +import ( + "context" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +const createShare = `-- name: CreateShare :exec +INSERT INTO shares ( + id, + entity_type, + entity_id, + owner, + expiration +) VALUES ($1, $2, $3, $4, $5) +` + +type CreateShareParams struct { + ID string `json:"id"` + EntityType string `json:"entity_type"` + EntityID uuid.UUID `json:"entity_id"` + Owner int `json:"owner"` + Expiration pgtype.Timestamptz `json:"expiration"` +} + +func (q *Queries) CreateShare(ctx context.Context, arg CreateShareParams) error { + _, err := q.db.Exec(ctx, createShare, + arg.ID, + arg.EntityType, + arg.EntityID, + arg.Owner, + arg.Expiration, + ) + return err +} + +const deleteShare = `-- name: DeleteShare :exec +DELETE FROM shares WHERE id = $1 +` + +func (q *Queries) DeleteShare(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, deleteShare, id) + return err +} + +const getShare = `-- name: GetShare :one +SELECT + id, + entity_type, + entity_id, + owner, + expiration +FROM shares +WHERE id = $1 +` + +func (q *Queries) GetShare(ctx context.Context, id string) (Share, error) { + row := q.db.QueryRow(ctx, getShare, id) + var i Share + err := row.Scan( + &i.ID, + &i.EntityType, + &i.EntityID, + &i.Owner, + &i.Expiration, + ) + return i, err +} + +const pruneShares = `-- name: PruneShares :exec +DELETE FROM shares WHERE expiration < NOW() +` + +func (q *Queries) PruneShares(ctx context.Context) error { + _, err := q.db.Exec(ctx, pruneShares) + return err +} diff --git a/pkg/database/users.sql.go b/pkg/database/users.sql.go index 94d144f..74f2b43 100644 --- a/pkg/database/users.sql.go +++ b/pkg/database/users.sql.go @@ -7,6 +7,7 @@ package database import ( "context" + "time" "github.com/jackc/pgx/v5/pgtype" ) @@ -91,12 +92,32 @@ func (q *Queries) DeleteUser(ctx context.Context, username string) error { } const getAPIKey = `-- name: GetAPIKey :one -SELECT id, owner, created_at, expires, disabled, api_key FROM api_keys WHERE api_key = $1 +SELECT + a.id, + a.owner, + a.created_at, + a.expires, + a.disabled, + a.api_key, + u.username +FROM api_keys a +JOIN users u ON (a.owner = u.id) +WHERE api_key = $1 ` -func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) { +type GetAPIKeyRow struct { + ID int `json:"id"` + Owner int `json:"owner"` + CreatedAt time.Time `json:"created_at"` + Expires pgtype.Timestamp `json:"expires"` + Disabled *bool `json:"disabled"` + ApiKey string `json:"api_key"` + Username string `json:"username"` +} + +func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (GetAPIKeyRow, error) { row := q.db.QueryRow(ctx, getAPIKey, apiKey) - var i ApiKey + var i GetAPIKeyRow err := row.Scan( &i.ID, &i.Owner, @@ -104,6 +125,7 @@ func (q *Queries) GetAPIKey(ctx context.Context, apiKey string) (ApiKey, error) &i.Expires, &i.Disabled, &i.ApiKey, + &i.Username, ) return i, err } @@ -121,7 +143,7 @@ func (q *Queries) GetAppPrefs(ctx context.Context, appName string, uid int) ([]b const getUserByID = `-- name: GetUserByID :one SELECT id, username, password, email, is_admin, prefs FROM users -WHERE id = $1 LIMIT 1 +WHERE id = $1 ` func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) { @@ -138,28 +160,9 @@ func (q *Queries) GetUserByID(ctx context.Context, id int) (User, error) { return i, err } -const getUserByUID = `-- name: GetUserByUID :one -SELECT id, username, password, email, is_admin, prefs FROM users -WHERE id = $1 LIMIT 1 -` - -func (q *Queries) GetUserByUID(ctx context.Context, id int) (User, error) { - row := q.db.QueryRow(ctx, getUserByUID, id) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Password, - &i.Email, - &i.IsAdmin, - &i.Prefs, - ) - return i, err -} - const getUserByUsername = `-- name: GetUserByUsername :one SELECT id, username, password, email, is_admin, prefs FROM users -WHERE username = $1 LIMIT 1 +WHERE username = $1 ` func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { @@ -224,3 +227,26 @@ func (q *Queries) UpdatePassword(ctx context.Context, username string, password _, err := q.db.Exec(ctx, updatePassword, username, password) return err } + +const updateUser = `-- name: UpdateUser :one +UPDATE users SET + email = COALESCE($2, email), + is_admin = COALESCE($3, is_admin) +WHERE + username = $1 +RETURNING id, username, password, email, is_admin, prefs +` + +func (q *Queries) UpdateUser(ctx context.Context, username string, email *string, isAdmin *bool) (User, error) { + row := q.db.QueryRow(ctx, updateUser, username, email, isAdmin) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Password, + &i.Email, + &i.IsAdmin, + &i.Prefs, + ) + return i, err +} diff --git a/pkg/incidents/incident.go b/pkg/incidents/incident.go index c2ee068..b48f152 100644 --- a/pkg/incidents/incident.go +++ b/pkg/incidents/incident.go @@ -5,11 +5,14 @@ import ( "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" ) type Incident struct { ID uuid.UUID `json:"id"` + Owner users.UserID `json:"owner"` Name string `json:"name"` Description *string `json:"description"` StartTime *jsontypes.Time `json:"startTime"` @@ -19,6 +22,10 @@ type Incident struct { Calls []IncidentCall `json:"calls"` } +func (inc *Incident) GetResourceName() string { + return rbac.ResourceIncident +} + type IncidentCall struct { calls.Call Notes json.RawMessage `json:"notes"` diff --git a/pkg/incidents/incstore/store.go b/pkg/incidents/incstore/store.go index 0a4ba6e..1ba52fa 100644 --- a/pkg/incidents/incstore/store.go +++ b/pkg/incidents/incstore/store.go @@ -6,10 +6,11 @@ import ( "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/jsontypes" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/incidents" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" "github.com/jackc/pgx/v5" ) @@ -72,6 +73,11 @@ func NewStore() Store { } func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*incidents.Incident, error) { + user, err := users.UserCheck(ctx, new(incidents.Incident), "create") + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) var dbInc database.Incident @@ -81,6 +87,7 @@ func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*in var err error dbInc, err = db.CreateIncident(ctx, database.CreateIncidentParams{ ID: id, + Owner: user.ID.Int(), Name: inc.Name, Description: inc.Description, StartTime: inc.StartTime.PGTypeTSTZ(), @@ -125,6 +132,16 @@ func (s *store) CreateIncident(ctx context.Context, inc incidents.Incident) (*in } func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID, addCallIDs []uuid.UUID, notes []byte, removeCallIDs []uuid.UUID) error { + inc, err := s.getIncidentOwner(ctx, incidentID) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionUpdate)) + if err != nil { + return err + } + return database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { if len(addCallIDs) > 0 { var noteAr [][]byte @@ -153,6 +170,10 @@ func (s *store) AddRemoveIncidentCalls(ctx context.Context, incidentID uuid.UUID } func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incident, totalCount int, err error) { + _, err = rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, 0, err + } db := database.FromCtx(ctx) offset, perPage := p.Pagination.OffsetPerPage(100) @@ -196,6 +217,7 @@ func (s *store) Incidents(ctx context.Context, p IncidentsParams) (incs []Incide func fromDBIncident(id uuid.UUID, d database.Incident) incidents.Incident { return incidents.Incident{ ID: id, + Owner: users.UserID(d.Owner), Name: d.Name, Description: d.Description, StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), @@ -214,6 +236,7 @@ func fromDBListInPRow(id uuid.UUID, d database.ListIncidentsPRow) Incident { return Incident{ Incident: incidents.Incident{ ID: id, + Owner: users.UserID(d.Owner), Name: d.Name, Description: d.Description, StartTime: jsontypes.TimePtrFromTSTZ(d.StartTime), @@ -228,7 +251,7 @@ func fromDBCalls(d []database.GetIncidentCallsRow) []incidents.IncidentCall { r := make([]incidents.IncidentCall, 0, len(d)) for _, v := range d { dur := calls.CallDuration(time.Duration(common.ZeroIfNil(v.Duration)) * time.Millisecond) - sub := common.PtrTo(auth.UserID(common.ZeroIfNil(v.Submitter))) + sub := common.PtrTo(users.UserID(common.ZeroIfNil(v.Submitter))) r = append(r, incidents.IncidentCall{ Call: calls.Call{ ID: v.CallID, @@ -252,6 +275,11 @@ func fromDBCalls(d []database.GetIncidentCallsRow) []incidents.IncidentCall { } func (s *store) Incident(ctx context.Context, id uuid.UUID) (*incidents.Incident, error) { + _, err := rbac.Check(ctx, new(incidents.Incident), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + var r incidents.Incident txErr := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { inc, err := db.GetIncident(ctx, id) @@ -298,6 +326,16 @@ func (uip UpdateIncidentParams) toDBUIP(id uuid.UUID) database.UpdateIncidentPar } func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncidentParams) (*incidents.Incident, error) { + ckinc, err := s.getIncidentOwner(ctx, id) + if err != nil { + return nil, err + } + + _, err = rbac.Check(ctx, &ckinc, rbac.WithActions(rbac.ActionUpdate)) + if err != nil { + return nil, err + } + db := database.FromCtx(ctx) dbInc, err := db.UpdateIncident(ctx, p.toDBUIP(id)) @@ -311,9 +349,24 @@ func (s *store) UpdateIncident(ctx context.Context, id uuid.UUID, p UpdateIncide } func (s *store) DeleteIncident(ctx context.Context, id uuid.UUID) error { + inc, err := s.getIncidentOwner(ctx, id) + if err != nil { + return err + } + + _, err = rbac.Check(ctx, &inc, rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + return database.FromCtx(ctx).DeleteIncident(ctx, id) } func (s *store) UpdateNotes(ctx context.Context, incidentID uuid.UUID, callID uuid.UUID, notes []byte) error { return database.FromCtx(ctx).UpdateCallIncidentNotes(ctx, notes, incidentID, callID) } + +func (s *store) getIncidentOwner(ctx context.Context, id uuid.UUID) (incidents.Incident, error) { + owner, err := database.FromCtx(ctx).GetIncidentOwner(ctx, id) + return incidents.Incident{ID: id, Owner: users.UserID(owner)}, err +} diff --git a/pkg/nexus/nexus.go b/pkg/nexus/nexus.go index 14d52ad..fcfe056 100644 --- a/pkg/nexus/nexus.go +++ b/pkg/nexus/nexus.go @@ -6,6 +6,7 @@ import ( "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/pb" + "dynatron.me/x/stillbox/pkg/rbac" "github.com/rs/zerolog/log" ) @@ -38,6 +39,7 @@ func New() *Nexus { } func (n *Nexus) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "nexus"}) for { select { case call, ok := <-n.callCh: diff --git a/pkg/rbac/mocks/RBAC.go b/pkg/rbac/mocks/RBAC.go new file mode 100644 index 0000000..d7de98b --- /dev/null +++ b/pkg/rbac/mocks/RBAC.go @@ -0,0 +1,113 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + rbac "dynatron.me/x/stillbox/pkg/rbac" + mock "github.com/stretchr/testify/mock" + + restrict "github.com/el-mike/restrict/v2" +) + +// RBAC is an autogenerated mock type for the RBAC type +type RBAC struct { + mock.Mock +} + +type RBAC_Expecter struct { + mock *mock.Mock +} + +func (_m *RBAC) EXPECT() *RBAC_Expecter { + return &RBAC_Expecter{mock: &_m.Mock} +} + +// Check provides a mock function with given fields: ctx, res, opts +func (_m *RBAC) Check(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption) (rbac.Subject, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, res) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Check") + } + + var r0 rbac.Subject + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)); ok { + return rf(ctx, res, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, restrict.Resource, ...rbac.CheckOption) rbac.Subject); ok { + r0 = rf(ctx, res, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(rbac.Subject) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, restrict.Resource, ...rbac.CheckOption) error); ok { + r1 = rf(ctx, res, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RBAC_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check' +type RBAC_Check_Call struct { + *mock.Call +} + +// Check is a helper method to define mock.On call +// - ctx context.Context +// - res restrict.Resource +// - opts ...rbac.CheckOption +func (_e *RBAC_Expecter) Check(ctx interface{}, res interface{}, opts ...interface{}) *RBAC_Check_Call { + return &RBAC_Check_Call{Call: _e.mock.On("Check", + append([]interface{}{ctx, res}, opts...)...)} +} + +func (_c *RBAC_Check_Call) Run(run func(ctx context.Context, res restrict.Resource, opts ...rbac.CheckOption)) *RBAC_Check_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]rbac.CheckOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(rbac.CheckOption) + } + } + run(args[0].(context.Context), args[1].(restrict.Resource), variadicArgs...) + }) + return _c +} + +func (_c *RBAC_Check_Call) Return(_a0 rbac.Subject, _a1 error) *RBAC_Check_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RBAC_Check_Call) RunAndReturn(run func(context.Context, restrict.Resource, ...rbac.CheckOption) (rbac.Subject, error)) *RBAC_Check_Call { + _c.Call.Return(run) + return _c +} + +// NewRBAC creates a new instance of RBAC. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRBAC(t interface { + mock.TestingT + Cleanup(func()) +}) *RBAC { + mock := &RBAC{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/rbac/rbac.go b/pkg/rbac/rbac.go new file mode 100644 index 0000000..b626f34 --- /dev/null +++ b/pkg/rbac/rbac.go @@ -0,0 +1,421 @@ +package rbac + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/el-mike/restrict/v2" + "github.com/el-mike/restrict/v2/adapters" +) + +const ( + RoleUser = "User" + RoleSubmitter = "Submitter" + RoleAdmin = "Admin" + RoleSystem = "System" + RolePublic = "Public" + RoleShareGuest = "ShareGuest" + + ResourceCall = "Call" + ResourceIncident = "Incident" + ResourceTalkgroup = "Talkgroup" + ResourceAlert = "Alert" + ResourceShare = "Share" + ResourceAPIKey = "APIKey" + + ActionRead = "read" + ActionCreate = "create" + ActionUpdate = "update" + ActionDelete = "delete" + + PresetUpdateOwn = "updateOwn" + PresetDeleteOwn = "deleteOwn" + PresetReadShared = "readShared" + + PresetUpdateSubmitter = "updateSubmitter" + PresetDeleteSubmitter = "deleteSubmitter" +) + +var ( + ErrBadSubject = errors.New("bad subject in token") +) + +type subjectContextKey string + +const SubjectCtxKey subjectContextKey = "sub" + +func CtxWithSubject(ctx context.Context, sub Subject) context.Context { + return context.WithValue(ctx, SubjectCtxKey, sub) +} + +func ErrAccessDenied(err error) *restrict.AccessDeniedError { + if accessErr, ok := err.(*restrict.AccessDeniedError); ok { + return accessErr + } + + return nil +} + +func SubjectFrom(ctx context.Context) Subject { + sub, ok := ctx.Value(SubjectCtxKey).(Subject) + if ok { + return sub + } + + return new(PublicSubject) +} + +type rbacCtxKey string + +const RBACCtxKey rbacCtxKey = "rbac" + +func FromCtx(ctx context.Context) RBAC { + rbac, ok := ctx.Value(RBACCtxKey).(RBAC) + if !ok { + panic("no RBAC in context") + } + + return rbac +} + +func CtxWithRBAC(ctx context.Context, rbac RBAC) context.Context { + return context.WithValue(ctx, RBACCtxKey, rbac) +} + +var ( + ErrNotAuthorized = errors.New("not authorized") +) + +var policy = &restrict.PolicyDefinition{ + Roles: restrict.Roles{ + RoleUser: { + Description: "An authenticated user", + Grants: restrict.GrantsMap{ + ResourceIncident: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateOwn}, + &restrict.Permission{Preset: PresetDeleteOwn}, + }, + ResourceCall: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateSubmitter}, + &restrict.Permission{Preset: PresetDeleteSubmitter}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionRead}, + }, + ResourceShare: { + &restrict.Permission{Action: ActionRead}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Preset: PresetUpdateOwn}, + &restrict.Permission{Preset: PresetDeleteOwn}, + }, + }, + }, + RoleSubmitter: { + Description: "A role that can submit calls", + Grants: restrict.GrantsMap{ + ResourceCall: { + &restrict.Permission{Action: ActionCreate}, + }, + ResourceTalkgroup: { + // for learning TGs + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Action: ActionUpdate}, + }, + }, + }, + RoleShareGuest: { + Description: "Someone who has a valid share link", + Grants: restrict.GrantsMap{ + ResourceCall: { + &restrict.Permission{Preset: PresetReadShared}, + }, + ResourceIncident: { + &restrict.Permission{Preset: PresetReadShared}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionRead}, + }, + }, + }, + RoleAdmin: { + Parents: []string{RoleUser}, + Grants: restrict.GrantsMap{ + ResourceIncident: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionDelete}, + }, + ResourceCall: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionDelete}, + }, + ResourceTalkgroup: { + &restrict.Permission{Action: ActionUpdate}, + &restrict.Permission{Action: ActionCreate}, + &restrict.Permission{Action: ActionDelete}, + }, + }, + }, + RoleSystem: { + Parents: []string{RoleSystem}, + }, + RolePublic: { + /* + Grants: restrict.GrantsMap{ + ResourceShare: { + &restrict.Permission{Action: ActionRead}, + }, + }, + */ + }, + }, + PermissionPresets: restrict.PermissionPresets{ + PresetUpdateOwn: &restrict.Permission{ + Action: ActionUpdate, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetDeleteOwn: &restrict.Permission{ + Action: ActionDelete, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetUpdateSubmitter: &restrict.Permission{ + Action: ActionUpdate, + Conditions: restrict.Conditions{ + &SubmitterEqualCondition{ + ID: "isSubmitter", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Submitter", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetDeleteSubmitter: &restrict.Permission{ + Action: ActionDelete, + Conditions: restrict.Conditions{ + &SubmitterEqualCondition{ + ID: "isSubmitter", + Left: &restrict.ValueDescriptor{ + Source: restrict.ResourceField, + Field: "Submitter", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + PresetReadShared: &restrict.Permission{ + Action: ActionRead, + Conditions: restrict.Conditions{ + &restrict.EqualCondition{ + ID: "isOwner", + Left: &restrict.ValueDescriptor{ + Source: restrict.ContextField, + Field: "Owner", + }, + Right: &restrict.ValueDescriptor{ + Source: restrict.SubjectField, + Field: "ID", + }, + }, + }, + }, + }, +} + +type checkOptions struct { + actions []string + context restrict.Context +} + +type CheckOption func(*checkOptions) + +func WithActions(actions ...string) CheckOption { + return func(o *checkOptions) { + o.actions = append(o.actions, actions...) + } +} + +func WithContext(ctx restrict.Context) CheckOption { + return func(o *checkOptions) { + o.context = ctx + } +} + +func UseResource(rsc string) restrict.Resource { + return restrict.UseResource(rsc) +} + +type Subject interface { + restrict.Subject + GetName() string +} + +type Resource interface { + restrict.Resource +} + +type RBAC interface { + Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) +} + +type rbac struct { + policy *restrict.PolicyManager + access *restrict.AccessManager +} + +func New() (*rbac, error) { + adapter := adapters.NewInMemoryAdapter(policy) + polMan, err := restrict.NewPolicyManager(adapter, true) + if err != nil { + return nil, err + } + + accMan := restrict.NewAccessManager(polMan) + return &rbac{ + policy: polMan, + access: accMan, + }, nil +} + +// Check is a convenience function to pull the RBAC instance out of ctx and Check. +func Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) { + return FromCtx(ctx).Check(ctx, res, opts...) +} + +func (r *rbac) Check(ctx context.Context, res restrict.Resource, opts ...CheckOption) (Subject, error) { + sub := SubjectFrom(ctx) + o := checkOptions{} + + for _, opt := range opts { + opt(&o) + } + + req := &restrict.AccessRequest{ + Subject: sub, + Resource: res, + Actions: o.actions, + Context: o.context, + } + + return sub, r.access.Authorize(req) +} + +type ShareLinkGuest struct { + ShareID string +} + +func (s *ShareLinkGuest) GetName() string { + return "SHARE:" + s.ShareID +} + +func (s *ShareLinkGuest) GetRoles() []string { + return []string{RoleShareGuest} +} + +type PublicSubject struct { + RemoteAddr string +} + +func (s *PublicSubject) GetName() string { + return "PUBLIC:" + s.RemoteAddr +} + +func (s *PublicSubject) GetRoles() []string { + return []string{RolePublic} +} + +type SystemServiceSubject struct { + Name string +} + +func (s *SystemServiceSubject) GetName() string { + return "SYSTEM:" + s.Name +} + +func (s *SystemServiceSubject) GetRoles() []string { + return []string{RoleSystem} +} + +const ( + SubmitterEqualConditionType = "SUBMITTER_EQUAL" +) + +type SubmitterEqualCondition struct { + ID string `json:"name,omitempty" yaml:"name,omitempty"` + Left *restrict.ValueDescriptor `json:"left" yaml:"left"` + Right *restrict.ValueDescriptor `json:"right" yaml:"right"` +} + +func (s *SubmitterEqualCondition) Type() string { + return SubmitterEqualConditionType +} + +func (c *SubmitterEqualCondition) Check(r *restrict.AccessRequest) error { + left, err := c.Left.GetValue(r) + if err != nil { + return err + } + + right, err := c.Right.GetValue(r) + if err != nil { + return err + } + + lVal := reflect.ValueOf(left) + rVal := reflect.ValueOf(right) + + // deref Left. this is the difference between us and EqualCondition + for lVal.Kind() == reflect.Pointer { + lVal = lVal.Elem() + } + + if !lVal.IsValid() || !reflect.DeepEqual(rVal.Interface(), lVal.Interface()) { + return restrict.NewConditionNotSatisfiedError(c, r, fmt.Errorf("values \"%v\" and \"%v\" are not equal", left, right)) + } + + return nil +} + +func SubmitterEqualConditionFactory() restrict.Condition { + return new(SubmitterEqualCondition) +} + +func init() { + restrict.RegisterConditionFactory(SubmitterEqualConditionType, SubmitterEqualConditionFactory) +} diff --git a/pkg/rbac/rbac_test.go b/pkg/rbac/rbac_test.go new file mode 100644 index 0000000..582732d --- /dev/null +++ b/pkg/rbac/rbac_test.go @@ -0,0 +1,197 @@ +package rbac_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/incidents" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/talkgroups" + "dynatron.me/x/stillbox/pkg/users" + "github.com/el-mike/restrict/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRBAC(t *testing.T) { + tests := []struct { + name string + subject rbac.Subject + resource rbac.Resource + action string + expectErr error + }{ + { + name: "admin update talkgroup", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &talkgroups.Talkgroup{}, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "admin update incident", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 4, + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update incident not owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 4, + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Incident"`), + }, + { + name: "user update incident owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 2, + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user delete incident not owner", + subject: &users.User{ + ID: 2, + }, + resource: &incidents.Incident{ + Name: "test incident", + Owner: 6, + }, + action: rbac.ActionDelete, + expectErr: errors.New(`access denied for Action: "delete" on Resource: "Incident"`), + }, + { + name: "admin update call", + subject: &users.User{ + ID: 2, + IsAdmin: true, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(4)), + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update call not owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(4)), + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`), + }, + { + name: "user update call owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(2)), + }, + action: rbac.ActionUpdate, + expectErr: nil, + }, + { + name: "user update call nil submitter", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: nil, + }, + action: rbac.ActionUpdate, + expectErr: errors.New(`access denied for Action: "update" on Resource: "Call"`), + }, + { + name: "user delete call not owner", + subject: &users.User{ + ID: 2, + }, + resource: &calls.Call{ + Submitter: common.PtrTo(users.UserID(6)), + }, + action: rbac.ActionDelete, + expectErr: errors.New(`access denied for Action: "delete" on Resource: "Call"`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := rbac.CtxWithSubject(context.Background(), tc.subject) + rb, err := rbac.New() + require.NoError(t, err) + sub, err := rb.Check(ctx, tc.resource, rbac.WithActions(tc.action)) + if tc.expectErr != nil { + assert.Equal(t, tc.expectErr.Error(), err.Error()) + } else { + if !assert.NoError(t, err) { + accErr(err) + } + } + assert.Equal(t, tc.subject, sub) + }) + } +} + +func accErr(err error) { + if accessError, ok := err.(*restrict.AccessDeniedError); ok { + // Error() implementation. Returns a message in a form: "access denied for Action/s: ... on Resource: ..." + fmt.Println(accessError) + // Returns an AccessRequest that failed. + fmt.Println(accessError.Request) + // Returns first reason for the denied access. + // Especially helpful in fail-early mode, where there will only be one Reason. + fmt.Println(accessError.FirstReason()) + + // Reasons property will hold all errors that caused the access to be denied. + for _, permissionErr := range accessError.Reasons { + fmt.Println(permissionErr) + fmt.Println(permissionErr.Action) + fmt.Println(permissionErr.RoleName) + fmt.Println(permissionErr.ResourceName) + + // Returns first ConditionNotSatisfied error for given PermissionError, if any was returned for given PermissionError. + // Especially helpful in fail-early mode, where there will only be one failed Condition. + fmt.Println(permissionErr.FirstConditionError()) + + // ConditionErrors property will hold all ConditionNotSatisfied errors. + for _, conditionErr := range permissionErr.ConditionErrors { + fmt.Println(conditionErr) + fmt.Println(conditionErr.Reason) + + // Every ConditionNotSatisfied contains an instance of Condition that returned it, + // so it can be tested using type assertion to get more details about failed Condition. + if emptyCondition, ok := conditionErr.Condition.(*restrict.EmptyCondition); ok { + fmt.Println(emptyCondition.ID) + } + } + } + } +} diff --git a/pkg/rest/api.go b/pkg/rest/api.go index 50bed21..e94194b 100644 --- a/pkg/rest/api.go +++ b/pkg/rest/api.go @@ -6,6 +6,7 @@ import ( "net/url" "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "github.com/go-chi/chi/v5" @@ -37,6 +38,7 @@ func (a *api) Subrouter() http.Handler { r.Mount("/call", new(callsAPI).Subrouter()) r.Mount("/user", new(usersAPI).Subrouter()) r.Mount("/incident", newIncidentsAPI(&a.baseURL).Subrouter()) + r.Mount("/share", newShareHandler(&a.baseURL).Subrouter()) return r } @@ -82,6 +84,14 @@ func unauthErrText(err error) render.Renderer { } } +func forbiddenErrText(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusForbidden, + Error: "Forbidden: " + err.Error(), + } +} + func constraintErrText(err error) render.Renderer { return &errResponse{ Err: err, @@ -127,9 +137,10 @@ var statusMapping = map[error]errResponder{ ErrTGIDMismatch: badRequestErrText, ErrSysMismatch: badRequestErrText, tgstore.ErrReference: constraintErrText, - ErrBadUID: unauthErrText, + rbac.ErrBadSubject: unauthErrText, ErrBadAppName: unauthErrText, common.ErrPageOutOfRange: badRequestErrText, + rbac.ErrNotAuthorized: unauthErrText, } func autoError(err error) render.Renderer { @@ -144,6 +155,10 @@ func autoError(err error) render.Renderer { } } + if rbac.ErrAccessDenied(err) != nil { + return forbiddenErrText(err) + } + return internalError(err) } diff --git a/pkg/rest/share.go b/pkg/rest/share.go new file mode 100644 index 0000000..b829f52 --- /dev/null +++ b/pkg/rest/share.go @@ -0,0 +1,231 @@ +package rest + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/internal/jsontypes" + "dynatron.me/x/stillbox/pkg/incidents" + "dynatron.me/x/stillbox/pkg/incidents/incstore" + "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +type shareAPI struct { + baseURL *url.URL +} + +func newShareHandler(baseURL *url.URL) API { + return &shareAPI{baseURL} +} + +func (ia *shareAPI) Subrouter() http.Handler { + r := chi.NewMux() + + //r.Get(`/{id:[A-Za-z0-9_-]{20,}}`, ia.getShare) + //r.Post('/create', ia.createShare) + //r.Delete(`/{id:[A-Za-z0-9_-]{20,}}`, ia.deleteShare) + //r.Get(`/`, ia.getShares) + + return r +} + +func (ia *shareAPI) listIncidents(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + p := incstore.IncidentsParams{} + err := forms.Unmarshal(r, &p, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + res := struct { + Incidents []incstore.Incident `json:"incidents"` + Count int `json:"count"` + }{} + + res.Incidents, res.Count, err = incs.Incidents(ctx, p) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, res) +} + +func (ia *shareAPI) createIncident(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + p := incidents.Incident{} + err := forms.Unmarshal(r, &p, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + inc, err := incs.CreateIncident(ctx, p) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, inc) +} + +func (ia *shareAPI) getIncident(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + id, err := idOnlyParam(w, r) + if err != nil { + return + } + + inc, err := incs.Incident(ctx, id) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, inc) +} + +func (ia *shareAPI) updateIncident(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + id, err := idOnlyParam(w, r) + if err != nil { + return + } + + p := incstore.UpdateIncidentParams{} + err = forms.Unmarshal(r, &p, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + inc, err := incs.UpdateIncident(ctx, id, p) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, inc) +} + +func (ia *shareAPI) deleteIncident(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + urlParams := struct { + ID uuid.UUID `param:"id"` + }{} + + err := decodeParams(&urlParams, r) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + err = incs.DeleteIncident(ctx, urlParams.ID) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +type CallIncidentParams2 struct { + Add jsontypes.UUIDs `json:"add"` + Notes json.RawMessage `json:"notes"` + + Remove jsontypes.UUIDs `json:"remove"` +} + +func (ia *shareAPI) postCalls(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + + id, err := idOnlyParam(w, r) + if err != nil { + return + } + + p := CallIncidentParams2{} + err = forms.Unmarshal(r, &p, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + err = incs.AddRemoveIncidentCalls(ctx, id, p.Add.UUIDs(), p.Notes, p.Remove.UUIDs()) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (ia *shareAPI) getCallsM3U(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + incs := incstore.FromCtx(ctx) + tgst := tgstore.FromCtx(ctx) + + id, err := idOnlyParam(w, r) + if err != nil { + return + } + + inc, err := incs.Incident(ctx, id) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + b := new(bytes.Buffer) + + callUrl := common.PtrTo(*ia.baseURL) + + b.WriteString("#EXTM3U\n\n") + for _, c := range inc.Calls { + tg, err := tgst.TG(ctx, c.TalkgroupTuple()) + if err != nil { + wErr(w, r, autoError(err)) + return + } + var from string + if c.Source != 0 { + from = fmt.Sprintf(" from %d", c.Source) + } + + callUrl.Path = "/api/call/" + c.ID.String() + + fmt.Fprintf(b, "#EXTINF:%d,%s%s (%s)\n%s\n\n", + c.Duration.Seconds(), + tg.StringTag(true), + from, + c.DateTime.Format("15:04 01/02"), + callUrl, + ) + } + + // Not a lot of agreement on which MIME type to use for non-HLS m3u, + // let's hope this is good enough + w.Header().Set("Content-Type", "audio/x-mpegurl") + w.WriteHeader(http.StatusOK) + _, _ = b.WriteTo(w) +} diff --git a/pkg/rest/users.go b/pkg/rest/users.go index 704088b..1fa6703 100644 --- a/pkg/rest/users.go +++ b/pkg/rest/users.go @@ -7,13 +7,13 @@ import ( "strings" "dynatron.me/x/stillbox/pkg/auth" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" ) var ( - ErrBadUID = errors.New("bad UID in token") ErrBadAppName = errors.New("bad app name") ) @@ -32,10 +32,10 @@ func (ua *usersAPI) Subrouter() http.Handler { func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - uid := auth.UIDFrom(ctx) + username := auth.UsernameFrom(ctx) - if uid == nil { - wErr(w, r, autoError(ErrBadUID)) + if username == nil { + wErr(w, r, autoError(rbac.ErrBadSubject)) return } @@ -55,7 +55,7 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { } us := users.FromCtx(ctx) - prefs, err := us.UserPrefs(ctx, *uid, *p.AppName) + prefs, err := us.UserPrefs(ctx, *username, *p.AppName) if err != nil { wErr(w, r, autoError(err)) return @@ -67,10 +67,10 @@ func (ua *usersAPI) getPrefs(w http.ResponseWriter, r *http.Request) { func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - uid := auth.UIDFrom(ctx) + username := auth.UsernameFrom(ctx) - if uid == nil { - wErr(w, r, autoError(ErrBadUID)) + if username == nil { + wErr(w, r, autoError(rbac.ErrBadSubject)) return } @@ -102,7 +102,7 @@ func (ua *usersAPI) putPrefs(w http.ResponseWriter, r *http.Request) { } us := users.FromCtx(ctx) - err = us.SetUserPrefs(ctx, *uid, *p.AppName, prefs) + err = us.SetUserPrefs(ctx, *username, *p.AppName, prefs) if err != nil { wErr(w, r, autoError(err)) return diff --git a/pkg/server/ingest.go b/pkg/server/ingest.go index 8f42950..a8c766e 100644 --- a/pkg/server/ingest.go +++ b/pkg/server/ingest.go @@ -7,5 +7,6 @@ import ( ) func (s *Server) Ingest(ctx context.Context, call *calls.Call) error { - return s.sinks.EmitCall(context.Background(), call) + ctx = context.WithoutCancel(ctx) + return s.sinks.EmitCall(ctx, call) } diff --git a/pkg/server/routes.go b/pkg/server/routes.go index a7f0788..cb5f26e 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -47,9 +47,11 @@ func (s *Server) setupRoutes() { }) r.Group(func(r chi.Router) { - // auth routes get rate-limited heavily, but not using middleware + // auth/share routes get rate-limited heavily, but not using middleware + s.rateLimit(r) r.Use(render.SetContentType(render.ContentTypeJSON)) s.auth.PublicRoutes(r) + // r.Mount("/share", s.share.ShareRouter(s.rest)) }) r.Group(func(r chi.Router) { @@ -66,7 +68,7 @@ func (s *Server) setupRoutes() { func (s *Server) WithCtxStores() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(s.addStoresTo(r.Context())) + r = r.WithContext(s.fillCtx(r.Context())) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) diff --git a/pkg/server/server.go b/pkg/server/server.go index 0ea300d..6c51931 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -15,7 +15,9 @@ import ( "dynatron.me/x/stillbox/pkg/incidents/incstore" "dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rbac" "dynatron.me/x/stillbox/pkg/rest" + "dynatron.me/x/stillbox/pkg/share" "dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" @@ -48,6 +50,8 @@ type Server struct { users users.Store calls callstore.Store incidents incstore.Store + share share.Service + rbac rbac.RBAC } func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { @@ -63,16 +67,23 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { r := chi.NewRouter() - authenticator := auth.NewAuthenticator(cfg.Auth) + ust := users.NewStore(db) + + authenticator := auth.NewAuthenticator(cfg.Auth, ust) notifier, err := notify.New(cfg.Notify) if err != nil { return nil, err } - tgCache := tgstore.NewCache() + tgCache := tgstore.NewCache(db) api := rest.New(cfg.BaseURL.URL()) + rbacSvc, err := rbac.New() + if err != nil { + return nil, err + } + srv := &Server{ auth: authenticator, conf: cfg, @@ -85,9 +96,11 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { tgs: tgCache, sinks: sinks.NewSinkManager(), rest: api, - users: users.NewStore(), - calls: callstore.NewStore(), + share: share.NewService(), + users: ust, + calls: callstore.NewStore(db), incidents: incstore.NewStore(), + rbac: rbacSvc, } if cfg.DB.Partition.Enabled { @@ -102,7 +115,7 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { } } - srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db, tgCache), true) + srv.sinks.Register("database", sinks.NewDatabaseSink(db, tgCache), true) srv.sinks.Register("nexus", sinks.NewNexusSink(srv.nex), false) if srv.alerter.Enabled() { @@ -135,12 +148,14 @@ func New(ctx context.Context, cfg *config.Configuration) (*Server, error) { return srv, nil } -func (s *Server) addStoresTo(ctx context.Context) context.Context { +func (s *Server) fillCtx(ctx context.Context) context.Context { ctx = database.CtxWithDB(ctx, s.db) ctx = tgstore.CtxWithStore(ctx, s.tgs) ctx = users.CtxWithStore(ctx, s.users) ctx = callstore.CtxWithStore(ctx, s.calls) ctx = incstore.CtxWithStore(ctx, s.incidents) + ctx = share.CtxWithStore(ctx, s.share.ShareStore()) + ctx = rbac.CtxWithRBAC(ctx, s.rbac) return ctx } @@ -150,7 +165,7 @@ func (s *Server) Go(ctx context.Context) error { s.installHupHandler() - ctx = s.addStoresTo(ctx) + ctx = s.fillCtx(ctx) httpSrv := &http.Server{ Addr: s.conf.Listen, @@ -159,6 +174,7 @@ func (s *Server) Go(ctx context.Context) error { go s.nex.Go(ctx) go s.alerter.Go(ctx) + go s.share.Go(ctx) if pm := s.partman; pm != nil { go pm.Go(ctx) diff --git a/pkg/share/service.go b/pkg/share/service.go new file mode 100644 index 0000000..eea1edd --- /dev/null +++ b/pkg/share/service.go @@ -0,0 +1,52 @@ +package share + +import ( + "context" + "time" + + "dynatron.me/x/stillbox/pkg/rbac" + "github.com/rs/zerolog/log" +) + +const ( + PruneInterval = time.Hour * 4 +) + +type Service interface { + ShareStore() Store + + Go(ctx context.Context) +} + +type service struct { + Store +} + +func (s *service) ShareStore() Store { + return s.Store +} + +func (s *service) Go(ctx context.Context) { + ctx = rbac.CtxWithSubject(ctx, &rbac.SystemServiceSubject{Name: "share"}) + + tick := time.NewTicker(PruneInterval) + + for { + select { + case <-tick.C: + err := s.Prune(ctx) + if err != nil { + log.Error().Err(err).Msg("share prune failed") + } + case <-ctx.Done(): + tick.Stop() + return + } + } +} + +func NewService() *service { + return &service{ + Store: NewStore(), + } +} diff --git a/pkg/share/share.go b/pkg/share/share.go new file mode 100644 index 0000000..8865d71 --- /dev/null +++ b/pkg/share/share.go @@ -0,0 +1,61 @@ +package share + +import ( + "context" + "time" + + "dynatron.me/x/stillbox/internal/jsontypes" + + "github.com/google/uuid" + "github.com/matoous/go-nanoid" +) + +const ( + SlugLength = 20 +) + +type EntityType string + +const ( + EntityIncident EntityType = "incident" + EntityCall EntityType = "call" +) + +// If an incident is shared, all calls that are part of it must be shared too, but this can be through the incident share (/share/bLaH/callID[.mp3]) + +type Share struct { + ID string `json:"id"` + Type EntityType `json:"entityType"` + EntityID uuid.UUID `json:"entityID"` + Expiration *jsontypes.Time `json:"expiration"` +} + +// NewShare creates a new share. +func (s *service) NewShare(ctx context.Context, shType EntityType, shID uuid.UUID, exp *time.Duration) (id string, err error) { + id, err = gonanoid.ID(SlugLength) + if err != nil { + return + } + + store := FromCtx(ctx) + + var expT *jsontypes.Time + if exp != nil { + tt := time.Now().Add(*exp) + expT = (*jsontypes.Time)(&tt) + } + + share := &Share{ + ID: id, + Type: shType, + EntityID: shID, + Expiration: expT, + } + + err = store.Create(ctx, share) + if err != nil { + return + } + + return id, nil +} diff --git a/pkg/share/store.go b/pkg/share/store.go new file mode 100644 index 0000000..4ee6ec8 --- /dev/null +++ b/pkg/share/store.go @@ -0,0 +1,85 @@ +package share + +import ( + "context" + + "dynatron.me/x/stillbox/internal/jsontypes" + "dynatron.me/x/stillbox/pkg/database" +) + +type Store interface { + // Get retreives a share record. + Get(ctx context.Context, id string) (*Share, error) + + // Create stores a new share record. + Create(ctx context.Context, share *Share) error + + // Delete deletes a share record. + Delete(ctx context.Context, id string) error + + // Prune removes expired share records. + Prune(ctx context.Context) error +} + +type postgresStore struct { +} + +func recToShare(share database.Share) *Share { + return &Share{ + ID: share.ID, + Type: EntityType(share.EntityType), + EntityID: share.EntityID, + Expiration: jsontypes.TimePtrFromTSTZ(share.Expiration), + } +} + +func (s *postgresStore) Get(ctx context.Context, id string) (*Share, error) { + db := database.FromCtx(ctx) + rec, err := db.GetShare(ctx, id) + if err != nil { + return nil, err + } + + return recToShare(rec), nil +} + +func (s *postgresStore) Create(ctx context.Context, share *Share) error { + db := database.FromCtx(ctx) + err := db.CreateShare(ctx, database.CreateShareParams{ + ID: share.ID, + EntityType: string(share.Type), + EntityID: share.EntityID, + Expiration: share.Expiration.PGTypeTSTZ(), + }) + + return err +} + +func (s *postgresStore) Delete(ctx context.Context, id string) error { + return database.FromCtx(ctx).DeleteShare(ctx, id) +} + +func (s *postgresStore) Prune(ctx context.Context) error { + return database.FromCtx(ctx).PruneShares(ctx) +} + +func NewStore() *postgresStore { + return new(postgresStore) +} + +type storeCtxKey string + +const StoreCtxKey storeCtxKey = "store" + +func CtxWithStore(ctx context.Context, s Store) context.Context { + return context.WithValue(ctx, StoreCtxKey, s) +} + +func FromCtx(ctx context.Context) Store { + s, ok := ctx.Value(StoreCtxKey).(Store) + if !ok { + return NewStore() + } + + return s +} diff --git a/pkg/sinks/database.go b/pkg/sinks/database.go index 68cf70b..8018117 100644 --- a/pkg/sinks/database.go +++ b/pkg/sinks/database.go @@ -2,15 +2,12 @@ package sinks import ( "context" - "fmt" - "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/calls/callstore" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog/log" ) @@ -29,59 +26,9 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error { return nil } - params := s.toAddCallParams(call) - - err := s.db.InTx(ctx, func(tx database.Store) error { - err := tx.AddCall(ctx, params) - if err != nil { - return fmt.Errorf("add call: %w", err) - } - - log.Debug().Str("id", call.ID.String()).Int("system", call.System).Int("tgid", call.Talkgroup).Msg("stored") - - return nil - }, pgx.TxOptions{}) - - if err != nil && database.IsTGConstraintViolation(err) { - return s.db.InTx(ctx, func(tx database.Store) error { - _, err := s.tgs.LearnTG(ctx, call) - if err != nil { - return fmt.Errorf("learn tg: %w", err) - } - - err = tx.AddCall(ctx, params) - if err != nil { - return fmt.Errorf("learn tg retry: %w", err) - } - - return nil - }, pgx.TxOptions{}) - } - - return err + return callstore.FromCtx(ctx).AddCall(ctx, call) } func (s *DatabaseSink) SinkType() string { return "database" } - -func (s *DatabaseSink) toAddCallParams(call *calls.Call) database.AddCallParams { - return database.AddCallParams{ - ID: call.ID, - Submitter: call.Submitter.Int32Ptr(), - System: call.System, - Talkgroup: call.Talkgroup, - CallDate: pgtype.Timestamptz{Time: call.DateTime, Valid: true}, - AudioName: common.NilIfZero(call.AudioName), - AudioBlob: call.Audio, - AudioType: common.NilIfZero(call.AudioType), - Duration: call.Duration.MsInt32Ptr(), - Frequency: call.Frequency, - Frequencies: call.Frequencies, - Patches: call.Patches, - TGLabel: call.TalkgroupLabel, - TGAlphaTag: call.TGAlphaTag, - TGGroup: call.TalkgroupGroup, - Source: call.Source, - } -} diff --git a/pkg/sinks/relay_test.go b/pkg/sinks/relay_test.go index b99d03b..d0891c0 100644 --- a/pkg/sinks/relay_test.go +++ b/pkg/sinks/relay_test.go @@ -13,10 +13,10 @@ import ( "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/sources" + "dynatron.me/x/stillbox/pkg/users" "github.com/google/uuid" ) @@ -32,16 +32,16 @@ func TestRelay(t *testing.T) { tests := []struct { name string - submitter auth.UserID + submitter users.UserID apiKey string call calls.Call }{ { name: "base", - submitter: auth.UserID(1), + submitter: users.UserID(1), call: calls.Call{ ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}), - Submitter: common.PtrTo(auth.UserID(1)), + Submitter: common.PtrTo(users.UserID(1)), System: 197, Talkgroup: 10101, DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local), diff --git a/pkg/sources/http.go b/pkg/sources/http.go index 1c18b39..dfd8df5 100644 --- a/pkg/sources/http.go +++ b/pkg/sources/http.go @@ -9,6 +9,8 @@ import ( "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/rbac" + "dynatron.me/x/stillbox/pkg/users" "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" ) @@ -70,7 +72,7 @@ func (car *CallUploadRequest) mimeType() string { return "" } -func (car *CallUploadRequest) ToCall(submitter auth.UserID) (*calls.Call, error) { +func (car *CallUploadRequest) ToCall(submitter users.UserID) (*calls.Call, error) { return calls.Make(&calls.Call{ Submitter: &submitter, System: car.System, @@ -98,7 +100,13 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - submitter, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) + submitterSub, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) + if err != nil { + auth.ErrorResponse(w, err) + return + } + + submitter, err := users.FromSubject(submitterSub) if err != nil { auth.ErrorResponse(w, err) return @@ -117,20 +125,22 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) { return } - call, err := cur.ToCall(*submitter) + call, err := cur.ToCall(submitter.ID) if err != nil { log.Error().Err(err).Msg("toCall failed") http.Error(w, err.Error(), http.StatusBadRequest) return } - err = h.ing.Ingest(ctx, call) + err = h.ing.Ingest(rbac.CtxWithSubject(ctx, submitterSub), call) if err != nil { - log.Error().Err(err).Msg("ingest failed") - http.Error(w, "Call ingest failed.", http.StatusInternalServerError) + if rbac.ErrAccessDenied(err) != nil { + log.Error().Err(err).Msg("ingest failed") + http.Error(w, "Call ingest failed.", http.StatusForbidden) + } return } - log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Msg("ingested") + log.Info().Int("system", cur.System).Int("tgid", cur.Talkgroup).Str("duration", call.Duration.Duration().String()).Str("sub", submitter.Username).Msg("ingested") written, err := w.Write([]byte("Call imported successfully.")) if err != nil { diff --git a/pkg/store/store.go b/pkg/store/store.go deleted file mode 100644 index bb85c85..0000000 --- a/pkg/store/store.go +++ /dev/null @@ -1,50 +0,0 @@ -package store - -import ( - "context" - - "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" - "dynatron.me/x/stillbox/pkg/users" -) - -type Store interface { - TG() tgstore.Store - User() users.Store -} - -type store struct { - tg tgstore.Store - user users.Store -} - -func (s *store) TG() tgstore.Store { - return s.tg -} - -func (s *store) User() users.Store { - return s.user -} - -func New() Store { - return &store{ - tg: tgstore.NewCache(), - user: users.NewStore(), - } -} - -type storeCtxKey string - -const StoreCtxKey storeCtxKey = "store" - -func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, StoreCtxKey, s) -} - -func FromCtx(ctx context.Context) Store { - s, ok := ctx.Value(StoreCtxKey).(Store) - if !ok { - return New() - } - - return s -} diff --git a/pkg/talkgroups/talkgroup.go b/pkg/talkgroups/talkgroup.go index ded46fc..7965e98 100644 --- a/pkg/talkgroups/talkgroup.go +++ b/pkg/talkgroups/talkgroup.go @@ -9,6 +9,7 @@ import ( "strings" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" ) type Talkgroup struct { @@ -17,6 +18,10 @@ type Talkgroup struct { Learned bool `json:"learned"` } +func (t *Talkgroup) GetResourceName() string { + return rbac.ResourceTalkgroup +} + func (t Talkgroup) String() string { if t.System.Name == "" { t.System.Name = strconv.Itoa(int(t.Talkgroup.TGID)) diff --git a/pkg/talkgroups/tgstore/store.go b/pkg/talkgroups/tgstore/store.go index cf45be8..64a010e 100644 --- a/pkg/talkgroups/tgstore/store.go +++ b/pkg/talkgroups/tgstore/store.go @@ -8,11 +8,12 @@ import ( "time" "dynatron.me/x/stillbox/internal/common" - "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" tgsp "dynatron.me/x/stillbox/pkg/talkgroups" + "dynatron.me/x/stillbox/pkg/users" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" @@ -176,7 +177,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewCache() + panic("no tg store in context") } return s @@ -201,19 +202,23 @@ type cache struct { sync.RWMutex tgs tgMap systems map[int]string + db database.Store } // NewCache returns a new cache Store. -func NewCache() *cache { +func NewCache(db database.Store) *cache { tgc := &cache{ tgs: make(tgMap), systems: make(map[int]string), + db: db, } return tgc } func (t *cache) Hint(ctx context.Context, tgs []tgsp.ID) error { + // since this doesn't actually return data, we can skip rbac checks. + // This is only called by system services anyway. if len(tgs) < 1 { return nil } @@ -322,11 +327,15 @@ func addToRowList[T rowType](t *cache, tgRecords []T) []*tgsp.Talkgroup { } func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + db := t.db r := make([]*tgsp.Talkgroup, 0, len(tgs)) opt := sOpt(opts) - var err error if tgs != nil { toGet := make(tgsp.IDs, 0, len(tgs)) for _, id := range tgs { @@ -394,7 +403,8 @@ func (t *cache) TGs(ctx context.Context, tgs tgsp.IDs, opts ...Option) ([]*tgsp. } func (t *cache) Load(ctx context.Context, tgs database.TGTuples) error { - tgRecords, err := database.FromCtx(ctx).GetTalkgroupsWithLearnedBySysTGID(ctx, tgs) + // No need for RBAC checks since this merely primes the cache and returns nothing. + tgRecords, err := t.db.GetTalkgroupsWithLearnedBySysTGID(ctx, tgs) if err != nil { return err } @@ -420,9 +430,13 @@ func (t *cache) Weight(ctx context.Context, id tgsp.ID, tm time.Time) float64 { } func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([]*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + db := t.db opt := sOpt(opts) - var err error if opt.pagination != nil { sortDir, err := opt.pagination.SortDir() if err != nil { @@ -472,13 +486,18 @@ func (t *cache) SystemTGs(ctx context.Context, systemID int, opts ...Option) ([] } func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + rec, has := t.get(tg) if has { return rec, nil } - record, err := database.FromCtx(ctx).GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup)) + record, err := t.db.GetTalkgroupWithLearned(ctx, int32(tg.System), int32(tg.Talkgroup)) switch err { case nil: case pgx.ErrNoRows: @@ -494,12 +513,17 @@ func (t *cache) TG(ctx context.Context, tg tgsp.ID) (*tgsp.Talkgroup, error) { } func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return "", false + } + t.RLock() n, has := t.systems[id] t.RUnlock() if !has { - sys, err := database.FromCtx(ctx).GetSystemName(ctx, id) + sys, err := t.db.GetSystemName(ctx, id) if err != nil { return "", false } @@ -515,20 +539,26 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) } func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*tgsp.Talkgroup, error) { + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update") + if err != nil { + return nil, err + } + sysName, has := t.SystemName(ctx, int(*input.SystemID)) if !has { return nil, ErrNoSuchSystem } - db := database.FromCtx(ctx) + + db := t.db var tg database.Talkgroup - err := db.InTx(ctx, func(db database.Store) error { + err = db.InTx(ctx, func(db database.Store) error { var oerr error tg, oerr = db.UpdateTalkgroup(ctx, input) if oerr != nil { return oerr } versionBatch := db.StoreTGVersion(ctx, []database.StoreTGVersionParams{{ - Submitter: auth.UIDFrom(ctx), + Submitter: user.ID.Int32Ptr(), TGID: *input.TGID, }}) defer versionBatch.Close() @@ -557,12 +587,17 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara } func (t *cache) DeleteSystem(ctx context.Context, id int) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() t.invalidate() - err := database.FromCtx(ctx).DeleteSystem(ctx, id) + err = t.db.DeleteSystem(ctx, id) switch { case err == nil: return nil @@ -574,11 +609,21 @@ func (t *cache) DeleteSystem(ctx context.Context, id int) error { } func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionDelete)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() - err := database.FromCtx(ctx).InTx(ctx, func(db database.Store) error { - err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), auth.UIDFrom(ctx)) + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "update") + if err != nil { + return err + } + + err = t.db.InTx(ctx, func(db database.Store) error { + err := db.StoreDeletedTGVersion(ctx, common.PtrTo(int32(id.System)), common.PtrTo(int32(id.Talkgroup)), user.ID.Int32Ptr()) if err != nil { return err } @@ -600,7 +645,12 @@ func (t *cache) DeleteTG(ctx context.Context, id tgsp.ID) error { } func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate, rbac.ActionUpdate)) + if err != nil { + return nil, err + } + + db := t.db sys, has := t.SystemName(ctx, c.System) if !has { @@ -633,7 +683,12 @@ func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, er } func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.UpsertTalkgroupParams) ([]*tgsp.Talkgroup, error) { - db := database.FromCtx(ctx) + user, err := users.UserCheck(ctx, new(tgsp.Talkgroup), "create+update") + if err != nil { + return nil, err + } + + db := t.db sysName, hasSys := t.SystemName(ctx, system) if !hasSys { return nil, ErrNoSuchSystem @@ -645,7 +700,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse tgs := make([]*tgsp.Talkgroup, 0, len(input)) - err := db.InTx(ctx, func(db database.Store) error { + err = db.InTx(ctx, func(db database.Store) error { versionParams := make([]database.StoreTGVersionParams, 0, len(input)) for i := range input { // normalize tags @@ -670,7 +725,7 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse versionParams = append(versionParams, database.StoreTGVersionParams{ SystemID: int32(system), TGID: r.TGID, - Submitter: auth.UIDFrom(ctx), + Submitter: user.ID.Int32Ptr(), }) tgs = append(tgs, &tgsp.Talkgroup{ Talkgroup: r, @@ -709,14 +764,24 @@ func (t *cache) UpsertTGs(ctx context.Context, system int, input []database.Upse } func (t *cache) CreateSystem(ctx context.Context, id int, name string) error { + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionCreate)) + if err != nil { + return err + } + t.Lock() defer t.Unlock() t.addSysNoLock(id, name) - return database.FromCtx(ctx).CreateSystem(ctx, id, name) + return t.db.CreateSystem(ctx, id, name) } func (t *cache) Tags(ctx context.Context) ([]string, error) { - return database.FromCtx(ctx).GetAllTalkgroupTags(ctx) + _, err := rbac.Check(ctx, rbac.UseResource(rbac.ResourceTalkgroup), rbac.WithActions(rbac.ActionRead)) + if err != nil { + return nil, err + } + + return t.db.GetAllTalkgroupTags(ctx) } diff --git a/pkg/talkgroups/xport/radioref/radioreference_test.go b/pkg/talkgroups/xport/radioref/radioreference_test.go index c453e2f..42fb340 100644 --- a/pkg/talkgroups/xport/radioref/radioreference_test.go +++ b/pkg/talkgroups/xport/radioref/radioreference_test.go @@ -14,9 +14,12 @@ import ( "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database/mocks" + "dynatron.me/x/stillbox/pkg/rbac" + rbacmocks "dynatron.me/x/stillbox/pkg/rbac/mocks" "dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups/tgstore" "dynatron.me/x/stillbox/pkg/talkgroups/xport" + "dynatron.me/x/stillbox/pkg/users" ) func getFixture(fixture string) []byte { @@ -51,14 +54,19 @@ func TestRadioRef(t *testing.T) { }, } + subject := users.User{IsAdmin: true} + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { dbMock := mocks.NewStore(t) + rbacMock := rbacmocks.NewRBAC(t) + rbacMock.EXPECT().Check(mock.AnythingOfType("*context.valueCtx"), rbac.UseResource("Talkgroup"), mock.AnythingOfType("rbac.CheckOption")).Return(&subject, nil) if tc.expectErr == nil { dbMock.EXPECT().GetSystemName(mock.AnythingOfType("*context.valueCtx"), tc.sysID).Return(tc.sysName, nil) } ctx := database.CtxWithDB(context.Background(), dbMock) - ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache()) + ctx = rbac.CtxWithRBAC(ctx, rbacMock) + ctx = tgstore.CtxWithStore(ctx, tgstore.NewCache(dbMock)) ij := &xport.ImportJob{ Type: xport.Format(tc.impType), SystemID: tc.sysID, diff --git a/pkg/users/guest.go b/pkg/users/guest.go new file mode 100644 index 0000000..c38d2dc --- /dev/null +++ b/pkg/users/guest.go @@ -0,0 +1,21 @@ +package users + +import ( + "dynatron.me/x/stillbox/pkg/rbac" +) + +type ShareLinkGuest struct { + ShareID string +} + +func (s *ShareLinkGuest) GetRoles() []string { + return []string{rbac.RoleShareGuest} +} + +type Public struct { + RemoteAddr string +} + +func (s *Public) GetRoles() []string { + return []string{rbac.RolePublic} +} diff --git a/pkg/users/store.go b/pkg/users/store.go index b37384c..0129181 100644 --- a/pkg/users/store.go +++ b/pkg/users/store.go @@ -3,22 +3,40 @@ package users import ( "context" + "dynatron.me/x/stillbox/internal/cache" "dynatron.me/x/stillbox/pkg/database" ) type Store interface { + // GetUser gets a user by UID. + GetUser(ctx context.Context, username string) (*User, error) + // UserPrefs gets the preferences for the specified user and app name. - UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) + UserPrefs(ctx context.Context, username string, appName string) ([]byte, error) // SetUserPrefs sets the preferences for the specified user and app name. - SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error + SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error + + // Invalidate clears the user cache. + Invalidate() + + // UpdateUser updates a user's record + UpdateUser(ctx context.Context, username string, user UserUpdate) error + + // GetUserByAPIKey gets a user by API key. + GetAPIKey(ctx context.Context, key string) (database.GetAPIKeyRow, error) } -type store struct { +type postgresStore struct { + cache.Cache[string, *User] + db database.Store } -func NewStore() *store { - return new(store) +func NewStore(db database.Store) *postgresStore { + return &postgresStore{ + Cache: cache.New[string, *User](), + db: db, + } } type storeCtxKey string @@ -32,16 +50,56 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { - return NewStore() + panic("no users store in context") } return s } -func (s *store) UserPrefs(ctx context.Context, uid int32, appName string) ([]byte, error) { - db := database.FromCtx(ctx) +func (s *postgresStore) Invalidate() { + s.Clear() +} - prefs, err := db.GetAppPrefs(ctx, appName, int(uid)) +type UserUpdate struct { + Email *string `json:"email"` + IsAdmin *bool `json:"isAdmin"` +} + +func (s *postgresStore) UpdateUser(ctx context.Context, username string, user UserUpdate) error { + dbu, err := s.db.UpdateUser(ctx, username, user.Email, user.IsAdmin) + if err != nil { + return err + } + + s.Set(username, fromDBUser(dbu)) + + return nil +} + +func (s *postgresStore) GetUser(ctx context.Context, username string) (*User, error) { + u, has := s.Get(username) + if has { + return u, nil + } + + dbu, err := s.db.GetUserByUsername(ctx, username) + if err != nil { + return nil, err + } + + u = fromDBUser(dbu) + s.Set(username, u) + + return u, nil +} + +func (s *postgresStore) UserPrefs(ctx context.Context, username string, appName string) ([]byte, error) { + u, err := s.GetUser(ctx, username) + if err != nil { + return nil, err + } + + prefs, err := s.db.GetAppPrefs(ctx, appName, int(u.ID)) if err != nil { return nil, err } @@ -49,8 +107,15 @@ func (s *store) UserPrefs(ctx context.Context, uid int32, appName string) ([]byt return []byte(prefs), err } -func (s *store) SetUserPrefs(ctx context.Context, uid int32, appName string, prefs []byte) error { - db := database.FromCtx(ctx) +func (s *postgresStore) SetUserPrefs(ctx context.Context, username string, appName string, prefs []byte) error { + u, err := s.GetUser(ctx, username) + if err != nil { + return err + } - return db.SetAppPrefs(ctx, appName, prefs, int(uid)) + return s.db.SetAppPrefs(ctx, appName, prefs, int(u.ID)) +} + +func (s *postgresStore) GetAPIKey(ctx context.Context, b64hash string) (database.GetAPIKeyRow, error) { + return s.db.GetAPIKey(ctx, b64hash) } diff --git a/pkg/users/user.go b/pkg/users/user.go new file mode 100644 index 0000000..d4904a7 --- /dev/null +++ b/pkg/users/user.go @@ -0,0 +1,94 @@ +package users + +import ( + "context" + "encoding/json" + "strings" + + "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/rbac" +) + +type UserID int + +func (u *UserID) Int32Ptr() *int32 { + if u == nil { + return nil + } + + i := int32(*u) + + return &i +} + +func (u UserID) Int() int { + return int(u) +} + +func (u UserID) IsValid() bool { + return u > 0 +} + +func From(ctx context.Context) (*User, error) { + sub := rbac.SubjectFrom(ctx) + return FromSubject(sub) +} + +func UserCheck(ctx context.Context, rsc rbac.Resource, actions string) (*User, error) { + acts := strings.Split(actions, "+") + subj, err := rbac.FromCtx(ctx).Check(ctx, rsc, rbac.WithActions(acts...)) + if err != nil { + return nil, err + } + + return FromSubject(subj) +} + +func FromSubject(sub rbac.Subject) (*User, error) { + if sub == nil { + return nil, rbac.ErrBadSubject + } + + user, isUser := sub.(*User) + if !isUser || user == nil || !user.ID.IsValid() { + return nil, rbac.ErrBadSubject + } + + return user, nil +} + +type User struct { + ID UserID + Username string + Password string + Email string + IsAdmin bool + Prefs json.RawMessage +} + +func (u *User) GetName() string { + return u.Username +} + +func (u *User) GetRoles() []string { + r := make([]string, 1, 2) + + r[0] = rbac.RoleUser + + if u.IsAdmin { + r = append(r, rbac.RoleAdmin) + } + + return r +} + +func fromDBUser(dbu database.User) *User { + return &User{ + ID: UserID(dbu.ID), + Username: dbu.Username, + Password: dbu.Password, + Email: dbu.Email, + IsAdmin: dbu.IsAdmin, + Prefs: dbu.Prefs, + } +} diff --git a/sql/postgres/migrations/001_initial.up.sql b/sql/postgres/migrations/001_initial.up.sql index b9c9304..f7889cb 100644 --- a/sql/postgres/migrations/001_initial.up.sql +++ b/sql/postgres/migrations/001_initial.up.sql @@ -1,5 +1,5 @@ CREATE TABLE IF NOT EXISTS users( - id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY (START WITH 1), username VARCHAR (255) UNIQUE NOT NULL, password TEXT NOT NULL, email TEXT NOT NULL, @@ -141,6 +141,7 @@ CREATE TABLE IF NOT EXISTS settings( CREATE TABLE IF NOT EXISTS incidents( id UUID PRIMARY KEY, name TEXT NOT NULL, + owner INTEGER NOT NULL, description TEXT, start_time TIMESTAMPTZ, end_time TIMESTAMPTZ, @@ -163,3 +164,11 @@ CREATE TABLE IF NOT EXISTS incidents_calls( FOREIGN KEY (calls_tbl_id, call_date) REFERENCES calls(id, call_date), PRIMARY KEY (incident_id, call_id) ); + +CREATE TABLE IF NOT EXISTS shares( + id TEXT PRIMARY KEY, + entity_type TEXT NOT NULL, + entity_id UUID NOT NULL, + owner INTEGER NOT NULL REFERENCES users(id), + expiration TIMESTAMPTZ NULL +); diff --git a/sql/postgres/queries/calls.sql b/sql/postgres/queries/calls.sql index 998566a..cceaea1 100644 --- a/sql/postgres/queries/calls.sql +++ b/sql/postgres/queries/calls.sql @@ -156,3 +156,9 @@ CASE WHEN sqlc.narg('tags_not')::TEXT[] IS NOT NULL THEN c.duration > @longer_than ) ELSE TRUE END) ; + +-- name: DeleteCall :exec +DELETE FROM calls WHERE id = @id; + +-- name: GetCallSubmitter :one +SELECT submitter FROM calls WHERE id = @id; diff --git a/sql/postgres/queries/incidents.sql b/sql/postgres/queries/incidents.sql index 214a123..98b3fc8 100644 --- a/sql/postgres/queries/incidents.sql +++ b/sql/postgres/queries/incidents.sql @@ -33,6 +33,7 @@ WHERE incident_id = @incident_id AND call_id = @call_id; INSERT INTO incidents ( id, name, + owner, description, start_time, end_time, @@ -41,6 +42,7 @@ INSERT INTO incidents ( ) VALUES ( @id, @name, + @owner, sqlc.narg('description'), sqlc.narg('start_time'), sqlc.narg('end_time'), @@ -54,6 +56,7 @@ RETURNING *; SELECT i.id, i.name, + i.owner, i.description, i.start_time, i.end_time, @@ -148,6 +151,7 @@ ORDER BY ic.call_date ASC; SELECT i.id, i.name, + i.owner, i.description, i.start_time, i.end_time, @@ -171,3 +175,6 @@ RETURNING *; -- name: DeleteIncident :exec DELETE FROM incidents CASCADE WHERE id = @id; + +-- name: GetIncidentOwner :one +SELECT owner FROM incidents WHERE id = @id; diff --git a/sql/postgres/queries/share.sql b/sql/postgres/queries/share.sql new file mode 100644 index 0000000..11931e3 --- /dev/null +++ b/sql/postgres/queries/share.sql @@ -0,0 +1,24 @@ +-- name: GetShare :one +SELECT + id, + entity_type, + entity_id, + owner, + expiration +FROM shares +WHERE id = @id; + +-- name: CreateShare :exec +INSERT INTO shares ( + id, + entity_type, + entity_id, + owner, + expiration +) VALUES (@id, @entity_type, @entity_id, @owner, sqlc.narg('expiration')); + +-- name: DeleteShare :exec +DELETE FROM shares WHERE id = @id; + +-- name: PruneShares :exec +DELETE FROM shares WHERE expiration < NOW(); diff --git a/sql/postgres/queries/users.sql b/sql/postgres/queries/users.sql index 36924b9..dfce324 100644 --- a/sql/postgres/queries/users.sql +++ b/sql/postgres/queries/users.sql @@ -1,14 +1,10 @@ -- name: GetUserByID :one SELECT * FROM users -WHERE id = $1 LIMIT 1; +WHERE id = $1; -- name: GetUserByUsername :one SELECT * FROM users -WHERE username = $1 LIMIT 1; - --- name: GetUserByUID :one -SELECT * FROM users -WHERE id = $1 LIMIT 1; +WHERE username = $1; -- name: GetUsers :many SELECT * FROM users; @@ -28,6 +24,14 @@ DELETE FROM users WHERE username = $1; -- name: UpdatePassword :exec UPDATE users SET password = $2 WHERE username = $1; +-- name: UpdateUser :one +UPDATE users SET + email = COALESCE(sqlc.narg('email'), email), + is_admin = COALESCE(sqlc.narg('is_admin'), is_admin) +WHERE + username = $1 +RETURNING *; + -- name: CreateAPIKey :one INSERT INTO api_keys( owner, @@ -42,7 +46,17 @@ RETURNING *; DELETE FROM api_keys WHERE api_key = $1; -- name: GetAPIKey :one -SELECT * FROM api_keys WHERE api_key = $1; +SELECT + a.id, + a.owner, + a.created_at, + a.expires, + a.disabled, + a.api_key, + u.username +FROM api_keys a +JOIN users u ON (a.owner = u.id) +WHERE api_key = $1; -- name: GetAppPrefs :one SELECT (prefs->>(@app_name::TEXT))::JSONB FROM users WHERE id = @uid;