From aeeb8cb776a5a06c661d6b3abb29a83e500e9c4a Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 15 Jul 2024 22:31:43 -0400 Subject: [PATCH] broken Tx impl --- pkg/gordio/database/database.go | 7 ++--- pkg/gordio/server/auth.go | 2 +- pkg/gordio/server/routes.go | 46 ++++++++++++++++++++++++++++++++- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/pkg/gordio/database/database.go b/pkg/gordio/database/database.go index 5e59ef4..bb85984 100644 --- a/pkg/gordio/database/database.go +++ b/pkg/gordio/database/database.go @@ -10,6 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) @@ -45,8 +46,8 @@ type DBCtxKey string const DBCTXKeyValue DBCtxKey = "dbctx" -func FromCtx(ctx context.Context) Conn { - c, ok := ctx.Value(DBCTXKeyValue).(Conn) +func Tx(ctx context.Context) pgx.Tx { + c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx) if !ok { panic("no DB in context") } @@ -54,6 +55,6 @@ func FromCtx(ctx context.Context) Conn { return c } -func CtxWithDB(ctx context.Context, conn Conn) context.Context { +func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context { return context.WithValue(ctx, DBCTXKeyValue, conn) } diff --git a/pkg/gordio/server/auth.go b/pkg/gordio/server/auth.go index adacfc9..f5076a1 100644 --- a/pkg/gordio/server/auth.go +++ b/pkg/gordio/server/auth.go @@ -26,7 +26,7 @@ var ( ) func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) { - q := database.New(database.FromCtx(ctx)) + q := database.New(database.Tx(ctx)) users, err := q.GetUsers(ctx) if err != nil { log.Error().Err(err).Msg("getUsers failed") diff --git a/pkg/gordio/server/routes.go b/pkg/gordio/server/routes.go index d29e1ed..ded55f5 100644 --- a/pkg/gordio/server/routes.go +++ b/pkg/gordio/server/routes.go @@ -1,6 +1,8 @@ package server import ( + "context" + "fmt" "net/http" "time" @@ -10,11 +12,53 @@ import ( "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" + "github.com/rs/zerolog/log" ) +func (s *Server) dbTx(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tx, err := s.db.Begin(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("tx open failed") + } + + r = r.WithContext(database.CtxWithTx(r.Context(), tx)) + + defer func(ctx context.Context) { + if rec := recover(); rec != nil { + var err error + switch r := rec.(type) { + case error: + err = r + default: + err = fmt.Errorf("%v", r) + } + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("tx rollback due to panic") + tx.Rollback(ctx) + } + }(r.Context()) + + err = next.ServeHTTP(w, r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("tx rollback due to error") + tx.Rollback(r.Context()) + return + } + + err = tx.Commit(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("tx commit failed") + } + }) +} + func (s *Server) setupRoutes() { r := s.r - r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) + r.Use(s.dbTx) r.Group(func(r chi.Router) { r.Use(jwtauth.Verifier(s.jwt))