broken Tx impl
This commit is contained in:
parent
c75dd9ec43
commit
aeeb8cb776
3 changed files with 50 additions and 5 deletions
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,8 +46,8 @@ type DBCtxKey string
|
||||||
|
|
||||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
const DBCTXKeyValue DBCtxKey = "dbctx"
|
||||||
|
|
||||||
func FromCtx(ctx context.Context) Conn {
|
func Tx(ctx context.Context) pgx.Tx {
|
||||||
c, ok := ctx.Value(DBCTXKeyValue).(Conn)
|
c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("no DB in context")
|
panic("no DB in context")
|
||||||
}
|
}
|
||||||
|
@ -54,6 +55,6 @@ func FromCtx(ctx context.Context) Conn {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func CtxWithDB(ctx context.Context, conn Conn) context.Context {
|
func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context {
|
||||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
||||||
q := database.New(database.FromCtx(ctx))
|
q := database.New(database.Tx(ctx))
|
||||||
users, err := q.GetUsers(ctx)
|
users, err := q.GetUsers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("getUsers failed")
|
log.Error().Err(err).Msg("getUsers failed")
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -10,11 +12,53 @@ import (
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
"github.com/go-chi/jwtauth/v5"
|
"github.com/go-chi/jwtauth/v5"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *Server) dbTx(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tx, err := s.db.Begin(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
log.Error().Err(err).Msg("tx open failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
r = r.WithContext(database.CtxWithTx(r.Context(), tx))
|
||||||
|
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
var err error
|
||||||
|
switch r := rec.(type) {
|
||||||
|
case error:
|
||||||
|
err = r
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("%v", r)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
log.Error().Err(err).Msg("tx rollback due to panic")
|
||||||
|
tx.Rollback(ctx)
|
||||||
|
}
|
||||||
|
}(r.Context())
|
||||||
|
|
||||||
|
err = next.ServeHTTP(w, r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
log.Error().Err(err).Msg("tx rollback due to error")
|
||||||
|
tx.Rollback(r.Context())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
log.Error().Err(err).Msg("tx commit failed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
r := s.r
|
r := s.r
|
||||||
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
r.Use(s.dbTx)
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(jwtauth.Verifier(s.jwt))
|
r.Use(jwtauth.Verifier(s.jwt))
|
||||||
|
|
Loading…
Reference in a new issue