broken Tx impl

This commit is contained in:
Daniel 2024-07-15 22:31:43 -04:00
parent c75dd9ec43
commit aeeb8cb776
3 changed files with 50 additions and 5 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5" _ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/source/iofs" "github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
) )
@ -45,8 +46,8 @@ type DBCtxKey string
const DBCTXKeyValue DBCtxKey = "dbctx" const DBCTXKeyValue DBCtxKey = "dbctx"
func FromCtx(ctx context.Context) Conn { func Tx(ctx context.Context) pgx.Tx {
c, ok := ctx.Value(DBCTXKeyValue).(Conn) c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx)
if !ok { if !ok {
panic("no DB in context") panic("no DB in context")
} }
@ -54,6 +55,6 @@ func FromCtx(ctx context.Context) Conn {
return c return c
} }
func CtxWithDB(ctx context.Context, conn Conn) context.Context { func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn) return context.WithValue(ctx, DBCTXKeyValue, conn)
} }

View file

@ -26,7 +26,7 @@ var (
) )
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) { func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
q := database.New(database.FromCtx(ctx)) q := database.New(database.Tx(ctx))
users, err := q.GetUsers(ctx) users, err := q.GetUsers(ctx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("getUsers failed") log.Error().Err(err).Msg("getUsers failed")

View file

@ -1,6 +1,8 @@
package server package server
import ( import (
"context"
"fmt"
"net/http" "net/http"
"time" "time"
@ -10,11 +12,53 @@ import (
"github.com/go-chi/httprate" "github.com/go-chi/httprate"
"github.com/go-chi/jwtauth/v5" "github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render" "github.com/go-chi/render"
"github.com/rs/zerolog/log"
) )
func (s *Server) dbTx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tx, err := s.db.Begin(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx open failed")
}
r = r.WithContext(database.CtxWithTx(r.Context(), tx))
defer func(ctx context.Context) {
if rec := recover(); rec != nil {
var err error
switch r := rec.(type) {
case error:
err = r
default:
err = fmt.Errorf("%v", r)
}
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx rollback due to panic")
tx.Rollback(ctx)
}
}(r.Context())
err = next.ServeHTTP(w, r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx rollback due to error")
tx.Rollback(r.Context())
return
}
err = tx.Commit(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx commit failed")
}
})
}
func (s *Server) setupRoutes() { func (s *Server) setupRoutes() {
r := s.r r := s.r
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) r.Use(s.dbTx)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(s.jwt)) r.Use(jwtauth.Verifier(s.jwt))