From 9e4144545a830833ff33f8f65ccba3fcaa1c7c6e Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Tue, 12 Nov 2024 08:40:56 -0500 Subject: [PATCH] Refactor alertrules so they unmarshal by pgx --- pkg/alerting/rules/alertrules.go | 36 +++++++++++ .../rules/alertrules_test.go} | 12 ++-- pkg/database/models.go | 25 ++++---- pkg/database/talkgroups.sql.go | 22 +++---- pkg/database/talkgroups.sql_test.go | 6 +- pkg/rest/api.go | 1 - pkg/server/server.go | 4 +- pkg/talkgroups/alertconfig.go | 59 ++++--------------- pkg/talkgroups/cache.go | 4 +- sql/sqlc.yaml | 5 ++ 10 files changed, 95 insertions(+), 79 deletions(-) create mode 100644 pkg/alerting/rules/alertrules.go rename pkg/{talkgroups/alertconfig_test.go => alerting/rules/alertrules_test.go} (93%) diff --git a/pkg/alerting/rules/alertrules.go b/pkg/alerting/rules/alertrules.go new file mode 100644 index 0000000..22cd0f1 --- /dev/null +++ b/pkg/alerting/rules/alertrules.go @@ -0,0 +1,36 @@ +package rules + +import ( + "time" + + "dynatron.me/x/stillbox/internal/ruletime" +) + +type AlertRules []AlertRule + +func (ars *AlertRules) Apply(t time.Time, coversOpts ...ruletime.CoversOption) float64 { + final := 1.0 + + for _, ar := range *ars { + if ar.MatchTime(t, coversOpts...) { + final *= float64(ar.ScoreMultiplier) + } + } + + return final +} + +type AlertRule struct { + Times []ruletime.RuleTime `json:"times"` + ScoreMultiplier float32 `json:"mult"` +} + +func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool { + for _, at := range ar.Times { + if at.Covers(t, coversOpts...) { + return true + } + } + + return false +} diff --git a/pkg/talkgroups/alertconfig_test.go b/pkg/alerting/rules/alertrules_test.go similarity index 93% rename from pkg/talkgroups/alertconfig_test.go rename to pkg/alerting/rules/alertrules_test.go index 4819f13..22c1564 100644 --- a/pkg/talkgroups/alertconfig_test.go +++ b/pkg/alerting/rules/alertrules_test.go @@ -1,6 +1,7 @@ -package talkgroups_test +package rules_test import ( + "encoding/json" "errors" "math" "testing" @@ -8,6 +9,7 @@ import ( "dynatron.me/x/stillbox/internal/ruletime" "dynatron.me/x/stillbox/internal/trending" + "dynatron.me/x/stillbox/pkg/alerting/rules" "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/stretchr/testify/assert" @@ -20,14 +22,14 @@ func TestAlertConfig(t *testing.T) { name string tg talkgroups.ID conf string - compare []talkgroups.AlertRule + compare rules.AlertRules expectErr error }{ { name: "base case", tg: talkgroups.TG(197, 3), conf: `[{"times":["7:00+2h","01:00+1h","16:00+1h","19:00+4h"],"mult":0.2},{"times":["11:00+1h","15:00+30m","16:03+20m"],"mult":2.0}]`, - compare: []talkgroups.AlertRule{ + compare: rules.AlertRules{ { Times: []ruletime.RuleTime{ ruletime.Must(ruletime.New("7:00+2h")), @@ -57,11 +59,13 @@ func TestAlertConfig(t *testing.T) { for _, tc := range parseTests { t.Run(tc.name, func(t *testing.T) { - err := ac.UnmarshalTGRules(tc.tg, []byte(tc.conf)) + var ar rules.AlertRules + err := json.Unmarshal([]byte(tc.conf), &ar) if tc.expectErr != nil { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectErr.Error()) } else { + ac.Add(tc.tg, ar) assert.Equal(t, tc.compare, ac.GetRules(tc.tg)) } }) diff --git a/pkg/database/models.go b/pkg/database/models.go index 7a2b67d..9e3eb2a 100644 --- a/pkg/database/models.go +++ b/pkg/database/models.go @@ -7,6 +7,7 @@ package database import ( "time" + "dynatron.me/x/stillbox/pkg/alerting/rules" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" ) @@ -82,18 +83,18 @@ type System struct { } type Talkgroup struct { - ID int64 `json:"id"` - SystemID int32 `json:"system_id"` - Tgid int32 `json:"tgid"` - Name *string `json:"name"` - AlphaTag *string `json:"alpha_tag"` - TgGroup *string `json:"tg_group"` - Frequency *int32 `json:"frequency"` - Metadata []byte `json:"metadata"` - Tags []string `json:"tags"` - Alert bool `json:"alert"` - AlertConfig []byte `json:"alert_config"` - Weight float32 `json:"weight"` + ID int64 `json:"id"` + SystemID int32 `json:"system_id"` + Tgid int32 `json:"tgid"` + Name *string `json:"name"` + AlphaTag *string `json:"alpha_tag"` + TgGroup *string `json:"tg_group"` + Frequency *int32 `json:"frequency"` + Metadata []byte `json:"metadata"` + Tags []string `json:"tags"` + Alert bool `json:"alert"` + AlertConfig rules.AlertRules `json:"alert_config"` + Weight float32 `json:"weight"` } type TalkgroupsLearned struct { diff --git a/pkg/database/talkgroups.sql.go b/pkg/database/talkgroups.sql.go index 81d224f..d511a2b 100644 --- a/pkg/database/talkgroups.sql.go +++ b/pkg/database/talkgroups.sql.go @@ -7,6 +7,8 @@ package database import ( "context" + + "dynatron.me/x/stillbox/pkg/alerting/rules" ) const bulkSetTalkgroupTags = `-- name: BulkSetTalkgroupTags :exec @@ -490,16 +492,16 @@ RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, t ` type UpdateTalkgroupParams struct { - Name *string `json:"name"` - AlphaTag *string `json:"alpha_tag"` - TgGroup *string `json:"tg_group"` - Frequency *int32 `json:"frequency"` - Metadata []byte `json:"metadata"` - Tags []string `json:"tags"` - Alert *bool `json:"alert"` - AlertConfig []byte `json:"alert_config"` - Weight *float32 `json:"weight"` - ID int64 `json:"id"` + Name *string `json:"name"` + AlphaTag *string `json:"alpha_tag"` + TgGroup *string `json:"tg_group"` + Frequency *int32 `json:"frequency"` + Metadata []byte `json:"metadata"` + Tags []string `json:"tags"` + Alert *bool `json:"alert"` + AlertConfig rules.AlertRules `json:"alert_config"` + Weight *float32 `json:"weight"` + ID int64 `json:"id"` } func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) { diff --git a/pkg/database/talkgroups.sql_test.go b/pkg/database/talkgroups.sql_test.go index 248b3dc..10f9888 100644 --- a/pkg/database/talkgroups.sql_test.go +++ b/pkg/database/talkgroups.sql_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -const getTalkgroupWithLearnedByPackedIDsTest = `-- name: GetTalkgroupWithLearnedByPackedIDs :many +const getTalkgroupsWithLearnedByPackedIDsTest = `-- name: GetTalkgroupsWithLearnedByPackedIDs :many SELECT tg.id, tg.system_id, tg.tgid, tg.name, tg.alpha_tag, tg.tg_group, tg.frequency, tg.metadata, tg.tags, tg.alert, tg.alert_config, tg.weight, sys.id, sys.name, FALSE learned @@ -59,7 +59,7 @@ TRUE, NULL::JSONB, 1.0, sys.id, sys.name, TRUE learned FROM talkgroups_learned tgl JOIN systems sys ON tgl.system_id = sys.id -WHERE tg.system_id = $1 AND ignored IS NOT TRUE +WHERE tgl.system_id = $1 AND ignored IS NOT TRUE ` const getTalkgroupsWithLearnedTest = `-- name: GetTalkgroupsWithLearned :many @@ -81,7 +81,7 @@ WHERE ignored IS NOT TRUE ` func TestQueryColumnsMatch(t *testing.T) { - require.Equal(t, getTalkgroupsWithLearnedByPackedIDsTest, getTalkgroupWithLearnedByPackedIDs) + require.Equal(t, getTalkgroupsWithLearnedByPackedIDsTest, getTalkgroupsWithLearnedByPackedIDs) require.Equal(t, getTalkgroupWithLearnedTest, getTalkgroupWithLearned) require.Equal(t, getTalkgroupsWithLearnedBySystemTest, getTalkgroupsWithLearnedBySystem) require.Equal(t, getTalkgroupsWithLearnedTest, getTalkgroupsWithLearned) diff --git a/pkg/rest/api.go b/pkg/rest/api.go index 0c08daa..86eab07 100644 --- a/pkg/rest/api.go +++ b/pkg/rest/api.go @@ -43,7 +43,6 @@ type errResponse struct { func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { switch e.Code { case http.StatusNotFound: - case http.StatusBadRequest: default: log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed") } diff --git a/pkg/server/server.go b/pkg/server/server.go index ebec903..f527214 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -37,7 +37,7 @@ type Server struct { notifier notify.Notifier hup chan os.Signal tgs talkgroups.Store - rest rest.API + rest rest.API } func New(ctx context.Context, cfg *config.Config) (*Server, error) { @@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), notifier: notifier, tgs: tgCache, - rest: api, + rest: api, } srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true) diff --git a/pkg/talkgroups/alertconfig.go b/pkg/talkgroups/alertconfig.go index 1a7f7b6..eaa7294 100644 --- a/pkg/talkgroups/alertconfig.go +++ b/pkg/talkgroups/alertconfig.go @@ -1,54 +1,39 @@ package talkgroups import ( - "encoding/json" "sync" "time" "dynatron.me/x/stillbox/internal/ruletime" + "dynatron.me/x/stillbox/pkg/alerting/rules" ) type AlertConfig struct { sync.RWMutex - m map[ID][]AlertRule -} - -type AlertRule struct { - Times []ruletime.RuleTime `json:"times"` - ScoreMultiplier float32 `json:"mult"` + m map[ID]rules.AlertRules } func NewAlertConfig() AlertConfig { return AlertConfig{ - m: make(map[ID][]AlertRule), + m: make(map[ID]rules.AlertRules), } } -func (ac *AlertConfig) GetRules(tg ID) []AlertRule { +func (ac *AlertConfig) Add(tg ID, r rules.AlertRules) error { + ac.Lock() + defer ac.Unlock() + + ac.m[tg] = r + return nil +} + +func (ac *AlertConfig) GetRules(tg ID) rules.AlertRules { ac.RLock() defer ac.RUnlock() return ac.m[tg] } -func (ac *AlertConfig) UnmarshalTGRules(tg ID, confBytes []byte) error { - ac.Lock() - defer ac.Unlock() - - if len(confBytes) == 0 { - return nil - } - - var rules []AlertRule - err := json.Unmarshal(confBytes, &rules) - if err != nil { - return err - } - - ac.m[tg] = rules - return nil -} - func (ac *AlertConfig) ApplyAlertRules(id ID, t time.Time, coversOpts ...ruletime.CoversOption) float64 { ac.RLock() s, has := ac.m[id] @@ -57,15 +42,7 @@ func (ac *AlertConfig) ApplyAlertRules(id ID, t time.Time, coversOpts ...ruletim return 1.0 } - final := 1.0 - - for _, ar := range s { - if ar.MatchTime(t, coversOpts...) { - final *= float64(ar.ScoreMultiplier) - } - } - - return final + return s.Apply(t, coversOpts...) } func (ac *AlertConfig) Invalidate() { @@ -74,13 +51,3 @@ func (ac *AlertConfig) Invalidate() { clear(ac.m) } - -func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool { - for _, at := range ar.Times { - if at.Covers(t, coversOpts...) { - return true - } - } - - return false -} diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 16b8ac6..9e14588 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -139,7 +139,9 @@ func (t *cache) add(rec *Talkgroup) error { t.tgs[tg] = rec t.systems[int32(rec.System.ID)] = rec.System.Name - return t.AlertConfig.UnmarshalTGRules(tg, rec.Talkgroup.AlertConfig) + t.AlertConfig.Add(tg, rec.AlertConfig) + + return nil } type row interface { diff --git a/sql/sqlc.yaml b/sql/sqlc.yaml index c0e573a..6b617e4 100644 --- a/sql/sqlc.yaml +++ b/sql/sqlc.yaml @@ -27,3 +27,8 @@ sql: go_type: "time.Time" - db_type: "pg_catalog.text" go_type: "string" + - column: "talkgroups.alert_config" + go_type: + import: "dynatron.me/x/stillbox/pkg/alerting/rules" + type: "AlertRules" + nullable: true