Fix learning

This commit is contained in:
Daniel 2024-11-21 07:44:08 -05:00
parent 692f7d69a3
commit a1b751fdf0
10 changed files with 179 additions and 91 deletions

View file

@ -46,7 +46,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
// Make creates an alert for later rendering or storage.
func Make(ctx context.Context, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
store := tgstore.From(ctx)
store := tgstore.FromCtx(ctx)
d := Alert{
Score: score,
Timestamp: time.Now(),

View file

@ -179,54 +179,6 @@ func (_c *Store_AddLearnedTalkgroup_Call) RunAndReturn(run func(context.Context,
return _c
}
// AddTalkgroupWithLearnedFlag provides a mock function with given fields: ctx, systemID, tGID
func (_m *Store) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
ret := _m.Called(ctx, systemID, tGID)
if len(ret) == 0 {
panic("no return value specified for AddTalkgroupWithLearnedFlag")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int32, int32) error); ok {
r0 = rf(ctx, systemID, tGID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Store_AddTalkgroupWithLearnedFlag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTalkgroupWithLearnedFlag'
type Store_AddTalkgroupWithLearnedFlag_Call struct {
*mock.Call
}
// AddTalkgroupWithLearnedFlag is a helper method to define mock.On call
// - ctx context.Context
// - systemID int32
// - tGID int32
func (_e *Store_Expecter) AddTalkgroupWithLearnedFlag(ctx interface{}, systemID interface{}, tGID interface{}) *Store_AddTalkgroupWithLearnedFlag_Call {
return &Store_AddTalkgroupWithLearnedFlag_Call{Call: _e.mock.On("AddTalkgroupWithLearnedFlag", ctx, systemID, tGID)}
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Run(run func(ctx context.Context, systemID int32, tGID int32)) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int32), args[2].(int32))
})
return _c
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Return(_a0 error) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) RunAndReturn(run func(context.Context, int32, int32) error) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Return(run)
return _c
}
// BulkSetTalkgroupTags provides a mock function with given fields: ctx, tgs, tags
func (_m *Store) BulkSetTalkgroupTags(ctx context.Context, tgs database.TGTuples, tags []string) error {
ret := _m.Called(ctx, tgs, tags)
@ -1569,6 +1521,63 @@ func (_c *Store_InTx_Call) RunAndReturn(run func(context.Context, func(database.
return _c
}
// RestoreTalkgroupVersion provides a mock function with given fields: ctx, versionIds
func (_m *Store) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (database.Talkgroup, error) {
ret := _m.Called(ctx, versionIds)
if len(ret) == 0 {
panic("no return value specified for RestoreTalkgroupVersion")
}
var r0 database.Talkgroup
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int) (database.Talkgroup, error)); ok {
return rf(ctx, versionIds)
}
if rf, ok := ret.Get(0).(func(context.Context, int) database.Talkgroup); ok {
r0 = rf(ctx, versionIds)
} else {
r0 = ret.Get(0).(database.Talkgroup)
}
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, versionIds)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_RestoreTalkgroupVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreTalkgroupVersion'
type Store_RestoreTalkgroupVersion_Call struct {
*mock.Call
}
// RestoreTalkgroupVersion is a helper method to define mock.On call
// - ctx context.Context
// - versionIds int
func (_e *Store_Expecter) RestoreTalkgroupVersion(ctx interface{}, versionIds interface{}) *Store_RestoreTalkgroupVersion_Call {
return &Store_RestoreTalkgroupVersion_Call{Call: _e.mock.On("RestoreTalkgroupVersion", ctx, versionIds)}
}
func (_c *Store_RestoreTalkgroupVersion_Call) Run(run func(ctx context.Context, versionIds int)) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int))
})
return _c
}
func (_c *Store_RestoreTalkgroupVersion_Call) Return(_a0 database.Talkgroup, _a1 error) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_RestoreTalkgroupVersion_Call) RunAndReturn(run func(context.Context, int) (database.Talkgroup, error)) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Return(run)
return _c
}
// SetCallTranscript provides a mock function with given fields: ctx, iD, transcript
func (_m *Store) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error {
ret := _m.Called(ctx, iD, transcript)

View file

@ -15,7 +15,6 @@ type Querier interface {
AddAlert(ctx context.Context, arg AddAlertParams) error
AddCall(ctx context.Context, arg AddCallParams) error
AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgroupParams) (Talkgroup, error)
AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error
CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error)
CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
DeleteAPIKey(ctx context.Context, apiKey string) error
@ -35,6 +34,7 @@ type Querier interface {
GetUserByUID(ctx context.Context, id int) (User, error)
GetUserByUsername(ctx context.Context, username string) (User, error)
GetUsers(ctx context.Context) ([]User, error)
RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error)
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
SetTalkgroupTags(ctx context.Context, tags []string, systemID int32, tGID int32) error
StoreTGVersion(ctx context.Context, arg []StoreTGVersionParams) *StoreTGVersionBatchResults

View file

