diff --git a/snippetbox/cmd/web/context.go b/snippetbox/cmd/web/context.go new file mode 100644 index 0000000..01d1722 --- /dev/null +++ b/snippetbox/cmd/web/context.go @@ -0,0 +1,5 @@ +package main + +type contextKey string + +const isAuthenticatedContextKey = contextKey("isAuthenticated") diff --git a/snippetbox/cmd/web/helpers.go b/snippetbox/cmd/web/helpers.go index 5b3dcb2..415badd 100644 --- a/snippetbox/cmd/web/helpers.go +++ b/snippetbox/cmd/web/helpers.go @@ -85,5 +85,10 @@ func (app *application) decodePostForm(r *http.Request, dst any) error { } func (app *application) isAuthenticated(r *http.Request) bool { - return app.sessionManager.Exists(r.Context(), "authenticatedUserID") + isAuthenticated, ok := r.Context().Value(isAuthenticatedContextKey).(bool) + if !ok { + return false + } + + return isAuthenticated } diff --git a/snippetbox/cmd/web/middleware.go b/snippetbox/cmd/web/middleware.go index ddc8104..65f572d 100644 --- a/snippetbox/cmd/web/middleware.go +++ b/snippetbox/cmd/web/middleware.go @@ -1,12 +1,36 @@ package main import ( + "context" "fmt" "net/http" "github.com/justinas/nosurf" ) +func (app *application) authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := app.sessionManager.GetInt(r.Context(), "authenticatedUserID") + if id == 0 { + next.ServeHTTP(w, r) + return + } + + exists, err := app.users.Exists(id) + if err != nil { + app.serverError(w, r, err) + return + } + + if exists { + ctx := context.WithValue(r.Context(), isAuthenticatedContextKey, true) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) +} + // logRequest ... func (app *application) logRequest(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/snippetbox/cmd/web/routes.go b/snippetbox/cmd/web/routes.go index c06f233..1553ad5 100644 --- a/snippetbox/cmd/web/routes.go +++ b/snippetbox/cmd/web/routes.go @@ -23,7 +23,7 @@ func (app *application) routes() http.Handler { fileServer := http.FileServer(http.Dir("./ui/static")) router.Handler(http.MethodGet, "/static/*filepath", http.StripPrefix("/static", fileServer)) - dynamic := alice.New(app.sessionManager.LoadAndSave, noSurf) + dynamic := alice.New(app.sessionManager.LoadAndSave, noSurf, app.authenticate) router.Handler(http.MethodGet, "/", dynamic.ThenFunc(app.home)) router.Handler(http.MethodGet, "/snippet/view/:id", dynamic.ThenFunc(app.snippetView)) diff --git a/snippetbox/internal/models/users.go b/snippetbox/internal/models/users.go index 494e5b0..ca20bec 100644 --- a/snippetbox/internal/models/users.go +++ b/snippetbox/internal/models/users.go @@ -76,5 +76,10 @@ func (m *UserModel) Authenticate(email, password string) (int, error) { // Exists func (m *UserModel) Exists(id int) (bool, error) { - return false, nil + var exists bool + + stmt := "SELECT EXISTS(SELECT true FROM users WHERE id = ?)" + + err := m.DB.QueryRow(stmt, id).Scan(&exists) + return exists, err }