Fix learning

This commit is contained in:
Daniel 2024-11-21 07:44:08 -05:00
parent 692f7d69a3
commit a1b751fdf0
10 changed files with 179 additions and 91 deletions

View file

@ -46,7 +46,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
// Make creates an alert for later rendering or storage. // Make creates an alert for later rendering or storage.
func Make(ctx context.Context, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { func Make(ctx context.Context, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
store := tgstore.From(ctx) store := tgstore.FromCtx(ctx)
d := Alert{ d := Alert{
Score: score, Score: score,
Timestamp: time.Now(), Timestamp: time.Now(),

View file

@ -179,54 +179,6 @@ func (_c *Store_AddLearnedTalkgroup_Call) RunAndReturn(run func(context.Context,
return _c return _c
} }
// AddTalkgroupWithLearnedFlag provides a mock function with given fields: ctx, systemID, tGID
func (_m *Store) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
ret := _m.Called(ctx, systemID, tGID)
if len(ret) == 0 {
panic("no return value specified for AddTalkgroupWithLearnedFlag")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int32, int32) error); ok {
r0 = rf(ctx, systemID, tGID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Store_AddTalkgroupWithLearnedFlag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTalkgroupWithLearnedFlag'
type Store_AddTalkgroupWithLearnedFlag_Call struct {
*mock.Call
}
// AddTalkgroupWithLearnedFlag is a helper method to define mock.On call
// - ctx context.Context
// - systemID int32
// - tGID int32
func (_e *Store_Expecter) AddTalkgroupWithLearnedFlag(ctx interface{}, systemID interface{}, tGID interface{}) *Store_AddTalkgroupWithLearnedFlag_Call {
return &Store_AddTalkgroupWithLearnedFlag_Call{Call: _e.mock.On("AddTalkgroupWithLearnedFlag", ctx, systemID, tGID)}
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Run(run func(ctx context.Context, systemID int32, tGID int32)) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int32), args[2].(int32))
})
return _c
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Return(_a0 error) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) RunAndReturn(run func(context.Context, int32, int32) error) *Store_AddTalkgroupWithLearnedFlag_Call {
_c.Call.Return(run)
return _c
}
// BulkSetTalkgroupTags provides a mock function with given fields: ctx, tgs, tags // BulkSetTalkgroupTags provides a mock function with given fields: ctx, tgs, tags
func (_m *Store) BulkSetTalkgroupTags(ctx context.Context, tgs database.TGTuples, tags []string) error { func (_m *Store) BulkSetTalkgroupTags(ctx context.Context, tgs database.TGTuples, tags []string) error {
ret := _m.Called(ctx, tgs, tags) ret := _m.Called(ctx, tgs, tags)
@ -1569,6 +1521,63 @@ func (_c *Store_InTx_Call) RunAndReturn(run func(context.Context, func(database.
return _c return _c
} }
// RestoreTalkgroupVersion provides a mock function with given fields: ctx, versionIds
func (_m *Store) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (database.Talkgroup, error) {
ret := _m.Called(ctx, versionIds)
if len(ret) == 0 {
panic("no return value specified for RestoreTalkgroupVersion")
}
var r0 database.Talkgroup
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int) (database.Talkgroup, error)); ok {
return rf(ctx, versionIds)
}
if rf, ok := ret.Get(0).(func(context.Context, int) database.Talkgroup); ok {
r0 = rf(ctx, versionIds)
} else {
r0 = ret.Get(0).(database.Talkgroup)
}
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, versionIds)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store_RestoreTalkgroupVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreTalkgroupVersion'
type Store_RestoreTalkgroupVersion_Call struct {
*mock.Call
}
// RestoreTalkgroupVersion is a helper method to define mock.On call
// - ctx context.Context
// - versionIds int
func (_e *Store_Expecter) RestoreTalkgroupVersion(ctx interface{}, versionIds interface{}) *Store_RestoreTalkgroupVersion_Call {
return &Store_RestoreTalkgroupVersion_Call{Call: _e.mock.On("RestoreTalkgroupVersion", ctx, versionIds)}
}
func (_c *Store_RestoreTalkgroupVersion_Call) Run(run func(ctx context.Context, versionIds int)) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int))
})
return _c
}
func (_c *Store_RestoreTalkgroupVersion_Call) Return(_a0 database.Talkgroup, _a1 error) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Store_RestoreTalkgroupVersion_Call) RunAndReturn(run func(context.Context, int) (database.Talkgroup, error)) *Store_RestoreTalkgroupVersion_Call {
_c.Call.Return(run)
return _c
}
// SetCallTranscript provides a mock function with given fields: ctx, iD, transcript // SetCallTranscript provides a mock function with given fields: ctx, iD, transcript
func (_m *Store) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error { func (_m *Store) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error {
ret := _m.Called(ctx, iD, transcript) ret := _m.Called(ctx, iD, transcript)

View file

@ -15,7 +15,6 @@ type Querier interface {
AddAlert(ctx context.Context, arg AddAlertParams) error AddAlert(ctx context.Context, arg AddAlertParams) error
AddCall(ctx context.Context, arg AddCallParams) error AddCall(ctx context.Context, arg AddCallParams) error
AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgroupParams) (Talkgroup, error) AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgroupParams) (Talkgroup, error)
AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error
CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error) CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error)
CreateUser(ctx context.Context, arg CreateUserParams) (User, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
DeleteAPIKey(ctx context.Context, apiKey string) error DeleteAPIKey(ctx context.Context, apiKey string) error
@ -35,6 +34,7 @@ type Querier interface {
GetUserByUID(ctx context.Context, id int) (User, error) GetUserByUID(ctx context.Context, id int) (User, error)
GetUserByUsername(ctx context.Context, username string) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error)
GetUsers(ctx context.Context) ([]User, error) GetUsers(ctx context.Context) ([]User, error)
RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error)
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
SetTalkgroupTags(ctx context.Context, tags []string, systemID int32, tGID int32) error SetTalkgroupTags(ctx context.Context, tags []string, systemID int32, tGID int32) error
StoreTGVersion(ctx context.Context, arg []StoreTGVersionParams) *StoreTGVersionBatchResults StoreTGVersion(ctx context.Context, arg []StoreTGVersionParams) *StoreTGVersionBatchResults

