Skip to content

Commit 9618f7e

Browse files
Avoid deadlocks by consolidating locks
1 parent 8e493f1 commit 9618f7e

File tree

6 files changed

+126
-93
lines changed

6 files changed

+126
-93
lines changed

.github/workflows/go.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
# Needed for the example-test to run.
3030
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
3131
run: |
32-
go test -cover -v ./...
32+
go test -race -cover -v ./...
3333
3434
lint:
3535
name: Lint

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ lint:
77
internal/lint/golangci-lint run ./... --fix
88

99
check: lint
10-
go test -cover ./...
10+
go test -race -cover ./...
1111
go mod tidy
1212

1313
.PHONY: example

graphql/subscription.go

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,80 @@
11
package graphql
22

33
import (
4-
"fmt"
54
"reflect"
6-
"sync"
75
)
86

9-
// map of subscription ID to subscription
7+
// subscriptionMap is a map of subscription ID to subscription.
8+
// It is NOT thread-safe and must be protected by the caller's lock.
109
type subscriptionMap struct {
1110
map_ map[string]subscription
12-
sync.RWMutex
1311
}
1412

1513
type subscription struct {
16-
interfaceChan interface{}
17-
forwardDataFunc ForwardDataFunction
18-
id string
19-
hasBeenUnsubscribed bool
14+
interfaceChan interface{}
15+
forwardDataFunc ForwardDataFunction
16+
id string
17+
closed bool // true if the channel has been closed
2018
}
2119

22-
func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
23-
s.Lock()
24-
defer s.Unlock()
20+
// create adds a new subscription to the map.
21+
// The caller must hold the webSocketClient lock.
22+
func (s *subscriptionMap) create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
2523
s.map_[subscriptionID] = subscription{
26-
id: subscriptionID,
27-
interfaceChan: interfaceChan,
28-
forwardDataFunc: forwardDataFunc,
29-
hasBeenUnsubscribed: false,
24+
id: subscriptionID,
25+
interfaceChan: interfaceChan,
26+
forwardDataFunc: forwardDataFunc,
27+
closed: false,
3028
}
3129
}
3230

33-
func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
34-
s.Lock()
35-
defer s.Unlock()
36-
unsub, success := s.map_[subscriptionID]
37-
if !success {
38-
return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID)
31+
// get retrieves a subscription by ID.
32+
// The caller must hold the webSocketClient lock.
33+
// Returns nil if not found.
34+
func (s *subscriptionMap) get(subscriptionID string) *subscription {
35+
sub, ok := s.map_[subscriptionID]
36+
if !ok {
37+
return nil
3938
}
40-
hasBeenUnsubscribed := unsub.hasBeenUnsubscribed
41-
unsub.hasBeenUnsubscribed = true
42-
s.map_[subscriptionID] = unsub
39+
return &sub
40+
}
4341

44-
if !hasBeenUnsubscribed {
45-
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
46-
}
47-
return nil
42+
// update updates a subscription in the map.
43+
// The caller must hold the webSocketClient lock.
44+
func (s *subscriptionMap) update(subscriptionID string, sub subscription) {
45+
s.map_[subscriptionID] = sub
4846
}
4947

50-
func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
51-
s.RLock()
52-
defer s.RUnlock()
48+
// getAllIDs returns all subscription IDs.
49+
// The caller must hold the webSocketClient lock.
50+
func (s *subscriptionMap) getAllIDs() []string {
51+
subscriptionIDs := make([]string, 0, len(s.map_))
5352
for subID := range s.map_ {
5453
subscriptionIDs = append(subscriptionIDs, subID)
5554
}
5655
return subscriptionIDs
5756
}
5857

