Revert "broken Tx impl"

This reverts commit aeeb8cb776.
This commit is contained in:
Daniel 2024-07-15 22:31:49 -04:00
parent aeeb8cb776
commit e30896ee4c
3 changed files with 5 additions and 50 deletions

View file

@ -10,7 +10,6 @@ import (
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
@ -46,8 +45,8 @@ type DBCtxKey string
const DBCTXKeyValue DBCtxKey = "dbctx"
func Tx(ctx context.Context) pgx.Tx {
c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx)
func FromCtx(ctx context.Context) Conn {
c, ok := ctx.Value(DBCTXKeyValue).(Conn)
if !ok {
panic("no DB in context")
}
@ -55,6 +54,6 @@ func Tx(ctx context.Context) pgx.Tx {
return c
}
func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context {
func CtxWithDB(ctx context.Context, conn Conn) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn)
}

View file

@ -26,7 +26,7 @@ var (
)
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
q := database.New(database.Tx(ctx))
q := database.New(database.FromCtx(ctx))
users, err := q.GetUsers(ctx)
if err != nil {
log.Error().Err(err).Msg("getUsers failed")

View file

@ -1,8 +1,6 @@
package server
import (
"context"
"fmt"
"net/http"
"time"
@ -12,53 +10,11 @@ import (
"github.com/go-chi/httprate"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"github.com/rs/zerolog/log"
)
func (s *Server) dbTx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tx, err := s.db.Begin(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx open failed")
}
r = r.WithContext(database.CtxWithTx(r.Context(), tx))
defer func(ctx context.Context) {
if rec := recover(); rec != nil {
var err error
switch r := rec.(type) {
case error:
err = r
default:
err = fmt.Errorf("%v", r)
}
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx rollback due to panic")
tx.Rollback(ctx)
}
}(r.Context())
err = next.ServeHTTP(w, r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx rollback due to error")
tx.Rollback(r.Context())
return
}
err = tx.Commit(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("tx commit failed")
}
})
}
func (s *Server) setupRoutes() {
r := s.r
r.Use(s.dbTx)
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(s.jwt))