Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Plumb CancellationToken through Socket.Receive/SendAsync #36516

Merged
merged 1 commit into from
Apr 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/System.Net.Sockets/System.Net.Sockets.sln
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27213.1
# Visual Studio Version 16
VisualStudioVersion = 16.0.28725.219
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Net.Sockets.Tests", "tests\FunctionalTests\System.Net.Sockets.Tests.csproj", "{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}"
ProjectSection(ProjectDependencies) = postProject
Expand Down Expand Up @@ -31,10 +31,10 @@ Global
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.ActiveCfg = netcoreapp-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.Build.0 = netcoreapp-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.ActiveCfg = netcoreapp-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.Build.0 = netcoreapp-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.ActiveCfg = netcoreapp-Unix-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Debug|Any CPU.Build.0 = netcoreapp-Unix-Debug|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.ActiveCfg = netcoreapp-Unix-Release|Any CPU
{8CBA022C-635F-4C8D-9D29-CD8AAC68C8E6}.Release|Any CPU.Build.0 = netcoreapp-Unix-Release|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Debug|Any CPU.ActiveCfg = netstandard-Windows_NT-Debug|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Debug|Any CPU.Build.0 = netstandard-Windows_NT-Debug|Any CPU
{BB5C85AD-C51A-4903-80E9-6F6E1AC1AD34}.Release|Any CPU.ActiveCfg = netstandard-Windows_NT-Release|Any CPU
Expand Down
3 changes: 3 additions & 0 deletions src/System.Net.Sockets/src/System.Net.Sockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@
<Compile Include="$(CommonPath)\Interop\Windows\WinSock\WSABuffer.cs">
<Link>Interop\Windows\WinSock\WSABuffer.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\CoreLib\Interop\Windows\Kernel32\Interop.CancelIoEx.cs">
<Link>Common\Interop\Windows\Interop.CancelIoEx.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\Interop\Windows\Kernel32\Interop.SetFileCompletionNotificationModes.cs">
<Link>Interop\Windows\Kernel32\Interop.SetFileCompletionNotificationModes.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -846,7 +843,7 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo
return _streamSocket.SendAsyncForNetworkStream(
buffer,
SocketFlags.None,
cancellationToken: cancellationToken);
cancellationToken);
}
catch (Exception exception) when (!(exception is OutOfMemoryException))
{
Expand Down
33 changes: 24 additions & 9 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ internal ValueTask<int> ReceiveAsync(Memory<byte> buffer, SocketFlags socketFlag
saea.SetBuffer(buffer);
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = fromNetworkStream;
return saea.ReceiveAsync(this);
return saea.ReceiveAsync(this, cancellationToken);
}
else
{
Expand Down Expand Up @@ -350,7 +350,7 @@ internal ValueTask<int> SendAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = false;
return saea.SendAsync(this);
return saea.SendAsync(this, cancellationToken);
}
else
{
Expand Down Expand Up @@ -828,6 +828,8 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal
/// it's already being reused by someone else.
/// </remarks>
private short _token;
/// <summary>The cancellation token used for the current operation.</summary>
private CancellationToken _cancellationToken;

/// <summary>Initializes the event args.</summary>
public AwaitableSocketAsyncEventArgs() :
Expand All @@ -842,6 +844,7 @@ public bool Reserve() =>

private void Release()
{
_cancellationToken = default;
_token++;
Volatile.Write(ref _continuation, s_availableSentinel);
}
Expand Down Expand Up @@ -882,12 +885,13 @@ protected override void OnCompleted(SocketAsyncEventArgs _)

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket)
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.ReceiveAsync(this))
if (socket.ReceiveAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<int>(this, _token);
}

Expand All @@ -903,12 +907,13 @@ public ValueTask<int> ReceiveAsync(Socket socket)

/// <summary>Initiates a send operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> SendAsync(Socket socket)
public ValueTask<int> SendAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.SendAsync(this))
if (socket.SendAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<int>(this, _token);
}

Expand Down Expand Up @@ -1059,12 +1064,13 @@ public int GetResult(short token)

SocketError error = SocketError;
int bytes = BytesTransferred;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
return bytes;
}
Expand All @@ -1077,20 +1083,29 @@ void IValueTaskSource.GetResult(short token)
}

SocketError error = SocketError;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
}

private void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken);

private void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations);

private void ThrowException(SocketError error) => throw CreateException(error);
private void ThrowException(SocketError error, CancellationToken cancellationToken)
{
if (error == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw CreateException(error);
}

private Exception CreateException(SocketError error)
{
Expand Down
12 changes: 8 additions & 4 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4038,7 +4038,9 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
return retval;
}

public bool ReceiveAsync(SocketAsyncEventArgs e)
public bool ReceiveAsync(SocketAsyncEventArgs e) => ReceiveAsync(e, default(CancellationToken));

private bool ReceiveAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e);

Expand All @@ -4057,7 +4059,7 @@ public bool ReceiveAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationReceive(_handle);
socketError = e.DoOperationReceive(_handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -4179,7 +4181,9 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e)
return pending;
}

public bool SendAsync(SocketAsyncEventArgs e)
public bool SendAsync(SocketAsyncEventArgs e) => SendAsync(e, default(CancellationToken));

private bool SendAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e);

