This commit is contained in:
Daniel 2024-11-15 12:18:32 -05:00
parent e82f07e094
commit c9a32cd4bf
10 changed files with 38 additions and 13 deletions

View file

@ -25,6 +25,7 @@ getcalls:
generate: generate:
sqlc generate -f sql/sqlc.yaml sqlc generate -f sql/sqlc.yaml
protoc -I=pkg/pb/ --go_out=pkg/ pkg/pb/stillbox.proto protoc -I=pkg/pb/ --go_out=pkg/ pkg/pb/stillbox.proto
go generate ./...
lint: lint:
golangci-lint run golangci-lint run

1
go.mod
View file

@ -57,6 +57,7 @@ require (
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect github.com/segmentio/asm v1.2.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
go.uber.org/atomic v1.7.0 // indirect go.uber.org/atomic v1.7.0 // indirect
golang.org/x/exp/shiny v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/image v0.14.0 // indirect golang.org/x/image v0.14.0 // indirect

2
go.sum
View file

@ -134,6 +134,8 @@ github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3k
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

View file

@ -312,7 +312,7 @@ func (as *alerter) backfill(ctx context.Context, since time.Time, until time.Tim
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
const backfillStatsQuery = `SELECT system, talkgroup, call_date FROM calls WHERE call_date > $1 AND call_date < $2 ORDER BY call_date ASC` const backfillStatsQuery = `SELECT system, talkgroup, call_date FROM calls WHERE call_date > $1 AND call_date < $2 ORDER BY call_date ASC`
rows, err := db.Query(ctx, backfillStatsQuery, since, until) rows, err := db.DB().Query(ctx, backfillStatsQuery, since, until)
if err != nil { if err != nil {
return count, err return count, err
} }

View file

@ -70,7 +70,7 @@ func (a *Auth) initJWT() {
} }
func (a *Auth) Login(ctx context.Context, username, password string) (token string, err error) { func (a *Auth) Login(ctx context.Context, username, password string) (token string, err error) {
q := database.New(database.FromCtx(ctx)) q := database.FromCtx(ctx)
users, err := q.GetUsers(ctx) users, err := q.GetUsers(ctx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("getUsers failed") log.Error().Err(err).Msg("getUsers failed")

View file

@ -16,11 +16,23 @@ import (
) )
// DB is a database handle. // DB is a database handle.
type DB struct { //go:generate mockery
type DB interface {
Querier
talkgroupQuerier
DB() *Database
}
type Database struct {
*pgxpool.Pool *pgxpool.Pool
*Queries *Queries
} }
func (db *Database) DB() *Database {
return db
}
type dbLogger struct{} type dbLogger struct{}
func (m dbLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { func (m dbLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
@ -28,7 +40,7 @@ func (m dbLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string,
} }
// NewClient creates a new DB using the provided config. // NewClient creates a new DB using the provided config.
func NewClient(ctx context.Context, conf config.DB) (*DB, error) { func NewClient(ctx context.Context, conf config.DB) (DB, error) {
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations") dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,7 +75,7 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) {
return nil, err return nil, err
} }
db := &DB{ db := &Database{
Pool: pool, Pool: pool,
Queries: New(pool), Queries: New(pool),
} }
@ -76,8 +88,8 @@ type dBCtxKey string
const DBCtxKey dBCtxKey = "dbctx" const DBCtxKey dBCtxKey = "dbctx"
// FromCtx returns the database handle from the provided Context. // FromCtx returns the database handle from the provided Context.
func FromCtx(ctx context.Context) *DB { func FromCtx(ctx context.Context) DB {
c, ok := ctx.Value(DBCtxKey).(*DB) c, ok := ctx.Value(DBCtxKey).(DB)
if !ok { if !ok {
panic("no DB in context") panic("no DB in context")
} }
@ -86,7 +98,7 @@ func FromCtx(ctx context.Context) *DB {
} }
// CtxWithDB returns a Context with the provided database handle. // CtxWithDB returns a Context with the provided database handle.
func CtxWithDB(ctx context.Context, conn *DB) context.Context { func CtxWithDB(ctx context.Context, conn DB) context.Context {
return context.WithValue(ctx, DBCtxKey, conn) return context.WithValue(ctx, DBCtxKey, conn)
} }

View file

@ -4,6 +4,12 @@ import (
"context" "context"
) )
type talkgroupQuerier interface {
GetTalkgroupsWithLearnedBySysTGID(ctx context.Context, ids TGTuples) ([]GetTalkgroupsRow, error)
GetTalkgroupsBySysTGID(ctx context.Context, ids TGTuples) ([]GetTalkgroupsRow, error)
BulkSetTalkgroupTags(ctx context.Context, tgs TGTuples, tags []string) error
}
type TGTuples [2][]uint32 type TGTuples [2][]uint32
func MakeTGTuples(cap int) TGTuples { func MakeTGTuples(cap int) TGTuples {

View file

@ -27,7 +27,7 @@ const shutdownTimeout = 5 * time.Second
type Server struct { type Server struct {
auth *auth.Auth auth *auth.Auth
conf *config.Config conf *config.Config
db *database.DB db database.DB
r *chi.Mux r *chi.Mux
sources sources.Sources sources sources.Sources
sinks sinks.Sinks sinks sinks.Sinks
@ -103,7 +103,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
} }
func (s *Server) Go(ctx context.Context) error { func (s *Server) Go(ctx context.Context) error {
defer s.db.Close() defer s.db.DB().Close()
s.installHupHandler() s.installHupHandler()

View file

@ -13,10 +13,10 @@ import (
) )
type DatabaseSink struct { type DatabaseSink struct {
db *database.DB db database.DB
} }
func NewDatabaseSink(db *database.DB) *DatabaseSink { func NewDatabaseSink(db database.DB) *DatabaseSink {
return &DatabaseSink{db: db} return &DatabaseSink{db: db}
} }

View file

@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"dynatron.me/x/stillbox/pkg/talkgroups/importer" "dynatron.me/x/stillbox/pkg/talkgroups/importer"
"dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/database/mocks"
) )
func getFixture(fixture string) []byte { func getFixture(fixture string) []byte {
@ -23,7 +25,6 @@ func getFixture(fixture string) []byte {
} }
func TestRadioReferenceImport(t *testing.T) { func TestRadioReferenceImport(t *testing.T) {
ctx := context.Background()
tests := []struct{ tests := []struct{
name string name string
input []byte input []byte
@ -41,6 +42,8 @@ func TestRadioReferenceImport(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
ctx = database.CtxWithDB(ctx, mocks.NewDB())
ij := &importer.ImportJob{ ij := &importer.ImportJob{
Type: "radioreference", Type: "radioreference",
SystemID: tc.sysID, SystemID: tc.sysID,