From e30896ee4c15f83b16405419728e0f2f92db2763 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 15 Jul 2024 22:31:49 -0400 Subject: [PATCH] Revert "broken Tx impl" This reverts commit aeeb8cb776a5a06c661d6b3abb29a83e500e9c4a. --- pkg/gordio/database/database.go | 7 +++-- pkg/gordio/server/auth.go | 2 +- pkg/gordio/server/routes.go | 46 +-------------------------------- 3 files changed, 5 insertions(+), 50 deletions(-) diff --git a/pkg/gordio/database/database.go b/pkg/gordio/database/database.go index bb85984..5e59ef4 100644 --- a/pkg/gordio/database/database.go +++ b/pkg/gordio/database/database.go @@ -10,7 +10,6 @@ import ( "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/iofs" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) @@ -46,8 +45,8 @@ type DBCtxKey string const DBCTXKeyValue DBCtxKey = "dbctx" -func Tx(ctx context.Context) pgx.Tx { - c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx) +func FromCtx(ctx context.Context) Conn { + c, ok := ctx.Value(DBCTXKeyValue).(Conn) if !ok { panic("no DB in context") } @@ -55,6 +54,6 @@ func Tx(ctx context.Context) pgx.Tx { return c } -func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context { +func CtxWithDB(ctx context.Context, conn Conn) context.Context { return context.WithValue(ctx, DBCTXKeyValue, conn) } diff --git a/pkg/gordio/server/auth.go b/pkg/gordio/server/auth.go index f5076a1..adacfc9 100644 --- a/pkg/gordio/server/auth.go +++ b/pkg/gordio/server/auth.go @@ -26,7 +26,7 @@ var ( ) func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) { - q := database.New(database.Tx(ctx)) + q := database.New(database.FromCtx(ctx)) users, err := q.GetUsers(ctx) if err != nil { log.Error().Err(err).Msg("getUsers failed") diff --git a/pkg/gordio/server/routes.go b/pkg/gordio/server/routes.go index ded55f5..d29e1ed 100644 --- a/pkg/gordio/server/routes.go +++ b/pkg/gordio/server/routes.go @@ -1,8 +1,6 @@ package server import ( - "context" - "fmt" "net/http" "time" @@ -12,53 +10,11 @@ import ( "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" - "github.com/rs/zerolog/log" ) -func (s *Server) dbTx(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tx, err := s.db.Begin(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - log.Error().Err(err).Msg("tx open failed") - } - - r = r.WithContext(database.CtxWithTx(r.Context(), tx)) - - defer func(ctx context.Context) { - if rec := recover(); rec != nil { - var err error - switch r := rec.(type) { - case error: - err = r - default: - err = fmt.Errorf("%v", r) - } - w.WriteHeader(http.StatusInternalServerError) - log.Error().Err(err).Msg("tx rollback due to panic") - tx.Rollback(ctx) - } - }(r.Context()) - - err = next.ServeHTTP(w, r) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - log.Error().Err(err).Msg("tx rollback due to error") - tx.Rollback(r.Context()) - return - } - - err = tx.Commit(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - log.Error().Err(err).Msg("tx commit failed") - } - }) -} - func (s *Server) setupRoutes() { r := s.r - r.Use(s.dbTx) + r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) r.Group(func(r chi.Router) { r.Use(jwtauth.Verifier(s.jwt))