From 01bfefac13544b5ebf727f54226523228642d02d Mon Sep 17 00:00:00 2001 From: tamsin johnson Date: Wed, 14 Feb 2024 14:46:27 -0800 Subject: [PATCH] lets-go:14.3 refactor test server setup --- snippetbox/cmd/web/handlers_test.go | 29 ++++----------- snippetbox/cmd/web/testutils_test.go | 53 ++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 23 deletions(-) create mode 100644 snippetbox/cmd/web/testutils_test.go diff --git a/snippetbox/cmd/web/handlers_test.go b/snippetbox/cmd/web/handlers_test.go index 30cfe19..5994471 100644 --- a/snippetbox/cmd/web/handlers_test.go +++ b/snippetbox/cmd/web/handlers_test.go @@ -1,36 +1,19 @@ package main import ( - "bytes" - "io" - "log/slog" "net/http" - "net/http/httptest" - "testing" - "snippetbox.chaosfem.tw/internal/assert" + "testing" ) func TestPing(t *testing.T) { - app := &application{ - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - } + app := newTestApplication(t) - ts := httptest.NewTLSServer(app.routes()) + ts := newTestServer(t, app.routes()) defer ts.Close() - rs, err := ts.Client().Get(ts.URL + "/ping") - if err != nil { - t.Fatal(err) - } + statusCode, _, body := ts.get(t, "/ping") - assert.Equal(t, rs.StatusCode, http.StatusOK) - - defer rs.Body.Close() - body, err := io.ReadAll(rs.Body) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, string(bytes.TrimSpace(body)), "OK") + assert.Equal(t, statusCode, http.StatusOK) + assert.Equal(t, body, "OK") } diff --git a/snippetbox/cmd/web/testutils_test.go b/snippetbox/cmd/web/testutils_test.go new file mode 100644 index 0000000..bc97f05 --- /dev/null +++ b/snippetbox/cmd/web/testutils_test.go @@ -0,0 +1,53 @@ +package main + +import ( + "bytes" + "io" + "log/slog" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "testing" +) + +func newTestApplication(t *testing.T) *application { + return &application{ + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } +} + +type testServer struct { + *httptest.Server +} + +func newTestServer(t *testing.T, h http.Handler) *testServer { + ts := httptest.NewTLSServer(h) + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } + + ts.Client().Jar = jar + + ts.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + return &testServer{ts} +} + +func (ts *testServer) get(t *testing.T, urlPath string) (int, http.Header, string) { + rs, err := ts.Client().Get(ts.URL + urlPath) + if err != nil { + t.Fatal(err) + } + + defer rs.Body.Close() + body, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + return rs.StatusCode, rs.Header, string(bytes.TrimSpace(body)) +}