parent
aeeb8cb776
commit
e30896ee4c
3 changed files with 5 additions and 50 deletions
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
|
@ -46,8 +45,8 @@ type DBCtxKey string
|
|||
|
||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
||||
|
||||
func Tx(ctx context.Context) pgx.Tx {
|
||||
c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx)
|
||||
func FromCtx(ctx context.Context) Conn {
|
||||
c, ok := ctx.Value(DBCTXKeyValue).(Conn)
|
||||
if !ok {
|
||||
panic("no DB in context")
|
||||
}
|
||||
|
@ -55,6 +54,6 @@ func Tx(ctx context.Context) pgx.Tx {
|
|||
return c
|
||||
}
|
||||
|
||||
func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context {
|
||||
func CtxWithDB(ctx context.Context, conn Conn) context.Context {
|
||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ var (
|
|||
)
|
||||
|
||||
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
||||
q := database.New(database.Tx(ctx))
|
||||
q := database.New(database.FromCtx(ctx))
|
||||
users, err := q.GetUsers(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("getUsers failed")
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -12,53 +10,11 @@ import (
|
|||
"github.com/go-chi/httprate"
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (s *Server) dbTx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tx, err := s.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Error().Err(err).Msg("tx open failed")
|
||||
}
|
||||
|
||||
r = r.WithContext(database.CtxWithTx(r.Context(), tx))
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if rec := recover(); rec != nil {
|
||||
var err error
|
||||
switch r := rec.(type) {
|
||||
case error:
|
||||
err = r
|
||||
default:
|
||||
err = fmt.Errorf("%v", r)
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Error().Err(err).Msg("tx rollback due to panic")
|
||||
tx.Rollback(ctx)
|
||||
}
|
||||
}(r.Context())
|
||||
|
||||
err = next.ServeHTTP(w, r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Error().Err(err).Msg("tx rollback due to error")
|
||||
tx.Rollback(r.Context())
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Error().Err(err).Msg("tx commit failed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
r := s.r
|
||||
r.Use(s.dbTx)
|
||||
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(jwtauth.Verifier(s.jwt))
|
||||
|
|
Loading…
Reference in a new issue