Expand All @@ -4198,7 +4202,7 @@ public bool SendAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationSend(_handle);
socketError = e.DoOperationSend(_handle, cancellationToken);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ private enum State
public SocketError ErrorCode;
public byte[] SocketAddress;
public int SocketAddressLen;
public CancellationTokenRegistration CancellationRegistration;

public ManualResetEventSlim Event
{
Expand Down Expand Up @@ -195,6 +196,11 @@ public bool TryCancel()
{
Trace("Enter");

// We're already canceling, so we don't need to still be hooked up to listen to cancellation.
// The cancellation request could also be caused by something other than the token, so it's
// important we clean it up, regardless.
CancellationRegistration.Dispose();

// Try to transition from Waiting to Cancelled
var spinWait = new SpinWait();
bool keepWaiting = true;
Expand Down Expand Up @@ -739,7 +745,7 @@ public bool IsReady(SocketAsyncContext context, out int observedSequenceNumber)
}

// Return true for pending, false for completed synchronously (including failure and abort)
public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation, int observedSequenceNumber)
public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation, int observedSequenceNumber, CancellationToken cancellationToken = default)
{
Trace(context, $"Enter");

Expand Down Expand Up @@ -781,8 +787,16 @@ public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation
}

_tail = operation;

Trace(context, $"Leave, enqueued {IdOf(operation)}");

// Now that the object is enqueued, hook up cancellation.
// Note that it's possible the call to register itself could
// call TryCancel, so we do this after the op is fully enqueued.
if (cancellationToken.CanBeCanceled)
{
operation.CancellationRegistration = cancellationToken.UnsafeRegister(s => ((TOperation)s).TryCancel(), operation);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephentoub a canceled operation will remain in the queue until it is removed at the next event or the queue gets stopped.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Is that a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not functionally. It will be alive for longer and maybe keeping some other things alive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be alive for longer and maybe keeping some other things alive.

Yes, but then again so are the Socket and the SocketAsyncEventArgs. In comparison this shouldn't keep alive much. And if you're canceling the operation, there's a really high likelihood you're also either tearing everything down or about to do something else that will revisit the queue.

}

return true;

case QueueState.Stopped:
Expand Down Expand Up @@ -866,7 +880,12 @@ internal void ProcessAsyncOperation(TOperation op)
{
// At this point, the operation has completed and it's no longer
// in the queue / no one else has a reference to it. We can invoke
// the callback and let it pool the object if appropriate.
// the callback and let it pool the object if appropriate. This is
// also a good time to unregister from cancellation; we must do
// so before the object is returned to the pool (or else a cancellation
// request for a previous operation could affect a subsequent one)
// and here we know the operation has completed.
op.CancellationRegistration.Dispose();
op.InvokeCallback(allowPooling: true);
}
}
Expand Down Expand Up @@ -1410,10 +1429,10 @@ public SocketError Receive(Span<byte> buffer, ref SocketFlags flags, int timeout
return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived);
}

public SocketError ReceiveAsync(Memory<byte> buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError ReceiveAsync(Memory<byte> buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken)
{
int socketAddressLen = 0;
return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback);
return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback, cancellationToken);
}

public SocketError ReceiveFrom(Memory<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived)
Expand Down Expand Up @@ -1478,7 +1497,7 @@ public unsafe SocketError ReceiveFrom(Span<byte> buffer, ref SocketFlags flags,
}
}

public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();

Expand All @@ -1497,7 +1516,7 @@ public SocketError ReceiveFromAsync(Memory<byte> buffer, SocketFlags flags, byt
operation.SocketAddress = socketAddress;
operation.SocketAddressLen = socketAddressLen;

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
receivedFlags = operation.ReceivedFlags;
bytesReceived = operation.BytesTransferred;
Expand Down Expand Up @@ -1673,10 +1692,10 @@ public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags,
return SendTo(buffer, offset, count, flags, null, 0, timeout, out bytesSent);
}

public SocketError SendAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError SendAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken)
{
int socketAddressLen = 0;
return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback);
return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback, cancellationToken);
}

public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, int timeout, out int bytesSent)
Expand Down Expand Up @@ -1745,7 +1764,7 @@ public unsafe SocketError SendTo(ReadOnlySpan<byte> buffer, SocketFlags flags, b
}
}

public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback)
public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action<int, byte[], int, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();

Expand All @@ -1768,7 +1787,7 @@ public SocketError SendToAsync(Memory<byte> buffer, int offset, int count, Socke
operation.SocketAddressLen = socketAddressLen;
operation.BytesTransferred = bytesSent;

if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
bytesSent = operation.BytesTransferred;
errorCode = operation.ErrorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Sockets
{
Expand Down Expand Up @@ -127,7 +129,7 @@ private void CompleteTransferOperation(int bytesTransferred, byte[] socketAddres
_receivedFlags = receivedFlags;
}

internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -137,7 +139,7 @@ internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle)
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, out flags, TransferCompletionCallback);
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down Expand Up @@ -219,7 +221,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
return socketError;
}

internal unsafe SocketError DoOperationSend(SafeSocketHandle handle)
internal unsafe SocketError DoOperationSend(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -228,7 +230,7 @@ internal unsafe SocketError DoOperationSend(SafeSocketHandle handle)
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback);
errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down
Loading