Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
360d0e5
Refactor websocket handling to support context propagation
Sep 22, 2025
4be8b4b
linter/AI: nits and fixes
Sep 22, 2025
cce08e7
linter: fixes again
Sep 22, 2025
86d218c
fix test
Sep 22, 2025
0cb3876
Merge branch 'master' into websock_read_ctx
Oct 14, 2025
98c0817
instead of a read function just expose the read only channel field li…
Oct 14, 2025
a87e39f
Merge branch 'master' into websock_read_ctx
Oct 28, 2025
6b0a3b5
Merge branch 'master' into websock_read_ctx
Nov 5, 2025
f380808
glorious: nits
Nov 9, 2025
5eacbae
Merge branch 'master' into websock_read_ctx
Nov 12, 2025
88e48a1
glorious: nits
Nov 12, 2025
c8fb278
rm jank for finding issue
Nov 12, 2025
f6dd7c3
linter: fix
Nov 12, 2025
fc7df14
glorious: more nits
Nov 12, 2025
770d7a6
Add todo for context removal
Nov 12, 2025
7d6e595
Merge branch 'master' into websock_read_ctx
Nov 27, 2025
cabcd83
Merge branch 'master' into websock_read_ctx
Nov 30, 2025
f21e02d
Implement context freezing and thawing functions; update Payload to u…
Dec 1, 2025
7625f60
linter: fix
Dec 1, 2025
7185eb9
linter: fix again
Dec 1, 2025
2eb5f6a
Merge branch 'master' into websock_read_ctx
Dec 1, 2025
e42ae00
refactor: simplify FrozenContext structure and update related functions
Dec 2, 2025
bab00e0
Update common/common.go
shazbert Dec 2, 2025
1680819
Update common/common.go
shazbert Dec 2, 2025
239d160
glorious: removal of check
Dec 3, 2025
07bba2a
thrasher: rm unused error return
Dec 3, 2025
1e63153
Update exchange/websocket/manager.go
shazbert Dec 3, 2025
9f5fbda
rm: check for nil datahandler
Dec 3, 2025
ca24dc1
gk: rm rpcContextToLongLivedSession
Dec 4, 2025
db647bf
gk: rn package
Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions docs/ADD_NEW_EXCHANGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,19 +693,17 @@ func (e *Exchange) WsConnect() error {
// KeepAuthKeyAlive will continuously send messages to
// keep the WS auth key active
func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()
ticks := time.NewTicker(time.Minute * 30)
for {
select {
case <-e.Websocket.ShutdownC:
ticks.Stop()
return
case <-ticks.C:
err := e.MaintainWsAuthStreamKey(ctx)
if err != nil {
e.Websocket.DataHandler <- err
log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name)
case <-time.After(time.Minute * 30):
if err := e.MaintainWsAuthStreamKey(ctx); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL())
}
}
}
Expand Down Expand Up @@ -817,9 +815,7 @@ Run gocryptotrader with the following settings enabled in config
```go
// wsReadData gets and passes on websocket messages for processing
func (e *Exchange) wsReadData() {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()

for {
select {
case <-e.Websocket.ShutdownC:
Expand All @@ -829,10 +825,10 @@ func (e *Exchange) wsReadData() {
if resp.Raw == nil {
return
}

err := e.wsHandleData(resp.Raw)
if err != nil {
e.Websocket.DataHandler <- err
if err := e.wsHandleData(ctx, resp.Raw); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
}
}
}
Expand Down Expand Up @@ -875,15 +871,15 @@ If a suitable struct does not exist in wshandler, wrapper types are the next pre
if err := json.Unmarshal(respRaw, &resultData);err != nil {
return err
}
e.Websocket.DataHandler <- &ticker.Price{
return e.Websocket.DataHandler.Send(ctx, &ticker.Price{
ExchangeName: e.Name,
Bid: resultData.Ticker.Bid,
Ask: resultData.Ticker.Ask,
Last: resultData.Ticker.Last,
LastUpdated: resultData.Ticker.Time,
Pair: p,
AssetType: a,
}
})
}
```

