Granular locking

This commit is contained in:
Daniel 2024-11-03 14:11:38 -05:00
parent 5d348e7888
commit 20dff10db8
4 changed files with 49 additions and 21 deletions

View file

@ -209,8 +209,6 @@ func (as *alerter) eval(ctx context.Context, now time.Time, testMode bool) ([]Al
} }
func (as *alerter) testNotifyHandler(w http.ResponseWriter, r *http.Request) { func (as *alerter) testNotifyHandler(w http.ResponseWriter, r *http.Request) {
as.RLock()
defer as.RUnlock()
alerts := make([]Alert, 0, len(as.scores)) alerts := make([]Alert, 0, len(as.scores))
ctx := r.Context() ctx := r.Context()
@ -257,9 +255,6 @@ func (as *alerter) notify(ctx context.Context) error {
return nil return nil
} }
as.Lock()
defer as.Unlock()
notifications, err := as.eval(ctx, time.Now(), false) notifications, err := as.eval(ctx, time.Now(), false)
if err != nil { if err != nil {
return err return err

View file

@ -2,20 +2,40 @@ package talkgroups
import ( import (
"encoding/json" "encoding/json"
"sync"
"time" "time"
"dynatron.me/x/stillbox/internal/ruletime" "dynatron.me/x/stillbox/internal/ruletime"
"dynatron.me/x/stillbox/internal/trending" "dynatron.me/x/stillbox/internal/trending"
) )
type AlertConfig map[ID][]AlertRule type AlertConfig struct {
sync.RWMutex
m map[ID][]AlertRule
}
type AlertRule struct { type AlertRule struct {
Times []ruletime.RuleTime `json:"times"` Times []ruletime.RuleTime `json:"times"`
ScoreMultiplier float32 `json:"mult"` ScoreMultiplier float32 `json:"mult"`
} }
func (ac AlertConfig) AddAlertConfig(tg ID, confBytes []byte) error { func NewAlertConfig() AlertConfig {
return AlertConfig{
m: make(map[ID][]AlertRule),
}
}
func (ac *AlertConfig) GetRules(tg ID) []AlertRule {
ac.RLock()
defer ac.RUnlock()
return ac.m[tg]
}
func (ac *AlertConfig) UnmarshalTGRules(tg ID, confBytes []byte) error {
ac.Lock()
defer ac.Unlock()
if len(confBytes) == 0 { if len(confBytes) == 0 {
return nil return nil
} }
@ -26,12 +46,14 @@ func (ac AlertConfig) AddAlertConfig(tg ID, confBytes []byte) error {
return err return err
} }
ac[tg] = rules ac.m[tg] = rules
return nil return nil
} }
func (ac AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, coversOpts ...ruletime.CoversOption) float64 { func (ac *AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, coversOpts ...ruletime.CoversOption) float64 {
s, has := ac[score.ID] ac.RLock()
s, has := ac.m[score.ID]
ac.RUnlock()
if !has { if !has {
return score.Score return score.Score
} }
@ -47,6 +69,13 @@ func (ac AlertConfig) ApplyAlertRules(score trending.Score[ID], t time.Time, cov
return final return final
} }
func (ac *AlertConfig) Invalidate() {
ac.Lock()
defer ac.Unlock()
clear(ac.m)
}
func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool { func (ar *AlertRule) MatchTime(t time.Time, coversOpts ...ruletime.CoversOption) bool {
for _, at := range ar.Times { for _, at := range ar.Times {
if at.Covers(t, coversOpts...) { if at.Covers(t, coversOpts...) {

View file

@ -15,7 +15,7 @@ import (
) )
func TestAlertConfig(t *testing.T) { func TestAlertConfig(t *testing.T) {
ac := make(talkgroups.AlertConfig) ac := talkgroups.NewAlertConfig()
parseTests := []struct { parseTests := []struct {
name string name string
tg talkgroups.ID tg talkgroups.ID
@ -57,12 +57,12 @@ func TestAlertConfig(t *testing.T) {
for _, tc := range parseTests { for _, tc := range parseTests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := ac.AddAlertConfig(tc.tg, []byte(tc.conf)) err := ac.UnmarshalTGRules(tc.tg, []byte(tc.conf))
if tc.expectErr != nil { if tc.expectErr != nil {
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr.Error()) assert.Contains(t, err.Error(), tc.expectErr.Error())
} else { } else {
assert.Equal(t, tc.compare, ac[tc.tg]) assert.Equal(t, tc.compare, ac.GetRules(tc.tg))
} }
}) })
} }

View file

@ -59,7 +59,7 @@ func (t *cache) Invalidate() {
defer t.Unlock() defer t.Unlock()
clear(t.tgs) clear(t.tgs)
clear(t.systems) clear(t.systems)
clear(t.AlertConfig) t.AlertConfig.Invalidate()
} }
type cache struct { type cache struct {
@ -74,7 +74,7 @@ func NewCache() Store {
tgc := &cache{ tgc := &cache{
tgs: make(tgMap), tgs: make(tgMap),
systems: make(map[int32]string), systems: make(map[int32]string),
AlertConfig: make(AlertConfig), AlertConfig: NewAlertConfig(),
} }
return tgc return tgc
@ -108,11 +108,14 @@ func (t *cache) Hint(ctx context.Context, tgs []ID) error {
} }
func (t *cache) add(rec Talkgroup) error { func (t *cache) add(rec Talkgroup) error {
t.Lock()
defer t.Unlock()
tg := TG(rec.System.ID, int(rec.Talkgroup.Tgid)) tg := TG(rec.System.ID, int(rec.Talkgroup.Tgid))
t.tgs[tg] = rec t.tgs[tg] = rec
t.systems[int32(rec.System.ID)] = rec.System.Name t.systems[int32(rec.System.ID)] = rec.System.Name
return t.AlertConfig.AddAlertConfig(tg, rec.Talkgroup.AlertConfig) return t.AlertConfig.UnmarshalTGRules(tg, rec.Talkgroup.AlertConfig)
} }
func rowToTalkgroup(r database.GetTalkgroupWithLearnedByPackedIDsRow) Talkgroup { func rowToTalkgroup(r database.GetTalkgroupWithLearnedByPackedIDsRow) Talkgroup {
@ -129,9 +132,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error {
return err return err
} }
t.Lock()
defer t.Unlock()
for _, rec := range tgRecords { for _, rec := range tgRecords {
err := t.add(rowToTalkgroup(rec)) err := t.add(rowToTalkgroup(rec))
@ -168,8 +168,6 @@ func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) {
return Talkgroup{}, ErrNoTG return Talkgroup{}, ErrNoTG
} }
t.Lock()
defer t.Unlock()
err = t.add(rowToTalkgroup(recs[0])) err = t.add(rowToTalkgroup(recs[0]))
if err != nil { if err != nil {
log.Error().Err(err).Msg("TG() cache add") log.Error().Err(err).Msg("TG() cache add")
@ -180,7 +178,9 @@ func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) {
} }
func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) { func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) {
t.RLock()
n, has := t.systems[int32(id)] n, has := t.systems[int32(id)]
t.RUnlock()
if !has { if !has {
sys, err := database.FromCtx(ctx).GetSystemName(ctx, id) sys, err := database.FromCtx(ctx).GetSystemName(ctx, id)
@ -188,6 +188,10 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool)
return "", false return "", false
} }
t.Lock()
t.systems[int32(id)] = sys
t.Unlock()
return sys, true return sys, true
} }