diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..9a2e935 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.1.0 LANGUAGES CXX) +project(Homa VERSION 0.2.0.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### @@ -74,6 +74,7 @@ endif() add_library(Homa src/CodeLocation.cc src/Debug.cc + src/Driver.cc src/Homa.cc src/Perf.cc src/Policy.cc diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index ecfe666..67df785 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -23,141 +23,85 @@ namespace Homa { /** - * Used by Homa::Transport to send and receive unreliable datagrams. Provides - * the interface to which all Driver implementations must conform. + * A simple wrapper struct around an IP address in binary format. * - * Implementations of this class should be thread-safe. + * This struct is meant to provide some type-safety when manipulating IP + * addresses. In order to avoid any runtime overhead, this struct contains + * nothing more than the IP address, so it is trivially copyable. */ -class Driver { - public: +struct IpAddress final { + /// IPv4 address in host byte order. + uint32_t addr; + /** - * Represents a Network address. - * - * Each Address representation is specific to the Driver instance that - * returned the it; they cannot be use interchangeably between different - * Driver instances. + * Unbox the IP address in binary format. */ - using Address = uint64_t; + explicit operator uint32_t() + { + return addr; + } /** - * Used to hold a driver's serialized byte-format for a network address. - * - * Each driver may define its own byte-format so long as fits within the - * bytes array. + * Equality function for IpAddress, for use in std::unordered_maps etc. */ - struct WireFormatAddress { - uint8_t type; ///< Can be used to distinguish between different wire - ///< address formats. - uint8_t bytes[19]; ///< Holds an Address's serialized byte-format. - } __attribute__((packed)); + bool operator==(const IpAddress& other) const + { + return addr == other.addr; + } /** - * Represents a packet of data that can be send or is received over the - * network. A Packet logically contains only the payload and not any Driver - * specific headers. - * - * A Packet may be Driver specific and should not used interchangeably - * between Driver instances or implementations. - * - * This class is NOT thread-safe but the Transport and Driver's use of - * Packet objects should be allow the Transport and the Driver to execute on - * different threads. + * This class computes a hash of an IpAddress, so that IpAddress can be used + * as keys in unordered_maps. */ - class Packet { - public: - /// Packet's source or destination. When sending a Packet, the address - /// field will contain the destination Address. When receiving a Packet, - /// address field will contain the source Address. - Address address; - - /// Packet's network priority (send only); the lowest possible priority - /// is 0. The highest priority is positive number defined by the - /// Driver; the highest priority can be queried by calling the method - /// getHighestPacketPriority(). - int priority; - - /// Pointer to an array of bytes containing the payload of this Packet. - /// This array is valid until the Packet is released back to the Driver. - void* const payload; + struct Hasher { + /// Return a "hash" of the given IpAddress. + std::size_t operator()(const IpAddress& address) const + { + return std::hash{}(address.addr); + } + }; - /// Number of bytes in the payload. - int length; + static std::string toString(IpAddress address); + static IpAddress fromString(const char* addressStr); +}; +static_assert(std::is_trivially_copyable()); - /// Return the maximum number of bytes the payload can hold. - virtual int getMaxPayloadSize() = 0; +/** + * Represents a packet of data that can be send or is received over the network. + * A Packet logically contains only the transport-layer (L4) Homa header in + * addition to application data. + * + * This struct specifies the minimal object layout of a packet that the core + * Homa protocol depends on (e.g., Homa::Core::{Sender, Receiver}); this is + * useful for applications that only want to use the transport layer of this + * library and have their own infrastructures for sending and receiving packets. + */ +struct PacketSpec { + /// Pointer to an array of bytes containing the payload of this Packet. + /// This array is valid until the Packet is released back to the Driver. + void* payload; - protected: - /** - * Construct a Packet. - */ - explicit Packet(void* payload, int length = 0) - : address() - , priority(0) - , payload(payload) - , length(length) - {} + /// Number of bytes in the payload. + int32_t length; +} __attribute__((packed)); +static_assert(std::is_trivial()); - // DISALLOW_COPY_AND_ASSIGN - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; - }; +/** + * Used by Homa::Transport to send and receive unreliable datagrams. Provides + * the interface to which all Driver implementations must conform. + * + * Implementations of this class should be thread-safe. + */ +class Driver { + public: + /// Import PacketSpec into the Driver namespace. + using Packet = PacketSpec; /** * Driver destructor. */ virtual ~Driver() = default; - /** - * Return a Driver specific network address for the given string - * representation of the address. - * - * @param addressString - * The string representation of the address to return. The address - * string format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _addressString_ is malformed. - */ - virtual Address getAddress(std::string const* const addressString) = 0; - - /** - * Return a Driver specific network address for the given serialized - * byte-format of the address. - * - * @param wireAddress - * The serialized byte-format of the address to be returned. The - * format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _rawAddress_ is malformed. - */ - virtual Address getAddress(WireFormatAddress const* const wireAddress) = 0; - - /** - * Return the string representation of a network address. - * - * @param address - * Address whose string representation should be returned. - */ - virtual std::string addressToString(const Address address) = 0; - - /** - * Serialize a network address into its Raw byte format. - * - * @param address - * Address to be serialized. - * @param[out] wireAddress - * WireFormatAddress object to which the Address is serialized. - */ - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) = 0; - /** * Allocate a new Packet object from the Driver's pool of resources. The * caller must eventually release the packet by passing it to a call to @@ -187,8 +131,16 @@ class Driver { * * @param packet * Packet to be sent over the network. + * @param destination + * IP address of the packet destination. + * @param priority + * Packet's network priority; the lowest possible priority is 0. + * The highest priority is positive number defined by the Driver; + * the highest priority can be queried by calling the method + * getHighestPacketPriority(). */ - virtual void sendPacket(Packet* packet) = 0; + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -218,9 +170,12 @@ class Driver { * * @param maxPackets * The maximum number of Packet objects that should be returned by - * this method. + * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. + * @param[out] sourceAddresses + * Source IP addresses of the received packets are appended to this + * array in order of arrival. * * @return * Number of Packet objects being returned. @@ -228,7 +183,8 @@ class Driver { * @sa Driver::releasePackets() */ virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]) = 0; + Packet* receivedPackets[], + IpAddress sourceAddresses[]) = 0; /** * Release a collection of Packet objects back to the Driver. Every @@ -273,10 +229,10 @@ class Driver { virtual uint32_t getBandwidth() = 0; /** - * Return this Driver's local network Address which it uses as the source - * Address for outgoing packets. + * Return this Driver's local IP address which it uses as the source + * address for outgoing packets. */ - virtual Address getLocalAddress() = 0; + virtual IpAddress getLocalAddress() = 0; /** * Return the number of bytes that have been passed to the Driver through diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index dafb05f..f15d575 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -53,14 +53,14 @@ class DpdkDriver : public Driver { * has exclusive access to DPDK. Note: This call will initialize the DPDK * EAL with default values. * - * @param port - * Selects which physical port to use for communication. + * @param ifname + * Selects which network interface to use for communication. * @param config * Optional configuration parameters (see Config). * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, const Config* const config = nullptr); + DpdkDriver(const char* ifname, const Config* const config = nullptr); /** * Construct a DpdkDriver and initialize the DPDK EAL using the provided @@ -75,7 +75,7 @@ class DpdkDriver : public Driver { * overriding the default affinity set by rte_eal_init(). * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param argc * Parameter passed to rte_eal_init(). * @param argv @@ -85,7 +85,7 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, int argc, char* argv[], + DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); /// Used to signal to the DpdkDriver constructor that the DPDK EAL should @@ -101,7 +101,7 @@ class DpdkDriver : public Driver { * called before calling this constructor. * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param _ * Parameter is used only to define this constructors alternate * signature. @@ -110,29 +110,20 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, NoEalInit _, const Config* const config = nullptr); + DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config = nullptr); /** * DpdkDriver Destructor. */ virtual ~DpdkDriver(); - /// See Driver::getAddress() - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - - /// See Driver::addressToString() - virtual std::string addressToString(const Address address); - - /// See Driver::addressToWireFormat() - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - /// See Driver::allocPacket() virtual Packet* allocPacket(); /// See Driver::sendPacket() - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); /// See Driver::cork() virtual void cork(); @@ -142,7 +133,8 @@ class DpdkDriver : public Driver { /// See Driver::receivePackets() virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); /// See Driver::releasePackets() virtual void releasePackets(Packet* packets[], uint16_t numPackets); @@ -157,7 +149,7 @@ class DpdkDriver : public Driver { virtual uint32_t getBandwidth(); /// See Driver::getLocalAddress() - virtual Driver::Address getLocalAddress(); + virtual IpAddress getLocalAddress(); /// See Driver::getQueuedBytes(); virtual uint32_t getQueuedBytes(); diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 8413778..5f54586 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -34,7 +34,7 @@ const int NUM_PRIORITIES = 8; /// Maximum number of bytes a packet can hold. const uint32_t MAX_PAYLOAD_SIZE = 1500; -/// A set of methods to contol the underlying FakeNetwork's behavior. +/// A set of methods to control the underlying FakeNetwork's behavior. namespace FakeNetworkConfig { /** * Configure the FakeNetwork to drop packets at the specified loss rate. @@ -51,43 +51,35 @@ void setPacketLossRate(double lossRate); * * @sa Driver::Packet */ -class FakePacket : public Driver::Packet { - public: +struct FakePacket { + /// C-style "inheritance"; used to maintain the base struct as a POD type. + Driver::Packet base; + + /// Raw storage for this packets payload. + char buf[MAX_PAYLOAD_SIZE]; + + /// Source IpAddress of the packet. + IpAddress sourceIp; + /** * FakePacket constructor. - * - * @param maxPayloadSize - * The maximum number of bytes this packet can hold. */ explicit FakePacket() - : Packet(buf, 0) + : base{.payload = buf, .length = 0} + , buf() + , sourceIp() {} /** * Copy constructor. */ FakePacket(const FakePacket& other) - : Packet(buf, other.length) + : base{.payload = buf, .length = other.base.length} + , buf() + , sourceIp() { - address = other.address; - priority = other.priority; - memcpy(buf, other.buf, MAX_PAYLOAD_SIZE); + memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); } - - virtual ~FakePacket() {} - - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() - { - return MAX_PAYLOAD_SIZE; - } - - private: - /// Raw storage for this packets payload. - char buf[MAX_PAYLOAD_SIZE]; - - // Disable Assignment - FakePacket& operator=(const FakePacket&) = delete; }; /// Holds the incoming packets for a particular driver. @@ -117,25 +109,22 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - virtual std::string addressToString(const Address address); - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); virtual void releasePackets(Packet* packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); virtual uint32_t getBandwidth(); - virtual Address getLocalAddress(); + virtual IpAddress getLocalAddress(); virtual uint32_t getQueuedBytes(); private: /// Identifier for this driver on the fake network. - uint64_t localAddressId; + uint32_t localAddressId; /// Holds the incoming packets for this driver. FakeNIC nic; diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index dec090c..b430118 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -37,6 +37,17 @@ namespace Homa { template using unique_ptr = std::unique_ptr; +/** + * Represents a socket address to (from) which we can send (receive) messages. + */ +struct SocketAddress { + /// IPv4 address in host byte order. + IpAddress ip; + + /// Port number in host byte order. + uint16_t port; +}; + /** * Represents an array of bytes that has been received over the network. * @@ -60,51 +71,19 @@ class InMessage { virtual void acknowledge() const = 0; /** - * Returns true if the sender is no longer waiting for this message to be - * processed; false otherwise. - */ - virtual bool dropped() const = 0; - - /** - * Inform the sender that this message has failed to be processed. - */ - virtual void fail() const = 0; - - /** - * Get the contents of a specified range of bytes in the Message by - * copying them into the provided destination memory region. - * - * @param offset - * The number of bytes in the Message preceding the range of bytes - * being requested. - * @param destination - * The pointer to the memory region into which the requested byte - * range will be copied. The caller must ensure that the buffer is - * big enough to hold the requested number of bytes. - * @param count - * The number of bytes being requested. + * Get the underlying contiguous memory buffer serving as message storage. + * The buffer will be large enough to hold at least length() bytes. * * @return - * The number of bytes actually copied out. This number may be less - * than "num" if the requested byte range exceeds the range of - * bytes in the Message. + * Pointer to the message buffer. */ - virtual size_t get(size_t offset, void* destination, - size_t count) const = 0; + virtual void* data() const = 0; /** * Return the number of bytes this Message contains. */ virtual size_t length() const = 0; - /** - * Remove a number of bytes from the beginning of the Message. - * - * @param count - * Number of bytes to remove. - */ - virtual void strip(size_t count) = 0; - protected: /** * Signal that this message is no longer needed. The caller should not @@ -136,15 +115,17 @@ class OutMessage { * Options with which an OutMessage can be sent. */ enum Options { - NONE = 0, //< Default send behavior. - NO_RETRY = 1 << 0, //< Message will not be resent if recoverable send - //< failure occurs; provides at-most-once delivery - //< of messages. - NO_KEEP_ALIVE = 1 << 1, //< Once the Message has been sent, Homa will - //< not automatically ping the Message's - //< receiver to ensure the receiver is still - //< alive and the Message will not "timeout" - //< due to receiver inactivity. + /// Default send behavior. + NONE = 0, + + /// Message will not be resent if recoverable send failure occurs; + /// provides at-most-once delivery of messages. + NO_RETRY = 1 << 0, + + /// Once the Message has been sent, Homa will not automatically ping the + /// Message's receiver to ensure the receiver is still alive and the + /// Message will not "timeout" due to receiver inactivity. + NO_KEEP_ALIVE = 1 << 1, }; /** @@ -220,11 +201,11 @@ class OutMessage { * Send this message to the destination. * * @param destination - * Address of the transport to which this message will be sent. + * Network address to which this message will be sent. * @param options * Flags to request non-default sending behavior. */ - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE) = 0; protected: @@ -265,10 +246,12 @@ class Transport { /** * Allocate Message that can be sent with this Transport. * + * @param sourcePort + * Port number of the socket from which the message will be sent. * @return * A pointer to the allocated message. */ - virtual Homa::unique_ptr alloc() = 0; + virtual Homa::unique_ptr alloc(uint16_t sourcePort) = 0; /** * Check for and return a Message sent to this Transport if available. diff --git a/include/Homa/Perf.h b/include/Homa/Perf.h index a163539..f213acd 100644 --- a/include/Homa/Perf.h +++ b/include/Homa/Perf.h @@ -38,6 +38,27 @@ struct Stats { /// CPU time spent running Homa with no work to do in cycles. uint64_t idle_cycles; + /// Number of InMessages that have been allocated by the Transport. + uint64_t allocated_rx_messages; + + /// Number of InMessages that have been received by the Transport. + uint64_t received_rx_messages; + + /// Number of InMessages delivered to the application. + uint64_t delivered_rx_messages; + + /// Number of InMessages released back to the Transport for destruction. + uint64_t destroyed_rx_messages; + + /// Number of OutMessages allocated for the application. + uint64_t allocated_tx_messages; + + /// Number of OutMessages released back to the transport. + uint64_t released_tx_messages; + + /// Number of OutMessages destroyed. + uint64_t destroyed_tx_messages; + /// Number of bytes sent by the transport. uint64_t tx_bytes; diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 121bb44..462e1ff 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2009-2018, Stanford University +/* Copyright (c) 2009-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -21,6 +21,14 @@ #include #include +/// Cast a member of a structure out to the containing structure. +template +P* +container_of(M* ptr, const M P::*member) +{ + return (P*)((char*)ptr - (size_t) & (reinterpret_cast(0)->*member)); +} + namespace Homa { namespace Util { @@ -52,6 +60,30 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); +/** + * Return true if the given number is a power of 2; false, otherwise. + */ +template +constexpr bool +isPowerOfTwo(num_type n) +{ + return (n > 0) && ((n & (n - 1)) == 0); +} + +/** + * Round up the result of x divided by y, where both x and y are positive + * integers. + */ +template +constexpr num_type_x +roundUpIntDiv(num_type_x x, num_type_y y) +{ + static_assert(std::is_integral::value, "Integral required."); + assert(x > 0 && y > 0); + num_type_x yy = downCast(y); + return (x + yy - 1) / yy; +} + /** * This class is used to temporarily release lock in a safe fashion. Creating * an object of this class will unlock its associated mutex; when the object diff --git a/src/ControlPacket.h b/src/ControlPacket.h index a8da070..17310af 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -31,21 +31,19 @@ namespace ControlPacket { * @param driver * Driver with which to send the packet. * @param address - * Destination address for the packet to be sent. + * Destination IP address for the packet to be sent. * @param args * Arguments to PacketHeaderType's constructor. */ template void -send(Driver* driver, Driver::Address address, Args&&... args) +send(Driver* driver, IpAddress address, Args&&... args) { Driver::Packet* packet = driver->allocPacket(); new (packet->payload) PacketHeaderType(static_cast(args)...); packet->length = sizeof(PacketHeaderType); - packet->address = address; - packet->priority = driver->getHighestPacketPriority(); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, address, driver->getHighestPacketPriority()); driver->releasePackets(&packet, 1); } diff --git a/src/Drivers/RawAddressType.h b/src/Driver.cc similarity index 57% rename from src/Drivers/RawAddressType.h rename to src/Driver.cc index 1def76d..b29c828 100644 --- a/src/Drivers/RawAddressType.h +++ b/src/Driver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2018-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -13,26 +13,26 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef HOMA_DRIVERS_RAWADDRESSTYPE_H -#define HOMA_DRIVERS_RAWADDRESSTYPE_H +#include + +#include "StringUtil.h" namespace Homa { -namespace Drivers { -/** - * Identifies a particular raw serialized byte-format for a Driver::Address - * supported by this project. The types are enumerated here in one place to - * ensure drivers do have overlapping type identifiers. New drivers that wish - * to claim a type id should add an entry to this enum. - * - * @sa Driver::Address::Raw - */ -enum RawAddressType { - FAKE = 0, - MAC = 1, -}; +std::string +IpAddress::toString(IpAddress address) +{ + uint32_t ip = address.addr; + return StringUtil::format("%d.%d.%d.%d", (ip >> 24) & 0xff, + (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} -} // namespace Drivers -} // namespace Homa +IpAddress +IpAddress::fromString(const char* addressStr) +{ + unsigned int b0, b1, b2, b3; + sscanf(addressStr, "%u.%u.%u.%u", &b0, &b1, &b2, &b3); + return IpAddress{(b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3}; +} -#endif // HOMA_DRIVERS_RAWADDRESSTYPE_H +} // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index c536159..c27d1df 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -21,50 +21,22 @@ namespace Homa { namespace Drivers { namespace DPDK { -DpdkDriver::DpdkDriver(int port, const Config* const config) - : pImpl(new Impl(port, config)) +DpdkDriver::DpdkDriver(const char* ifname, const Config* const config) + : pImpl(new Impl(ifname, config)) {} -DpdkDriver::DpdkDriver(int port, int argc, char* argv[], +DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config) - : pImpl(new Impl(port, argc, argv, config)) + : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(int port, NoEalInit _, const Config* const config) - : pImpl(new Impl(port, _, config)) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config) + : pImpl(new Impl(ifname, _, config)) {} DpdkDriver::~DpdkDriver() = default; -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(std::string const* const addressString) -{ - return pImpl->getAddress(addressString); -} - -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(WireFormatAddress const* const wireAddress) -{ - return pImpl->getAddress(wireAddress); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::addressToString(const Address address) -{ - return pImpl->addressToString(address); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - pImpl->addressToWireFormat(address, wireAddress); -} - /// See Driver::allocPacket() Driver::Packet* DpdkDriver::allocPacket() @@ -74,9 +46,9 @@ DpdkDriver::allocPacket() /// See Driver::sendPacket() void -DpdkDriver::sendPacket(Packet* packet) +DpdkDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - return pImpl->sendPacket(packet); + return pImpl->sendPacket(packet, destination, priority); } /// See Driver::cork() @@ -95,9 +67,10 @@ DpdkDriver::uncork() /// See Driver::receivePackets() uint32_t -DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { - return pImpl->receivePackets(maxPackets, receivedPackets); + return pImpl->receivePackets(maxPackets, receivedPackets, sourceAddresses); } /// See Driver::releasePackets() void @@ -128,7 +101,7 @@ DpdkDriver::getBandwidth() } /// See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::getLocalAddress() { return pImpl->getLocalAddress(); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e658ccb..b4372ad 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -19,10 +19,16 @@ #include "DpdkDriverImpl.h" +#include +#include #include +#include #include +#include + #include "CodeLocation.h" +#include "Homa/Util.h" #include "StringUtil.h" namespace Homa { @@ -45,7 +51,7 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : Driver::Packet(data, 0) + : base{.payload = data, .length = 0} , bufType(MBUF) , bufRef() { @@ -59,7 +65,7 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : Driver::Packet(overflowBuf->data, 0) + : base{.payload = overflowBuf->data, .length = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -69,17 +75,21 @@ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, const Config* const config) - : Impl(port, default_eal_argc, const_cast(default_eal_argv), config) +DpdkDriver::Impl::Impl(const char* ifname, const Config* const config) + : Impl(ifname, default_eal_argc, const_cast(default_eal_argv), + config) {} /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, int argc, char* argv[], +DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -87,6 +97,7 @@ DpdkDriver::Impl::Impl(int port, int argc, char* argv[], , packetLock() , packetPool() , overflowBufferPool() + , mbufsOutstanding(0) , mbufPool(nullptr) , loopbackRing(nullptr) , rx() @@ -124,10 +135,14 @@ DpdkDriver::Impl::Impl(int port, int argc, char* argv[], /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, __attribute__((__unused__)) NoEalInit _, +DpdkDriver::Impl::Impl(const char* ifname, + __attribute__((__unused__)) NoEalInit _, const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -159,55 +174,51 @@ DpdkDriver::Impl::~Impl() rte_mempool_free(mbufPool); } -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(std::string const* const addressString) -{ - return MacAddress(addressString->c_str()).toAddress(); -} - -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - return MacAddress(wireAddress).toAddress(); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::Impl::addressToString(const Driver::Address address) -{ - return MacAddress(address).toString(); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::Impl::addressToWireFormat(const Driver::Address address, - Driver::WireFormatAddress* wireAddress) -{ - MacAddress(address).toWireFormat(wireAddress); -} - // See Driver::allocPacket() -DpdkDriver::Impl::Packet* +Driver::Packet* DpdkDriver::Impl::allocPacket() { - DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); - if (unlikely(packet == nullptr)) { - SpinLock::Lock lock(packetLock); + DpdkDriver::Impl::Packet* packet = nullptr; + SpinLock::Lock lock(packetLock); + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); + if (unlikely(NULL == mbuf)) { + uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); + uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); + NOTICE( + "Failed to allocate an mbuf packet buffer; " + "%u mbufs available, %u mbufs in use, %u mbufs held by app", + numMbufsAvail, numMbufsInUse, mbufsOutstanding); + } else { + char* buf = rte_pktmbuf_append( + mbuf, Homa::Util::downCast(PACKET_HDR_LEN + + MAX_PAYLOAD_SIZE)); + if (unlikely(NULL == buf)) { + NOTICE("rte_pktmbuf_append call failed; dropping packet"); + rte_pktmbuf_free(mbuf); + } else { + packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); + mbufsOutstanding++; + } + } + } + if (packet == nullptr) { OverflowBuffer* buf = overflowBufferPool.construct(); packet = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return packet; + return &packet->base; } // See Driver::sendPacket() void -DpdkDriver::Impl::sendPacket(Driver::Packet* packet) +DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, + int priority) { + ; DpdkDriver::Impl::Packet* pkt = - static_cast(packet); + container_of(packet, DpdkDriver::Impl::Packet, base); struct rte_mbuf* mbuf = nullptr; // If the packet is held in an Overflow buffer, we need to copy it out // into a new mbuf. @@ -224,14 +235,15 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) return; } char* buf = rte_pktmbuf_append( - mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->length)); + mbuf, + Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); rte_pktmbuf_free(mbuf); return; } char* data = buf + PACKET_HDR_LEN; - rte_memcpy(data, pkt->payload, pkt->length); + rte_memcpy(data, pkt->base.payload, pkt->base.length); } else { mbuf = pkt->bufRef.mbuf; @@ -246,9 +258,14 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // Fill out the destination and source MAC addresses plus the Ethernet // frame type (i.e., IEEE 802.1Q VLAN tagging). - MacAddress macAddr(pkt->address); + auto it = arpTable.find(destination); + if (it == arpTable.end()) { + WARNING("Failed to find ARP record for packet; dropping packet"); + return; + } + MacAddress& destMac = it->second; struct ether_hdr* ethHdr = rte_pktmbuf_mtod(mbuf, struct ether_hdr*); - rte_memcpy(ðHdr->d_addr, macAddr.address, ETHER_ADDR_LEN); + rte_memcpy(ðHdr->d_addr, destMac.address, ETHER_ADDR_LEN); rte_memcpy(ðHdr->s_addr, localMac.address, ETHER_ADDR_LEN); ethHdr->ether_type = rte_cpu_to_be_16(ETHER_TYPE_VLAN); @@ -256,13 +273,17 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // encapsulated frame (DEI and VLAN ID are not relevant and trivially // set to 0). struct vlan_hdr* vlanHdr = reinterpret_cast(ethHdr + 1); - vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[pkt->priority]); + vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[priority]); vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); + // Store our local IP address right before the payload. + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = + (uint32_t)localIp; + // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is // smaller, trim the mbuf to size to avoid sending unecessary bits. - uint32_t actualLength = PACKET_HDR_LEN + pkt->length; + uint32_t actualLength = PACKET_HDR_LEN + pkt->base.length; uint32_t mbufDataLength = rte_pktmbuf_pkt_len(mbuf); if (actualLength < mbufDataLength) { if (rte_pktmbuf_trim(mbuf, mbufDataLength - actualLength) < 0) { @@ -274,14 +295,18 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) } // loopback if src mac == dst mac - if (localMac.toAddress() == pkt->address) { + if (localMac == destMac) { struct rte_mbuf* mbuf_clone = rte_pktmbuf_clone(mbuf, mbufPool); if (unlikely(mbuf_clone == NULL)) { WARNING("Failed to clone packet for loopback; dropping packet"); + return; } int ret = rte_ring_enqueue(loopbackRing, mbuf_clone); if (unlikely(ret != 0)) { - WARNING("rte_ring_enqueue returned %d; packet may be lost?", ret); + WARNING( + "rte_ring_enqueue returned %d with %u packets queued; " + "packet may be lost?", + ret, rte_ring_count(loopbackRing)); rte_pktmbuf_free(mbuf_clone); } return; @@ -327,7 +352,8 @@ DpdkDriver::Impl::uncork() // See Driver::receivePackets() uint32_t DpdkDriver::Impl::receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]) + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]) { uint32_t numPacketsReceived = 0; @@ -335,26 +361,26 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, if (maxPackets > MAX_PACKETS_AT_ONCE) { maxPackets = MAX_PACKETS_AT_ONCE; } + uint32_t maxLoopbackPkts = maxPackets / 2; + struct rte_mbuf* mPkts[MAX_PACKETS_AT_ONCE]; // attempt to dequeue a batch of received packets from the NIC // as well as from the loopback ring. - uint32_t incomingPkts = 0; uint32_t loopbackPkts = 0; + uint32_t incomingPkts = 0; { SpinLock::Lock lock(rx.mutex); - incomingPkts = rte_eth_rx_burst( - port, 0, mPkts, Homa::Util::downCast(maxPackets)); - loopbackPkts = rte_ring_count(loopbackRing); - if (incomingPkts + loopbackPkts > maxPackets) { - loopbackPkts = maxPackets - incomingPkts; - } + loopbackPkts = std::min(loopbackPkts, maxLoopbackPkts); for (uint32_t i = 0; i < loopbackPkts; i++) { - rte_ring_dequeue(loopbackRing, reinterpret_cast( - &mPkts[incomingPkts + i])); + rte_ring_dequeue(loopbackRing, reinterpret_cast(&mPkts[i])); } + + incomingPkts = rte_eth_rx_burst( + port, 0, &(mPkts[loopbackPkts]), + Homa::Util::downCast(maxPackets - loopbackPkts)); } uint32_t totalPkts = incomingPkts + loopbackPkts; @@ -390,6 +416,9 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, } } + uint32_t srcIp = *rte_pktmbuf_mtod_offset(m, uint32_t*, headerLength); + headerLength += sizeof(srcIp); + payload += sizeof(srcIp); assert(rte_pktmbuf_pkt_len(m) >= headerLength); uint32_t length = rte_pktmbuf_pkt_len(m) - headerLength; assert(length <= MAX_PAYLOAD_SIZE); @@ -397,12 +426,21 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, DpdkDriver::Impl::Packet* packet = nullptr; { SpinLock::Lock lock(packetLock); - packet = packetPool.construct(m, payload); + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + packet = packetPool.construct(m, payload); + mbufsOutstanding++; + } else { + OverflowBuffer* buf = overflowBufferPool.construct(); + rte_memcpy(payload, buf->data, length); + packet = packetPool.construct(buf); + } } - packet->address = MacAddress(ethHdr->s_addr.addr_bytes).toAddress(); - packet->length = length; + packet->base.length = length; - receivedPackets[numPacketsReceived++] = packet; + receivedPackets[numPacketsReceived] = &packet->base; + sourceAddresses[numPacketsReceived] = {srcIp}; + ++numPacketsReceived; } return numPacketsReceived; @@ -415,9 +453,10 @@ DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) for (uint16_t i = 0; i < numPackets; ++i) { SpinLock::Lock lock(packetLock); DpdkDriver::Impl::Packet* packet = - static_cast(packets[i]); + container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { rte_pktmbuf_free(packet->bufRef.mbuf); + mbufsOutstanding--; } else { overflowBufferPool.destroy(packet->bufRef.overflowBuf); } @@ -447,10 +486,10 @@ DpdkDriver::Impl::getBandwidth() } // See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::Impl::getLocalAddress() { - return localMac.toAddress(); + return localIp; } // See Driver::getQueuedBytes(); @@ -490,11 +529,77 @@ DpdkDriver::Impl::_eal_init(int argc, char* argv[]) void DpdkDriver::Impl::_init() { - struct ether_addr mac; struct rte_eth_conf portConf; int ret; uint16_t mtu; + // Populate the ARP table with records in /proc/net/arp (inspired by + // net-tools/arp.c) + std::ifstream input("/proc/net/arp"); + for (std::string line; getline(input, line);) { + char ip[100]; + char hwa[100]; + char mask[100]; + char dev[100]; + int type, flags; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, + &type, &flags, hwa, mask, dev); + if (cols != 6) + continue; + arpTable.emplace(IpAddress::fromString(ip), hwa); + } + + // Use ioctl to obtain the IP and MAC addresses of the network interface. + struct ifreq ifr; + ifname.copy(ifr.ifr_name, ifname.length()); + ifr.ifr_name[ifname.length() + 1] = 0; + if (ifname.length() >= sizeof(ifr.ifr_name)) { + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Interface name %s too long", ifname.c_str())); + } + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd == -1) { + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to create socket: %s", strerror(errno))); + } + + if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to obtain IP address: %s", error)); + } + localIp = {be32toh(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr.s_addr)}; + + if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to obtain MAC address: %s", error)); + } + close(fd); + memcpy(localMac.address, ifr.ifr_hwaddr.sa_data, 6); + + // Iterate over ethernet devices to locate the port identifier. + int p; + RTE_ETH_FOREACH_DEV(p) + { + struct ether_addr mac; + rte_eth_macaddr_get(p, &mac); + if (MacAddress(mac.addr_bytes) == localMac) { + port = p; + break; + } + } + NOTICE("Using interface %s, ip %s, mac %s, port %u", ifname.c_str(), + IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), + port); + std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); @@ -518,10 +623,6 @@ DpdkDriver::Impl::_init() StringUtil::format("Ethernet port %u doesn't exist", port)); } - // Read the MAC address from the NIC via DPDK. - rte_eth_macaddr_get(port, &mac); - new (const_cast(&localMac)) MacAddress(mac.addr_bytes); - // configure some default NIC port parameters memset(&portConf, 0, sizeof(portConf)); portConf.rxmode.max_rx_pkt_len = ETHER_MAX_VLAN_FRAME_LEN; @@ -567,14 +668,6 @@ DpdkDriver::Impl::_init() "Cannot allocate buffer for tx on port %u", port)); } rte_eth_tx_buffer_init(tx.buffer, MAX_PKT_BURST); - ret = rte_eth_tx_buffer_set_err_callback(tx.buffer, txBurstErrorCallback, - &tx.stats); - if (ret < 0) { - throw DriverInitFailure( - HERE_STR, - StringUtil::format( - "Cannot set error callback for tx buffer on port %u", port)); - } // get the current MTU. ret = rte_eth_dev_get_mtu(port, &mtu); @@ -633,7 +726,8 @@ DpdkDriver::Impl::_init() // create an in-memory ring, used as a software loopback in order to // handle packets that are addressed to the localhost. - loopbackRing = rte_ring_create(ringName.c_str(), 4096, SOCKET_ID_ANY, 0); + loopbackRing = + rte_ring_create(ringName.c_str(), NB_LOOPBACK_SLOTS, SOCKET_ID_ANY, 0); if (NULL == loopbackRing) { throw DriverInitFailure( HERE_STR, StringUtil::format("Failed to allocate loopback ring: %s", @@ -644,57 +738,6 @@ DpdkDriver::Impl::_init() localMac.toString().c_str(), bandwidthMbps.load(), mtu); } -/** - * Helper function to try to allocation a new Dpdk Packet backed by an mbuf. - * - * @return - * The newly allocated Dpdk Packet; nullptr if the mbuf allocation - * failed. - */ -DpdkDriver::Impl::Packet* -DpdkDriver::Impl::_allocMbufPacket() -{ - DpdkDriver::Impl::Packet* packet = nullptr; - uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); - if (unlikely(numMbufsAvail <= NB_MBUF_RESERVED)) { - uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); - NOTICE( - "Driver is running low on mbuf packet buffers; " - "%u mbufs available, %u mbufs in use", - numMbufsAvail, numMbufsInUse); - return nullptr; - } - - struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); - - if (unlikely(NULL == mbuf)) { - uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); - uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); - NOTICE( - "Failed to allocate an mbuf packet buffer; " - "%u mbufs available, %u mbufs in use", - numMbufsAvail, numMbufsInUse); - return nullptr; - } - - char* buf = rte_pktmbuf_append( - mbuf, - Homa::Util::downCast(PACKET_HDR_LEN + MAX_PAYLOAD_SIZE)); - - if (unlikely(NULL == buf)) { - NOTICE("rte_pktmbuf_append call failed; dropping packet"); - rte_pktmbuf_free(mbuf); - return nullptr; - } - - // Perform packet operations with the lock held. - { - SpinLock::Lock _(packetLock); - packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); - } - return packet; -} - /** * Called before a burst of packets is transmitted to update the transmit stats. * @@ -734,24 +777,6 @@ DpdkDriver::Impl::txBurstCallback(uint16_t port_id, uint16_t queue, return nb_pkts; } -/** - * Called to process the packets cannot be sent. - */ -void -DpdkDriver::Impl::txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, - void* userdata) -{ - uint64_t bytesDropped = 0; - for (int i = 0; i < unsent; ++i) { - bytesDropped += rte_pktmbuf_pkt_len(pkts[i]); - rte_pktmbuf_free(pkts[i]); - } - Tx::Stats* stats = static_cast(userdata); - SpinLock::Lock lock(stats->mutex); - assert(bytesDropped <= stats->bufferedBytes); - stats->bufferedBytes -= bytesDropped; -} - } // namespace DPDK } // namespace Drivers } // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 289e83f..9425d34 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -28,6 +28,7 @@ #include #include +#include #include "MacAddress.h" #include "ObjectPool.h" @@ -44,7 +45,7 @@ const int NDESC = 256; // Maximum number of packet buffers that the memory pool can hold. The // documentation of `rte_mempool_create` suggests that the optimum value // (in terms of memory usage) of this number is a power of two minus one. -const int NB_MBUF = 8191; +const int NB_MBUF = 16383; // If cache_size is non-zero, the rte_mempool library will try to limit the // accesses to the common lockless pool, by maintaining a per-lcore object @@ -56,7 +57,10 @@ const int MEMPOOL_CACHE_SIZE = 32; // The number of mbufs the driver should try to reserve for receiving packets. // Prevents applications from claiming more mbufs once the number of available // mbufs reaches this level. -const uint32_t NB_MBUF_RESERVED = 1024; +const uint32_t NB_MBUF_RESERVED = 4096; + +// The number of packets that can be held in loopback before they get dropped +const uint32_t NB_LOOPBACK_SLOTS = 4096; // The number of packets that the driver can buffer while corked. const uint16_t MAX_PKT_BURST = 32; @@ -65,8 +69,13 @@ const uint16_t MAX_PKT_BURST = 32; /// field defined in the VLAN tag to specify the packet priority. const uint32_t VLAN_TAG_LEN = 4; -// Size of Ethernet header including VLAN tag, in bytes. -const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN; +/// Strictly speaking, this DPDK driver is supposed to send/receive IP packets; +/// however, it currently only records the source IP address right after the +/// Ethernet header for simplicity. +const uint32_t IP_HDR_LEN = sizeof(IpAddress); + +// Size of Ethernet header including VLAN tag plus IP header, in bytes. +const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN + IP_HDR_LEN; // The MTU (Maximum Transmission Unit) size of an Ethernet frame, which is the // maximum size of the packet an Ethernet frame can carry in its payload. This @@ -104,16 +113,12 @@ class DpdkDriver::Impl { * Dpdk specific Packet object used to track a its lifetime and * contents. */ - class Packet : public Driver::Packet { - public: + struct Packet { explicit Packet(struct rte_mbuf* mbuf, void* data); explicit Packet(OverflowBuffer* overflowBuf); - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() - { - return MAX_PAYLOAD_SIZE; - } + /// C-style "inheritance" + Driver::Packet base; /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. @@ -128,53 +133,52 @@ class DpdkDriver::Impl { /// The memory location of this packet's header. The header should be /// PACKET_HDR_LEN in length. void* header; - - private: - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; }; - Impl(int port, const Config* const config = nullptr); - Impl(int port, int argc, char* argv[], + Impl(const char* ifname, const Config* const config = nullptr); + Impl(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); - Impl(int port, NoEalInit _, const Config* const config = nullptr); + Impl(const char* ifname, NoEalInit _, const Config* const config = nullptr); virtual ~Impl(); // Interface Methods - Driver::Address getAddress(std::string const* const addressString); - Driver::Address getAddress(WireFormatAddress const* const wireAddress); - std::string addressToString(const Address address); - void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - Packet* allocPacket(); - void sendPacket(Driver::Packet* packet); + Driver::Packet* allocPacket(); + void sendPacket(Driver::Packet* packet, IpAddress destination, + int priority); void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]); + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]); void releasePackets(Driver::Packet* packets[], uint16_t numPackets); int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); uint32_t getBandwidth(); - Driver::Address getLocalAddress(); + IpAddress getLocalAddress(); uint32_t getQueuedBytes(); private: void _eal_init(int argc, char* argv[]); void _init(); - Packet* _allocMbufPacket(); static uint16_t txBurstCallback(uint16_t port_id, uint16_t queue, struct rte_mbuf* pkts[], uint16_t nb_pkts, void* user_param); - static void txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, - void* userdata); + + /// Name of the Linux network interface to be used by DPDK. + std::string ifname; /// Stores the NIC's physical port id addressed by the instantiated /// driver. - const uint16_t port; + uint16_t port; - /// Stores the address of the NIC (either native or set by override). - const MacAddress localMac; + /// Address resolution table that translates IP addresses to MAC addresses. + std::unordered_map arpTable; + + /// Stores the IpAddress of the driver. + IpAddress localIp; + + /// Stores the HW address of the NIC (either native or set by override). + MacAddress localMac; /// Stores the driver's maximum network packet priority (either default or /// set by override). @@ -185,10 +189,13 @@ class DpdkDriver::Impl { /// Provides memory allocation for the DPDK specific implementation of a /// Driver Packet. - ObjectPool packetPool; + ObjectPool packetPool; /// Provides memory allocation for packet storage when mbuf are running out. - ObjectPool overflowBufferPool; + ObjectPool overflowBufferPool; + + /// The number of mbufs that have been given out to callers in Packets. + uint32_t mbufsOutstanding; /// Holds packet buffers that are dequeued from the NIC's HW queues /// via DPDK. diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 0178851..e47f27a 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -18,7 +18,6 @@ #include "StringUtil.h" #include "../../CodeLocation.h" -#include "../RawAddressType.h" namespace Homa { namespace Drivers { @@ -55,33 +54,6 @@ MacAddress::MacAddress(const char* macStr) address[i] = Util::downCast(bytes[i]); } -/** - * Create a new address from a given address in its raw byte format. - * @param raw - * The raw bytes format. - * - * @sa Driver::Address::Raw - */ -MacAddress::MacAddress(const Driver::WireFormatAddress* const wireAddress) -{ - if (wireAddress->type != RawAddressType::MAC) { - throw BadAddress(HERE_STR, "Bad address: Raw format is not type MAC"); - } - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(address, wireAddress->bytes, 6); -} - -/** - * Create a new address given the Driver::Address representation. - * - * @param addr - * The Driver::Address representation of an address. - */ -MacAddress::MacAddress(const Driver::Address addr) -{ - memcpy(address, &addr, 6); -} - /** * Return the string representation of this address. */ @@ -94,31 +66,6 @@ MacAddress::toString() const return buf; } -/** - * Serialized this address into a wire format. - * - * @param[out] wireAddress - * WireFormatAddress object to which the this address is serialized. - */ -void -MacAddress::toWireFormat(Driver::WireFormatAddress* wireAddress) const -{ - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(wireAddress->bytes, address, 6); - wireAddress->type = RawAddressType::MAC; -} - -/** - * Return a Driver::Address representation of this address. - */ -Driver::Address -MacAddress::toAddress() const -{ - Driver::Address addr = 0; - memcpy(&addr, address, 6); - return addr; -} - /** * @return * True if the MacAddress consists of all zero bytes, false if not. diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 1106eec..33f47a5 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -28,14 +28,19 @@ namespace DPDK { struct MacAddress { explicit MacAddress(const uint8_t raw[6]); explicit MacAddress(const char* macStr); - explicit MacAddress(const Driver::WireFormatAddress* const wireAddress); - explicit MacAddress(const Driver::Address addr); MacAddress(const MacAddress&) = default; std::string toString() const; - void toWireFormat(Driver::WireFormatAddress* wireAddress) const; - Driver::Address toAddress() const; bool isNull() const; + /** + * Equality function for MacAddress, for use in std::unordered_maps etc. + */ + bool operator==(const MacAddress& other) const + { + return (*(uint32_t*)(address + 0) == *(uint32_t*)(other.address + 0)) && + (*(uint16_t*)(address + 4) == *(uint16_t*)(other.address + 4)); + } + /// The raw bytes of the MAC address. uint8_t address[6]; }; diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 329c309..9b8b8ae 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -15,8 +15,6 @@ #include "MacAddress.h" -#include "../RawAddressType.h" - #include namespace Homa { @@ -35,26 +33,6 @@ TEST(MacAddressTest, constructorString) EXPECT_EQ("de:ad:be:ef:98:76", MacAddress("de:ad:be:ef:98:76").toString()); } -TEST(MacAddressTest, constructorWireFormatAddress) -{ - uint8_t bytes[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::MAC; - memcpy(wireformatAddress.bytes, bytes, 6); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(&wireformatAddress).toString()); - - wireformatAddress.type = RawAddressType::FAKE; - EXPECT_THROW(MacAddress address(&wireformatAddress), BadAddress); -} - -TEST(MacAddressTest, constructorAddress) -{ - uint8_t raw[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - MacAddress(raw).toString(); - Driver::Address addr = MacAddress("de:ad:be:ef:98:76").toAddress(); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(addr).toString()); -} - TEST(MacAddressTest, construct_DefaultCopy) { MacAddress source("de:ad:be:ef:98:76"); @@ -67,24 +45,6 @@ TEST(MacAddressTest, toString) // tested sufficiently in constructor tests } -TEST(MacAddressTest, toWireFormat) -{ - Driver::WireFormatAddress wireformatAddress; - MacAddress("de:ad:be:ef:98:76").toWireFormat(&wireformatAddress); - EXPECT_EQ(RawAddressType::MAC, wireformatAddress.type); - EXPECT_EQ(0xde, wireformatAddress.bytes[0]); - EXPECT_EQ(0xad, wireformatAddress.bytes[1]); - EXPECT_EQ(0xbe, wireformatAddress.bytes[2]); - EXPECT_EQ(0xef, wireformatAddress.bytes[3]); - EXPECT_EQ(0x98, wireformatAddress.bytes[4]); - EXPECT_EQ(0x76, wireformatAddress.bytes[5]); -} - -TEST(MacAddressTest, toAddress) -{ - // Tested in constructorAddress -} - TEST(MacAddressTest, isNull) { uint8_t rawNull[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; diff --git a/src/Drivers/Fake/FakeAddressTest.cc b/src/Drivers/Fake/FakeAddressTest.cc deleted file mode 100644 index 67cef78..0000000 --- a/src/Drivers/Fake/FakeAddressTest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include - -#include "FakeAddress.h" - -#include "../RawAddressType.h" - -namespace Homa { -namespace Drivers { -namespace Fake { -namespace { - -TEST(FakeAddressTest, constructor_id) -{ - FakeAddress address(42); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str) -{ - FakeAddress address("42"); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str_bad) -{ - EXPECT_THROW(FakeAddress address("D42"), BadAddress); -} - -TEST(FakeAddressTest, constructor_raw) -{ - Driver::Address::Raw raw; - raw.type = RawAddressType::FAKE; - *reinterpret_cast(raw.bytes) = 42; - - FakeAddress address(&raw); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_raw_bad) -{ - Driver::Address::Raw raw; - raw.type = !RawAddressType::FAKE; - - EXPECT_THROW(FakeAddress address(&raw), BadAddress); -} - -TEST(FakeAddressTest, toString) -{ - // tested sufficiently in constructor tests -} - -TEST(FakeAddressTest, toAddressId) -{ - EXPECT_THROW(FakeAddress::toAddressId("D42"), BadAddress); -} - -} // namespace -} // namespace Fake -} // namespace Drivers -} // namespace Homa diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 6200a49..26cb102 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -56,24 +56,26 @@ static class FakeNetwork { /// Register the FakeNIC so it can receive packets. Returns the newly /// registered FakeNIC's addressId. - uint64_t registerNIC(FakeNIC* nic) + uint32_t registerNIC(FakeNIC* nic) { std::lock_guard lock(mutex); - uint64_t addressId = nextAddressId.fetch_add(1); - network.insert({addressId, nic}); + uint32_t addressId = nextAddressId.fetch_add(1); + IpAddress ipAddress{addressId}; + network.insert({ipAddress, nic}); return addressId; } /// Remove the FakeNIC from the network. - void deregisterNIC(uint64_t addressId) + void deregisterNIC(uint32_t addressId) { std::lock_guard lock(mutex); - network.erase(addressId); + IpAddress ipAddress{addressId}; + network.erase(ipAddress); } /// Deliver the provide packet to the specified destination. - void sendPacket(FakePacket* packet, Driver::Address src, - Driver::Address dst) + void sendPacket(FakePacket* packet, int priority, IpAddress src, + IpAddress dst) { FakeNIC* nic = nullptr; { @@ -92,10 +94,10 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->address = src; - assert(dstPacket->priority < NUM_PRIORITIES); - assert(dstPacket->priority >= 0); - nic->priorityQueue.at(dstPacket->priority).push_back(dstPacket); + dstPacket->sourceIp = src; + assert(priority < NUM_PRIORITIES); + assert(priority >= 0); + nic->priorityQueue.at(priority).push_back(dstPacket); } void setPacketLossRate(double lossRate) @@ -115,11 +117,10 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; - /// The FakeAddress identifier for the next FakeDriver that "connects" to - /// the FakeNetwork. - std::atomic nextAddressId; + /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. + std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. double packetLossRate; @@ -177,53 +178,6 @@ FakeDriver::~FakeDriver() fakeNetwork.deregisterNIC(localAddressId); } -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(std::string const* const addressString) -{ - char* end; - uint64_t address = std::strtoul(addressString->c_str(), &end, 10); - if (address == 0) { - throw BadAddress(HERE_STR, StringUtil::format("Bad address string: %s", - addressString->c_str())); - } - return address; -} - -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - const Address* address = - reinterpret_cast(wireAddress->bytes); - return *address; -} - -/** - * See Driver::addressToString() - */ -std::string -FakeDriver::addressToString(const Address address) -{ - char buf[21]; - snprintf(buf, sizeof(buf), "%lu", address); - return buf; -} - -/** - * See Driver::addressToWireFormat() - */ -void -FakeDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - new (reinterpret_cast(wireAddress->bytes)) Address(address); -} - /** * See Driver::allocPacket() */ @@ -231,19 +185,19 @@ Driver::Packet* FakeDriver::allocPacket() { FakePacket* packet = new FakePacket(); - return packet; + return &packet->base; } /** * See Driver::sendPacket() */ void -FakeDriver::sendPacket(Packet* packet) +FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = static_cast(packet); - Address srcAddress = getLocalAddress(); - Address dstAddress = srcPacket->address; - fakeNetwork.sendPacket(srcPacket, srcAddress, dstAddress); + FakePacket* srcPacket = container_of(packet, &FakePacket::base); + IpAddress srcAddress = getLocalAddress(); + IpAddress dstAddress = destination; + fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); queueEstimator.signalBytesSent(packet->length); } @@ -251,14 +205,17 @@ FakeDriver::sendPacket(Packet* packet) * See Driver::receivePackets() */ uint32_t -FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { std::lock_guard lock_nic(nic.mutex); uint32_t numReceived = 0; for (int i = NUM_PRIORITIES - 1; i >= 0; --i) { while (numReceived < maxPackets && !nic.priorityQueue.at(i).empty()) { - receivedPackets[numReceived] = nic.priorityQueue.at(i).front(); + FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); + receivedPackets[numReceived] = &fakePacket->base; + sourceAddresses[numReceived] = fakePacket->sourceIp; numReceived++; } } @@ -272,8 +229,7 @@ void FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - FakePacket* packet = static_cast(packets[i]); - delete packet; + delete container_of(packets[i], &FakePacket::base); } } @@ -308,10 +264,10 @@ FakeDriver::getBandwidth() /** * See Driver::getLocalAddress() */ -Driver::Address +IpAddress FakeDriver::getLocalAddress() { - return localAddressId; + return IpAddress{localAddressId}; } /** diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index e410119..cd64917 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -18,7 +18,6 @@ #include -#include "../RawAddressType.h" #include "StringUtil.h" namespace Homa { @@ -28,52 +27,18 @@ namespace { TEST(FakeDriverTest, constructor) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; EXPECT_EQ(nextAddressId, driver.localAddressId); } -TEST(FakeDriverTest, getAddress_string) -{ - FakeDriver driver; - std::string addressStr("42"); - Driver::Address address = driver.getAddress(&addressStr); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, getAddress_wireformat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::FAKE; - *reinterpret_cast(wireformatAddress.bytes) = 42; - Driver::Address address = driver.getAddress(&wireformatAddress); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToString) -{ - FakeDriver driver; - Driver::Address address = 42; - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToWireFormat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - driver.addressToWireFormat(42, &wireformatAddress); - EXPECT_EQ("42", - driver.addressToString(driver.getAddress(&wireformatAddress))); -} - TEST(FakeDriverTest, allocPacket) { FakeDriver driver; Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete packet; + delete container_of(packet, &FakePacket::base); } TEST(FakeDriverTest, sendPackets) @@ -82,13 +47,14 @@ TEST(FakeDriverTest, sendPackets) FakeDriver driver2; Driver::Packet* packets[4]; + IpAddress destinations[4]; + int prio[4]; for (int i = 0; i < 4; ++i) { packets[i] = driver1.allocPacket(); - packets[i]->address = driver2.getLocalAddress(); - packets[i]->priority = i; + destinations[i] = driver2.getLocalAddress(); + prio[i] = i; } - std::string addressStr("42"); - packets[2]->address = driver1.getAddress(&addressStr); + destinations[2] = IpAddress{42}; EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -99,7 +65,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - driver1.sendPacket(packets[0]); + driver1.sendPacket(packets[0], destinations[0], prio[0]); EXPECT_EQ(1U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -110,13 +76,12 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = static_cast( - driver2.nic.priorityQueue.at(0).front()); - EXPECT_EQ(driver1.getLocalAddress(), packet->address); + FakePacket* packet = driver2.nic.priorityQueue.at(0).front(); + EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } for (int i = 0; i < 4; ++i) { - driver1.sendPacket(packets[i]); + driver1.sendPacket(packets[i], destinations[i], prio[i]); } EXPECT_EQ(2U, driver2.nic.priorityQueue.at(0).size()); @@ -128,7 +93,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - delete packets[2]; + delete container_of(packets[2], &FakePacket::base); } TEST(FakeDriverTest, receivePackets) @@ -137,6 +102,7 @@ TEST(FakeDriverTest, receivePackets) FakeDriver driver; Driver::Packet* packets[4]; + IpAddress srcAddrs[4]; // 3 packets at priority 7 for (int i = 0; i < 3; ++i) @@ -158,7 +124,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(3U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(4U, driver.receivePackets(4, packets)); + EXPECT_EQ(4U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 4); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -170,7 +136,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -193,7 +159,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(1U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -205,7 +171,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(3U, driver.receivePackets(4, packets)); + EXPECT_EQ(3U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 3); } @@ -234,11 +200,9 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; - std::string addressStr = StringUtil::format("%lu", nextAddressId); - + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; - EXPECT_EQ(driver.getAddress(&addressStr), driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, (uint32_t)driver.getLocalAddress()); } } // namespace diff --git a/src/Intrusive.h b/src/Intrusive.h index fc3391c..47cddaa 100644 --- a/src/Intrusive.h +++ b/src/Intrusive.h @@ -486,26 +486,25 @@ class List { * @tparam ElementType * Type of the element held in the Intrusive::List. * @tparam Compare - * A weak strict ordering binary comparator for objects of ElementType. + * A weak strict ordering binary comparator for objects of ElementType + * which returns true when the first argument should be ordered before + * the second. The signature should be equivalent to the following: + * bool comp(const ElementType& a, const ElementType& b); * @param list * List that contains the element. * @parma node * Intrusive list node for the element that should be prioritized. - * @param comp - * Comparison function object which returns true when the first argument - * should be ordered before the second. The signature should be equivalent - * to the following: - * bool comp(const ElementType& a, const ElementType& b); */ -template +template void -prioritize(List* list, typename List::Node* node, - Compare comp) +prioritize(List* list, typename List::Node* node) { assert(list->contains(node)); auto it_node = list->get(node); auto it_pos = it_node; while (it_pos != list->begin()) { + Compare comp; if (!comp(*it_node, *std::prev(it_pos))) { // Found the correct location; just before it_pos. break; @@ -528,25 +527,24 @@ prioritize(List* list, typename List::Node* node, * @tparam ElementType * Type of the element held in the Intrusive::List. * @tparam Compare - * A weak strict ordering binary comparator for objects of ElementType. + * A weak strict ordering binary comparator for objects of ElementType + * which returns true when the first argument should be ordered before + * the second. The signature should be equivalent to the following: + * bool comp(const ElementType& a, const ElementType& b); * @param list * List that contains the element. * @parma node * Intrusive list node for the element that should be prioritized. - * @param comp - * Comparison function object which returns true when the first argument - * should be ordered before the second. The signature should be equivalent - * to the following: - * bool comp(const ElementType& a, const ElementType& b); */ -template +template void -deprioritize(List* list, typename List::Node* node, - Compare comp) +deprioritize(List* list, typename List::Node* node) { assert(list->contains(node)); auto it_node = list->get(node); auto it_pos = std::next(it_node); + Compare comp; while (it_pos != list->end()) { if (comp(*it_node, *it_pos)) { // Found the correct location; just before it_pos. diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 6cc5ea7..9ea6ffe 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -35,32 +35,24 @@ class MockDriver : public Driver { * * @sa Driver::Packet. */ - class MockPacket : public Driver::Packet { - public: - MockPacket(void* payload, uint16_t length = 0) - : Packet(payload, length) - {} - - MOCK_METHOD0(getMaxPayloadSize, int()); - }; - - MOCK_METHOD1(getAddress, Address(std::string const* const addressString)); - MOCK_METHOD1(getAddress, - Address(WireFormatAddress const* const wireAddress)); - MOCK_METHOD1(addressToString, std::string(Address address)); - MOCK_METHOD2(addressToWireFormat, - void(Address address, WireFormatAddress* wireAddress)); - MOCK_METHOD0(allocPacket, Packet*()); - MOCK_METHOD1(sendPacket, void(Packet* packet)); - MOCK_METHOD0(flushPackets, void()); - MOCK_METHOD2(receivePackets, - uint32_t(uint32_t maxPackets, Packet* receivedPackets[])); - MOCK_METHOD2(releasePackets, void(Packet* packets[], uint16_t numPackets)); - MOCK_METHOD0(getHighestPacketPriority, int()); - MOCK_METHOD0(getMaxPayloadSize, uint32_t()); - MOCK_METHOD0(getBandwidth, uint32_t()); - MOCK_METHOD0(getLocalAddress, Address()); - MOCK_METHOD0(getQueuedBytes, uint32_t()); + using MockPacket = Driver::Packet; + + MOCK_METHOD(Packet*, allocPacket, (), (override)); + MOCK_METHOD(void, sendPacket, + (Packet * packet, IpAddress destination, int priority), + (override)); + MOCK_METHOD(void, flushPackets, ()); + MOCK_METHOD(uint32_t, receivePackets, + (uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]), + (override)); + MOCK_METHOD(void, releasePackets, (Packet * packets[], uint16_t numPackets), + (override)); + MOCK_METHOD(int, getHighestPacketPriority, (), (override)); + MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); + MOCK_METHOD(uint32_t, getBandwidth, (), (override)); + MOCK_METHOD(IpAddress, getLocalAddress, (), (override)); + MOCK_METHOD(uint32_t, getQueuedBytes, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 0595f25..32e7be8 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -36,10 +36,10 @@ class MockPolicyManager : public Core::Policy::Manager { MOCK_METHOD0(getResendPriority, int()); MOCK_METHOD0(getScheduledPolicy, Core::Policy::Scheduled()); MOCK_METHOD2(getUnscheduledPolicy, - Core::Policy::Unscheduled(const Driver::Address destination, + Core::Policy::Unscheduled(const IpAddress destination, const uint32_t messageLength)); MOCK_METHOD3(signalNewMessage, - void(const Driver::Address source, uint8_t policyVersion, + void(const IpAddress source, uint8_t policyVersion, uint32_t messageLength)); MOCK_METHOD0(poll, void()); }; diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index fc0fa13..0646a94 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -36,15 +36,14 @@ class MockReceiver : public Core::Receiver { : Receiver(driver, nullptr, messageTimeoutCycles, resendIntervalCycles) {} - MOCK_METHOD2(handleDataPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleBusyPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handlePingPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(receiveMessage, Homa::InMessage*()); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(void, handleDataPacket, + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handlePingPacket, + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(void, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index b67152b..2cf4234 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -37,19 +37,16 @@ class MockSender : public Core::Sender { pingIntervalCycles) {} - MOCK_METHOD0(allocMessage, Homa::OutMessage*()); - MOCK_METHOD2(handleDonePacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleGrantPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleResendPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleUnknownPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleErrorPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); + MOCK_METHOD(void, handleDonePacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(void, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/ObjectPool.h b/src/ObjectPool.h index 3b6a918..860be71 100644 --- a/src/ObjectPool.h +++ b/src/ObjectPool.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2010-2018, Stanford University +/* Copyright (c) 2010-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -16,11 +16,10 @@ #ifndef HOMA_OBJECTPOOL_H #define HOMA_OBJECTPOOL_H -#include "Homa/Exception.h" - -#include "Debug.h" - #include +#include "Debug.h" +#include "Homa/Exception.h" +#include "SpinLock.h" /* * Notes on performance and efficiency: @@ -52,8 +51,11 @@ namespace Homa { * new and delete an relatively fixed set of objects very quickly. * For example, transports use ObjectPool to allocate short-lived rpc * objects that cannot be kept in a stack context. + * + * Thread-safety of this class can be configured statically using the + * template argument ThreadSafe. */ -template +template class ObjectPool { public: /** @@ -63,7 +65,8 @@ class ObjectPool { * allocations. For simplicity, no bulk allocations are performed. */ ObjectPool() - : outstandingObjects(0) + : mutex() + , outstandingObjects(0) , pool() {} @@ -99,24 +102,26 @@ class ObjectPool { template T* construct(Args&&... args) { - void* backing = NULL; + void* backing = nullptr; + enterCS(); if (pool.size() == 0) { backing = operator new(sizeof(T)); } else { backing = pool.back(); pool.pop_back(); } + outstandingObjects++; + exitCS(); - T* object = NULL; try { - object = new (backing) T(static_cast(args)...); + return new (backing) T(static_cast(args)...); } catch (...) { + enterCS(); pool.push_back(backing); + outstandingObjects--; + exitCS(); throw; } - - outstandingObjects++; - return object; } /** @@ -124,13 +129,39 @@ class ObjectPool { */ void destroy(T* object) { - assert(outstandingObjects > 0); object->~T(); + + enterCS(); + assert(outstandingObjects > 0); pool.push_back(static_cast(object)); outstandingObjects--; + exitCS(); } private: + /** + * Enter the critical section guarded by _mutex_. + */ + void enterCS() + { + if (ThreadSafe) { + mutex.lock(); + } + } + + /** + * Exit the critical section guarded by _mutex_. + */ + void exitCS() + { + if (ThreadSafe) { + mutex.unlock(); + } + } + + /// Monitor-style lock to protect the metadata of the pool. + SpinLock mutex; + /// Count of the number of objects for which construct() was called, but /// destroy() was not. uint64_t outstandingObjects; diff --git a/src/Perf.cc b/src/Perf.cc index 155802c..faf154a 100644 --- a/src/Perf.cc +++ b/src/Perf.cc @@ -15,8 +15,6 @@ #include "Perf.h" -#include - #include #include diff --git a/src/Perf.h b/src/Perf.h index bf9668d..3b89438 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -17,7 +17,7 @@ #define HOMA_PERF_H #include -#include +#include #include @@ -29,7 +29,7 @@ namespace Perf { */ struct Counters { /** - * Wrapper class for individual counter entires to + * Wrapper class for individual counter entries. */ template struct Stat : private std::atomic { @@ -43,8 +43,10 @@ struct Counters { /** * Add the value of another Stat to this Stat. + * + * This method is thread-safe. */ - void add(const Stat& other) + inline void add(const Stat& other) { this->fetch_add(other.load(std::memory_order_relaxed), std::memory_order_relaxed); @@ -52,16 +54,21 @@ struct Counters { /** * Add the given value to this Stat. + * + * This method is not thread-safe. */ - void add(T val) + inline void add(T val) { - this->fetch_add(val, std::memory_order_relaxed); + this->store(this->load(std::memory_order_relaxed) + val, + std::memory_order_relaxed); } /** * Return the stat value. + * + * This method is thread-safe. */ - T get() const + inline T get() const { return this->load(std::memory_order_relaxed); } @@ -71,8 +78,15 @@ struct Counters { * Default constructor. */ Counters() - : active_cycles(0) - , idle_cycles(0) + : total_cycles(0) + , active_cycles(0) + , allocated_rx_messages(0) + , received_rx_messages(0) + , delivered_rx_messages(0) + , destroyed_rx_messages(0) + , allocated_tx_messages(0) + , released_tx_messages(0) + , destroyed_tx_messages(0) , tx_bytes(0) , rx_bytes(0) , tx_data_pkts(0) @@ -103,8 +117,15 @@ struct Counters { */ void add(const Counters* other) { + total_cycles.add(other->total_cycles); active_cycles.add(other->active_cycles); - idle_cycles.add(other->idle_cycles); + allocated_rx_messages.add(other->allocated_rx_messages); + received_rx_messages.add(other->received_rx_messages); + delivered_rx_messages.add(other->delivered_rx_messages); + destroyed_rx_messages.add(other->destroyed_rx_messages); + allocated_tx_messages.add(other->allocated_tx_messages); + released_tx_messages.add(other->released_tx_messages); + destroyed_tx_messages.add(other->destroyed_tx_messages); tx_bytes.add(other->tx_bytes); rx_bytes.add(other->rx_bytes); tx_data_pkts.add(other->tx_data_pkts); @@ -131,7 +152,14 @@ struct Counters { void dumpStats(Stats* stats) { stats->active_cycles = active_cycles.get(); - stats->idle_cycles = idle_cycles.get(); + stats->idle_cycles = total_cycles.get() - active_cycles.get(); + stats->allocated_rx_messages = allocated_rx_messages.get(); + stats->received_rx_messages = received_rx_messages.get(); + stats->delivered_rx_messages = delivered_rx_messages.get(); + stats->destroyed_rx_messages = destroyed_rx_messages.get(); + stats->allocated_tx_messages = allocated_tx_messages.get(); + stats->released_tx_messages = released_tx_messages.get(); + stats->destroyed_tx_messages = destroyed_tx_messages.get(); stats->tx_bytes = tx_bytes.get(); stats->rx_bytes = rx_bytes.get(); stats->tx_data_pkts = tx_data_pkts.get(); @@ -152,11 +180,32 @@ struct Counters { stats->rx_error_pkts = rx_error_pkts.get(); } + /// CPU time spent running the Homa poll loop in cycles. + Stat total_cycles; + /// CPU time spent actively processing Homa messages in cycles. Stat active_cycles; - /// CPU time spent running Homa with no work to do in cycles. - Stat idle_cycles; + /// Number of InMessages that have been allocated by the Transport. + Stat allocated_rx_messages; + + /// Number of InMessages that have been received by the Transport. + Stat received_rx_messages; + + /// Number of InMessages delivered to the application. + Stat delivered_rx_messages; + + /// Number of InMessages released back to the Transport for destruction. + Stat destroyed_rx_messages; + + /// Number of OutMessages allocated for the application. + Stat allocated_tx_messages; + + /// Number of OutMessages released back to the transport. + Stat released_tx_messages; + + /// Number of OutMessages destroyed. + Stat destroyed_tx_messages; /// Number of bytes sent by the transport. Stat tx_bytes; @@ -233,10 +282,10 @@ extern thread_local ThreadCounters counters; class Timer { public: /** - * Construct a new uninitialized Timer. + * Construct a new Timer. */ Timer() - : split_tsc(0) + : split_tsc(PerfUtils::Cycles::rdtsc()) {} /** diff --git a/src/Policy.cc b/src/Policy.cc index cf0e62e..12e7e16 100644 --- a/src/Policy.cc +++ b/src/Policy.cc @@ -97,14 +97,14 @@ Manager::getScheduledPolicy() * unilaterally "granted" (unscheduled) bytes for a new Message to be sent. * * @param destination - * The policy for the Transport at this Address will be returned. + * The policy for the Transport at this IpAddress will be returned. * @param messageLength * The policy for message containing this many bytes will be returned. * * @sa Policy::Unscheduled */ Unscheduled -Manager::getUnscheduledPolicy(const Driver::Address destination, +Manager::getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength) { SpinLock::Lock lock(mutex); @@ -140,14 +140,14 @@ Manager::getUnscheduledPolicy(const Driver::Address destination, * Called by the Receiver when a new Message has started to arrive. * * @param source - * Address of the Transport from which the new Message was received. + * IpAddress of the Transport from which the new Message was received. * @param policyVersion * Version of the policy the Sender used when sending the Message. * @param messageLength * Number of bytes the new incoming Message contains. */ void -Manager::signalNewMessage(const Driver::Address source, uint8_t policyVersion, +Manager::signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength) { SpinLock::Lock lock(mutex); diff --git a/src/Policy.h b/src/Policy.h index c32bf66..0be1eb2 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -75,10 +75,9 @@ class Manager { virtual ~Manager() = default; virtual int getResendPriority(); virtual Scheduled getScheduledPolicy(); - virtual Unscheduled getUnscheduledPolicy(const Driver::Address destination, + virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const Driver::Address source, - uint8_t policyVersion, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); @@ -107,7 +106,8 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map + peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index ee0dde5..44b8829 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - Driver::Address dest(22); + IpAddress dest{22}; { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index f83725e..84f8522 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -42,6 +42,9 @@ struct MessageId { uint64_t sequence; ///< Sequence number for this message (unique for ///< transportId, monotonically increasing). + /// MessageId default constructor. + MessageId() = default; + /// MessageId constructor. MessageId(uint64_t transportId, uint64_t sequence) : transportId(transportId) @@ -104,19 +107,26 @@ enum Opcode { /** * This is the first part of the Homa packet header and is common to all - * versions of the protocol. The struct contains version information about the + * versions of the protocol. The first four bytes of the header store the source + * and destination ports, which is common for many transport layer protocols + * (e.g., TCP, UDP, etc.) The struct also contains version information about the * protocol used in the encompassing packet. The Transport should always send * this prefix and can always expect it when receiving a Homa packet. The prefix * is separated into its own struct because the Transport may need to know the * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { + uint16_t sport, + dport; ///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. uint8_t version; ///< The version of the protocol being used by this ///< packet. /// HeaderPrefix constructor. - HeaderPrefix(uint8_t version) - : version(version) + HeaderPrefix(uint16_t sport, uint16_t dport, uint8_t version) + : sport(sport) + , dport(dport) + , version(version) {} } __attribute__((packed)); @@ -131,7 +141,7 @@ struct CommonHeader { /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : prefix(1) + : prefix(0, 0, 1) , opcode(opcode) , messageId(messageId) {} @@ -157,14 +167,18 @@ struct DataHeader { // starting at the offset corresponding to the given packet index. /// DataHeader constructor. - DataHeader(MessageId messageId, uint32_t totalLength, uint8_t policyVersion, + DataHeader(uint16_t sport, uint16_t dport, MessageId messageId, + uint32_t totalLength, uint8_t policyVersion, uint16_t unscheduledIndexLimit, uint16_t index) : common(Opcode::DATA, messageId) , totalLength(totalLength) , policyVersion(policyVersion) , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) - {} + { + common.prefix.sport = htobe16(sport); + common.prefix.dport = htobe16(dport); + } } __attribute__((packed)); /** diff --git a/src/Receiver.cc b/src/Receiver.cc index 25e0619..402a39d 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -18,7 +18,6 @@ #include #include "Perf.h" -#include "Util.h" namespace Homa { namespace Core { @@ -41,12 +40,18 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) : driver(driver) , policyManager(policyManager) - , messageBuckets(messageTimeoutCycles, resendIntervalCycles) + , MESSAGE_TIMEOUT_INTERVALS( + Util::roundUpIntDiv(messageTimeoutCycles, resendIntervalCycles)) + , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) + , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) + , messageBuckets(resendIntervalCycles) , schedulerMutex() , scheduledPeers() , receivedMessages() , granting() + , nextBucketIndex(0) , messageAllocator() + , externalBuffers() {} /** @@ -54,27 +59,38 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, */ Receiver::~Receiver() { - schedulerMutex.lock(); - scheduledPeers.clear(); - peerTable.clear(); - receivedMessages.mutex.lock(); + // To ensure that all resources of a Receiver can be freed correctly, it's + // the user's responsibility to ensure the following before destructing the + // Receiver: + // - The transport must have been taken "offline" so that no more incoming + // packets will arrive. + // - All completed incoming messages that are delivered to the application + // must have been returned back to the Receiver (so that they don't hold + // dangling pointers to the destructed Receiver afterwards). + // - There must be only one thread left that can hold a reference to the + // transport (the destructor is designed to run exactly once). + + // Technically speaking, the Receiver is designed in a way that a default + // destructor should be sufficient. However, for clarity and debugging + // purpose, we decided to write the cleanup procedure explicitly anyway. + + // Remove all completed Messages that are still inside the receive queue. receivedMessages.queue.clear(); - for (auto it = messageBuckets.buckets.begin(); - it != messageBuckets.buckets.end(); ++it) { - MessageBucket* bucket = *it; - bucket->mutex.lock(); - auto iit = bucket->messages.begin(); - while (iit != bucket->messages.end()) { - Message* message = &(*iit); - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - iit = bucket->messages.remove(iit); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - } + + // Destruct all MessageBucket's and the Messages within. + for (auto& bucket : messageBuckets.buckets) { + // Intrusive::List is not responsible for destructing its elements; + // it must be done manually. + for (auto& message : bucket.messages) { + dropMessage(&message); } + assert(bucket.resendTimeouts.empty()); } + messageBuckets.buckets.clear(); + + // Destruct all Peer's. Peer's must be removed from scheduledPeers first. + scheduledPeers.clear(); + peerTable.clear(); } /** @@ -82,86 +98,89 @@ Receiver::~Receiver() * * @param packet * The incoming packet to be processed. - * @param driver - * The driver from which the packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::DataHeader* header = static_cast(packet->payload); - uint16_t dataHeaderLength = sizeof(Protocol::Packet::DataHeader); Protocol::MessageId id = header->common.messageId; MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); + SpinLock::UniqueLock lock_bucket(bucket->mutex); Message* message = bucket->findMessage(id, lock_bucket); if (message == nullptr) { // New message int messageLength = header->totalLength; int numUnscheduledPackets = header->unscheduledIndexLimit; - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - message = messageAllocator.pool.construct( - this, driver, dataHeaderLength, messageLength, id, - packet->address, numUnscheduledPackets); - } + SocketAddress srcAddress = { + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; + message = messageAllocator.construct(this, driver, messageLength, id, + srcAddress, numUnscheduledPackets); + Perf::counters.allocated_rx_messages.add(1); + // Start tracking the message. bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source, header->policyVersion, - header->totalLength); - - if (message->scheduled) { - // Message needs to be scheduled. - SpinLock::Lock lock_scheduler(schedulerMutex); - schedule(message, lock_scheduler); - } + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); } - // Things that must be true (sanity check) - assert(id == message->id); - assert(message->driver == driver); - assert(message->source == packet->address); + // Sanity checks + assert(message->source.ip == sourceIp); + assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); - // Add the packet - bool packetAdded = message->setPacket(header->index, packet); - if (packetAdded) { - // Update schedule for scheduled messages. - if (message->scheduled) { - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - // Update the schedule if the message is still being scheduled - // (i.e. still linked to a scheduled peer). - if (info->peer != nullptr) { - int packetDataBytes = - packet->length - message->TRANSPORT_HEADER_LENGTH; - assert(info->bytesRemaining >= packetDataBytes); - info->bytesRemaining -= packetDataBytes; - updateSchedule(message, lock_scheduler); - } + if (message->received.test(header->index)) { + // Must be a duplicate packet; drop it. + return; + } else { + // Add the packet and copy the payload. + message->received.set(header->index); + message->numPackets++; + std::memcpy( + message->buffer + header->index * PACKET_DATA_LENGTH, + static_cast(packet->payload) + TRANSPORT_HEADER_LENGTH, + packet->length - TRANSPORT_HEADER_LENGTH); + } + + if (message->scheduled) { + // A new Message needs to be entered into the scheduler. + SpinLock::Lock lock_scheduler(schedulerMutex); + schedule(message, lock_scheduler); + } else if (message->scheduled) { + // Update schedule for an existing scheduled message. + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + // Update the schedule if the message is still being scheduled + // (i.e. still linked to a scheduled peer). + if (info->peer != nullptr) { + int packetDataBytes = packet->length - TRANSPORT_HEADER_LENGTH; + assert(info->bytesRemaining >= packetDataBytes); + info->bytesRemaining -= packetDataBytes; + updateSchedule(message, lock_scheduler); } + } - // Receiving a new packet means the message is still active so it - // shouldn't time out until a while later. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->numPackets < message->numExpectedPackets) { - // Still waiting for more packets to arrive but the arrival of a - // new packet means we should wait a while longer before requesting - // RESENDs of the missing packets. - bucket->resendTimeouts.setTimeout(&message->resendTimeout); - } else { - // All message packets have been received. - message->state.store(Message::State::COMPLETED); + if (message->numPackets == message->numExpectedPackets) { + // All message packets have been received. + message->completed.store(true, std::memory_order_release); + if (message->needTimeout) { bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - receivedMessages.queue.push_back(&message->receivedMessageNode); } - } else { - // must be a duplicate packet; drop packet. - driver->releasePackets(&packet, 1); + lock_bucket.unlock(); + + // Deliver the message to the user of the transport. + SpinLock::Lock lock_received_messages(receivedMessages.mutex); + receivedMessages.queue.push_back(&message->receivedMessageNode); + Perf::counters.received_rx_messages.add(1); + } else if (message->needTimeout) { + // Receiving a new packet means the message is still active so it + // shouldn't time out until a while later. + message->numResendTimeouts = 0; + bucket->resendTimeouts.setTimeout(&message->resendTimeout); } - return; } /** @@ -169,28 +188,25 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming BUSY packet to be processed. - * @param driver - * The driver from which the BUSY packet was received. */ void -Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleBusyPacket(Driver::Packet* packet) { Protocol::Packet::BusyHeader* header = static_cast(packet->payload); Protocol::MessageId id = header->common.messageId; MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); + SpinLock::UniqueLock lock_bucket(bucket->mutex); Message* message = bucket->findMessage(id, lock_bucket); if (message != nullptr) { // Sender has replied BUSY to our RESEND request; consider this message // still active. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->state == Message::State::IN_PROGRESS) { + if (message->needTimeout && !message->completed) { + message->numResendTimeouts = 0; bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } - driver->releasePackets(&packet, 1); } /** @@ -198,24 +214,31 @@ Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming PING packet to be processed. - * @param driver - * The driver from which the PING packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) +Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::PingHeader* header = static_cast(packet->payload); Protocol::MessageId id = header->common.messageId; + // FIXME(Yilong): after making the transport purely message-based, we need + // to send back an ACK packet here if this message is complete + MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); + SpinLock::UniqueLock lock_bucket(bucket->mutex); Message* message = bucket->findMessage(id, lock_bucket); if (message != nullptr) { - // Sender is checking on this message; consider it still active. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); + // Don't (re-)insert a timeout unless necessary. + if (message->needTimeout && !message->completed) { + // Sender is checking on this message; consider it still active. + message->numResendTimeouts = 0; + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } - // We are here either because a GRANT got lost, or we haven't issued a + // We are here either because a GRANT got lost, or we haven't issued a // GRANT in along time. Send out the latest GRANT if one exists or just // an "empty" GRANT to let the Sender know we are aware of the message. @@ -234,17 +257,18 @@ Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) priority = info->priority; } + lock_bucket.unlock(); Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, message->source, message->id, bytesGranted, priority); + driver, sourceIp, id, bytesGranted, priority); } else { // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. + lock_bucket.unlock(); Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send( - driver, packet->address, id); + ControlPacket::send(driver, sourceIp, + id); } - driver->releasePackets(&packet, 1); } /** @@ -255,8 +279,6 @@ Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) * * @return * A new Message which has been received, if available; otherwise, nullptr. - * - * @sa dropMessage() */ Homa::InMessage* Receiver::receiveMessage() @@ -266,6 +288,7 @@ Receiver::receiveMessage() if (!receivedMessages.queue.empty()) { message = &receivedMessages.queue.front(); receivedMessages.queue.pop_front(); + Perf::counters.delivered_rx_messages.add(1); } return message; } @@ -279,61 +302,34 @@ void Receiver::poll() { trySendGrants(); + checkTimeouts(); } /** - * Process any Receiver timeouts that have expired. + * Make incremental progress processing expired Receiver timeouts. * - * This method must be called periodically to ensure timely handling of - * expired timeouts. - * - * @return - * The rdtsc cycle time when this method should be called again. + * Pulled out of poll() for ease of testing. */ -uint64_t +void Receiver::checkTimeouts() { - uint64_t nextTimeout; - - // Ping Timeout - nextTimeout = checkResendTimeouts(); - - // Message Timeout - uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - - return nextTimeout; + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = &messageBuckets.buckets[index]; + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkResendTimeouts(now, bucket); } /** - * Destruct a Message. Will release all contained Packet objects. + * Destruct a Message. */ Receiver::Message::~Message() { - // Find contiguous ranges of packets and release them back to the - // driver. - int num = 0; - int index = 0; - int packetsFound = 0; - for (int i = 0; i < MAX_MESSAGE_PACKETS && packetsFound < numPackets; ++i) { - if (occupied.test(i)) { - if (num == 0) { - // First packet in new region. - index = i; - } - ++num; - ++packetsFound; - } else { - if (num != 0) { - // End of region; release the last region. - driver->releasePackets(&packets[index], num); - num = 0; - } - } - } - if (num != 0) { - // Release the last region (if any). - driver->releasePackets(&packets[index], num); + // Release the external buffer, if any. + if (buffer != internalBuffer) { + MessageBuffer* externalBuf = + (MessageBuffer*)buffer; + bucket->receiver->externalBuffers.destroy(externalBuf); } } @@ -343,76 +339,17 @@ Receiver::Message::~Message() void Receiver::Message::acknowledge() const { - MessageBucket* bucket = receiver->messageBuckets.getBucket(id); - SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_done_pkts.add(1); - ControlPacket::send(driver, source, id); -} - -/** - * @copydoc Homa::InMessage::dropped() - */ -bool -Receiver::Message::dropped() const -{ - return state.load() == State::DROPPED; -} - -/** - * @copydoc See Homa::InMessage::fail() - */ -void -Receiver::Message::fail() const -{ - MessageBucket* bucket = receiver->messageBuckets.getBucket(id); - SpinLock::Lock lock(bucket->mutex); - Perf::counters.tx_error_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** - * @copydoc Homa::InMessage::get() + * @copydoc Homa::InMessage::data() */ -size_t -Receiver::Message::get(size_t offset, void* destination, size_t count) const +void* +Receiver::Message::data() const { - // This operation should be performed with the offset relative to the - // logical beginning of the Message. - int _offset = Util::downCast(offset); - int _count = Util::downCast(count); - int realOffset = _offset + start; - int packetIndex = realOffset / PACKET_DATA_LENGTH; - int packetOffset = realOffset % PACKET_DATA_LENGTH; - int bytesCopied = 0; - - // Offset is passed the end of the message. - if (realOffset >= messageLength) { - return 0; - } - - if (realOffset + _count > messageLength) { - _count = messageLength - realOffset; - } - - while (bytesCopied < _count) { - uint32_t bytesToCopy = - std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); - Driver::Packet* packet = getPacket(packetIndex); - if (packet != nullptr) { - char* source = static_cast(packet->payload); - source += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(static_cast(destination) + bytesCopied, source, - bytesToCopy); - } else { - ERROR("Message is missing data starting at packet index %u", - packetIndex); - break; - } - bytesCopied += bytesToCopy; - packetIndex++; - packetOffset = 0; - } - return bytesCopied; + return buffer; } /** @@ -421,16 +358,7 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const size_t Receiver::Message::length() const { - return Util::downCast(messageLength - start); -} - -/** - * @copydoc Homa::InMessage::strip() - */ -void -Receiver::Message::strip(size_t count) -{ - start = std::min(start + Util::downCast(count), messageLength); + return Util::downCast(messageLength); } /** @@ -439,159 +367,40 @@ Receiver::Message::strip(size_t count) void Receiver::Message::release() { - receiver->dropMessage(this); -} - -/** - * Return the Packet with the given index. - * - * @param index - * A Packet's index in the array of packets that form the message. - * "packet index = "packet message offset" / PACKET_DATA_LENGTH - * @return - * Pointer to a Packet at the given index if it exists; nullptr otherwise. - */ -Driver::Packet* -Receiver::Message::getPacket(size_t index) const -{ - if (occupied.test(index)) { - return packets[index]; - } - return nullptr; + bucket->receiver->dropMessage(this); } /** - * Store the given packet as the Packet of the given index if one does not - * already exist. - * - * Responsibly for releasing the given Packet is passed to this context if the - * Packet is stored (returns true). - * - * @param index - * The Packet's index in the array of packets that form the message. - * "packet index = "packet message offset" / PACKET_DATA_LENGTH - * @param packet - * The packet pointer that should be stored. - * @return - * True if the packet was stored; false if a packet already exists (the new - * packet is not stored). - */ -bool -Receiver::Message::setPacket(size_t index, Driver::Packet* packet) -{ - if (occupied.test(index)) { - return false; - } - packets[index] = packet; - occupied.set(index); - numPackets++; - return true; -} - -/** - * Inform the Receiver that an Message returned by receiveMessage() is not - * needed and can be dropped. + * Drop a message because it's no longer needed (either the application released + * the message or a timeout occurred). * * @param message - * Message which will be dropped. + * Message which will be detached from the transport and destroyed. */ void Receiver::dropMessage(Receiver::Message* message) { - Protocol::MessageId msgId = message->id; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* foundMessage = bucket->findMessage(msgId, lock_bucket); - if (foundMessage != nullptr) { - assert(message == foundMessage); - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - if (message->scheduled) { - // Unschedule the message if it is still scheduled (i.e. still - // linked to a scheduled peer). - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - if (info->peer != nullptr) { - unschedule(message, lock_scheduler); - } - } - bucket->messages.remove(&message->bucketNode); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); + // Unschedule the message if it is still scheduled (i.e. still linked to a + // scheduled peer). + if (message->scheduled) { + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + if (info->peer != nullptr) { + unschedule(message, lock_scheduler); } } -} - -/** - * Process any inbound messages that have timed out due to lack of activity from - * the Sender. - * - * Pulled out of checkTimeouts() for ease of testing. - * - * @return - * The rdtsc cycle time when this method should be called again. - */ -uint64_t -Receiver::checkMessageTimeouts() -{ - uint64_t globalNextTimeout = UINT64_MAX; - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - - // No remaining timeouts. - if (bucket->messageTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->messageTimeouts.timeoutIntervalCycles; - break; - } - - Message* message = &bucket->messageTimeouts.list.front(); - - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed()) { - nextTimeout = message->messageTimeout.expirationCycleTime; - break; - } - // Found expired timeout. - - // Cancel timeouts - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - - if (message->state == Message::State::IN_PROGRESS) { - // Message timed out before being fully received; drop the - // message. - - // Unschedule the message - if (message->scheduled) { - // Unschedule the message if it is still scheduled (i.e. - // still linked to a scheduled peer). - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - if (info->peer != nullptr) { - unschedule(message, lock_scheduler); - } - } - - bucket->messages.remove(&message->bucketNode); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - } - } else { - // Message timed out but we already made it available to the - // Transport; let the Transport know. - message->state.store(Message::State::DROPPED); - } - } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + // Remove this message from the other data structures of the Receiver. + MessageBucket* bucket = message->bucket; + { + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); + bucket->messages.remove(&message->bucketNode); } - return globalNextTimeout; + + // Destroy the message. + messageAllocator.destroy(message); + Perf::counters.destroyed_rx_messages.add(1); } /** @@ -599,107 +408,102 @@ Receiver::checkMessageTimeouts() * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose resend timeouts should be checked. */ -uint64_t -Receiver::checkResendTimeouts() +void +Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - - // No remaining timeouts. - if (bucket->resendTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->resendTimeouts.timeoutIntervalCycles; - break; - } + if (!bucket->resendTimeouts.anyElapsed(now)) { + return; + } - Message* message = &bucket->resendTimeouts.list.front(); + while (true) { + SpinLock::UniqueLock lock_bucket(bucket->mutex); - // No remaining expired timeouts. - if (!message->resendTimeout.hasElapsed()) { - nextTimeout = message->resendTimeout.expirationCycleTime; - break; - } + // No remaining timeouts. + if (bucket->resendTimeouts.empty()) { + break; + } - // Found expired timeout. - assert(message->state == Message::State::IN_PROGRESS); + // No remaining expired timeouts. + Message* message = &bucket->resendTimeouts.front(); + if (!message->resendTimeout.hasElapsed(now)) { + break; + } + + // Found expired timeout. + assert(!message->completed); + message->numResendTimeouts++; + if (message->numResendTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { + // Message timed out before being fully received; drop the message. + lock_bucket.unlock(); + dropMessage(message); + continue; + } else { bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } - // This Receiver expected to have heard from the Sender within the - // last timeout period but it didn't. Request a resend of granted - // packets in case DATA packets got lost. - int index = 0; - int num = 0; - int grantIndexLimit = message->numUnscheduledPackets; - - if (message->scheduled) { - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - int receivedBytes = info->messageLength - info->bytesRemaining; - if (receivedBytes >= info->bytesGranted) { - // Sender is blocked on this Receiver; all granted packets - // have already been received. No need to check for resend. - continue; - } else if (grantIndexLimit * message->PACKET_DATA_LENGTH < - info->bytesGranted) { - grantIndexLimit = - (info->bytesGranted + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH; - } + // This Receiver expected to have heard from the Sender within the + // last timeout period but it didn't. Request a resend of granted + // packets in case DATA packets got lost. + uint16_t index = 0; + uint16_t num = 0; + int grantIndexLimit = message->numUnscheduledPackets; + + // The RESEND also includes the current granted priority so that it + // can act as a GRANT in case a GRANT was lost. If this message + // hasn't been scheduled (i.e. no grants have been sent) then the + // priority will hold the default value; this is ok since the Sender + // will ignore the priority field for resends of purely unscheduled + // packets (see Sender::handleResendPacket()). + int resendPriority = 0; + if (message->scheduled) { + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + int receivedBytes = info->messageLength - info->bytesRemaining; + if (receivedBytes >= info->bytesGranted) { + // Sender is blocked on this Receiver; all granted packets + // have already been received. No need to check for resend. + continue; + } else if (grantIndexLimit * PACKET_DATA_LENGTH < + info->bytesGranted) { + grantIndexLimit = + Util::roundUpIntDiv(info->bytesGranted, PACKET_DATA_LENGTH); } + resendPriority = info->priority; + } - for (int i = 0; i < grantIndexLimit; ++i) { - if (message->getPacket(i) == nullptr) { - // Unreceived packet - if (num == 0) { - // First unreceived packet - index = i; - } - ++num; - } else { - // Received packet - if (num != 0) { - // Send out the range of packets found so far. - // - // The RESEND also includes the current granted priority - // so that it can act as a GRANT in case a GRANT was - // lost. If this message hasn't been scheduled (i.e. no - // grants have been sent) then the priority will hold - // the default value; this is ok since the Sender will - // ignore the priority field for resends of purely - // unscheduled packets (see - // Sender::handleResendPacket()). - SpinLock::Lock lock_scheduler(schedulerMutex); - Perf::counters.tx_resend_pkts.add(1); - ControlPacket::send( - message->driver, message->source, message->id, - Util::downCast(index), - Util::downCast(num), - message->scheduledMessageInfo.priority); - num = 0; - } + for (int i = 0; i < grantIndexLimit; ++i) { + if (!message->received.test(i)) { + // Unreceived packet + if (num == 0) { + // First unreceived packet + index = i; + } + ++num; + } else { + // Received packet + if (num != 0) { + // Send out the range of packets found so far. + Perf::counters.tx_resend_pkts.add(1); + ControlPacket::send( + message->driver, message->source.ip, message->id, index, + num, resendPriority); + num = 0; } } - if (num != 0) { - // Send out the last range of packets found. - SpinLock::Lock lock_scheduler(schedulerMutex); - Perf::counters.tx_resend_pkts.add(1); - ControlPacket::send( - message->driver, message->source, message->id, - Util::downCast(index), - Util::downCast(num), - message->scheduledMessageInfo.priority); - } } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + if (num != 0) { + // Send out the last range of packets found. + Perf::counters.tx_resend_pkts.add(1); + ControlPacket::send( + message->driver, message->source.ip, message->id, index, num, + resendPriority); + } } - return globalNextTimeout; } /** @@ -708,8 +512,7 @@ Receiver::checkResendTimeouts() void Receiver::trySendGrants() { - uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); - bool idle = true; + Perf::Timer timer; // Skip scheduling if another poller is already working on it. if (granting.test_and_set()) { @@ -744,11 +547,11 @@ Receiver::trySendGrants() int slot = 0; while (it != scheduledPeers.end() && slot < policy.degreeOvercommitment) { assert(!it->scheduledMessages.empty()); + // No need to acquire the bucket mutex here because we are only going to + // access the const members of a Message; besides, the message can't get + // destroyed while we are holding the schedulerMutex. Message* message = &it->scheduledMessages.front(); ScheduledMessageInfo* info = &message->scheduledMessageInfo; - // Access message const variables without message mutex. - const Protocol::MessageId id = message->id; - const Driver::Address source = message->source; // Recalculate message priority info->priority = @@ -757,7 +560,6 @@ Receiver::trySendGrants() // Send a GRANT if there are too few bytes granted and unreceived. int receivedBytes = info->messageLength - info->bytesRemaining; if (info->bytesGranted - receivedBytes < policy.minScheduledBytes) { - idle = false; // Calculate new grant limit int newGrantLimit = std::min( receivedBytes + policy.maxScheduledBytes, info->messageLength); @@ -765,8 +567,9 @@ Receiver::trySendGrants() info->bytesGranted = newGrantLimit; Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, source, id, + driver, message->source.ip, message->id, Util::downCast(info->bytesGranted), info->priority); + Perf::counters.active_cycles.add(timer.split()); } // Update the iterator first since calling unschedule() may cause the @@ -776,19 +579,13 @@ Receiver::trySendGrants() if (info->messageLength <= info->bytesGranted) { // All packets granted, unschedule the message. unschedule(message, lock); + Perf::counters.active_cycles.add(timer.split()); } ++slot; } granting.clear(); - - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; - if (!idle) { - Perf::counters.active_cycles.add(elapsed_cycles); - } else { - Perf::counters.idle_cycles.add(elapsed_cycles); - } } /** @@ -802,30 +599,27 @@ Receiver::trySendGrants() * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::schedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; - Peer* peer = &peerTable[message->source]; + Peer* peer = &peerTable[message->source.ip]; // Insert the Message peer->scheduledMessages.push_front(&info->scheduledMessageNode); Intrusive::deprioritize(&peer->scheduledMessages, - &info->scheduledMessageNode, - ScheduledMessageInfo::ComparePriority()); + &info->scheduledMessageNode); info->peer = peer; if (!scheduledPeers.contains(&peer->scheduledPeerNode)) { - // Must be the only message of this peer; push the peer to the - // end of list to be moved later. + // Must be the only message of this peer; push the peer to the front of + // list to be moved later. assert(peer->scheduledMessages.size() == 1); scheduledPeers.push_front(&peer->scheduledPeerNode); - Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, - Peer::ComparePriority()); + Intrusive::deprioritize(&scheduledPeers, + &peer->scheduledPeerNode); } else if (&info->peer->scheduledMessages.front() == message) { // Update the Peer's position in the queue since the new message is the // peer's first scheduled message. - Intrusive::prioritize(&scheduledPeers, - &info->peer->scheduledPeerNode, - Peer::ComparePriority()); + Intrusive::prioritize(&scheduledPeers, &peer->scheduledPeerNode); } else { // The peer's first scheduled message did not change. Nothing to do. } @@ -842,9 +636,9 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::unschedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; assert(info->peer != nullptr); Peer* peer = info->peer; @@ -859,7 +653,8 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) // Cleanup the schedule if (peer->scheduledMessages.empty()) { - // Remove the empty peer. + // Remove the empty peer from the schedule (the peer object is still + // alive). scheduledPeers.remove(it); } else if (std::next(it) == scheduledPeers.end() || !comp(*std::next(it), *it)) { @@ -868,8 +663,8 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) // since removing a message cannot increase the peer's priority. } else { // Peer needs to be moved. - Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, - comp); + Intrusive::deprioritize(&scheduledPeers, + &peer->scheduledPeerNode); } } @@ -886,24 +681,22 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::updateSchedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::updateSchedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; assert(info->peer != nullptr); assert(info->peer->scheduledMessages.contains(&info->scheduledMessageNode)); // Update the message's position within its Peer scheduled message queue. Intrusive::prioritize(&info->peer->scheduledMessages, - &info->scheduledMessageNode, - ScheduledMessageInfo::ComparePriority()); + &info->scheduledMessageNode); // Update the Peer's position in the queue if this message is now the first // scheduled message. if (&info->peer->scheduledMessages.front() == message) { Intrusive::prioritize(&scheduledPeers, - &info->peer->scheduledPeerNode, - Peer::ComparePriority()); + &info->peer->scheduledPeerNode); } } diff --git a/src/Receiver.h b/src/Receiver.h index 444e1aa..3630d75 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -30,6 +31,7 @@ #include "Protocol.h" #include "SpinLock.h" #include "Timeout.h" +#include "Util.h" namespace Homa { namespace Core { @@ -46,34 +48,29 @@ class Receiver { uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual void handleDataPacket(Driver::Packet* packet, Driver* driver); - virtual void handleBusyPacket(Driver::Packet* packet, Driver* driver); - virtual void handlePingPacket(Driver::Packet* packet, Driver* driver); + virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual void handleBusyPacket(Driver::Packet* packet); + virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); virtual void poll(); - virtual uint64_t checkTimeouts(); + virtual void checkTimeouts(); private: // Forward declaration class Message; + struct MessageBucket; struct Peer; + /// Define the maximum number of packets that a message can hold. + static const int MAX_MESSAGE_PACKETS = 1024; + + /// Define the maximum number of bytes within a message. + static const int MAX_MESSAGE_LENGTH = 1u << 20u; + /** * Contains metadata for a Message that requires additional GRANTs. */ struct ScheduledMessageInfo { - /** - * Implements a binary comparison function for the strict weak priority - * ordering of two Message objects. - */ - struct ComparePriority { - bool operator()(const Message& a, const Message& b) - { - return a.scheduledMessageInfo.bytesRemaining < - b.scheduledMessageInfo.bytesRemaining; - } - }; - /** * ScheduledMessageInfo constructor. * @@ -120,73 +117,53 @@ class Receiver { class Message : public Homa::InMessage { public: /** - * Defines the possible states of this Message. + * Implements a binary comparison function for the strict weak priority + * ordering of two Message objects. */ - enum class State { - IN_PROGRESS, //< Receiver is in the process of receiving this - // message. - COMPLETED, //< Receiver has received the entire message. - DROPPED, //< Message was COMPLETED but the Receiver has lost - //< communication with the Sender. + struct ComparePriority { + bool operator()(const Message& a, const Message& b) + { + return a.scheduledMessageInfo.bytesRemaining < + b.scheduledMessageInfo.bytesRemaining; + } }; - explicit Message(Receiver* receiver, Driver* driver, - size_t packetHeaderLength, size_t messageLength, - Protocol::MessageId id, Driver::Address source, + explicit Message(Receiver* receiver, Driver* driver, int messageLength, + Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) - : receiver(receiver) - , driver(driver) + : driver(driver) , id(id) + , bucket(receiver->messageBuckets.getBucket(id)) , source(source) - , TRANSPORT_HEADER_LENGTH(packetHeaderLength) - , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - - TRANSPORT_HEADER_LENGTH) - , numExpectedPackets((messageLength + PACKET_DATA_LENGTH - 1) / - PACKET_DATA_LENGTH) + , numExpectedPackets(Util::roundUpIntDiv( + messageLength, receiver->PACKET_DATA_LENGTH)) , numUnscheduledPackets(numUnscheduledPackets) , scheduled(numExpectedPackets > numUnscheduledPackets) - , start(0) , messageLength(messageLength) , numPackets(0) - , occupied() - // packets is not initialized to reduce the work done during - // construction. See Message::occupied. - , state(Message::State::IN_PROGRESS) + , received() + , buffer(messageLength <= Util::arrayLength(internalBuffer) + ? internalBuffer + : receiver->externalBuffers.construct()->raw) + // No need to zero-out internalBuffer. + , completed(false) , bucketNode(this) , receivedMessageNode(this) - , messageTimeout(this) + , needTimeout(numExpectedPackets > 1) + , numResendTimeouts(0) , resendTimeout(this) , scheduledMessageInfo(this, messageLength) - {} - - virtual ~Message(); - virtual void acknowledge() const; - virtual bool dropped() const; - virtual void fail() const; - virtual size_t get(size_t offset, void* destination, - size_t count) const; - virtual size_t length() const; - virtual void strip(size_t count); - virtual void release(); - - /** - * Return the current state of this message. - */ - State getState() const { - return state.load(); + assert(messageLength <= MAX_MESSAGE_LENGTH); } - private: - /// Define the maximum number of packets that a message can hold. - static const int MAX_MESSAGE_PACKETS = 1024; - - Driver::Packet* getPacket(size_t index) const; - bool setPacket(size_t index, Driver::Packet* packet); - - /// The Receiver responsible for this message. - Receiver* const receiver; + virtual ~Message(); + void acknowledge() const override; + void* data() const override; + size_t length() const override; + void release() override; + private: /// Driver from which packets were received and to which they should be /// returned when this message is no longer needed. Driver* const driver; @@ -194,15 +171,11 @@ class Receiver { /// Contains the unique identifier for this message. const Protocol::MessageId id; - /// Contains source address this message. - const Driver::Address source; - - /// Number of bytes at the beginning of each Packet that should be - /// reserved for the Homa transport header. - const int TRANSPORT_HEADER_LENGTH; + /// Message bucket this message belongs to. + MessageBucket* const bucket; - /// Number of bytes of data in each full packet. - const int PACKET_DATA_LENGTH; + /// Contains source address this message. + const SocketAddress source; /// Number of packets the message is expected to contain. const int numExpectedPackets; @@ -214,25 +187,26 @@ class Receiver { /// GRANTs to be sent. const bool scheduled; - /// First byte where data is or will go if empty. - int start; - /// Number of bytes in this Message including any stripped bytes. - int messageLength; + const int messageLength; - /// Number of packets currently contained in this message. + /// Number of packets currently contained in this message. Protected by + /// MessageBucket::mutex. int numPackets; - /// Bit array representing which entires in the _packets_ array are set. - /// Used to avoid having to zero out the entire _packets_ array. - std::bitset occupied; + /// Bit array representing which packets in this message are received. + /// Protected by MessageBucket::mutex. + std::bitset received; + + /// Pointer to the contiguous memory buffer serving as message storage. + char* const buffer; - /// Collection of Packet objects that make up this context's Message. - /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + /// Internal memory buffer used to store messages within 2KB. + char internalBuffer[2048]; - /// This message's current state. - std::atomic state; + /// Current state of the message. True means the entire message has been + /// received; otherwise, this message is still in progress. + std::atomic completed; /// Intrusive structure used by the Receiver to hold on to this Message /// in one of the Receiver's MessageBuckets. Access to this structure @@ -243,12 +217,18 @@ class Receiver { /// message when it has been completely received. Intrusive::List::Node receivedMessageNode; - /// Intrusive structure used by the Receiver to keep track when the - /// receiving of this message should be considered failed. - Timeout messageTimeout; + /// True if this message needs to be tracked by the timeout manager. + /// As an inbound message, this variable should only be set to false + /// when the message fits in a single packet. + const bool needTimeout; + + /// Number of resend timeouts that occurred in a row. Protected by + /// MessageBucket::mutex. + int numResendTimeouts; /// Intrusive structure used by the Receiver to keep track when - /// unreceived parts of this message should be re-requested. + /// unreceived parts of this message should be re-requested. Protected + /// by MessageBucket::mutex. Timeout resendTimeout; /// Intrusive structure used by the Receiver to keep track of this @@ -259,6 +239,20 @@ class Receiver { friend class Receiver; }; + /** + * Memory buffer used to hold large messages that don't fit in Message's + * internal buffer. It's basically a simple wrapper around an array so + * that it can be allocated from an ObjectPool. + * + * @tparam length + * Number of bytes in the buffer. + */ + template + struct MessageBuffer { + /// Buffer space. + char raw[length]; + }; + /** * A collection of incoming Message objects and their associated timeouts. * @@ -269,57 +263,58 @@ class Receiver { /** * MessageBucket constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. + * @param receiver + * Receiver that owns this bucket. * @param resendIntervalCycles * Number of cycles of inactivity to wait between requesting * retransmission of un-received parts of a Message. * liveness of a Message. */ - explicit MessageBucket(uint64_t messageTimeoutCycles, + explicit MessageBucket(Receiver* receiver, uint64_t resendIntervalCycles) - : mutex() + : receiver(receiver) + , mutex() , messages() - , messageTimeouts(messageTimeoutCycles) , resendTimeouts(resendIntervalCycles) {} + // MessageBucket's are only destroyed when the transport is destructed; + // all the real work are done in ~Receiver(). + ~MessageBucket() = default; + /** * Return the Message with the given MessageId. * * @param msgId * MessageId of the Message to be found. * @param lock - * Reminder to hold the MessageBucket::mutex during this call. (Not - * used) + * Reminder to hold the MessageBucket::mutex during this call. * @return * A pointer to the Message if found; nullptr, otherwise. */ Message* findMessage(const Protocol::MessageId& msgId, - const SpinLock::Lock& lock) + const SpinLock::UniqueLock& lock) { - (void)lock; - Message* message = nullptr; - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (it->id == msgId) { - message = &(*it); - break; + assert(lock.owns_lock()); + for (auto& it : messages) { + if (it.id == msgId) { + return ⁢ } } - return message; + return nullptr; } - /// Mutex protecting the contents of this bucket. + /// The Receiver that owns this bucket. + Receiver* const receiver; + + /// Mutex protecting the contents of this bucket. This includes messages + /// within the bucket. SpinLock mutex; /// Collection of inbound messages Intrusive::List messages; - /// Maintains Message objects in increasing order of timeout. - TimeoutManager messageTimeouts; - - /// Maintains Message object in increase order of resend timeout. + /// Maintains Message object in increasing order of resend timeout. TimeoutManager resendTimeouts; }; @@ -334,6 +329,9 @@ class Receiver { */ static const int NUM_BUCKETS = 256; + // Make sure the number of buckets is a power of 2. + static_assert(Util::isPowerOfTwo(NUM_BUCKETS)); + /** * Bit mask used to map from a hashed key to the bucket index. */ @@ -343,66 +341,41 @@ class Receiver { static_assert(NUM_BUCKETS == HASH_KEY_MASK + 1); /** - * Helper method to create the set of buckets. + * MessageBucketMap constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. * @param resendIntervalCycles * Number of cycles of inactivity to wait between requesting * retransmission of un-received parts of a Message. * liveness of a Message. */ - static std::array makeBuckets( - uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) + explicit MessageBucketMap(uint64_t resendIntervalCycles) + : buckets() + , hasher() { - std::array buckets; + buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets[i] = new MessageBucket(messageTimeoutCycles, - resendIntervalCycles); + buckets.emplace_back(resendIntervalCycles); } - return buckets; } - /** - * MessageBucketMap constructor. - * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. - * @param resendIntervalCycles - * Number of cycles of inactivity to wait between requesting - * retransmission of un-received parts of a Message. - * liveness of a Message. - */ - explicit MessageBucketMap(uint64_t messageTimeoutCycles, - uint64_t resendIntervalCycles) - : buckets(makeBuckets(messageTimeoutCycles, resendIntervalCycles)) - , hasher() - {} - /** * MessageBucketMap destructor. */ - ~MessageBucketMap() - { - for (int i = 0; i < NUM_BUCKETS; ++i) { - delete buckets[i]; - } - } + ~MessageBucketMap() = default; /** * Return the MessageBucket that should hold a Message with the given * MessageId. */ - MessageBucket* getBucket(const Protocol::MessageId& msgId) const + MessageBucket* getBucket(const Protocol::MessageId& msgId) { uint index = hasher(msgId) & HASH_KEY_MASK; - return buckets[index]; + return &buckets[index]; } - /// Array of buckets. - std::array const buckets; + /// Array of NUM_BUCKETS buckets. Defined as a vector to avoid the need + /// for a default constructor in MessageBucket. + std::vector buckets; /// MessageId hash function container. Protocol::MessageId::Hasher hasher; @@ -410,23 +383,22 @@ class Receiver { /** * Holds the incoming scheduled messages from another transport. + * + * The lifetime of a Peer is the same as Receiver: we never destruct Peer + * objects when the transport is running. */ struct Peer { /** - * Peer constructor. + * Default constructor. Easy to use with std::unorderd_map. */ Peer() : scheduledMessages() , scheduledPeerNode(this) {} - /** - * Peer destructor. - */ - ~Peer() - { - scheduledMessages.clear(); - } + // Peer's are only destroyed when the transport is destructed; all the + // real work are done in ~Receiver(). + ~Peer() = default; /** * Implements a binary comparison function for the strict weak priority @@ -437,7 +409,7 @@ class Receiver { { assert(!a.scheduledMessages.empty()); assert(!b.scheduledMessages.empty()); - ScheduledMessageInfo::ComparePriority comp; + Message::ComparePriority comp; return comp(a.scheduledMessages.front(), b.scheduledMessages.front()); } @@ -450,30 +422,42 @@ class Receiver { }; void dropMessage(Receiver::Message* message); - uint64_t checkMessageTimeouts(); - uint64_t checkResendTimeouts(); + void checkResendTimeouts(uint64_t now, MessageBucket* bucket); void trySendGrants(); - void schedule(Message* message, const SpinLock::Lock& lock); - void unschedule(Message* message, const SpinLock::Lock& lock); - void updateSchedule(Message* message, const SpinLock::Lock& lock); + void schedule(Message* message, SpinLock::Lock& lock); + void unschedule(Message* message, SpinLock::Lock& lock); + void updateSchedule(Message* message, SpinLock::Lock& lock); /// Driver with which all packets will be sent and received. This driver - /// is chosen by the Transport that owns this Sender. + /// is chosen by the Transport that owns this Receiver. Driver* const driver; - /// Provider of network packet priority and grant policy decisions. + /// Provider of network packet priority and grant policy decisions. Not + /// owned by this class. Policy::Manager* const policyManager; + /// The number of resend timeouts to occur before declaring a message + /// timeout. + const int MESSAGE_TIMEOUT_INTERVALS; + + /// Number of bytes at the beginning of each Packet that should be reserved + /// for the Homa transport header. + const int TRANSPORT_HEADER_LENGTH; + + /// Number of bytes of data in each full packet. + const int PACKET_DATA_LENGTH; + /// Tracks the set of inbound messages being received by this Receiver. MessageBucketMap messageBuckets; /// Protects access to the Receiver's scheduler state (i.e. peerTable, - /// scheduledPeers, and ScheduledMessageInfo). + /// scheduledPeers, and ScheduledMessageInfo). Locking order constraint: + /// MessageBucket::mutex must be acquired before schedulerMutex. SpinLock schedulerMutex; /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. @@ -492,13 +476,17 @@ class Receiver { /// each other. std::atomic_flag granting = ATOMIC_FLAG_INIT; + /// The index of the next bucket in the messageBuckets::buckets array to + /// process in the poll loop. The index is held in the lower order bits of + /// this variable; the higher order bits should be masked off using the + /// MessageBucketMap::HASH_KEY_MASK bit mask. + std::atomic nextBucketIndex; + /// Used to allocate Message objects. - struct { - /// Protects the messageAllocator.pool - SpinLock mutex; - /// Pool from which Message objects can be allocated. - ObjectPool pool; - } messageAllocator; + ObjectPool messageAllocator; + + /// Used to allocate large memory buffers outside the Message struct. + ObjectPool> externalBuffers; }; } // namespace Core diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index a49aee2..fcd450f 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -37,11 +37,18 @@ using ::testing::NiceMock; using ::testing::Pointee; using ::testing::Return; +/// Helper macro to construct an IpAddress from a numeric number. +#define IP(x) \ + IpAddress \ + { \ + x \ + } + class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket(&payload) + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , payload() , receiver() @@ -68,7 +75,7 @@ class ReceiverTest : public ::testing::Test { static const uint64_t resendIntervalCycles = 100; NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Receiver* receiver; @@ -105,21 +112,21 @@ TEST_F(ReceiverTest, handleDataPacket) header->totalLength = totalMessageLength; header->policyVersion = policyVersion; header->unscheduledIndexLimit = 1; - mockPacket.address = Driver::Address(22); + IpAddress sourceIp{22}; // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.address), Eq(policyVersion), + signalNewMessage(Eq(sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- { @@ -148,7 +155,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -162,7 +169,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -177,7 +184,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -192,7 +199,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); @@ -207,7 +214,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -217,7 +224,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) { Protocol::MessageId id(42, 32); Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(0), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); bucket->messages.push_back(&message->bucketNode); @@ -228,7 +235,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->resendTimeout.expirationCycleTime); @@ -245,15 +252,15 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); } TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - Driver::Address mockAddress = 22; + IpAddress mockAddress{22}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 20000, id, mockAddress, 0); + receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); Receiver::ScheduledMessageInfo* info = &message->scheduledMessageInfo; info->bytesGranted = 500; @@ -263,25 +270,25 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = mockAddress; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; + IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); - EXPECT_EQ(mockAddress, mockPacket.address); Protocol::Packet::GrantHeader* header = (Protocol::Packet::GrantHeader*)payload; EXPECT_EQ(Protocol::Packet::GRANT, header->common.opcode); @@ -295,22 +302,22 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = (Driver::Address)22; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; + IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, mockAddress); - EXPECT_EQ(pingPacket.address, mockPacket.address); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; EXPECT_EQ(Protocol::Packet::UNKNOWN, header->common.opcode); @@ -321,10 +328,10 @@ TEST_F(ReceiverTest, receiveMessage) { Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); receiver->receivedMessages.queue.push_back(&msg0->receivedMessageNode); receiver->receivedMessages.queue.push_back(&msg1->receivedMessageNode); @@ -348,38 +355,26 @@ TEST_F(ReceiverTest, poll) TEST_F(ReceiverTest, checkTimeouts) { - Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), Driver::Address(0), 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); - bucket->resendTimeouts.setTimeout(&message.resendTimeout); - bucket->messageTimeouts.setTimeout(&message.messageTimeout); - message.resendTimeout.expirationCycleTime = 10010; - message.messageTimeout.expirationCycleTime = 10020; + EXPECT_EQ(0, receiver->nextBucketIndex.load()); - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10010U, receiver->checkTimeouts()); - - message.resendTimeout.expirationCycleTime = 10030; + receiver->checkTimeouts(); - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10020U, receiver->checkTimeouts()); - - bucket->resendTimeouts.cancelTimeout(&message.resendTimeout); - bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); + EXPECT_EQ(1, receiver->nextBucketIndex.load()); } TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 5; message->numPackets = NUM_PKTS; for (int i = 0; i < NUM_PKTS; ++i) { - message->occupied.set(i); + message->received.set(i); } EXPECT_CALL(mockDriver, releasePackets(Eq(message->packets), Eq(NUM_PKTS))) @@ -392,15 +387,15 @@ TEST_F(ReceiverTest, Message_destructor_holes) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 4; message->numPackets = NUM_PKTS; - message->occupied.set(0); - message->occupied.set(1); - message->occupied.set(3); - message->occupied.set(4); + message->received.set(0); + message->received.set(1); + message->received.set(3); + message->received.set(4); EXPECT_CALL(mockDriver, releasePackets(Eq(&message->packets[0]), Eq(2))) .Times(1); @@ -414,10 +409,12 @@ TEST_F(ReceiverTest, Message_acknowledge) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -428,14 +425,13 @@ TEST_F(ReceiverTest, Message_acknowledge) EXPECT_EQ(Protocol::Packet::DONE, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::DoneHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_dropped) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->state = Receiver::Message::State::IN_PROGRESS; @@ -450,10 +446,12 @@ TEST_F(ReceiverTest, Message_fail) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -464,7 +462,6 @@ TEST_F(ReceiverTest, Message_fail) EXPECT_EQ(Protocol::Packet::ERROR, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::ErrorHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_get_basic) @@ -472,10 +469,10 @@ TEST_F(ReceiverTest, Message_get_basic) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -499,10 +496,10 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -525,10 +522,10 @@ TEST_F(ReceiverTest, Message_get_missingPacket) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -557,7 +554,7 @@ TEST_F(ReceiverTest, Message_length) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 200; message->start = 20; EXPECT_EQ(180U, message->length()); @@ -567,7 +564,7 @@ TEST_F(ReceiverTest, Message_strip) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 30; message->start = 0; @@ -589,14 +586,14 @@ TEST_F(ReceiverTest, Message_getPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; message->packets[0] = packet; EXPECT_EQ(nullptr, message->getPacket(0)); - message->occupied.set(0); + message->received.set(0); EXPECT_EQ(packet, message->getPacket(0)); } @@ -605,16 +602,16 @@ TEST_F(ReceiverTest, Message_setPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; - EXPECT_FALSE(message->occupied.test(0)); + EXPECT_FALSE(message->received.test(0)); EXPECT_EQ(0U, message->numPackets); EXPECT_TRUE(message->setPacket(0, packet)); EXPECT_EQ(packet, message->packets[0]); - EXPECT_TRUE(message->occupied.test(0)); + EXPECT_TRUE(message->received.test(0)); EXPECT_EQ(1U, message->numPackets); EXPECT_FALSE(message->setPacket(0, packet)); @@ -626,12 +623,12 @@ TEST_F(ReceiverTest, MessageBucket_findMessage) Protocol::MessageId id0 = {42, 0}; Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, 0, - 0); + receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, + SocketAddress{0, 60001}, 0); Protocol::MessageId id1 = {42, 1}; Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id1, - Driver::Address(0), 0); + SocketAddress{0, 60001}, 0); Protocol::MessageId id_none = {42, 42}; bucket->messages.push_back(&msg0->bucketNode); @@ -659,7 +656,7 @@ TEST_F(ReceiverTest, dropMessage) SpinLock::Lock dummy(dummyMutex); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{22, 60001}, 0); ASSERT_TRUE(message->scheduled); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); @@ -670,7 +667,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_EQ(1U, receiver->messageAllocator.pool.outstandingObjects); EXPECT_EQ(message, bucket->findMessage(id, dummy)); - EXPECT_EQ(&receiver->peerTable[message->source], + EXPECT_EQ(&receiver->peerTable[message->source.ip], message->scheduledMessageInfo.peer); EXPECT_FALSE(bucket->messageTimeouts.list.empty()); EXPECT_FALSE(bucket->resendTimeouts.list.empty()); @@ -684,7 +681,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_TRUE(bucket->resendTimeouts.list.empty()); } -TEST_F(ReceiverTest, checkMessageTimeouts_basic) +TEST_F(ReceiverTest, checkMessageTimeouts) { void* op[3]; Receiver::Message* message[3]; @@ -693,7 +690,7 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) Protocol::MessageId id = {42, 10 + i}; op[i] = reinterpret_cast(i); message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, 0, 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{0, 60001}, 0); bucket->messages.push_back(&message[i]->bucketNode); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); @@ -716,14 +713,17 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) // Message[2]: No timeout message[2]->messageTimeout.expirationCycleTime = 10001; + bucket->messageTimeouts.nextTimeout = 9998; + ASSERT_EQ(10000U, PerfUtils::Cycles::rdtsc()); ASSERT_TRUE(message[0]->messageTimeout.hasElapsed()); ASSERT_TRUE(message[1]->messageTimeout.hasElapsed()); ASSERT_FALSE(message[2]->messageTimeout.hasElapsed()); - uint64_t nextTimeout = receiver->checkMessageTimeouts(); + receiver->checkMessageTimeouts(10000, bucket); - EXPECT_EQ(message[2]->messageTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[2]->messageTimeout.expirationCycleTime, + bucket->messageTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: IN_PROGRESS EXPECT_EQ(nullptr, message[0]->messageTimeout.node.list); @@ -748,26 +748,14 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) EXPECT_EQ(2U, receiver->messageAllocator.pool.outstandingObjects); } -TEST_F(ReceiverTest, checkMessageTimeouts_empty) -{ - for (int i = 0; i < Receiver::MessageBucketMap::NUM_BUCKETS; ++i) { - Receiver::MessageBucket* bucket = - receiver->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->messageTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = receiver->checkMessageTimeouts(); - EXPECT_EQ(10000 + messageTimeoutCycles, nextTimeout); -} - -TEST_F(ReceiverTest, checkResendTimeouts_basic) +TEST_F(ReceiverTest, checkResendTimeouts) { Receiver::Message* message[3]; Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 10000, id, Driver::Address(22), 5); + receiver, &mockDriver, 0, 10000, id, SocketAddress{22, 60001}, 5); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); } @@ -799,27 +787,34 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) ASSERT_EQ(Receiver::Message::State::IN_PROGRESS, message[2]->state); message[2]->resendTimeout.expirationCycleTime = 10001; + bucket->resendTimeouts.nextTimeout = 9999; + EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1(buf1); - Homa::Mock::MockDriver::MockPacket mockResendPacket2(buf2); + Homa::Mock::MockDriver::MockPacket mockResendPacket1{buf1}; + Homa::Mock::MockDriver::MockPacket mockResendPacket2{buf2}; EXPECT_CALL(mockDriver, allocPacket()) .WillOnce(Return(&mockResendPacket1)) .WillOnce(Return(&mockResendPacket2)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), + Eq(message[0]->source.ip), _)) + .Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), + Eq(message[0]->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) .Times(1); // TEST CALL - uint64_t nextTimeout = receiver->checkResendTimeouts(); + receiver->checkResendTimeouts(10000, bucket); - EXPECT_EQ(message[2]->resendTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[2]->resendTimeout.expirationCycleTime, + bucket->resendTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: resends EXPECT_EQ(10100, message[0]->resendTimeout.expirationCycleTime); @@ -830,7 +825,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(2U, header1->index); EXPECT_EQ(4U, header1->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket1.length); - EXPECT_EQ(message[0]->source, mockResendPacket1.address); Protocol::Packet::ResendHeader* header2 = static_cast(mockResendPacket2.payload); EXPECT_EQ(Protocol::Packet::RESEND, header2->common.opcode); @@ -838,7 +832,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(8U, header2->index); EXPECT_EQ(2U, header2->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket2.length); - EXPECT_EQ(message[0]->source, mockResendPacket2.address); // Message[1]: Blocked on grants EXPECT_EQ(10100, message[1]->resendTimeout.expirationCycleTime); @@ -847,27 +840,16 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(10001, message[2]->resendTimeout.expirationCycleTime); } -TEST_F(ReceiverTest, checkResendTimeouts_empty) -{ - for (int i = 0; i < Receiver::MessageBucketMap::NUM_BUCKETS; ++i) { - Receiver::MessageBucket* bucket = - receiver->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->resendTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = receiver->checkResendTimeouts(); - EXPECT_EQ(10000 + resendIntervalCycles, nextTimeout); -} - TEST_F(ReceiverTest, trySendGrants) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - for (uint64_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, Driver::Address(100 + i), 10 * (i + 1)); + 10000 * (i + 1), id, SocketAddress{IP(100 + i), 60001}, + 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); receiver->schedule(message[i], lock_scheduler); @@ -894,7 +876,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -920,7 +902,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -941,7 +923,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -960,7 +942,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -975,13 +957,13 @@ TEST_F(ReceiverTest, schedule) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - Driver::Address address[4] = {22, 33, 33, 22}; + IpAddress address[4] = {22, 33, 33, 22}; int messageLength[4] = {2000, 3000, 1000, 4000}; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, address[i], 0); + messageLength[i], id, SocketAddress{address[i], 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; } @@ -994,7 +976,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[0], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[0]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[0]->peer); EXPECT_EQ(message[0], &info[0]->peer->scheduledMessages.front()); EXPECT_EQ(info[0]->peer, &receiver->scheduledPeers.front()); @@ -1006,7 +988,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[1], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[1]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[1]->peer); EXPECT_EQ(message[1], &info[1]->peer->scheduledMessages.front()); EXPECT_EQ(info[1]->peer, &receiver->scheduledPeers.back()); @@ -1018,7 +1000,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[2], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[2]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[2]->peer); EXPECT_EQ(message[2], &info[2]->peer->scheduledMessages.front()); EXPECT_EQ(info[2]->peer, &receiver->scheduledPeers.front()); @@ -1030,7 +1012,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[3], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[3]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[3]->peer); EXPECT_EQ(message[3], &info[3]->peer->scheduledMessages.back()); EXPECT_EQ(info[3]->peer, &receiver->scheduledPeers.back()); } @@ -1041,23 +1023,24 @@ TEST_F(ReceiverTest, unschedule) Receiver::ScheduledMessageInfo* info[5]; SpinLock::Lock lock(receiver->schedulerMutex); int messageLength[5] = {10, 20, 30, 10, 20}; - for (uint64_t i = 0; i < 5; ++i) { + for (uint32_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - Driver::Address source = Driver::Address((i / 3) + 10); + IpAddress source = IP((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, source, 0); + messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } + auto& scheduledPeers = receiver->scheduledPeers; - ASSERT_EQ(Driver::Address(10), message[0]->source); - ASSERT_EQ(Driver::Address(10), message[1]->source); - ASSERT_EQ(Driver::Address(10), message[2]->source); - ASSERT_EQ(Driver::Address(11), message[3]->source); - ASSERT_EQ(Driver::Address(11), message[4]->source); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + ASSERT_EQ(IP(10), message[0]->source.ip); + ASSERT_EQ(IP(10), message[1]->source.ip); + ASSERT_EQ(IP(10), message[2]->source.ip); + ASSERT_EQ(IP(11), message[3]->source.ip); + ASSERT_EQ(IP(11), message[4]->source.ip); + ASSERT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + ASSERT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); // <10>: [0](10) -> [1](20) -> [2](30) // <11>: [3](10) -> [4](20) @@ -1075,10 +1058,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[4], lock); EXPECT_EQ(nullptr, info[4]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(3U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(3U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[1]; peer in correct position. @@ -1088,10 +1071,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[1], lock); EXPECT_EQ(nullptr, info[1]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(2U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(2U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[0]; peer needs to be reordered. @@ -1101,10 +1084,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[0], lock); EXPECT_EQ(nullptr, info[0]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[3]; peer needs to be removed. @@ -1113,10 +1096,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[3], lock); EXPECT_EQ(nullptr, info[3]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(0U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(0U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); } TEST_F(ReceiverTest, updateSchedule) @@ -1125,25 +1108,26 @@ TEST_F(ReceiverTest, updateSchedule) // 11 : [20][30] SpinLock::Lock lock(receiver->schedulerMutex); Receiver::Message* other[3]; - for (uint64_t i = 0; i < 3; ++i) { + for (uint32_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - Driver::Address source = Driver::Address(((i + 1) / 2) + 10); + IpAddress source = IP(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10 * (i + 1), id, source, 0); + 10 * (i + 1), id, SocketAddress{source, 60001}, 0); receiver->schedule(other[i], lock); } Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, - Protocol::MessageId(42, 1), Driver::Address(11), 0); + Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); - ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[2]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), message->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& peerTable = receiver->peerTable; + ASSERT_EQ(&peerTable.at(IP(10)), other[0]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[1]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[2]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), message->scheduledMessageInfo.peer); + ASSERT_EQ(&receiver->scheduledPeers.front(), &peerTable.at(IP(10))); + ASSERT_EQ(&receiver->scheduledPeers.back(), &peerTable.at(IP(11))); //-------------------------------------------------------------------------- // Move message up within peer. @@ -1153,11 +1137,12 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& scheduledPeers = receiver->scheduledPeers; + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); Receiver::Peer* peer = &receiver->scheduledPeers.back(); auto it = peer->scheduledMessages.begin(); EXPECT_TRUE( - std::next(receiver->peerTable.at(11).scheduledMessages.begin()) == + std::next(receiver->peerTable.at(IP(11)).scheduledMessages.begin()) == message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1169,8 +1154,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1182,8 +1167,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); } diff --git a/src/Sender.cc b/src/Sender.cc index c2d0c3f..646bc25 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -50,27 +50,28 @@ Sender::Sender(uint64_t transportId, Driver* driver, , policyManager(policyManager) , nextMessageSequenceNumber(1) , DRIVER_QUEUED_BYTE_LIMIT(2 * driver->getMaxPayloadSize()) - , messageBuckets(messageTimeoutCycles, pingIntervalCycles) + , MESSAGE_TIMEOUT_INTERVALS( + Util::roundUpIntDiv(messageTimeoutCycles, pingIntervalCycles)) + , messageBuckets(this, pingIntervalCycles) , queueMutex() , sendQueue() , sending() , sendReady(false) + , sentMessages() + , nextBucketIndex(0) , messageAllocator() {} -/** - * Sender Destructor - */ -Sender::~Sender() {} - /** * Allocate an OutMessage that can be sent with this Sender. */ Homa::OutMessage* -Sender::allocMessage() +Sender::allocMessage(uint16_t sourcePort) { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - return messageAllocator.pool.construct(this, driver); + Perf::counters.allocated_tx_messages.add(1); + uint64_t messageId = + nextMessageSequenceNumber.fetch_add(1, std::memory_order_relaxed); + return messageAllocator.construct(this, messageId, sourcePort); } /** @@ -78,24 +79,19 @@ Sender::allocMessage() * * @param packet * Incoming DONE packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) +Sender::handleDonePacket(Driver::Packet* packet) { - Protocol::Packet::DoneHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->messageId; MessageBucket* bucket = messageBuckets.getBucket(msgId); SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); - if (message == nullptr) { - // No message for this DONE packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this DONE packet; must be old. return; } @@ -104,9 +100,7 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) switch (status) { case OutMessage::Status::SENT: // Expected behavior - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::COMPLETED); + message->setStatus(OutMessage::Status::COMPLETED, false, lock); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -143,8 +137,6 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) msgId.transportId, msgId.sequence); break; } - - driver->releasePackets(&packet, 1); } /** @@ -152,12 +144,9 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming RESEND packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) +Sender::handleResendPacket(Driver::Packet* packet) { Protocol::Packet::ResendHeader* header = static_cast(packet->payload); @@ -173,7 +162,6 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) if (message == nullptr) { // No message for this RESEND; RESEND must be old. Just ignore it; this // case should be pretty rare and the Receiver will timeout eventually. - driver->releasePackets(&packet, 1); return; } else if (message->numPackets < 2) { // We should never get a RESEND for a single packet message. Just @@ -181,31 +169,27 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) WARNING( "Message (%lu, %lu) with only 1 packet received unexpected RESEND " "request; peer Transport may be confused.", - msgId.transportId, msgId.sequence); - driver->releasePackets(&packet, 1); + message->id.transportId, message->id.sequence); return; } - bucket->messageTimeouts.setTimeout(&message->messageTimeout); + message->numPingTimeouts = 0; bucket->pingTimeouts.setTimeout(&message->pingTimeout); - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - // Check if RESEND request is out of range. - if (index >= info->packets->numPackets || - resendEnd > info->packets->numPackets) { + if (index >= message->numPackets || resendEnd > message->numPackets) { WARNING( "Message (%lu, %lu) RESEND request range out of bounds: requested " "range [%d, %d); message only contains %d packets; peer Transport " "may be confused.", - msgId.transportId, msgId.sequence, index, resendEnd, - info->packets->numPackets); - driver->releasePackets(&packet, 1); + message->id.transportId, message->id.sequence, index, resendEnd, + message->numPackets); return; } // In case a GRANT may have been lost, consider the RESEND a GRANT. + QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::UniqueLock lock_queue(queueMutex); if (info->packetsGranted < resendEnd) { info->packetsGranted = resendEnd; // Note that the priority of messages under the unscheduled byte limit @@ -215,30 +199,30 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) sendReady.store(true); } - if (index >= info->packetsSent) { + // Release the queue mutex; the rest of the code won't touch the queue. + int packetsSent = info->packetsSent; + lock_queue.unlock(); + if (index >= packetsSent) { // If this RESEND is only requesting unsent packets, it must be that // this Sender has been busy and the Receiver is trying to ensure there // are no lost packets. Reply BUSY and allow this Sender to send DATA // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->destination, info->id); + driver, message->destination.ip, message->id); } else { // There are some packets to resend but only resend packets that have // already been sent. - resendEnd = std::min(resendEnd, info->packetsSent); + resendEnd = std::min(resendEnd, packetsSent); int resendPriority = policyManager->getResendPriority(); - for (uint16_t i = index; i < resendEnd; ++i) { - Driver::Packet* packet = info->packets->getPacket(i); - packet->priority = resendPriority; - // Packets will be sent at the priority their original priority. + for (int i = index; i < resendEnd; ++i) { + Driver::Packet* resendPacket = message->getPacket(i); Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + Perf::counters.tx_bytes.add(resendPacket->length); + driver->sendPacket(resendPacket, message->destination.ip, + resendPriority); } } - - driver->releasePackets(&packet, 1); } /** @@ -246,12 +230,9 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming GRANT packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) +Sender::handleGrantPacket(Driver::Packet* packet) { Protocol::Packet::GrantHeader* header = static_cast(packet->payload); @@ -261,48 +242,44 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { - // No message for this grant; grant must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this grant; grant must be old. return; } - - bucket->messageTimeouts.setTimeout(&message->messageTimeout); + message->numPingTimeouts = 0; bucket->pingTimeouts.setTimeout(&message->pingTimeout); - if (message->state.load() == OutMessage::Status::IN_PROGRESS) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - + Protocol::Packet::GrantHeader* grantHeader = + static_cast(packet->payload); + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { // Convert the byteLimit to a packet index limit such that the packet // that holds the last granted byte is also considered granted. This // can cause at most 1 packet worth of data to be sent without a grant // but allows the sender to always send full packets. - int incomingGrantIndex = - (header->byteLimit + info->packets->PACKET_DATA_LENGTH - 1) / - info->packets->PACKET_DATA_LENGTH; + int incomingGrantIndex = Util::roundUpIntDiv( + grantHeader->byteLimit, message->PACKET_DATA_LENGTH); // Make that grants don't exceed the number of packets. Internally, // the sender always assumes that packetsGranted <= numPackets. - if (incomingGrantIndex > info->packets->numPackets) { + if (incomingGrantIndex > message->numPackets) { WARNING( "Message (%lu, %lu) GRANT exceeds message length; granted " "packets: %d, message packets %d; extra grants are ignored.", - msgId.transportId, msgId.sequence, incomingGrantIndex, - info->packets->numPackets); - incomingGrantIndex = info->packets->numPackets; + message->id.transportId, message->id.sequence, + incomingGrantIndex, message->numPackets); + incomingGrantIndex = message->numPackets; } + QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::Lock lock_queue(queueMutex); if (info->packetsGranted < incomingGrantIndex) { info->packetsGranted = incomingGrantIndex; // Note that the priority of messages under the unscheduled byte // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. - info->priority = header->priority; + info->priority = grantHeader->priority; sendReady.store(true); } } - - driver->releasePackets(&packet, 1); } /** @@ -310,12 +287,9 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming UNKNOWN packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) +Sender::handleUnknownPacket(Driver::Packet* packet) { Protocol::Packet::UnknownHeader* header = static_cast(packet->payload); @@ -326,8 +300,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { - // No message was found. Just drop the packet. - driver->releasePackets(&packet, 1); + // No message was found. return; } @@ -339,99 +312,13 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // must be a stale response to a ping. } else if (message->options & OutMessage::Options::NO_RETRY) { // Option: NO_RETRY - // Either the Message or the DONE packet was lost; consider the message // failed since the application asked for the message not to be retried. - - // Remove Message from sendQueue. - if (message->numPackets > 1) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - assert(!sendQueue.contains(&info->sendQueueNode)); - } - - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED, true, lock); } else { // Message isn't done yet so we will restart sending the message. - - // Make sure the message is not in the sendQueue before making any - // changes to the message. - if (message->numPackets > 1) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - assert(!sendQueue.contains(&info->sendQueueNode)); - } - - message->state.store(OutMessage::Status::IN_PROGRESS); - - // Get the current policy for unscheduled bytes. - Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - message->destination, message->messageLength); - int unscheduledIndexLimit = - ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH); - - // Update the policy version for each packet - for (uint16_t i = 0; i < message->numPackets; ++i) { - Driver::Packet* dataPacket = message->getPacket(i); - assert(dataPacket != nullptr); - Protocol::Packet::DataHeader* header = - static_cast(dataPacket->payload); - header->policyVersion = policy.version; - header->unscheduledIndexLimit = - Util::downCast(unscheduledIndexLimit); - } - - // Reset the timeouts - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - - assert(message->numPackets > 0); - if (message->numPackets == 1) { - // If there is only one packet in the message, send it right away. - Driver::Packet* dataPacket = message->getPacket(0); - assert(dataPacket != nullptr); - dataPacket->priority = policy.priority; - Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(dataPacket->length); - driver->sendPacket(dataPacket); - message->state.store(OutMessage::Status::SENT); - } else { - // Otherwise, queue the message to be sent in SRPT order. - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - // Some of these values should still be set from when the message - // was first queued. - assert(info->id == message->id); - assert(info->destination == message->destination); - assert(info->packets == message); - // Some values need to be updated - info->unsentBytes = message->messageLength; - info->packetsGranted = - std::min(unscheduledIndexLimit, message->numPackets); - info->priority = policy.priority; - info->packetsSent = 0; - // Insert and move message into the correct order in the priority - // queue. - sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize( - &sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); - sendReady.store(true); - } + startMessage(message, true, lock); } - - driver->releasePackets(&packet, 1); } /** @@ -439,12 +326,9 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming ERROR packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) +Sender::handleErrorPacket(Driver::Packet* packet) { Protocol::Packet::ErrorHeader* header = static_cast(packet->payload); @@ -454,8 +338,7 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { - // No message for this ERROR packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this ERROR packet; must be old. return; } @@ -463,9 +346,7 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) switch (status) { case OutMessage::Status::SENT: // Message was sent and a failure notification was received. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED, false, lock); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -502,8 +383,6 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) msgId.transportId, msgId.sequence); break; } - - driver->releasePackets(&packet, 1); } /** @@ -515,34 +394,29 @@ void Sender::poll() { trySend(); + checkTimeouts(); } /** - * Process any Sender timeouts that have expired. - * - * This method must be called periodically to ensure timely handling of - * expired timeouts. + * Make incremental progress processing expired Sender timeouts. * - * @return - * The rdtsc cycle time when this method should be called again. + * Pulled out of poll() for ease of testing. */ -uint64_t +void Sender::checkTimeouts() { - uint64_t nextTimeout; - - // Ping Timeout - nextTimeout = checkPingTimeouts(); - - // Message Timeout - uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - - return nextTimeout; + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = &messageBuckets.buckets[index]; + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkPingTimeouts(now, bucket); } /** - * Destruct a Message. Will release all contained Packet objects. + * Destruct a Message. + * + * This method will detach the message from the Sender and release all contained + * Packet objects. */ Sender::Message::~Message() { @@ -572,9 +446,9 @@ Sender::Message::append(const void* source, size_t count) int bytesToCopy = std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); Driver::Packet* packet = getOrAllocPacket(packetIndex); - char* destination = static_cast(packet->payload); - destination += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(destination, static_cast(source) + bytesCopied, + char* copyDst = static_cast(packet->payload); + copyDst += packetOffset + TRANSPORT_HEADER_LENGTH; + std::memcpy(copyDst, static_cast(source) + bytesCopied, bytesToCopy); // TODO(cstlee): A Message probably shouldn't be in charge of setting // the packet length. @@ -594,7 +468,8 @@ Sender::Message::append(const void* source, size_t count) void Sender::Message::cancel() { - sender->cancelMessage(this); + SpinLock::Lock lock(bucket->mutex); + setStatus(OutMessage::Status::CANCELED, true, lock); } /** @@ -603,7 +478,54 @@ Sender::Message::cancel() OutMessage::Status Sender::Message::getStatus() const { - return state.load(); + return state.load(std::memory_order_acquire); +} + +/** + * Change the status of this message. + * + * All status change must be done by this method. + * + * @param newStatus + * The new status. + * @param deschedule + * True if we should remove this message from the send queue. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. + * + * Note: the caller should not hold Sender::queueMutex when calling this method + * because our spinlock is not reentrant. + */ +void +Sender::Message::setStatus(Status newStatus, bool deschedule, + [[maybe_unused]] SpinLock::Lock& lock) +{ + // Whether to remove the message from the send queue depends on more than + // just the message status; only the caller has enough information to make + // the decision. + if (deschedule) { + // An outgoing message is on the sendQueue iff. it's still in progress + // and subject to the sender's packet pacing mechanism; test this + // condition first to reduce the expensive locking operation + if (throttled && (getStatus() == OutMessage::Status::IN_PROGRESS)) { + SpinLock::Lock lock_queue(sender->queueMutex); + sender->sendQueue.remove(&queuedMessageInfo.sendQueueNode); + } + } + + state.store(newStatus, std::memory_order_release); + bool endState = (newStatus == OutMessage::Status::CANCELED || + newStatus == OutMessage::Status::COMPLETED || + newStatus == OutMessage::Status::FAILED); + if (endState) { + if (!held.load(std::memory_order_acquire)) { + // Ok to delete now that the message has reached an end state. + sender->dropMessage(this, lock); + } else { + // Cancel the timeouts; the message will be dropped upon release(). + bucket->pingTimeouts.cancelTimeout(&pingTimeout); + } + } } /** @@ -635,9 +557,9 @@ Sender::Message::prepend(const void* source, size_t count) std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); Driver::Packet* packet = getPacket(packetIndex); assert(packet != nullptr); - char* destination = static_cast(packet->payload); - destination += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(destination, static_cast(source) + bytesCopied, + char* copyDst = static_cast(packet->payload); + copyDst += packetOffset + TRANSPORT_HEADER_LENGTH; + std::memcpy(copyDst, static_cast(source) + bytesCopied, bytesToCopy); bytesCopied += bytesToCopy; packetIndex++; @@ -651,7 +573,16 @@ Sender::Message::prepend(const void* source, size_t count) void Sender::Message::release() { - sender->dropMessage(this); + held.store(false, std::memory_order_release); + if (getStatus() != OutMessage::Status::IN_PROGRESS) { + // Ok to delete immediately since we don't have to wait for the message + // to be sent. + SpinLock::Lock lock(bucket->mutex); + sender->dropMessage(this, lock); + } else { + // Defer deletion and wait for the message to be SENT at least. + } + Perf::counters.released_tx_messages.add(1); } /** @@ -697,10 +628,19 @@ Sender::Message::reserve(size_t count) * @copydoc Homa::OutMessage::send() */ void -Sender::Message::send(Driver::Address destination, +Sender::Message::send(SocketAddress destination, Sender::Message::Options options) { - sender->sendMessage(this, destination, options); + // Prepare the message + this->destination = destination; + this->options = options; + + // All single-packet messages currently bypass our packet pacing mechanism. + throttled = (numPackets > 1); + + // Kick start the transmission. + SpinLock::Lock lock(bucket->mutex); + sender->startMessage(this, false, lock); } /** @@ -746,259 +686,212 @@ Sender::Message::getOrAllocPacket(size_t index) } /** - * Queue a message to be sent. + * Drop a message that is no longer needed (released by the application). * * @param message - * Sender::Message to be sent. - * @param destination - * Destination address for this message. - * @param options - * Flags indicating requested non-default send behavior. + * Message which will be detached from the transport and destroyed. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. + */ +void +Sender::dropMessage(Sender::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) +{ + // We assume that this message has been unlinked from the sendQueue before + // this method is invoked. + assert(message->getStatus() != OutMessage::Status::IN_PROGRESS); + + // Remove this message from the other data structures of the Sender. + MessageBucket* bucket = message->bucket; + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + bucket->messages.remove(&message->bucketNode); + + Perf::counters.destroyed_tx_messages.add(1); + messageAllocator.destroy(message); +} + +/** + * (Re)start the transmission of an outgoing message. * - * @sa dropMessage() + * @param message + * Sender::Message to be sent. + * @param restart + * False if the message is new to the transport; true means the message is + * restarted by the transport. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. */ void -Sender::sendMessage(Sender::Message* message, Driver::Address destination, - Sender::Message::Options options) +Sender::startMessage(Sender::Message* message, bool restart, + [[maybe_unused]] SpinLock::Lock& lock) { - // Prepare the message - assert(message->driver == driver); - // Allocate a new message id - Protocol::MessageId id(transportId, nextMessageSequenceNumber++); + // If we are restarting an existing message, make sure it's not in the + // sendQueue before making any changes to it. + message->setStatus(OutMessage::Status::IN_PROGRESS, restart, lock); + // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - destination, message->messageLength); - int unscheduledPacketLimit = - ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH); - - message->id = id; - message->destination = destination; - message->options = options; - message->state.store(OutMessage::Status::IN_PROGRESS); - - int actualMessageLen = 0; - // fill out metadata. - for (int i = 0; i < message->numPackets; ++i) { - Driver::Packet* packet = message->getPacket(i); - if (packet == nullptr) { - PANIC( - "Incomplete message with id (%lu:%lu); missing packet " - "at offset %d; this shouldn't happen.", - message->id.transportId, message->id.sequence, - i * message->PACKET_DATA_LENGTH); + message->destination.ip, message->messageLength); + uint16_t unscheduledIndexLimit = Util::roundUpIntDiv( + policy.unscheduledByteLimit, message->PACKET_DATA_LENGTH); + + if (!restart) { + // Fill out packet headers. + int actualMessageLen = 0; + for (int i = 0; i < message->numPackets; ++i) { + Driver::Packet* packet = message->getPacket(i); + assert(packet != nullptr); + new (packet->payload) Protocol::Packet::DataHeader( + message->source.port, message->destination.port, message->id, + Util::downCast(message->messageLength), + policy.version, unscheduledIndexLimit, + Util::downCast(i)); + actualMessageLen += + (packet->length - message->TRANSPORT_HEADER_LENGTH); } + assert(message->messageLength == actualMessageLen); - packet->address = message->destination; - new (packet->payload) Protocol::Packet::DataHeader( - message->id, Util::downCast(message->messageLength), - policy.version, Util::downCast(unscheduledPacketLimit), - Util::downCast(i)); - actualMessageLen += (packet->length - message->TRANSPORT_HEADER_LENGTH); + // Start tracking the new message + MessageBucket* bucket = message->bucket; + assert(!bucket->messages.contains(&message->bucketNode)); + bucket->messages.push_back(&message->bucketNode); + } else { + // Update the policy version for each packet + for (int i = 0; i < message->numPackets; ++i) { + Driver::Packet* packet = message->getPacket(i); + assert(packet != nullptr); + Protocol::Packet::DataHeader* header = + static_cast(packet->payload); + header->policyVersion = policy.version; + header->unscheduledIndexLimit = unscheduledIndexLimit; + } } - // perform sanity checks. - assert(message->driver == driver); - assert(message->messageLength == actualMessageLen); - assert(message->TRANSPORT_HEADER_LENGTH == - sizeof(Protocol::Packet::DataHeader)); - - // Track message - MessageBucket* bucket = messageBuckets.getBucket(message->id); - SpinLock::Lock lock(bucket->mutex); - assert(!bucket->messages.contains(&message->bucketNode)); - bucket->messages.push_back(&message->bucketNode); - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - + // Kick start the message. assert(message->numPackets > 0); - if (message->numPackets == 1) { - // If there is only one packet in the message, send it right away. - Driver::Packet* packet = message->getPacket(0); - assert(packet != nullptr); - packet->priority = policy.priority; + bool needTimeouts = true; + if (!message->throttled) { + // The message is too small to be scheduled, send it right away. + Driver::Packet* dataPacket = message->getPacket(0); + assert(dataPacket != nullptr); Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); - message->state.store(OutMessage::Status::SENT); + Perf::counters.tx_bytes.add(dataPacket->length); + driver->sendPacket(dataPacket, message->destination.ip, + policy.priority); + message->setStatus(OutMessage::Status::SENT, false, lock); + // This message must be still be held by the application since the + // message still exists (it would have been removed when dropped + // because single packet messages are never IN_PROGRESS). Assuming + // the message is still held, we can skip the auto removal of SENT + // and !held messages. + assert(message->held); + if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + needTimeouts = false; + } } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - info->id = id; - info->destination = message->destination; - info->packets = message; + // Some values need to be updated info->unsentBytes = message->messageLength; info->packetsGranted = - std::min(unscheduledPacketLimit, message->numPackets); + std::min(unscheduledIndexLimit, + Util::downCast(message->numPackets)); info->priority = policy.priority; info->packetsSent = 0; // Insert and move message into the correct order in the priority queue. sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::deprioritize(&sendQueue, &info->sendQueueNode); sendReady.store(true); } -} -/** - * Inform the Sender that a Message no longer needs to be sent. - * - * @param message - * The Sender::Message that is no longer needs to be sent. - */ -void -Sender::cancelMessage(Sender::Message* message) -{ - Protocol::MessageId msgId = message->id; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - if (bucket->messages.contains(&message->bucketNode)) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - if (message->numPackets > 1 && - message->state == OutMessage::Status::IN_PROGRESS) { - // Check to see if the message needs to be dequeued. - SpinLock::Lock lock_queue(queueMutex); - // Recheck state with lock in case it change right before this. - if (message->state == OutMessage::Status::IN_PROGRESS) { - QueuedMessageInfo* info = &message->queuedMessageInfo; - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - } - bucket->messages.remove(&message->bucketNode); - message->state.store(OutMessage::Status::CANCELED); + // Initialize the timeouts + if (needTimeouts) { + message->numPingTimeouts = 0; + message->bucket->pingTimeouts.setTimeout(&message->pingTimeout); } } /** - * Inform the Sender that a Message is no longer needed. - * - * @param message - * The Sender::Message that is no longer needed. - */ -void -Sender::dropMessage(Sender::Message* message) -{ - cancelMessage(message); - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); -} - -/** - * Process any outbound messages that have timed out due to lack of activity - * from the Receiver. + * Process any outbound messages in a given bucket that need to be pinged to + * ensure the message is kept alive by the receiver. * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose ping timeouts should be checked. */ -uint64_t -Sender::checkMessageTimeouts() +void +Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - assert(MessageBucketMap::NUM_BUCKETS > 0); - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock(bucket->mutex); - // No remaining timeouts. - if (bucket->messageTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->messageTimeouts.timeoutIntervalCycles; - break; - } - Message* message = &bucket->messageTimeouts.list.front(); - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed()) { - nextTimeout = message->messageTimeout.expirationCycleTime; - break; - } - // Found expired timeout. - if (message->state != OutMessage::Status::COMPLETED) { - message->state.store(OutMessage::Status::FAILED); - // A sent NO_KEEP_ALIVE message should never reach this state - // since the shorter ping timeout should have already canceled - // the message timeout. - assert( - !((message->state == OutMessage::Status::SENT) && - (message->options & OutMessage::Options::NO_KEEP_ALIVE))); - } - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + if (!bucket->pingTimeouts.anyElapsed(now)) { + return; } - return globalNextTimeout; -} -/** - * Process any outbound messages that need to be pinged to ensure the - * message is kept alive by the receiver. - * - * Pulled out of checkTimeouts() for ease of testing. - * - * @return - * The rdtsc cycle time when this method should be called again. - */ -uint64_t -Sender::checkPingTimeouts() -{ - uint64_t globalNextTimeout = UINT64_MAX; - assert(MessageBucketMap::NUM_BUCKETS > 0); - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock(bucket->mutex); - // No remaining timeouts. - if (bucket->pingTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->pingTimeouts.timeoutIntervalCycles; - break; - } - Message* message = &bucket->pingTimeouts.list.front(); - // No remaining expired timeouts. - if (!message->pingTimeout.hasElapsed()) { - nextTimeout = message->pingTimeout.expirationCycleTime; - break; - } - // Found expired timeout. - if (message->state == OutMessage::Status::COMPLETED || - message->state == OutMessage::Status::FAILED) { - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - continue; - } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->state == OutMessage::Status::SENT) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + while (true) { + SpinLock::Lock lock_bucket(bucket->mutex); + // No remaining timeouts. + if (bucket->pingTimeouts.empty()) { + break; + } + Message* message = &bucket->pingTimeouts.front(); + // No remaining expired timeouts. + if (!message->pingTimeout.hasElapsed(now)) { + break; + } + // Found expired timeout. + OutMessage::Status status = message->getStatus(); + assert(status == OutMessage::Status::IN_PROGRESS || + status == OutMessage::Status::SENT); + if (message->options & OutMessage::Options::NO_KEEP_ALIVE && + status == OutMessage::Status::SENT) { + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + continue; + } else { + message->numPingTimeouts++; + if (message->numPingTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { + // Found expired message. + message->setStatus(OutMessage::Status::FAILED, true, + lock_bucket); continue; } else { bucket->pingTimeouts.setTimeout(&message->pingTimeout); } + } - // Have not heard from the Receiver in the last timeout period. Ping - // the receiver to ensure it still knows about this Message. - Perf::counters.tx_ping_pkts.add(1); - ControlPacket::send( - message->driver, message->destination, message->id); + // Check if sender still has packets to send + if (status == OutMessage::Status::IN_PROGRESS) { + QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::Lock lock_queue(queueMutex); + if (info->packetsSent < info->packetsGranted) { + // Sender is blocked on itself, no need to send ping + continue; + } } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + + // Have not heard from the Receiver in the last timeout period. Ping + // the receiver to ensure it still knows about this Message. + Perf::counters.tx_ping_pkts.add(1); + ControlPacket::send( + message->driver, message->destination.ip, message->id); } - return globalNextTimeout; } /** * Send out packets for any messages with unscheduled/granted bytes. + * + * Pulled out of poll() for ease of testing. */ void Sender::trySend() { - uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); + Perf::Timer timer; bool idle = true; + // Skip when there are no messages to send. if (!sendReady) { return; @@ -1022,15 +915,18 @@ Sender::trySend() sendReady = false; auto it = sendQueue.begin(); while (it != sendQueue.end()) { + // No need to acquire the bucket mutex here because we are only going to + // access the const members of a Message; besides, the message can't get + // destroyed concurrently because whoever needs to destroy the message + // will need to acquire queueMutex first. Message& message = *it; - assert(message.state.load() == OutMessage::Status::IN_PROGRESS); + assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; - assert(info->packetsGranted <= info->packets->numPackets); + assert(info->packetsGranted <= message.numPackets); while (info->packetsSent < info->packetsGranted) { // There are packets to send idle = false; - Driver::Packet* packet = - info->packets->getPacket(info->packetsSent); + Driver::Packet* packet = message.getPacket(info->packetsSent); assert(packet != nullptr); queuedBytesEstimate += packet->length; // Check if the send limit would be reached... @@ -1038,25 +934,22 @@ Sender::trySend() break; } // ... if not, send away! - packet->priority = info->priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message.destination.ip, info->priority); int packetDataBytes = - packet->length - info->packets->TRANSPORT_HEADER_LENGTH; + packet->length - message.TRANSPORT_HEADER_LENGTH; assert(info->unsentBytes >= packetDataBytes); info->unsentBytes -= packetDataBytes; // The Message's unsentBytes only ever decreases. See if the // updated Message should move up in the queue. - Intrusive::prioritize( - &sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::prioritize(&sendQueue, &info->sendQueueNode); ++info->packetsSent; } - if (info->packetsSent >= info->packets->numPackets) { + if (info->packetsSent >= message.numPackets) { // We have finished sending the message. - message.state.store(OutMessage::Status::SENT); it = sendQueue.remove(it); + sentMessages.emplace_back(message.bucket, message.id); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; @@ -1067,13 +960,36 @@ Sender::trySend() break; } } - sending.clear(); - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; + + // Unlock the queueMutex to process any SENT messages to ensure any bucket + // mutex is always acquired before the send queueMutex. + lock_queue.unlock(); + for (auto& [bucket, id] : sentMessages) { + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(id, lock); + if (message == nullptr) { + // Message must have already been deleted. + continue; + } + + // The message status may change after the queueMutex is unlocked; + // ignore this message if that is the case. + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { + message->setStatus(OutMessage::Status::SENT, false, lock); + if (!message->held.load(std::memory_order_acquire)) { + // Ok to delete now that the message has been sent. + dropMessage(message, lock); + } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + } + } + } + if (!idle) { - Perf::counters.active_cycles.add(elapsed_cycles); - } else { - Perf::counters.idle_cycles.add(elapsed_cycles); + Perf::counters.active_cycles.add(timer.split()); } } diff --git a/src/Sender.h b/src/Sender.h index 471925a..b039d51 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -44,37 +45,27 @@ class Sender { explicit Sender(uint64_t transportId, Driver* driver, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); - virtual ~Sender(); - - virtual Homa::OutMessage* allocMessage(); - virtual void handleDonePacket(Driver::Packet* packet, Driver* driver); - virtual void handleResendPacket(Driver::Packet* packet, Driver* driver); - virtual void handleGrantPacket(Driver::Packet* packet, Driver* driver); - virtual void handleUnknownPacket(Driver::Packet* packet, Driver* driver); - virtual void handleErrorPacket(Driver::Packet* packet, Driver* driver); + virtual ~Sender() = default; + + virtual Homa::OutMessage* allocMessage(uint16_t sourcePort); + virtual void handleDonePacket(Driver::Packet* packet); + virtual void handleResendPacket(Driver::Packet* packet); + virtual void handleGrantPacket(Driver::Packet* packet); + virtual void handleUnknownPacket(Driver::Packet* packet); + virtual void handleErrorPacket(Driver::Packet* packet); + virtual void poll(); - virtual uint64_t checkTimeouts(); + virtual void checkTimeouts(); private: /// Forward declarations class Message; + struct MessageBucket; /** * Contains metadata for a Message that has been queued to be sent. */ struct QueuedMessageInfo { - /** - * Implements a binary comparison function for the strict weak priority - * ordering of two Message objects. - */ - struct ComparePriority { - bool operator()(const Message& a, const Message& b) - { - return a.queuedMessageInfo.unsentBytes < - b.queuedMessageInfo.unsentBytes; - } - }; - /** * QueuedMessageInfo constructor. * @@ -82,91 +73,94 @@ class Sender { * Message to which this metadata is associated. */ explicit QueuedMessageInfo(Message* message) - : id(0, 0) - , destination() - , packets(nullptr) - , unsentBytes(0) + : unsentBytes(0) , packetsGranted(0) - , priority(0) , packetsSent(0) + , priority(0) , sendQueueNode(message) {} - /// Contains the unique identifier for this message. - Protocol::MessageId id; - - /// Contains destination address this message. - Driver::Address destination; - - /// Handle to the queue Message for access to the packets that will - /// be sent. This member documents that the packets are logically owned - /// by the sendQueue and thus protected by the queueMutex. - Message* packets; - /// The number of bytes that still need to be sent for a queued Message. + /// This variable is used to rank messages in SRPT order so it must be + /// protected by Sender::queueMutex. int unsentBytes; /// The number of packets that can be sent for this Message. int packetsGranted; - /// The network priority at which this Message should be sent. - int priority; - /// The number of packets that have been sent for this Message. int packetsSent; + /// The network priority at which this Message should be sent. + int priority; + /// Intrusive structure used to enqueue the associated Message into - /// the sendQueue. + /// the sendQueue. Protected by Sender::queueMutex. Intrusive::List::Node sendQueueNode; }; /** * Represents an outgoing message that can be sent. - * - * Sender::Message objects are contained in the Transport::Op but should - * only be accessed by the Sender. */ class Message : public Homa::OutMessage { public: + /** + * Implements a binary comparison function for the strict weak priority + * ordering of two Message objects. + */ + struct ComparePriority { + bool operator()(const Message& a, const Message& b) + { + return a.queuedMessageInfo.unsentBytes < + b.queuedMessageInfo.unsentBytes; + } + }; + /** * Construct an Message. */ - explicit Message(Sender* sender, Driver* driver) + explicit Message(Sender* sender, uint64_t messageId, + uint16_t sourcePort) : sender(sender) - , driver(driver) + , driver(sender->driver) , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) - , id(0, 0) + , id(sender->transportId, messageId) + , bucket(sender->messageBuckets.getBucket(id)) + , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) + , held(true) , start(0) , messageLength(0) , numPackets(0) + , throttled() , occupied() // packets is not initialized to reduce the work done during // construction. See Message::occupied. , state(Status::NOT_STARTED) , bucketNode(this) - , messageTimeout(this) + , numPingTimeouts(0) , pingTimeout(this) , queuedMessageInfo(this) {} virtual ~Message(); - virtual void append(const void* source, size_t count); - virtual void cancel(); - virtual Status getStatus() const; - virtual size_t length() const; - virtual void prepend(const void* source, size_t count); - virtual void release(); - virtual void reserve(size_t count); - virtual void send(Driver::Address destination, - Options options = Options::NONE); + void append(const void* source, size_t count) override; + void cancel() override; + Status getStatus() const override; + void setStatus(Status newStatus, bool deschedule, SpinLock::Lock& lock); + size_t length() const override; + void prepend(const void* source, size_t count) override; + void release() override; + void reserve(size_t count) override; + void send(SocketAddress destination, + Options options = Options::NONE) override; private: /// Define the maximum number of packets that a message can hold. - static const size_t MAX_MESSAGE_PACKETS = 1024; + static const int MAX_MESSAGE_PACKETS = 1024; Driver::Packet* getPacket(size_t index) const; Driver::Packet* getOrAllocPacket(size_t index); @@ -186,29 +180,53 @@ class Sender { const int PACKET_DATA_LENGTH; /// Contains the unique identifier for this message. - Protocol::MessageId id; + const Protocol::MessageId id; + + /// Message bucket this message belongs to. + MessageBucket* const bucket; - /// Contains destination address this message. - Driver::Address destination; + /// Contains source address of this message. + const SocketAddress source; - /// Contains flags for any requested optional send behavior. + /// Contains destination address of this message. Must be constant after + /// send() is invoked. + SocketAddress destination; + + /// Contains flags for any requested optional send behavior. Must be + /// constant after send() is invoked. Options options; - /// First byte where data is or will go if empty. + /// True if a pointer to this message is accessible by the application + /// (e.g. the message has been allocated via allocMessage() but has not + /// been release via dropMessage()); false, otherwise. + std::atomic held; + + /// First byte where data is or will go if empty. Must be constant after + /// send() is invoked. int start; /// Number of bytes in this Message including any reserved headroom. + /// Must be constant after send() is invoked. int messageLength; - /// Number of packets currently contained in this message. + /// Number of packets currently contained in this message. Must be + /// constant after send() is invoked. int numPackets; - /// Bit array representing which entires in the _packets_ array are set. - /// Used to avoid having to zero out the entire _packets_ array. + /// False if this message is so small that we are better off sending it + /// right away to reduce software overhead; otherwise, the message must + /// be put into Sender::sendQueue and transmitted in SRPT order. Must be + /// constant after send() is invoked. + bool throttled; + + /// Bit array representing which entries in the _packets_ array are set. + /// Used to avoid having to zero out the entire _packets_ array. Must be + /// constant after send() is invoked. std::bitset occupied; /// Collection of Packet objects that make up this context's Message. - /// These Packets will be released when this context is destroyed. + /// These Packets will be released when this context is destroyed. Must + /// be constant after send() is invoked. Driver::Packet* packets[MAX_MESSAGE_PACKETS]; /// This message's current state. @@ -219,10 +237,9 @@ class Sender { /// is protected by the associated MessageBucket::mutex; Intrusive::List::Node bucketNode; - /// Intrusive structure used by the Sender to keep track when the - /// sending of this message should be considered failed. Access to this + /// Number of ping timeouts that occurred in a row. Access to this /// structure is protected by the associated MessageBucket::mutex. - Timeout messageTimeout; + int numPingTimeouts; /// Intrusive structure used by the Sender to keep track when this /// message should be checked to ensure progress is still being made. @@ -248,21 +265,31 @@ class Sender { /** * MessageBucket constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. + * @param Sender + * Sender that owns this bucket. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - explicit MessageBucket(uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) - : mutex() + explicit MessageBucket(Sender* sender, uint64_t pingIntervalCycles) + : sender(sender) + , mutex() , messages() - , messageTimeouts(messageTimeoutCycles) , pingTimeouts(pingIntervalCycles) {} + /** + * Destruct a MessageBucket. Will destroy all contained Message objects. + */ + ~MessageBucket() + { + // Intrusive::List is not responsible for destructing its elements; + // it must be done manually. + for (auto& message : messages) { + sender->messageAllocator.destroy(&message); + } + } + /** * Return the Message with the given MessageId. * @@ -275,29 +302,27 @@ class Sender { * A pointer to the Message if found; nullptr, otherwise. */ Message* findMessage(const Protocol::MessageId& msgId, - const SpinLock::Lock& lock) + [[maybe_unused]] const SpinLock::Lock& lock) { - (void)lock; - Message* message = nullptr; - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (it->id == msgId) { - message = &(*it); - break; + for (auto& it : messages) { + if (it.id == msgId) { + return ⁢ } } - return message; + return nullptr; } - /// Mutex protecting the contents of this bucket. + /// Sender that owns this object. + Sender* const sender; + + /// Mutex protecting the contents of this bucket. See Sender::queueMutex + /// for locking order constraints. SpinLock mutex; /// Collection of outbound messages Intrusive::List messages; - /// Maintains Message objects in increasing order of timeout. - TimeoutManager messageTimeouts; - - /// Maintains Message object in increase order of ping timeout. + /// Maintains Message objects in increasing order of ping timeout. TimeoutManager pingTimeouts; }; @@ -312,6 +337,9 @@ class Sender { */ static const int NUM_BUCKETS = 256; + // Make sure the number of buckets is a power of 2. + static_assert(Util::isPowerOfTwo(NUM_BUCKETS)); + /** * Bit mask used to map from a hashed key to the bucket index. */ @@ -321,75 +349,51 @@ class Sender { static_assert(NUM_BUCKETS == HASH_KEY_MASK + 1); /** - * Helper method to create the set of buckets. + * MessageBucketMap constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. + * @param sender + * Sender that owns this bucket map. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - static std::array makeBuckets( - uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) + explicit MessageBucketMap(Sender* sender, uint64_t pingIntervalCycles) + : buckets() + , hasher() { - std::array buckets; + buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets[i] = - new MessageBucket(messageTimeoutCycles, pingIntervalCycles); + buckets.emplace_back(sender, pingIntervalCycles); } - return buckets; } - /** - * MessageBucketMap constructor. - * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. - * @param pingIntervalCycles - * Number of cycles of inactivity to wait between checking on the - * liveness of a Message. - */ - explicit MessageBucketMap(uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) - : buckets(makeBuckets(messageTimeoutCycles, pingIntervalCycles)) - , hasher() - {} - /** * MessageBucketMap destructor. */ - ~MessageBucketMap() - { - for (int i = 0; i < NUM_BUCKETS; ++i) { - delete buckets[i]; - } - } + ~MessageBucketMap() = default; /** * Return the MessageBucket that should hold a Message with the given * MessageId. */ - MessageBucket* getBucket(const Protocol::MessageId& msgId) const + MessageBucket* getBucket(const Protocol::MessageId& msgId) { uint index = hasher(msgId) & HASH_KEY_MASK; - return buckets[index]; + return &buckets[index]; } - /// Array of buckets. - std::array const buckets; + /// Array of NUM_BUCKETS buckets. Defined as a vector to avoid the need + /// for a default constructor in MessageBucket. + std::vector buckets; /// MessageId hash function container. Protocol::MessageId::Hasher hasher; }; - void sendMessage(Sender::Message* message, Driver::Address destination, - Message::Options options = Message::Options::NONE); - void cancelMessage(Sender::Message* message); - void dropMessage(Sender::Message* message); - uint64_t checkMessageTimeouts(); - uint64_t checkPingTimeouts(); + void dropMessage(Sender::Message* message, SpinLock::Lock& lock); + void startMessage(Sender::Message* message, bool restart, + SpinLock::Lock& lock); + void checkPingTimeouts(uint64_t now, MessageBucket* bucket); void trySend(); /// Transport identifier. @@ -408,10 +412,16 @@ class Sender { /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; + /// The number of ping timeouts to occur before declaring a message timeout. + const int MESSAGE_TIMEOUT_INTERVALS; + /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - /// Protects the readyQueue. + /// Protects the sendQueue, including all member variables of its items. + /// When multiple locks must be acquired, this class follows the locking + /// order constraint below ("<" means "acquired before"): + /// MessageBucket::mutex < queueMutex SpinLock queueMutex; /// A list of outbound messages that have unsent packets. Messages are kept @@ -427,13 +437,18 @@ class Sender { /// if there is work to do is more efficient. std::atomic sendReady; + /// Used to temporarily hold messages whose last packets have just been + /// sent out. Always empty unless inside trySend(). + std::vector> sentMessages; + + /// The index of the next bucket in the messageBuckets::buckets array to + /// process in the poll loop. The index is held in the lower order bits of + /// this variable; the higher order bits should be masked off using the + /// MessageBucketMap::HASH_KEY_MASK bit mask. + std::atomic nextBucketIndex; + /// Used to allocate Message objects. - struct { - /// Protects the messageAllocator.pool - SpinLock mutex; - /// Pool allocator for Message objects. - ObjectPool pool; - } messageAllocator; + ObjectPool messageAllocator; }; } // namespace Core diff --git a/src/SenderTest.cc b/src/SenderTest.cc index fdae6ab..07630a8 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -36,13 +36,13 @@ class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket(&payload) + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , sender() , savedLogPolicy(Debug::getLogPolicy()) { ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1031)); ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); @@ -59,7 +59,7 @@ class SenderTest : public ::testing::Test { } NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Sender* sender; @@ -124,7 +124,7 @@ TEST_F(SenderTest, allocMessage) { EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); } @@ -132,7 +132,7 @@ TEST_F(SenderTest, handleDonePacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); Protocol::Packet::DoneHeader* header = @@ -143,7 +143,7 @@ TEST_F(SenderTest, handleDonePacket_basic) .Times(2); // No message. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); @@ -151,7 +151,7 @@ TEST_F(SenderTest, handleDonePacket_basic) message->state = Homa::OutMessage::Status::SENT; // Normal expected behavior. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -162,7 +162,7 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::CANCELED; @@ -173,14 +173,14 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); } TEST_F(SenderTest, handleDonePacket_COMPLETED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::COMPLETED; @@ -194,7 +194,7 @@ TEST_F(SenderTest, handleDonePacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -211,7 +211,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::FAILED; @@ -225,7 +225,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -244,7 +244,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -258,7 +258,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -277,7 +277,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::NOT_STARTED; @@ -291,7 +291,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -310,10 +310,12 @@ TEST_F(SenderTest, handleResendPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; + std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); + priorities.push_back(0); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -331,22 +333,26 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[4] = p; }); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); EXPECT_EQ(4, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_EQ(0, packets[2]->priority); - EXPECT_EQ(7, packets[3]->priority); - EXPECT_EQ(7, packets[4]->priority); - EXPECT_EQ(0, packets[5]->priority); + EXPECT_EQ(0, priorities[2]); + EXPECT_EQ(7, priorities[3]); + EXPECT_EQ(7, priorities[4]); + EXPECT_EQ(0, priorities[5]); EXPECT_TRUE(sender->sendReady.load()); for (int i = 0; i < 10; ++i) { @@ -366,7 +372,7 @@ TEST_F(SenderTest, handleResendPacket_staleResend) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); } TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) @@ -374,10 +380,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload); + new Homa::Mock::MockDriver::MockPacket{payload}; setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -393,7 +399,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -414,10 +420,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -440,7 +446,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -464,9 +470,9 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket(data); + Homa::Mock::MockDriver::MockPacket dataPacket{data}; for (int i = 0; i < 10; ++i) { setMessagePacket(message, i, &dataPacket); } @@ -484,18 +490,18 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket(busy); + Homa::Mock::MockDriver::MockPacket busyPacket{busy}; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) .Times(1); // Expect no data to be sent but the RESEND packet to be release. - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(0); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(0); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); @@ -511,7 +517,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -530,7 +536,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(7, info->packetsGranted); EXPECT_EQ(6, info->priority); @@ -543,7 +549,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -565,7 +571,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -590,7 +596,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -608,7 +614,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(5, info->packetsGranted); EXPECT_EQ(2, info->priority); @@ -628,23 +634,23 @@ TEST_F(SenderTest, handleGrantPacket_dropGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_basic) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket{payload[i]}; Protocol::Packet::DataHeader* header = static_cast(packet->payload); header->policyVersion = policyOld.version; @@ -674,12 +680,12 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); for (int i = 0; i < 3; ++i) { @@ -706,13 +712,13 @@ TEST_F(SenderTest, handleUnknownPacket_basic) TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - Homa::Mock::MockDriver::MockPacket dataPacket(payload); + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::MockPacket dataPacket{payload}; Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; @@ -733,13 +739,13 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(policyNew.version, dataHeader->policyVersion); @@ -751,20 +757,63 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_FALSE(sender->sendReady.load()); } +TEST_F(SenderTest, handleUnknownPacket_NO_KEEP_ALIVE) +{ + Protocol::MessageId id = {42, 1}; + SocketAddress destination = {22, 60001}; + Core::Policy::Unscheduled policyNew = {2, 3000, 2}; + + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::MockPacket dataPacket{payload}; + Protocol::Packet::DataHeader* dataHeader = + static_cast(dataPacket.payload); + setMessagePacket(message, 0, &dataPacket); + message->destination = destination; + message->messageLength = 500; + message->state.store(Homa::OutMessage::Status::SENT); + message->options = OutMessage::Options::NO_KEEP_ALIVE; + SenderTest::addMessage(sender, id, message); + Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + bucket->messageTimeouts.setTimeout(&message->messageTimeout); + bucket->pingTimeouts.setTimeout(&message->pingTimeout); + EXPECT_FALSE(bucket->messageTimeouts.empty()); + EXPECT_FALSE(bucket->pingTimeouts.empty()); + + Protocol::Packet::UnknownHeader* header = + static_cast(mockPacket.payload); + header->common.messageId = id; + + EXPECT_CALL( + mockPolicyManager, + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) + .WillOnce(Return(policyNew)); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), Eq(destination.ip), _)) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + .Times(1); + + sender->handleUnknownPacket(&mockPacket); + + EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); + EXPECT_TRUE(bucket->messageTimeouts.empty()); + EXPECT_TRUE(bucket->pingTimeouts.empty()); +} + TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); message->options = OutMessage::Options::NO_RETRY; std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket{payload[i]}; packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -784,7 +833,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); @@ -806,14 +855,14 @@ TEST_F(SenderTest, handleUnknownPacket_no_message) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_done) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -826,7 +875,7 @@ TEST_F(SenderTest, handleUnknownPacket_done) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message->state); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -838,7 +887,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -851,7 +900,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -863,7 +912,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::CANCELED); @@ -874,7 +923,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state); } @@ -884,7 +933,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::NOT_STARTED); @@ -898,7 +947,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -920,7 +969,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::IN_PROGRESS); @@ -934,7 +983,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -956,7 +1005,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); @@ -970,7 +1019,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -992,7 +1041,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::FAILED); @@ -1006,7 +1055,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -1029,7 +1078,7 @@ TEST_F(SenderTest, handleErrorPacket_noMessage) header->common.messageId = id; EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); } TEST_F(SenderTest, poll) @@ -1040,24 +1089,13 @@ TEST_F(SenderTest, poll) TEST_F(SenderTest, checkTimeouts) { - Sender::Message message(sender, &mockDriver); Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); - bucket->pingTimeouts.setTimeout(&message.pingTimeout); - bucket->messageTimeouts.setTimeout(&message.messageTimeout); - - message.pingTimeout.expirationCycleTime = 10010; - message.messageTimeout.expirationCycleTime = 10020; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10010U, sender->checkTimeouts()); - message.pingTimeout.expirationCycleTime = 10030; + EXPECT_EQ(0, sender->nextBucketIndex.load()); - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10020U, sender->checkTimeouts()); + sender->checkTimeouts(); - bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); - bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); + EXPECT_EQ(1, sender->nextBucketIndex.load()); } TEST_F(SenderTest, Message_destructor) @@ -1065,7 +1103,7 @@ TEST_F(SenderTest, Message_destructor) const int MAX_RAW_PACKET_LENGTH = 2000; ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message* msg = new Sender::Message(sender, &mockDriver); + Sender::Message* msg = new Sender::Message(sender, 0); const uint16_t NUM_PKTS = 5; @@ -1086,10 +1124,10 @@ TEST_F(SenderTest, Message_append_basic) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1126,10 +1164,10 @@ TEST_F(SenderTest, Message_append_truncated) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1155,7 +1193,7 @@ TEST_F(SenderTest, Message_append_truncated) EXPECT_STREQ("append", m.function); EXPECT_EQ(int(Debug::LogLevel::WARNING), m.logLevel); EXPECT_EQ( - "Max message size limit (2020352B) reached; 7 of 14 bytes appended", + "Max message size limit (2016256B) reached; 7 of 14 bytes appended", m.message); Debug::setLogHandler(std::function()); @@ -1174,7 +1212,7 @@ TEST_F(SenderTest, Message_getStatus) TEST_F(SenderTest, Message_length) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); msg.messageLength = 200; msg.start = 20; EXPECT_EQ(180U, msg.length()); @@ -1183,10 +1221,10 @@ TEST_F(SenderTest, Message_length) TEST_F(SenderTest, Message_prepend) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1218,10 +1256,10 @@ TEST_F(SenderTest, Message_release) TEST_F(SenderTest, Message_reserve) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1259,7 +1297,7 @@ TEST_F(SenderTest, Message_send) TEST_F(SenderTest, Message_getPacket) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); Driver::Packet* packet = (Driver::Packet*)42; msg.packets[0] = packet; @@ -1273,10 +1311,10 @@ TEST_F(SenderTest, Message_getPacket) TEST_F(SenderTest, Message_getOrAllocPacket) { // TODO(cstlee): cleanup - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); @@ -1298,9 +1336,9 @@ TEST_F(SenderTest, MessageBucket_findMessage) Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); Sender::Message* msg0 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::Message* msg1 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); msg0->id = {42, 0}; msg1->id = {42, 1}; Protocol::MessageId id_none = {42, 42}; @@ -1329,35 +1367,43 @@ TEST_F(SenderTest, sendMessage_basic) { Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; + uint16_t sport = 0; + uint16_t dport = 60001; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(sport)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &mockPacket); message->messageLength = 420; mockPacket.length = message->messageLength + message->TRANSPORT_HEADER_LENGTH; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, dport}; Core::Policy::Unscheduled policy = {1, 3000, 2}; EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + int mockPriority = 0; + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .WillOnce( + [&mockPriority](auto _1, auto _2, int p) { mockPriority = p; }); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Sender::Message::Options::NO_RETRY, message->options); // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); + EXPECT_EQ(htobe16(sport), header->common.prefix.sport); + EXPECT_EQ(htobe16(dport), header->common.prefix.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); @@ -1370,8 +1416,7 @@ TEST_F(SenderTest, sendMessage_basic) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); // Check sent packet metadata - EXPECT_EQ(22U, (uint64_t)mockPacket.address); - EXPECT_EQ(policy.priority, mockPacket.priority); + EXPECT_EQ(policy.priority, mockPriority); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_FALSE(sender->sendReady.load()); @@ -1381,48 +1426,48 @@ TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - NiceMock packet0(payload0); - NiceMock packet1(payload1); + Homa::Mock::MockDriver::MockPacket packet0{payload0}; + Homa::Mock::MockDriver::MockPacket packet1{payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &packet0); setMessagePacket(message, 1, &packet1); message->messageLength = 1420; - packet0.length = 1000 + 27; - packet1.length = 420 + 27; - Driver::Address destination = (Driver::Address)22; + packet0.length = 1000 + 31; + packet1.length = 420 + 31; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 1000, 2}; - EXPECT_EQ(27U, sizeof(Protocol::Packet::DataHeader)); + EXPECT_EQ(31U, sizeof(Protocol::Packet::DataHeader)); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(1420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(1420))) .WillOnce(Return(policy)); sender->sendMessage(message, destination); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); // Check packet metadata Protocol::Packet::DataHeader* header = nullptr; // Packet0 - EXPECT_EQ(22U, (uint64_t)packet0.address); header = static_cast(packet0.payload); EXPECT_EQ(message->id, header->common.messageId); EXPECT_EQ(message->messageLength, header->totalLength); // Packet1 - EXPECT_EQ(22U, (uint64_t)packet1.address); header = static_cast(packet1.payload); EXPECT_EQ(message->id, header->common.messageId); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(message->messageLength, header->totalLength); // Check Sender metadata @@ -1436,18 +1481,49 @@ TEST_F(SenderTest, sendMessage_multipacket) EXPECT_TRUE(sender->sendReady.load()); } +TEST_F(SenderTest, sendMessage_NO_KEEP_ALIVE) +{ + Protocol::MessageId id = {sender->transportId, + sender->nextMessageSequenceNumber}; + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + + setMessagePacket(message, 0, &mockPacket); + message->messageLength = 420; + mockPacket.length = + message->messageLength + message->TRANSPORT_HEADER_LENGTH; + SocketAddress destination = {22, 60001}; + Core::Policy::Unscheduled policy = {1, 3000, 2}; + + EXPECT_CALL(mockPolicyManager, + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) + .WillOnce(Return(policy)); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .Times(1); + + sender->sendMessage(message, destination, + Sender::Message::Options::NO_KEEP_ALIVE); + + EXPECT_EQ(Sender::Message::Options::NO_KEEP_ALIVE, message->options); + EXPECT_TRUE(bucket->messages.contains(&message->bucketNode)); + EXPECT_TRUE(bucket->messageTimeouts.empty()); + EXPECT_TRUE(bucket->pingTimeouts.empty()); + EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); +} + TEST_F(SenderTest, sendMessage_missingPacket) { Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); setMessagePacket(message, 1, &mockPacket); Core::Policy::Unscheduled policy = {1, 1000, 2}; ON_CALL(mockPolicyManager, getUnscheduledPolicy(_, _)) .WillByDefault(Return(policy)); - EXPECT_DEATH(sender->sendMessage(message, Driver::Address()), + EXPECT_DEATH(sender->sendMessage(message, SocketAddress{0, 0}), ".*Incomplete message with id \\(22:1\\); missing packet at " "offset 0; this shouldn't happen.*"); } @@ -1457,17 +1533,17 @@ TEST_F(SenderTest, sendMessage_unscheduledLimit) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); for (int i = 0; i < 9; ++i) { setMessagePacket(message, i, &mockPacket); } message->messageLength = 9000; mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 4500, 2}; EXPECT_EQ(9U, message->numPackets); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); - EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination, 9000)) + EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination.ip, 9000)) .WillOnce(Return(policy)); sender->sendMessage(message, destination); @@ -1481,7 +1557,7 @@ TEST_F(SenderTest, cancelMessage) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -1499,28 +1575,41 @@ TEST_F(SenderTest, cancelMessage) EXPECT_TRUE(bucket->pingTimeouts.list.empty()); EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state.load()); - EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); } -TEST_F(SenderTest, dropMessage) +TEST_F(SenderTest, dropMessage_basic) { Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_TRUE(message->held); + EXPECT_EQ(OutMessage::Status::NOT_STARTED, message->state); sender->dropMessage(message); EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); } -TEST_F(SenderTest, checkMessageTimeouts_basic) +TEST_F(SenderTest, dropMessage_IN_PROGRESS) +{ + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + message->state = OutMessage::Status::IN_PROGRESS; + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_TRUE(message->held); + + sender->dropMessage(message); + + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_FALSE(message->held); +} + +TEST_F(SenderTest, checkMessageTimeouts) { Sender::Message* message[4]; + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); for (uint64_t i = 0; i < 4; ++i) { - Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); - SenderTest::addMessage(sender, id, message[i]); - Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + message[i] = dynamic_cast(sender->allocMessage(0)); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); } @@ -1528,6 +1617,7 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) // Message[0]: Normal timeout: IN_PROGRESS message[0]->messageTimeout.expirationCycleTime = 9998; message[0]->state = Homa::OutMessage::Status::IN_PROGRESS; + sender->sendQueue.push_front(&message[0]->queuedMessageInfo.sendQueueNode); // Message[1]: Normal timeout: SENT message[1]->messageTimeout.expirationCycleTime = 9999; message[1]->state = Homa::OutMessage::Status::SENT; @@ -1538,15 +1628,17 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) message[3]->messageTimeout.expirationCycleTime = 10001; message[3]->state = Homa::OutMessage::Status::SENT; - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); + bucket->messageTimeouts.nextTimeout = 9998; - uint64_t nextTimeout = sender->checkMessageTimeouts(); + sender->checkMessageTimeouts(10000, bucket); - EXPECT_EQ(message[3]->messageTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[3]->messageTimeout.expirationCycleTime, + bucket->messageTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: IN_PROGRESS EXPECT_EQ(nullptr, message[0]->messageTimeout.node.list); EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message[0]->getStatus()); + EXPECT_TRUE(sender->sendQueue.empty()); // Message[1]: Normal timeout: SENT EXPECT_EQ(nullptr, message[1]->messageTimeout.node.list); EXPECT_EQ(nullptr, message[1]->pingTimeout.node.list); @@ -1556,96 +1648,76 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) EXPECT_EQ(nullptr, message[2]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message[2]->getStatus()); // Message[3]: No timeout - EXPECT_EQ( - &sender->messageBuckets.getBucket(message[3]->id)->messageTimeouts.list, - message[3]->messageTimeout.node.list); - EXPECT_EQ( - &sender->messageBuckets.getBucket(message[3]->id)->pingTimeouts.list, - message[3]->pingTimeout.node.list); + EXPECT_EQ(&bucket->messageTimeouts.list, + message[3]->messageTimeout.node.list); + EXPECT_EQ(&bucket->pingTimeouts.list, message[3]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[3]->getStatus()); } -TEST_F(SenderTest, checkMessageTimeouts_empty) +TEST_F(SenderTest, checkPingTimeouts) { - for (int i = 0; i < Sender::MessageBucketMap::NUM_BUCKETS; ++i) { - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->messageTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = sender->checkMessageTimeouts(); - EXPECT_EQ(10000 + messageTimeoutCycles, nextTimeout); -} - -TEST_F(SenderTest, checkPingTimeouts_basic) -{ - Sender::Message* message[5]; - for (uint64_t i = 0; i < 5; ++i) { - Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); - SenderTest::addMessage(sender, id, message[i]); - Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + Sender::Message* message[6]; + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); + for (uint64_t i = 0; i < 6; ++i) { + message[i] = dynamic_cast(sender->allocMessage(0)); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); } // Message[0]: Normal timeout: COMPLETED message[0]->state = Homa::OutMessage::Status::COMPLETED; - message[0]->pingTimeout.expirationCycleTime = 9997; + message[0]->pingTimeout.expirationCycleTime = 9996; // Message[1]: Normal timeout: FAILED message[1]->state = Homa::OutMessage::Status::FAILED; - message[1]->pingTimeout.expirationCycleTime = 9998; + message[1]->pingTimeout.expirationCycleTime = 9997; // Message[2]: Normal timeout: NO_KEEP_ALIVE && SENT message[2]->options = Homa::OutMessage::Options::NO_KEEP_ALIVE; message[2]->state = Homa::OutMessage::Status::SENT; - message[2]->pingTimeout.expirationCycleTime = 9999; - // Message[3]: Normal timeout: SENT - message[3]->state = Homa::OutMessage::Status::SENT; - message[3]->pingTimeout.expirationCycleTime = 10000; - // Message[4]: No timeout - message[4]->pingTimeout.expirationCycleTime = 10001; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); + message[2]->pingTimeout.expirationCycleTime = 9998; + // Message[3]: Normal timeout: IN_PROGRESS + message[3]->state = Homa::OutMessage::Status::IN_PROGRESS; + message[3]->pingTimeout.expirationCycleTime = 9999; + message[3]->queuedMessageInfo.packetsSent = 1; + message[3]->queuedMessageInfo.packetsGranted = 2; + // Message[4]: Normal timeout: SENT + message[4]->state = Homa::OutMessage::Status::SENT; + message[4]->pingTimeout.expirationCycleTime = 10000; + // Message[5]: No timeout + message[5]->pingTimeout.expirationCycleTime = 10001; + + bucket->pingTimeouts.nextTimeout = 9996; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - uint64_t nextTimeout = sender->checkPingTimeouts(); + sender->checkPingTimeouts(10000, bucket); - EXPECT_EQ(message[4]->pingTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[5]->pingTimeout.expirationCycleTime, + bucket->pingTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: COMPLETED EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); // Message[1]: Normal timeout: FAILED EXPECT_EQ(nullptr, message[1]->pingTimeout.node.list); // Message[2]: Normal timeout: NO_KEEP_ALIVE && SENT EXPECT_EQ(nullptr, message[2]->pingTimeout.node.list); - // Message[3]: Normal timeout: SENT - EXPECT_EQ(10100, message[3]->pingTimeout.expirationCycleTime); + // Message[3]: Normal timeout: IN_PROGRESS + EXPECT_EQ(10100, message[4]->pingTimeout.expirationCycleTime); + // Message[4]: Normal timeout: SENT + EXPECT_EQ(10100, message[4]->pingTimeout.expirationCycleTime); Protocol::Packet::CommonHeader* header = static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::PING, header->opcode); - EXPECT_EQ(message[3]->id, header->messageId); - // Message[4]: No timeout - EXPECT_EQ(10001, message[4]->pingTimeout.expirationCycleTime); -} - -TEST_F(SenderTest, checkPingTimeouts_empty) -{ - for (int i = 0; i < Sender::MessageBucketMap::NUM_BUCKETS; ++i) { - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->pingTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - sender->checkPingTimeouts(); - uint64_t nextTimeout = sender->checkPingTimeouts(); - EXPECT_EQ(10000 + pingIntervalCycles, nextTimeout); + EXPECT_EQ(message[4]->id, header->messageId); + // Message[5]: No timeout + EXPECT_EQ(10001, message[5]->pingTimeout.expirationCycleTime); } TEST_F(SenderTest, trySend_basic) { Protocol::MessageId id = {42, 10}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 3); Homa::Mock::MockDriver::MockPacket* packet[5]; @@ -1653,12 +1725,13 @@ TEST_F(SenderTest, trySend_basic) const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; } message->state = Homa::OutMessage::Status::IN_PROGRESS; + message->held = false; sender->sendReady = true; EXPECT_EQ(5U, message->numPackets); EXPECT_EQ(3U, info->packetsGranted); @@ -1666,10 +1739,11 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(5 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); // 3 granted packets; 2 will send; queue limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); sender->trySend(); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1678,10 +1752,11 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(3 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1690,6 +1765,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(2 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // No additional grants; spurious ready hint. @@ -1703,13 +1779,14 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(2 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // 2 more granted packets; will finish. info->packetsGranted = 5; sender->sendReady = true; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); @@ -1718,6 +1795,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(0 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_FALSE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); for (int i = 0; i < 5; ++i) { @@ -1727,15 +1805,18 @@ TEST_F(SenderTest, trySend_basic) TEST_F(SenderTest, trySend_multipleMessages) { + Protocol::MessageId id[3]; Sender::Message* message[3]; Sender::QueuedMessageInfo* info[3]; + Sender::MessageBucket* bucket[3]; Homa::Mock::MockDriver::MockPacket* packet[3]; for (uint64_t i = 0; i < 3; ++i) { - Protocol::MessageId id = {22, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + id[i] = {22, 10 + i}; + message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; - SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + SenderTest::addMessage(sender, id[i], message[i], true, 1); + bucket[i] = sender->messageBuckets.getBucket(id[i]); + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += @@ -1744,9 +1825,10 @@ TEST_F(SenderTest, trySend_multipleMessages) } sender->sendReady = true; - // Message 0: Will finish + // Message 0: Will finish, !held EXPECT_EQ(1, info[0]->packetsGranted); info[0]->packetsSent = 0; + message[0]->held = false; // Message 1: Will reach grant limit EXPECT_EQ(1, info[1]->packetsGranted); @@ -1754,32 +1836,39 @@ TEST_F(SenderTest, trySend_multipleMessages) setMessagePacket(message[1], 1, nullptr); EXPECT_EQ(2, message[1]->numPackets); - // Message 2: Will finish + // Message 2: Will finish, NO_KEEP_ALIVE EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; + message[2]->options = OutMessage::Options::NO_KEEP_ALIVE; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); EXPECT_EQ(1U, info[0]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[0]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[0]->sendQueueNode)); + EXPECT_EQ(nullptr, + bucket[0]->findMessage(id[0], SpinLock::Lock(bucket[0]->mutex))); EXPECT_EQ(1U, info[1]->packetsSent); EXPECT_NE(Homa::OutMessage::Status::SENT, message[1]->state); EXPECT_TRUE(sender->sendQueue.contains(&info[1]->sendQueueNode)); EXPECT_EQ(1U, info[2]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[2]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[2]->sendQueueNode)); + EXPECT_FALSE(bucket[2]->messageTimeouts.list.contains( + &message[2]->messageTimeout.node)); + EXPECT_FALSE( + bucket[2]->pingTimeouts.list.contains(&message[2]->pingTimeout.node)); } TEST_F(SenderTest, trySend_alreadyRunning) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 1); setMessagePacket(message, 0, &mockPacket); diff --git a/src/Timeout.h b/src/Timeout.h index 6931f2c..56710f4 100644 --- a/src/Timeout.h +++ b/src/Timeout.h @@ -18,18 +18,25 @@ #include +#include + #include "Intrusive.h" namespace Homa { namespace Core { +// Forward declaration. +template +class TimeoutManager; + /** * Intrusive structure to keep track of a per object timeout. * * This structure is not thread-safe. */ template -struct Timeout { +class Timeout { + public: /** * Initialize this Timeout, associating it with a particular object. * @@ -38,21 +45,34 @@ struct Timeout { */ explicit Timeout(ElementType* owner) : expirationCycleTime(0) - , node(owner) + , owner(owner) + , node(this) {} /** * Return true if this Timeout has elapsed; false otherwise. + * + * @param now + * Optionally provided "current" timestamp cycle time. Used to avoid + * unnecessary calls to PerfUtils::Cycles::rdtsc() if the current time + * is already available to the caller. */ - bool hasElapsed() + inline bool hasElapsed(uint64_t now = PerfUtils::Cycles::rdtsc()) { - return PerfUtils::Cycles::rdtsc() >= expirationCycleTime; + return now >= expirationCycleTime; } + private: /// Cycle timestamp when timeout should elapse. uint64_t expirationCycleTime; + + /// Pointer to the object that is associated with this timeout. + ElementType* owner; + /// Intrusive member to help track this timeout. - typename Intrusive::List::Node node; + typename Intrusive::List>::Node node; + + friend class TimeoutManager; }; /** @@ -61,7 +81,8 @@ struct Timeout { * This structure is not thread-safe. */ template -struct TimeoutManager { +class TimeoutManager { + public: /** * Construct a new TimeoutManager with a particular timeout interval. All * timeouts tracked by this manager will have the same timeout interval. @@ -69,6 +90,7 @@ struct TimeoutManager { */ explicit TimeoutManager(uint64_t timeoutIntervalCycles) : timeoutIntervalCycles(timeoutIntervalCycles) + , nextTimeout(UINT64_MAX) , list() {} @@ -79,12 +101,14 @@ struct TimeoutManager { * @param timeout * The Timeout that should be scheduled. */ - void setTimeout(Timeout* timeout) + inline void setTimeout(Timeout* timeout) { list.remove(&timeout->node); timeout->expirationCycleTime = PerfUtils::Cycles::rdtsc() + timeoutIntervalCycles; list.push_back(&timeout->node); + nextTimeout.store(list.front().expirationCycleTime, + std::memory_order_relaxed); } /** @@ -93,16 +117,78 @@ struct TimeoutManager { * @param timeout * The Timeout that should be canceled. */ - void cancelTimeout(Timeout* timeout) + inline void cancelTimeout(Timeout* timeout) { list.remove(&timeout->node); + if (list.empty()) { + nextTimeout.store(UINT64_MAX, std::memory_order_relaxed); + } else { + nextTimeout.store(list.front().expirationCycleTime, + std::memory_order_relaxed); + } + } + + /** + * Check if any managed Timeouts have elapsed. + * + * This method is thread-safe but may race with the other + * non-thread-safe methods of the TimeoutManager (e.g. concurrent calls + * to setTimeout() or cancelTimeout() may not be reflected in the result + * of this method call). + * + * @param now + * Optionally provided "current" timestamp cycle time. Used to + * avoid unnecessary calls to PerfUtils::Cycles::rdtsc() if the current + * time is already available to the caller. + */ + inline bool anyElapsed(uint64_t now = PerfUtils::Cycles::rdtsc()) + { + return now >= nextTimeout.load(std::memory_order_relaxed); + } + + /** + * Check if the TimeoutManager manages no Timeouts. + * + * @return + * True, if there are no Timeouts being managed; false, otherwise. + */ + inline bool empty() const + { + return list.empty(); + } + + /** + * Return a reference the managed timeout element that expires first. + * + * Calling front() an empty TimeoutManager is undefined. + */ + inline ElementType& front() + { + return *list.front().owner; } + /** + * Return a const reference the managed timeout element that expires + * first. + * + * Calling front() an empty TimeoutManager is undefined. + */ + inline const ElementType& front() const + { + return *list.front().owner; + } + + private: /// The number of cycles this newly scheduled timeouts would wait before /// they elapse. uint64_t timeoutIntervalCycles; + + /// The smallest timeout expiration time of all timeouts under + /// management. Accessing this value is thread-safe. + std::atomic nextTimeout; + /// Used to keep track of all timeouts under management. - Intrusive::List list; + Intrusive::List> list; }; } // namespace Core diff --git a/src/TimeoutTest.cc b/src/TimeoutTest.cc index 0225918..18ec9d2 100644 --- a/src/TimeoutTest.cc +++ b/src/TimeoutTest.cc @@ -13,12 +13,11 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include #include #include "Timeout.h" -#include - namespace Homa { namespace Core { namespace { @@ -46,19 +45,22 @@ TEST(TimeoutManagerTest, setTimeout) char dummyOwner; char owner; Timeout dummy(&dummyOwner); + dummy.expirationCycleTime = 42; manager.list.push_back(&dummy.node); + manager.nextTimeout = 0; Timeout t(&owner); EXPECT_EQ(0U, t.expirationCycleTime); EXPECT_EQ(nullptr, t.node.list); - EXPECT_EQ(&dummyOwner, &manager.list.back()); + EXPECT_EQ(&dummyOwner, manager.list.back().owner); manager.setTimeout(&t); EXPECT_EQ(10100U, t.expirationCycleTime); + EXPECT_EQ(42U, manager.nextTimeout.load()); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&owner, &manager.list.back()); + EXPECT_EQ(&owner, manager.list.back().owner); manager.list.clear(); PerfUtils::Cycles::mockTscValue = 0; @@ -72,21 +74,23 @@ TEST(TimeoutManagerTest, setTimeout_reset) char dummyOwner; Timeout t(&owner); Timeout dummy(&dummyOwner); + dummy.expirationCycleTime = 9001; manager.list.push_back(&t.node); manager.list.push_back(&dummy.node); t.expirationCycleTime = 50; EXPECT_EQ(50U, t.expirationCycleTime); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&owner, &manager.list.front()); - EXPECT_EQ(&dummyOwner, &manager.list.back()); + EXPECT_EQ(&owner, &manager.front()); + EXPECT_EQ(&dummyOwner, manager.list.back().owner); manager.setTimeout(&t); EXPECT_EQ(10100U, t.expirationCycleTime); + EXPECT_EQ(9001U, manager.nextTimeout.load()); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&dummyOwner, &manager.list.front()); - EXPECT_EQ(&owner, &manager.list.back()); + EXPECT_EQ(&dummyOwner, &manager.front()); + EXPECT_EQ(&owner, manager.list.back().owner); manager.list.clear(); PerfUtils::Cycles::mockTscValue = 0; @@ -96,20 +100,25 @@ TEST(TimeoutManagerTest, cancelTimeout) { TimeoutManager manager(100); char owner; - Timeout t(&owner); - manager.list.push_back(&t.node); + Timeout t1(&owner); + t1.expirationCycleTime = 42; + Timeout t2(&owner); + t2.expirationCycleTime = 9001; + manager.list.push_back(&t1.node); + manager.list.push_back(&t2.node); - EXPECT_EQ(&manager.list, t.node.list); - EXPECT_FALSE(manager.list.empty()); + EXPECT_EQ(2, manager.list.size()); - manager.cancelTimeout(&t); + manager.cancelTimeout(&t1); - EXPECT_EQ(nullptr, t.node.list); - EXPECT_TRUE(manager.list.empty()); + EXPECT_EQ(nullptr, t1.node.list); + EXPECT_EQ(9001U, manager.nextTimeout.load()); + EXPECT_EQ(1, manager.list.size()); - manager.cancelTimeout(&t); + manager.cancelTimeout(&t2); - EXPECT_EQ(nullptr, t.node.list); + EXPECT_EQ(nullptr, t2.node.list); + EXPECT_EQ(UINT64_MAX, manager.nextTimeout.load()); EXPECT_TRUE(manager.list.empty()); } diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 310e099..8178b1a 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -14,11 +14,7 @@ */ #include "TransportImpl.h" - #include -#include -#include - #include "Cycles.h" #include "Perf.h" #include "Protocol.h" @@ -67,6 +63,8 @@ TransportImpl::~TransportImpl() = default; void TransportImpl::poll() { + Perf::Timer timer; + // Receive and dispatch incoming packets. processPackets(); @@ -74,15 +72,7 @@ TransportImpl::poll() sender->poll(); receiver->poll(); - if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { - uint64_t requestedTimeoutCycles; - requestedTimeoutCycles = sender->checkTimeouts(); - nextTimeoutCycles.store(requestedTimeoutCycles); - requestedTimeoutCycles = receiver->checkTimeouts(); - if (nextTimeoutCycles.load() > requestedTimeoutCycles) { - nextTimeoutCycles.store(requestedTimeoutCycles); - } - } + Perf::counters.total_cycles.add(timer.split()); } /** @@ -94,61 +84,74 @@ void TransportImpl::processPackets() { // Keep track of time spent doing active processing versus idle. - Perf::Timer activityTimer; - activityTimer.split(); - uint64_t activeTime = 0; - uint64_t idleTime = 0; + Perf::Timer timer; const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; - int numPackets = driver->receivePackets(MAX_BURST, packets); + IpAddress srcAddrs[MAX_BURST]; + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); for (int i = 0; i < numPackets; ++i) { - Driver::Packet* packet = packets[i]; - assert(packet->length >= - Util::downCast(sizeof(Protocol::Packet::CommonHeader))); - Perf::counters.rx_bytes.add(packet->length); - Protocol::Packet::CommonHeader* header = - static_cast(packet->payload); - switch (header->opcode) { - case Protocol::Packet::DATA: - Perf::counters.rx_data_pkts.add(1); - receiver->handleDataPacket(packet, driver); - break; - case Protocol::Packet::GRANT: - Perf::counters.rx_grant_pkts.add(1); - sender->handleGrantPacket(packet, driver); - break; - case Protocol::Packet::DONE: - Perf::counters.rx_done_pkts.add(1); - sender->handleDonePacket(packet, driver); - break; - case Protocol::Packet::RESEND: - Perf::counters.rx_resend_pkts.add(1); - sender->handleResendPacket(packet, driver); - break; - case Protocol::Packet::BUSY: - Perf::counters.rx_busy_pkts.add(1); - receiver->handleBusyPacket(packet, driver); - break; - case Protocol::Packet::PING: - Perf::counters.rx_ping_pkts.add(1); - receiver->handlePingPacket(packet, driver); - break; - case Protocol::Packet::UNKNOWN: - Perf::counters.rx_unknown_pkts.add(1); - sender->handleUnknownPacket(packet, driver); - break; - case Protocol::Packet::ERROR: - Perf::counters.rx_error_pkts.add(1); - sender->handleErrorPacket(packet, driver); - break; - } - activeTime += activityTimer.split(); + processPacket(packets[i], srcAddrs[i]); } - idleTime += activityTimer.split(); + driver->releasePackets(packets, numPackets); + + if (numPackets > 0) { + Perf::counters.active_cycles.add(timer.split()); + } +} - Perf::counters.active_cycles.add(activeTime); - Perf::counters.idle_cycles.add(idleTime); +/** + * Process an incoming packet. The transport will have no more use of this + * packet afterwards, so the packet can be released to the driver when the + * method returns. + * + * @param packet + * Incoming packet to be processed. + * @param sourceIp + * Source IP address. + */ +void +TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) +{ + assert(packet->length >= + Util::downCast(sizeof(Protocol::Packet::CommonHeader))); + Perf::counters.rx_bytes.add(packet->length); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + switch (header->opcode) { + case Protocol::Packet::DATA: + Perf::counters.rx_data_pkts.add(1); + receiver->handleDataPacket(packet, sourceIp); + break; + case Protocol::Packet::GRANT: + Perf::counters.rx_grant_pkts.add(1); + sender->handleGrantPacket(packet); + break; + case Protocol::Packet::DONE: + Perf::counters.rx_done_pkts.add(1); + sender->handleDonePacket(packet); + break; + case Protocol::Packet::RESEND: + Perf::counters.rx_resend_pkts.add(1); + sender->handleResendPacket(packet); + break; + case Protocol::Packet::BUSY: + Perf::counters.rx_busy_pkts.add(1); + receiver->handleBusyPacket(packet); + break; + case Protocol::Packet::PING: + Perf::counters.rx_ping_pkts.add(1); + receiver->handlePingPacket(packet, sourceIp); + break; + case Protocol::Packet::UNKNOWN: + Perf::counters.rx_unknown_pkts.add(1); + sender->handleUnknownPacket(packet); + break; + case Protocol::Packet::ERROR: + Perf::counters.rx_error_pkts.add(1); + sender->handleErrorPacket(packet); + break; + } } } // namespace Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 2d559be..ad46f99 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -47,9 +47,10 @@ class TransportImpl : public Transport { ~TransportImpl(); /// See Homa::Transport::alloc() - virtual Homa::unique_ptr alloc() + virtual Homa::unique_ptr alloc(uint16_t sourcePort) { - return Homa::unique_ptr(sender->allocMessage()); + Homa::OutMessage* outMessage = sender->allocMessage(sourcePort); + return Homa::unique_ptr(outMessage); } /// See Homa::Transport::receive() @@ -74,6 +75,7 @@ class TransportImpl : public Transport { private: void processPackets(); + void processPacket(Driver::Packet* packet, IpAddress source); /// Unique identifier for this transport. const std::atomic transportId; diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index 0e0ab60..b5c708b 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -27,6 +27,7 @@ namespace Homa { namespace Core { namespace { +using ::testing::_; using ::testing::DoAll; using ::testing::Eq; using ::testing::NiceMock; @@ -67,32 +68,20 @@ TEST_F(TransportImplTest, poll) EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); transport->poll(); - EXPECT_EQ(10000U, transport->nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); transport->poll(); - EXPECT_EQ(10100U, transport->nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).Times(0); - EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); transport->poll(); - - EXPECT_EQ(10100U, transport->nextTimeoutCycles); } TEST_F(TransportImplTest, processPackets) @@ -101,68 +90,60 @@ TEST_F(TransportImplTest, processPackets) Homa::Driver::Packet* packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket(payload[0], 1024); + Homa::Mock::MockDriver::MockPacket dataPacket{payload[0], 1024}; static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; packets[0] = &dataPacket; - EXPECT_CALL(*mockReceiver, - handleDataPacket(Eq(&dataPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket(payload[1], 1024); + Homa::Mock::MockDriver::MockPacket grantPacket{payload[1], 1024}; static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; packets[1] = &grantPacket; - EXPECT_CALL(*mockSender, - handleGrantPacket(Eq(&grantPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket(payload[2], 1024); + Homa::Mock::MockDriver::MockPacket donePacket{payload[2], 1024}; static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; packets[2] = &donePacket; - EXPECT_CALL(*mockSender, - handleDonePacket(Eq(&donePacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket(payload[3], 1024); + Homa::Mock::MockDriver::MockPacket resendPacket{payload[3], 1024}; static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; packets[3] = &resendPacket; - EXPECT_CALL(*mockSender, - handleResendPacket(Eq(&resendPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket(payload[4], 1024); + Homa::Mock::MockDriver::MockPacket busyPacket{payload[4], 1024}; static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; packets[4] = &busyPacket; - EXPECT_CALL(*mockReceiver, - handleBusyPacket(Eq(&busyPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket(payload[5], 1024); + Homa::Mock::MockDriver::MockPacket pingPacket{payload[5], 1024}; static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; packets[5] = &pingPacket; - EXPECT_CALL(*mockReceiver, - handlePingPacket(Eq(&pingPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket(payload[6], 1024); + Homa::Mock::MockDriver::MockPacket unknownPacket{payload[6], 1024}; static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; packets[6] = &unknownPacket; - EXPECT_CALL(*mockSender, - handleUnknownPacket(Eq(&unknownPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket(payload[7], 1024); + Homa::Mock::MockDriver::MockPacket errorPacket{payload[7], 1024}; static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; packets[7] = &errorPacket; - EXPECT_CALL(*mockSender, - handleErrorPacket(Eq(&errorPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleErrorPacket(Eq(&errorPacket))); EXPECT_CALL(mockDriver, receivePackets) .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); diff --git a/src/UtilTest.cc b/src/UtilTest.cc index 3f8b7ff..b491220 100644 --- a/src/UtilTest.cc +++ b/src/UtilTest.cc @@ -14,7 +14,6 @@ */ #include - #include namespace Homa { @@ -33,4 +32,10 @@ TEST(UtilTest, downCast) EXPECT_EQ(64, c); } +TEST(UtilTest, isPowerOfTwo) +{ + EXPECT_TRUE(Util::isPowerOfTwo(4)); + EXPECT_FALSE(Util::isPowerOfTwo(3)); +} + } // namespace Homa \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 403a340..e01e945 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,6 +34,18 @@ target_link_libraries(system_test docopt ) +## dpdk_test ################################################################# + +add_executable(dpdk_test + dpdk_test.cc +) +target_link_libraries(dpdk_test + PRIVATE + Homa::DpdkDriver + docopt + PerfUtils +) + ## Perf ######################################################################## add_executable(Perf diff --git a/test/Output.h b/test/Output.h new file mode 100644 index 0000000..de8d740 --- /dev/null +++ b/test/Output.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include + +namespace Output { + +using Latency = std::chrono::duration; + +struct TimeDist { + Latency min; // Fastest time seen (seconds). + Latency p50; // Median time per operation (seconds). + Latency p90; // 90th percentile time/op (seconds). + Latency p99; // 99th percentile time/op (seconds). + Latency p999; // 99.9th percentile time/op (seconds). +}; + +std::string +format(const std::string& format, ...) +{ + va_list args; + va_start(args, format); + size_t len = std::vsnprintf(NULL, 0, format.c_str(), args); + va_end(args); + std::vector vec(len + 1); + va_start(args, format); + std::vsnprintf(&vec[0], len + 1, format.c_str(), args); + va_end(args); + return &vec[0]; +} + +std::string +formatTime(Latency seconds) +{ + if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f ns", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f us", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.2f ms", + std::chrono::duration(seconds).count()); + } else { + return format("%5.2f s ", seconds.count()); + } +} + +std::string +basicHeader() +{ + return "median min p90 p99 p999 description"; +} + +std::string +basic(std::vector& times, const std::string description) +{ + int count = times.size(); + std::sort(times.begin(), times.end()); + + TimeDist dist; + + dist.min = times[0]; + int index = count / 2; + if (index < count) { + dist.p50 = times.at(index); + } else { + dist.p50 = dist.min; + } + index = count - (count + 5) / 10; + if (index < count) { + dist.p90 = times.at(index); + } else { + dist.p90 = dist.p50; + } + index = count - (count + 50) / 100; + if (index < count) { + dist.p99 = times.at(index); + } else { + dist.p99 = dist.p90; + } + index = count - (count + 500) / 1000; + if (index < count) { + dist.p999 = times.at(index); + } else { + dist.p999 = dist.p99; + } + + std::string output = ""; + output += format("%9s", formatTime(dist.p50).c_str()); + output += format(" %9s", formatTime(dist.min).c_str()); + output += format(" %9s", formatTime(dist.p90).c_str()); + output += format(" %9s", formatTime(dist.p99).c_str()); + output += format(" %9s", formatTime(dist.p999).c_str()); + output += " "; + output += description; + return output; +} + +} // namespace Output diff --git a/test/Perf.cc b/test/Perf.cc index 2b6f04c..2e34399 100644 --- a/test/Perf.cc +++ b/test/Perf.cc @@ -13,22 +13,26 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include #include +#include #include #include #include #include +#include #include #include +#include #include #include #include #include "Cycles.h" -#include "docopt.h" - #include "Homa/Drivers/Util/QueueEstimator.h" #include "Intrusive.h" +#include "ObjectPool.h" +#include "docopt.h" static const char USAGE[] = R"(Performance Nano-Benchmark @@ -62,6 +66,217 @@ struct TestInfo { // should be less than 72 characters long). }; +TestInfo atomicLoadTestInfo = { + "atomicLoad", "Read an std::atomic", + R"(Measure the cost of reading an std::atomic value.)"}; +double +atomicLoadTest() +{ + int count = 1000000; + uint64_t temp; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + temp = val[i].load(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicStoreTestInfo = { + "atomicStore", "Write an std::atomic", + R"(Measure the cost of writing an std::atomic value.)"}; +double +atomicStoreTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicStoreRelaxedTestInfo = { + "atomicStoreRelaxed", "Write an std::atomic (std::memory_order_relaxed)", + R"(Measure the cost of a relaxed atomic write.)"}; +double +atomicStoreRelaxedTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(temp, std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicIncTestInfo = { + "atomicInc", "Increment an std::atomic", + R"(Measure the cost of incrementing an std::atomic value.)"}; +double +atomicIncTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].fetch_add(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicIncRelaxedTestInfo = { + "atomicIncRelaxed", "Increment an std::atomic (std::memory_order_relaxed)", + R"(Measure the cost of a relaxed atomic incrementing.)"}; +double +atomicIncRelaxedTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].fetch_add(temp, std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicIncUnsafeTestInfo = { + "atomicIncUnsafe", "Increment an std::atomic using read-modify-write", + R"(Measure the cost of a thread unsafe increment of an std::atomic.)"}; +double +atomicIncUnsafeTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(val[i].load(std::memory_order_relaxed) + temp, + std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo intReadWriteTestInfo = { + "intReadWrite", "Read and write a uint64_t", + R"(Measure the cost the baseline read/write.)"}; +double +intReadWriteTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + uint64_t val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i] = temp; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo branchTestInfo = { + "branch", "If-else statement", + R"(The cost of choosing a branch in an if-else statement)"}; +double +branchTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + uint64_t a[0xFF + 1]; + uint64_t b[0xFF + 1]; + for (int i = 0; i < 0xFF + 1; i++) { + a[i] = std::rand(); + b[i] = std::rand(); + } + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + int index = i & 0xFF; + if (a[index] < b[index]) { + b[index] = a[index]; + } else { + a[index] = b[index]; + } + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo defaultAllocatorTestInfo = { + "defaultAllocator", "Test new and delete of a simple structure", + R"(Measure the cost of allocation and deallocation using new and delete.)"}; +double +defaultAllocatorTest() +{ + struct Foo { + Foo() + : i() + , buf() + {} + + uint64_t i; + char buf[100]; + }; + Foo* foo[0xFFFF + 1]; + for (int i = 0; i < 0xFFFF + 1; ++i) { + foo[i] = new Foo; + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + delete foo[i & 0xFFFF]; + foo[i & 0xFFFF] = new Foo; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < 0xFFFF + 1; ++i) { + delete foo[i]; + } + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo objectPoolTestInfo = { + "objectPool", "Test ObjectPool allocation of a simple structure", + R"(Measure the cost of allocation and deallocation using an ObjectPool.)"}; +double +objectPoolTest() +{ + struct Foo { + Foo() + : i() + , buf() + {} + uint64_t i; + char buf[100]; + }; + Homa::ObjectPool pool; + Foo* foo[0xFFFF + 1]; + for (int i = 0; i < 0xFFFF + 1; ++i) { + foo[i] = pool.construct(); + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + pool.destroy(foo[i & 0xFFFF]); + foo[i & 0xFFFF] = pool.construct(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < 0xFFFF + 1; ++i) { + pool.destroy(foo[i]); + } + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + TestInfo listSearchTestInfo = { "listSearch", "Linear search an intrusive list", R"(Measure the cost (per entry) of searching through an intrusive list)"}; @@ -194,6 +409,248 @@ mapNullInsertTest() return PerfUtils::Cycles::toSeconds(stop - start) / run; } +TestInfo dequeConstructInfo = { + "dequeConstruct", "Construct an std::deque", + R"(Measure the cost of constructing (and destructing) an std::deque.)"}; +double +dequeConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::deque deque; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo dequePushPopTestInfo = { + "dequePushPop", "std::deque operations", + R"(Measure the cost of pushing/popping an element to/from an std::deque.)"}; +double +dequePushPopTest() +{ + std::deque deque; + for (int i = 0; i < 10000; ++i) { + deque.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = deque.front(); + deque.pop_front(); + deque.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo vectorConstructInfo = { + "vectorConstruct", "Construct an std::vector", + R"(Measure the cost of constructing (and destructing) an std::vector.)"}; +double +vectorConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::vector vector; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorReserveTestInfo = { + "vectorReserve", "Reserve capacity in an std::vector", + R"(Measure the cost of reserving capacity an std::vector.)"}; +double +vectorReserveTest() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::vector vector; + vector.reserve(32); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorPushTestInfo = { + "vectorPush", "std::vector push", + R"(Measure the cost of pushing a new element to an std::vector.)"}; +double +vectorPushTest() +{ + std::vector vector; + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + vector.push_back(i); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorPushPopTestInfo = { + "vectorPushPop", "std::vector operations", + R"(Measure the cost of pushing/popping an element to/from an std::vector.)"}; +double +vectorPushPopTest() +{ + std::vector vector; + for (int i = 0; i < 10000; ++i) { + vector.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = vector.back(); + vector.pop_back(); + vector.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo listConstructInfo = { + "listConstruct", "Construct an std::list", + R"(Measure the cost of constructing (and destructing) an std::list.)"}; +double +listConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::list list; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo listPushTestInfo = { + "listPush", "std::list push", + R"(Measure the cost of pushing a new element to an std::list.)"}; +double +listPushTest() +{ + std::list list; + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + list.push_back(i); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo listPushPopTestInfo = { + "listPushPop", "std::list operations", + R"(Measure the cost of pushing/popping an element to/from an std::list.)"}; +double +listPushPopTest() +{ + std::list list; + for (int i = 0; i < 10000; ++i) { + list.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = list.front(); + list.pop_front(); + list.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo ilistConstructInfo = { + "ilistConstruct", "Construct an Intrusive::list", + R"(Measure the cost of constructing (and destructing) an Intrusive::list.)"}; +double +ilistConstruct() +{ + struct Foo { + Foo() + : node(this) + {} + + Homa::Core::Intrusive::List::Node node; + }; + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + Homa::Core::Intrusive::List list; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo ilistPushPopTestInfo = { + "ilistPushPop", "Intrusive::list operations", + R"(Measure the cost of pushing/popping an element to/from an Intrusive::list.)"}; +double +ilistPushPopTest() +{ + struct Foo { + Foo() + : node(this) + {} + + Homa::Core::Intrusive::List::Node node; + }; + + Homa::Core::Intrusive::List list; + for (int i = 0; i < 10000; ++i) { + Foo* foo = new Foo; + list.push_back(&foo->node); + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + Foo* foo = &list.front(); + list.pop_front(); + list.push_front(&foo->node); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + while (!list.empty()) { + Foo* foo = &list.front(); + list.pop_front(); + delete foo; + } + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo heapTestInfo = { + "heap", "std heap operations", + R"(Measure the cost of pushing/popping an element to/from an std heap.)"}; +double +heapTest() +{ + std::vector heap; + for (uint64_t i = 0; i < 10000; ++i) { + heap.push_back(i); + } + std::make_heap(heap.begin(), heap.end(), std::greater<>{}); + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = heap.front(); + std::pop_heap(heap.begin(), heap.end(), std::greater<>{}); + heap.pop_back(); + heap.push_back(temp + 5000); + std::push_heap(heap.begin(), heap.end(), std::greater<>{}); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + TestInfo queueEstimatorTestInfo = { "queueEstimator", "Update a QueueEstimator", R"(Measure the cost of updating a Homa::Drivers::Util::QueueEstimator.)"}; @@ -244,6 +701,21 @@ rdhrcTest() return PerfUtils::Cycles::toSeconds(stop - start) / count; } +TestInfo rdcscTestInfo = { + "rdcsc", "Read std::chrono::steady_clock", + R"(Measure the cost of reading the std::chrono::steady_clock.)"}; +double +rdcscTest() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + auto timestamp = std::chrono::steady_clock::now(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + // The following struct and table define each performance test in terms of // function that implements the test and collection of string information about // the test like the test's string name. @@ -255,13 +727,36 @@ struct TestCase { // including the test's string name. }; TestCase tests[] = { + {atomicLoadTest, &atomicLoadTestInfo}, + {atomicStoreTest, &atomicStoreTestInfo}, + {atomicStoreRelaxedTest, &atomicStoreRelaxedTestInfo}, + {atomicIncTest, &atomicIncTestInfo}, + {atomicIncRelaxedTest, &atomicIncRelaxedTestInfo}, + {atomicIncUnsafeTest, &atomicIncUnsafeTestInfo}, + {branchTest, &branchTestInfo}, + {intReadWriteTest, &intReadWriteTestInfo}, + {defaultAllocatorTest, &defaultAllocatorTestInfo}, + {objectPoolTest, &objectPoolTestInfo}, {listSearchTest, &listSearchTestInfo}, {mapFindTest, &mapFindTestInfo}, {mapLookupTest, &mapLookupTestInfo}, {mapNullInsertTest, &mapNullInsertTestInfo}, + {dequeConstruct, &dequeConstructInfo}, + {dequePushPopTest, &dequePushPopTestInfo}, + {vectorConstruct, &vectorConstructInfo}, + {vectorReserveTest, &vectorReserveTestInfo}, + {vectorPushTest, &vectorPushTestInfo}, + {vectorPushPopTest, &vectorPushPopTestInfo}, + {listConstruct, &listConstructInfo}, + {listPushTest, &listPushTestInfo}, + {listPushPopTest, &listPushPopTestInfo}, + {ilistConstruct, &ilistConstructInfo}, + {ilistPushPopTest, &ilistPushPopTestInfo}, + {heapTest, &heapTestInfo}, {queueEstimatorTest, &queueEstimatorTestInfo}, {rdtscTest, &rdtscTestInfo}, {rdhrcTest, &rdhrcTestInfo}, + {rdcscTest, &rdcscTestInfo}, }; /** diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc new file mode 100644 index 0000000..4ca1a82 --- /dev/null +++ b/test/dpdk_test.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include + +#include +#include +#include +#include + +#include "Output.h" + +static const char USAGE[] = R"(DPDK Driver Test. + + Usage: + dpdk_test [options] (--server | ) + + Options: + -h --help Show this screen. + --version Show version. + --timetrace Enable TimeTrace output [default: false]. +)"; + +int +main(int argc, char* argv[]) +{ + std::map args = + docopt::docopt(USAGE, {argv + 1, argv + argc}, + true, // show help if requested + "DPDK Driver Test"); // version string + + std::string iface = args[""].asString(); + bool isServer = args["--server"].asBool(); + std::string server_ip_string; + if (!isServer) { + server_ip_string = args[""].asString(); + } + + Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); + + if (isServer) { + std::cout << Homa::IpAddress::toString(driver.getLocalAddress()) + << std::endl; + while (true) { + Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); + } while (receivedPackets == 0); + Homa::Driver::Packet* pong = driver.allocPacket(); + pong->length = 100; + driver.sendPacket(pong, srcAddrs[0], 0); + driver.releasePackets(incoming, receivedPackets); + driver.releasePackets(&pong, 1); + } + } else { + Homa::IpAddress server_ip = + Homa::IpAddress::fromString(server_ip_string.c_str()); + std::vector times; + for (int i = 0; i < 100000; ++i) { + uint64_t start = PerfUtils::Cycles::rdtsc(); + PerfUtils::TimeTrace::record(start, "START"); + Homa::Driver::Packet* ping = driver.allocPacket(); + PerfUtils::TimeTrace::record("allocPacket"); + ping->length = 100; + PerfUtils::TimeTrace::record("set ping args"); + driver.sendPacket(ping, server_ip, 0); + PerfUtils::TimeTrace::record("sendPacket"); + driver.releasePackets(&ping, 1); + PerfUtils::TimeTrace::record("releasePacket"); + Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); + PerfUtils::TimeTrace::record("receivePackets"); + } while (receivedPackets == 0); + driver.releasePackets(incoming, receivedPackets); + PerfUtils::TimeTrace::record("releasePacket"); + uint64_t stop = PerfUtils::Cycles::rdtsc(); + times.emplace_back(PerfUtils::Cycles::toSeconds(stop - start)); + } + if (args["--timetrace"].asBool()) { + PerfUtils::TimeTrace::print(); + } + std::cout << Output::basicHeader() << std::endl; + std::cout << Output::basic(times, "DpdkDriver Ping-Pong") << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/test/system_test.cc b/test/system_test.cc index 8e43238..88b3814 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -70,7 +70,7 @@ struct Node { }; void -serverMain(Node* server, std::vector addresses) +serverMain(Node* server, std::vector addresses) { while (true) { if (server->run.load() == false) { @@ -101,7 +101,7 @@ serverMain(Node* server, std::vector addresses) * Number of Op that failed. */ int -clientMain(int count, int size, std::vector addresses) +clientMain(int count, int size, std::vector addresses) { std::random_device rd; std::mt19937 gen(rd()); @@ -119,9 +119,9 @@ clientMain(int count, int size, std::vector addresses) payload[i] = randData(gen); } - std::string destAddress = addresses[randAddr(gen)]; + Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = client.transport->alloc(); + Homa::unique_ptr message = client.transport->alloc(0); { MessageHeader header; header.id = id; @@ -133,7 +133,7 @@ clientMain(int count, int size, std::vector addresses) << std::endl; } } - message->send(client.driver.getAddress(&destAddress)); + message->send(Homa::SocketAddress{destAddress, 60001}); while (1) { Homa::OutMessage::Status status = message->getStatus(); @@ -185,12 +185,11 @@ main(int argc, char* argv[]) Homa::Drivers::Fake::FakeNetworkConfig::setPacketLossRate(packetLossRate); uint64_t nextServerId = 101; - std::vector addresses; + std::vector addresses; std::vector servers; for (int i = 0; i < numServers; ++i) { Node* server = new Node(nextServerId++); - addresses.emplace_back(std::string( - server->driver.addressToString(server->driver.getLocalAddress()))); + addresses.emplace_back(server->driver.getLocalAddress()); servers.push_back(server); }