Expand All @@ -896,7 +892,7 @@ If neither of those provide a suitable struct to store the data in, the data can
if err != nil {
return err
}
e.Websocket.DataHandler <- resultData.FillsData
return e.Websocket.DataHandler.Send(ctx, resultData.FillsData)
```

- Data Handling can be tested offline similar to the following example:
Expand Down
18 changes: 14 additions & 4 deletions engine/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ func (s *RPCServer) WebsocketGetInfo(_ context.Context, r *gctrpc.WebsocketGetIn
}

// WebsocketSetEnabled enables or disables the websocket client
func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetEnabled(ctx context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -2964,7 +2964,7 @@ func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSe
}

if r.Enable {
err = w.Enable()
err = w.Enable(s.rpcContextToLongLivedSession(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -3013,7 +3013,7 @@ func (s *RPCServer) WebsocketGetSubscriptions(_ context.Context, r *gctrpc.Webso
}

// WebsocketSetProxy sets client websocket connection proxy
func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetProxy(ctx context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -3024,7 +3024,7 @@ func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetP
return nil, fmt.Errorf("websocket not supported for exchange %s", r.Exchange)
}

err = w.SetProxyAddress(r.Proxy)
err = w.SetProxyAddress(s.rpcContextToLongLivedSession(ctx), r.Proxy)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -5884,3 +5884,13 @@ func (s *RPCServer) GetCurrencyTradeURL(ctx context.Context, r *gctrpc.GetCurren
Url: url,
}, nil
}

// rpcContextToLongLivedSession converts a short-lived incoming context to a long-lived outgoing context, this is due
// to the incoming context being cancelled when the RPC call completes.
func (s *RPCServer) rpcContextToLongLivedSession(ctx context.Context) context.Context {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐒 Optional nitpick:
Okay; I kinda get it, except LongLivedSession seems ... well weird.
We're basically just stripping out the lifecycle from the context, right?
So ... isn't this just a very fancy context.WithoutCancel ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, any other naming suggestions and I will change it, and should this function live somewhere like common?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean... if it's context.WithoutCancel you don't need this function... right?
Just drop in replace it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I feel like an idiot. 😆

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch: ca24dc1

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(nil) // Fallback to empty metadata
}
return metadata.NewOutgoingContext(context.Background(), md)
}
48 changes: 48 additions & 0 deletions engine/rpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3869,3 +3869,51 @@ func TestGetCurrencyTradeURL(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, resp.Url)
}

func TestRPCContextToLongLivedSession(t *testing.T) {
t.Parallel()
s := &RPCServer{}
t.Run("preserve metadata", func(t *testing.T) {
t.Parallel()
incomingMD := metadata.Pairs(
"authorization", "Bearer token",
"x-id", "123",
)
inCtx, cancel := context.WithCancel(metadata.NewIncomingContext(context.Background(), incomingMD))
defer cancel()

newCtx := s.rpcContextToLongLivedSession(inCtx)
require.NotNil(t, newCtx)

outMD, ok := metadata.FromOutgoingContext(newCtx)
require.True(t, ok)
require.Len(t, outMD, len(incomingMD))
assert.ElementsMatch(t, incomingMD.Get("authorization"), outMD.Get("authorization"))
assert.ElementsMatch(t, incomingMD.Get("x-id"), outMD.Get("x-id"))

cancel()
assert.Eventually(t, func() bool {
return inCtx.Err() != nil
}, time.Second, 10*time.Millisecond)
assert.Eventually(t, func() bool {
return newCtx.Err() == nil
}, 500*time.Millisecond, 20*time.Millisecond)
})

t.Run("no metadata", func(t *testing.T) {
t.Parallel()
inCtx, cancel := context.WithCancel(context.Background())
defer cancel()

newCtx := s.rpcContextToLongLivedSession(inCtx)
require.NotNil(t, newCtx)

outMD, ok := metadata.FromOutgoingContext(newCtx)
require.True(t, ok)
assert.Empty(t, outMD)

cancel()
assert.Eventually(t, func() bool { return inCtx.Err() != nil }, time.Second, 10*time.Millisecond)
assert.Eventually(t, func() bool { return newCtx.Err() == nil }, 500*time.Millisecond, 20*time.Millisecond)
})
}
10 changes: 5 additions & 5 deletions engine/websocketroutine_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -139,7 +140,7 @@ func (m *WebsocketRoutineManager) websocketRoutine() {
log.Errorf(log.WebsocketMgr, "%v", err)
}

if err := ws.Connect(); err != nil {
if err := ws.Connect(context.TODO()); err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}
})
Expand Down Expand Up @@ -167,14 +168,13 @@ func (m *WebsocketRoutineManager) websocketDataReceiver(ws *websocket.Manager) e
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
case payload := <-ws.DataHandler.C:
if payload.Data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
if err := m.dataHandlers[x](ws.GetName(), payload.Data); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
Expand Down
10 changes: 6 additions & 4 deletions engine/websocketroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,18 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
}

mock := websocket.NewManager()
mock.ToRoutine = make(chan any)
m.state = readyState
err = m.websocketDataReceiver(mock)
if err != nil {
t.Fatal(err)
}

mock.ToRoutine <- nil
mock.ToRoutine <- 1336
mock.ToRoutine <- "intercepted"
err = mock.DataHandler.Send(t.Context(), nil)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), 1336)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), "intercepted")
require.NoError(t, err)

if r := <-dataChan; r != "intercepted" {
t.Fatal("unexpected value received")
Expand Down
47 changes: 47 additions & 0 deletions exchange/message/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package message
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚧 Change request:
exchange/message feels like a very overloaded term and it feels like you're domain-squatting it a bit here.
I get that this could represent any message queue.
But right now we don't have an idea of any other type of messages here.
And this is fundamentally a relay, not a message thing.

So I'd like to see this in exchange/websocket/relay/relay.go instead.
But I'm also thinking it's even better in exchange/websocket/buffer/relay.go 🎉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning for placing this in exchange/message was that the relay is intentionally generic — it doesn’t depend on WebSockets or exchange code, and its API (message.Relay) is transport-agnostic by design.

If we move it into websocket/relay, it implicitly couples the type to a subsystem it doesn’t rely on and might make it harder to reuse later for other messaging paths (e.g., FIX, -- these are all a strech --> RPC, internal event pumps, REST batch queues).

That said, I’m happy to relocate it if the consensus is that the Relay should remain WebSocket-specific for now.

@thrasher- @gloriousCode

Also Buffer naming I think will be subject to change because no websocket orderbooks should buffer events or sort them.

Also what about exchange/stream/relay.go?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think exchange/stream/relay.go is the best. Removes the websocket coupling since it's protocol agnostic and I agree that message is too generic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • exchange/stream/relay.go is good 👍

🥃 Your reasoning about buffer was backwards. Just because websocket orderbooks are in a package called buffer doesn't mean that nothing else can be, or that when they move out we couldn't leave a relay in there. But that's just an aside.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was more meaning, its going to change, then we will have to change the package name for all relay instances and it doesn't really reflect buffer that's all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: db647bf


import (
"context"
"errors"
"fmt"
)

var errChannelBufferFull = errors.New("channel buffer is full")

// Relay defines a channel relay for messages
type Relay struct {
C <-chan Payload
comm chan Payload
}

// Payload represents a relayed message with a context
type Payload struct {
// TODO: remove context from payload see: https://github.com/thrasher-corp/gocryptotrader/pull/2066#discussion_r2501403057
Ctx context.Context //nolint:containedctx // context needed for tracing/metrics
Data any
}

// NewRelay creates a new Relay instance with a specified buffer size
func NewRelay(buffer uint) *Relay {
if buffer == 0 {
panic("buffer size must be greater than 0")
}
comm := make(chan Payload, buffer)
return &Relay{comm: comm, C: comm}
}

// Send sends a message to the channel receiver
// This is non-blocking and returns an error if the channel buffer is full
func (r *Relay) Send(ctx context.Context, data any) error {
select {
case r.comm <- Payload{Ctx: ctx, Data: data}:
return nil
default:
return fmt.Errorf("%w: failed to relay <%T>", errChannelBufferFull, data)
}
}

// Close closes the relay channel
func (r *Relay) Close() {
close(r.comm)
}
43 changes: 43 additions & 0 deletions exchange/message/message_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package message

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewRelay(t *testing.T) {
t.Parallel()
assert.Panics(t, func() { NewRelay(0) }, "buffer size should be greater than 0")
r := NewRelay(5)
require.NotNil(t, r)
assert.Equal(t, 5, cap(r.comm))
}

func TestSend(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
assert.NoError(t, r.Send(t.Context(), "test"))
assert.ErrorIs(t, r.Send(t.Context(), "overflow"), errChannelBufferFull)
}

func TestRead(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
require.Empty(t, r.C)
assert.NoError(t, r.Send(t.Context(), "test"))
require.Len(t, r.C, 1)
assert.Equal(t, "test", (<-r.C).Data)
}

func TestClose(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
r.Close()
_, ok := <-r.C
assert.False(t, ok)
}
Loading
Loading