Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
360d0e5
Refactor websocket handling to support context propagation
Sep 22, 2025
4be8b4b
linter/AI: nits and fixes
Sep 22, 2025
cce08e7
linter: fixes again
Sep 22, 2025
86d218c
fix test
Sep 22, 2025
0cb3876
Merge branch 'master' into websock_read_ctx
Oct 14, 2025
98c0817
instead of a read function just expose the read only channel field li…
Oct 14, 2025
a87e39f
Merge branch 'master' into websock_read_ctx
Oct 28, 2025
6b0a3b5
Merge branch 'master' into websock_read_ctx
Nov 5, 2025
f380808
glorious: nits
Nov 9, 2025
5eacbae
Merge branch 'master' into websock_read_ctx
Nov 12, 2025
88e48a1
glorious: nits
Nov 12, 2025
c8fb278
rm jank for finding issue
Nov 12, 2025
f6dd7c3
linter: fix
Nov 12, 2025
fc7df14
glorious: more nits
Nov 12, 2025
770d7a6
Add todo for context removal
Nov 12, 2025
7d6e595
Merge branch 'master' into websock_read_ctx
Nov 27, 2025
cabcd83
Merge branch 'master' into websock_read_ctx
Nov 30, 2025
f21e02d
Implement context freezing and thawing functions; update Payload to u…
Dec 1, 2025
7625f60
linter: fix
Dec 1, 2025
7185eb9
linter: fix again
Dec 1, 2025
2eb5f6a
Merge branch 'master' into websock_read_ctx
Dec 1, 2025
e42ae00
refactor: simplify FrozenContext structure and update related functions
Dec 2, 2025
bab00e0
Update common/common.go
shazbert Dec 2, 2025
1680819
Update common/common.go
shazbert Dec 2, 2025
239d160
glorious: removal of check
Dec 3, 2025
07bba2a
thrasher: rm unused error return
Dec 3, 2025
1e63153
Update exchange/websocket/manager.go
shazbert Dec 3, 2025
9f5fbda
rm: check for nil datahandler
Dec 3, 2025
ca24dc1
gk: rm rpcContextToLongLivedSession
Dec 4, 2025
db647bf
gk: rn package
Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ var (
errCannotSetInvalidTimeout = errors.New("cannot set new HTTP client with timeout that is equal or less than 0")
errUserAgentInvalid = errors.New("cannot set invalid user agent")
errHTTPClientInvalid = errors.New("custom http client cannot be nil")
errDuplicateContextKey = errors.New("duplicate context key")
)

// NilGuard returns an ErrNilPointer with the type of the first nil argument
Expand Down Expand Up @@ -695,3 +696,62 @@ func SetIfZero[T comparable](p *T, def T) bool {
*p = def
return true
}

var (
contextKeys []any
contextKeysMu sync.RWMutex
)

// RegisterContextKey registers a key to be captured by FreezeCtx
func RegisterContextKey(key any) {
contextKeysMu.Lock()
defer contextKeysMu.Unlock()
if !slices.Contains(contextKeys, key) {
contextKeys = append(contextKeys, key)
}
}

// FrozenContext holds captured context values
type FrozenContext map[any]any

// FreezeCtx captures values from the context for registered keys
func FreezeCtx(ctx context.Context) FrozenContext {
contextKeysMu.RLock()
defer contextKeysMu.RUnlock()

values := make(FrozenContext, len(contextKeys))
for _, key := range contextKeys {
if val := ctx.Value(key); val != nil {
values[key] = val
}
}
return values
}

// ThawCtx creates a new context from the frozen context using context.Background() as parent
func ThawCtx(fc FrozenContext) (context.Context, error) {
return MergeCtx(context.Background(), fc)
}

// MergeCtx adds the frozen values to an existing context
func MergeCtx(ctx context.Context, fc FrozenContext) (context.Context, error) {
for k := range fc {
if ctx.Value(k) != nil {
return nil, fmt.Errorf("%w: %q", errDuplicateContextKey, k)
}
}
return &mergeCtx{Context: ctx, frozen: fc}, nil
}

// mergeCtx is a context that merges values from a frozen context and a parent context.
type mergeCtx struct {
context.Context //nolint:containedctx // Using context.WithValue will nest contexts and cause lookup latency
frozen FrozenContext
}

func (m *mergeCtx) Value(key any) any {
if val, ok := m.frozen[key]; ok {
return val
}
return m.Context.Value(key)
}
39 changes: 39 additions & 0 deletions common/common_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -691,3 +692,41 @@ func TestSetIfZero(t *testing.T) {
assert.True(t, changed, "SetIfZero should change a zero value")
assert.Equal(t, "world", s, "SetIfZero should change a zero value")
}

