Fix learning
This commit is contained in:
parent
692f7d69a3
commit
a1b751fdf0
10 changed files with 179 additions and 91 deletions
|
@ -46,7 +46,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
|
|||
|
||||
// Make creates an alert for later rendering or storage.
|
||||
func Make(ctx context.Context, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
|
||||
store := tgstore.From(ctx)
|
||||
store := tgstore.FromCtx(ctx)
|
||||
d := Alert{
|
||||
Score: score,
|
||||
Timestamp: time.Now(),
|
||||
|
|
|
@ -179,54 +179,6 @@ func (_c *Store_AddLearnedTalkgroup_Call) RunAndReturn(run func(context.Context,
|
|||
return _c
|
||||
}
|
||||
|
||||
// AddTalkgroupWithLearnedFlag provides a mock function with given fields: ctx, systemID, tGID
|
||||
func (_m *Store) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
|
||||
ret := _m.Called(ctx, systemID, tGID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AddTalkgroupWithLearnedFlag")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int32, int32) error); ok {
|
||||
r0 = rf(ctx, systemID, tGID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Store_AddTalkgroupWithLearnedFlag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTalkgroupWithLearnedFlag'
|
||||
type Store_AddTalkgroupWithLearnedFlag_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AddTalkgroupWithLearnedFlag is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - systemID int32
|
||||
// - tGID int32
|
||||
func (_e *Store_Expecter) AddTalkgroupWithLearnedFlag(ctx interface{}, systemID interface{}, tGID interface{}) *Store_AddTalkgroupWithLearnedFlag_Call {
|
||||
return &Store_AddTalkgroupWithLearnedFlag_Call{Call: _e.mock.On("AddTalkgroupWithLearnedFlag", ctx, systemID, tGID)}
|
||||
}
|
||||
|
||||
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Run(run func(ctx context.Context, systemID int32, tGID int32)) *Store_AddTalkgroupWithLearnedFlag_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(int32), args[2].(int32))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) Return(_a0 error) *Store_AddTalkgroupWithLearnedFlag_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Store_AddTalkgroupWithLearnedFlag_Call) RunAndReturn(run func(context.Context, int32, int32) error) *Store_AddTalkgroupWithLearnedFlag_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// BulkSetTalkgroupTags provides a mock function with given fields: ctx, tgs, tags
|
||||
func (_m *Store) BulkSetTalkgroupTags(ctx context.Context, tgs database.TGTuples, tags []string) error {
|
||||
ret := _m.Called(ctx, tgs, tags)
|
||||
|
@ -1569,6 +1521,63 @@ func (_c *Store_InTx_Call) RunAndReturn(run func(context.Context, func(database.
|
|||
return _c
|
||||
}
|
||||
|
||||
// RestoreTalkgroupVersion provides a mock function with given fields: ctx, versionIds
|
||||
func (_m *Store) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (database.Talkgroup, error) {
|
||||
ret := _m.Called(ctx, versionIds)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RestoreTalkgroupVersion")
|
||||
}
|
||||
|
||||
var r0 database.Talkgroup
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) (database.Talkgroup, error)); ok {
|
||||
return rf(ctx, versionIds)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) database.Talkgroup); ok {
|
||||
r0 = rf(ctx, versionIds)
|
||||
} else {
|
||||
r0 = ret.Get(0).(database.Talkgroup)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
|
||||
r1 = rf(ctx, versionIds)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Store_RestoreTalkgroupVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreTalkgroupVersion'
|
||||
type Store_RestoreTalkgroupVersion_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RestoreTalkgroupVersion is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - versionIds int
|
||||
func (_e *Store_Expecter) RestoreTalkgroupVersion(ctx interface{}, versionIds interface{}) *Store_RestoreTalkgroupVersion_Call {
|
||||
return &Store_RestoreTalkgroupVersion_Call{Call: _e.mock.On("RestoreTalkgroupVersion", ctx, versionIds)}
|
||||
}
|
||||
|
||||
func (_c *Store_RestoreTalkgroupVersion_Call) Run(run func(ctx context.Context, versionIds int)) *Store_RestoreTalkgroupVersion_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(int))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Store_RestoreTalkgroupVersion_Call) Return(_a0 database.Talkgroup, _a1 error) *Store_RestoreTalkgroupVersion_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Store_RestoreTalkgroupVersion_Call) RunAndReturn(run func(context.Context, int) (database.Talkgroup, error)) *Store_RestoreTalkgroupVersion_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCallTranscript provides a mock function with given fields: ctx, iD, transcript
|
||||
func (_m *Store) SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error {
|
||||
ret := _m.Called(ctx, iD, transcript)
|
||||
|
|
|
@ -15,7 +15,6 @@ type Querier interface {
|
|||
AddAlert(ctx context.Context, arg AddAlertParams) error
|
||||
AddCall(ctx context.Context, arg AddCallParams) error
|
||||
AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgroupParams) (Talkgroup, error)
|
||||
AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error
|
||||
CreateAPIKey(ctx context.Context, owner int, expires pgtype.Timestamp, disabled *bool) (ApiKey, error)
|
||||
CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
|
||||
DeleteAPIKey(ctx context.Context, apiKey string) error
|
||||
|
@ -35,6 +34,7 @@ type Querier interface {
|
|||
GetUserByUID(ctx context.Context, id int) (User, error)
|
||||
GetUserByUsername(ctx context.Context, username string) (User, error)
|
||||
GetUsers(ctx context.Context) ([]User, error)
|
||||
RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error)
|
||||
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
|
||||
SetTalkgroupTags(ctx context.Context, tags []string, systemID int32, tGID int32) error
|
||||
StoreTGVersion(ctx context.Context, arg []StoreTGVersionParams) *StoreTGVersionBatchResults
|
||||
|
|
|
@ -66,23 +66,6 @@ func (q *Queries) AddLearnedTalkgroup(ctx context.Context, arg AddLearnedTalkgro
|
|||
return i, err
|
||||
}
|
||||
|
||||
const addTalkgroupWithLearnedFlag = `-- name: AddTalkgroupWithLearnedFlag :exec
|
||||
INSERT INTO talkgroups (
|
||||
system_id,
|
||||
tgid,
|
||||
learned
|
||||
) VALUES(
|
||||
$1,
|
||||
$2,
|
||||
TRUE
|
||||
)
|
||||
`
|
||||
|
||||
func (q *Queries) AddTalkgroupWithLearnedFlag(ctx context.Context, systemID int32, tGID int32) error {
|
||||
_, err := q.db.Exec(ctx, addTalkgroupWithLearnedFlag, systemID, tGID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSystemName = `-- name: GetSystemName :one
|
||||
SELECT name FROM systems WHERE id = $1
|
||||
`
|
||||
|
@ -394,6 +377,73 @@ func (q *Queries) GetTalkgroupsWithLearnedBySystem(ctx context.Context, system i
|
|||
return items, nil
|
||||
}
|
||||
|
||||
const restoreTalkgroupVersion = `-- name: RestoreTalkgroupVersion :one
|
||||
INSERT INTO talkgroups(
|
||||
system_id,
|
||||
tgid,
|
||||
name,
|
||||
alpha_tag,
|
||||
tg_group,
|
||||
frequency,
|
||||
metadata,
|
||||
tags,
|
||||
alert,
|
||||
alert_config,
|
||||
weight,
|
||||
learned,
|
||||
ignored
|
||||
)
|
||||
SELECT
|
||||
system_id,
|
||||
tgid,
|
||||
name,
|
||||
alpha_tag,
|
||||
tg_group,
|
||||
frequency,
|
||||
metadata,
|
||||
tags,
|
||||
alert,
|
||||
alert_config,
|
||||
weight,
|
||||
learned,
|
||||
ignored
|
||||
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
alpha_tag = excluded.alpha_tag,
|
||||
tg_group = excluded.tg_group,
|
||||
metadata = excluded.metadata,
|
||||
tags = excluded.tags,
|
||||
alert = excluded.alert,
|
||||
alert_config = excluded.alert_config,
|
||||
weight = excluded.weight,
|
||||
learned = excluded.learner,
|
||||
ignored = excluded.ignored
|
||||
WHERE tgv.id = ANY($1)
|
||||
RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight, learned, ignored
|
||||
`
|
||||
|
||||
func (q *Queries) RestoreTalkgroupVersion(ctx context.Context, versionIds int) (Talkgroup, error) {
|
||||
row := q.db.QueryRow(ctx, restoreTalkgroupVersion, versionIds)
|
||||
var i Talkgroup
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SystemID,
|
||||
&i.TGID,
|
||||
&i.Name,
|
||||
&i.AlphaTag,
|
||||
&i.TGGroup,
|
||||
&i.Frequency,
|
||||
&i.Metadata,
|
||||
&i.Tags,
|
||||
&i.Alert,
|
||||
&i.AlertConfig,
|
||||
&i.Weight,
|
||||
&i.Learned,
|
||||
&i.Ignored,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const setTalkgroupTags = `-- name: SetTalkgroupTags :exec
|
||||
UPDATE talkgroups SET tags = $1
|
||||
WHERE system_id = $2 AND tgid = $3
|
||||
|
|
|
@ -60,7 +60,7 @@ func (c *client) SendError(cmd *pb.Command, err error) {
|
|||
}
|
||||
|
||||
func (c *client) Talkgroup(ctx context.Context, tg *pb.Talkgroup) error {
|
||||
tgi, err := tgstore.From(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup))
|
||||
tgi, err := tgstore.FromCtx(ctx).TG(ctx, talkgroups.TG(tg.System, tg.Talkgroup))
|
||||
if err != nil {
|
||||
if err != tgstore.ErrNotFound {
|
||||
log.Error().Err(err).Int32("sys", tg.System).Int32("tg", tg.Talkgroup).Msg("get talkgroup fail")
|
||||
|
|
|
@ -54,7 +54,7 @@ func (t tgParams) ToID() talkgroups.ID {
|
|||
|
||||
func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
tgs := tgstore.From(ctx)
|
||||
tgs := tgstore.FromCtx(ctx)
|
||||
|
||||
var p tgParams
|
||||
|
||||
|
@ -92,7 +92,7 @@ func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tgs := tgstore.From(ctx)
|
||||
tgs := tgstore.FromCtx(ctx)
|
||||
|
||||
input := database.UpdateTalkgroupParams{}
|
||||
|
||||
|
@ -143,7 +143,7 @@ func (tga *talkgroupAPI) putTalkgroups(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tgs := tgstore.From(ctx)
|
||||
tgs := tgstore.FromCtx(ctx)
|
||||
|
||||
var input []database.UpsertTalkgroupParams
|
||||
|
||||
|
|
|
@ -47,12 +47,12 @@ func (s *DatabaseSink) Call(ctx context.Context, call *calls.Call) error {
|
|||
return s.db.InTx(ctx, func(tx database.Store) error {
|
||||
_, err := s.tgs.LearnTG(ctx, call)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add call: learn tg: %w", err)
|
||||
return fmt.Errorf("learn tg: %w", err)
|
||||
}
|
||||
|
||||
err = tx.AddCall(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add call: retry: %w", err)
|
||||
return fmt.Errorf("learn tg retry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -67,7 +67,7 @@ var rrRE = regexp.MustCompile(`DEC\s+HEX\s+Mode\s+Alpha Tag\s+Description\s+Tag`
|
|||
func (rr *radioReferenceImporter) importTalkgroups(ctx context.Context, sys int, r io.Reader) ([]talkgroups.Talkgroup, error) {
|
||||
sc := bufio.NewScanner(r)
|
||||
tgs := make([]talkgroups.Talkgroup, 0, 8)
|
||||
sysn, has := tgstore.From(ctx).SystemName(ctx, sys)
|
||||
sysn, has := tgstore.FromCtx(ctx).SystemName(ctx, sys)
|
||||
if !has {
|
||||
return nil, tgstore.ErrNoSuchSystem
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package tgstore
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -72,7 +71,7 @@ func CtxWithStore(ctx context.Context, s Store) context.Context {
|
|||
return context.WithValue(ctx, StoreCtxKey, s)
|
||||
}
|
||||
|
||||
func From(ctx context.Context) Store {
|
||||
func FromCtx(ctx context.Context) Store {
|
||||
s, ok := ctx.Value(StoreCtxKey).(Store)
|
||||
if !ok {
|
||||
return NewCache()
|
||||
|
@ -307,10 +306,7 @@ func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupPara
|
|||
|
||||
func (t *cache) LearnTG(ctx context.Context, c *calls.Call) (*tgsp.Talkgroup, error) {
|
||||
db := database.FromCtx(ctx)
|
||||
err := db.AddTalkgroupWithLearnedFlag(ctx, int32(c.System), int32(c.Talkgroup))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("addTalkgroupWithLearnedFlag: %w", err)
|
||||
}
|
||||
|
||||
sys, has := t.SystemName(ctx, c.System)
|
||||
if !has {
|
||||
return nil, ErrNoSuchSystem
|
||||
|
|
|
@ -124,17 +124,6 @@ INSERT INTO talkgroup_versions(time, created_by,
|
|||
tg.learned
|
||||
FROM talkgroups tg WHERE tg.system_id = @system_id AND tg.tgid = @tgid;
|
||||
|
||||
-- name: AddTalkgroupWithLearnedFlag :exec
|
||||
INSERT INTO talkgroups (
|
||||
system_id,
|
||||
tgid,
|
||||
learned
|
||||
) VALUES(
|
||||
@system_id,
|
||||
@tgid,
|
||||
TRUE
|
||||
);
|
||||
|
||||
-- name: AddLearnedTalkgroup :one
|
||||
INSERT INTO talkgroups(
|
||||
system_id,
|
||||
|
@ -151,3 +140,47 @@ INSERT INTO talkgroups(
|
|||
sqlc.narg('alpha_tag'),
|
||||
sqlc.narg('tg_group')
|
||||
) RETURNING *;
|
||||
|
||||
-- name: RestoreTalkgroupVersion :one
|
||||
INSERT INTO talkgroups(
|
||||
system_id,
|
||||
tgid,
|
||||
name,
|
||||
alpha_tag,
|
||||
tg_group,
|
||||
frequency,
|
||||
metadata,
|
||||
tags,
|
||||
alert,
|
||||
alert_config,
|
||||
weight,
|
||||
learned,
|
||||
ignored
|
||||
)
|
||||
SELECT
|
||||
system_id,
|
||||
tgid,
|
||||
name,
|
||||
alpha_tag,
|
||||
tg_group,
|
||||
frequency,
|
||||
metadata,
|
||||
tags,
|
||||
alert,
|
||||
alert_config,
|
||||
weight,
|
||||
learned,
|
||||
ignored
|
||||
FROM talkgroup_versions tgv ON CONFLICT (system_id, tgid) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
alpha_tag = excluded.alpha_tag,
|
||||
tg_group = excluded.tg_group,
|
||||
metadata = excluded.metadata,
|
||||
tags = excluded.tags,
|
||||
alert = excluded.alert,
|
||||
alert_config = excluded.alert_config,
|
||||
weight = excluded.weight,
|
||||
learned = excluded.learner,
|
||||
ignored = excluded.ignored
|
||||
WHERE tgv.id = ANY(@version_ids)
|
||||
RETURNING *;
|
||||
|
|
Loading…
Reference in a new issue