Skip to content

Commit

Permalink
Plumb CancellationToken through Socket.Receive/SendAsync (dotnet#36516)
Browse files Browse the repository at this point in the history
In .NET Core 2.1 we added overloads of Send/ReceiveAsync, and we proactively added CancellationToken arguments to them, but those tokens were only checked at the beginning of the call; if a cancellation request came in after that check, the operation would not be interrupted.

This PR plumbs the token through so that a cancellation request at any point in the operation will cancel that operation.  On Windows we register to use CancelIoEx to request cancellation of the specific overlapped operation on the specific socket.  On Unix we use the existing cancellation infrastructure already in place to support the existing custom queueing scheme.

Some caveats:
- On Windows, canceling a TCP receive will end up canceling all TCP receives pending on that socket, even when we request cancellation of a specific overlapped operation; this is just how cancellation works at the OS level, and there's little we can do about it.  It also shouldn't matter much, as multiple pending receives on the same socket are rare.
- If multiple concurrent receives or multiple concurrent sends are issued on the same socket, only the first will actually be cancelable.  This is because this implementation only plumbs the token through the SocketAsyncEventArgs-based code paths, not the APM based code paths, and currently when using the Task-based APIs, we use the SocketAsyncEventArgs under the covers for only one receive and one send at a time; other receives made while that SAEA receive is in progress or other sends made while that SAEA send is in progress will fall back to using the APM code paths.  This could be addressed in the future in various ways, including a) just using the SAEA code paths for all operations and deleting the APM fallback, or b) plumbing cancellation through APM as well.  However, for now, this approach addresses the primary use case and should be sufficient.
- This only affects code paths to which the CancellationToken passed to Send/ReceiveAsync could reach.  If in the future we add additional overloads taking CancellationToken, we will likely need to plumb it to more places.
  • Loading branch information
stephentoub authored Apr 7, 2019
1 parent e5fddbc commit 2190a0f
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 49 deletions.
12 changes: 6 additions & 6 deletions src/System.Net.Sockets/System.Net.Sockets.sln
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27213.1
# Visual Studio Version 16
VisualStudioVersion = 16.0.28725.219
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Net.Sockets.Tests", "tests\FunctionalTests\System.Net.Sockets.Tests.csproj", "{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}"
ProjectSection(ProjectDependencies) = postProject
Expand Down Expand Up @@ -31,10 +31,10 @@ Global
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.ActiveCfg = netcoreapp-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.Build.0 = netcoreapp-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.ActiveCfg = netcoreapp-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.Build.0 = netcoreapp-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.ActiveCfg = netcoreapp-Unix-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.Build.0 = netcoreapp-Unix-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.ActiveCfg = netcoreapp-Unix-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.Build.0 = netcoreapp-Unix-Release|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Debug|Any CPU.ActiveCfg = netstandard-Windows_NT-Debug|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Debug|Any CPU.Build.0 = netstandard-Windows_NT-Debug|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Release|Any CPU.ActiveCfg = netstandard-Windows_NT-Release|Any CPU
Expand Down
3 changes: 3 additions & 0 deletions src/System.Net.Sockets/src/System.Net.Sockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@
<Compile Include="$(CommonPath)\Interop\Windows\WinSock\WSABuffer.cs">
<Link>Interop\Windows\WinSock\WSABuffer.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\CoreLib\Interop\Windows\Kernel32\Interop.CancelIoEx.cs">
<Link>Common\Interop\Windows\Interop.CancelIoEx.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\Interop\Windows\Kernel32\Interop.SetFileCompletionNotificationModes.cs">
<Link>Interop\Windows\Kernel32\Interop.SetFileCompletionNotificationModes.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -846,7 +843,7 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo
return _streamSocket.SendAsyncForNetworkStream(
buffer,
SocketFlags.None,
cancellationToken: cancellationToken);
cancellationToken);
}
catch (Exception exception) when (!(exception is OutOfMemoryException))
{
Expand Down
33 changes: 24 additions & 9 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ internal ValueTask<int> ReceiveAsync(Memory<byte> buffer, SocketFlags socketFlag
saea.SetBuffer(buffer);
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = fromNetworkStream;
return saea.ReceiveAsync(this);
return saea.ReceiveAsync(this, cancellationToken);
}
else
{
Expand Down Expand Up @@ -350,7 +350,7 @@ internal ValueTask<int> SendAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = false;
return saea.SendAsync(this);
return saea.SendAsync(this, cancellationToken);
}
else
{
Expand Down Expand Up @@ -828,6 +828,8 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal
/// it's already being reused by someone else.
/// </remarks>
private short _token;
/// <summary>The cancellation token used for the current operation.</summary>
private CancellationToken _cancellationToken;

/// <summary>Initializes the event args.</summary>
public AwaitableSocketAsyncEventArgs() :
Expand All @@ -842,6 +844,7 @@ public bool Reserve() =>

private void Release()
{
_cancellationToken = default;
_token++;
Volatile.Write(ref _continuation, s_availableSentinel);
}
Expand Down Expand Up @@ -882,12 +885,13 @@ protected override void OnCompleted(SocketAsyncEventArgs _)

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket)
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.ReceiveAsync(this))
if (socket.ReceiveAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<int>(this, _token);
}

Expand All @@ -903,12 +907,13 @@ public ValueTask<int> ReceiveAsync(Socket socket)

/// <summary>Initiates a send operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> SendAsync(Socket socket)
public ValueTask<int> SendAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.SendAsync(this))
if (socket.SendAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<int>(this, _token);
}

Expand Down Expand Up @@ -1059,12 +1064,13 @@ public int GetResult(short token)

SocketError error = SocketError;
int bytes = BytesTransferred;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
return bytes;
}
Expand All @@ -1077,20 +1083,29 @@ void IValueTaskSource.GetResult(short token)
}

