From 170970e92d0c0124792fc68cd75e0770855d38df Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 4 Nov 2024 11:15:24 -0500 Subject: [PATCH] wip --- pkg/api/api.go | 122 ++++++++++++++++++++++++++++++++++++ pkg/server/routes.go | 1 + pkg/server/server.go | 4 ++ pkg/talkgroups/cache.go | 53 +++++++++++++--- pkg/talkgroups/talkgroup.go | 11 ++++ 5 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 pkg/api/api.go diff --git a/pkg/api/api.go b/pkg/api/api.go new file mode 100644 index 0000000..eb7eb7e --- /dev/null +++ b/pkg/api/api.go @@ -0,0 +1,122 @@ +package api + +import ( + "encoding/json" + "errors" + "net/http" + + "dynatron.me/x/stillbox/pkg/talkgroups" + + "github.com/go-chi/chi/v5" + "github.com/go-viper/mapstructure/v2" + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" +) + +type API interface { + Subrouter() http.Handler +} + +type api struct { + tgs talkgroups.Store +} + +func New(tgs talkgroups.Store) API { + s := &api{ + tgs: tgs, + } + + return s +} + +func (a *api) Subrouter() http.Handler { + r := chi.NewMux() + + r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup) + r.Get("/talkgroup/{system:\\d+}/", a.talkgroup) + r.Get("/talkgroup/", a.talkgroup) + return r +} + +var statusMapping = map[error]int{ + talkgroups.ErrNotFound: http.StatusNotFound, + pgx.ErrNoRows: http.StatusNotFound, +} + +func httpCode(err error) int { + c, ok := statusMapping[err] + if ok { + return c + } + + for e, c := range statusMapping { // check if err wraps an error we know about + if errors.Is(err, e) { + return c + } + } + + return http.StatusInternalServerError +} + +func (a *api) writeJSON(w http.ResponseWriter, r *http.Request, data interface{}, err error) { + if err != nil { + log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") + http.Error(w, err.Error(), httpCode(err)) + return + } + + enc := json.NewEncoder(w) + err = enc.Encode(data) + if err != nil { + log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") + http.Error(w, err.Error(), httpCode(err)) + return + } +} + +func decodeParams(d interface{}, r *http.Request) error { + params := chi.RouteContext(r.Context()).URLParams + m := make(map[string]string, len(params.Keys)) + + for i, k := range params.Keys { + m[k] = params.Values[i] + } + + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Metadata: nil, + Result: d, + TagName: "param", + WeaklyTypedInput: true, + }) + if err != nil { + return err + } + + return dec.Decode(m) +} + +func (a *api) badReq(w http.ResponseWriter, err error) { + http.Error(w, err.Error(), http.StatusBadRequest) +} + +func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) { + p := struct { + System *int `param:"system"` + ID *int `param:"id"` + }{} + + err := decodeParams(&p, r) + if err != nil { + a.badReq(w, err) + return + } + + var res interface{} + switch { + case p.System != nil && p.ID != nil: + res, err = a.tgs.TG(r.Context(), talkgroups.TG(*p.System, *p.ID)) + case p.System != nil: + default: + } + a.writeJSON(w, r, res, err) +} diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 1d6b87b..bcc7c3b 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -36,6 +36,7 @@ func (s *Server) setupRoutes() { s.nex.PrivateRoutes(r) s.auth.PrivateRoutes(r) s.alerter.PrivateRoutes(r) + r.Mount("/api", s.api.Subrouter()) }) r.Group(func(r chi.Router) { diff --git a/pkg/server/server.go b/pkg/server/server.go index fb0c23b..ed11996 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,6 +7,7 @@ import ( "time" "dynatron.me/x/stillbox/pkg/alerting" + "dynatron.me/x/stillbox/pkg/api" "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" @@ -36,6 +37,7 @@ type Server struct { notifier notify.Notifier hup chan os.Signal tgs talkgroups.Store + api api.API } func New(ctx context.Context, cfg *config.Config) (*Server, error) { @@ -59,6 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { } tgCache := talkgroups.NewCache() + api := api.New(tgCache) srv := &Server{ auth: authenticator, @@ -70,6 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), notifier: notifier, tgs: tgCache, + api: api, } srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true) diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 22e9500..dd7a6cb 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -15,11 +15,14 @@ import ( "github.com/rs/zerolog/log" ) -type tgMap map[ID]Talkgroup +type tgMap map[ID]*Talkgroup type Store interface { // TG retrieves a Talkgroup from the Store. - TG(ctx context.Context, tg ID) (Talkgroup, error) + TG(ctx context.Context, tg ID) (*Talkgroup, error) + + // TGs retrieves many talkgroups from the Store. + TGs(ctx context.Context, tgs IDs) ([]*Talkgroup, error) // SystemName retrieves a system name from the store. It returns the record and whether one was found. SystemName(ctx context.Context, id int) (string, bool) @@ -117,7 +120,7 @@ func (t *cache) Hint(ctx context.Context, tgs []ID) error { return nil } -func (t *cache) add(rec Talkgroup) error { +func (t *cache) add(rec *Talkgroup) error { t.Lock() defer t.Unlock() @@ -128,14 +131,46 @@ func (t *cache) add(rec Talkgroup) error { return t.AlertConfig.UnmarshalTGRules(tg, rec.Talkgroup.AlertConfig) } -func rowToTalkgroup(r database.GetTalkgroupWithLearnedByPackedIDsRow) Talkgroup { - return Talkgroup{ +func rowToTalkgroup(r database.GetTalkgroupWithLearnedByPackedIDsRow) *Talkgroup { + return &Talkgroup{ Talkgroup: r.Talkgroup, System: r.System, Learned: r.Learned, } } +func (t *cache) TGs(ctx context.Context, tgs IDs) ([]*Talkgroup, error) { + r := make([]*Talkgroup, 0, len(tgs)) + toGet := make(IDs, 0, len(tgs)) + t.RLock() + for _, id := range tgs { + rec, has := t.tgs[id] + if has { + r = append(r, rec) + } else { + toGet = append(toGet, id) + } + } + t.RUnlock() + + tgRecords, err := database.FromCtx(ctx).GetTalkgroupWithLearnedByPackedIDs(ctx, toGet.Packed()) + if err != nil { + return nil, err + } + + for _, rec := range tgRecords { + tg := rowToTalkgroup(rec) + err := t.add(tg) + if err != nil { + return nil, err + } + + r = append(r, tg) + } + + return r, nil +} + func (t *cache) Load(ctx context.Context, tgs []int64) error { tgRecords, err := database.FromCtx(ctx).GetTalkgroupWithLearnedByPackedIDs(ctx, tgs) if err != nil { @@ -168,7 +203,7 @@ func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 { return float64(m) } -func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) { +func (t *cache) TG(ctx context.Context, tg ID) (*Talkgroup, error) { t.RLock() rec, has := t.tgs[tg] t.RUnlock() @@ -181,14 +216,14 @@ func (t *cache) TG(ctx context.Context, tg ID) (Talkgroup, error) { switch err { case nil: case pgx.ErrNoRows: - return Talkgroup{}, ErrNotFound + return nil, ErrNotFound default: log.Error().Err(err).Msg("TG() cache add db get") - return Talkgroup{}, errors.Join(ErrNotFound, err) + return nil, errors.Join(ErrNotFound, err) } if len(recs) < 1 { - return Talkgroup{}, ErrNotFound + return nil, ErrNotFound } err = t.add(rowToTalkgroup(recs[0])) diff --git a/pkg/talkgroups/talkgroup.go b/pkg/talkgroups/talkgroup.go index bf68dfa..a988c24 100644 --- a/pkg/talkgroups/talkgroup.go +++ b/pkg/talkgroups/talkgroup.go @@ -17,6 +17,17 @@ type ID struct { Talkgroup uint32 } +type IDs []ID + +func (ids *IDs) Packed() []int64 { + r := make([]int64, len(*ids)) + for i := range *ids { + r[i] = (*ids)[i].Pack() + } + + return r +} + func TG[T int | uint | int64 | uint64 | int32 | uint32](sys, tgid T) ID { return ID{ System: uint32(sys),