Refactor alertrules so they unmarshal by pgx #32

Merged
amigan merged 1 commit from reid into trunk 2024-11-12 08:42:36 -05:00
10 changed files with 95 additions and 79 deletions

View file

@ -0,0 +1,36 @@
package rules
import (
"time"
"dynatron.me/x/stillbox/internal/ruletime"
)
type AlertRules []AlertRule
func (ars *AlertRules) Apply(t time.Time, coversOpts ...ruletime.CoversOption) float64 {
final := 1.0
for _, ar := range *ars {
if ar.MatchTime(t, coversOpts...) {
final *= float64(ar.ScoreMultiplier)
}
}
return final
}
type AlertRule struct {
Times []ruletime.RuleTime `json:"times"`
ScoreMultiplier float32 `json:"mult"`
}
func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool {
for _, at := range ar.Times {
if at.Covers(t, coversOpts...) {
return true
}
}
return false
}

View file

@ -1,6 +1,7 @@
package talkgroups_test package rules_test
import ( import (
"encoding/json"
"errors" "errors"
"math" "math"
"testing" "testing"
@ -8,6 +9,7 @@ import (
"dynatron.me/x/stillbox/internal/ruletime" "dynatron.me/x/stillbox/internal/ruletime"
"dynatron.me/x/stillbox/internal/trending" "dynatron.me/x/stillbox/internal/trending"
"dynatron.me/x/stillbox/pkg/alerting/rules"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -20,14 +22,14 @@ func TestAlertConfig(t *testing.T) {
name string name string
tg talkgroups.ID tg talkgroups.ID
conf string conf string
compare []talkgroups.AlertRule compare rules.AlertRules
expectErr error expectErr error
}{ }{
{ {
name: "base case", name: "base case",
tg: talkgroups.TG(197, 3), tg: talkgroups.TG(197, 3),
conf: `[{"times":["7:00+2h","01:00+1h","16:00+1h","19:00+4h"],"mult":0.2},{"times":["11:00+1h","15:00+30m","16:03+20m"],"mult":2.0}]`, conf: `[{"times":["7:00+2h","01:00+1h","16:00+1h","19:00+4h"],"mult":0.2},{"times":["11:00+1h","15:00+30m","16:03+20m"],"mult":2.0}]`,
compare: []talkgroups.AlertRule{ compare: rules.AlertRules{
{ {
Times: []ruletime.RuleTime{ Times: []ruletime.RuleTime{
ruletime.Must(ruletime.New("7:00+2h")), ruletime.Must(ruletime.New("7:00+2h")),
@ -57,11 +59,13 @@ func TestAlertConfig(t *testing.T) {
for _, tc := range parseTests { for _, tc := range parseTests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := ac.UnmarshalTGRules(tc.tg, []byte(tc.conf)) var ar rules.AlertRules
err := json.Unmarshal([]byte(tc.conf), &ar)
if tc.expectErr != nil { if tc.expectErr != nil {
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr.Error()) assert.Contains(t, err.Error(), tc.expectErr.Error())
} else { } else {
ac.Add(tc.tg, ar)
assert.Equal(t, tc.compare, ac.GetRules(tc.tg)) assert.Equal(t, tc.compare, ac.GetRules(tc.tg))
} }
}) })

View file

@ -7,6 +7,7 @@ package database
import ( import (
"time" "time"
"dynatron.me/x/stillbox/pkg/alerting/rules"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
@ -82,18 +83,18 @@ type System struct {
} }
type Talkgroup struct { type Talkgroup struct {
ID int64 `json:"id"` ID int64 `json:"id"`
SystemID int32 `json:"system_id"` SystemID int32 `json:"system_id"`
Tgid int32 `json:"tgid"` Tgid int32 `json:"tgid"`
Name *string `json:"name"` Name *string `json:"name"`
AlphaTag *string `json:"alpha_tag"` AlphaTag *string `json:"alpha_tag"`
TgGroup *string `json:"tg_group"` TgGroup *string `json:"tg_group"`
Frequency *int32 `json:"frequency"` Frequency *int32 `json:"frequency"`
Metadata []byte `json:"metadata"` Metadata []byte `json:"metadata"`
Tags []string `json:"tags"` Tags []string `json:"tags"`
Alert bool `json:"alert"` Alert bool `json:"alert"`
AlertConfig []byte `json:"alert_config"` AlertConfig rules.AlertRules `json:"alert_config"`
Weight float32 `json:"weight"` Weight float32 `json:"weight"`
} }
type TalkgroupsLearned struct { type TalkgroupsLearned struct {

View file

@ -7,6 +7,8 @@ package database
import ( import (
"context" "context"
"dynatron.me/x/stillbox/pkg/alerting/rules"
) )
const bulkSetTalkgroupTags = `-- name: BulkSetTalkgroupTags :exec const bulkSetTalkgroupTags = `-- name: BulkSetTalkgroupTags :exec
@ -490,16 +492,16 @@ RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, t
` `
type UpdateTalkgroupParams struct { type UpdateTalkgroupParams struct {
Name *string `json:"name"` Name *string `json:"name"`
AlphaTag *string `json:"alpha_tag"` AlphaTag *string `json:"alpha_tag"`
TgGroup *string `json:"tg_group"` TgGroup *string `json:"tg_group"`
Frequency *int32 `json:"frequency"` Frequency *int32 `json:"frequency"`
Metadata []byte `json:"metadata"` Metadata []byte `json:"metadata"`
Tags []string `json:"tags"` Tags []string `json:"tags"`
Alert *bool `json:"alert"` Alert *bool `json:"alert"`
AlertConfig []byte `json:"alert_config"` AlertConfig rules.AlertRules `json:"alert_config"`
Weight *float32 `json:"weight"` Weight *float32 `json:"weight"`
ID int64 `json:"id"` ID int64 `json:"id"`
} }
func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) { func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) {

View file

@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const getTalkgroupWithLearnedByPackedIDsTest = `-- name: GetTalkgroupWithLearnedByPackedIDs :many const getTalkgroupsWithLearnedByPackedIDsTest = `-- name: GetTalkgroupsWithLearnedByPackedIDs :many
SELECT SELECT
tg.id, tg.system_id, tg.tgid, tg.name, tg.alpha_tag, tg.tg_group, tg.frequency, tg.metadata, tg.tags, tg.alert, tg.alert_config, tg.weight, sys.id, sys.name, tg.id, tg.system_id, tg.tgid, tg.name, tg.alpha_tag, tg.tg_group, tg.frequency, tg.metadata, tg.tags, tg.alert, tg.alert_config, tg.weight, sys.id, sys.name,
FALSE learned FALSE learned
@ -59,7 +59,7 @@ TRUE, NULL::JSONB, 1.0, sys.id, sys.name,
TRUE learned TRUE learned
FROM talkgroups_learned tgl FROM talkgroups_learned tgl
JOIN systems sys ON tgl.system_id = sys.id JOIN systems sys ON tgl.system_id = sys.id
WHERE tg.system_id = $1 AND ignored IS NOT TRUE WHERE tgl.system_id = $1 AND ignored IS NOT TRUE
` `
const getTalkgroupsWithLearnedTest = `-- name: GetTalkgroupsWithLearned :many const getTalkgroupsWithLearnedTest = `-- name: GetTalkgroupsWithLearned :many
@ -81,7 +81,7 @@ WHERE ignored IS NOT TRUE
` `
func TestQueryColumnsMatch(t *testing.T) { func TestQueryColumnsMatch(t *testing.T) {
require.Equal(t, getTalkgroupsWithLearnedByPackedIDsTest, getTalkgroupWithLearnedByPackedIDs) require.Equal(t, getTalkgroupsWithLearnedByPackedIDsTest, getTalkgroupsWithLearnedByPackedIDs)
require.Equal(t, getTalkgroupWithLearnedTest, getTalkgroupWithLearned) require.Equal(t, getTalkgroupWithLearnedTest, getTalkgroupWithLearned)
require.Equal(t, getTalkgroupsWithLearnedBySystemTest, getTalkgroupsWithLearnedBySystem) require.Equal(t, getTalkgroupsWithLearnedBySystemTest, getTalkgroupsWithLearnedBySystem)
require.Equal(t, getTalkgroupsWithLearnedTest, getTalkgroupsWithLearned) require.Equal(t, getTalkgroupsWithLearnedTest, getTalkgroupsWithLearned)

View file

@ -43,7 +43,6 @@ type errResponse struct {
func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
switch e.Code { switch e.Code {
case http.StatusNotFound: case http.StatusNotFound:
case http.StatusBadRequest:
default: default:
log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed") log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed")
} }

View file

@ -37,7 +37,7 @@ type Server struct {
notifier notify.Notifier notifier notify.Notifier
hup chan os.Signal hup chan os.Signal
tgs talkgroups.Store tgs talkgroups.Store
rest rest.API rest rest.API
} }
func New(ctx context.Context, cfg *config.Config) (*Server, error) { func New(ctx context.Context, cfg *config.Config) (*Server, error) {
@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)),
notifier: notifier, notifier: notifier,
tgs: tgCache, tgs: tgCache,
rest: api, rest: api,
} }
srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true) srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true)

View file

@ -1,54 +1,39 @@
package talkgroups package talkgroups
import ( import (
"encoding/json"
"sync" "sync"
"time" "time"
"dynatron.me/x/stillbox/internal/ruletime" "dynatron.me/x/stillbox/internal/ruletime"
"dynatron.me/x/stillbox/pkg/alerting/rules"
) )
type AlertConfig struct { type AlertConfig struct {
sync.RWMutex sync.RWMutex
m map[ID][]AlertRule m map[ID]rules.AlertRules
}
type AlertRule struct {
Times []ruletime.RuleTime `json:"times"`
ScoreMultiplier float32 `json:"mult"`
} }
func NewAlertConfig() AlertConfig { func NewAlertConfig() AlertConfig {
return AlertConfig{ return AlertConfig{
m: make(map[ID][]AlertRule), m: make(map[ID]rules.AlertRules),
} }
} }
func (ac *AlertConfig) GetRules(tg ID) []AlertRule { func (ac *AlertConfig) Add(tg ID, r rules.AlertRules) error {
ac.Lock()
defer ac.Unlock()
ac.m[tg] = r
return nil
}
func (ac *AlertConfig) GetRules(tg ID) rules.AlertRules {
ac.RLock() ac.RLock()
defer ac.RUnlock() defer ac.RUnlock()
return ac.m[tg] return ac.m[tg]
} }
func (ac *AlertConfig) UnmarshalTGRules(tg ID, confBytes []byte) error {
ac.Lock()
defer ac.Unlock()
if len(confBytes) == 0 {
return nil
}
var rules []AlertRule
err := json.Unmarshal(confBytes, &rules)
if err != nil {
return err
}
ac.m[tg] = rules
return nil
}
func (ac *AlertConfig) ApplyAlertRules(id ID, t time.Time, coversOpts ...ruletime.CoversOption) float64 { func (ac *AlertConfig) ApplyAlertRules(id ID, t time.Time, coversOpts ...ruletime.CoversOption) float64 {
ac.RLock() ac.RLock()
s, has := ac.m[id] s, has := ac.m[id]
@ -57,15 +42,7 @@ func (ac *AlertConfig) ApplyAlertRules(id ID, t time.Time, coversOpts ...ruletim
return 1.0 return 1.0
} }
final := 1.0 return s.Apply(t, coversOpts...)
for _, ar := range s {
if ar.MatchTime(t, coversOpts...) {
final *= float64(ar.ScoreMultiplier)
}
}
return final
} }
func (ac *AlertConfig) Invalidate() { func (ac *AlertConfig) Invalidate() {
@ -74,13 +51,3 @@ func (ac *AlertConfig) Invalidate() {
clear(ac.m) clear(ac.m)
} }
func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool {
for _, at := range ar.Times {
if at.Covers(t, coversOpts...) {
return true
}
}
return false
}

View file

@ -139,7 +139,9 @@ func (t *cache) add(rec *Talkgroup) error {
t.tgs[tg] = rec t.tgs[tg] = rec
t.systems[int32(rec.System.ID)] = rec.System.Name t.systems[int32(rec.System.ID)] = rec.System.Name
return t.AlertConfig.UnmarshalTGRules(tg, rec.Talkgroup.AlertConfig) t.AlertConfig.Add(tg, rec.AlertConfig)
return nil
} }
type row interface { type row interface {

View file

@ -27,3 +27,8 @@ sql:
go_type: "time.Time" go_type: "time.Time"
- db_type: "pg_catalog.text" - db_type: "pg_catalog.text"
go_type: "string" go_type: "string"
- column: "talkgroups.alert_config"
go_type:
import: "dynatron.me/x/stillbox/pkg/alerting/rules"
type: "AlertRules"
nullable: true