SocketError error = SocketError;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
}

private void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken);

private void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations);

private void ThrowException(SocketError error) => throw CreateException(error);
private void ThrowException(SocketError error, CancellationToken cancellationToken)
{
if (error == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw CreateException(error);
}

private Exception CreateException(SocketError error)
{
Expand Down
12 changes: 8 additions & 4 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4038,7 +4038,9 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
return retval;
}

public bool ReceiveAsync(SocketAsyncEventArgs e)
public bool ReceiveAsync(SocketAsyncEventArgs e) => ReceiveAsync(e, default(CancellationToken));

private bool ReceiveAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e);

Expand All @@ -4057,7 +4059,7 @@ public bool ReceiveAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationReceive(_handle);
socketError = e.DoOperationReceive(_handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -4179,7 +4181,9 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e)
return pending;
}

public bool SendAsync(SocketAsyncEventArgs e)
public bool SendAsync(SocketAsyncEventArgs e) => SendAsync(e, default(CancellationToken));

private bool SendAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e);

Expand All @@ -4198,7 +4202,7 @@ public bool SendAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationSend(_handle);
socketError = e.DoOperationSend(_handle, cancellationToken);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ private enum State
public SocketError ErrorCode;
public byte[] SocketAddress;
public int SocketAddressLen;
public CancellationTokenRegistration CancellationRegistration;

public ManualResetEventSlim Event
{
Expand Down Expand Up @@ -195,6 +196,11 @@ public bool TryCancel()
{
Trace("Enter");

// We're already canceling, so we don't need to still be hooked up to listen to cancellation.
// The cancellation request could also be caused by something other than the token, so it's
// important we clean it up, regardless.
CancellationRegistration.Dispose();

// Try to transition from Waiting to Cancelled
var spinWait = new SpinWait();
bool keepWaiting = true;
Expand Down Expand Up @@ -739,7 +745,7 @@ public bool IsReady(SocketAsyncContext context, out int observedSequenceNumber)
}

// Return true for pending, false for completed synchronously (including failure and abort)
public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation, int observedSequenceNumber)
public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation, int observedSequenceNumber, CancellationToken cancellationToken = default)
{
Trace(context, $"Enter");

Expand Down Expand Up @@ -781,8 +787,16 @@ public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation
}

_tail = operation;

Trace(context, $"Leave, enqueued {IdOf(operation)}");

// Now that the object is enqueued, hook up cancellation.
// Note that it's possible the call to register itself could
// call TryCancel, so we do this after the op is fully enqueued.
if (cancellationToken.CanBeCanceled)
{
operation.CancellationRegistration = cancellationToken.UnsafeRegister(s => ((TOperation)s).TryCancel(), operation);
}

return true;

case QueueState.Stopped:
Expand Down Expand Up @@ -866,7 +880,12 @@ internal void ProcessAsyncOperation(TOperation op)
{
// At this point, the operation has completed and it's no longer
// in the queue / no one else has a reference to it. We can invoke
// the callback and let it pool the object if appropriate.
// the callback and let it pool the object if appropriate. This is
// also a good time to unregister from cancellation; we must do
// so before the object is returned to the pool (or else a cancellation
// request for a previous operation could affect a subsequent one)
// and here we know the operation has completed.
op.CancellationRegistration.Dispose();
op.InvokeCallback(allowPooling: true);
}
}
Expand Down Expand Up @@ -1410,10 +1429,10 @@ public SocketError Receive(Span<byte> buffer, ref SocketFlags flags, int timeout
return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived);
}

public SocketError ReceiveAsync(Memory<byte> buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError ReceiveAsync(Memory<byte> buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken)
{
int socketAddressLen = 0;
return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback);
return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback, cancellationToken);
}

public SocketError ReceiveFrom(Memory<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived)
Expand Down Expand Up @@ -1478,7 +1497,7 @@ public unsafe SocketError ReceiveFrom(Span<byte> buffer, ref SocketFlags flags,
}
}

public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();

Expand All @@ -1497,7 +1516,7 @@ public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byt
operation.SocketAddress = socketAddress;
operation.SocketAddressLen = socketAddressLen;

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
receivedFlags = operation.ReceivedFlags;
bytesReceived = operation.BytesTransferred;
Expand Down Expand Up @@ -1673,10 +1692,10 @@ public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags,
return SendTo(buffer, offset, count, flags, null, 0, timeout, out bytesSent);
}

public SocketError SendAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError SendAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken)
{
int socketAddressLen = 0;
return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback);
return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback, cancellationToken);
}

public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, int timeout, out int bytesSent)
Expand Down Expand Up @@ -1745,7 +1764,7 @@ public unsafe SocketError SendTo(ReadOnlySpan<byte> buffer, SocketFlags flags, b
}
}

public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();

Expand All @@ -1768,7 +1787,7 @@ public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, Socke
operation.SocketAddressLen = socketAddressLen;
operation.BytesTransferred = bytesSent;

if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
bytesSent = operation.BytesTransferred;
errorCode = operation.ErrorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Sockets
{
Expand Down Expand Up @@ -127,7 +129,7 @@ private void CompleteTransferOperation(int bytesTransferred, byte[] socketAddres
_receivedFlags = receivedFlags;
}

internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -137,7 +139,7 @@ internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle)
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, out flags, TransferCompletionCallback);
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down Expand Up @@ -219,7 +221,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
return socketError;
}

internal unsafe SocketError DoOperationSend(SafeSocketHandle handle)
internal unsafe SocketError DoOperationSend(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -228,7 +230,7 @@ internal unsafe SocketError DoOperationSend(SafeSocketHandle handle)
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback);
errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down
Loading

0 comments on commit 2190a0f

Please sign in to comment.