diff --git a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs index dbb11f48e59..7419244a46c 100644 --- a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs +++ b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs @@ -34,7 +34,7 @@ public void Compress(Stream input, Stream output) { var uncompressedSize = (int)(input.Length - input.Position); var uncompressedBytes = new byte[uncompressedSize]; // does not include uncompressed message headers - input.ReadBytes(uncompressedBytes, offset: 0, count: uncompressedSize, CancellationToken.None); + input.ReadBytes(uncompressedBytes, offset: 0, count: uncompressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); var maxCompressedSize = Snappy.GetMaxCompressedLength(uncompressedSize); var compressedBytes = new byte[maxCompressedSize]; var compressedSize = Snappy.Compress(uncompressedBytes, compressedBytes); @@ -50,7 +50,7 @@ public void Decompress(Stream input, Stream output) { var compressedSize = (int)(input.Length - input.Position); var compressedBytes = new byte[compressedSize]; - input.ReadBytes(compressedBytes, offset: 0, count: compressedSize, CancellationToken.None); + input.ReadBytes(compressedBytes, offset: 0, count: compressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); var uncompressedSize = Snappy.GetUncompressedLength(compressedBytes); var decompressedBytes = new byte[uncompressedSize]; var decompressedSize = Snappy.Decompress(compressedBytes, decompressedBytes); diff --git a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs index 380526e8348..d026da63a5e 100644 --- a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs +++ b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs @@ -333,14 +333,15 @@ private IByteBuffer ReceiveBuffer(CancellationToken cancellationToken) try { var messageSizeBytes = new byte[4]; - _stream.ReadBytes(messageSizeBytes, 0, 4, cancellationToken); + var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + _stream.ReadBytes(messageSizeBytes, 0, 4, readTimeout, cancellationToken); var messageSize = BinaryPrimitives.ReadInt32LittleEndian(messageSizeBytes); EnsureMessageSizeIsValid(messageSize); var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default); var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize); buffer.Length = messageSize; buffer.SetBytes(0, messageSizeBytes, 0, 4); - _stream.ReadBytes(buffer, 4, messageSize - 4, cancellationToken); + _stream.ReadBytes(buffer, 4, messageSize - 4, readTimeout, cancellationToken); _lastUsedAtUtc = DateTime.UtcNow; buffer.MakeReadOnly(); return buffer; @@ -535,7 +536,8 @@ private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) try { - _stream.WriteBytes(buffer, 0, buffer.Length, cancellationToken); + var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan; + _stream.WriteBytes(buffer, 0, buffer.Length, writeTimeout, cancellationToken); _lastUsedAtUtc = DateTime.UtcNow; } catch (Exception ex) diff --git a/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs b/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs index ae2c1fb69b4..0cef690d661 100644 --- a/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs +++ b/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs @@ -138,7 +138,11 @@ private void Connect(Socket socket, EndPoint endPoint, CancellationToken cancell if (!connectOperation.IsCompleted) { - try { socket.Dispose(); } catch { } + try + { + socket.Dispose(); + socket.EndConnect(connectOperation); + } catch { } cancellationToken.ThrowIfCancellationRequested(); throw new TimeoutException($"Timed out connecting to {endPoint}. Timeout was {_settings.ConnectTimeout}."); @@ -164,7 +168,11 @@ private async Task ConnectAsync(Socket socket, EndPoint endPoint, CancellationTo if (!connectTask.IsCompleted) { - try { socket.Dispose(); } catch { } + try + { + socket.Dispose(); + await connectTask.ConfigureAwait(false); + } catch { } cancellationToken.ThrowIfCancellationRequested(); throw new TimeoutException($"Timed out connecting to {endPoint}. Timeout was {_settings.ConnectTimeout}."); diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 248aa756272..c1e51cd074f 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -36,46 +36,64 @@ public static void EfficientCopyTo(this Stream input, Stream output) } } - public static async Task ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { - var state = 1; // 1 == reading, 2 == done reading, 3 == timedout, 4 == cancelled - - var bytesRead = 0; - using (new Timer(_ => ChangeState(3), null, timeout, Timeout.InfiniteTimeSpan)) - using (cancellationToken.Register(() => ChangeState(4))) + try { - try - { - bytesRead = await stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); - ChangeState(2); // note: might not actually go to state 2 if already in state 3 or 4 - } - catch when (state == 1) - { - try { stream.Dispose(); } catch { } - throw; - } - catch when (state >= 3) + var readOperation = stream.BeginRead(buffer, offset, count, null, null); + WaitHandle.WaitAny([readOperation.AsyncWaitHandle, cancellationToken.WaitHandle], timeout); + + if (!readOperation.IsCompleted) { - // a timeout or operation cancelled exception will be thrown instead + try + { + stream.Dispose(); + stream.EndRead(readOperation); + } + catch + { + // ignore any exceptions + } + + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); } - if (state == 3) { throw new TimeoutException(); } - if (state == 4) { throw new OperationCanceledException(); } + return stream.EndRead(readOperation); + } + catch (ObjectDisposedException ex) + { + throw new EndOfStreamException("The connection was interrupted.", ex); } + } + + public static async Task ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + { + var timeoutTask = Task.Delay(timeout, cancellationToken); + var readTask = stream.ReadAsync(buffer, offset, count); - return bytesRead; + await Task.WhenAny(readTask, timeoutTask).ConfigureAwait(false); - void ChangeState(int to) + if (!readTask.IsCompleted) { - var from = Interlocked.CompareExchange(ref state, to, 1); - if (from == 1 && to >= 3) + try + { + stream.Dispose(); + await readTask.ConfigureAwait(false); + } + catch { - try { stream.Dispose(); } catch { } // disposing the stream aborts the read attempt + // ignore any exceptions } + + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); } + + return await readTask.ConfigureAwait(false); } - public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); @@ -84,7 +102,7 @@ public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int while (count > 0) { - var bytesRead = stream.Read(buffer, offset, count); // TODO: honor cancellationToken? + var bytesRead = stream.Read(buffer, offset, count, timeout, cancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -94,7 +112,7 @@ public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int } } - public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, int count, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); @@ -105,7 +123,7 @@ public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, { var backingBytes = buffer.AccessBackingBytes(offset); var bytesToRead = Math.Min(count, backingBytes.Count); - var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead); // TODO: honor cancellationToken? + var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -155,44 +173,64 @@ public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer, } } + public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + { + try + { + var writeOperation = stream.BeginWrite(buffer, offset, count, null, null); + WaitHandle.WaitAny([writeOperation.AsyncWaitHandle, cancellationToken.WaitHandle], timeout); + + if (!writeOperation.IsCompleted) + { + try + { + stream.Dispose(); + stream.EndWrite(writeOperation); + } + catch + { + // ignore any exceptions + } + + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); + } + + stream.EndWrite(writeOperation); + } + catch (ObjectDisposedException ex) + { + throw new EndOfStreamException("The connection was interrupted.", ex); + } + } public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { - var state = 1; // 1 == writing, 2 == done writing, 3 == timedout, 4 == cancelled + var timeoutTask = Task.Delay(timeout); + var writeTask = stream.WriteAsync(buffer, offset, count, cancellationToken); + + await Task.WhenAny(writeTask, timeoutTask).ConfigureAwait(false); - using (new Timer(_ => ChangeState(3), null, timeout, Timeout.InfiniteTimeSpan)) - using (cancellationToken.Register(() => ChangeState(4))) + if (!writeTask.IsCompleted) { try { - await stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); - ChangeState(2); // note: might not actually go to state 2 if already in state 3 or 4 - } - catch when (state == 1) - { - try { stream.Dispose(); } catch { } - throw; + stream.Dispose(); + await writeTask.ConfigureAwait(false); } - catch when (state >= 3) + catch { - // a timeout or operation cancelled exception will be thrown instead + // ignore any exceptions } - if (state == 3) { throw new TimeoutException(); } - if (state == 4) { throw new OperationCanceledException(); } + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); } - void ChangeState(int to) - { - var from = Interlocked.CompareExchange(ref state, to, 1); - if (from == 1 && to >= 3) - { - try { stream.Dispose(); } catch { } // disposing the stream aborts the write attempt - } - } + await writeTask.ConfigureAwait(false); } - public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset, int count, CancellationToken cancellationToken) + public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); @@ -204,7 +242,7 @@ public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset cancellationToken.ThrowIfCancellationRequested(); var backingBytes = buffer.AccessBackingBytes(offset); var bytesToWrite = Math.Min(count, backingBytes.Count); - stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite); // TODO: honor cancellationToken? + stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken); // TODO: honor cancellationToken? offset += bytesToWrite; count -= bytesToWrite; } diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs index 9a904c3db08..321d572d61e 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs @@ -14,7 +14,6 @@ */ using System; -using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Sockets; @@ -652,9 +651,13 @@ public void ReceiveMessage_should_not_produce_unobserved_task_exceptions_on_fail } else { + var task = Task.FromException(new SocketException()); mockStream - .Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) - .Throws(new SocketException()); + .Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + .Returns(task); + mockStream + .Setup(s => s.EndRead(It.IsAny())) + .Returns(x => ((Task)x).GetAwaiter().GetResult()); } _subject.Open(CancellationToken.None); @@ -682,6 +685,93 @@ public void ReceiveMessage_should_not_produce_unobserved_task_exceptions_on_fail } } + [Theory] + [ParameterAttributeData] + public void ReceiveMessage_should_not_produce_unobserved_task_exceptions_on_timeout( + [Values(false, true)] bool async) + { + GC.Collect(); // Collects the unobserved tasks + GC.WaitForPendingFinalizers(); // Assures finalizers are executed + + Exception ex = null; + var mockStream = new Mock(); + EventHandler eventHandler = (s, args) => + { + ex = args.Exception; + }; + + try + { + TaskScheduler.UnobservedTaskException += eventHandler; + var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); + + _mockStreamFactory + .Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) + .Returns(mockStream.Object); + + mockStream.SetupGet(s => s.CanTimeout).Returns(true); + mockStream.SetupGet(s => s.ReadTimeout).Returns(20); + Task task = null; + if (async) + { + mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(async () => + { + await Task.Delay(40); + throw new SocketException(); + }); + } + else + { + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + .Returns(() => + { + task = Task.Run(() => + { + Thread.Sleep(40); + throw new SocketException(); + }); + return task; + }); + mockStream.Setup(s => s.EndRead(It.IsAny())) + .Returns(x => + { + task.GetAwaiter().GetResult(); + return 0; + }); + } + + _subject.Open(CancellationToken.None); + + Exception exception; + if (async) + { + exception = Record.Exception(() => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult()); + } + else + { + exception = Record.Exception(() => _subject.ReceiveMessage(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); + } + exception.Should().BeOfType(); + exception.InnerException.Should().BeOfType(); + + task = null; + mockStream.Reset(); + GC.Collect(); // Collects the unobserved tasks + GC.WaitForPendingFinalizers(); // Assures finalizers are executed + + if (ex != null) + { + Assert.Fail($"{ex.Message} - {ex}"); + } + } + finally + { + TaskScheduler.UnobservedTaskException -= eventHandler; + mockStream.Object?.Dispose(); + } + } + [Theory] [ParameterAttributeData] public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( @@ -698,8 +788,10 @@ public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) .Returns(mockStream.Object); var readTcs = new TaskCompletionSource(); - mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(() => readTcs.Task.GetAwaiter().GetResult()); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + .Returns(() => readTcs.Task); + mockStream.Setup(s => s.EndRead(It.IsAny())) + .Returns(x => ((Task)x).GetAwaiter().GetResult()); mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(readTcs.Task); _subject.Open(CancellationToken.None); @@ -763,8 +855,10 @@ public void ReceiveMessage_should_throw_MongoConnectionClosedException_when_conn .Returns(mockStream.Object); var readTcs = new TaskCompletionSource(); readTcs.SetException(new SocketException()); - mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(() => readTcs.Task.GetAwaiter().GetResult()); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + .Returns(() => readTcs.Task); + mockStream.Setup(s => s.EndRead(It.IsAny())) + .Returns(x => ((Task)x).GetAwaiter().GetResult()); mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(readTcs.Task); _subject.Open(CancellationToken.None); diff --git a/tests/MongoDB.Driver.Tests/Jira/CSharp3188Tests.cs b/tests/MongoDB.Driver.Tests/Jira/CSharp3188Tests.cs index c3395ad3d04..5ef47a27fdd 100644 --- a/tests/MongoDB.Driver.Tests/Jira/CSharp3188Tests.cs +++ b/tests/MongoDB.Driver.Tests/Jira/CSharp3188Tests.cs @@ -65,33 +65,12 @@ public void Connection_timeout_should_throw_expected_exception([Values(false, tr .Limit(1) .Project(projectionDefinition); - if (async) - { - var exception = Record.Exception(() => collection.AggregateAsync(pipeline).GetAwaiter().GetResult()); + var exception = Record.Exception(() => collection.AggregateAsync(pipeline).GetAwaiter().GetResult()); - var mongoConnectionException = exception.Should().BeOfType().Subject; -#pragma warning disable CS0618 // Type or member is obsolete - mongoConnectionException.ContainsSocketTimeoutException.Should().BeFalse(); -#pragma warning restore CS0618 // Type or member is obsolete - mongoConnectionException.ContainsTimeoutException.Should().BeTrue(); - var baseException = GetBaseException(mongoConnectionException); - baseException.Should().BeOfType().Which.InnerException.Should().BeNull(); - } - else - { - var exception = Record.Exception(() => collection.Aggregate(pipeline)); - - var mongoConnectionException = exception.Should().BeOfType().Subject; -#pragma warning disable CS0618 // Type or member is obsolete - mongoConnectionException.ContainsSocketTimeoutException.Should().BeTrue(); -#pragma warning restore CS0618 // Type or member is obsolete - mongoConnectionException.ContainsTimeoutException.Should().BeTrue(); - var baseException = GetBaseException(mongoConnectionException); - var socketException = baseException.Should().BeOfType() - .Which.InnerException.Should().BeOfType().Subject; - socketException.SocketErrorCode.Should().Be(SocketError.TimedOut); - socketException.InnerException.Should().BeNull(); - } + var mongoConnectionException = exception.Should().BeOfType().Subject; + mongoConnectionException.ContainsTimeoutException.Should().BeTrue(); + var baseException = GetBaseException(mongoConnectionException); + baseException.Should().BeOfType().Which.InnerException.Should().BeNull(); } Exception GetBaseException(MongoConnectionException mongoConnectionException) diff --git a/tests/MongoDB.Driver.Tests/Specifications/UnifiedTestSpecRunner.cs b/tests/MongoDB.Driver.Tests/Specifications/UnifiedTestSpecRunner.cs index fe8d7683961..54efafe132c 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/UnifiedTestSpecRunner.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/UnifiedTestSpecRunner.cs @@ -262,9 +262,6 @@ private static void RequireKmsMock() => "legacy hello without speculativeAuthenticate is always observed", // transactions - // Skipped because CSharp Driver has an issue with handling read timeout for sync code-path. CSHARP-3662 - "add RetryableWriteError and UnknownTransactionCommitResult labels to connection errors", - // CSHARP Driver does not comply with the requirement to throw in case explicit writeConcern were used, see CSHARP-5468 "client bulkWrite with writeConcern in a transaction causes a transaction error", ]); diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs index 2d9d3251d7b..dc305b85185 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs @@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open() var mockStream = new Mock(); mockStream - .Setup(s => s.Write(It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(s => s.BeginWrite(It.IsAny(), It.IsAny(), It.IsAny(), null, null)) .Callback(() => EnqueueEvent(HelloReceivedEvent)) .Throws(new Exception("Stream is closed."));