View file

@ -66,23 +66,6 @@ func (q *Queries) AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgro
return i, err return i, err
} }
const addTalkgroupWithLearnedFlag = `-- name: AddTalkgroupWithLearnedFlag :exec
INSERT INTO talkgroups (
system_id,
tgid,
learned
) VALUES(
$1,
$2,
TRUE
)
`
func (q *Queries) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
_, err := q.db.Exec(ctx, addTalkgroupWithLearnedFlag, systemID, tGID)
return err
}
const getSystemName = `-- name: GetSystemName :one const getSystemName = `-- name: GetSystemName :one
SELECT name FROM systems WHERE id = $1 SELECT name FROM systems WHERE id = $1
` `
@ -394,6 +377,73 @@ func (q *Queries) GetTalkgroupsWithLearnedBySystem(ctx context.Context, system i
return items, nil return items, nil
} }
const restoreTalkgroupVersion = `-- name: RestoreTalkgroupVersion :one
INSERT INTO talkgroups(
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
)
SELECT
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
name = excluded.name,
alpha_tag = excluded.alpha_tag,
tg_group = excluded.tg_group,
metadata = excluded.metadata,
tags = excluded.tags,
alert = excluded.alert,
alert_config = excluded.alert_config,
weight = excluded.weight,
learned = excluded.learner,
ignored = excluded.ignored
WHERE tgv.id = ANY($1)
RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight, learned, ignored
`
func (q *Queries) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) {
row := q.db.QueryRow(ctx, restoreTalkgroupVersion, versionIds)
var i Talkgroup
err := row.Scan(
&i.ID,
&i.SystemID,
&i.TGID,
&i.Name,
&i.AlphaTag,
&i.TGGroup,
&i.Frequency,
&i.Metadata,
&i.Tags,
&i.Alert,
&i.AlertConfig,
&i.Weight,
&i.Learned,
&i.Ignored,
)
return i, err
}
const setTalkgroupTags = `-- name: SetTalkgroupTags :exec const setTalkgroupTags = `-- name: SetTalkgroupTags :exec
UPDATE talkgroups SET tags = $1 UPDATE talkgroups SET tags = $1
WHERE system_id = $2 AND tgid = $3 WHERE system_id = $2 AND tgid = $3

View file