@ -66,23 +66,6 @@ func (q *Queries) AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgro
return i, err
}
const addTalkgroupWithLearnedFlag = `-- name: AddTalkgroupWithLearnedFlag :exec
INSERT INTO talkgroups (
system_id,
tgid,
learned
) VALUES(
$1,
$2,
TRUE
)
`
func (q *Queries) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
_, err := q.db.Exec(ctx, addTalkgroupWithLearnedFlag, systemID, tGID)
return err
}
const getSystemName = `-- name: GetSystemName :one
SELECT name FROM systems WHERE id = $1
`
@ -394,6 +377,73 @@ func (q *Queries) GetTalkgroupsWithLearnedBySystem(ctx context.Context, system i
return items, nil
}
const restoreTalkgroupVersion = `-- name: RestoreTalkgroupVersion :one
INSERT INTO talkgroups(
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
)
SELECT
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
name = excluded.name,
alpha_tag = excluded.alpha_tag,
tg_group = excluded.tg_group,
metadata = excluded.metadata,
tags = excluded.tags,
alert = excluded.alert,
alert_config = excluded.alert_config,
weight = excluded.weight,
learned = excluded.learner,
ignored = excluded.ignored
WHERE tgv.id = ANY($1)
RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight, learned, ignored
`
func (q *Queries) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) {
row := q.db.QueryRow(ctx, restoreTalkgroupVersion, versionIds)
var i Talkgroup
err := row.Scan(
&i.ID,
&i.SystemID,
&i.TGID,
&i.Name,
&i.AlphaTag,
&i.TGGroup,
&i.Frequency,
&i.Metadata,
&i.Tags,
&i.Alert,
&i.AlertConfig,
&i.Weight,
&i.Learned,
&i.Ignored,
)
return i, err
}
const setTalkgroupTags = `-- name: SetTalkgroupTags :exec
UPDATE talkgroups SET tags = $1
WHERE system_id = $2 AND tgid = $3

View file

@ -60,7 +60,7 @@ func (c *client) SendError(cmd *pb.Command, err error) {
}
func (c *client) Talkgroup(ctx context.Context, tg *pb.Talkgroup) error {
tgi, err := tgstore.From(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup))
tgi, err := tgstore.FromCtx(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup))
if err != nil {
if err != tgstore.ErrNotFound {
log.Error().Err(err).Int32("sys", tg.System).Int32("tg", tg.Talkgroup).Msg("get talkgroup fail")

View file

@ -54,7 +54,7 @@ func (t tgParams) ToID() talkgroups.ID {
func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tgs := tgstore.From(ctx)
tgs := tgstore.FromCtx(ctx)
var p tgParams
@ -92,7 +92,7 @@ func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
tgs := tgstore.From(ctx)
tgs := tgstore.FromCtx(ctx)
input := database.UpdateTalkgroupParams{}
@ -143,7 +143,7 @@ func (tga *talkgroupAPI) putTalkgroups(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
tgs := tgstore.From(ctx)
tgs := tgstore.FromCtx(ctx)
var input []database.UpsertTalkgroupParams

View file

@ -47,12 +47,12 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error {
return s.db.InTx(ctx, func(tx database.Store) error {
_, err := s.tgs.LearnTG(ctx, call)
if err != nil {
return fmt.Errorf("add call: learn tg: %w", err)
return fmt.Errorf("learn tg: %w", err)
}
err = tx.AddCall(ctx, params)
if err != nil {
return fmt.Errorf("add call: retry: %w", err)
return fmt.Errorf("learn tg retry: %w", err)
}
return nil

View file

@ -67,7 +67,7 @@ var rrRE = regexp.MustCompile(`DEC\s+HEX\s+Mode\s+Alpha Tag\s+Description\s+Tag`
func (rr *radioReferenceImporter) importTalkgroups(ctx context.Context, sys int, r io.Reader) ([]talkgroups.Talkgroup, error) {
sc := bufio.NewScanner(r)
tgs := make([]talkgroups.Talkgroup, 0, 8)
sysn, has := tgstore.From(ctx).SystemName(ctx, sys)
sysn, has := tgstore.FromCtx(ctx).SystemName(ctx, sys)
if !has {
return nil, tgstore.ErrNoSuchSystem
}

View file

@ -3,7 +3,6 @@ package tgstore
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
@ -72,7 +71,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
return context.WithValue(ctx, StoreCtxKey, s)
}
func From(ctx context.Context) Store {
func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok {
return NewCache()
@ -307,10 +306,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara
func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) {
db := database.FromCtx(ctx)
err := db.AddTalkgroupWithLearnedFlag(ctx, int32(c.System), int32(c.Talkgroup))
if err != nil {
return nil, fmt.Errorf("addTalkgroupWithLearnedFlag: %w", err)
}
sys, has := t.SystemName(ctx, c.System)
if !has {
return nil, ErrNoSuchSystem

View file

@ -124,17 +124,6 @@ INSERT INTO talkgroup_versions(time, created_by,
tg.learned
FROM talkgroups tg WHERE tg.system_id = @system_id AND tg.tgid = @tgid;
-- name: AddTalkgroupWithLearnedFlag :exec
INSERT INTO talkgroups (
system_id,
tgid,
learned
) VALUES(
@system_id,
@tgid,
TRUE
);
-- name: AddLearnedTalkgroup :one
INSERT INTO talkgroups(
system_id,
@ -151,3 +140,47 @@ INSERT INTO talkgroups(
sqlc.narg('alpha_tag'),
sqlc.narg('tg_group')
) RETURNING *;
-- name: RestoreTalkgroupVersion :one
INSERT INTO talkgroups(
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
)
SELECT
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
name = excluded.name,
alpha_tag = excluded.alpha_tag,
tg_group = excluded.tg_group,
metadata = excluded.metadata,
tags = excluded.tags,
alert = excluded.alert,
alert_config = excluded.alert_config,
weight = excluded.weight,
learned = excluded.learner,
ignored = excluded.ignored
WHERE tgv.id = ANY(@version_ids)
RETURNING *;