59-
func (s *subscriptionMap) Delete(subscriptionID string) {
60-
s.Lock()
61-
defer s.Unlock()
58+
// delete removes a subscription from the map.
59+
// The caller must hold the webSocketClient lock.
60+
func (s *subscriptionMap) delete(subscriptionID string) {
6261
delete(s.map_, subscriptionID)
6362
}
63+
64+
// closeChannel closes a subscription's channel if it hasn't been closed yet.
65+
// The caller must hold the webSocketClient lock.
66+
// Returns true if the channel was closed, false if it was already closed.
67+
func (s *subscriptionMap) closeChannel(subscriptionID string) bool {
68+
sub := s.get(subscriptionID)
69+
if sub == nil || sub.closed {
70+
return false
71+
}
72+
73+
// Mark as closed before actually closing to prevent double-close
74+
sub.closed = true
75+
s.update(subscriptionID, *sub)
76+
77+
// Close the channel
78+
reflect.ValueOf(sub.interfaceChan).Close()
79+
return true
80+
}

graphql/subscription_test.go

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,57 @@ import (
44
"testing"
55
)
66

7-
func Test_subscriptionMap_Unsubscribe(t *testing.T) {
8-
type args struct {
9-
subscriptionID string
10-
}
7+
func Test_subscriptionMap_closeChannel(t *testing.T) {
118
tests := []struct {
12-
name string
13-
args args
14-
sm subscriptionMap
15-
wantErr bool
9+
name string
10+
sm subscriptionMap
11+
subscriptionID string
12+
wantClosed bool
1613
}{
1714
{
18-
name: "unsubscribe existing subscription",
15+
name: "close existing open channel",
1916
sm: subscriptionMap{
2017
map_: map[string]subscription{
2118
"sub1": {
22-
id: "sub1",
23-
interfaceChan: make(chan struct{}),
24-
forwardDataFunc: nil,
25-
hasBeenUnsubscribed: false,
19+
id: "sub1",
20+
interfaceChan: make(chan struct{}),
21+
closed: false,
2622
},
2723
},
2824
},
29-
args: args{subscriptionID: "sub1"},
30-
wantErr: false,
25+
subscriptionID: "sub1",
26+
wantClosed: true,
3127
},
3228
{
33-
name: "unsubscribe non-existent subscription",
34-
sm: subscriptionMap{
35-
map_: map[string]subscription{},
36-
},
37-
args: args{subscriptionID: "doesnotexist"},
38-
wantErr: true,
39-
},
40-
{
41-
name: "unsubscribe already unsubscribed subscription",
29+
name: "close already closed channel",
4230
sm: subscriptionMap{
4331
map_: map[string]subscription{
4432
"sub2": {
45-
id: "sub2",
46-
interfaceChan: nil,
47-
forwardDataFunc: nil,
48-
hasBeenUnsubscribed: true,
33+
id: "sub2",
34+
interfaceChan: make(chan struct{}),
35+
closed: true,
4936
},
5037
},
5138
},
52-
args: args{subscriptionID: "sub2"},
53-
wantErr: false,
39+
subscriptionID: "sub2",
40+
wantClosed: false,
41+
},
42+
{
43+
name: "close non-existent subscription",
44+
sm: subscriptionMap{
45+
map_: map[string]subscription{},
46+
},
47+
subscriptionID: "doesnotexist",
48+
wantClosed: false,
5449
},
5550
}
5651
for i := range tests {
5752
tt := &tests[i]
5853
t.Run(tt.name, func(t *testing.T) {
5954
s := &tt.sm
60-
if err := s.Unsubscribe(tt.args.subscriptionID); (err != nil) != tt.wantErr {
61-
t.Errorf("subscriptionMap.Unsubscribe() error = %v, wantErr %v", err, tt.wantErr)
55+
gotClosed := s.closeChannel(tt.subscriptionID)
56+
if gotClosed != tt.wantClosed {
57+
t.Errorf("subscriptionMap.closeChannel() = %v, want %v", gotClosed, tt.wantClosed)
6258
}
6359
})
6460
}

graphql/websocket.go

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/json"
77
"fmt"
88
"net/http"
9-
"reflect"
109
"strings"
1110
"sync"
1211
"time"
@@ -106,15 +105,19 @@ func (w *webSocketClient) waitForConnAck() error {
106105

107106
func (w *webSocketClient) handleErr(err error) {
108107
w.Lock()
109-
defer w.Unlock()
110-
if !w.isClosing {
108+
isClosing := w.isClosing
109+
w.Unlock()
110+
if !isClosing {
111111
w.errChan <- err
112112
}
113113
}
114114

115115
func (w *webSocketClient) listenWebSocket() {
116116
for {
117-
if w.isClosing {
117+
w.Lock()
118+
isClosing := w.isClosing
119+
w.Unlock()
120+
if isClosing {
118121
return
119122
}
120123
_, message, err := w.conn.ReadMessage()
@@ -139,22 +142,31 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
139142
if wsMsg.ID == "" { // e.g. keep-alive messages
140143
return nil
141144
}
142-
w.subscriptions.Lock()
143-
defer w.subscriptions.Unlock()
144-
sub, success := w.subscriptions.map_[wsMsg.ID]
145-
if !success {
145+
146+
w.Lock()
147+
sub := w.subscriptions.get(wsMsg.ID)
148+
if sub == nil {
149+
w.Unlock()
146150
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
147151
}
148-
if sub.hasBeenUnsubscribed {
152+
if sub.closed {
153+
// Already closed, ignore message
154+
w.Unlock()
149155
return nil
150156
}
157+
151158
if wsMsg.Type == webSocketTypeComplete {
152-
sub.hasBeenUnsubscribed = true
153-
w.subscriptions.map_[wsMsg.ID] = sub
154-
reflect.ValueOf(sub.interfaceChan).Close()
159+
// Server is telling us the subscription is complete
160+
w.subscriptions.closeChannel(wsMsg.ID)
161+
w.subscriptions.delete(wsMsg.ID)
162+
w.Unlock()
155163
return nil
156164
}
157165

166+
// Forward the data to the subscription channel.
167+
// We release the lock while calling the forward function to avoid holding
168+
// the lock while doing potentially slow user code.
169+
w.Unlock()
158170
return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload)
159171
}
160172

@@ -224,15 +236,21 @@ func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, for
224236
}
225237

226238
subscriptionID := uuid.NewString()
227-
w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc)
239+
240+
w.Lock()
241+
w.subscriptions.create(subscriptionID, interfaceChan, forwardDataFunc)
242+
w.Unlock()
243+
228244
subscriptionMsg := webSocketSendMessage{
229245
Type: webSocketTypeSubscribe,
230246
Payload: req,
231247
ID: subscriptionID,
232248
}
233249
err := w.sendStructAsJSON(subscriptionMsg)
234250
if err != nil {
235-
w.subscriptions.Delete(subscriptionID)
251+
w.Lock()
252+
w.subscriptions.delete(subscriptionID)
253+
w.Unlock()
236254
return "", err
237255
}
238256
return subscriptionID, nil
@@ -247,15 +265,19 @@ func (w *webSocketClient) Unsubscribe(subscriptionID string) error {
247265
if err != nil {
248266
return err
249267
}
250-
err = w.subscriptions.Unsubscribe(subscriptionID)
251-
if err != nil {
252-
return err
253-
}
268+
269+
w.Lock()
270+
defer w.Unlock()
271+
w.subscriptions.closeChannel(subscriptionID)
272+
w.subscriptions.delete(subscriptionID)
254273
return nil
255274
}
256275

257276
func (w *webSocketClient) UnsubscribeAll() error {
258-
subscriptionIDs := w.subscriptions.GetAllIDs()
277+
w.Lock()
278+
subscriptionIDs := w.subscriptions.getAllIDs()
279+
w.Unlock()
280+
259281
for _, subscriptionID := range subscriptionIDs {
260282
err := w.Unsubscribe(subscriptionID)
261283
if err != nil {

graphql/websocket_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,18 @@ package graphql
22

33
import (
44
"encoding/json"
5-
"sync"
65
"testing"
76
)
87

98
const testSubscriptionID = "test-subscription-id"
109

11-
func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient {
10+
func forgeTestWebSocketClient(closed bool) *webSocketClient {
1211
return &webSocketClient{
1312
subscriptions: subscriptionMap{
14-
RWMutex: sync.RWMutex{},
1513
map_: map[string]subscription{
1614
testSubscriptionID: {
17-
hasBeenUnsubscribed: hasBeenUnsubscribed,
18-
interfaceChan: make(chan any),
15+
closed: closed,
16+
interfaceChan: make(chan any),
1917
forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error {
2018
return nil
2119
},
@@ -60,7 +58,7 @@ func Test_webSocketClient_forwardWebSocketData(t *testing.T) {
6058
wantErr: false,
6159
},
6260
{
63-
name: "unsubscribed subscription",
61+
name: "closed subscription",
6462
args: args{message: []byte(`{"type":"next","id":"test-subscription-id","payload":{}}`)},
6563
wc: forgeTestWebSocketClient(true),
6664
wantErr: false,

0 commit comments

Comments
 (0)