diff --git a/Source/HiveMQtt/Client/HiveMQClient.cs b/Source/HiveMQtt/Client/HiveMQClient.cs index 482dfac6..8b845d2c 100644 --- a/Source/HiveMQtt/Client/HiveMQClient.cs +++ b/Source/HiveMQtt/Client/HiveMQClient.cs @@ -57,10 +57,12 @@ public HiveMQClient(HiveMQClientOptions? options = null) this.Options = options; this.cancellationTokenSource = new CancellationTokenSource(); - this.ClientReceiveSemaphore = new SemaphoreSlim(this.Options.ClientReceiveMaximum); + + // In-flight transaction queues + this.IPubTransactionQueue = new BoundedDictionaryX>(this.Options.ClientReceiveMaximum); // Set protocol default until ConnAck is received - this.BrokerReceiveSemaphore = new SemaphoreSlim(65535); + this.OPubTransactionQueue = new BoundedDictionaryX>(65535); } /// @@ -140,9 +142,15 @@ public async Task ConnectAsync() /// public async Task DisconnectAsync(DisconnectOptions? options = null) { + if (this.ConnectState == ConnectState.Disconnecting) + { + // We're already disconnecting in another task. + return true; + } + if (this.ConnectState != ConnectState.Connected) { - Logger.Warn("DisconnectAsync called but this client is not connected. State is ${this.ConnectState}."); + Logger.Warn($"DisconnectAsync called but this client is not connected. State is {this.ConnectState}."); return false; } @@ -510,10 +518,7 @@ private async Task HandleDisconnectionAsync(bool clean = true) // Cancel all background tasks and close the socket this.ConnectState = ConnectState.Disconnected; - // Don't use CancelAsync here to maintain backwards compatibility - // with >=.net6.0. CancelAsync was introduced in .net8.0 - this.cancellationTokenSource.Cancel(); - this.CloseSocket(); + await this.CloseSocketAsync().ConfigureAwait(false); if (clean) { diff --git a/Source/HiveMQtt/Client/HiveMQClientSocket.cs b/Source/HiveMQtt/Client/HiveMQClientSocket.cs index 3ad80e44..8c3ddd66 100644 --- a/Source/HiveMQtt/Client/HiveMQClientSocket.cs +++ b/Source/HiveMQtt/Client/HiveMQClientSocket.cs @@ -254,25 +254,17 @@ private async Task CreateTLSConnectionAsync(Stream stream) } } - internal bool CloseSocket(bool? shutdownPipeline = true) + internal async Task CloseSocketAsync(bool? shutdownPipeline = true) { - // Cancel the background traffic processing tasks - this.cancellationTokenSource.Cancel(); - - // Reset the tasks - this.ConnectionPublishWriterTask = null; - this.ConnectionWriterTask = null; - this.ConnectionReaderTask = null; - this.ReceivedPacketsHandlerTask = null; - this.ConnectionMonitorTask = null; + await this.CancelBackgroundTasksAsync().ConfigureAwait(false); if (shutdownPipeline == true) { if (this.Reader != null && this.Writer != null) { // Dispose of the PipeReader and PipeWriter - this.Reader.Complete(); - this.Writer.Complete(); + await this.Reader.CompleteAsync().ConfigureAwait(false); + await this.Writer.CompleteAsync().ConfigureAwait(false); // Shutdown the pipeline this.Reader = null; @@ -284,7 +276,7 @@ internal bool CloseSocket(bool? shutdownPipeline = true) { // Dispose of the Stream this.Stream.Close(); - this.Stream.Dispose(); + await this.Stream.DisposeAsync().ConfigureAwait(false); this.Stream = null; } @@ -300,4 +292,64 @@ internal bool CloseSocket(bool? shutdownPipeline = true) return true; } + + /// + /// Cancel all background tasks. + /// + /// A task representing the asynchronous operation. + internal async Task CancelBackgroundTasksAsync() + { + // Don't use CancelAsync here to maintain backwards compatibility + // with >=.net6.0. CancelAsync was introduced in .net8.0 + this.cancellationTokenSource.Cancel(); + + // Delay for a short period to allow the tasks to cancel + await Task.Delay(1000).ConfigureAwait(false); + + // Reset the tasks + if (this.ConnectionPublishWriterTask is not null && this.ConnectionPublishWriterTask.IsCompleted) + { + this.ConnectionPublishWriterTask = null; + } + else + { + Logger.Error("ConnectionPublishWriterTask did not complete"); + } + + if (this.ConnectionWriterTask is not null && this.ConnectionWriterTask.IsCompleted) + { + this.ConnectionWriterTask = null; + } + else + { + Logger.Error("ConnectionWriterTask did not complete"); + } + + if (this.ConnectionReaderTask is not null && this.ConnectionReaderTask.IsCompleted) + { + this.ConnectionReaderTask = null; + } + else + { + Logger.Error("ConnectionReaderTask did not complete"); + } + + if (this.ReceivedPacketsHandlerTask is not null && this.ReceivedPacketsHandlerTask.IsCompleted) + { + this.ReceivedPacketsHandlerTask = null; + } + else + { + Logger.Error("ReceivedPacketsHandlerTask did not complete"); + } + + if (this.ConnectionMonitorTask is not null && this.ConnectionMonitorTask.IsCompleted) + { + this.ConnectionMonitorTask = null; + } + else + { + Logger.Error("ConnectionMonitorTask did not complete"); + } + } } diff --git a/Source/HiveMQtt/Client/HiveMQClientTrafficProcessor.cs b/Source/HiveMQtt/Client/HiveMQClientTrafficProcessor.cs index 6c9b96c0..d38c1a16 100644 --- a/Source/HiveMQtt/Client/HiveMQClientTrafficProcessor.cs +++ b/Source/HiveMQtt/Client/HiveMQClientTrafficProcessor.cs @@ -16,7 +16,6 @@ namespace HiveMQtt.Client; using System; -using System.Collections.Concurrent; using System.Diagnostics; using System.IO.Pipelines; using System.Threading.Tasks; @@ -39,15 +38,11 @@ public partial class HiveMQClient : IDisposable, IHiveMQClient internal AwaitableQueueX ReceivedQueue { get; } = new(); - // Incoming Publish QoS > 0 packets indexed by packet identifier - internal ConcurrentDictionary> IPubTransactionQueue { get; } = new(); + // Incoming Publish QoS > 0 in-flight transactions indexed by packet identifier + internal BoundedDictionaryX> IPubTransactionQueue { get; set; } - // Outgoing Publish QoS > 0 packets indexed by packet identifier - internal ConcurrentDictionary> OPubTransactionQueue { get; } = new(); - - private SemaphoreSlim BrokerReceiveSemaphore { get; set; } - - internal SemaphoreSlim ClientReceiveSemaphore { get; } + // Outgoing Publish QoS > 0 in-flight transactions indexed by packet identifier + internal BoundedDictionaryX> OPubTransactionQueue { get; set; } private readonly Stopwatch lastCommunicationTimer = new(); @@ -106,13 +101,12 @@ private Task ConnectionMonitorAsync(CancellationToken cancellationToken) => Task // Dumping Client State Logger.Debug($"{this.Options.ClientId}-(CM)- {this.ConnectState} lastCommunicationTimer:{this.lastCommunicationTimer.Elapsed}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- SendQueue:............{this.SendQueue.Count}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- ReceivedQueue:........{this.ReceivedQueue.Count}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- OutgoingPublishQueue:.{this.OutgoingPublishQueue.Count}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- BrokerReceiveMaxSem...{this.BrokerReceiveSemaphore.CurrentCount}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- OPubTransactionQueue:.{this.OPubTransactionQueue.Count}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- IPubTransactionQueue:.{this.IPubTransactionQueue.Count}"); - Logger.Debug($"{this.Options.ClientId}-(CM)- # of Subscriptions:...{this.Subscriptions.Count}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- SendQueue:...............{this.SendQueue.Count}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- ReceivedQueue:...........{this.ReceivedQueue.Count}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- OutgoingPublishQueue:....{this.OutgoingPublishQueue.Count}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- OPubTransactionQueue:....{this.OPubTransactionQueue.Count}/{this.OPubTransactionQueue.Capacity}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- IPubTransactionQueue:....{this.IPubTransactionQueue.Count}/{this.IPubTransactionQueue.Capacity}"); + Logger.Debug($"{this.Options.ClientId}-(CM)- # of Subscriptions:......{this.Subscriptions.Count}"); await this.RunTaskHealthCheckAsync(this.ConnectionWriterTask, "ConnectionWriter").ConfigureAwait(false); await this.RunTaskHealthCheckAsync(this.ConnectionReaderTask, "ConnectionReader").ConfigureAwait(false); @@ -173,15 +167,16 @@ private Task ConnectionPublishWriterAsync(CancellationToken cancellationToken) = if (publishPacket.Message.QoS is QualityOfService.AtLeastOnceDelivery || publishPacket.Message.QoS is QualityOfService.ExactlyOnceDelivery) { - // We have the next qos>0 publish packet to send - // Respect the broker's ReceiveMaximum - await this.BrokerReceiveSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - - // QoS > 0 - Add to transaction queue - if (!this.OPubTransactionQueue.TryAdd(publishPacket.PacketIdentifier, new List { publishPacket })) + // QoS > 0 - Add to transaction queue. OPubTransactionQueue will block when necessary + // to respect the broker's ReceiveMaximum + var success = await this.OPubTransactionQueue.AddAsync( + publishPacket.PacketIdentifier, + new List { publishPacket }, + cancellationToken).ConfigureAwait(false); + + if (!success) { Logger.Warn($"Duplicate packet ID detected {publishPacket.PacketIdentifier} while queueing to transaction queue for an outgoing QoS {publishPacket.Message.QoS} publish ."); - this.BrokerReceiveSemaphore.Release(); continue; } } @@ -276,11 +271,13 @@ private Task ConnectionWriterAsync(CancellationToken cancellationToken) => Task. break; case PublishPacket publishPacket: throw new HiveMQttClientException("PublishPacket should be sent via ConnectionPublishWriterAsync."); + case PubAckPacket pubAckPacket: Logger.Trace($"{this.Options.ClientId}-(W)- --> Sending PubAckPacket id={pubAckPacket.PacketIdentifier} reason={pubAckPacket.ReasonCode}"); writeResult = await this.WriteAsync(pubAckPacket.Encode()).ConfigureAwait(false); - this.OnPubAckSentEventLauncher(pubAckPacket); + this.HandleSentPubAckPacket(pubAckPacket); break; + case PubRecPacket pubRecPacket: Logger.Trace($"{this.Options.ClientId}-(W)- --> Sending PubRecPacket id={pubRecPacket.PacketIdentifier} reason={pubRecPacket.ReasonCode}"); writeResult = await this.WriteAsync(pubRecPacket.Encode()).ConfigureAwait(false); @@ -294,7 +291,7 @@ private Task ConnectionWriterAsync(CancellationToken cancellationToken) => Task. case PubCompPacket pubCompPacket: Logger.Trace($"{this.Options.ClientId}-(W)- --> Sending PubCompPacket id={pubCompPacket.PacketIdentifier} reason={pubCompPacket.ReasonCode}"); writeResult = await this.WriteAsync(pubCompPacket.Encode()).ConfigureAwait(false); - this.OnPubCompSentEventLauncher(pubCompPacket); + this.HandleSentPubCompPacket(pubCompPacket); break; case PingReqPacket pingReqPacket: Logger.Trace($"{this.Options.ClientId}-(W)- --> Sending PingReqPacket id={pingReqPacket.PacketIdentifier}"); @@ -415,7 +412,9 @@ private Task ConnectionReaderAsync(CancellationToken cancellationToken) => // We handle disconnects immediately if (decodedPacket is DisconnectPacket disconnectPacket) { - Logger.Warn($"-(R)- <-- Disconnect received: {disconnectPacket.DisconnectReasonCode} {disconnectPacket.Properties.ReasonString}"); + // FIXME: If we received disconnect another client with the same id connected else where, should we + // Call the OnDisconnectReceivedEventLauncher? + Logger.Error($"--> Disconnect received <--: {disconnectPacket.DisconnectReasonCode} {disconnectPacket.Properties.ReasonString}"); await this.HandleDisconnectionAsync(false).ConfigureAwait(false); this.OnDisconnectReceivedEventLauncher(disconnectPacket); break; @@ -427,18 +426,17 @@ private Task ConnectionReaderAsync(CancellationToken cancellationToken) => if (publishPacket.Message.QoS is QualityOfService.ExactlyOnceDelivery || publishPacket.Message.QoS is QualityOfService.AtLeastOnceDelivery) { - while (true) + var success = await this.IPubTransactionQueue.AddAsync( + publishPacket.PacketIdentifier, + new List { publishPacket }).ConfigureAwait(false); + + if (!success) { - if (this.IPubTransactionQueue.Count >= this.Options.ClientReceiveMaximum) - { - Logger.Trace($"-(R)- The Maximum number of concurrent publishes have been received from broker. Applying back-pressure and waiting for existing transactions to complete."); - await Task.Delay(500).ConfigureAwait(false); - } - else - { - break; - } - } // while (true) + Logger.Warn($"Duplicate packet ID detected {publishPacket.PacketIdentifier} while queueing to transaction queue for an incoming QoS {publishPacket.Message.QoS} publish ."); + + // FIXME: We should potentially disconnect here + continue; + } } } @@ -494,6 +492,7 @@ private Task ReceivedPacketsHandlerAsync(CancellationToken cancellationToken) => var packet = await this.ReceivedQueue.DequeueAsync(cancellationToken).ConfigureAwait(false); if (this.Options.ClientMaximumPacketSize != null) { + // FIXME: Move this to ConnectionReaderAsync/closer to the source instead of after the queue if (packet.PacketSize > this.Options.ClientMaximumPacketSize) { Logger.Warn($"Received packet size {packet.PacketSize} exceeds maximum packet size {this.Options.ClientMaximumPacketSize}. Disconnecting."); @@ -517,8 +516,8 @@ private Task ReceivedPacketsHandlerAsync(CancellationToken cancellationToken) => { Logger.Debug($"{this.Options.ClientId}-(RPH)- <-- Broker says limit concurrent incoming QoS 1 and QoS 2 publishes to {connAckPacket.Properties.ReceiveMaximum}."); - // Replace the BrokerReceiveSemaphore with a new one with the broker's ReceiveMaximum - this.BrokerReceiveSemaphore = new SemaphoreSlim((int)connAckPacket.Properties.ReceiveMaximum); + // Replace the OPubTransactionQueue BoundedDictionary with a new one with the broker's ReceiveMaximum + this.OPubTransactionQueue = new BoundedDictionaryX>((int)connAckPacket.Properties.ReceiveMaximum); } this.ConnectionProperties = connAckPacket.Properties; @@ -587,7 +586,7 @@ private Task ReceivedPacketsHandlerAsync(CancellationToken cancellationToken) => /// Handle an incoming Publish packet. /// /// The received publish packet. - internal void HandleIncomingPublishPacket(PublishPacket publishPacket) + internal async void HandleIncomingPublishPacket(PublishPacket publishPacket) { Logger.Trace($"{this.Options.ClientId}-(RPH)- <-- Received Publish id={publishPacket.PacketIdentifier}"); this.OnPublishReceivedEventLauncher(publishPacket); @@ -600,20 +599,35 @@ internal void HandleIncomingPublishPacket(PublishPacket publishPacket) { // We've received a QoS 1 publish. Send a PubAck and notify subscribers. var pubAckResponse = new PubAckPacket(publishPacket.PacketIdentifier, PubAckReasonCode.Success); + this.SendQueue.Enqueue(pubAckResponse); this.OnMessageReceivedEventLauncher(publishPacket); } else if (publishPacket.Message.QoS is QualityOfService.ExactlyOnceDelivery) { + // We've received a QoS 2 publish. Send a PubRec and add to QoS2 transaction register. // When we get the PubRel, we'll notify subscribers and send the PubComp in HandleIncomingPubRelPacket + bool success; var pubRecResponse = new PubRecPacket(publishPacket.PacketIdentifier, PubRecReasonCode.Success); - var publishQoS2Chain = new List { publishPacket, pubRecResponse }; - if (!this.IPubTransactionQueue.TryAdd(publishPacket.PacketIdentifier, publishQoS2Chain)) + // Get the QoS2 transaction chain for this packet identifier and add the PubRec to it + success = this.IPubTransactionQueue.TryGetValue(publishPacket.PacketIdentifier, out var publishQoS2Chain); + publishQoS2Chain.Add(pubRecResponse); + + if (success) + { + // Update the chain in the queue + if (!this.IPubTransactionQueue.TryUpdate(publishPacket.PacketIdentifier, publishQoS2Chain, publishQoS2Chain)) + { + Logger.Error($"QoS2: Couldn't update Publish --> PubRec QoS2 Chain for packet identifier {publishPacket.PacketIdentifier}. Discarded."); + this.IPubTransactionQueue.Remove(publishPacket.PacketIdentifier, out _); + } + } + else { - Logger.Warn($"Duplicate packet ID detected {publishPacket.PacketIdentifier} while queueing to transaction queue for an incoming QoS {publishPacket.Message.QoS} publish ."); - pubRecResponse.ReasonCode = PubRecReasonCode.PacketIdentifierInUse; + Logger.Error($"QoS2: Received Publish with an unknown packet identifier {publishPacket.PacketIdentifier}. Discarded."); + return; } this.SendQueue.Enqueue(pubRecResponse); @@ -637,17 +651,37 @@ internal void HandleIncomingPubAckPacket(PubAckPacket pubAckPacket) var publishPacket = (PublishPacket)publishQoS1Chain.First(); // We sent a QoS1 publish and received a PubAck. The transaction is complete. + // Trigger the packet specific event + publishPacket.OnPublishQoS1CompleteEventLauncher(pubAckPacket); + } + else + { + Logger.Warn($"QoS1: Received PubAck with an unknown packet identifier {pubAckPacket.PacketIdentifier}. Discarded."); + } + } - // Release the semaphore - this.BrokerReceiveSemaphore.Release(); + /// + /// Handle an incoming PubComp packet. + /// + /// The received PubComp packet. + internal void HandleSentPubAckPacket(PubAckPacket pubAckPacket) + { + // Remove the transaction chain from the transaction queue + var success = this.IPubTransactionQueue.Remove(pubAckPacket.PacketIdentifier, out var publishQoS1Chain); + + if (success) + { + var publishPacket = (PublishPacket)publishQoS1Chain.First(); // Trigger the packet specific event publishPacket.OnPublishQoS1CompleteEventLauncher(pubAckPacket); } else { - Logger.Warn($"QoS1: Received PubAck with an unknown packet identifier {pubAckPacket.PacketIdentifier}. Discarded."); + Logger.Warn($"QoS1: Couldn't remove PubAck --> Publish QoS1 Chain for packet identifier {pubAckPacket.PacketIdentifier}."); } + + this.OnPubAckSentEventLauncher(pubAckPacket); } /// @@ -681,9 +715,8 @@ internal async void HandleIncomingPubRecPacket(PubRecPacket pubRecPacket) { Logger.Error($"QoS2: Couldn't update PubRec --> PubRel QoS2 Chain for packet identifier {pubRecPacket.PacketIdentifier}."); this.OPubTransactionQueue.Remove(pubRecPacket.PacketIdentifier, out _); - this.BrokerReceiveSemaphore.Release(); - // FIXME: Send an appropriate disconnect packet + // FIXME: Send an appropriate disconnect packet? await this.HandleDisconnectionAsync(false).ConfigureAwait(false); } @@ -707,31 +740,22 @@ internal void HandleIncomingPubRelPacket(PubRelPacket pubRelPacket) Logger.Trace($"{this.Options.ClientId}-(RPH)- <-- Received PubRel id={pubRelPacket.PacketIdentifier} reason={pubRelPacket.ReasonCode}"); this.OnPubRelReceivedEventLauncher(pubRelPacket); + PubCompPacket pubCompResponsePacket; + // This is in response to a publish that we received and already sent a pubrec - if (this.IPubTransactionQueue.TryGetValue(pubRelPacket.PacketIdentifier, out var originalPublishQoS2Chain)) + if (this.IPubTransactionQueue.TryGetValue(pubRelPacket.PacketIdentifier, out var publishQoS2Chain)) { - var originalPublishPacket = (PublishPacket)originalPublishQoS2Chain.First(); - // Send a PUBCOMP in response - var pubCompResponsePacket = new PubCompPacket(pubRelPacket.PacketIdentifier, PubCompReasonCode.Success); + pubCompResponsePacket = new PubCompPacket(pubRelPacket.PacketIdentifier, PubCompReasonCode.Success); - // This QoS2 transaction chain is done. Remove it from the transaction queue. - if (this.IPubTransactionQueue.TryRemove(pubRelPacket.PacketIdentifier, out var publishQoS2Chain)) - { - // Update the chain with the latest packets for the event launcher - publishQoS2Chain.Add(pubRelPacket); - publishQoS2Chain.Add(pubCompResponsePacket); + // Update the chain with the latest packets for the event launcher + publishQoS2Chain.Add(pubRelPacket); + publishQoS2Chain.Add(pubCompResponsePacket); - // Trigger the packet specific event - originalPublishPacket.OnPublishQoS2CompleteEventLauncher(publishQoS2Chain); - this.OnMessageReceivedEventLauncher(originalPublishPacket); - } - else + if (!this.IPubTransactionQueue.TryUpdate(pubRelPacket.PacketIdentifier, publishQoS2Chain, publishQoS2Chain)) { - Logger.Warn($"QoS2: Couldn't remove PubRel --> PubComp QoS2 Chain for packet identifier {pubRelPacket.PacketIdentifier}."); + Logger.Warn($"QoS2: Couldn't update PubRel --> PubComp QoS2 Chain for packet identifier {pubRelPacket.PacketIdentifier}."); } - - this.SendQueue.Enqueue(pubCompResponsePacket); } else { @@ -739,9 +763,40 @@ internal void HandleIncomingPubRelPacket(PubRelPacket pubRelPacket) "Responding with PubComp PacketIdentifierNotFound."); // Send a PUBCOMP with PacketIdentifierNotFound - var pubCompResponsePacket = new PubCompPacket(pubRelPacket.PacketIdentifier, PubCompReasonCode.PacketIdentifierNotFound); - this.SendQueue.Enqueue(pubCompResponsePacket); + pubCompResponsePacket = new PubCompPacket(pubRelPacket.PacketIdentifier, PubCompReasonCode.PacketIdentifierNotFound); } + + this.SendQueue.Enqueue(pubCompResponsePacket); + + } + + /// + /// Action to take once a PubComp packet is sent. + /// + /// The sent PubComp packet. + internal void HandleSentPubCompPacket(PubCompPacket pubCompPacket) + { + Logger.Trace($"{this.Options.ClientId}-(RPH)- <-- Sent PubComp id={pubCompPacket.PacketIdentifier} reason={pubCompPacket.ReasonCode}"); + + // PubCompReasonCode is either Success or PacketIdentifierNotFound. If the latter, + // there won't be a transaction chain to remove. + if (pubCompPacket.ReasonCode == PubCompReasonCode.Success) + { + // QoS 2 Transaction is done. Remove the transaction chain from the queue + if (this.IPubTransactionQueue.Remove(pubCompPacket.PacketIdentifier, out var publishQoS2Chain)) + { + var originalPublishPacket = (PublishPacket)publishQoS2Chain.First(); + + // Trigger the packet specific event + originalPublishPacket.OnPublishQoS2CompleteEventLauncher(publishQoS2Chain); + + // Trigger the application message event + this.OnMessageReceivedEventLauncher(originalPublishPacket); + } + } + + // Trigger the general event + this.OnPubCompSentEventLauncher(pubCompPacket); } /// @@ -763,9 +818,6 @@ internal void HandleIncomingPubCompPacket(PubCompPacket pubCompPacket) // Update the chain with this PubComp packet for the event launcher publishQoS2Chain.Add(pubCompPacket); - // Release the semaphore - this.BrokerReceiveSemaphore.Release(); - // Trigger the packet specific event with the entire chain originalPublishPacket.OnPublishQoS2CompleteEventLauncher(publishQoS2Chain); } diff --git a/Source/HiveMQtt/Client/internal/BoundedDictionaryX.cs b/Source/HiveMQtt/Client/internal/BoundedDictionaryX.cs new file mode 100644 index 00000000..9598f585 --- /dev/null +++ b/Source/HiveMQtt/Client/internal/BoundedDictionaryX.cs @@ -0,0 +1,194 @@ +namespace HiveMQtt.Client.Internal; + +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +/// +/// A finite (bounded) dictionary that can be awaited on for slots to become available. +/// +/// The type of items to index with. +/// The type of items to store as values. +public class BoundedDictionaryX : IDisposable + where TKey : notnull +{ + private static readonly NLog.Logger Logger = NLog.LogManager.GetCurrentClassLogger(); + + /// + /// The semaphore used to signal when items are enqueued. + /// + private readonly SemaphoreSlim semaphore; + + /// + /// The internal queue of items. + /// + private readonly ConcurrentDictionary dictionary; + + /// + /// Gets the capacity of the queue. + /// + public int Capacity { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// The capacity of the queue. + public BoundedDictionaryX(int capacity) + { + this.Capacity = capacity; + this.semaphore = new SemaphoreSlim(capacity); + this.dictionary = new ConcurrentDictionary(); + } + + /// + /// Attempts to add an item to the dictionary. + /// + /// The key to add. + /// The value to add. + /// The cancellation token. + /// true if the item was added; otherwise, false. + public async Task AddAsync(TKey key, TVal value, CancellationToken cancellationToken = default) + { + bool errorDetected; + + Logger.Trace("Adding item {0}", key); + Logger.Trace("Open slots: {0} Dictionary Count: {1}", this.semaphore.CurrentCount, this.dictionary.Count); + + // Wait for an available slot + await this.semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + if (this.dictionary.TryAdd(key, value)) + { + return true; + } + else + { + Logger.Warn("Duplicate key: {0}", key); + + errorDetected = true; + } + } + catch (ArgumentNullException ex) + { + Logger.Warn("ArgumentNull Exception: {0}", ex); + errorDetected = true; + } + catch (OverflowException ex) + { + Logger.Warn("Overflow Exception: {0}", ex); + errorDetected = true; + } + + if (errorDetected) + { + // We failed to add the item, release the slot + this.semaphore.Release(); + } + + return false; + } + + /// + /// Attempts to remove an item from the dictionary. + /// + /// The key to remove. + /// The value removed. + /// true if the item was removed; otherwise, false. + public bool Remove(TKey key, out TVal value) + { + Logger.Trace("Removing item {0}", key); + Logger.Trace("Open slots: {0} Dictionary Count: {1}", this.semaphore.CurrentCount, this.dictionary.Count); + + try + { + if (this.dictionary.TryRemove(key, out value)) + { + // Item successfully removed, release the slot + this.semaphore.Release(); + return true; + } + else + { + Logger.Warn("Key not found: {0}", key); + } + } + catch (ArgumentNullException ex) + { + Logger.Warn("ArgumentNull Exception: {0}", ex); + } + catch (OverflowException ex) + { + Logger.Warn("Overflow Exception: {0}", ex); + } + + value = default!; + return false; + } + + /// + /// Attempts to update an item in the dictionary. + /// + /// The key to update. + /// The new value. + /// The value to compare against. + /// true if the item was updated; otherwise, false. + public bool TryUpdate(TKey key, TVal newValue, TVal comparisonValue) => this.dictionary.TryUpdate(key, newValue, comparisonValue); + + /// + /// Attempts to get a value from the dictionary. + /// + /// The key to get. + /// The value retrieved. + /// true if the item was retrieved; otherwise, false. + public bool TryGetValue(TKey key, out TVal value) => this.dictionary.TryGetValue(key, out value); + + /// + /// Removes all items from the dictionary. + /// + /// true if the dictionary was cleared; otherwise, false. + public bool Clear() + { + try + { + var numItems = this.dictionary.Count; + this.dictionary.Clear(); + this.semaphore.Release(numItems); + return true; + } + catch (ArgumentNullException ex) + { + Logger.Warn("ArgumentNull Exception: {0}", ex); + } + catch (OverflowException ex) + { + Logger.Warn("Overflow Exception: {0}", ex); + } + catch (Exception ex) + { + Logger.Warn("Exception: {0}", ex); + } + + return false; + } + + /// + /// Gets the number of items in the queue. + /// + /// The number of items in the queue. + public int Count => this.dictionary.Count; + + /// + /// Gets a value indicating whether the queue is empty. + /// + public bool IsEmpty => this.dictionary.IsEmpty; + + /// + public void Dispose() + { + this.semaphore.Dispose(); + GC.SuppressFinalize(this); + } +} diff --git a/Tests/HiveMQtt.Test/HiveMQClient/PublishTest.cs b/Tests/HiveMQtt.Test/HiveMQClient/PublishTest.cs index 36c097cf..27a05e53 100644 --- a/Tests/HiveMQtt.Test/HiveMQClient/PublishTest.cs +++ b/Tests/HiveMQtt.Test/HiveMQClient/PublishTest.cs @@ -86,7 +86,7 @@ public async Task MultiPublishWithQoS0Async() Assert.IsType(result); Assert.Null(result.QoS1ReasonCode); Assert.Null(result.QoS2ReasonCode); - Assert.Equal(MQTT5.Types.QualityOfService.AtMostOnceDelivery, result.Message.QoS); + Assert.Equal(QualityOfService.AtMostOnceDelivery, result.Message.QoS); } var disconnectResult = await client.DisconnectAsync().ConfigureAwait(false); @@ -255,13 +255,13 @@ void Client3MessageHandler(object? sender, OnMessageReceivedEventArgs eventArgs) Assert.Equal(0, client2.SendQueue.Count); Assert.Equal(0, client3.SendQueue.Count); - Assert.Empty(client1.OPubTransactionQueue); - Assert.Empty(client2.OPubTransactionQueue); - Assert.Empty(client3.OPubTransactionQueue); + Assert.Equal(0, client1.OPubTransactionQueue.Count); + Assert.Equal(0, client2.OPubTransactionQueue.Count); + Assert.Equal(0, client3.OPubTransactionQueue.Count); - Assert.Empty(client1.IPubTransactionQueue); - Assert.Empty(client2.IPubTransactionQueue); - Assert.Empty(client3.IPubTransactionQueue); + Assert.Equal(0, client1.IPubTransactionQueue.Count); + Assert.Equal(0, client2.IPubTransactionQueue.Count); + Assert.Equal(0, client3.IPubTransactionQueue.Count); // All done, disconnect all clients var disconnectResult = await client1.DisconnectAsync().ConfigureAwait(false); @@ -350,13 +350,13 @@ void Client3MessageHandler(object? sender, OnMessageReceivedEventArgs eventArgs) Assert.Equal(0, client2.SendQueue.Count); Assert.Equal(0, client3.SendQueue.Count); - Assert.Empty(client1.OPubTransactionQueue); - Assert.Empty(client2.OPubTransactionQueue); - Assert.Empty(client3.OPubTransactionQueue); + Assert.Equal(0, client1.OPubTransactionQueue.Count); + Assert.Equal(0, client2.OPubTransactionQueue.Count); + Assert.Equal(0, client3.OPubTransactionQueue.Count); - Assert.Empty(client1.IPubTransactionQueue); - Assert.Empty(client2.IPubTransactionQueue); - Assert.Empty(client3.IPubTransactionQueue); + Assert.Equal(0, client1.IPubTransactionQueue.Count); + Assert.Equal(0, client2.IPubTransactionQueue.Count); + Assert.Equal(0, client3.IPubTransactionQueue.Count); // All done, disconnect all clients var disconnectResult = await client1.DisconnectAsync().ConfigureAwait(false); @@ -444,13 +444,13 @@ void Client3MessageHandler(object? sender, OnMessageReceivedEventArgs eventArgs) Assert.Equal(0, client2.SendQueue.Count); Assert.Equal(0, client3.SendQueue.Count); - Assert.Empty(client1.OPubTransactionQueue); - Assert.Empty(client2.OPubTransactionQueue); - Assert.Empty(client3.OPubTransactionQueue); + Assert.Equal(0, client1.OPubTransactionQueue.Count); + Assert.Equal(0, client2.OPubTransactionQueue.Count); + Assert.Equal(0, client3.OPubTransactionQueue.Count); - Assert.Empty(client1.IPubTransactionQueue); - Assert.Empty(client2.IPubTransactionQueue); - Assert.Empty(client3.IPubTransactionQueue); + Assert.Equal(0, client1.IPubTransactionQueue.Count); + Assert.Equal(0, client2.IPubTransactionQueue.Count); + Assert.Equal(0, client3.IPubTransactionQueue.Count); // All done, disconnect all clients var disconnectResult = await client1.DisconnectAsync().ConfigureAwait(false); diff --git a/Tests/HiveMQtt.Test/Queues/BoundedDictionaryXTest.cs b/Tests/HiveMQtt.Test/Queues/BoundedDictionaryXTest.cs new file mode 100644 index 00000000..f82f17cb --- /dev/null +++ b/Tests/HiveMQtt.Test/Queues/BoundedDictionaryXTest.cs @@ -0,0 +1,136 @@ +namespace HiveMQtt.Test.Packets; + +using HiveMQtt.Client.Options; +using HiveMQtt.Client.Internal; +using HiveMQtt.MQTT5.Packets; +using HiveMQtt.MQTT5; +using Xunit; + +public class BoundedDictionaryXTest +{ + [Fact] + public async Task BlockWhenSlotsFullAsync() + { + var dictionary = new BoundedDictionaryX(2); + Assert.True(dictionary.IsEmpty); + + var options = new HiveMQClientOptions(); + Assert.NotNull(options); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + var packet = new ConnectPacket(options); + var operationCanceled = false; + + try + { + // Add the first + await dictionary.AddAsync(1, packet, cts.Token).ConfigureAwait(false); + Assert.False(dictionary.IsEmpty); + Assert.Equal(1, dictionary.Count); + + // Add the second + await dictionary.AddAsync(2, packet, cts.Token).ConfigureAwait(false); + Assert.False(dictionary.IsEmpty); + Assert.Equal(2, dictionary.Count); + + // The third should block and wait for the cancellation token to timeout + await dictionary.AddAsync(3, packet, cts.Token).ConfigureAwait(false); + + } + catch (OperationCanceledException) + { + operationCanceled = true; + } + + Assert.True(operationCanceled); + Assert.False(dictionary.IsEmpty); + Assert.Equal(2, dictionary.Count); + } + + [Fact] + public async Task SlotsCanBeUpdatedAsync() + { + var dictionary = new BoundedDictionaryX(3); + Assert.True(dictionary.IsEmpty); + + var options = new HiveMQClientOptions(); + Assert.NotNull(options); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + var packet = new ConnectPacket(options); + var result = false; + + // Add the first + result = await dictionary.AddAsync(1, packet, cts.Token).ConfigureAwait(false); + Assert.True(result); + Assert.False(dictionary.IsEmpty); + Assert.Equal(1, dictionary.Count); + + // Add the second + result = await dictionary.AddAsync(2, packet, cts.Token).ConfigureAwait(false); + Assert.True(result); + Assert.False(dictionary.IsEmpty); + Assert.Equal(2, dictionary.Count); + + // Add the third + result = await dictionary.AddAsync(3, packet, cts.Token).ConfigureAwait(false); + Assert.True(result); + Assert.False(dictionary.IsEmpty); + Assert.Equal(3, dictionary.Count); + + // Get the first + result = dictionary.TryGetValue(2, out var origValue); + Assert.True(result); + Assert.False(dictionary.IsEmpty); + Assert.Equal(packet, origValue); + + var cPacket = new PingReqPacket(); + + // Update the second + result = dictionary.TryUpdate(2, cPacket, packet); + Assert.True(result); + + Assert.False(dictionary.IsEmpty); + Assert.Equal(3, dictionary.Count); + + // Re-retrieve the second item to verify the update + result = dictionary.TryGetValue(2, out var newValue); + Assert.True(result); + Assert.False(dictionary.IsEmpty); + Assert.Equal(cPacket, newValue); + } + + [Fact] + public void ExposesCapacity() + { + var dictionary = new BoundedDictionaryX(3); + Assert.True(dictionary.IsEmpty); + Assert.Equal(3, dictionary.Capacity); + } + + [Fact] + public async Task CanBeClearedAsync() + { + var dictionary = new BoundedDictionaryX(3); + Assert.True(dictionary.IsEmpty); + + var options = new HiveMQClientOptions(); + Assert.NotNull(options); + + var packetOne = new ConnectPacket(options); + await dictionary.AddAsync(1, packetOne).ConfigureAwait(false); + Assert.Equal(1, dictionary.Count); + + var packetTwo = new ConnectPacket(options); + await dictionary.AddAsync(2, packetTwo).ConfigureAwait(false); + Assert.Equal(2, dictionary.Count); + + var packetThree = new ConnectPacket(options); + await dictionary.AddAsync(3, packetThree).ConfigureAwait(false); + Assert.Equal(3, dictionary.Count); + + dictionary.Clear(); + Assert.True(dictionary.IsEmpty); + Assert.Equal(0, dictionary.Count); + } +}