func TestContextFunctions(t *testing.T) {
t.Parallel()

type key string
const k1 key = "key1"
const k2 key = "key2"
const k3 key = "key3"

RegisterContextKey(k1)
RegisterContextKey(k2)

ctx := context.WithValue(context.Background(), k1, "value1")
ctx = context.WithValue(ctx, k2, "value2")
ctx = context.WithValue(ctx, k3, "value3") // Not registered

frozen := FreezeCtx(ctx)

assert.Equal(t, "value1", frozen[k1], "should have captured k1")
assert.Equal(t, "value2", frozen[k2], "should have captured k2")
assert.Zero(t, frozen[k3], "k3 should not be captured")

thawed, err := ThawCtx(frozen)
require.NoError(t, err)
assert.Equal(t, "value1", thawed.Value(k1), "should have k1 after thaw")
assert.Equal(t, "value2", thawed.Value(k2), "should have k2 after thaw")
assert.Nil(t, thawed.Value(k3), "Thawed context should not have k3")

ctx2 := context.WithValue(context.Background(), k3, "value3_new")
merged, err := MergeCtx(ctx2, frozen)
require.NoError(t, err)
assert.Equal(t, "value1", merged.Value(k1), "should have k1 from frozen")
assert.Equal(t, "value2", merged.Value(k2), "should have k2 from frozen")
assert.Equal(t, "value3_new", merged.Value(k3), "should have k3 from parent")

_, err = MergeCtx(merged, frozen)
require.ErrorIs(t, err, errDuplicateContextKey, "must error on duplicate keys")
}
30 changes: 13 additions & 17 deletions docs/ADD_NEW_EXCHANGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,19 +693,17 @@ func (e *Exchange) WsConnect() error {
// KeepAuthKeyAlive will continuously send messages to
// keep the WS auth key active
func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()
ticks := time.NewTicker(time.Minute * 30)
for {
select {
case <-e.Websocket.ShutdownC:
ticks.Stop()
return
case <-ticks.C:
err := e.MaintainWsAuthStreamKey(ctx)
if err != nil {
e.Websocket.DataHandler <- err
log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name)
case <-time.After(time.Minute * 30):
if err := e.MaintainWsAuthStreamKey(ctx); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL())
}
}
}
Expand Down Expand Up @@ -817,9 +815,7 @@ Run gocryptotrader with the following settings enabled in config
```go
// wsReadData gets and passes on websocket messages for processing
func (e *Exchange) wsReadData() {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()

for {
select {
case <-e.Websocket.ShutdownC:
Expand All @@ -829,10 +825,10 @@ func (e *Exchange) wsReadData() {
if resp.Raw == nil {
return
}

err := e.wsHandleData(resp.Raw)
if err != nil {
e.Websocket.DataHandler <- err
if err := e.wsHandleData(ctx, resp.Raw); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
}
}
}
Expand Down Expand Up @@ -875,15 +871,15 @@ If a suitable struct does not exist in wshandler, wrapper types are the next pre
if err := json.Unmarshal(respRaw, &resultData);err != nil {
return err
}
e.Websocket.DataHandler <- &ticker.Price{
return e.Websocket.DataHandler.Send(ctx, &ticker.Price{
ExchangeName: e.Name,
Bid: resultData.Ticker.Bid,
Ask: resultData.Ticker.Ask,
Last: resultData.Ticker.Last,
LastUpdated: resultData.Ticker.Time,
Pair: p,
AssetType: a,
}
})
}
```

