parent
aeeb8cb776
commit
e30896ee4c
3 changed files with 5 additions and 50 deletions
|
@ -10,7 +10,6 @@ import (
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,8 +45,8 @@ type DBCtxKey string
|
||||||
|
|
||||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
const DBCTXKeyValue DBCtxKey = "dbctx"
|
||||||
|
|
||||||
func Tx(ctx context.Context) pgx.Tx {
|
func FromCtx(ctx context.Context) Conn {
|
||||||
c, ok := ctx.Value(DBCTXKeyValue).(pgx.Tx)
|
c, ok := ctx.Value(DBCTXKeyValue).(Conn)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("no DB in context")
|
panic("no DB in context")
|
||||||
}
|
}
|
||||||
|
@ -55,6 +54,6 @@ func Tx(ctx context.Context) pgx.Tx {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func CtxWithTx(ctx context.Context, conn pgx.Tx) context.Context {
|
func CtxWithDB(ctx context.Context, conn Conn) context.Context {
|
||||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
||||||
q := database.New(database.Tx(ctx))
|
q := database.New(database.FromCtx(ctx))
|
||||||
users, err := q.GetUsers(ctx)
|
users, err := q.GetUsers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("getUsers failed")
|
log.Error().Err(err).Msg("getUsers failed")
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,53 +10,11 @@ import (
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
"github.com/go-chi/jwtauth/v5"
|
"github.com/go-chi/jwtauth/v5"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) dbTx(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tx, err := s.db.Begin(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
log.Error().Err(err).Msg("tx open failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
r = r.WithContext(database.CtxWithTx(r.Context(), tx))
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
if rec := recover(); rec != nil {
|
|
||||||
var err error
|
|
||||||
switch r := rec.(type) {
|
|
||||||
case error:
|
|
||||||
err = r
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("%v", r)
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
log.Error().Err(err).Msg("tx rollback due to panic")
|
|
||||||
tx.Rollback(ctx)
|
|
||||||
}
|
|
||||||
}(r.Context())
|
|
||||||
|
|
||||||
err = next.ServeHTTP(w, r)
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
log.Error().Err(err).Msg("tx rollback due to error")
|
|
||||||
tx.Rollback(r.Context())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
log.Error().Err(err).Msg("tx commit failed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
r := s.r
|
r := s.r
|
||||||
r.Use(s.dbTx)
|
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(jwtauth.Verifier(s.jwt))
|
r.Use(jwtauth.Verifier(s.jwt))
|
||||||
|
|
Loading…
Reference in a new issue