-
Notifications
You must be signed in to change notification settings - Fork 897
websocket: Initial refactor for websocket handling to support context propagation #2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 15 commits
360d0e5
4be8b4b
cce08e7
86d218c
0cb3876
98c0817
a87e39f
6b0a3b5
f380808
5eacbae
88e48a1
c8fb278
f6dd7c3
fc7df14
770d7a6
7d6e595
cabcd83
f21e02d
7625f60
7185eb9
2eb5f6a
e42ae00
bab00e0
1680819
239d160
07bba2a
1e63153
9f5fbda
ca24dc1
db647bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2947,7 +2947,7 @@ func (s *RPCServer) WebsocketGetInfo(_ context.Context, r *gctrpc.WebsocketGetIn | |
| } | ||
|
|
||
| // WebsocketSetEnabled enables or disables the websocket client | ||
| func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) { | ||
| func (s *RPCServer) WebsocketSetEnabled(ctx context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) { | ||
| exch, err := s.GetExchangeByName(r.Exchange) | ||
| if err != nil { | ||
| return nil, err | ||
|
|
@@ -2964,7 +2964,7 @@ func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSe | |
| } | ||
|
|
||
| if r.Enable { | ||
| err = w.Enable() | ||
| err = w.Enable(s.rpcContextToLongLivedSession(ctx)) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -3013,7 +3013,7 @@ func (s *RPCServer) WebsocketGetSubscriptions(_ context.Context, r *gctrpc.Webso | |
| } | ||
|
|
||
| // WebsocketSetProxy sets client websocket connection proxy | ||
| func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) { | ||
| func (s *RPCServer) WebsocketSetProxy(ctx context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) { | ||
| exch, err := s.GetExchangeByName(r.Exchange) | ||
| if err != nil { | ||
| return nil, err | ||
|
|
@@ -3024,7 +3024,7 @@ func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetP | |
| return nil, fmt.Errorf("websocket not supported for exchange %s", r.Exchange) | ||
| } | ||
|
|
||
| err = w.SetProxyAddress(r.Proxy) | ||
| err = w.SetProxyAddress(s.rpcContextToLongLivedSession(ctx), r.Proxy) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -5884,3 +5884,13 @@ func (s *RPCServer) GetCurrencyTradeURL(ctx context.Context, r *gctrpc.GetCurren | |
| Url: url, | ||
| }, nil | ||
| } | ||
|
|
||
| // rpcContextToLongLivedSession converts a short-lived incoming context to a long-lived outgoing context, this is due | ||
| // to the incoming context being cancelled when the RPC call completes. | ||
| func (s *RPCServer) rpcContextToLongLivedSession(ctx context.Context) context.Context { | ||
|
||
| md, ok := metadata.FromIncomingContext(ctx) | ||
| if !ok { | ||
| md = metadata.New(nil) // Fallback to empty metadata | ||
| } | ||
| return metadata.NewOutgoingContext(context.Background(), md) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| package message | ||
|
||
|
|
||
| import ( | ||
| "context" | ||
| "errors" | ||
| "fmt" | ||
| ) | ||
|
|
||
| var errChannelBufferFull = errors.New("channel buffer is full") | ||
|
|
||
| // Relay defines a channel relay for messages | ||
| type Relay struct { | ||
| C <-chan Payload | ||
| comm chan Payload | ||
| } | ||
|
|
||
| // Payload represents a relayed message with a context | ||
| type Payload struct { | ||
| // TODO: remove context from payload see: https://github.com/thrasher-corp/gocryptotrader/pull/2066#discussion_r2501403057 | ||
| Ctx context.Context //nolint:containedctx // context needed for tracing/metrics | ||
gbjk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Data any | ||
| } | ||
|
|
||
| // NewRelay creates a new Relay instance with a specified buffer size | ||
| func NewRelay(buffer uint) *Relay { | ||
| if buffer == 0 { | ||
| panic("buffer size must be greater than 0") | ||
| } | ||
| comm := make(chan Payload, buffer) | ||
| return &Relay{comm: comm, C: comm} | ||
| } | ||
|
|
||
| // Send sends a message to the channel receiver | ||
| // This is non-blocking and returns an error if the channel buffer is full | ||
| func (r *Relay) Send(ctx context.Context, data any) error { | ||
| select { | ||
| case r.comm <- Payload{Ctx: ctx, Data: data}: | ||
| return nil | ||
| default: | ||
| return fmt.Errorf("%w: failed to relay <%T>", errChannelBufferFull, data) | ||
| } | ||
| } | ||
|
|
||
| // Close closes the relay channel | ||
| func (r *Relay) Close() { | ||
| close(r.comm) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| package message | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestNewRelay(t *testing.T) { | ||
| t.Parallel() | ||
| assert.Panics(t, func() { NewRelay(0) }, "buffer size should be greater than 0") | ||
| r := NewRelay(5) | ||
| require.NotNil(t, r) | ||
| assert.Equal(t, 5, cap(r.comm)) | ||
| } | ||
|
|
||
| func TestSend(t *testing.T) { | ||
| t.Parallel() | ||
| r := NewRelay(1) | ||
| require.NotNil(t, r) | ||
| assert.NoError(t, r.Send(t.Context(), "test")) | ||
| assert.ErrorIs(t, r.Send(t.Context(), "overflow"), errChannelBufferFull) | ||
| } | ||
|
|
||
| func TestRead(t *testing.T) { | ||
| t.Parallel() | ||
| r := NewRelay(1) | ||
| require.NotNil(t, r) | ||
| require.Empty(t, r.C) | ||
| assert.NoError(t, r.Send(t.Context(), "test")) | ||
| require.Len(t, r.C, 1) | ||
| assert.Equal(t, "test", (<-r.C).Data) | ||
| } | ||
|
|
||
| func TestClose(t *testing.T) { | ||
| t.Parallel() | ||
| r := NewRelay(1) | ||
| require.NotNil(t, r) | ||
| r.Close() | ||
| _, ok := <-r.C | ||
| assert.False(t, ok) | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.