Expand All @@ -896,7 +892,7 @@ If neither of those provide a suitable struct to store the data in, the data can
if err != nil {
return err
}
e.Websocket.DataHandler <- resultData.FillsData
return e.Websocket.DataHandler.Send(ctx, resultData.FillsData)
```

- Data Handling can be tested offline similar to the following example:
Expand Down
18 changes: 14 additions & 4 deletions engine/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ func (s *RPCServer) WebsocketGetInfo(_ context.Context, r *gctrpc.WebsocketGetIn
}

// WebsocketSetEnabled enables or disables the websocket client
func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetEnabled(ctx context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -2964,7 +2964,7 @@ func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSe
}

if r.Enable {
err = w.Enable()
err = w.Enable(s.rpcContextToLongLivedSession(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -3013,7 +3013,7 @@ func (s *RPCServer) WebsocketGetSubscriptions(_ context.Context, r *gctrpc.Webso
}

// WebsocketSetProxy sets client websocket connection proxy
func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetProxy(ctx context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -3024,7 +3024,7 @@ func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetP
return nil, fmt.Errorf("websocket not supported for exchange %s", r.Exchange)
}

err = w.SetProxyAddress(r.Proxy)
err = w.SetProxyAddress(s.rpcContextToLongLivedSession(ctx), r.Proxy)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -5884,3 +5884,13 @@ func (s *RPCServer) GetCurrencyTradeURL(ctx context.Context, r *gctrpc.GetCurren
Url: url,
}, nil
}

// rpcContextToLongLivedSession converts a short-lived incoming context to a long-lived outgoing context, this is due
// to the incoming context being cancelled when the RPC call completes.
func (s *RPCServer) rpcContextToLongLivedSession(ctx context.Context) context.Context {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐒 Optional nitpick:
Okay; I kinda get it, except LongLivedSession seems ... well weird.
We're basically just stripping out the lifecycle from the context, right?
So ... isn't this just a very fancy context.WithoutCancel ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, any other naming suggestions and I will change it, and should this function live somewhere like common?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean... if it's context.WithoutCancel you don't need this function... right?
Just drop in replace it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I feel like an idiot. 😆

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch: ca24dc1

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(nil) // Fallback to empty metadata
}
return metadata.NewOutgoingContext(context.Background(), md)
}
48 changes: 48 additions & 0 deletions engine/rpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3869,3 +3869,51 @@ func TestGetCurrencyTradeURL(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, resp.Url)
}

func TestRPCContextToLongLivedSession(t *testing.T) {
t.Parallel()
s := &RPCServer{}
t.Run("preserve metadata", func(t *testing.T) {
t.Parallel()
incomingMD := metadata.Pairs(
"authorization", "Bearer token",
"x-id", "123",
)
inCtx, cancel := context.WithCancel(metadata.NewIncomingContext(context.Background(), incomingMD))
defer cancel()

newCtx := s.rpcContextToLongLivedSession(inCtx)
require.NotNil(t, newCtx)

outMD, ok := metadata.FromOutgoingContext(newCtx)
require.True(t, ok)
require.Len(t, outMD, len(incomingMD))
assert.ElementsMatch(t, incomingMD.Get("authorization"), outMD.Get("authorization"))
assert.ElementsMatch(t, incomingMD.Get("x-id"), outMD.Get("x-id"))

cancel()
assert.Eventually(t, func() bool {
return inCtx.Err() != nil
}, time.Second, 10*time.Millisecond)
assert.Eventually(t, func() bool {
return newCtx.Err() == nil
}, 500*time.Millisecond, 20*time.Millisecond)
})

t.Run("no metadata", func(t *testing.T) {
t.Parallel()
inCtx, cancel := context.WithCancel(context.Background())
defer cancel()

newCtx := s.rpcContextToLongLivedSession(inCtx)
require.NotNil(t, newCtx)

outMD, ok := metadata.FromOutgoingContext(newCtx)
require.True(t, ok)
assert.Empty(t, outMD)

cancel()
assert.Eventually(t, func() bool { return inCtx.Err() != nil }, time.Second, 10*time.Millisecond)
assert.Eventually(t, func() bool { return newCtx.Err() == nil }, 500*time.Millisecond, 20*time.Millisecond)
})
}
10 changes: 5 additions & 5 deletions engine/websocketroutine_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -139,7 +140,7 @@ func (m *WebsocketRoutineManager) websocketRoutine() {
log.Errorf(log.WebsocketMgr, "%v", err)
}

if err := ws.Connect(); err != nil {
if err := ws.Connect(context.TODO()); err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}
})
Expand Down Expand Up @@ -167,14 +168,13 @@ func (m *WebsocketRoutineManager) websocketDataReceiver(ws *websocket.Manager) e
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
case payload := <-ws.DataHandler.C:
if payload.Data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
if err := m.dataHandlers[x](ws.GetName(), payload.Data); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
Expand Down
10 changes: 6 additions & 4 deletions engine/websocketroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,18 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
}

mock := websocket.NewManager()
mock.ToRoutine = make(chan any)
m.state = readyState
err = m.websocketDataReceiver(mock)
if err != nil {
t.Fatal(err)
}

mock.ToRoutine <- nil
mock.ToRoutine <- 1336
mock.ToRoutine <- "intercepted"
err = mock.DataHandler.Send(t.Context(), nil)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), 1336)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), "intercepted")
require.NoError(t, err)

if r := <-dataChan; r != "intercepted" {
t.Fatal("unexpected value received")
Expand Down
Loading
Loading