2626import io .rsocket .util .PayloadImpl ;
2727import java .nio .channels .ClosedChannelException ;
2828import java .time .Duration ;
29+ import java .util .Collection ;
30+ import java .util .concurrent .atomic .AtomicBoolean ;
2931import java .util .concurrent .atomic .AtomicInteger ;
3032import java .util .function .Consumer ;
3133import java .util .function .Function ;
3436import org .reactivestreams .Publisher ;
3537import org .reactivestreams .Subscriber ;
3638import reactor .core .Disposable ;
37- import reactor .core .publisher .Flux ;
38- import reactor .core .publisher .Mono ;
39- import reactor .core .publisher .MonoProcessor ;
40- import reactor .core .publisher .UnicastProcessor ;
39+ import reactor .core .publisher .*;
4140
4241/** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */
4342class RSocketClient implements RSocket {
@@ -53,8 +52,9 @@ class RSocketClient implements RSocket {
5352 private final IntObjectHashMap <Subscriber <Payload >> receivers ;
5453 private final AtomicInteger missedAckCounter ;
5554
56- private @ Nullable Disposable keepAliveSendSub ;
55+ private final EmitterProcessor < Frame > sendProcessor ;
5756
57+ private @ Nullable Disposable keepAliveSendSub ;
5858 private volatile long timeLastTickSentMs ;
5959
6060 RSocketClient (
@@ -78,6 +78,7 @@ class RSocketClient implements RSocket {
7878 this .senders = new IntObjectHashMap <>(256 , 0.9f );
7979 this .receivers = new IntObjectHashMap <>(256 , 0.9f );
8080 this .missedAckCounter = new AtomicInteger ();
81+ this .sendProcessor = EmitterProcessor .create ();
8182
8283 if (!Duration .ZERO .equals (tickPeriod )) {
8384 long ackTimeoutMs = ackTimeout .toMillis ();
@@ -103,33 +104,87 @@ class RSocketClient implements RSocket {
103104 .doOnNext (this ::handleIncomingFrames )
104105 .doOnError (errorConsumer )
105106 .subscribe ();
107+
108+ connection
109+ .send (sendProcessor )
110+ .doOnError (this ::handleSendProcessorError )
111+ .doFinally (this ::handleSendProcessorCancel )
112+ .subscribe ();
106113 }
107114
108- private Mono <Void > sendKeepAlive (long ackTimeoutMs , int missedAcks ) {
109- long now = System .currentTimeMillis ();
110- if (now - timeLastTickSentMs > ackTimeoutMs ) {
111- int count = missedAckCounter .incrementAndGet ();
112- if (count >= missedAcks ) {
113- String message =
114- String .format (
115- "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms" ,
116- count , missedAcks , ackTimeoutMs );
117- return Mono .error (new ConnectionException (message ));
115+ private void handleSendProcessorError (Throwable t ) {
116+ Collection <Subscriber <Payload >> values ;
117+ Collection <LimitableRequestPublisher > values1 ;
118+ synchronized (RSocketClient .this ) {
119+ values = receivers .values ();
120+ values1 = senders .values ();
121+ }
122+
123+ for (Subscriber subscriber : values ) {
124+ try {
125+ subscriber .onError (t );
126+ } catch (Throwable e ) {
127+ errorConsumer .accept (e );
128+ }
129+ }
130+
131+ for (LimitableRequestPublisher p : values1 ) {
132+ p .cancel ();
133+ }
134+ }
135+
136+ private void handleSendProcessorCancel (SignalType t ) {
137+ if (SignalType .ON_ERROR == t ) {
138+ return ;
139+ }
140+ Collection <Subscriber <Payload >> values ;
141+ Collection <LimitableRequestPublisher > values1 ;
142+ synchronized (RSocketClient .this ) {
143+ values = receivers .values ();
144+ values1 = senders .values ();
145+ }
146+
147+ for (Subscriber subscriber : values ) {
148+ try {
149+ subscriber .onError (new Throwable ("closed connection" ));
150+ } catch (Throwable e ) {
151+ errorConsumer .accept (e );
118152 }
119153 }
120154
121- return connection .sendOne (Frame .Keepalive .from (Unpooled .EMPTY_BUFFER , true ));
155+ for (LimitableRequestPublisher p : values1 ) {
156+ p .cancel ();
157+ }
158+ }
159+
160+ private Mono <Void > sendKeepAlive (long ackTimeoutMs , int missedAcks ) {
161+ return Mono .fromRunnable (
162+ () -> {
163+ long now = System .currentTimeMillis ();
164+ if (now - timeLastTickSentMs > ackTimeoutMs ) {
165+ int count = missedAckCounter .incrementAndGet ();
166+ if (count >= missedAcks ) {
167+ String message =
168+ String .format (
169+ "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms" ,
170+ count , missedAcks , ackTimeoutMs );
171+ throw new ConnectionException (message );
172+ }
173+ }
174+
175+ sendProcessor .onNext (Frame .Keepalive .from (Unpooled .EMPTY_BUFFER , true ));
176+ });
122177 }
123178
124179 @ Override
125180 public Mono <Void > fireAndForget (Payload payload ) {
126181 Mono <Void > defer =
127- Mono .defer (
182+ Mono .fromRunnable (
128183 () -> {
129184 final int streamId = streamIdSupplier .nextStreamId ();
130185 final Frame requestFrame =
131186 Frame .Request .from (streamId , FrameType .FIRE_AND_FORGET , payload , 1 );
132- return connection . sendOne (requestFrame );
187+ sendProcessor . onNext (requestFrame );
133188 });
134189
135190 return started .then (defer );
@@ -142,7 +197,7 @@ public Mono<Payload> requestResponse(Payload payload) {
142197
143198 @ Override
144199 public Flux <Payload > requestStream (Payload payload ) {
145- return handleStreamResponse ( Flux . just ( payload ), FrameType . REQUEST_STREAM );
200+ return handleRequestStream ( payload );
146201 }
147202
148203 @ Override
@@ -153,7 +208,8 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
153208 @ Override
154209 public Mono <Void > metadataPush (Payload payload ) {
155210 final Frame requestFrame = Frame .Request .from (0 , FrameType .METADATA_PUSH , payload , 1 );
156- return connection .sendOne (requestFrame );
211+ sendProcessor .onNext (requestFrame );
212+ return Mono .empty ();
157213 }
158214
159215 @ Override
@@ -171,59 +227,77 @@ public Mono<Void> onClose() {
171227 return connection .onClose ();
172228 }
173229
174- private Mono <Payload > handleRequestResponse (final Payload payload ) {
175- return started .then (
176- Mono .defer (
230+ public Flux <Payload > handleRequestStream (final Payload payload ) {
231+ return started .thenMany (
232+ Flux .defer (
177233 () -> {
178234 int streamId = streamIdSupplier .nextStreamId ();
179- final Frame requestFrame =
180- Frame .Request .from (streamId , FrameType .REQUEST_RESPONSE , payload , 1 );
181235
182- MonoProcessor <Payload > receiver = MonoProcessor .create ();
236+ UnicastProcessor <Payload > receiver = UnicastProcessor .create ();
183237
184238 synchronized (this ) {
185239 receivers .put (streamId , receiver );
186240 }
187241
188- MonoProcessor <Void > subscribedRequest =
189- connection
190- .sendOne (requestFrame )
191- .doOnError (
192- t -> {
193- errorConsumer .accept (t );
194- receiver .cancel ();
195- })
196- .toProcessor ();
197- subscribedRequest .subscribe ();
242+ AtomicBoolean first = new AtomicBoolean (false );
198243
199244 return receiver
245+ .doOnRequest (
246+ l -> {
247+ if (first .compareAndSet (false , true ) && !receiver .isTerminated ()) {
248+ final Frame requestFrame =
249+ Frame .Request .from (streamId , FrameType .REQUEST_STREAM , payload , l );
250+
251+ sendProcessor .onNext (requestFrame );
252+ } else if (contains (streamId )
253+ && connection .availability () > 0.0
254+ && !receiver .isTerminated ()) {
255+ sendProcessor .onNext (Frame .RequestN .from (streamId , l ));
256+ }
257+ })
200258 .doOnError (
201259 t -> {
202260 if (contains (streamId )
203261 && connection .availability () > 0.0
204262 && !receiver .isTerminated ()) {
205- connection
206- .sendOne (Frame .Error .from (streamId , t ))
207- .doOnError (errorConsumer )
208- .subscribe ();
263+ sendProcessor .onNext (Frame .Error .from (streamId , t ));
209264 }
210265 })
211266 .doOnCancel (
212267 () -> {
213268 if (contains (streamId )
214269 && connection .availability () > 0.0
215270 && !receiver .isTerminated ()) {
216- connection
217- .sendOne (Frame .Cancel .from (streamId ))
218- .doOnError (errorConsumer )
219- .subscribe ();
271+ sendProcessor .onNext (Frame .Cancel .from (streamId ));
220272 }
221- subscribedRequest .cancel ();
222273 })
223274 .doFinally (s -> removeReceiver (streamId ));
224275 }));
225276 }
226277
278+ private Mono <Payload > handleRequestResponse (final Payload payload ) {
279+ return started .then (
280+ Mono .defer (
281+ () -> {
282+ int streamId = streamIdSupplier .nextStreamId ();
283+ final Frame requestFrame =
284+ Frame .Request .from (streamId , FrameType .REQUEST_RESPONSE , payload , 1 );
285+
286+ MonoProcessor <Payload > receiver = MonoProcessor .create ();
287+
288+ synchronized (this ) {
289+ receivers .put (streamId , receiver );
290+ }
291+
292+ sendProcessor .onNext (requestFrame );
293+
294+ return receiver
295+ .doOnError (t -> sendProcessor .onNext (Frame .Error .from (streamId , t )))
296+ .doOnCancel (() -> sendProcessor .onNext (Frame .Cancel .from (streamId )))
297+ .doFinally (s -> removeReceiver (streamId ));
298+ }));
299+ }
300+
227301 private Flux <Payload > handleStreamResponse (Flux <Payload > request , FrameType requestType ) {
228302 return started .thenMany (
229303 Flux .defer (
@@ -241,7 +315,7 @@ boolean isValidToSendFrame() {
241315
242316 void sendOneFrame (Frame frame ) {
243317 if (isValidToSendFrame ()) {
244- connection . sendOne (frame ). doOnError ( errorConsumer ). subscribe ( );
318+ sendProcessor . onNext (frame );
245319 }
246320 }
247321
@@ -306,16 +380,14 @@ public Frame apply(Payload payload) {
306380 }
307381 });
308382
309- subscribedRequests =
310- connection
311- .send (requestFrames )
312- .doOnError (
313- t -> {
314- errorConsumer .accept (t );
315- receiver .cancel ();
316- })
317- .toProcessor ();
318- subscribedRequests .subscribe ();
383+ requestFrames
384+ .doOnNext (sendProcessor ::onNext )
385+ .doOnError (
386+ t -> {
387+ errorConsumer .accept (t );
388+ receiver .cancel ();
389+ })
390+ .subscribe ();
319391 } else {
320392 sendOneFrame (Frame .RequestN .from (streamId , l ));
321393 }
0 commit comments