@ -60,7 +60,7 @@ func (c *client) SendError(cmd *pb.Command, err error) {
} }
func (c *client) Talkgroup(ctx context.Context, tg *pb.Talkgroup) error { func (c *client) Talkgroup(ctx context.Context, tg *pb.Talkgroup) error {
tgi, err := tgstore.From(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup)) tgi, err := tgstore.FromCtx(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup))
if err != nil { if err != nil {
if err != tgstore.ErrNotFound { if err != tgstore.ErrNotFound {
log.Error().Err(err).Int32("sys", tg.System).Int32("tg", tg.Talkgroup).Msg("get talkgroup fail") log.Error().Err(err).Int32("sys", tg.System).Int32("tg", tg.Talkgroup).Msg("get talkgroup fail")

View file

@ -54,7 +54,7 @@ func (t tgParams) ToID() talkgroups.ID {
func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) { func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
tgs := tgstore.From(ctx) tgs := tgstore.FromCtx(ctx)
var p tgParams var p tgParams
@ -92,7 +92,7 @@ func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
tgs := tgstore.From(ctx) tgs := tgstore.FromCtx(ctx)
input := database.UpdateTalkgroupParams{} input := database.UpdateTalkgroupParams{}
@ -143,7 +143,7 @@ func (tga *talkgroupAPI) putTalkgroups(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
tgs := tgstore.From(ctx) tgs := tgstore.FromCtx(ctx)
var input []database.UpsertTalkgroupParams var input []database.UpsertTalkgroupParams

View file

@ -47,12 +47,12 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error {
return s.db.InTx(ctx, func(tx database.Store) error { return s.db.InTx(ctx, func(tx database.Store) error {
_, err := s.tgs.LearnTG(ctx, call) _, err := s.tgs.LearnTG(ctx, call)
if err != nil { if err != nil {
return fmt.Errorf("add call: learn tg: %w", err) return fmt.Errorf("learn tg: %w", err)
} }
err = tx.AddCall(ctx, params) err = tx.AddCall(ctx, params)
if err != nil { if err != nil {
return fmt.Errorf("add call: retry: %w", err) return fmt.Errorf("learn tg retry: %w", err)
} }
return nil return nil

View file

@ -67,7 +67,7 @@ var rrRE = regexp.MustCompile(`DEC\s+HEX\s+Mode\s+Alpha Tag\s+Description\s+Tag`
func (rr *radioReferenceImporter) importTalkgroups(ctx context.Context, sys int, r io.Reader) ([]talkgroups.Talkgroup, error) { func (rr *radioReferenceImporter) importTalkgroups(ctx context.Context, sys int, r io.Reader) ([]talkgroups.Talkgroup, error) {
sc := bufio.NewScanner(r) sc := bufio.NewScanner(r)
tgs := make([]talkgroups.Talkgroup, 0, 8) tgs := make([]talkgroups.Talkgroup, 0, 8)
sysn, has := tgstore.From(ctx).SystemName(ctx, sys) sysn, has := tgstore.FromCtx(ctx).SystemName(ctx, sys)
if !has { if !has {
return nil, tgstore.ErrNoSuchSystem return nil, tgstore.ErrNoSuchSystem
} }

View file

@ -3,7 +3,6 @@ package tgstore
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -72,7 +71,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
return context.WithValue(ctx, StoreCtxKey, s) return context.WithValue(ctx, StoreCtxKey, s)
} }
func From(ctx context.Context) Store { func FromCtx(ctx context.Context) Store {
s, ok := ctx.Value(StoreCtxKey).(Store) s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok { if !ok {
return NewCache() return NewCache()
@ -307,10 +306,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara
func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) { func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) {
db := database.FromCtx(ctx) db := database.FromCtx(ctx)
err := db.AddTalkgroupWithLearnedFlag(ctx, int32(c.System), int32(c.Talkgroup))
if err != nil {
return nil, fmt.Errorf("addTalkgroupWithLearnedFlag: %w", err)
}
sys, has := t.SystemName(ctx, c.System) sys, has := t.SystemName(ctx, c.System)
if !has { if !has {
return nil, ErrNoSuchSystem return nil, ErrNoSuchSystem

View file

@ -124,17 +124,6 @@ INSERT INTO talkgroup_versions(time, created_by,
tg.learned tg.learned
FROM talkgroups tg WHERE tg.system_id = @system_id AND tg.tgid = @tgid; FROM talkgroups tg WHERE tg.system_id = @system_id AND tg.tgid = @tgid;
-- name: AddTalkgroupWithLearnedFlag :exec
INSERT INTO talkgroups (
system_id,
tgid,
learned
) VALUES(
@system_id,
@tgid,
TRUE
);
-- name: AddLearnedTalkgroup :one -- name: AddLearnedTalkgroup :one
INSERT INTO talkgroups( INSERT INTO talkgroups(
system_id, system_id,
@ -151,3 +140,47 @@ INSERT INTO talkgroups(
sqlc.narg('alpha_tag'), sqlc.narg('alpha_tag'),
sqlc.narg('tg_group') sqlc.narg('tg_group')
) RETURNING *; ) RETURNING *;
-- name: RestoreTalkgroupVersion :one
INSERT INTO talkgroups(
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
)
SELECT
system_id,
tgid,
name,
alpha_tag,
tg_group,
frequency,
metadata,
tags,
alert,
alert_config,
weight,
learned,
ignored
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
name = excluded.name,
alpha_tag = excluded.alpha_tag,
tg_group = excluded.tg_group,
metadata = excluded.metadata,
tags = excluded.tags,
alert = excluded.alert,
alert_config = excluded.alert_config,
weight = excluded.weight,
learned = excluded.learner,
ignored = excluded.ignored
WHERE tgv.id = ANY(@version_ids)
RETURNING *;