Skip to content

Commit e8de0d9

Browse files
authored
Fix sending redundant RequestN frame (#101)
1 parent 86e27e3 commit e8de0d9

File tree

9 files changed

+142
-14
lines changed

9 files changed

+142
-14
lines changed

rsocket-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ kotlin {
3636
}
3737
val commonTest by getting {
3838
dependencies {
39+
implementation("app.cash.turbine:turbine:0.2.1")
3940
implementation("io.ktor:ktor-utils:1.4.0")
4041
implementation(project(":rsocket-transport-local"))
4142
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ internal class RequestChannelRequesterFlow(
3535
override fun create(context: CoroutineContext, capacity: Int): RequestChannelRequesterFlow =
3636
RequestChannelRequesterFlow(payloads, requester, state, context, capacity)
3737

38-
@OptIn(ExperimentalCoroutinesApi::class)
39-
override suspend fun collectTo(scope: ProducerScope<Payload>): Unit = with(state) {
38+
override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector<Payload>): Unit = with(state) {
4039
val streamId = requester.createStream()
4140
val receiverDeferred = CompletableDeferred<ReceiveChannel<RequestFrame>?>()
4241
val request = launchCancelable(streamId) {
@@ -47,6 +46,7 @@ internal class RequestChannelRequesterFlow(
4746
}
4847
request.invokeOnCompletion {
4948
if (receiverDeferred.isCompleted) {
49+
@OptIn(ExperimentalCoroutinesApi::class)
5050
if (it != null && it !is CancellationException) receiverDeferred.getCompleted()?.cancelConsumed(it)
5151
} else {
5252
if (it == null) receiverDeferred.complete(null)
@@ -55,7 +55,7 @@ internal class RequestChannelRequesterFlow(
5555
}
5656
try {
5757
val receiver = receiverDeferred.await() ?: return
58-
collectStream(streamId, receiver, scope)
58+
collectStream(streamId, receiver, collectContext, collector)
5959
} catch (e: Throwable) {
6060
if (e is CancellationException) request.cancel(e)
6161
else request.cancel("Receiver failed", e)

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package io.rsocket.kotlin.internal.flow
1919
import io.rsocket.kotlin.frame.*
2020
import io.rsocket.kotlin.internal.*
2121
import io.rsocket.kotlin.payload.*
22-
import kotlinx.coroutines.*
2322
import kotlinx.coroutines.channels.*
23+
import kotlinx.coroutines.flow.*
2424
import kotlin.coroutines.*
2525

2626
//TODO prevent consuming more then one time - add atomic ?
@@ -35,9 +35,8 @@ internal class RequestChannelResponderFlow(
3535
override fun create(context: CoroutineContext, capacity: Int): RequestChannelResponderFlow =
3636
RequestChannelResponderFlow(streamId, receiver, state, context, capacity)
3737

38-
@OptIn(ExperimentalCoroutinesApi::class)
39-
override suspend fun collectTo(scope: ProducerScope<Payload>): Unit = with(state) {
38+
override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector<Payload>): Unit = with(state) {
4039
send(RequestNFrame(streamId, requestSize - 1)) //-1 because first payload received
41-
collectStream(streamId, receiver, scope)
40+
collectStream(streamId, receiver, collectContext, collector)
4241
}
4342
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package io.rsocket.kotlin.internal.flow
1919
import io.rsocket.kotlin.frame.*
2020
import io.rsocket.kotlin.internal.*
2121
import io.rsocket.kotlin.payload.*
22-
import kotlinx.coroutines.*
2322
import kotlinx.coroutines.channels.*
23+
import kotlinx.coroutines.flow.*
2424
import kotlin.coroutines.*
2525

2626
internal class RequestStreamRequesterFlow(
@@ -33,11 +33,10 @@ internal class RequestStreamRequesterFlow(
3333
override fun create(context: CoroutineContext, capacity: Int): RequestStreamRequesterFlow =
3434
RequestStreamRequesterFlow(payload, requester, state, context, capacity)
3535

36-
@OptIn(ExperimentalCoroutinesApi::class)
37-
override suspend fun collectTo(scope: ProducerScope<Payload>): Unit = with(state) {
36+
override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector<Payload>): Unit = with(state) {
3837
val streamId = requester.createStream()
3938
val receiver = createReceiverFor(streamId)
4039
send(RequestStreamFrame(streamId, requestSize, payload))
41-
collectStream(streamId, receiver, scope)
40+
collectStream(streamId, receiver, collectContext, collector)
4241
}
4342
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/StreamFlow.kt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import io.rsocket.kotlin.internal.*
2121
import io.rsocket.kotlin.payload.*
2222
import kotlinx.coroutines.*
2323
import kotlinx.coroutines.channels.*
24+
import kotlinx.coroutines.flow.*
2425
import kotlinx.coroutines.flow.internal.*
2526
import kotlin.coroutines.*
2627

@@ -40,18 +41,33 @@ internal abstract class StreamFlow(
4041
else -> capacity.also { check(it >= 1) }
4142
}
4243

44+
protected abstract suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector<Payload>)
45+
46+
final override suspend fun collect(collector: FlowCollector<Payload>) {
47+
val collectContext = context + coroutineContext
48+
withContext(coroutineContext + context) {
49+
collectImpl(collectContext, collector)
50+
}
51+
}
52+
53+
final override suspend fun collectTo(scope: ProducerScope<Payload>): Unit =
54+
collectImpl(scope.coroutineContext, SendingCollector(scope.channel))
55+
4356
protected suspend fun collectStream(
4457
streamId: Int,
4558
receiver: ReceiveChannel<RequestFrame>,
46-
scope: ProducerScope<Payload>,
59+
collectContext: CoroutineContext,
60+
collector: FlowCollector<Payload>,
4761
): Unit = with(state) {
48-
val collector = SendingCollector(scope.channel)
4962
consumeReceiverFor(streamId) {
5063
var consumed = 0
5164
//TODO fragmentation
5265
for (frame in receiver) {
5366
if (frame.complete) return //TODO check next flag
54-
collector.emit(frame.payload)
67+
//emit in collectContext to prevent `Flow invariant is violated`
68+
withContext(collectContext) {
69+
collector.emit(frame.payload)
70+
}
5571
if (++consumed == requestSize) {
5672
consumed = 0
5773
send(RequestNFrame(streamId, requestSize))

rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/Test.common.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ import kotlinx.coroutines.*
2020
import kotlin.time.*
2121

2222
expect fun test(timeout: Duration? = 10.seconds, block: suspend CoroutineScope.() -> Unit)
23+
24+
expect val anotherDispatcher: CoroutineDispatcher

rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package io.rsocket.kotlin.internal
1818

19+
import app.cash.turbine.*
1920
import io.rsocket.kotlin.*
2021
import io.rsocket.kotlin.error.*
2122
import io.rsocket.kotlin.frame.*
@@ -24,6 +25,7 @@ import io.rsocket.kotlin.payload.*
2425
import kotlinx.coroutines.*
2526
import kotlinx.coroutines.channels.*
2627
import kotlinx.coroutines.flow.*
28+
import kotlin.coroutines.*
2729
import kotlin.test.*
2830
import kotlin.time.*
2931

@@ -57,6 +59,110 @@ class RSocketRequesterTest {
5759
assertEquals(5, frame.initialRequest)
5860
}
5961

62+
@Test
63+
fun testStreamBuffer() = test {
64+
val flow =
65+
requester.requestStream(Payload.Empty)
66+
.buffer(2)
67+
.take(2)
68+
69+
assertEquals(0, connection.sentFrames.size)
70+
71+
flow.launchIn(CoroutineScope(connection.job))
72+
73+
connection.sentAsFlow().test {
74+
expectItem().let { frame ->
75+
assertTrue(frame is RequestFrame)
76+
assertEquals(FrameType.RequestStream, frame.type)
77+
assertEquals(2, frame.initialRequest)
78+
}
79+
delay(200)
80+
expectNoEvents()
81+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
82+
delay(200)
83+
expectNoEvents()
84+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
85+
delay(200)
86+
expectItem().let { frame ->
87+
assertTrue(frame is CancelFrame)
88+
}
89+
delay(200)
90+
expectNoEvents()
91+
}
92+
}
93+
94+
class SomeContext(val context: Int) : AbstractCoroutineContextElement(SomeContext) {
95+
companion object Key : CoroutineContext.Key<SomeContext>
96+
}
97+
98+
@Test
99+
fun testStreamBufferWithAdditionalContext() = test {
100+
val flow =
101+
requester.requestStream(Payload.Empty)
102+
.buffer(2)
103+
.flowOn(SomeContext(2))
104+
.take(2)
105+
106+
assertEquals(0, connection.sentFrames.size)
107+
108+
flow.launchIn(CoroutineScope(connection.job))
109+
110+
connection.sentAsFlow().test {
111+
expectItem().let { frame ->
112+
assertTrue(frame is RequestFrame)
113+
assertEquals(FrameType.RequestStream, frame.type)
114+
assertEquals(2, frame.initialRequest)
115+
}
116+
delay(200)
117+
expectNoEvents()
118+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
119+
delay(200)
120+
expectNoEvents()
121+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
122+
delay(200)
123+
expectItem().let { frame ->
124+
assertTrue(frame is CancelFrame)
125+
}
126+
delay(200)
127+
expectNoEvents()
128+
}
129+
}
130+
131+
@Test
132+
fun testStreamBufferWithAnotherDispatcher() = test {
133+
val flow =
134+
requester.requestStream(Payload.Empty)
135+
.buffer(2)
136+
.flowOn(anotherDispatcher) //change dispatcher before take
137+
.take(2)
138+
.transform { emit(it) } //force using SafeCollector to check that `Flow invariant is violated` not happens
139+
140+
assertEquals(0, connection.sentFrames.size)
141+
142+
flow.launchIn(CoroutineScope(connection.job))
143+
144+
connection.sentAsFlow().test {
145+
expectItem().let { frame ->
146+
assertTrue(frame is RequestFrame)
147+
assertEquals(FrameType.RequestStream, frame.type)
148+
assertEquals(2, frame.initialRequest)
149+
}
150+
delay(200)
151+
expectNoEvents()
152+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
153+
delay(200)
154+
expectNoEvents() //will fail here if `Flow invariant is violated`
155+
connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty))
156+
delay(200)
157+
expectItem().let { frame ->
158+
println(frame)
159+
assertTrue(frame is CancelFrame)
160+
}
161+
delay(200)
162+
expectNoEvents()
163+
}
164+
}
165+
60166
@Test
61167
fun testHandleSetupException() = test {
62168
val errorMessage = "error"

rsocket-core/src/jsTest/kotlin/io/rsocket/kotlin/Test.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ actual fun test(timeout: Duration?, block: suspend CoroutineScope.() -> Unit): d
2525
else -> withTimeout(timeout) { block() }
2626
}
2727
}
28+
29+
//JS is single threaded, so it have only one dispatcher backed by one threed
30+
actual val anotherDispatcher: CoroutineDispatcher get() = Dispatchers.Default

rsocket-core/src/jvmTest/kotlin/io/rsocket/kotlin/Test.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ actual fun test(timeout: Duration?, block: suspend CoroutineScope.() -> Unit): U
3131
else -> withTimeout(timeout) { block() }
3232
}
3333
}
34+
35+
actual val anotherDispatcher: CoroutineDispatcher get() = Dispatchers.IO

0 commit comments

Comments
 (0)