Skip to content

Commit a11ef3f

Browse files
OlegDokukarobertroeser
authored andcommitted
fixes uncontrolled data sending in case of direct propagation of request from requester (rsocket#595)
* fixes uncontrolled data sending in case of direct propagation of request from requester * fixes timeout typo * replaces forEach with explicit loop * optimize access to limitableRequestPublisher Signed-off-by: Oleh Dokuka <[email protected]>
1 parent e48658e commit a11ef3f

7 files changed

Lines changed: 224 additions & 25 deletions

File tree

rsocket-core/src/main/java/io/rsocket/RSocketClient.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ class RSocketClient implements RSocket {
9898
connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);
9999

100100
connection
101-
.send(sendProcessor)
101+
.send(
102+
sendProcessor.doOnRequest(
103+
r -> {
104+
for (LimitableRequestPublisher lrp : senders.values()) {
105+
lrp.increaseInternalLimit(r);
106+
}
107+
}))
102108
.doFinally(this::handleSendProcessorCancel)
103109
.subscribe(null, this::handleSendProcessorError);
104110

@@ -335,7 +341,8 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
335341
.transform(
336342
f -> {
337343
LimitableRequestPublisher<Payload> wrapped =
338-
LimitableRequestPublisher.wrap(f);
344+
LimitableRequestPublisher.wrap(
345+
f, sendProcessor.available());
339346
// Need to set this to one for first the frame
340347
wrapped.increaseRequestLimit(1);
341348
senders.put(streamId, wrapped);

rsocket-core/src/main/java/io/rsocket/RSocketServer.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class RSocketServer implements ResponderRSocket {
4949
private final PayloadDecoder payloadDecoder;
5050
private final Consumer<Throwable> errorConsumer;
5151

52+
private final Map<Integer, LimitableRequestPublisher> sendingLimitableSubscriptions;
5253
private final Map<Integer, Subscription> sendingSubscriptions;
5354
private final Map<Integer, Processor<Payload, Payload>> channelProcessors;
5455

@@ -84,6 +85,7 @@ class RSocketServer implements ResponderRSocket {
8485

8586
this.payloadDecoder = payloadDecoder;
8687
this.errorConsumer = errorConsumer;
88+
this.sendingLimitableSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
8789
this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
8890
this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>());
8991

@@ -92,7 +94,13 @@ class RSocketServer implements ResponderRSocket {
9294
this.sendProcessor = new UnboundedProcessor<>();
9395

9496
connection
95-
.send(sendProcessor)
97+
.send(
98+
sendProcessor.doOnRequest(
99+
r -> {
100+
for (LimitableRequestPublisher lrp : sendingLimitableSubscriptions.values()) {
101+
lrp.increaseInternalLimit(r);
102+
}
103+
}))
96104
.doFinally(this::handleSendProcessorCancel)
97105
.subscribe(null, this::handleSendProcessorError);
98106

@@ -138,6 +146,17 @@ private void handleSendProcessorError(Throwable t) {
138146
}
139147
});
140148

149+
sendingLimitableSubscriptions
150+
.values()
151+
.forEach(
152+
subscription -> {
153+
try {
154+
subscription.cancel();
155+
} catch (Throwable e) {
156+
errorConsumer.accept(e);
157+
}
158+
});
159+
141160
channelProcessors
142161
.values()
143162
.forEach(
@@ -166,6 +185,17 @@ private void handleSendProcessorCancel(SignalType t) {
166185
}
167186
});
168187

188+
sendingLimitableSubscriptions
189+
.values()
190+
.forEach(
191+
subscription -> {
192+
try {
193+
subscription.cancel();
194+
} catch (Throwable e) {
195+
errorConsumer.accept(e);
196+
}
197+
});
198+
169199
channelProcessors
170200
.values()
171201
.forEach(
@@ -261,6 +291,9 @@ private void cleanup() {
261291
private synchronized void cleanUpSendingSubscriptions() {
262292
sendingSubscriptions.values().forEach(Subscription::cancel);
263293
sendingSubscriptions.clear();
294+
295+
sendingLimitableSubscriptions.values().forEach(Subscription::cancel);
296+
sendingLimitableSubscriptions.clear();
264297
}
265298

266299
private synchronized void cleanUpChannelProcessors() {
@@ -391,12 +424,12 @@ private void handleStream(int streamId, Flux<Payload> response, int initialReque
391424
.transform(
392425
frameFlux -> {
393426
LimitableRequestPublisher<Payload> payloads =
394-
LimitableRequestPublisher.wrap(frameFlux);
395-
sendingSubscriptions.put(streamId, payloads);
427+
LimitableRequestPublisher.wrap(frameFlux, sendProcessor.available());
428+
sendingLimitableSubscriptions.put(streamId, payloads);
396429
payloads.increaseRequestLimit(initialRequestN);
397430
return payloads;
398431
})
399-
.doFinally(signalType -> sendingSubscriptions.remove(streamId))
432+
.doFinally(signalType -> sendingLimitableSubscriptions.remove(streamId))
400433
.subscribe(
401434
payload -> {
402435
ByteBuf byteBuf = null;
@@ -449,6 +482,11 @@ private void handleKeepAliveFrame(ByteBuf frame) {
449482

450483
private void handleCancelFrame(int streamId) {
451484
Subscription subscription = sendingSubscriptions.remove(streamId);
485+
486+
if (subscription == null) {
487+
subscription = sendingLimitableSubscriptions.get(streamId);
488+
}
489+
452490
if (subscription != null) {
453491
subscription.cancel();
454492
}
@@ -460,7 +498,12 @@ private void handleError(int streamId, Throwable t) {
460498
}
461499

462500
private void handleRequestN(int streamId, ByteBuf frame) {
463-
final Subscription subscription = sendingSubscriptions.get(streamId);
501+
Subscription subscription = sendingSubscriptions.get(streamId);
502+
503+
if (subscription == null) {
504+
subscription = sendingLimitableSubscriptions.get(streamId);
505+
}
506+
464507
if (subscription != null) {
465508
int n = RequestNFrameFlyweight.requestN(frame);
466509
subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n);

rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio
3131

3232
private final AtomicBoolean canceled;
3333

34+
private final long prefetch;
35+
3436
private long internalRequested;
3537

3638
private long externalRequested;
@@ -39,13 +41,14 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio
3941

4042
private volatile @Nullable Subscription internalSubscription;
4143

42-
private LimitableRequestPublisher(Publisher<T> source) {
44+
private LimitableRequestPublisher(Publisher<T> source, long prefetch) {
4345
this.source = source;
46+
this.prefetch = prefetch;
4447
this.canceled = new AtomicBoolean();
4548
}
4649

47-
public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source) {
48-
return new LimitableRequestPublisher<>(source);
50+
public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source, long prefetch) {
51+
return new LimitableRequestPublisher<>(source, prefetch);
4952
}
5053

5154
@Override
@@ -60,6 +63,7 @@ public void subscribe(CoreSubscriber<? super T> destination) {
6063

6164
destination.onSubscribe(new InnerSubscription());
6265
source.subscribe(new InnerSubscriber(destination));
66+
increaseInternalLimit(prefetch);
6367
}
6468

6569
public void increaseRequestLimit(long n) {
@@ -70,6 +74,14 @@ public void increaseRequestLimit(long n) {
7074
requestN();
7175
}
7276

77+
public void increaseInternalLimit(long n) {
78+
synchronized (this) {
79+
internalRequested = Operators.addCap(n, internalRequested);
80+
}
81+
82+
requestN();
83+
}
84+
7385
@Override
7486
public void request(long n) {
7587
increaseRequestLimit(n);
@@ -82,9 +94,17 @@ private void requestN() {
8294
return;
8395
}
8496

85-
r = Math.min(internalRequested, externalRequested);
86-
externalRequested -= r;
87-
internalRequested -= r;
97+
if (externalRequested != Long.MAX_VALUE || internalRequested != Long.MAX_VALUE) {
98+
r = Math.min(internalRequested, externalRequested);
99+
if (externalRequested != Long.MAX_VALUE) {
100+
externalRequested -= r;
101+
}
102+
if (internalRequested != Long.MAX_VALUE) {
103+
internalRequested -= r;
104+
}
105+
} else {
106+
r = Long.MAX_VALUE;
107+
}
88108
}
89109

90110
if (r > 0) {
@@ -144,13 +164,7 @@ public void onComplete() {
144164

145165
private class InnerSubscription implements Subscription {
146166
@Override
147-
public void request(long n) {
148-
synchronized (LimitableRequestPublisher.this) {
149-
internalRequested = Operators.addCap(n, internalRequested);
150-
}
151-
152-
requestN();
153-
}
167+
public void request(long n) {}
154168

155169
@Override
156170
public void cancel() {

rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ public void onSubscribe(Subscription s) {
221221
}
222222
}
223223

224+
public long available() {
225+
return requested;
226+
}
227+
224228
@Override
225229
public int getPrefetch() {
226230
return Integer.MAX_VALUE;

rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
import io.rsocket.exceptions.ApplicationErrorException;
2929
import io.rsocket.exceptions.RejectedSetupException;
3030
import io.rsocket.frame.*;
31+
import io.rsocket.test.util.TestDuplexConnection;
3132
import io.rsocket.test.util.TestSubscriber;
3233
import io.rsocket.util.DefaultPayload;
3334
import io.rsocket.util.EmptyPayload;
3435
import java.time.Duration;
3536
import java.util.ArrayList;
3637
import java.util.List;
38+
import java.util.Queue;
39+
import java.util.concurrent.ConcurrentLinkedQueue;
3740
import java.util.stream.Collectors;
3841
import org.assertj.core.api.Assertions;
3942
import org.junit.Rule;
@@ -216,6 +219,32 @@ public void testChannelRequestServerSideCancellation() {
216219
Assertions.assertThat(request.isDisposed()).isTrue();
217220
}
218221

222+
@Test(timeout = 2_000)
223+
@SuppressWarnings("unchecked")
224+
public void
225+
testClientSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
226+
final Queue<Long> requests = new ConcurrentLinkedQueue<>();
227+
rule.connection.dispose();
228+
rule.connection = new TestDuplexConnection();
229+
rule.connection.setInitialSendRequestN(256);
230+
rule.init();
231+
232+
rule.socket
233+
.requestChannel(
234+
Flux.<Payload>generate(s -> s.next(EmptyPayload.INSTANCE)).doOnRequest(requests::add))
235+
.subscribe();
236+
237+
int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
238+
239+
assertThat("Unexpected error.", rule.errors, is(empty()));
240+
241+
rule.connection.addToReceivedBuffer(
242+
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2));
243+
rule.connection.addToReceivedBuffer(
244+
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, Integer.MAX_VALUE));
245+
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
246+
}
247+
219248
public int sendRequestResponse(Publisher<Payload> response) {
220249
Subscriber<Payload> sub = TestSubscriber.create();
221250
response.subscribe(sub);

0 commit comments

Comments
 (0)