Skip to content

Commit df7e229

Browse files
feat: Helpers for working with middlewares
1 parent f7afd17 commit df7e229

File tree

7 files changed

+338
-0
lines changed

7 files changed

+338
-0
lines changed

README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,50 @@ func addProduct(w http.ResponseWriter, r *http.Request) error {
3333
return errhandler.SendJSON(w, p)
3434
}
3535
```
36+
37+
### Middleware
38+
39+
errhandler contains helper objects for building middleware
40+
41+
* A `errhandler.Middleware` type, which is simply a function that takes a Wrap function and returns a Wrap function:
42+
43+
```go
44+
mux := http.NewServeMux()
45+
mux.Handle("GET /products/{id}", errhandler.Wrap(midLog(getProduct)))
46+
47+
...
48+
49+
func midLog(n errhandler.Wrap) errhandler.Wrap {
50+
return func(w http.ResponseWriter, r *http.Request) error {
51+
log.Printf("1 %s %s", r.Method, r.URL.Path)
52+
return n(w, r)
53+
}
54+
}
55+
56+
func addProduct(w http.ResponseWriter, r *http.Request) error {
57+
...
58+
}
59+
```
60+
61+
* A `errhandler.Chain` function, which allows `Middleware` functions easy to chain:
62+
63+
```go
64+
chain := errhandler.Chain(midLog1, midLog2)
65+
66+
mux := http.NewServeMux()
67+
mux.Handle("GET /products", errhandler.Wrap(chain(getProducts)))
68+
69+
func midLog1(n errhandler.Wrap) errhandler.Wrap {
70+
return func(w http.ResponseWriter, r *http.Request) error {
71+
log.Printf("1 %s %s", r.Method, r.URL.Path)
72+
return n(w, r)
73+
}
74+
}
75+
76+
func midLog2(n errhandler.Wrap) errhandler.Wrap {
77+
return func(w http.ResponseWriter, r *http.Request) error {
78+
log.Printf("2 %s %s", r.Method, r.URL.Path)
79+
return n(w, r)
80+
}
81+
}
82+
```

examples/basic/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Start server
2+
3+
```sh
4+
go run examples/basic/main.go
5+
```
6+
7+
Test server
8+
9+
```sh
10+
curl -s http://localhost:3000/products/a32fb2bd-b402-4bea-93c2-4a0a567b2261 | jq
11+
12+
curl -s http://localhost:3000/products | jq
13+
```

examples/basic/main.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net/http"
7+
8+
"github.com/codingconcepts/errhandler"
9+
)
10+
11+
var products = map[string]product{
12+
"a32fb2bd-b402-4bea-93c2-4a0a567b2261": {
13+
ID: "a32fb2bd-b402-4bea-93c2-4a0a567b2261",
14+
Name: "a",
15+
Price: 10.99,
16+
},
17+
"b68ed795-0604-4696-8eb2-5b4b927330a0": {
18+
ID: "b68ed795-0604-4696-8eb2-5b4b927330a0",
19+
Name: "b",
20+
Price: 20.99,
21+
},
22+
}
23+
24+
func main() {
25+
mux := http.NewServeMux()
26+
mux.Handle("GET /products", errhandler.Wrap(getProducts))
27+
mux.Handle("GET /products/{id}", errhandler.Wrap(getProduct))
28+
29+
server := &http.Server{Addr: "localhost:3000", Handler: mux}
30+
log.Fatal(server.ListenAndServe())
31+
}
32+
33+
func getProducts(w http.ResponseWriter, r *http.Request) error {
34+
return errhandler.SendJSON(w, products)
35+
}
36+
37+
func getProduct(w http.ResponseWriter, r *http.Request) error {
38+
id := r.PathValue("id")
39+
40+
p, ok := products[id]
41+
if !ok {
42+
return fmt.Errorf("no product with id: %s", id)
43+
}
44+
45+
return errhandler.SendJSON(w, p)
46+
}
47+
48+
type product struct {
49+
ID string `json:"id"`
50+
Name string `json:"name"`
51+
Price float64 `json:"price"`
52+
}

examples/middleware/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Start server
2+
3+
```sh
4+
go run examples/middleware/main.go
5+
```
6+
7+
Test server
8+
9+
```sh
10+
curl -s http://localhost:3000/products/a32fb2bd-b402-4bea-93c2-4a0a567b2261 | jq
11+
12+
curl -s http://localhost:3000/products | jq
13+
```

examples/middleware/main.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net/http"
7+
8+
"github.com/codingconcepts/errhandler"
9+
)
10+
11+
var products = map[string]product{
12+
"a32fb2bd-b402-4bea-93c2-4a0a567b2261": {
13+
ID: "a32fb2bd-b402-4bea-93c2-4a0a567b2261",
14+
Name: "a",
15+
Price: 10.99,
16+
},
17+
"b68ed795-0604-4696-8eb2-5b4b927330a0": {
18+
ID: "b68ed795-0604-4696-8eb2-5b4b927330a0",
19+
Name: "b",
20+
Price: 20.99,
21+
},
22+
}
23+
24+
func main() {
25+
chain := errhandler.Chain(midLog1, midLog2)
26+
27+
mux := http.NewServeMux()
28+
mux.Handle("GET /products", errhandler.Wrap(chain(getProducts)))
29+
mux.Handle("GET /products/{id}", errhandler.Wrap(midLog1(getProduct)))
30+
31+
server := &http.Server{Addr: "localhost:3000", Handler: mux}
32+
log.Fatal(server.ListenAndServe())
33+
}
34+
35+
func midLog1(n errhandler.Wrap) errhandler.Wrap {
36+
return func(w http.ResponseWriter, r *http.Request) error {
37+
log.Printf("1 %s %s", r.Method, r.URL.Path)
38+
return n(w, r)
39+
}
40+
}
41+
42+
func midLog2(n errhandler.Wrap) errhandler.Wrap {
43+
return func(w http.ResponseWriter, r *http.Request) error {
44+
log.Printf("2 %s %s", r.Method, r.URL.Path)
45+
return n(w, r)
46+
}
47+
}
48+
49+
func getProducts(w http.ResponseWriter, r *http.Request) error {
50+
return errhandler.SendJSON(w, products)
51+
}
52+
53+
func getProduct(w http.ResponseWriter, r *http.Request) error {
54+
id := r.PathValue("id")
55+
56+
p, ok := products[id]
57+
if !ok {
58+
return fmt.Errorf("no product with id: %s", id)
59+
}
60+
61+
return errhandler.SendJSON(w, p)
62+
}
63+
64+
type product struct {
65+
ID string `json:"id"`
66+
Name string `json:"name"`
67+
Price float64 `json:"price"`
68+
}

middleware.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package errhandler
2+
3+
// Middleware is a helper for adding middleware to wrapped handelers.
4+
type Middleware func(Wrap) Wrap
5+
6+
// Chain multiple middleware functions together into a single middleware.
7+
func Chain(m ...Middleware) Middleware {
8+
return func(n Wrap) Wrap {
9+
for i := len(m) - 1; i >= 0; i-- {
10+
n = m[i](n)
11+
}
12+
13+
return n
14+
}
15+
}

middleware_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package errhandler
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
)
9+
10+
type contextKey string
11+
12+
func TestMiddleware(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
handler Wrap
16+
middleware Middleware
17+
expectedStatus int
18+
expectedBody string
19+
}{
20+
{
21+
name: "single middleware",
22+
middleware: func(n Wrap) Wrap {
23+
return func(w http.ResponseWriter, r *http.Request) error {
24+
ctx := context.WithValue(r.Context(), contextKey("value"), 1)
25+
r = r.WithContext(ctx)
26+
27+
return n(w, r)
28+
}
29+
},
30+
handler: Wrap(func(w http.ResponseWriter, r *http.Request) error {
31+
m := map[string]any{
32+
"value": r.Context().Value(contextKey("value")),
33+
}
34+
35+
return SendJSON(w, m)
36+
}),
37+
expectedStatus: http.StatusOK,
38+
expectedBody: "{\"value\":1}\n",
39+
},
40+
}
41+
42+
for _, tt := range tests {
43+
t.Run(tt.name, func(t *testing.T) {
44+
req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
45+
w := httptest.NewRecorder()
46+
47+
tt.middleware(tt.handler).ServeHTTP(w, req)
48+
49+
res := w.Result()
50+
defer res.Body.Close()
51+
52+
if res.StatusCode != tt.expectedStatus {
53+
t.Errorf("expected status %d, got %d", tt.expectedStatus, res.StatusCode)
54+
}
55+
56+
actualBody := w.Body.String()
57+
if actualBody != tt.expectedBody {
58+
t.Errorf("expected body %q, got %q", tt.expectedBody, actualBody)
59+
}
60+
})
61+
}
62+
}
63+
64+
func TestChain(t *testing.T) {
65+
tests := []struct {
66+
name string
67+
handler Wrap
68+
middlewares []Middleware
69+
expectedStatus int
70+
expectedBody string
71+
}{
72+
{
73+
name: "multiple middlewares",
74+
middlewares: []Middleware{
75+
func(n Wrap) Wrap {
76+
return func(w http.ResponseWriter, r *http.Request) error {
77+
ctx := context.WithValue(r.Context(), contextKey("value"), 1)
78+
r = r.WithContext(ctx)
79+
80+
return n(w, r)
81+
}
82+
},
83+
func(n Wrap) Wrap {
84+
return func(w http.ResponseWriter, r *http.Request) error {
85+
value, ok := r.Context().Value(contextKey("value")).(int)
86+
if !ok {
87+
t.Fatalf("expected integer but didn't get one")
88+
}
89+
90+
ctx := context.WithValue(r.Context(), contextKey("value"), value+2)
91+
r = r.WithContext(ctx)
92+
93+
return n(w, r)
94+
}
95+
},
96+
},
97+
handler: Wrap(func(w http.ResponseWriter, r *http.Request) error {
98+
m := map[string]any{
99+
"value": r.Context().Value(contextKey("value")),
100+
}
101+
102+
return SendJSON(w, m)
103+
}),
104+
expectedStatus: http.StatusOK,
105+
expectedBody: "{\"value\":3}\n",
106+
},
107+
}
108+
109+
for _, tt := range tests {
110+
t.Run(tt.name, func(t *testing.T) {
111+
req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
112+
w := httptest.NewRecorder()
113+
114+
chain := Chain(tt.middlewares...)
115+
chain(tt.handler).ServeHTTP(w, req)
116+
117+
res := w.Result()
118+
defer res.Body.Close()
119+
120+
if res.StatusCode != tt.expectedStatus {
121+
t.Errorf("expected status %d, got %d", tt.expectedStatus, res.StatusCode)
122+
}
123+
124+
actualBody := w.Body.String()
125+
if actualBody != tt.expectedBody {
126+
t.Errorf("expected body %q, got %q", tt.expectedBody, actualBody)
127+
}
128+
})
129+
}
130+
}

0 commit comments

Comments
 (0)