From d16ad6a4ad13b260d06f3a1853966895d24f3d3d Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 10 Feb 2025 21:29:32 -0500 Subject: [PATCH] Shares list endpoint --- pkg/database/mocks/Store.go | 116 +++++++++++++++++++++++++++++++++ pkg/database/querier.go | 2 + pkg/database/share.sql.go | 73 +++++++++++++++++++++ pkg/rest/share.go | 29 +++++++++ pkg/shares/store.go | 58 ++++++++++++++++- sql/postgres/queries/share.sql | 27 ++++++++ 6 files changed, 304 insertions(+), 1 deletion(-) diff --git a/pkg/database/mocks/Store.go b/pkg/database/mocks/Store.go index ad4ad0a..8e33522 100644 --- a/pkg/database/mocks/Store.go +++ b/pkg/database/mocks/Store.go @@ -1924,6 +1924,122 @@ func (_c *Store_GetShare_Call) RunAndReturn(run func(context.Context, string) (d return _c } +// GetSharesP provides a mock function with given fields: ctx, arg +func (_m *Store) GetSharesP(ctx context.Context, arg database.GetSharesPParams) ([]database.Share, error) { + ret := _m.Called(ctx, arg) + + if len(ret) == 0 { + panic("no return value specified for GetSharesP") + } + + var r0 []database.Share + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, database.GetSharesPParams) ([]database.Share, error)); ok { + return rf(ctx, arg) + } + if rf, ok := ret.Get(0).(func(context.Context, database.GetSharesPParams) []database.Share); ok { + r0 = rf(ctx, arg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.Share) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, database.GetSharesPParams) error); ok { + r1 = rf(ctx, arg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetSharesP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSharesP' +type Store_GetSharesP_Call struct { + *mock.Call +} + +// GetSharesP is a helper method to define mock.On call +// - ctx context.Context +// - arg database.GetSharesPParams +func (_e *Store_Expecter) GetSharesP(ctx interface{}, arg interface{}) *Store_GetSharesP_Call { + return &Store_GetSharesP_Call{Call: _e.mock.On("GetSharesP", ctx, arg)} +} + +func (_c *Store_GetSharesP_Call) Run(run func(ctx context.Context, arg database.GetSharesPParams)) *Store_GetSharesP_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.GetSharesPParams)) + }) + return _c +} + +func (_c *Store_GetSharesP_Call) Return(_a0 []database.Share, _a1 error) *Store_GetSharesP_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetSharesP_Call) RunAndReturn(run func(context.Context, database.GetSharesPParams) ([]database.Share, error)) *Store_GetSharesP_Call { + _c.Call.Return(run) + return _c +} + +// GetSharesPCount provides a mock function with given fields: ctx, owner +func (_m *Store) GetSharesPCount(ctx context.Context, owner *int32) (int64, error) { + ret := _m.Called(ctx, owner) + + if len(ret) == 0 { + panic("no return value specified for GetSharesPCount") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *int32) (int64, error)); ok { + return rf(ctx, owner) + } + if rf, ok := ret.Get(0).(func(context.Context, *int32) int64); ok { + r0 = rf(ctx, owner) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, *int32) error); ok { + r1 = rf(ctx, owner) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_GetSharesPCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSharesPCount' +type Store_GetSharesPCount_Call struct { + *mock.Call +} + +// GetSharesPCount is a helper method to define mock.On call +// - ctx context.Context +// - owner *int32 +func (_e *Store_Expecter) GetSharesPCount(ctx interface{}, owner interface{}) *Store_GetSharesPCount_Call { + return &Store_GetSharesPCount_Call{Call: _e.mock.On("GetSharesPCount", ctx, owner)} +} + +func (_c *Store_GetSharesPCount_Call) Run(run func(ctx context.Context, owner *int32)) *Store_GetSharesPCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*int32)) + }) + return _c +} + +func (_c *Store_GetSharesPCount_Call) Return(_a0 int64, _a1 error) *Store_GetSharesPCount_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_GetSharesPCount_Call) RunAndReturn(run func(context.Context, *int32) (int64, error)) *Store_GetSharesPCount_Call { + _c.Call.Return(run) + return _c +} + // GetSystemName provides a mock function with given fields: ctx, systemID func (_m *Store) GetSystemName(ctx context.Context, systemID int) (string, error) { ret := _m.Called(ctx, systemID) diff --git a/pkg/database/querier.go b/pkg/database/querier.go index d2cd4a9..c519bed 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -42,6 +42,8 @@ type Querier interface { GetIncidentOwner(ctx context.Context, id uuid.UUID) (int, error) GetIncidentTalkgroups(ctx context.Context, incidentID uuid.UUID) ([]GetIncidentTalkgroupsRow, error) GetShare(ctx context.Context, id string) (Share, error) + GetSharesP(ctx context.Context, arg GetSharesPParams) ([]Share, error) + GetSharesPCount(ctx context.Context, owner *int32) (int64, error) GetSystemName(ctx context.Context, systemID int) (string, error) GetTalkgroup(ctx context.Context, systemID int32, tGID int32) (GetTalkgroupRow, error) GetTalkgroupIDsByTags(ctx context.Context, anyTags []string, allTags []string, notTags []string) ([]GetTalkgroupIDsByTagsRow, error) diff --git a/pkg/database/share.sql.go b/pkg/database/share.sql.go index ac7bc38..20a77e1 100644 --- a/pkg/database/share.sql.go +++ b/pkg/database/share.sql.go @@ -79,6 +79,79 @@ func (q *Queries) GetShare(ctx context.Context, id string) (Share, error) { return i, err } +const getSharesP = `-- name: GetSharesP :many +SELECT + s.id, + s.entity_type, + s.entity_id, + s.entity_date, + s.owner, + s.expiration +FROM shares s +WHERE +CASE WHEN $1::INTEGER IS NOT NULL THEN + s.owner = $1 ELSE TRUE END +ORDER BY +CASE WHEN $2::TEXT = 'asc' THEN s.entity_date END ASC, +CASE WHEN $2::TEXT = 'desc' THEN s.entity_date END DESC +OFFSET $3 ROWS +FETCH NEXT $4 ROWS ONLY +` + +type GetSharesPParams struct { + Owner *int32 `json:"owner"` + Direction string `json:"direction"` + Offset int32 `json:"offset"` + PerPage int32 `json:"perPage"` +} + +func (q *Queries) GetSharesP(ctx context.Context, arg GetSharesPParams) ([]Share, error) { + rows, err := q.db.Query(ctx, getSharesP, + arg.Owner, + arg.Direction, + arg.Offset, + arg.PerPage, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Share + for rows.Next() { + var i Share + if err := rows.Scan( + &i.ID, + &i.EntityType, + &i.EntityID, + &i.EntityDate, + &i.Owner, + &i.Expiration, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSharesPCount = `-- name: GetSharesPCount :one +SELECT COUNT(*) +FROM shares s +WHERE +CASE WHEN $1::INTEGER IS NOT NULL THEN + s.owner = $1 ELSE TRUE END +` + +func (q *Queries) GetSharesPCount(ctx context.Context, owner *int32) (int64, error) { + row := q.db.QueryRow(ctx, getSharesPCount, owner) + var count int64 + err := row.Scan(&count) + return count, err +} + const pruneShares = `-- name: PruneShares :exec DELETE FROM shares WHERE expiration < NOW() ` diff --git a/pkg/rest/share.go b/pkg/rest/share.go index 07a161d..45d676e 100644 --- a/pkg/rest/share.go +++ b/pkg/rest/share.go @@ -121,6 +121,7 @@ func (sa *shareAPI) Subrouter() http.Handler { r.Post(`/create`, sa.createShare) r.Delete(`/{id:[A-Za-z0-9_-]{20,}}`, sa.deleteShare) + r.Post(`/`, sa.listShares) return r } @@ -156,6 +157,34 @@ func (sa *shareAPI) createShare(w http.ResponseWriter, r *http.Request) { respond(w, r, sh) } +func (sa *shareAPI) listShares(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + shs := shares.FromCtx(ctx) + + p := shares.SharesParams{} + err := forms.Unmarshal(r, &p, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + shRes, count, err := shs.Shares(ctx, p) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + response := struct { + Shares []*shares.Share `json:"shares"` + TotalCount int `json:"totalCount"` + }{ + Shares: shRes, + TotalCount: count, + } + + respond(w, r, &response) +} + func (sa *shareAPI) routeShare(w http.ResponseWriter, r *http.Request) { ctx := r.Context() shs := shares.FromCtx(ctx) diff --git a/pkg/shares/store.go b/pkg/shares/store.go index 14287a7..46dd3e6 100644 --- a/pkg/shares/store.go +++ b/pkg/shares/store.go @@ -3,7 +3,9 @@ package shares import ( "context" "errors" + "fmt" + "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/jsontypes" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/rbac" @@ -12,13 +14,21 @@ import ( "github.com/jackc/pgx/v5" ) +type SharesParams struct { + common.Pagination + Direction *common.SortDirection `json:"dir"` +} + type Shares interface { // NewShare creates a new share. NewShare(ctx context.Context, sh CreateShareParams) (*Share, error) - // Share retreives a share record. + // Share retrieves a share record. GetShare(ctx context.Context, id string) (*Share, error) + // Shares retrieves shares visible by the context Subject. + Shares(ctx context.Context, p SharesParams) (shares []*Share, totalCount int, err error) + // Create stores a new share record. Create(ctx context.Context, share *Share) error @@ -98,6 +108,52 @@ func (s *postgresStore) Delete(ctx context.Context, id string) error { return database.FromCtx(ctx).DeleteShare(ctx, id) } +func (s *postgresStore) Shares(ctx context.Context, p SharesParams) (shares []*Share, totalCount int, err error) { + sub := entities.SubjectFrom(ctx) + + // ersatz RBAC + owner := common.PtrTo(int32(-1)) // invalid UID + switch s := sub.(type) { + case *users.User: + if !s.IsAdmin { + owner = s.ID.Int32Ptr() + } else { + owner = nil + } + case *entities.SystemServiceSubject: + owner = nil + default: + return nil, 0, rbac.ErrAccessDenied(rbac.ErrNotAuthorized) + } + + db := database.FromCtx(ctx) + + count, err := db.GetSharesPCount(ctx, owner) + if err != nil { + return nil, 0, fmt.Errorf("shares count: %w", err) + } + + offset, perPage := p.Pagination.OffsetPerPage(100) + dbParam := database.GetSharesPParams{ + Owner: owner, + Direction: p.Direction.DirString(common.DirAsc), + Offset: offset, + PerPage: perPage, + } + + shs, err := db.GetSharesP(ctx, dbParam) + if err != nil { + return nil, 0, err + } + + shares = make([]*Share, 0, len(shs)) + for _, v := range shs { + shares = append(shares, recToShare(v)) + } + + return shares, int(count), nil +} + func (s *postgresStore) Prune(ctx context.Context) error { return database.FromCtx(ctx).PruneShares(ctx) } diff --git a/sql/postgres/queries/share.sql b/sql/postgres/queries/share.sql index 7854c48..c1e940b 100644 --- a/sql/postgres/queries/share.sql +++ b/sql/postgres/queries/share.sql @@ -24,3 +24,30 @@ DELETE FROM shares WHERE id = @id; -- name: PruneShares :exec DELETE FROM shares WHERE expiration < NOW(); + +-- name: GetSharesP :many +SELECT + s.id, + s.entity_type, + s.entity_id, + s.entity_date, + s.owner, + s.expiration +FROM shares s +WHERE +CASE WHEN sqlc.narg('owner')::INTEGER IS NOT NULL THEN + s.owner = @owner ELSE TRUE END +ORDER BY +CASE WHEN @direction::TEXT = 'asc' THEN s.entity_date END ASC, +CASE WHEN @direction::TEXT = 'desc' THEN s.entity_date END DESC +OFFSET sqlc.arg('offset') ROWS +FETCH NEXT sqlc.arg('per_page') ROWS ONLY +; + +-- name: GetSharesPCount :one +SELECT COUNT(*) +FROM shares s +WHERE +CASE WHEN sqlc.narg('owner')::INTEGER IS NOT NULL THEN + s.owner = @owner ELSE TRUE END +;