diff --git a/pkg/alerting/alerting.go b/pkg/alerting/alerting.go index 5d5d09d..59020af 100644 --- a/pkg/alerting/alerting.go +++ b/pkg/alerting/alerting.go @@ -209,8 +209,6 @@ func (as *alerter) eval(ctx context.Context, now time.Time, testMode bool) ([]Al } func (as *alerter) testNotifyHandler(w http.ResponseWriter, r *http.Request) { - as.RLock() - defer as.RUnlock() alerts := make([]Alert, 0, len(as.scores)) ctx := r.Context() @@ -257,9 +255,6 @@ func (as *alerter) notify(ctx context.Context) error { return nil } - as.Lock() - defer as.Unlock() - notifications, err := as.eval(ctx, time.Now(), false) if err != nil { return err diff --git a/pkg/talkgroups/alertconfig.go b/pkg/talkgroups/alertconfig.go index 65837a0..12d8305 100644 --- a/pkg/talkgroups/alertconfig.go +++ b/pkg/talkgroups/alertconfig.go @@ -2,20 +2,40 @@ package talkgroups import ( "encoding/json" + "sync" "time" "dynatron.me/x/stillbox/internal/ruletime" "dynatron.me/x/stillbox/internal/trending" ) -type AlertConfig map[ID][]AlertRule +type AlertConfig struct { + sync.RWMutex + m map[ID][]AlertRule +} type AlertRule struct { Times []ruletime.RuleTime `json:"times"` ScoreMultiplier float32 `json:"mult"` } -func (ac AlertConfig) AddAlertConfig(tg ID, confBytes []byte) error { +func NewAlertConfig() AlertConfig { + return AlertConfig{ + m: make(map[ID][]AlertRule), + } +} + +func (ac *AlertConfig) GetRules(tg ID) []AlertRule { + ac.RLock() + defer ac.RUnlock() + + return ac.m[tg] +} + +func (ac *AlertConfig) UnmarshalTGRules(tg ID, confBytes []byte) error { + ac.Lock() + defer ac.Unlock() + if len(confBytes) == 0 { return nil } @@ -26,12 +46,14 @@ func (ac AlertConfig) AddAlertConfig(tg ID, confBytes []byte) error { return err } - ac[tg] = rules + ac.m[tg] = rules return nil } -func (ac AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, coversOpts ...ruletime.CoversOption) float64 { - s, has := ac[score.ID] +func (ac *AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, coversOpts ...ruletime.CoversOption) float64 { + ac.RLock() + s, has := ac.m[score.ID] + ac.RUnlock() if !has { return score.Score } @@ -47,6 +69,13 @@ func (ac AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, cov return final } +func (ac *AlertConfig) Invalidate() { + ac.Lock() + defer ac.Unlock() + + clear(ac.m) +} + func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool { for _, at := range ar.Times { if at.Covers(t, coversOpts...) { diff --git a/pkg/talkgroups/alertconfig_test.go b/pkg/talkgroups/alertconfig_test.go index 2d26765..e966831 100644 --- a/pkg/talkgroups/alertconfig_test.go +++ b/pkg/talkgroups/alertconfig_test.go @@ -15,7 +15,7 @@ import ( ) func TestAlertConfig(t *testing.T) { - ac := make(talkgroups.AlertConfig) + ac := talkgroups.NewAlertConfig() parseTests := []struct { name string tg talkgroups.ID @@ -57,12 +57,12 @@ func TestAlertConfig(t *testing.T) { for _, tc := range parseTests { t.Run(tc.name, func(t *testing.T) { - err := ac.AddAlertConfig(tc.tg, []byte(tc.conf)) + err := ac.UnmarshalTGRules(tc.tg, []byte(tc.conf)) if tc.expectErr != nil { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectErr.Error()) } else { - assert.Equal(t, tc.compare, ac[tc.tg]) + assert.Equal(t, tc.compare, ac.GetRules(tc.tg)) } }) } diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index bc60417..f1c29fc 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -59,7 +59,7 @@ func (t *cache) Invalidate() { defer t.Unlock() clear(t.tgs) clear(t.systems) - clear(t.AlertConfig) + t.AlertConfig.Invalidate() } type cache struct { @@ -74,7 +74,7 @@ func NewCache() Store { tgc := &cache{ tgs: make(tgMap), systems: make(map[int32]string), - AlertConfig: make(AlertConfig), + AlertConfig: NewAlertConfig(), } return tgc @@ -108,11 +108,14 @@ func (t *cache) Hint(ctx context.Context, tgs []ID) error { } func (t *cache) add(rec Talkgroup) error { + t.Lock() + defer t.Unlock() + tg := TG(rec.System.ID, int(rec.Talkgroup.Tgid)) t.tgs[tg] = rec t.systems[int32(rec.System.ID)] = rec.System.Name - return t.AlertConfig.AddAlertConfig(tg, rec.Talkgroup.AlertConfig) + return t.AlertConfig.UnmarshalTGRules(tg, rec.Talkgroup.AlertConfig) } func rowToTalkgroup(r database.GetTalkgroupWithLearnedByPackedIDsRow) Talkgroup { @@ -129,9 +132,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error { return err } - t.Lock() - defer t.Unlock() - for _, rec := range tgRecords { err := t.add(rowToTalkgroup(rec)) @@ -168,8 +168,6 @@ func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) { return Talkgroup{}, ErrNoTG } - t.Lock() - defer t.Unlock() err = t.add(rowToTalkgroup(recs[0])) if err != nil { log.Error().Err(err).Msg("TG() cache add") @@ -180,7 +178,9 @@ func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) { } func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) { + t.RLock() n, has := t.systems[int32(id)] + t.RUnlock() if !has { sys, err := database.FromCtx(ctx).GetSystemName(ctx, id) @@ -188,6 +188,10 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) return "", false } + t.Lock() + t.systems[int32(id)] = sys + t.Unlock() + return sys, true }