From a1b751fdf0e7db25b983fadf588f4f5777d7adb5 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Thu, 21 Nov 2024 07:44:08 -0500 Subject: [PATCH] Fix learning --- pkg/alerting/alert/alert.go | 2 +- pkg/database/mocks/Store.go | 105 +++++++++++++++------------- pkg/database/querier.go | 2 +- pkg/database/talkgroups.sql.go | 84 +++++++++++++++++----- pkg/nexus/commands.go | 2 +- pkg/rest/talkgroups.go | 6 +- pkg/sinks/database.go | 4 +- pkg/talkgroups/importer/import.go | 2 +- pkg/talkgroups/tgstore/store.go | 8 +-- sql/postgres/queries/talkgroups.sql | 55 ++++++++++++--- 10 files changed, 179 insertions(+), 91 deletions(-) diff --git a/pkg/alerting/alert/alert.go b/pkg/alerting/alert/alert.go index 2e1ada4..060948b 100644 --- a/pkg/alerting/alert/alert.go +++ b/pkg/alerting/alert/alert.go @@ -46,7 +46,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams { // Make creates an alert for later rendering or storage. func Make(ctx context.Context, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { - store := tgstore.From(ctx) + store := tgstore.FromCtx(ctx) d := Alert{ Score: score, Timestamp: time.Now(), diff --git a/pkg/database/mocks/Store.go b/pkg/database/mocks/Store.go index 1cf50d8..6d78e98 100644 --- a/pkg/database/mocks/Store.go +++ b/pkg/database/mocks/Store.go @@ -179,54 +179,6 @@ func (_c *Store_AddLearnedTalkgroup_Call) RunAndReturn(run func(context.Context, return _c } -// AddTalkgroupWithLearnedFlag provides a mock function with given fields: ctx, systemID, tGID -func (_m *Store) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error { - ret := _m.Called(ctx, systemID, tGID) - - if len(ret) == 0 { - panic("no return value specified for AddTalkgroupWithLearnedFlag") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, int32, int32) error); ok { - r0 = rf(ctx, systemID, tGID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Store_AddTalkgroupWithLearnedFlag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTalkgroupWithLearnedFlag' -type Store_AddTalkgroupWithLearnedFlag_Call struct { - *mock.Call -} - -// AddTalkgroupWithLearnedFlag is a helper method to define mock.On call -// - ctx context.Context -// - systemID int32 -// - tGID int32 -func (_e *Store_Expecter) AddTalkgroupWithLearnedFlag(ctx interface{}, systemID interface{}, tGID interface{}) *Store_AddTalkgroupWithLearnedFlag_Call { - return &Store_AddTalkgroupWithLearnedFlag_Call{Call: _e.mock.On("AddTalkgroupWithLearnedFlag", ctx, systemID, tGID)} -} - -func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Run(run func(ctx context.Context, systemID int32, tGID int32)) *Store_AddTalkgroupWithLearnedFlag_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int32), args[2].(int32)) - }) - return _c -} - -func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Return(_a0 error) *Store_AddTalkgroupWithLearnedFlag_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *Store_AddTalkgroupWithLearnedFlag_Call) RunAndReturn(run func(context.Context, int32, int32) error) *Store_AddTalkgroupWithLearnedFlag_Call { - _c.Call.Return(run) - return _c -} - // BulkSetTalkgroupTags provides a mock function with given fields: ctx, tgs, tags func (_m *Store) BulkSetTalkgroupTags(ctx context.Context, tgs database.TGTuples, tags []string) error { ret := _m.Called(ctx, tgs, tags) @@ -1569,6 +1521,63 @@ func (_c *Store_InTx_Call) RunAndReturn(run func(context.Context, func(database. return _c } +// RestoreTalkgroupVersion provides a mock function with given fields: ctx, versionIds +func (_m *Store) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (database.Talkgroup, error) { + ret := _m.Called(ctx, versionIds) + + if len(ret) == 0 { + panic("no return value specified for RestoreTalkgroupVersion") + } + + var r0 database.Talkgroup + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (database.Talkgroup, error)); ok { + return rf(ctx, versionIds) + } + if rf, ok := ret.Get(0).(func(context.Context, int) database.Talkgroup); ok { + r0 = rf(ctx, versionIds) + } else { + r0 = ret.Get(0).(database.Talkgroup) + } + + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, versionIds) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Store_RestoreTalkgroupVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreTalkgroupVersion' +type Store_RestoreTalkgroupVersion_Call struct { + *mock.Call +} + +// RestoreTalkgroupVersion is a helper method to define mock.On call +// - ctx context.Context +// - versionIds int +func (_e *Store_Expecter) RestoreTalkgroupVersion(ctx interface{}, versionIds interface{}) *Store_RestoreTalkgroupVersion_Call { + return &Store_RestoreTalkgroupVersion_Call{Call: _e.mock.On("RestoreTalkgroupVersion", ctx, versionIds)} +} + +func (_c *Store_RestoreTalkgroupVersion_Call) Run(run func(ctx context.Context, versionIds int)) *Store_RestoreTalkgroupVersion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int)) + }) + return _c +} + +func (_c *Store_RestoreTalkgroupVersion_Call) Return(_a0 database.Talkgroup, _a1 error) *Store_RestoreTalkgroupVersion_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Store_RestoreTalkgroupVersion_Call) RunAndReturn(run func(context.Context, int) (database.Talkgroup, error)) *Store_RestoreTalkgroupVersion_Call { + _c.Call.Return(run) + return _c +} + // SetCallTranscript provides a mock function with given fields: ctx, iD, transcript func (_m *Store) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error { ret := _m.Called(ctx, iD, transcript) diff --git a/pkg/database/querier.go b/pkg/database/querier.go index ca151e5..c9c531b 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -15,7 +15,6 @@ type Querier interface { AddAlert(ctx context.Context, arg AddAlertParams) error AddCall(ctx context.Context, arg AddCallParams) error AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgroupParams) (Talkgroup, error) - AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAPIKey(ctx context.Context, apiKey string) error @@ -35,6 +34,7 @@ type Querier interface { GetUserByUID(ctx context.Context, id int) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) GetUsers(ctx context.Context) ([]User, error) + RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error SetTalkgroupTags(ctx context.Context, tags []string, systemID int32, tGID int32) error StoreTGVersion(ctx context.Context, arg []StoreTGVersionParams) *StoreTGVersionBatchResults diff --git a/pkg/database/talkgroups.sql.go b/pkg/database/talkgroups.sql.go index 44b2a4f..ff41080 100644 --- a/pkg/database/talkgroups.sql.go +++ b/pkg/database/talkgroups.sql.go @@ -66,23 +66,6 @@ func (q *Queries) AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgro return i, err } -const addTalkgroupWithLearnedFlag = `-- name: AddTalkgroupWithLearnedFlag :exec -INSERT INTO talkgroups ( - system_id, - tgid, - learned -) VALUES( - $1, - $2, - TRUE -) -` - -func (q *Queries) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error { - _, err := q.db.Exec(ctx, addTalkgroupWithLearnedFlag, systemID, tGID) - return err -} - const getSystemName = `-- name: GetSystemName :one SELECT name FROM systems WHERE id = $1 ` @@ -394,6 +377,73 @@ func (q *Queries) GetTalkgroupsWithLearnedBySystem(ctx context.Context, system i return items, nil } +const restoreTalkgroupVersion = `-- name: RestoreTalkgroupVersion :one +INSERT INTO talkgroups( + system_id, + tgid, + name, + alpha_tag, + tg_group, + frequency, + metadata, + tags, + alert, + alert_config, + weight, + learned, + ignored +) +SELECT + system_id, + tgid, + name, + alpha_tag, + tg_group, + frequency, + metadata, + tags, + alert, + alert_config, + weight, + learned, + ignored +FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET + name = excluded.name, + alpha_tag = excluded.alpha_tag, + tg_group = excluded.tg_group, + metadata = excluded.metadata, + tags = excluded.tags, + alert = excluded.alert, + alert_config = excluded.alert_config, + weight = excluded.weight, + learned = excluded.learner, + ignored = excluded.ignored +WHERE tgv.id = ANY($1) +RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight, learned, ignored +` + +func (q *Queries) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) { + row := q.db.QueryRow(ctx, restoreTalkgroupVersion, versionIds) + var i Talkgroup + err := row.Scan( + &i.ID, + &i.SystemID, + &i.TGID, + &i.Name, + &i.AlphaTag, + &i.TGGroup, + &i.Frequency, + &i.Metadata, + &i.Tags, + &i.Alert, + &i.AlertConfig, + &i.Weight, + &i.Learned, + &i.Ignored, + ) + return i, err +} + const setTalkgroupTags = `-- name: SetTalkgroupTags :exec UPDATE talkgroups SET tags = $1 WHERE system_id = $2 AND tgid = $3 diff --git a/pkg/nexus/commands.go b/pkg/nexus/commands.go index 2ae94fd..37e06a8 100644 --- a/pkg/nexus/commands.go +++ b/pkg/nexus/commands.go @@ -60,7 +60,7 @@ func (c *client) SendError(cmd *pb.Command, err error) { } func (c *client) Talkgroup(ctx context.Context, tg *pb.Talkgroup) error { - tgi, err := tgstore.From(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup)) + tgi, err := tgstore.FromCtx(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup)) if err != nil { if err != tgstore.ErrNotFound { log.Error().Err(err).Int32("sys", tg.System).Int32("tg", tg.Talkgroup).Msg("get talkgroup fail") diff --git a/pkg/rest/talkgroups.go b/pkg/rest/talkgroups.go index 85a0257..e9c4fc5 100644 --- a/pkg/rest/talkgroups.go +++ b/pkg/rest/talkgroups.go @@ -54,7 +54,7 @@ func (t tgParams) ToID() talkgroups.ID { func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - tgs := tgstore.From(ctx) + tgs := tgstore.FromCtx(ctx) var p tgParams @@ -92,7 +92,7 @@ func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - tgs := tgstore.From(ctx) + tgs := tgstore.FromCtx(ctx) input := database.UpdateTalkgroupParams{} @@ -143,7 +143,7 @@ func (tga *talkgroupAPI) putTalkgroups(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - tgs := tgstore.From(ctx) + tgs := tgstore.FromCtx(ctx) var input []database.UpsertTalkgroupParams diff --git a/pkg/sinks/database.go b/pkg/sinks/database.go index 190fa9d..98ea90c 100644 --- a/pkg/sinks/database.go +++ b/pkg/sinks/database.go @@ -47,12 +47,12 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error { return s.db.InTx(ctx, func(tx database.Store) error { _, err := s.tgs.LearnTG(ctx, call) if err != nil { - return fmt.Errorf("add call: learn tg: %w", err) + return fmt.Errorf("learn tg: %w", err) } err = tx.AddCall(ctx, params) if err != nil { - return fmt.Errorf("add call: retry: %w", err) + return fmt.Errorf("learn tg retry: %w", err) } return nil diff --git a/pkg/talkgroups/importer/import.go b/pkg/talkgroups/importer/import.go index 224e350..347c25b 100644 --- a/pkg/talkgroups/importer/import.go +++ b/pkg/talkgroups/importer/import.go @@ -67,7 +67,7 @@ var rrRE = regexp.MustCompile(`DEC\s+HEX\s+Mode\s+Alpha Tag\s+Description\s+Tag` func (rr *radioReferenceImporter) importTalkgroups(ctx context.Context, sys int, r io.Reader) ([]talkgroups.Talkgroup, error) { sc := bufio.NewScanner(r) tgs := make([]talkgroups.Talkgroup, 0, 8) - sysn, has := tgstore.From(ctx).SystemName(ctx, sys) + sysn, has := tgstore.FromCtx(ctx).SystemName(ctx, sys) if !has { return nil, tgstore.ErrNoSuchSystem } diff --git a/pkg/talkgroups/tgstore/store.go b/pkg/talkgroups/tgstore/store.go index 15bf696..9ec106f 100644 --- a/pkg/talkgroups/tgstore/store.go +++ b/pkg/talkgroups/tgstore/store.go @@ -3,7 +3,6 @@ package tgstore import ( "context" "errors" - "fmt" "strings" "sync" "time" @@ -72,7 +71,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context { return context.WithValue(ctx, StoreCtxKey, s) } -func From(ctx context.Context) Store { +func FromCtx(ctx context.Context) Store { s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { return NewCache() @@ -307,10 +306,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) { db := database.FromCtx(ctx) - err := db.AddTalkgroupWithLearnedFlag(ctx, int32(c.System), int32(c.Talkgroup)) - if err != nil { - return nil, fmt.Errorf("addTalkgroupWithLearnedFlag: %w", err) - } + sys, has := t.SystemName(ctx, c.System) if !has { return nil, ErrNoSuchSystem diff --git a/sql/postgres/queries/talkgroups.sql b/sql/postgres/queries/talkgroups.sql index d6f4a62..9be53fe 100644 --- a/sql/postgres/queries/talkgroups.sql +++ b/sql/postgres/queries/talkgroups.sql @@ -124,17 +124,6 @@ INSERT INTO talkgroup_versions(time, created_by, tg.learned FROM talkgroups tg WHERE tg.system_id = @system_id AND tg.tgid = @tgid; --- name: AddTalkgroupWithLearnedFlag :exec -INSERT INTO talkgroups ( - system_id, - tgid, - learned -) VALUES( - @system_id, - @tgid, - TRUE -); - -- name: AddLearnedTalkgroup :one INSERT INTO talkgroups( system_id, @@ -151,3 +140,47 @@ INSERT INTO talkgroups( sqlc.narg('alpha_tag'), sqlc.narg('tg_group') ) RETURNING *; + +-- name: RestoreTalkgroupVersion :one +INSERT INTO talkgroups( + system_id, + tgid, + name, + alpha_tag, + tg_group, + frequency, + metadata, + tags, + alert, + alert_config, + weight, + learned, + ignored +) +SELECT + system_id, + tgid, + name, + alpha_tag, + tg_group, + frequency, + metadata, + tags, + alert, + alert_config, + weight, + learned, + ignored +FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET + name = excluded.name, + alpha_tag = excluded.alpha_tag, + tg_group = excluded.tg_group, + metadata = excluded.metadata, + tags = excluded.tags, + alert = excluded.alert, + alert_config = excluded.alert_config, + weight = excluded.weight, + learned = excluded.learner, + ignored = excluded.ignored +WHERE tgv.id = ANY(@version_ids) +RETURNING *;