diff --git a/snippetbox/cmd/web/middleware_test.go b/snippetbox/cmd/web/middleware_test.go new file mode 100644 index 0000000..ec78d32 --- /dev/null +++ b/snippetbox/cmd/web/middleware_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "snippetbox.chaosfem.tw/internal/assert" +) + +func TestSecureHeaders(t *testing.T) { + rr := httptest.NewRecorder() + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) + + secureHeaders(next).ServeHTTP(rr, r) + + rs := rr.Result() + + assert.Equal(t, rs.Header.Get("Content-Security-Policy"), "default-src 'self'; style-src 'self' fonts.googleapis.com; font-src fonts.gstatic.com") + assert.Equal(t, rs.Header.Get("Referrer-Policy"), "origin-when-cross-origin") + assert.Equal(t, rs.Header.Get("X-Content-Type-Options"), "nosniff") + assert.Equal(t, rs.Header.Get("X-Frame-Options"), "deny") + assert.Equal(t, rs.Header.Get("X-XSS-Protection"), "0") + + defer rs.Body.Close() + + body, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, string(bytes.TrimSpace(body)), "OK") +}