From 234e2abb4e38f3d6be5b9cbcc86ac6bf27f7490f Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 6 Oct 2020 01:56:26 -0700 Subject: [PATCH 01/33] Fixed compilation erros in src/Perf.h --- src/Perf.cc | 2 -- src/Perf.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Perf.cc b/src/Perf.cc index 155802c..faf154a 100644 --- a/src/Perf.cc +++ b/src/Perf.cc @@ -15,8 +15,6 @@ #include "Perf.h" -#include - #include #include diff --git a/src/Perf.h b/src/Perf.h index bf9668d..2349b01 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -17,7 +17,7 @@ #define HOMA_PERF_H #include -#include +#include #include From 9d4e95e090cd06f6c6ce026c961ba7696bf00e02 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Fri, 17 Jul 2020 15:09:34 -0700 Subject: [PATCH 02/33] Changes to the Driver interface to get Homa ready for Shenango integration Details: - Disable DpdkDriver in CMakeLists.txt temporarily - Remove unused method Driver::Packet::getMaxPayloadSize() - Remove param `driver` in handleXXXPacket - Change Driver::Packet to become a POD struct - Change Driver::Packet::{address,priority} into params in Driver::sendPacket - Remove the opaque Driver::Address - Use IP packets as the common interface between transport and driver - Extend Homa packet headers to include L4 src/dst port numbers - Use SocketAddress (i.e., ip + port) as opposed to Driver::Address to identify the src/dst address of a message --- CMakeLists.txt | 71 +++--- include/Homa/Driver.h | 177 ++++++--------- include/Homa/Drivers/Fake/FakeDriver.h | 52 ++--- include/Homa/Homa.h | 19 +- include/Homa/Util.h | 6 + src/ControlPacket.h | 8 +- src/Drivers/DPDK/DpdkDriverImpl.h | 6 - src/Drivers/Fake/FakeAddressTest.cc | 75 ------- src/Drivers/Fake/FakeDriver.cc | 84 ++------ src/Drivers/Fake/FakeDriverTest.cc | 61 ++---- src/Mock/MockDriver.h | 42 ++-- src/Mock/MockPolicy.h | 4 +- src/Mock/MockReceiver.h | 17 +- src/Mock/MockSender.h | 22 +- src/Policy.cc | 8 +- src/Policy.h | 6 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 16 +- src/Receiver.cc | 47 ++-- src/Receiver.h | 13 +- src/ReceiverTest.cc | 162 +++++++------- src/Sender.cc | 63 ++---- src/Sender.h | 30 +-- src/SenderTest.cc | 286 +++++++++++++------------ src/TransportImpl.cc | 99 +++++---- src/TransportImpl.h | 6 +- src/TransportImplTest.cc | 41 ++-- test/system_test.cc | 15 +- 28 files changed, 607 insertions(+), 831 deletions(-) delete mode 100644 src/Drivers/Fake/FakeAddressTest.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..2f82962 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules) find_package(Doxygen OPTIONAL_COMPONENTS dot mscgen dia) # Network Interface library (https://www.dpdk.org/) -find_package(Dpdk REQUIRED) +# find_package(Dpdk REQUIRED) # Source control tool; needed to download external libraries. find_package(Git REQUIRED) @@ -135,34 +135,34 @@ target_compile_options(FakeDriver ) ## lib DpdkDriver ############################################################## -add_library(DpdkDriver - src/Drivers/DPDK/DpdkDriver.cc - src/Drivers/DPDK/DpdkDriverImpl.cc - src/Drivers/DPDK/MacAddress.cc -) -add_library(Homa::DpdkDriver ALIAS DpdkDriver) -target_include_directories(DpdkDriver - PUBLIC - $ - $ - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/src -) -target_link_libraries(DpdkDriver - PRIVATE - Dpdk::Dpdk - PUBLIC - Homa -) -target_compile_features(DpdkDriver - PUBLIC - cxx_std_11 -) -target_compile_options(DpdkDriver - PRIVATE - -Wall - -Wextra -) +#add_library(DpdkDriver +# src/Drivers/DPDK/DpdkDriver.cc +# src/Drivers/DPDK/DpdkDriverImpl.cc +# src/Drivers/DPDK/MacAddress.cc +#) +#add_library(Homa::DpdkDriver ALIAS DpdkDriver) +#target_include_directories(DpdkDriver +# PUBLIC +# $ +# $ +# PRIVATE +# ${CMAKE_CURRENT_SOURCE_DIR}/src +#) +#target_link_libraries(DpdkDriver +# PRIVATE +# Dpdk::Dpdk +# PUBLIC +# Homa +#) +#target_compile_features(DpdkDriver +# PUBLIC +# cxx_std_11 +#) +#target_compile_options(DpdkDriver +# PRIVATE +# -Wall +# -Wextra +#) ################################################################################ ## Tests ####################################################################### @@ -195,7 +195,8 @@ endif() ## Install & Export ############################################################ ################################################################################ -install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets +#install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets +install(TARGETS Homa FakeDriver EXPORT HomaTargets LIBRARY DESTINATION lib ARCHIVE DESTINATION lib RUNTIME DESTINATION bin @@ -274,11 +275,11 @@ target_sources(unit_test target_link_libraries(unit_test FakeDriver) #DPDK Tests -target_sources(unit_test - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc -) -target_link_libraries(unit_test DpdkDriver) +#target_sources(unit_test +# PUBLIC +# ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc +#) +#target_link_libraries(unit_test DpdkDriver) target_link_libraries(unit_test gmock_main) # -fno-access-control allows access to private members for testing diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index ecfe666..d510046 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -22,6 +22,28 @@ namespace Homa { +/// IPv4 address in host byte order. +using IpAddress = uint32_t; + +/** + * Represents a packet of data that can be send or is received over the network. + * A Packet logically contains only the transport-layer (L4) Homa header in + * addition to application data. + * + * This struct specifies the minimal object layout of a packet that the core + * Homa protocol depends on (e.g., Homa::Core::{Sender, Receiver}); this is + * useful for applications that only want to use the transport layer of this + * library and have their own infrastructures for sending and receiving packets. + */ +struct PacketSpec { + /// Pointer to an array of bytes containing the payload of this Packet. + /// This array is valid until the Packet is released back to the Driver. + void* payload; + + /// Number of bytes in the payload. + int32_t length; +} __attribute__((packed)); + /** * Used by Homa::Transport to send and receive unreliable datagrams. Provides * the interface to which all Driver implementations must conform. @@ -31,133 +53,46 @@ namespace Homa { class Driver { public: /** - * Represents a Network address. + * Represents a packet that can be send or is received over the network. * - * Each Address representation is specific to the Driver instance that - * returned the it; they cannot be use interchangeably between different - * Driver instances. - */ - using Address = uint64_t; - - /** - * Used to hold a driver's serialized byte-format for a network address. + * The layout of this struct has two parts: the first part is essentially + * a copy of PacketSpec, while the second part contains members specific + * to our driver implementation. * - * Each driver may define its own byte-format so long as fits within the - * bytes array. + * @sa Homa::PacketSpec */ - struct WireFormatAddress { - uint8_t type; ///< Can be used to distinguish between different wire - ///< address formats. - uint8_t bytes[19]; ///< Holds an Address's serialized byte-format. - } __attribute__((packed)); + struct Packet final { + // === PacketSpec definitions === + // The order and types of the following members must match those in + // PacketSpec precisely. - /** - * Represents a packet of data that can be send or is received over the - * network. A Packet logically contains only the payload and not any Driver - * specific headers. - * - * A Packet may be Driver specific and should not used interchangeably - * between Driver instances or implementations. - * - * This class is NOT thread-safe but the Transport and Driver's use of - * Packet objects should be allow the Transport and the Driver to execute on - * different threads. - */ - class Packet { - public: - /// Packet's source or destination. When sending a Packet, the address - /// field will contain the destination Address. When receiving a Packet, - /// address field will contain the source Address. - Address address; - - /// Packet's network priority (send only); the lowest possible priority - /// is 0. The highest priority is positive number defined by the - /// Driver; the highest priority can be queried by calling the method - /// getHighestPacketPriority(). - int priority; - - /// Pointer to an array of bytes containing the payload of this Packet. - /// This array is valid until the Packet is released back to the Driver. - void* const payload; + /// See Homa::PacketSpec::payload. + void* payload; - /// Number of bytes in the payload. - int length; + /// See Homa::PacketSpec::length + int32_t length; - /// Return the maximum number of bytes the payload can hold. - virtual int getMaxPayloadSize() = 0; + // === Extended definitions === + // The following members are specific to the driver framework bundled + // in this library. Therefore, these members must *NOT* appear in the + // core components of Homa transport; they are only used in a few + // places to facilitate the glue code between transport and driver. - protected: - /** - * Construct a Packet. - */ - explicit Packet(void* payload, int length = 0) - : address() - , priority(0) - , payload(payload) - , length(length) - {} + /// Packet's source IpAddress. Only meaningful when this packet is an + /// incoming packet. + IpAddress sourceIp; + } __attribute__((packed)); - // DISALLOW_COPY_AND_ASSIGN - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; - }; + // Static checks to enforce the object layout compatibility between + // Driver::Packet and PacketSpec. + static_assert(offsetof(Packet, payload) == offsetof(PacketSpec, payload)); + static_assert(offsetof(Packet, length) == offsetof(PacketSpec, length)); /** * Driver destructor. */ virtual ~Driver() = default; - /** - * Return a Driver specific network address for the given string - * representation of the address. - * - * @param addressString - * The string representation of the address to return. The address - * string format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _addressString_ is malformed. - */ - virtual Address getAddress(std::string const* const addressString) = 0; - - /** - * Return a Driver specific network address for the given serialized - * byte-format of the address. - * - * @param wireAddress - * The serialized byte-format of the address to be returned. The - * format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _rawAddress_ is malformed. - */ - virtual Address getAddress(WireFormatAddress const* const wireAddress) = 0; - - /** - * Return the string representation of a network address. - * - * @param address - * Address whose string representation should be returned. - */ - virtual std::string addressToString(const Address address) = 0; - - /** - * Serialize a network address into its Raw byte format. - * - * @param address - * Address to be serialized. - * @param[out] wireAddress - * WireFormatAddress object to which the Address is serialized. - */ - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) = 0; - /** * Allocate a new Packet object from the Driver's pool of resources. The * caller must eventually release the packet by passing it to a call to @@ -187,8 +122,16 @@ class Driver { * * @param packet * Packet to be sent over the network. + * @param destination + * IP address of the packet destination. + * @param priority + * Packet's network priority; the lowest possible priority is 0. + * The highest priority is positive number defined by the Driver; + * the highest priority can be queried by calling the method + * getHighestPacketPriority(). */ - virtual void sendPacket(Packet* packet) = 0; + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -273,10 +216,10 @@ class Driver { virtual uint32_t getBandwidth() = 0; /** - * Return this Driver's local network Address which it uses as the source - * Address for outgoing packets. + * Return this Driver's local IP address which it uses as the source + * address for outgoing packets. */ - virtual Address getLocalAddress() = 0; + virtual IpAddress getLocalAddress() = 0; /** * Return the number of bytes that have been passed to the Driver through diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 8413778..04ce8c0 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -34,7 +34,7 @@ const int NUM_PRIORITIES = 8; /// Maximum number of bytes a packet can hold. const uint32_t MAX_PAYLOAD_SIZE = 1500; -/// A set of methods to contol the underlying FakeNetwork's behavior. +/// A set of methods to control the underlying FakeNetwork's behavior. namespace FakeNetworkConfig { /** * Configure the FakeNetwork to drop packets at the specified loss rate. @@ -51,43 +51,34 @@ void setPacketLossRate(double lossRate); * * @sa Driver::Packet */ -class FakePacket : public Driver::Packet { - public: +struct FakePacket { + /// C-style "inheritance"; used to maintain the base struct as a POD type. + Driver::Packet base; + + /// Raw storage for this packets payload. + char buf[MAX_PAYLOAD_SIZE]; + /** * FakePacket constructor. - * - * @param maxPayloadSize - * The maximum number of bytes this packet can hold. */ explicit FakePacket() - : Packet(buf, 0) + : base{.payload = buf, + .length = 0, + .sourceIp = 0} + , buf() {} /** * Copy constructor. */ FakePacket(const FakePacket& other) - : Packet(buf, other.length) - { - address = other.address; - priority = other.priority; - memcpy(buf, other.buf, MAX_PAYLOAD_SIZE); - } - - virtual ~FakePacket() {} - - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() + : base{.payload = buf, + .length = other.base.length, + .sourceIp = 0} + , buf() { - return MAX_PAYLOAD_SIZE; + memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); } - - private: - /// Raw storage for this packets payload. - char buf[MAX_PAYLOAD_SIZE]; - - // Disable Assignment - FakePacket& operator=(const FakePacket&) = delete; }; /// Holds the incoming packets for a particular driver. @@ -117,20 +108,15 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - virtual std::string addressToString(const Address address); - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, Packet* receivedPackets[]); virtual void releasePackets(Packet* packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); virtual uint32_t getBandwidth(); - virtual Address getLocalAddress(); + virtual IpAddress getLocalAddress(); virtual uint32_t getQueuedBytes(); private: diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index dec090c..aba9073 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -37,6 +37,17 @@ namespace Homa { template using unique_ptr = std::unique_ptr; +/** + * Represents a socket address to (from) which we can send (receive) messages. + */ +struct SocketAddress { + /// IPv4 address in host byte order. + IpAddress ip; + + /// Port number in host byte order. + uint16_t port; +}; + /** * Represents an array of bytes that has been received over the network. * @@ -220,11 +231,11 @@ class OutMessage { * Send this message to the destination. * * @param destination - * Address of the transport to which this message will be sent. + * Network address to which this message will be sent. * @param options * Flags to request non-default sending behavior. */ - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE) = 0; protected: @@ -265,10 +276,12 @@ class Transport { /** * Allocate Message that can be sent with this Transport. * + * @param sourcePort + * Port number of the socket from which the message will be sent. * @return * A pointer to the allocated message. */ - virtual Homa::unique_ptr alloc() = 0; + virtual Homa::unique_ptr alloc(uint16_t sourcePort) = 0; /** * Check for and return a Message sent to this Transport if available. diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 121bb44..30a3548 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -21,6 +21,12 @@ #include #include +/// Cast a member of a structure out to the containing structure. +#define container_of(ptr, type, member) ({ \ + const typeof( ((type *)0)->member ) \ + *__mptr = (ptr); \ + (type *)( (char *)__mptr - offsetof(type,member) );}) + namespace Homa { namespace Util { diff --git a/src/ControlPacket.h b/src/ControlPacket.h index a8da070..bc53f10 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -31,21 +31,19 @@ namespace ControlPacket { * @param driver * Driver with which to send the packet. * @param address - * Destination address for the packet to be sent. + * Destination IP address for the packet to be sent. * @param args * Arguments to PacketHeaderType's constructor. */ template void -send(Driver* driver, Driver::Address address, Args&&... args) +send(Driver* driver, IpAddress address, Args&&... args) { Driver::Packet* packet = driver->allocPacket(); new (packet->payload) PacketHeaderType(static_cast(args)...); packet->length = sizeof(PacketHeaderType); - packet->address = address; - packet->priority = driver->getHighestPacketPriority(); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, address, driver->getHighestPacketPriority()); driver->releasePackets(&packet, 1); } diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 289e83f..9b77383 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -109,12 +109,6 @@ class DpdkDriver::Impl { explicit Packet(struct rte_mbuf* mbuf, void* data); explicit Packet(OverflowBuffer* overflowBuf); - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() - { - return MAX_PAYLOAD_SIZE; - } - /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. enum BufferType { MBUF, OVERFLOW_BUF } bufType; ///< Packet BufferType. diff --git a/src/Drivers/Fake/FakeAddressTest.cc b/src/Drivers/Fake/FakeAddressTest.cc deleted file mode 100644 index 67cef78..0000000 --- a/src/Drivers/Fake/FakeAddressTest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include - -#include "FakeAddress.h" - -#include "../RawAddressType.h" - -namespace Homa { -namespace Drivers { -namespace Fake { -namespace { - -TEST(FakeAddressTest, constructor_id) -{ - FakeAddress address(42); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str) -{ - FakeAddress address("42"); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str_bad) -{ - EXPECT_THROW(FakeAddress address("D42"), BadAddress); -} - -TEST(FakeAddressTest, constructor_raw) -{ - Driver::Address::Raw raw; - raw.type = RawAddressType::FAKE; - *reinterpret_cast(raw.bytes) = 42; - - FakeAddress address(&raw); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_raw_bad) -{ - Driver::Address::Raw raw; - raw.type = !RawAddressType::FAKE; - - EXPECT_THROW(FakeAddress address(&raw), BadAddress); -} - -TEST(FakeAddressTest, toString) -{ - // tested sufficiently in constructor tests -} - -TEST(FakeAddressTest, toAddressId) -{ - EXPECT_THROW(FakeAddress::toAddressId("D42"), BadAddress); -} - -} // namespace -} // namespace Fake -} // namespace Drivers -} // namespace Homa diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 6200a49..b6355cc 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -72,8 +72,8 @@ static class FakeNetwork { } /// Deliver the provide packet to the specified destination. - void sendPacket(FakePacket* packet, Driver::Address src, - Driver::Address dst) + void sendPacket(FakePacket* packet, int priority, IpAddress src, + IpAddress dst) { FakeNIC* nic = nullptr; { @@ -92,10 +92,10 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->address = src; - assert(dstPacket->priority < NUM_PRIORITIES); - assert(dstPacket->priority >= 0); - nic->priorityQueue.at(dstPacket->priority).push_back(dstPacket); + dstPacket->base.sourceIp = src; + assert(priority < NUM_PRIORITIES); + assert(priority >= 0); + nic->priorityQueue.at(priority).push_back(dstPacket); } void setPacketLossRate(double lossRate) @@ -115,10 +115,9 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; - /// The FakeAddress identifier for the next FakeDriver that "connects" to - /// the FakeNetwork. + /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. @@ -177,53 +176,6 @@ FakeDriver::~FakeDriver() fakeNetwork.deregisterNIC(localAddressId); } -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(std::string const* const addressString) -{ - char* end; - uint64_t address = std::strtoul(addressString->c_str(), &end, 10); - if (address == 0) { - throw BadAddress(HERE_STR, StringUtil::format("Bad address string: %s", - addressString->c_str())); - } - return address; -} - -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - const Address* address = - reinterpret_cast(wireAddress->bytes); - return *address; -} - -/** - * See Driver::addressToString() - */ -std::string -FakeDriver::addressToString(const Address address) -{ - char buf[21]; - snprintf(buf, sizeof(buf), "%lu", address); - return buf; -} - -/** - * See Driver::addressToWireFormat() - */ -void -FakeDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - new (reinterpret_cast(wireAddress->bytes)) Address(address); -} - /** * See Driver::allocPacket() */ @@ -231,19 +183,19 @@ Driver::Packet* FakeDriver::allocPacket() { FakePacket* packet = new FakePacket(); - return packet; + return &packet->base; } /** * See Driver::sendPacket() */ void -FakeDriver::sendPacket(Packet* packet) +FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = static_cast(packet); - Address srcAddress = getLocalAddress(); - Address dstAddress = srcPacket->address; - fakeNetwork.sendPacket(srcPacket, srcAddress, dstAddress); + FakePacket* srcPacket = container_of(packet, FakePacket, base); + IpAddress srcAddress = getLocalAddress(); + IpAddress dstAddress = destination; + fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); queueEstimator.signalBytesSent(packet->length); } @@ -257,8 +209,9 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) uint32_t numReceived = 0; for (int i = NUM_PRIORITIES - 1; i >= 0; --i) { while (numReceived < maxPackets && !nic.priorityQueue.at(i).empty()) { - receivedPackets[numReceived] = nic.priorityQueue.at(i).front(); + FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); + receivedPackets[numReceived] = &fakePacket->base; numReceived++; } } @@ -272,8 +225,7 @@ void FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - FakePacket* packet = static_cast(packets[i]); - delete packet; + delete container_of(packets[i], FakePacket, base); } } @@ -308,7 +260,7 @@ FakeDriver::getBandwidth() /** * See Driver::getLocalAddress() */ -Driver::Address +IpAddress FakeDriver::getLocalAddress() { return localAddressId; diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index e410119..2390abf 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -18,7 +18,6 @@ #include -#include "../RawAddressType.h" #include "StringUtil.h" namespace Homa { @@ -34,46 +33,12 @@ TEST(FakeDriverTest, constructor) EXPECT_EQ(nextAddressId, driver.localAddressId); } -TEST(FakeDriverTest, getAddress_string) -{ - FakeDriver driver; - std::string addressStr("42"); - Driver::Address address = driver.getAddress(&addressStr); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, getAddress_wireformat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::FAKE; - *reinterpret_cast(wireformatAddress.bytes) = 42; - Driver::Address address = driver.getAddress(&wireformatAddress); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToString) -{ - FakeDriver driver; - Driver::Address address = 42; - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToWireFormat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - driver.addressToWireFormat(42, &wireformatAddress); - EXPECT_EQ("42", - driver.addressToString(driver.getAddress(&wireformatAddress))); -} - TEST(FakeDriverTest, allocPacket) { FakeDriver driver; Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete packet; + delete container_of(packet, FakePacket, base); } TEST(FakeDriverTest, sendPackets) @@ -82,13 +47,14 @@ TEST(FakeDriverTest, sendPackets) FakeDriver driver2; Driver::Packet* packets[4]; + IpAddress destinations[4]; + int prio[4]; for (int i = 0; i < 4; ++i) { packets[i] = driver1.allocPacket(); - packets[i]->address = driver2.getLocalAddress(); - packets[i]->priority = i; + destinations[i] = driver2.getLocalAddress(); + prio[i] = i; } - std::string addressStr("42"); - packets[2]->address = driver1.getAddress(&addressStr); + destinations[2] = IpAddress(42); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -99,7 +65,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - driver1.sendPacket(packets[0]); + driver1.sendPacket(packets[0], destinations[0], prio[0]); EXPECT_EQ(1U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -110,13 +76,12 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = static_cast( - driver2.nic.priorityQueue.at(0).front()); - EXPECT_EQ(driver1.getLocalAddress(), packet->address); + Driver::Packet* packet = &driver2.nic.priorityQueue.at(0).front()->base; + EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } for (int i = 0; i < 4; ++i) { - driver1.sendPacket(packets[i]); + driver1.sendPacket(packets[i], destinations[i], prio[i]); } EXPECT_EQ(2U, driver2.nic.priorityQueue.at(0).size()); @@ -128,7 +93,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - delete packets[2]; + delete container_of(packets[2], FakePacket, base); } TEST(FakeDriverTest, receivePackets) @@ -235,10 +200,8 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { uint64_t nextAddressId = FakeDriver().localAddressId + 1; - std::string addressStr = StringUtil::format("%lu", nextAddressId); - FakeDriver driver; - EXPECT_EQ(driver.getAddress(&addressStr), driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, driver.getLocalAddress()); } } // namespace diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 6cc5ea7..35fd731 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -35,32 +35,22 @@ class MockDriver : public Driver { * * @sa Driver::Packet. */ - class MockPacket : public Driver::Packet { - public: - MockPacket(void* payload, uint16_t length = 0) - : Packet(payload, length) - {} - - MOCK_METHOD0(getMaxPayloadSize, int()); - }; - - MOCK_METHOD1(getAddress, Address(std::string const* const addressString)); - MOCK_METHOD1(getAddress, - Address(WireFormatAddress const* const wireAddress)); - MOCK_METHOD1(addressToString, std::string(Address address)); - MOCK_METHOD2(addressToWireFormat, - void(Address address, WireFormatAddress* wireAddress)); - MOCK_METHOD0(allocPacket, Packet*()); - MOCK_METHOD1(sendPacket, void(Packet* packet)); - MOCK_METHOD0(flushPackets, void()); - MOCK_METHOD2(receivePackets, - uint32_t(uint32_t maxPackets, Packet* receivedPackets[])); - MOCK_METHOD2(releasePackets, void(Packet* packets[], uint16_t numPackets)); - MOCK_METHOD0(getHighestPacketPriority, int()); - MOCK_METHOD0(getMaxPayloadSize, uint32_t()); - MOCK_METHOD0(getBandwidth, uint32_t()); - MOCK_METHOD0(getLocalAddress, Address()); - MOCK_METHOD0(getQueuedBytes, uint32_t()); + using MockPacket = Driver::Packet; + + MOCK_METHOD(Packet*, allocPacket, (), (override)); + MOCK_METHOD(void, sendPacket, + (Packet* packet, IpAddress destination, int priority), + (override)); + MOCK_METHOD(void, flushPackets, ()); + MOCK_METHOD(uint32_t, receivePackets, + (uint32_t maxPackets, Packet* receivedPackets[]), (override)); + MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), + (override)); + MOCK_METHOD(int, getHighestPacketPriority, (), (override)); + MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); + MOCK_METHOD(uint32_t, getBandwidth, (), (override)); + MOCK_METHOD(IpAddress, getLocalAddress, (), (override)); + MOCK_METHOD(uint32_t, getQueuedBytes, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 0595f25..52cb2a5 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -36,10 +36,10 @@ class MockPolicyManager : public Core::Policy::Manager { MOCK_METHOD0(getResendPriority, int()); MOCK_METHOD0(getScheduledPolicy, Core::Policy::Scheduled()); MOCK_METHOD2(getUnscheduledPolicy, - Core::Policy::Unscheduled(const Driver::Address destination, + Core::Policy::Unscheduled(const IpAddress destination, const uint32_t messageLength)); MOCK_METHOD3(signalNewMessage, - void(const Driver::Address source, uint8_t policyVersion, + void(const IpAddress source, uint8_t policyVersion, uint32_t messageLength)); MOCK_METHOD0(poll, void()); }; diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index fc0fa13..75eea2c 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -36,15 +36,14 @@ class MockReceiver : public Core::Receiver { : Receiver(driver, nullptr, messageTimeoutCycles, resendIntervalCycles) {} - MOCK_METHOD2(handleDataPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleBusyPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handlePingPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(receiveMessage, Homa::InMessage*()); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(void, handleDataPacket, + (Driver::Packet* packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handlePingPacket, + (Driver::Packet* packet, IpAddress sourceIp), (override)); + MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index b67152b..4a8bd27 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -37,19 +37,15 @@ class MockSender : public Core::Sender { pingIntervalCycles) {} - MOCK_METHOD0(allocMessage, Homa::OutMessage*()); - MOCK_METHOD2(handleDonePacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleGrantPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleResendPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleUnknownPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleErrorPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); + MOCK_METHOD(void, handleDonePacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet* packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Policy.cc b/src/Policy.cc index cf0e62e..12e7e16 100644 --- a/src/Policy.cc +++ b/src/Policy.cc @@ -97,14 +97,14 @@ Manager::getScheduledPolicy() * unilaterally "granted" (unscheduled) bytes for a new Message to be sent. * * @param destination - * The policy for the Transport at this Address will be returned. + * The policy for the Transport at this IpAddress will be returned. * @param messageLength * The policy for message containing this many bytes will be returned. * * @sa Policy::Unscheduled */ Unscheduled -Manager::getUnscheduledPolicy(const Driver::Address destination, +Manager::getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength) { SpinLock::Lock lock(mutex); @@ -140,14 +140,14 @@ Manager::getUnscheduledPolicy(const Driver::Address destination, * Called by the Receiver when a new Message has started to arrive. * * @param source - * Address of the Transport from which the new Message was received. + * IpAddress of the Transport from which the new Message was received. * @param policyVersion * Version of the policy the Sender used when sending the Message. * @param messageLength * Number of bytes the new incoming Message contains. */ void -Manager::signalNewMessage(const Driver::Address source, uint8_t policyVersion, +Manager::signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength) { SpinLock::Lock lock(mutex); diff --git a/src/Policy.h b/src/Policy.h index c32bf66..6c80c90 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -75,9 +75,9 @@ class Manager { virtual ~Manager() = default; virtual int getResendPriority(); virtual Scheduled getScheduledPolicy(); - virtual Unscheduled getUnscheduledPolicy(const Driver::Address destination, + virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const Driver::Address source, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); @@ -107,7 +107,7 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index ee0dde5..88cdd45 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - Driver::Address dest(22); + IpAddress dest(22); { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index f83725e..25471bb 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -122,16 +122,20 @@ struct HeaderPrefix { /** * Describes the wire format for header fields that are common to all packet - * types. + * types. Note: the first 4 bytes are identical for TCP, UDP, and Homa. */ struct CommonHeader { + uint16_t sport, dport;///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. HeaderPrefix prefix; ///< Common to all versions of the protocol. uint8_t opcode; ///< One of the values of Opcode. MessageId messageId; ///< RemoteOp/Message associated with this packet. /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : prefix(1) + : sport(0) + , dport(0) + , prefix(1) , opcode(opcode) , messageId(messageId) {} @@ -157,14 +161,18 @@ struct DataHeader { // starting at the offset corresponding to the given packet index. /// DataHeader constructor. - DataHeader(MessageId messageId, uint32_t totalLength, uint8_t policyVersion, + DataHeader(uint16_t sport, uint16_t dport, MessageId messageId, + uint32_t totalLength, uint8_t policyVersion, uint16_t unscheduledIndexLimit, uint16_t index) : common(Opcode::DATA, messageId) , totalLength(totalLength) , policyVersion(policyVersion) , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) - {} + { + common.sport = htobe16(sport); + common.dport = htobe16(dport); + } } __attribute__((packed)); /** diff --git a/src/Receiver.cc b/src/Receiver.cc index 25e0619..d499a61 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -82,11 +82,11 @@ Receiver::~Receiver() * * @param packet * The incoming packet to be processed. - * @param driver - * The driver from which the packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::DataHeader* header = static_cast(packet->payload); @@ -102,14 +102,18 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) int numUnscheduledPackets = header->unscheduledIndexLimit; { SpinLock::Lock lock_allocator(messageAllocator.mutex); + SocketAddress srcAddress = { + .ip = sourceIp, + .port = be16toh(header->common.sport) + }; message = messageAllocator.pool.construct( this, driver, dataHeaderLength, messageLength, id, - packet->address, numUnscheduledPackets); + srcAddress, numUnscheduledPackets); } bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source, header->policyVersion, - header->totalLength); + policyManager->signalNewMessage(message->source.ip, + header->policyVersion, header->totalLength); if (message->scheduled) { // Message needs to be scheduled. @@ -121,7 +125,8 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) // Things that must be true (sanity check) assert(id == message->id); assert(message->driver == driver); - assert(message->source == packet->address); + assert(message->source.ip == sourceIp); + assert(message->source.port == be16toh(header->common.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet @@ -169,11 +174,9 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming BUSY packet to be processed. - * @param driver - * The driver from which the BUSY packet was received. */ void -Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleBusyPacket(Driver::Packet* packet) { Protocol::Packet::BusyHeader* header = static_cast(packet->payload); @@ -198,11 +201,11 @@ Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming PING packet to be processed. - * @param driver - * The driver from which the PING packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) +Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::PingHeader* header = static_cast(packet->payload); @@ -236,13 +239,13 @@ Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, message->source, message->id, bytesGranted, priority); + driver, message->source.ip, message->id, bytesGranted, priority); } else { // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); ControlPacket::send( - driver, packet->address, id); + driver, sourceIp, id); } driver->releasePackets(&packet, 1); } @@ -346,7 +349,7 @@ Receiver::Message::acknowledge() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_done_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -367,7 +370,7 @@ Receiver::Message::fail() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_error_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -678,7 +681,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -691,7 +694,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -748,7 +751,7 @@ Receiver::trySendGrants() ScheduledMessageInfo* info = &message->scheduledMessageInfo; // Access message const variables without message mutex. const Protocol::MessageId id = message->id; - const Driver::Address source = message->source; + const IpAddress sourceIp = message->source.ip; // Recalculate message priority info->priority = @@ -765,7 +768,7 @@ Receiver::trySendGrants() info->bytesGranted = newGrantLimit; Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, source, id, + driver, sourceIp, id, Util::downCast(info->bytesGranted), info->priority); } @@ -806,7 +809,7 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) { (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; - Peer* peer = &peerTable[message->source]; + Peer* peer = &peerTable[message->source.ip]; // Insert the Message peer->scheduledMessages.push_front(&info->scheduledMessageNode); Intrusive::deprioritize(&peer->scheduledMessages, diff --git a/src/Receiver.h b/src/Receiver.h index 444e1aa..c97c462 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -30,6 +30,7 @@ #include "Protocol.h" #include "SpinLock.h" #include "Timeout.h" +#include "Util.h" namespace Homa { namespace Core { @@ -46,9 +47,9 @@ class Receiver { uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual void handleDataPacket(Driver::Packet* packet, Driver* driver); - virtual void handleBusyPacket(Driver::Packet* packet, Driver* driver); - virtual void handlePingPacket(Driver::Packet* packet, Driver* driver); + virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual void handleBusyPacket(Driver::Packet* packet); + virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); virtual void poll(); virtual uint64_t checkTimeouts(); @@ -132,7 +133,7 @@ class Receiver { explicit Message(Receiver* receiver, Driver* driver, size_t packetHeaderLength, size_t messageLength, - Protocol::MessageId id, Driver::Address source, + Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) : receiver(receiver) , driver(driver) @@ -195,7 +196,7 @@ class Receiver { const Protocol::MessageId id; /// Contains source address this message. - const Driver::Address source; + const SocketAddress source; /// Number of bytes at the beginning of each Packet that should be /// reserved for the Homa transport header. @@ -473,7 +474,7 @@ class Receiver { /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index a49aee2..213e2bd 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -41,7 +41,7 @@ class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket(&payload) + , mockPacket {&payload} , mockPolicyManager(&mockDriver) , payload() , receiver() @@ -68,7 +68,7 @@ class ReceiverTest : public ::testing::Test { static const uint64_t resendIntervalCycles = 100; NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Receiver* receiver; @@ -105,21 +105,21 @@ TEST_F(ReceiverTest, handleDataPacket) header->totalLength = totalMessageLength; header->policyVersion = policyVersion; header->unscheduledIndexLimit = 1; - mockPacket.address = Driver::Address(22); + mockPacket.sourceIp = IpAddress(22); // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.address), Eq(policyVersion), + signalNewMessage(Eq(mockPacket.sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- { @@ -148,7 +148,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -162,7 +162,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -177,7 +177,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -192,7 +192,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); @@ -207,7 +207,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -217,7 +217,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) { Protocol::MessageId id(42, 32); Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(0), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); bucket->messages.push_back(&message->bucketNode); @@ -228,7 +228,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->resendTimeout.expirationCycleTime); @@ -245,15 +245,15 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); } TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - Driver::Address mockAddress = 22; + IpAddress mockAddress = 22; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 20000, id, mockAddress, 0); + receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); Receiver::ScheduledMessageInfo* info = &message->scheduledMessageInfo; info->bytesGranted = 500; @@ -263,25 +263,25 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = mockAddress; + Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + pingPacket.sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); - EXPECT_EQ(mockAddress, mockPacket.address); Protocol::Packet::GrantHeader* header = (Protocol::Packet::GrantHeader*)payload; EXPECT_EQ(Protocol::Packet::GRANT, header->common.opcode); @@ -295,22 +295,23 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = (Driver::Address)22; + Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + IpAddress mockAddress = (IpAddress)22; + pingPacket.sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); - EXPECT_EQ(pingPacket.address, mockPacket.address); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; EXPECT_EQ(Protocol::Packet::UNKNOWN, header->common.opcode); @@ -321,10 +322,10 @@ TEST_F(ReceiverTest, receiveMessage) { Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); receiver->receivedMessages.queue.push_back(&msg0->receivedMessageNode); receiver->receivedMessages.queue.push_back(&msg1->receivedMessageNode); @@ -349,7 +350,7 @@ TEST_F(ReceiverTest, poll) TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), Driver::Address(0), 0); + Protocol::MessageId(0, 0), SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); bucket->resendTimeouts.setTimeout(&message.resendTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -373,7 +374,7 @@ TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 5; @@ -392,7 +393,7 @@ TEST_F(ReceiverTest, Message_destructor_holes) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 4; @@ -414,10 +415,11 @@ TEST_F(ReceiverTest, Message_acknowledge) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket( + Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -428,14 +430,13 @@ TEST_F(ReceiverTest, Message_acknowledge) EXPECT_EQ(Protocol::Packet::DONE, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::DoneHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_dropped) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->state = Receiver::Message::State::IN_PROGRESS; @@ -450,10 +451,11 @@ TEST_F(ReceiverTest, Message_fail) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket( + Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -464,7 +466,6 @@ TEST_F(ReceiverTest, Message_fail) EXPECT_EQ(Protocol::Packet::ERROR, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::ErrorHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_get_basic) @@ -472,10 +473,10 @@ TEST_F(ReceiverTest, Message_get_basic) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -499,10 +500,10 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -525,10 +526,10 @@ TEST_F(ReceiverTest, Message_get_missingPacket) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -557,7 +558,7 @@ TEST_F(ReceiverTest, Message_length) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 200; message->start = 20; EXPECT_EQ(180U, message->length()); @@ -567,7 +568,7 @@ TEST_F(ReceiverTest, Message_strip) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 30; message->start = 0; @@ -589,7 +590,7 @@ TEST_F(ReceiverTest, Message_getPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; message->packets[0] = packet; @@ -605,7 +606,7 @@ TEST_F(ReceiverTest, Message_setPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; EXPECT_FALSE(message->occupied.test(0)); @@ -626,12 +627,12 @@ TEST_F(ReceiverTest, MessageBucket_findMessage) Protocol::MessageId id0 = {42, 0}; Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, 0, - 0); + receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, + SocketAddress{0, 60001}, 0); Protocol::MessageId id1 = {42, 1}; Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id1, - Driver::Address(0), 0); + SocketAddress{0, 60001}, 0); Protocol::MessageId id_none = {42, 42}; bucket->messages.push_back(&msg0->bucketNode); @@ -659,7 +660,7 @@ TEST_F(ReceiverTest, dropMessage) SpinLock::Lock dummy(dummyMutex); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{22, 60001}, 0); ASSERT_TRUE(message->scheduled); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); @@ -670,7 +671,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_EQ(1U, receiver->messageAllocator.pool.outstandingObjects); EXPECT_EQ(message, bucket->findMessage(id, dummy)); - EXPECT_EQ(&receiver->peerTable[message->source], + EXPECT_EQ(&receiver->peerTable[message->source.ip], message->scheduledMessageInfo.peer); EXPECT_FALSE(bucket->messageTimeouts.list.empty()); EXPECT_FALSE(bucket->resendTimeouts.list.empty()); @@ -693,7 +694,7 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) Protocol::MessageId id = {42, 10 + i}; op[i] = reinterpret_cast(i); message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, 0, 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{0, 60001}, 0); bucket->messages.push_back(&message[i]->bucketNode); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); @@ -767,7 +768,7 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 10000, id, Driver::Address(22), 5); + receiver, &mockDriver, 0, 10000, id, SocketAddress{22, 60001}, 5); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); } @@ -803,14 +804,16 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1(buf1); - Homa::Mock::MockDriver::MockPacket mockResendPacket2(buf2); + Homa::Mock::MockDriver::MockPacket mockResendPacket1 {buf1}; + Homa::Mock::MockDriver::MockPacket mockResendPacket2 {buf2}; EXPECT_CALL(mockDriver, allocPacket()) .WillOnce(Return(&mockResendPacket1)) .WillOnce(Return(&mockResendPacket2)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), + Eq(message[0]->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), + Eq(message[0]->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) @@ -830,7 +833,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(2U, header1->index); EXPECT_EQ(4U, header1->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket1.length); - EXPECT_EQ(message[0]->source, mockResendPacket1.address); Protocol::Packet::ResendHeader* header2 = static_cast(mockResendPacket2.payload); EXPECT_EQ(Protocol::Packet::RESEND, header2->common.opcode); @@ -838,7 +840,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(8U, header2->index); EXPECT_EQ(2U, header2->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket2.length); - EXPECT_EQ(message[0]->source, mockResendPacket2.address); // Message[1]: Blocked on grants EXPECT_EQ(10100, message[1]->resendTimeout.expirationCycleTime); @@ -867,7 +868,8 @@ TEST_F(ReceiverTest, trySendGrants) Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, Driver::Address(100 + i), 10 * (i + 1)); + 10000 * (i + 1), id, SocketAddress{IpAddress(100 + i), 60001}, + 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); receiver->schedule(message[i], lock_scheduler); @@ -894,7 +896,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -920,7 +922,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -941,7 +943,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -960,7 +962,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -975,13 +977,13 @@ TEST_F(ReceiverTest, schedule) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - Driver::Address address[4] = {22, 33, 33, 22}; + IpAddress address[4] = {22, 33, 33, 22}; int messageLength[4] = {2000, 3000, 1000, 4000}; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, address[i], 0); + messageLength[i], id, SocketAddress{address[i], 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; } @@ -1043,19 +1045,19 @@ TEST_F(ReceiverTest, unschedule) int messageLength[5] = {10, 20, 30, 10, 20}; for (uint64_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - Driver::Address source = Driver::Address((i / 3) + 10); + IpAddress source = IpAddress((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, source, 0); + messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } - ASSERT_EQ(Driver::Address(10), message[0]->source); - ASSERT_EQ(Driver::Address(10), message[1]->source); - ASSERT_EQ(Driver::Address(10), message[2]->source); - ASSERT_EQ(Driver::Address(11), message[3]->source); - ASSERT_EQ(Driver::Address(11), message[4]->source); + ASSERT_EQ(IpAddress(10), message[0]->source.ip); + ASSERT_EQ(IpAddress(10), message[1]->source.ip); + ASSERT_EQ(IpAddress(10), message[2]->source.ip); + ASSERT_EQ(IpAddress(11), message[3]->source.ip); + ASSERT_EQ(IpAddress(11), message[4]->source.ip); ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); @@ -1128,15 +1130,15 @@ TEST_F(ReceiverTest, updateSchedule) for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - Driver::Address source = Driver::Address(((i + 1) / 2) + 10); + IpAddress source = IpAddress(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10 * (i + 1), id, source, 0); + 10 * (i + 1), id, SocketAddress{source, 60001}, 0); receiver->schedule(other[i], lock); } Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, - Protocol::MessageId(42, 1), Driver::Address(11), 0); + Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); diff --git a/src/Sender.cc b/src/Sender.cc index c2d0c3f..b993b6b 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -67,10 +67,10 @@ Sender::~Sender() {} * Allocate an OutMessage that can be sent with this Sender. */ Homa::OutMessage* -Sender::allocMessage() +Sender::allocMessage(uint16_t sourcePort) { SpinLock::Lock lock_allocator(messageAllocator.mutex); - return messageAllocator.pool.construct(this, driver); + return messageAllocator.pool.construct(this, sourcePort); } /** @@ -78,12 +78,9 @@ Sender::allocMessage() * * @param packet * Incoming DONE packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) +Sender::handleDonePacket(Driver::Packet* packet) { Protocol::Packet::DoneHeader* header = static_cast(packet->payload); @@ -152,12 +149,9 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming RESEND packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) +Sender::handleResendPacket(Driver::Packet* packet) { Protocol::Packet::ResendHeader* header = static_cast(packet->payload); @@ -222,7 +216,7 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->destination, info->id); + driver, info->destination.ip, info->id); } else { // There are some packets to resend but only resend packets that have // already been sent. @@ -230,11 +224,10 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) int resendPriority = policyManager->getResendPriority(); for (uint16_t i = index; i < resendEnd; ++i) { Driver::Packet* packet = info->packets->getPacket(i); - packet->priority = resendPriority; // Packets will be sent at the priority their original priority. Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message->destination.ip, resendPriority); } } @@ -246,12 +239,9 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming GRANT packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) +Sender::handleGrantPacket(Driver::Packet* packet) { Protocol::Packet::GrantHeader* header = static_cast(packet->payload); @@ -310,12 +300,9 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming UNKNOWN packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) +Sender::handleUnknownPacket(Driver::Packet* packet) { Protocol::Packet::UnknownHeader* header = static_cast(packet->payload); @@ -376,7 +363,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - message->destination, message->messageLength); + message->destination.ip, message->messageLength); int unscheduledIndexLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -401,10 +388,10 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // If there is only one packet in the message, send it right away. Driver::Packet* dataPacket = message->getPacket(0); assert(dataPacket != nullptr); - dataPacket->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); - driver->sendPacket(dataPacket); + driver->sendPacket(dataPacket, message->destination.ip, + policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -413,7 +400,8 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // Some of these values should still be set from when the message // was first queued. assert(info->id == message->id); - assert(info->destination == message->destination); + assert(!memcmp(&info->destination, &message->destination, + sizeof(info->destination))); assert(info->packets == message); // Some values need to be updated info->unsentBytes = message->messageLength; @@ -439,12 +427,9 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming ERROR packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) +Sender::handleErrorPacket(Driver::Packet* packet) { Protocol::Packet::ErrorHeader* header = static_cast(packet->payload); @@ -697,7 +682,7 @@ Sender::Message::reserve(size_t count) * @copydoc Homa::OutMessage::send() */ void -Sender::Message::send(Driver::Address destination, +Sender::Message::send(SocketAddress destination, Sender::Message::Options options) { sender->sendMessage(this, destination, options); @@ -758,7 +743,7 @@ Sender::Message::getOrAllocPacket(size_t index) * @sa dropMessage() */ void -Sender::sendMessage(Sender::Message* message, Driver::Address destination, +Sender::sendMessage(Sender::Message* message, SocketAddress destination, Sender::Message::Options options) { // Prepare the message @@ -767,7 +752,7 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, Protocol::MessageId id(transportId, nextMessageSequenceNumber++); Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - destination, message->messageLength); + destination.ip, message->messageLength); int unscheduledPacketLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -789,10 +774,10 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, i * message->PACKET_DATA_LENGTH); } - packet->address = message->destination; new (packet->payload) Protocol::Packet::DataHeader( - message->id, Util::downCast(message->messageLength), - policy.version, Util::downCast(unscheduledPacketLimit), + message->source.port, destination.port, message->id, + Util::downCast(message->messageLength), policy.version, + Util::downCast(unscheduledPacketLimit), Util::downCast(i)); actualMessageLen += (packet->length - message->TRANSPORT_HEADER_LENGTH); } @@ -816,10 +801,9 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, // If there is only one packet in the message, send it right away. Driver::Packet* packet = message->getPacket(0); assert(packet != nullptr); - packet->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message->destination.ip, policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -984,7 +968,7 @@ Sender::checkPingTimeouts() // the receiver to ensure it still knows about this Message. Perf::counters.tx_ping_pkts.add(1); ControlPacket::send( - message->driver, message->destination, message->id); + message->driver, message->destination.ip, message->id); } globalNextTimeout = std::min(globalNextTimeout, nextTimeout); } @@ -1038,10 +1022,9 @@ Sender::trySend() break; } // ... if not, send away! - packet->priority = info->priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message.destination.ip, info->priority); int packetDataBytes = packet->length - info->packets->TRANSPORT_HEADER_LENGTH; assert(info->unsentBytes >= packetDataBytes); diff --git a/src/Sender.h b/src/Sender.h index 471925a..faa5dee 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -46,12 +46,12 @@ class Sender { uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); virtual ~Sender(); - virtual Homa::OutMessage* allocMessage(); - virtual void handleDonePacket(Driver::Packet* packet, Driver* driver); - virtual void handleResendPacket(Driver::Packet* packet, Driver* driver); - virtual void handleGrantPacket(Driver::Packet* packet, Driver* driver); - virtual void handleUnknownPacket(Driver::Packet* packet, Driver* driver); - virtual void handleErrorPacket(Driver::Packet* packet, Driver* driver); + virtual Homa::OutMessage* allocMessage(uint16_t sourcePort); + virtual void handleDonePacket(Driver::Packet* packet); + virtual void handleResendPacket(Driver::Packet* packet); + virtual void handleGrantPacket(Driver::Packet* packet); + virtual void handleUnknownPacket(Driver::Packet* packet); + virtual void handleErrorPacket(Driver::Packet* packet); virtual void poll(); virtual uint64_t checkTimeouts(); @@ -96,7 +96,7 @@ class Sender { Protocol::MessageId id; /// Contains destination address this message. - Driver::Address destination; + SocketAddress destination; /// Handle to the queue Message for access to the packets that will /// be sent. This member documents that the packets are logically owned @@ -131,13 +131,14 @@ class Sender { /** * Construct an Message. */ - explicit Message(Sender* sender, Driver* driver) + explicit Message(Sender* sender, uint16_t sourcePort) : sender(sender) - , driver(driver) + , driver(sender->driver) , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) , id(0, 0) + , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) , start(0) @@ -161,7 +162,7 @@ class Sender { virtual void prepend(const void* source, size_t count); virtual void release(); virtual void reserve(size_t count); - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE); private: @@ -188,8 +189,11 @@ class Sender { /// Contains the unique identifier for this message. Protocol::MessageId id; - /// Contains destination address this message. - Driver::Address destination; + /// Contains source address of this message. + SocketAddress source; + + /// Contains destination address of this message. + SocketAddress destination; /// Contains flags for any requested optional send behavior. Options options; @@ -384,7 +388,7 @@ class Sender { Protocol::MessageId::Hasher hasher; }; - void sendMessage(Sender::Message* message, Driver::Address destination, + void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); diff --git a/src/SenderTest.cc b/src/SenderTest.cc index fdae6ab..244a7c9 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -36,13 +36,13 @@ class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket(&payload) + , mockPacket {&payload} , mockPolicyManager(&mockDriver) , sender() , savedLogPolicy(Debug::getLogPolicy()) { ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1031)); ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); @@ -59,7 +59,7 @@ class SenderTest : public ::testing::Test { } NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Sender* sender; @@ -124,7 +124,7 @@ TEST_F(SenderTest, allocMessage) { EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); } @@ -132,7 +132,7 @@ TEST_F(SenderTest, handleDonePacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); Protocol::Packet::DoneHeader* header = @@ -143,7 +143,7 @@ TEST_F(SenderTest, handleDonePacket_basic) .Times(2); // No message. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); @@ -151,7 +151,7 @@ TEST_F(SenderTest, handleDonePacket_basic) message->state = Homa::OutMessage::Status::SENT; // Normal expected behavior. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -162,7 +162,7 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::CANCELED; @@ -173,14 +173,14 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); } TEST_F(SenderTest, handleDonePacket_COMPLETED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::COMPLETED; @@ -194,7 +194,7 @@ TEST_F(SenderTest, handleDonePacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -211,7 +211,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::FAILED; @@ -225,7 +225,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -244,7 +244,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -258,7 +258,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -277,7 +277,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::NOT_STARTED; @@ -291,7 +291,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -310,10 +310,12 @@ TEST_F(SenderTest, handleResendPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; + std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + priorities.push_back(0); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -331,22 +333,24 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)).WillOnce( + [&priorities] (auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)).WillOnce( + [&priorities] (auto _1, auto _2, int p) { priorities[4] = p; }); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); EXPECT_EQ(4, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_EQ(0, packets[2]->priority); - EXPECT_EQ(7, packets[3]->priority); - EXPECT_EQ(7, packets[4]->priority); - EXPECT_EQ(0, packets[5]->priority); + EXPECT_EQ(0, priorities[2]); + EXPECT_EQ(7, priorities[3]); + EXPECT_EQ(7, priorities[4]); + EXPECT_EQ(0, priorities[5]); EXPECT_TRUE(sender->sendReady.load()); for (int i = 0; i < 10; ++i) { @@ -366,7 +370,7 @@ TEST_F(SenderTest, handleResendPacket_staleResend) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); } TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) @@ -374,10 +378,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload); + new Homa::Mock::MockDriver::MockPacket {payload}; setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -393,7 +397,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -414,10 +418,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -440,7 +444,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -464,9 +468,9 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket(data); + Homa::Mock::MockDriver::MockPacket dataPacket {data}; for (int i = 0; i < 10; ++i) { setMessagePacket(message, i, &dataPacket); } @@ -484,18 +488,18 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket(busy); + Homa::Mock::MockDriver::MockPacket busyPacket {busy}; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) .Times(1); // Expect no data to be sent but the RESEND packet to be release. - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(0); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(0); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); @@ -511,7 +515,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -530,7 +534,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(7, info->packetsGranted); EXPECT_EQ(6, info->priority); @@ -543,7 +547,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -565,7 +569,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -590,7 +594,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -608,7 +612,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(5, info->packetsGranted); EXPECT_EQ(2, info->priority); @@ -628,23 +632,23 @@ TEST_F(SenderTest, handleGrantPacket_dropGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_basic) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket {payload[i]}; Protocol::Packet::DataHeader* header = static_cast(packet->payload); header->policyVersion = policyOld.version; @@ -674,12 +678,12 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); for (int i = 0; i < 3; ++i) { @@ -706,13 +710,13 @@ TEST_F(SenderTest, handleUnknownPacket_basic) TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - Homa::Mock::MockDriver::MockPacket dataPacket(payload); + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::MockPacket dataPacket {payload}; Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; @@ -733,13 +737,13 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(policyNew.version, dataHeader->policyVersion); @@ -754,7 +758,7 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = @@ -764,7 +768,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket {payload[i]}; packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -806,14 +810,14 @@ TEST_F(SenderTest, handleUnknownPacket_no_message) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_done) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -826,7 +830,7 @@ TEST_F(SenderTest, handleUnknownPacket_done) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message->state); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -838,7 +842,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -851,7 +855,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -863,7 +867,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::CANCELED); @@ -874,7 +878,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state); } @@ -884,7 +888,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::NOT_STARTED); @@ -898,7 +902,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -920,7 +924,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::IN_PROGRESS); @@ -934,7 +938,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -956,7 +960,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); @@ -970,7 +974,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -992,7 +996,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::FAILED); @@ -1006,7 +1010,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -1029,7 +1033,7 @@ TEST_F(SenderTest, handleErrorPacket_noMessage) header->common.messageId = id; EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); } TEST_F(SenderTest, poll) @@ -1040,7 +1044,7 @@ TEST_F(SenderTest, poll) TEST_F(SenderTest, checkTimeouts) { - Sender::Message message(sender, &mockDriver); + Sender::Message message(sender, 0); Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); bucket->pingTimeouts.setTimeout(&message.pingTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -1065,7 +1069,7 @@ TEST_F(SenderTest, Message_destructor) const int MAX_RAW_PACKET_LENGTH = 2000; ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message* msg = new Sender::Message(sender, &mockDriver); + Sender::Message* msg = new Sender::Message(sender, 0); const uint16_t NUM_PKTS = 5; @@ -1086,10 +1090,10 @@ TEST_F(SenderTest, Message_append_basic) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1126,10 +1130,10 @@ TEST_F(SenderTest, Message_append_truncated) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1155,7 +1159,7 @@ TEST_F(SenderTest, Message_append_truncated) EXPECT_STREQ("append", m.function); EXPECT_EQ(int(Debug::LogLevel::WARNING), m.logLevel); EXPECT_EQ( - "Max message size limit (2020352B) reached; 7 of 14 bytes appended", + "Max message size limit (2016256B) reached; 7 of 14 bytes appended", m.message); Debug::setLogHandler(std::function()); @@ -1183,10 +1187,10 @@ TEST_F(SenderTest, Message_length) TEST_F(SenderTest, Message_prepend) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1218,10 +1222,10 @@ TEST_F(SenderTest, Message_release) TEST_F(SenderTest, Message_reserve) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1259,7 +1263,7 @@ TEST_F(SenderTest, Message_send) TEST_F(SenderTest, Message_getPacket) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); Driver::Packet* packet = (Driver::Packet*)42; msg.packets[0] = packet; @@ -1273,10 +1277,10 @@ TEST_F(SenderTest, Message_getPacket) TEST_F(SenderTest, Message_getOrAllocPacket) { // TODO(cstlee): cleanup - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); @@ -1298,9 +1302,9 @@ TEST_F(SenderTest, MessageBucket_findMessage) Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); Sender::Message* msg0 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::Message* msg1 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); msg0->id = {42, 0}; msg1->id = {42, 1}; Protocol::MessageId id_none = {42, 42}; @@ -1329,35 +1333,42 @@ TEST_F(SenderTest, sendMessage_basic) { Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; + uint16_t sport = 0; + uint16_t dport = 60001; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(sport)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &mockPacket); message->messageLength = 420; mockPacket.length = message->messageLength + message->TRANSPORT_HEADER_LENGTH; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, dport}; Core::Policy::Unscheduled policy = {1, 3000, 2}; EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + int mockPriority = 0; + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .WillOnce([&mockPriority] (auto _1, auto _2, int p){mockPriority = p;}); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Sender::Message::Options::NO_RETRY, message->options); // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); + EXPECT_EQ(htobe16(sport), header->common.sport); + EXPECT_EQ(htobe16(dport), header->common.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); @@ -1370,8 +1381,7 @@ TEST_F(SenderTest, sendMessage_basic) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); // Check sent packet metadata - EXPECT_EQ(22U, (uint64_t)mockPacket.address); - EXPECT_EQ(policy.priority, mockPacket.priority); + EXPECT_EQ(policy.priority, mockPriority); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_FALSE(sender->sendReady.load()); @@ -1381,48 +1391,48 @@ TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - NiceMock packet0(payload0); - NiceMock packet1(payload1); + Homa::Mock::MockDriver::MockPacket packet0 {payload0}; + Homa::Mock::MockDriver::MockPacket packet1 {payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &packet0); setMessagePacket(message, 1, &packet1); message->messageLength = 1420; - packet0.length = 1000 + 27; - packet1.length = 420 + 27; - Driver::Address destination = (Driver::Address)22; + packet0.length = 1000 + 31; + packet1.length = 420 + 31; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 1000, 2}; - EXPECT_EQ(27U, sizeof(Protocol::Packet::DataHeader)); + EXPECT_EQ(31U, sizeof(Protocol::Packet::DataHeader)); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(1420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(1420))) .WillOnce(Return(policy)); sender->sendMessage(message, destination); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); // Check packet metadata Protocol::Packet::DataHeader* header = nullptr; // Packet0 - EXPECT_EQ(22U, (uint64_t)packet0.address); header = static_cast(packet0.payload); EXPECT_EQ(message->id, header->common.messageId); EXPECT_EQ(message->messageLength, header->totalLength); // Packet1 - EXPECT_EQ(22U, (uint64_t)packet1.address); header = static_cast(packet1.payload); EXPECT_EQ(message->id, header->common.messageId); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(message->messageLength, header->totalLength); // Check Sender metadata @@ -1441,13 +1451,13 @@ TEST_F(SenderTest, sendMessage_missingPacket) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); setMessagePacket(message, 1, &mockPacket); Core::Policy::Unscheduled policy = {1, 1000, 2}; ON_CALL(mockPolicyManager, getUnscheduledPolicy(_, _)) .WillByDefault(Return(policy)); - EXPECT_DEATH(sender->sendMessage(message, Driver::Address()), + EXPECT_DEATH(sender->sendMessage(message, SocketAddress{0, 0}), ".*Incomplete message with id \\(22:1\\); missing packet at " "offset 0; this shouldn't happen.*"); } @@ -1457,17 +1467,17 @@ TEST_F(SenderTest, sendMessage_unscheduledLimit) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); for (int i = 0; i < 9; ++i) { setMessagePacket(message, i, &mockPacket); } message->messageLength = 9000; mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 4500, 2}; EXPECT_EQ(9U, message->numPackets); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); - EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination, 9000)) + EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination.ip, 9000)) .WillOnce(Return(policy)); sender->sendMessage(message, destination); @@ -1481,7 +1491,7 @@ TEST_F(SenderTest, cancelMessage) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -1505,7 +1515,7 @@ TEST_F(SenderTest, cancelMessage) TEST_F(SenderTest, dropMessage) { Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); sender->dropMessage(message); @@ -1518,7 +1528,7 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) Sender::Message* message[4]; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); @@ -1581,7 +1591,7 @@ TEST_F(SenderTest, checkPingTimeouts_basic) Sender::Message* message[5]; for (uint64_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); @@ -1606,7 +1616,7 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -1645,7 +1655,7 @@ TEST_F(SenderTest, trySend_basic) { Protocol::MessageId id = {42, 10}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 3); Homa::Mock::MockDriver::MockPacket* packet[5]; @@ -1653,7 +1663,7 @@ TEST_F(SenderTest, trySend_basic) const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; packet[i]->length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; @@ -1668,8 +1678,8 @@ TEST_F(SenderTest, trySend_basic) EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); // 3 granted packets; 2 will send; queue limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); sender->trySend(); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1681,7 +1691,7 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1708,8 +1718,8 @@ TEST_F(SenderTest, trySend_basic) // 2 more granted packets; will finish. info->packetsGranted = 5; sender->sendReady = true; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); @@ -1732,10 +1742,10 @@ TEST_F(SenderTest, trySend_multipleMessages) Homa::Mock::MockDriver::MockPacket* packet[3]; for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {22, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += @@ -1758,9 +1768,9 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); @@ -1779,7 +1789,7 @@ TEST_F(SenderTest, trySend_alreadyRunning) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 1); setMessagePacket(message, 0, &mockPacket); diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 310e099..d4ebc70 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -94,61 +94,66 @@ void TransportImpl::processPackets() { // Keep track of time spent doing active processing versus idle. - Perf::Timer activityTimer; - activityTimer.split(); - uint64_t activeTime = 0; - uint64_t idleTime = 0; + uint64_t cycles = PerfUtils::Cycles::rdtsc(); const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; int numPackets = driver->receivePackets(MAX_BURST, packets); for (int i = 0; i < numPackets; ++i) { Driver::Packet* packet = packets[i]; - assert(packet->length >= - Util::downCast(sizeof(Protocol::Packet::CommonHeader))); - Perf::counters.rx_bytes.add(packet->length); - Protocol::Packet::CommonHeader* header = - static_cast(packet->payload); - switch (header->opcode) { - case Protocol::Packet::DATA: - Perf::counters.rx_data_pkts.add(1); - receiver->handleDataPacket(packet, driver); - break; - case Protocol::Packet::GRANT: - Perf::counters.rx_grant_pkts.add(1); - sender->handleGrantPacket(packet, driver); - break; - case Protocol::Packet::DONE: - Perf::counters.rx_done_pkts.add(1); - sender->handleDonePacket(packet, driver); - break; - case Protocol::Packet::RESEND: - Perf::counters.rx_resend_pkts.add(1); - sender->handleResendPacket(packet, driver); - break; - case Protocol::Packet::BUSY: - Perf::counters.rx_busy_pkts.add(1); - receiver->handleBusyPacket(packet, driver); - break; - case Protocol::Packet::PING: - Perf::counters.rx_ping_pkts.add(1); - receiver->handlePingPacket(packet, driver); - break; - case Protocol::Packet::UNKNOWN: - Perf::counters.rx_unknown_pkts.add(1); - sender->handleUnknownPacket(packet, driver); - break; - case Protocol::Packet::ERROR: - Perf::counters.rx_error_pkts.add(1); - sender->handleErrorPacket(packet, driver); - break; - } - activeTime += activityTimer.split(); + processPacket(packet, packet->sourceIp); + } + + cycles = PerfUtils::Cycles::rdtsc() - cycles; + if (numPackets > 0) { + Perf::counters.active_cycles.add(cycles); + } else { + Perf::counters.idle_cycles.add(cycles); } - idleTime += activityTimer.split(); +} - Perf::counters.active_cycles.add(activeTime); - Perf::counters.idle_cycles.add(idleTime); +void +TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) +{ + assert(packet->length >= + Util::downCast(sizeof(Protocol::Packet::CommonHeader))); + Perf::counters.rx_bytes.add(packet->length); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + switch (header->opcode) { + case Protocol::Packet::DATA: + Perf::counters.rx_data_pkts.add(1); + receiver->handleDataPacket(packet, sourceIp); + break; + case Protocol::Packet::GRANT: + Perf::counters.rx_grant_pkts.add(1); + sender->handleGrantPacket(packet); + break; + case Protocol::Packet::DONE: + Perf::counters.rx_done_pkts.add(1); + sender->handleDonePacket(packet); + break; + case Protocol::Packet::RESEND: + Perf::counters.rx_resend_pkts.add(1); + sender->handleResendPacket(packet); + break; + case Protocol::Packet::BUSY: + Perf::counters.rx_busy_pkts.add(1); + receiver->handleBusyPacket(packet); + break; + case Protocol::Packet::PING: + Perf::counters.rx_ping_pkts.add(1); + receiver->handlePingPacket(packet, sourceIp); + break; + case Protocol::Packet::UNKNOWN: + Perf::counters.rx_unknown_pkts.add(1); + sender->handleUnknownPacket(packet); + break; + case Protocol::Packet::ERROR: + Perf::counters.rx_error_pkts.add(1); + sender->handleErrorPacket(packet); + break; + } } } // namespace Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 2d559be..ad46f99 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -47,9 +47,10 @@ class TransportImpl : public Transport { ~TransportImpl(); /// See Homa::Transport::alloc() - virtual Homa::unique_ptr alloc() + virtual Homa::unique_ptr alloc(uint16_t sourcePort) { - return Homa::unique_ptr(sender->allocMessage()); + Homa::OutMessage* outMessage = sender->allocMessage(sourcePort); + return Homa::unique_ptr(outMessage); } /// See Homa::Transport::receive() @@ -74,6 +75,7 @@ class TransportImpl : public Transport { private: void processPackets(); + void processPacket(Driver::Packet* packet, IpAddress source); /// Unique identifier for this transport. const std::atomic transportId; diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index 0e0ab60..c69a36a 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -27,6 +27,7 @@ namespace Homa { namespace Core { namespace { +using ::testing::_; using ::testing::DoAll; using ::testing::Eq; using ::testing::NiceMock; @@ -101,68 +102,60 @@ TEST_F(TransportImplTest, processPackets) Homa::Driver::Packet* packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket(payload[0], 1024); + Homa::Mock::MockDriver::MockPacket dataPacket {payload[0], 1024}; static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; packets[0] = &dataPacket; - EXPECT_CALL(*mockReceiver, - handleDataPacket(Eq(&dataPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket(payload[1], 1024); + Homa::Mock::MockDriver::MockPacket grantPacket {payload[1], 1024}; static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; packets[1] = &grantPacket; - EXPECT_CALL(*mockSender, - handleGrantPacket(Eq(&grantPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket(payload[2], 1024); + Homa::Mock::MockDriver::MockPacket donePacket {payload[2], 1024}; static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; packets[2] = &donePacket; - EXPECT_CALL(*mockSender, - handleDonePacket(Eq(&donePacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket(payload[3], 1024); + Homa::Mock::MockDriver::MockPacket resendPacket {payload[3], 1024}; static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; packets[3] = &resendPacket; - EXPECT_CALL(*mockSender, - handleResendPacket(Eq(&resendPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket(payload[4], 1024); + Homa::Mock::MockDriver::MockPacket busyPacket {payload[4], 1024}; static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; packets[4] = &busyPacket; - EXPECT_CALL(*mockReceiver, - handleBusyPacket(Eq(&busyPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket(payload[5], 1024); + Homa::Mock::MockDriver::MockPacket pingPacket {payload[5], 1024}; static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; packets[5] = &pingPacket; - EXPECT_CALL(*mockReceiver, - handlePingPacket(Eq(&pingPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket(payload[6], 1024); + Homa::Mock::MockDriver::MockPacket unknownPacket {payload[6], 1024}; static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; packets[6] = &unknownPacket; - EXPECT_CALL(*mockSender, - handleUnknownPacket(Eq(&unknownPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket(payload[7], 1024); + Homa::Mock::MockDriver::MockPacket errorPacket {payload[7], 1024}; static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; packets[7] = &errorPacket; - EXPECT_CALL(*mockSender, - handleErrorPacket(Eq(&errorPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleErrorPacket(Eq(&errorPacket))); EXPECT_CALL(mockDriver, receivePackets) .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); diff --git a/test/system_test.cc b/test/system_test.cc index 8e43238..266d842 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -70,7 +70,7 @@ struct Node { }; void -serverMain(Node* server, std::vector addresses) +serverMain(Node* server, std::vector addresses) { while (true) { if (server->run.load() == false) { @@ -101,7 +101,7 @@ serverMain(Node* server, std::vector addresses) * Number of Op that failed. */ int -clientMain(int count, int size, std::vector addresses) +clientMain(int count, int size, std::vector addresses) { std::random_device rd; std::mt19937 gen(rd()); @@ -119,9 +119,9 @@ clientMain(int count, int size, std::vector addresses) payload[i] = randData(gen); } - std::string destAddress = addresses[randAddr(gen)]; + Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = client.transport->alloc(); + Homa::unique_ptr message = client.transport->alloc(0); { MessageHeader header; header.id = id; @@ -133,7 +133,7 @@ clientMain(int count, int size, std::vector addresses) << std::endl; } } - message->send(client.driver.getAddress(&destAddress)); + message->send(Homa::SocketAddress{destAddress, 60001}); while (1) { Homa::OutMessage::Status status = message->getStatus(); @@ -185,12 +185,11 @@ main(int argc, char* argv[]) Homa::Drivers::Fake::FakeNetworkConfig::setPacketLossRate(packetLossRate); uint64_t nextServerId = 101; - std::vector addresses; + std::vector addresses; std::vector servers; for (int i = 0; i < numServers; ++i) { Node* server = new Node(nextServerId++); - addresses.emplace_back(std::string( - server->driver.addressToString(server->driver.getLocalAddress()))); + addresses.emplace_back(server->driver.getLocalAddress()); servers.push_back(server); } From a305d46158b52eb61a79b6b55cc563b5cd084c7c Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Mon, 20 Jul 2020 22:00:51 -0700 Subject: [PATCH 03/33] Update DpdkDriver to use the new API introduced in the previous commit - Reenable DPDK in CMakeLists.txt - Initialize an ARP table at driver startup using the content of /proc/net/arp - Select eth port via the symbolic name of the network interface (e.g., eno1d1) (the current implementation uses ioctl to obtain the IP and MAC addresses of a network interface) - Add a system test for DPDK driver: test/dpdk_test.cc --- CMakeLists.txt | 71 +++++----- include/Homa/Drivers/DPDK/DpdkDriver.h | 31 ++--- include/Homa/Util.h | 2 + src/Drivers/DPDK/DpdkDriver.cc | 47 ++----- src/Drivers/DPDK/DpdkDriverImpl.cc | 180 ++++++++++++++++--------- src/Drivers/DPDK/DpdkDriverImpl.h | 53 +++++--- src/Drivers/DPDK/MacAddress.cc | 53 -------- src/Drivers/DPDK/MacAddress.h | 13 +- src/Drivers/DPDK/MacAddressTest.cc | 40 ------ src/Drivers/RawAddressType.h | 38 ------ src/Util.cc | 16 +++ test/CMakeLists.txt | 12 ++ test/Output.h | 105 +++++++++++++++ test/dpdk_test.cc | 88 ++++++++++++ 14 files changed, 435 insertions(+), 314 deletions(-) delete mode 100644 src/Drivers/RawAddressType.h create mode 100644 test/Output.h create mode 100644 test/dpdk_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f82962..4a6f9c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules) find_package(Doxygen OPTIONAL_COMPONENTS dot mscgen dia) # Network Interface library (https://www.dpdk.org/) -# find_package(Dpdk REQUIRED) +find_package(Dpdk REQUIRED) # Source control tool; needed to download external libraries. find_package(Git REQUIRED) @@ -135,34 +135,34 @@ target_compile_options(FakeDriver ) ## lib DpdkDriver ############################################################## -#add_library(DpdkDriver -# src/Drivers/DPDK/DpdkDriver.cc -# src/Drivers/DPDK/DpdkDriverImpl.cc -# src/Drivers/DPDK/MacAddress.cc -#) -#add_library(Homa::DpdkDriver ALIAS DpdkDriver) -#target_include_directories(DpdkDriver -# PUBLIC -# $ -# $ -# PRIVATE -# ${CMAKE_CURRENT_SOURCE_DIR}/src -#) -#target_link_libraries(DpdkDriver -# PRIVATE -# Dpdk::Dpdk -# PUBLIC -# Homa -#) -#target_compile_features(DpdkDriver -# PUBLIC -# cxx_std_11 -#) -#target_compile_options(DpdkDriver -# PRIVATE -# -Wall -# -Wextra -#) +add_library(DpdkDriver + src/Drivers/DPDK/DpdkDriver.cc + src/Drivers/DPDK/DpdkDriverImpl.cc + src/Drivers/DPDK/MacAddress.cc +) +add_library(Homa::DpdkDriver ALIAS DpdkDriver) +target_include_directories(DpdkDriver + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src +) +target_link_libraries(DpdkDriver + PRIVATE + Dpdk::Dpdk + PUBLIC + Homa +) +target_compile_features(DpdkDriver + PUBLIC + cxx_std_11 +) +target_compile_options(DpdkDriver + PRIVATE + -Wall + -Wextra +) ################################################################################ ## Tests ####################################################################### @@ -195,8 +195,7 @@ endif() ## Install & Export ############################################################ ################################################################################ -#install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets -install(TARGETS Homa FakeDriver EXPORT HomaTargets +install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets LIBRARY DESTINATION lib ARCHIVE DESTINATION lib RUNTIME DESTINATION bin @@ -275,11 +274,11 @@ target_sources(unit_test target_link_libraries(unit_test FakeDriver) #DPDK Tests -#target_sources(unit_test -# PUBLIC -# ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc -#) -#target_link_libraries(unit_test DpdkDriver) +target_sources(unit_test + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc +) +target_link_libraries(unit_test DpdkDriver) target_link_libraries(unit_test gmock_main) # -fno-access-control allows access to private members for testing diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index dafb05f..010d59b 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -53,14 +53,14 @@ class DpdkDriver : public Driver { * has exclusive access to DPDK. Note: This call will initialize the DPDK * EAL with default values. * - * @param port - * Selects which physical port to use for communication. + * @param ifname + * Selects which network interface to use for communication. * @param config * Optional configuration parameters (see Config). * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, const Config* const config = nullptr); + DpdkDriver(const char* ifname, const Config* const config = nullptr); /** * Construct a DpdkDriver and initialize the DPDK EAL using the provided @@ -75,7 +75,7 @@ class DpdkDriver : public Driver { * overriding the default affinity set by rte_eal_init(). * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param argc * Parameter passed to rte_eal_init(). * @param argv @@ -85,7 +85,7 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, int argc, char* argv[], + DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); /// Used to signal to the DpdkDriver constructor that the DPDK EAL should @@ -101,7 +101,7 @@ class DpdkDriver : public Driver { * called before calling this constructor. * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param _ * Parameter is used only to define this constructors alternate * signature. @@ -110,29 +110,20 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, NoEalInit _, const Config* const config = nullptr); + DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config = nullptr); /** * DpdkDriver Destructor. */ virtual ~DpdkDriver(); - /// See Driver::getAddress() - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - - /// See Driver::addressToString() - virtual std::string addressToString(const Address address); - - /// See Driver::addressToWireFormat() - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - /// See Driver::allocPacket() virtual Packet* allocPacket(); /// See Driver::sendPacket() - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); /// See Driver::cork() virtual void cork(); @@ -157,7 +148,7 @@ class DpdkDriver : public Driver { virtual uint32_t getBandwidth(); /// See Driver::getLocalAddress() - virtual Driver::Address getLocalAddress(); + virtual IpAddress getLocalAddress(); /// See Driver::getQueuedBytes(); virtual uint32_t getQueuedBytes(); diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 30a3548..a57a386 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -57,6 +57,8 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); +std::string ipToString(uint32_t ip); +uint32_t stringToIp(const char* ip); /** * This class is used to temporarily release lock in a safe fashion. Creating diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index c536159..1500c26 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -21,50 +21,21 @@ namespace Homa { namespace Drivers { namespace DPDK { -DpdkDriver::DpdkDriver(int port, const Config* const config) - : pImpl(new Impl(port, config)) +DpdkDriver::DpdkDriver(const char* ifname, const Config* const config) + : pImpl(new Impl(ifname, config)) {} -DpdkDriver::DpdkDriver(int port, int argc, char* argv[], +DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config) - : pImpl(new Impl(port, argc, argv, config)) + : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(int port, NoEalInit _, const Config* const config) - : pImpl(new Impl(port, _, config)) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, const Config* const config) + : pImpl(new Impl(ifname, _, config)) {} DpdkDriver::~DpdkDriver() = default; -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(std::string const* const addressString) -{ - return pImpl->getAddress(addressString); -} - -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(WireFormatAddress const* const wireAddress) -{ - return pImpl->getAddress(wireAddress); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::addressToString(const Address address) -{ - return pImpl->addressToString(address); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - pImpl->addressToWireFormat(address, wireAddress); -} - /// See Driver::allocPacket() Driver::Packet* DpdkDriver::allocPacket() @@ -74,9 +45,9 @@ DpdkDriver::allocPacket() /// See Driver::sendPacket() void -DpdkDriver::sendPacket(Packet* packet) +DpdkDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - return pImpl->sendPacket(packet); + return pImpl->sendPacket(packet, destination, priority); } /// See Driver::cork() @@ -128,7 +99,7 @@ DpdkDriver::getBandwidth() } /// See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::getLocalAddress() { return pImpl->getLocalAddress(); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e658ccb..ec4c58d 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,13 +17,19 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include +#include +#include +#include +#include + #include "DpdkDriverImpl.h" #include -#include #include "CodeLocation.h" #include "StringUtil.h" +#include "Homa/Util.h" namespace Homa { @@ -45,7 +51,7 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : Driver::Packet(data, 0) + : base {.payload = data, .length = 0, .sourceIp = 0} , bufType(MBUF) , bufRef() { @@ -59,7 +65,7 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : Driver::Packet(overflowBuf->data, 0) + : base {.payload = overflowBuf->data, .length = 0, .sourceIp = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -69,17 +75,21 @@ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, const Config* const config) - : Impl(port, default_eal_argc, const_cast(default_eal_argv), config) +DpdkDriver::Impl::Impl(const char* ifname, const Config* const config) + : Impl(ifname, default_eal_argc, const_cast(default_eal_argv), + config) {} /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, int argc, char* argv[], +DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -124,10 +134,14 @@ DpdkDriver::Impl::Impl(int port, int argc, char* argv[], /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, __attribute__((__unused__)) NoEalInit _, +DpdkDriver::Impl::Impl(const char* ifname, + __attribute__((__unused__)) NoEalInit _, const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -159,37 +173,8 @@ DpdkDriver::Impl::~Impl() rte_mempool_free(mbufPool); } -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(std::string const* const addressString) -{ - return MacAddress(addressString->c_str()).toAddress(); -} - -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - return MacAddress(wireAddress).toAddress(); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::Impl::addressToString(const Driver::Address address) -{ - return MacAddress(address).toString(); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::Impl::addressToWireFormat(const Driver::Address address, - Driver::WireFormatAddress* wireAddress) -{ - MacAddress(address).toWireFormat(wireAddress); -} - // See Driver::allocPacket() -DpdkDriver::Impl::Packet* +Driver::Packet* DpdkDriver::Impl::allocPacket() { DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); @@ -199,15 +184,17 @@ DpdkDriver::Impl::allocPacket() packet = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return packet; + return &packet->base; } // See Driver::sendPacket() void -DpdkDriver::Impl::sendPacket(Driver::Packet* packet) +DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, + int priority) { + ; DpdkDriver::Impl::Packet* pkt = - static_cast(packet); + container_of(packet, DpdkDriver::Impl::Packet, base); struct rte_mbuf* mbuf = nullptr; // If the packet is held in an Overflow buffer, we need to copy it out // into a new mbuf. @@ -223,15 +210,15 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) numMbufsAvail, numMbufsInUse); return; } - char* buf = rte_pktmbuf_append( - mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->length)); + char* buf = rte_pktmbuf_append(mbuf, + Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); rte_pktmbuf_free(mbuf); return; } char* data = buf + PACKET_HDR_LEN; - rte_memcpy(data, pkt->payload, pkt->length); + rte_memcpy(data, pkt->base.payload, pkt->base.length); } else { mbuf = pkt->bufRef.mbuf; @@ -246,9 +233,14 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // Fill out the destination and source MAC addresses plus the Ethernet // frame type (i.e., IEEE 802.1Q VLAN tagging). - MacAddress macAddr(pkt->address); + auto it = arpTable.find(destination); + if (it == arpTable.end()) { + WARNING("Failed to find ARP record for packet; dropping packet"); + return; + } + MacAddress& destMac = it->second; struct ether_hdr* ethHdr = rte_pktmbuf_mtod(mbuf, struct ether_hdr*); - rte_memcpy(ðHdr->d_addr, macAddr.address, ETHER_ADDR_LEN); + rte_memcpy(ðHdr->d_addr, destMac.address, ETHER_ADDR_LEN); rte_memcpy(ðHdr->s_addr, localMac.address, ETHER_ADDR_LEN); ethHdr->ether_type = rte_cpu_to_be_16(ETHER_TYPE_VLAN); @@ -256,13 +248,16 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // encapsulated frame (DEI and VLAN ID are not relevant and trivially // set to 0). struct vlan_hdr* vlanHdr = reinterpret_cast(ethHdr + 1); - vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[pkt->priority]); + vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[priority]); vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); + // Store our local IP address right before the payload. + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = localIp; + // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is // smaller, trim the mbuf to size to avoid sending unecessary bits. - uint32_t actualLength = PACKET_HDR_LEN + pkt->length; + uint32_t actualLength = PACKET_HDR_LEN + pkt->base.length; uint32_t mbufDataLength = rte_pktmbuf_pkt_len(mbuf); if (actualLength < mbufDataLength) { if (rte_pktmbuf_trim(mbuf, mbufDataLength - actualLength) < 0) { @@ -274,7 +269,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) } // loopback if src mac == dst mac - if (localMac.toAddress() == pkt->address) { + if (localMac == destMac) { struct rte_mbuf* mbuf_clone = rte_pktmbuf_clone(mbuf, mbufPool); if (unlikely(mbuf_clone == NULL)) { WARNING("Failed to clone packet for loopback; dropping packet"); @@ -390,6 +385,9 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, } } + uint32_t srcIp = *rte_pktmbuf_mtod_offset(m, uint32_t*, headerLength); + headerLength += sizeof(srcIp); + payload += sizeof(srcIp); assert(rte_pktmbuf_pkt_len(m) >= headerLength); uint32_t length = rte_pktmbuf_pkt_len(m) - headerLength; assert(length <= MAX_PAYLOAD_SIZE); @@ -399,10 +397,10 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, SpinLock::Lock lock(packetLock); packet = packetPool.construct(m, payload); } - packet->address = MacAddress(ethHdr->s_addr.addr_bytes).toAddress(); - packet->length = length; + packet->base.length = length; + packet->base.sourceIp = srcIp; - receivedPackets[numPacketsReceived++] = packet; + receivedPackets[numPacketsReceived++] = &packet->base; } return numPacketsReceived; @@ -415,7 +413,7 @@ DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) for (uint16_t i = 0; i < numPackets; ++i) { SpinLock::Lock lock(packetLock); DpdkDriver::Impl::Packet* packet = - static_cast(packets[i]); + container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { rte_pktmbuf_free(packet->bufRef.mbuf); } else { @@ -447,10 +445,10 @@ DpdkDriver::Impl::getBandwidth() } // See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::Impl::getLocalAddress() { - return localMac.toAddress(); + return localIp; } // See Driver::getQueuedBytes(); @@ -490,11 +488,71 @@ DpdkDriver::Impl::_eal_init(int argc, char* argv[]) void DpdkDriver::Impl::_init() { - struct ether_addr mac; struct rte_eth_conf portConf; int ret; uint16_t mtu; + // Populate the ARP table with records in /proc/net/arp (inspired by + // net-tools/arp.c) + std::ifstream input("/proc/net/arp"); + for (std::string line; getline(input, line);) { + char ip[100]; + char hwa[100]; + char mask[100]; + char dev[100]; + int type, flags; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", + ip, &type, &flags, hwa, mask, dev); + if (cols != 6) continue; + arpTable.emplace(Homa::Util::stringToIp(ip), hwa); + } + + // Use ioctl to obtain the IP and MAC addresses of the network interface. + struct ifreq ifr; + ifname.copy(ifr.ifr_name, ifname.length()); + ifr.ifr_name[ifname.length() + 1] = 0; + if (ifname.length() >= sizeof(ifr.ifr_name)) { + throw DriverInitFailure(HERE_STR, + StringUtil::format("Interface name %s too long", ifname.c_str())); + } + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd == -1) { + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to create socket: %s", strerror(errno))); + } + + if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to obtain IP address: %s", error)); + } + localIp = be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr); + + if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to obtain MAC address: %s", error)); + } + close(fd); + memcpy(localMac.address, ifr.ifr_hwaddr.sa_data, 6); + + // Iterate over ethernet devices to locate the port identifier. + int p; + RTE_ETH_FOREACH_DEV(p) { + struct ether_addr mac; + rte_eth_macaddr_get(p, &mac); + if (MacAddress(mac.addr_bytes) == localMac) { + port = p; + break; + } + } + NOTICE("Using interface %s, ip %s, mac %s, port %u", + ifname.c_str(), Homa::Util::ipToString(localIp).c_str(), + localMac.toString().c_str(), port); + std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); @@ -518,10 +576,6 @@ DpdkDriver::Impl::_init() StringUtil::format("Ethernet port %u doesn't exist", port)); } - // Read the MAC address from the NIC via DPDK. - rte_eth_macaddr_get(port, &mac); - new (const_cast(&localMac)) MacAddress(mac.addr_bytes); - // configure some default NIC port parameters memset(&portConf, 0, sizeof(portConf)); portConf.rxmode.max_rx_pkt_len = ETHER_MAX_VLAN_FRAME_LEN; diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 9b77383..4ed3406 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -28,6 +28,7 @@ #include #include +#include #include "MacAddress.h" #include "ObjectPool.h" @@ -65,8 +66,13 @@ const uint16_t MAX_PKT_BURST = 32; /// field defined in the VLAN tag to specify the packet priority. const uint32_t VLAN_TAG_LEN = 4; -// Size of Ethernet header including VLAN tag, in bytes. -const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN; +/// Strictly speaking, this DPDK driver is supposed to send/receive IP packets; +/// however, it currently only records the source IP address right after the +/// Ethernet header for simplicity. +const uint32_t IP_HDR_LEN = sizeof(IpAddress); + +// Size of Ethernet header including VLAN tag plus IP header, in bytes. +const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN + IP_HDR_LEN; // The MTU (Maximum Transmission Unit) size of an Ethernet frame, which is the // maximum size of the packet an Ethernet frame can carry in its payload. This @@ -104,11 +110,13 @@ class DpdkDriver::Impl { * Dpdk specific Packet object used to track a its lifetime and * contents. */ - class Packet : public Driver::Packet { - public: + struct Packet { explicit Packet(struct rte_mbuf* mbuf, void* data); explicit Packet(OverflowBuffer* overflowBuf); + /// C-style "inheritance" + Driver::Packet base; + /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. enum BufferType { MBUF, OVERFLOW_BUF } bufType; ///< Packet BufferType. @@ -122,26 +130,18 @@ class DpdkDriver::Impl { /// The memory location of this packet's header. The header should be /// PACKET_HDR_LEN in length. void* header; - - private: - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; }; - Impl(int port, const Config* const config = nullptr); - Impl(int port, int argc, char* argv[], + Impl(const char* ifname, const Config* const config = nullptr); + Impl(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); - Impl(int port, NoEalInit _, const Config* const config = nullptr); + Impl(const char* ifname, NoEalInit _, const Config* const config = nullptr); virtual ~Impl(); // Interface Methods - Driver::Address getAddress(std::string const* const addressString); - Driver::Address getAddress(WireFormatAddress const* const wireAddress); - std::string addressToString(const Address address); - void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - Packet* allocPacket(); - void sendPacket(Driver::Packet* packet); + Driver::Packet* allocPacket(); + void sendPacket(Driver::Packet* packet, IpAddress destination, + int priority); void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, @@ -150,7 +150,7 @@ class DpdkDriver::Impl { int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); uint32_t getBandwidth(); - Driver::Address getLocalAddress(); + IpAddress getLocalAddress(); uint32_t getQueuedBytes(); private: @@ -163,12 +163,21 @@ class DpdkDriver::Impl { static void txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, void* userdata); + /// Name of the Linux network interface to be used by DPDK. + std::string ifname; + /// Stores the NIC's physical port id addressed by the instantiated /// driver. - const uint16_t port; + uint16_t port; + + /// Address resolution table that translates IP addresses to MAC addresses. + std::unordered_map arpTable; + + /// Stores the IpAddress of the driver. + IpAddress localIp; - /// Stores the address of the NIC (either native or set by override). - const MacAddress localMac; + /// Stores the HW address of the NIC (either native or set by override). + MacAddress localMac; /// Stores the driver's maximum network packet priority (either default or /// set by override). diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 0178851..63149fa 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -18,7 +18,6 @@ #include "StringUtil.h" #include "../../CodeLocation.h" -#include "../RawAddressType.h" namespace Homa { namespace Drivers { @@ -55,33 +54,6 @@ MacAddress::MacAddress(const char* macStr) address[i] = Util::downCast(bytes[i]); } -/** - * Create a new address from a given address in its raw byte format. - * @param raw - * The raw bytes format. - * - * @sa Driver::Address::Raw - */ -MacAddress::MacAddress(const Driver::WireFormatAddress* const wireAddress) -{ - if (wireAddress->type != RawAddressType::MAC) { - throw BadAddress(HERE_STR, "Bad address: Raw format is not type MAC"); - } - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(address, wireAddress->bytes, 6); -} - -/** - * Create a new address given the Driver::Address representation. - * - * @param addr - * The Driver::Address representation of an address. - */ -MacAddress::MacAddress(const Driver::Address addr) -{ - memcpy(address, &addr, 6); -} - /** * Return the string representation of this address. */ @@ -94,31 +66,6 @@ MacAddress::toString() const return buf; } -/** - * Serialized this address into a wire format. - * - * @param[out] wireAddress - * WireFormatAddress object to which the this address is serialized. - */ -void -MacAddress::toWireFormat(Driver::WireFormatAddress* wireAddress) const -{ - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(wireAddress->bytes, address, 6); - wireAddress->type = RawAddressType::MAC; -} - -/** - * Return a Driver::Address representation of this address. - */ -Driver::Address -MacAddress::toAddress() const -{ - Driver::Address addr = 0; - memcpy(&addr, address, 6); - return addr; -} - /** * @return * True if the MacAddress consists of all zero bytes, false if not. diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 1106eec..148f2ce 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -28,14 +28,19 @@ namespace DPDK { struct MacAddress { explicit MacAddress(const uint8_t raw[6]); explicit MacAddress(const char* macStr); - explicit MacAddress(const Driver::WireFormatAddress* const wireAddress); - explicit MacAddress(const Driver::Address addr); MacAddress(const MacAddress&) = default; std::string toString() const; - void toWireFormat(Driver::WireFormatAddress* wireAddress) const; - Driver::Address toAddress() const; bool isNull() const; + /** + * Equality function for MacAddress, for use in std::unordered_maps etc. + */ + bool operator==(const MacAddress& other) const + { + return (*(uint32_t*)(address + 0) == *(uint32_t*)(other.address + 0)) && + (*(uint16_t*)(address + 4) == *(uint16_t*)(other.address + 4)); + } + /// The raw bytes of the MAC address. uint8_t address[6]; }; diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 329c309..7587a16 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -15,8 +15,6 @@ #include "MacAddress.h" -#include "../RawAddressType.h" - #include namespace Homa { @@ -35,26 +33,6 @@ TEST(MacAddressTest, constructorString) EXPECT_EQ("de:ad:be:ef:98:76", MacAddress("de:ad:be:ef:98:76").toString()); } -TEST(MacAddressTest, constructorWireFormatAddress) -{ - uint8_t bytes[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::MAC; - memcpy(wireformatAddress.bytes, bytes, 6); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(&wireformatAddress).toString()); - - wireformatAddress.type = RawAddressType::FAKE; - EXPECT_THROW(MacAddress address(&wireformatAddress), BadAddress); -} - -TEST(MacAddressTest, constructorAddress) -{ - uint8_t raw[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - MacAddress(raw).toString(); - Driver::Address addr = MacAddress("de:ad:be:ef:98:76").toAddress(); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(addr).toString()); -} - TEST(MacAddressTest, construct_DefaultCopy) { MacAddress source("de:ad:be:ef:98:76"); @@ -67,24 +45,6 @@ TEST(MacAddressTest, toString) // tested sufficiently in constructor tests } -TEST(MacAddressTest, toWireFormat) -{ - Driver::WireFormatAddress wireformatAddress; - MacAddress("de:ad:be:ef:98:76").toWireFormat(&wireformatAddress); - EXPECT_EQ(RawAddressType::MAC, wireformatAddress.type); - EXPECT_EQ(0xde, wireformatAddress.bytes[0]); - EXPECT_EQ(0xad, wireformatAddress.bytes[1]); - EXPECT_EQ(0xbe, wireformatAddress.bytes[2]); - EXPECT_EQ(0xef, wireformatAddress.bytes[3]); - EXPECT_EQ(0x98, wireformatAddress.bytes[4]); - EXPECT_EQ(0x76, wireformatAddress.bytes[5]); -} - -TEST(MacAddressTest, toAddress) -{ - // Tested in constructorAddress -} - TEST(MacAddressTest, isNull) { uint8_t rawNull[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; diff --git a/src/Drivers/RawAddressType.h b/src/Drivers/RawAddressType.h deleted file mode 100644 index 1def76d..0000000 --- a/src/Drivers/RawAddressType.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef HOMA_DRIVERS_RAWADDRESSTYPE_H -#define HOMA_DRIVERS_RAWADDRESSTYPE_H - -namespace Homa { -namespace Drivers { - -/** - * Identifies a particular raw serialized byte-format for a Driver::Address - * supported by this project. The types are enumerated here in one place to - * ensure drivers do have overlapping type identifiers. New drivers that wish - * to claim a type id should add an entry to this enum. - * - * @sa Driver::Address::Raw - */ -enum RawAddressType { - FAKE = 0, - MAC = 1, -}; - -} // namespace Drivers -} // namespace Homa - -#endif // HOMA_DRIVERS_RAWADDRESSTYPE_H diff --git a/src/Util.cc b/src/Util.cc index 90ee9f4..fe73752 100644 --- a/src/Util.cc +++ b/src/Util.cc @@ -100,5 +100,21 @@ hexDump(const void* buf, uint64_t bytes) return output.str(); } +std::string +ipToString(uint32_t ip) +{ + return StringUtil::format("%d.%d.%d.%d", + (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} + +uint32_t +stringToIp(const char* ipStr) +{ + unsigned int bytes[4]; + sscanf(ipStr, "%u.%u.%u.%u", &bytes[0], &bytes[1], &bytes[2], &bytes[3]); + return (bytes[0] << 24) | (bytes[1] << 16) | (bytes[2] << 8) | bytes[3]; +} + + } // namespace Util } // namespace Homa diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 403a340..e01e945 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,6 +34,18 @@ target_link_libraries(system_test docopt ) +## dpdk_test ################################################################# + +add_executable(dpdk_test + dpdk_test.cc +) +target_link_libraries(dpdk_test + PRIVATE + Homa::DpdkDriver + docopt + PerfUtils +) + ## Perf ######################################################################## add_executable(Perf diff --git a/test/Output.h b/test/Output.h new file mode 100644 index 0000000..bea8f8b --- /dev/null +++ b/test/Output.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include + +namespace Output { + +using Latency = std::chrono::duration; + +struct TimeDist { + Latency min; // Fastest time seen (seconds). + Latency p50; // Median time per operation (seconds). + Latency p90; // 90th percentile time/op (seconds). + Latency p99; // 99th percentile time/op (seconds). + Latency p999; // 99.9th percentile time/op (seconds). +}; + +std::string +format(const std::string& format, ...) +{ + va_list args; + va_start(args, format); + size_t len = std::vsnprintf(NULL, 0, format.c_str(), args); + va_end(args); + std::vector vec(len + 1); + va_start(args, format); + std::vsnprintf(&vec[0], len + 1, format.c_str(), args); + va_end(args); + return &vec[0]; +} + +std::string +formatTime(Latency seconds) +{ + if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f ns", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f us", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.2f ms", + std::chrono::duration(seconds).count()); + } else { + return format("%5.2f s ", seconds.count()); + } +} + +std::string +basicHeader() +{ + return "median min p90 p99 p999 description"; +} + +std::string +basic(std::vector& times, const std::string description) +{ + int count = times.size(); + std::sort(times.begin(), times.end()); + + TimeDist dist; + + dist.min = times[0]; + int index = count / 2; + if (index < count) { + dist.p50 = times.at(index); + } else { + dist.p50 = dist.min; + } + index = count - (count + 5) / 10; + if (index < count) { + dist.p90 = times.at(index); + } else { + dist.p90 = dist.p50; + } + index = count - (count + 50) / 100; + if (index < count) { + dist.p99 = times.at(index); + } else { + dist.p99 = dist.p90; + } + index = count - (count + 500) / 1000; + if (index < count) { + dist.p999 = times.at(index); + } else { + dist.p999 = dist.p99; + } + + std::string output = ""; + output += format("%9s", formatTime(dist.p50).c_str()); + output += format(" %9s", formatTime(dist.min).c_str()); + output += format(" %9s", formatTime(dist.p90).c_str()); + output += format(" %9s", formatTime(dist.p99).c_str()); + output += format(" %9s", formatTime(dist.p999).c_str()); + output += " "; + output += description; + return output; +} + +} // namespace Output diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc new file mode 100644 index 0000000..ebf9ba2 --- /dev/null +++ b/test/dpdk_test.cc @@ -0,0 +1,88 @@ +#include +#include + +#include +#include +#include +#include + +#include "Output.h" + +static const char USAGE[] = R"(DPDK Driver Test. + + Usage: + dpdk_test [options] (--server | ) + + Options: + -h --help Show this screen. + --version Show version. + --timetrace Enable TimeTrace output [default: false]. +)"; + +int +main(int argc, char* argv[]) +{ + std::map args = + docopt::docopt(USAGE, {argv + 1, argv + argc}, + true, // show help if requested + "DPDK Driver Test"); // version string + + std::string iface = args[""].asString(); + bool isServer = args["--server"].asBool(); + std::string server_ip_string; + if (!isServer) { + server_ip_string = args[""].asString(); + } + + Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); + + if (isServer) { + std::cout << Homa::Util::ipToString(driver.getLocalAddress()) + << std::endl; + while (true) { + Homa::Driver::Packet* incoming[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming); + } while (receivedPackets == 0); + Homa::Driver::Packet* pong = driver.allocPacket(); + pong->length = 100; + driver.sendPacket(pong, incoming[0]->sourceIp, 0); + driver.releasePackets(incoming, receivedPackets); + driver.releasePackets(&pong, 1); + } + } else { + Homa::IpAddress server_ip = + Homa::Util::stringToIp(server_ip_string.c_str()); + std::vector times; + for (int i = 0; i < 100000; ++i) { + uint64_t start = PerfUtils::Cycles::rdtsc(); + PerfUtils::TimeTrace::record(start, "START"); + Homa::Driver::Packet* ping = driver.allocPacket(); + PerfUtils::TimeTrace::record("allocPacket"); + ping->length = 100; + PerfUtils::TimeTrace::record("set ping args"); + driver.sendPacket(ping, server_ip, 0); + PerfUtils::TimeTrace::record("sendPacket"); + driver.releasePackets(&ping, 1); + PerfUtils::TimeTrace::record("releasePacket"); + Homa::Driver::Packet* incoming[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming); + PerfUtils::TimeTrace::record("receivePackets"); + } while (receivedPackets == 0); + driver.releasePackets(incoming, receivedPackets); + PerfUtils::TimeTrace::record("releasePacket"); + uint64_t stop = PerfUtils::Cycles::rdtsc(); + times.emplace_back(PerfUtils::Cycles::toSeconds(stop - start)); + } + if (args["--timetrace"].asBool()) { + PerfUtils::TimeTrace::print(); + } + std::cout << Output::basicHeader() << std::endl; + std::cout << Output::basic(times, "DpdkDriver Ping-Pong") << std::endl; + } + + return 0; +} \ No newline at end of file From 23e6c28cd70570a6d7bb3f37c67b8d02ebd22f36 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Thu, 6 Aug 2020 23:04:20 -0700 Subject: [PATCH 04/33] Improvements based on code review discussions with Collin - Change IpAddress from typedef to a POD type to provide some type safety - Add a third argument to Driver::receivePackets() to hold source addresses of ingress packets when the method returns - Eliminate Driver::Packet (use Homa::PacketSpec instead) - Move L4 header fields sport/dport into header prefix --- CMakeLists.txt | 3 +- include/Homa/Driver.h | 89 ++++++++++-------- include/Homa/Drivers/DPDK/DpdkDriver.h | 3 +- include/Homa/Drivers/Fake/FakeDriver.h | 16 ++-- include/Homa/Util.h | 11 +-- src/Driver.cc | 38 ++++++++ src/Drivers/DPDK/DpdkDriver.cc | 5 +- src/Drivers/DPDK/DpdkDriverImpl.cc | 23 +++-- src/Drivers/DPDK/DpdkDriverImpl.h | 5 +- src/Drivers/Fake/FakeDriver.cc | 28 +++--- src/Drivers/Fake/FakeDriverTest.cc | 23 ++--- src/Mock/MockDriver.h | 3 +- src/Policy.h | 3 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 24 ++--- src/Receiver.cc | 4 +- src/Receiver.h | 2 +- src/ReceiverTest.cc | 123 +++++++++++++------------ src/SenderTest.cc | 4 +- src/TransportImpl.cc | 6 +- src/Util.cc | 16 ---- test/Output.h | 15 +++ test/dpdk_test.cc | 27 +++++- 23 files changed, 283 insertions(+), 190 deletions(-) create mode 100644 src/Driver.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..f5cb6ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.1.0 LANGUAGES CXX) +project(Homa VERSION 0.1.2.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### @@ -74,6 +74,7 @@ endif() add_library(Homa src/CodeLocation.cc src/Debug.cc + src/Driver.cc src/Homa.cc src/Perf.cc src/Policy.cc diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index d510046..a5cc855 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -22,8 +22,49 @@ namespace Homa { -/// IPv4 address in host byte order. -using IpAddress = uint32_t; +/** + * A simple wrapper struct around an IP address in binary format. + * + * This struct is meant to provide some type-safety when manipulating IP + * addresses. In order to avoid any runtime overhead, this struct contains + * nothing more than the IP address, so it is trivially copyable. + */ +struct IpAddress final { + /// IPv4 address in host byte order. + uint32_t addr; + + /** + * Unbox the IP address in binary format. + */ + explicit operator uint32_t() + { + return addr; + } + + /** + * Equality function for IpAddress, for use in std::unordered_maps etc. + */ + bool operator==(const IpAddress& other) const + { + return addr == other.addr; + } + + /** + * This class computes a hash of an IpAddress, so that IpAddress can be used + * as keys in unordered_maps. + */ + struct Hasher { + /// Return a "hash" of the given IpAddress. + std::size_t operator()(const IpAddress& address) const + { + return std::hash{}(address.addr); + } + }; + + static std::string toString(IpAddress address); + static IpAddress fromString(const char* addressStr); +}; +static_assert(std::is_trivially_copyable()); /** * Represents a packet of data that can be send or is received over the network. @@ -43,6 +84,7 @@ struct PacketSpec { /// Number of bytes in the payload. int32_t length; } __attribute__((packed)); +static_assert(std::is_trivial()); /** * Used by Homa::Transport to send and receive unreliable datagrams. Provides @@ -52,41 +94,8 @@ struct PacketSpec { */ class Driver { public: - /** - * Represents a packet that can be send or is received over the network. - * - * The layout of this struct has two parts: the first part is essentially - * a copy of PacketSpec, while the second part contains members specific - * to our driver implementation. - * - * @sa Homa::PacketSpec - */ - struct Packet final { - // === PacketSpec definitions === - // The order and types of the following members must match those in - // PacketSpec precisely. - - /// See Homa::PacketSpec::payload. - void* payload; - - /// See Homa::PacketSpec::length - int32_t length; - - // === Extended definitions === - // The following members are specific to the driver framework bundled - // in this library. Therefore, these members must *NOT* appear in the - // core components of Homa transport; they are only used in a few - // places to facilitate the glue code between transport and driver. - - /// Packet's source IpAddress. Only meaningful when this packet is an - /// incoming packet. - IpAddress sourceIp; - } __attribute__((packed)); - - // Static checks to enforce the object layout compatibility between - // Driver::Packet and PacketSpec. - static_assert(offsetof(Packet, payload) == offsetof(PacketSpec, payload)); - static_assert(offsetof(Packet, length) == offsetof(PacketSpec, length)); + /// Import PacketSpec into the Driver namespace. + using Packet = PacketSpec; /** * Driver destructor. @@ -164,6 +173,9 @@ class Driver { * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. + * @param[out] sourceAddresses + * Source IP addresses of the received packets are appended to this + * array in order of arrival. * * @return * Number of Packet objects being returned. @@ -171,7 +183,8 @@ class Driver { * @sa Driver::releasePackets() */ virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]) = 0; + Packet* receivedPackets[], + IpAddress sourceAddresses[]) = 0; /** * Release a collection of Packet objects back to the Driver. Every diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index 010d59b..f15d575 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -133,7 +133,8 @@ class DpdkDriver : public Driver { /// See Driver::receivePackets() virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); /// See Driver::releasePackets() virtual void releasePackets(Packet* packets[], uint16_t numPackets); diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 04ce8c0..dd01261 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -58,14 +58,17 @@ struct FakePacket { /// Raw storage for this packets payload. char buf[MAX_PAYLOAD_SIZE]; + /// Source IpAddress of the packet. + IpAddress sourceIp; + /** * FakePacket constructor. */ explicit FakePacket() : base{.payload = buf, - .length = 0, - .sourceIp = 0} + .length = 0} , buf() + , sourceIp() {} /** @@ -73,9 +76,9 @@ struct FakePacket { */ FakePacket(const FakePacket& other) : base{.payload = buf, - .length = other.base.length, - .sourceIp = 0} + .length = other.base.length} , buf() + , sourceIp() { memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); } @@ -111,7 +114,8 @@ class FakeDriver : public Driver { virtual Packet* allocPacket(); virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); virtual void releasePackets(Packet* packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); @@ -121,7 +125,7 @@ class FakeDriver : public Driver { private: /// Identifier for this driver on the fake network. - uint64_t localAddressId; + uint32_t localAddressId; /// Holds the incoming packets for this driver. FakeNIC nic; diff --git a/include/Homa/Util.h b/include/Homa/Util.h index a57a386..ba757e6 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -22,10 +22,11 @@ #include /// Cast a member of a structure out to the containing structure. -#define container_of(ptr, type, member) ({ \ - const typeof( ((type *)0)->member ) \ - *__mptr = (ptr); \ - (type *)( (char *)__mptr - offsetof(type,member) );}) +template +P* container_of(M* ptr, const M P::*member) +{ + return (P*)((char*) ptr - (size_t) &(reinterpret_cast(0)->*member)); +} namespace Homa { namespace Util { @@ -57,8 +58,6 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); -std::string ipToString(uint32_t ip); -uint32_t stringToIp(const char* ip); /** * This class is used to temporarily release lock in a safe fashion. Creating diff --git a/src/Driver.cc b/src/Driver.cc new file mode 100644 index 0000000..c7d61cb --- /dev/null +++ b/src/Driver.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2018-2019, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include + +#include "StringUtil.h" + +namespace Homa { + +std::string +IpAddress::toString(IpAddress address) +{ + uint32_t ip = address.addr; + return StringUtil::format("%d.%d.%d.%d", + (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} + +IpAddress +IpAddress::fromString(const char* addressStr) +{ + unsigned int b0, b1, b2, b3; + sscanf(addressStr, "%u.%u.%u.%u", &b0, &b1, &b2, &b3); + return IpAddress{(b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3}; +} + +} // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index 1500c26..3c8833a 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -66,9 +66,10 @@ DpdkDriver::uncork() /// See Driver::receivePackets() uint32_t -DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { - return pImpl->receivePackets(maxPackets, receivedPackets); + return pImpl->receivePackets(maxPackets, receivedPackets, sourceAddresses); } /// See Driver::releasePackets() void diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index ec4c58d..42a2340 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -51,7 +51,8 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : base {.payload = data, .length = 0, .sourceIp = 0} + : base {.payload = data, + .length = 0} , bufType(MBUF) , bufRef() { @@ -65,7 +66,8 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : base {.payload = overflowBuf->data, .length = 0, .sourceIp = 0} + : base {.payload = overflowBuf->data, + .length = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -252,7 +254,8 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); // Store our local IP address right before the payload. - *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = localIp; + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = + (uint32_t)localIp; // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is @@ -322,7 +325,8 @@ DpdkDriver::Impl::uncork() // See Driver::receivePackets() uint32_t DpdkDriver::Impl::receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]) + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]) { uint32_t numPacketsReceived = 0; @@ -398,9 +402,10 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, packet = packetPool.construct(m, payload); } packet->base.length = length; - packet->base.sourceIp = srcIp; - receivedPackets[numPacketsReceived++] = &packet->base; + receivedPackets[numPacketsReceived] = &packet->base; + sourceAddresses[numPacketsReceived] = {srcIp}; + ++numPacketsReceived; } return numPacketsReceived; @@ -504,7 +509,7 @@ DpdkDriver::Impl::_init() int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, &type, &flags, hwa, mask, dev); if (cols != 6) continue; - arpTable.emplace(Homa::Util::stringToIp(ip), hwa); + arpTable.emplace(IpAddress::fromString(ip), hwa); } // Use ioctl to obtain the IP and MAC addresses of the network interface. @@ -528,7 +533,7 @@ DpdkDriver::Impl::_init() throw DriverInitFailure(HERE_STR, StringUtil::format("Failed to obtain IP address: %s", error)); } - localIp = be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr); + localIp = {be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr)}; if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { char* error = strerror(errno); @@ -550,7 +555,7 @@ DpdkDriver::Impl::_init() } } NOTICE("Using interface %s, ip %s, mac %s, port %u", - ifname.c_str(), Homa::Util::ipToString(localIp).c_str(), + ifname.c_str(), IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), port); std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 4ed3406..4d664fb 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -145,7 +145,8 @@ class DpdkDriver::Impl { void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]); + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]); void releasePackets(Driver::Packet* packets[], uint16_t numPackets); int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); @@ -171,7 +172,7 @@ class DpdkDriver::Impl { uint16_t port; /// Address resolution table that translates IP addresses to MAC addresses. - std::unordered_map arpTable; + std::unordered_map arpTable; /// Stores the IpAddress of the driver. IpAddress localIp; diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index b6355cc..5cbafb8 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -56,19 +56,21 @@ static class FakeNetwork { /// Register the FakeNIC so it can receive packets. Returns the newly /// registered FakeNIC's addressId. - uint64_t registerNIC(FakeNIC* nic) + uint32_t registerNIC(FakeNIC* nic) { std::lock_guard lock(mutex); - uint64_t addressId = nextAddressId.fetch_add(1); - network.insert({addressId, nic}); + uint32_t addressId = nextAddressId.fetch_add(1); + IpAddress ipAddress{addressId}; + network.insert({ipAddress, nic}); return addressId; } /// Remove the FakeNIC from the network. - void deregisterNIC(uint64_t addressId) + void deregisterNIC(uint32_t addressId) { std::lock_guard lock(mutex); - network.erase(addressId); + IpAddress ipAddress{addressId}; + network.erase(ipAddress); } /// Deliver the provide packet to the specified destination. @@ -92,7 +94,7 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->base.sourceIp = src; + dstPacket->sourceIp = src; assert(priority < NUM_PRIORITIES); assert(priority >= 0); nic->priorityQueue.at(priority).push_back(dstPacket); @@ -115,10 +117,10 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. - std::atomic nextAddressId; + std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. double packetLossRate; @@ -192,7 +194,7 @@ FakeDriver::allocPacket() void FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = container_of(packet, FakePacket, base); + FakePacket* srcPacket = container_of(packet, &FakePacket::base); IpAddress srcAddress = getLocalAddress(); IpAddress dstAddress = destination; fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); @@ -203,7 +205,8 @@ FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) * See Driver::receivePackets() */ uint32_t -FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { std::lock_guard lock_nic(nic.mutex); uint32_t numReceived = 0; @@ -212,6 +215,7 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); receivedPackets[numReceived] = &fakePacket->base; + sourceAddresses[numReceived] = fakePacket->sourceIp; numReceived++; } } @@ -225,7 +229,7 @@ void FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - delete container_of(packets[i], FakePacket, base); + delete container_of(packets[i], &FakePacket::base); } } @@ -263,7 +267,7 @@ FakeDriver::getBandwidth() IpAddress FakeDriver::getLocalAddress() { - return localAddressId; + return IpAddress{localAddressId}; } /** diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index 2390abf..43802ae 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -27,7 +27,7 @@ namespace { TEST(FakeDriverTest, constructor) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; EXPECT_EQ(nextAddressId, driver.localAddressId); @@ -38,7 +38,7 @@ TEST(FakeDriverTest, allocPacket) FakeDriver driver; Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete container_of(packet, FakePacket, base); + delete container_of(packet, &FakePacket::base); } TEST(FakeDriverTest, sendPackets) @@ -54,7 +54,7 @@ TEST(FakeDriverTest, sendPackets) destinations[i] = driver2.getLocalAddress(); prio[i] = i; } - destinations[2] = IpAddress(42); + destinations[2] = IpAddress{42}; EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -76,7 +76,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = &driver2.nic.priorityQueue.at(0).front()->base; + FakePacket* packet = driver2.nic.priorityQueue.at(0).front(); EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } @@ -93,7 +93,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - delete container_of(packets[2], FakePacket, base); + delete container_of(packets[2], &FakePacket::base); } TEST(FakeDriverTest, receivePackets) @@ -102,6 +102,7 @@ TEST(FakeDriverTest, receivePackets) FakeDriver driver; Driver::Packet* packets[4]; + IpAddress srcAddrs[4]; // 3 packets at priority 7 for (int i = 0; i < 3; ++i) @@ -123,7 +124,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(3U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(4U, driver.receivePackets(4, packets)); + EXPECT_EQ(4U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 4); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -135,7 +136,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -158,7 +159,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(1U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -170,7 +171,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(3U, driver.receivePackets(4, packets)); + EXPECT_EQ(3U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 3); } @@ -199,9 +200,9 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; - EXPECT_EQ(nextAddressId, driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, (uint32_t)driver.getLocalAddress()); } } // namespace diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 35fd731..4080882 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -43,7 +43,8 @@ class MockDriver : public Driver { (override)); MOCK_METHOD(void, flushPackets, ()); MOCK_METHOD(uint32_t, receivePackets, - (uint32_t maxPackets, Packet* receivedPackets[]), (override)); + (uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]), (override)); MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), (override)); MOCK_METHOD(int, getHighestPacketPriority, (), (override)); diff --git a/src/Policy.h b/src/Policy.h index 6c80c90..5339f32 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -107,7 +107,8 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map + peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index 88cdd45..4f23806 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - IpAddress dest(22); + IpAddress dest{22}; { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index 25471bb..ef2c723 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -104,38 +104,40 @@ enum Opcode { /** * This is the first part of the Homa packet header and is common to all - * versions of the protocol. The struct contains version information about the + * versions of the protocol. The first four bytes of the header store the source + * and destination ports, which is common for many transport layer protocols + * (e.g., TCP, UDP, etc.) The struct also contains version information about the * protocol used in the encompassing packet. The Transport should always send * this prefix and can always expect it when receiving a Homa packet. The prefix * is separated into its own struct because the Transport may need to know the * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { + uint16_t sport, dport;///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. uint8_t version; ///< The version of the protocol being used by this ///< packet. /// HeaderPrefix constructor. - HeaderPrefix(uint8_t version) - : version(version) + HeaderPrefix(uint16_t sport, uint16_t dport, uint8_t version) + : sport(sport) + , dport(dport) + , version(version) {} } __attribute__((packed)); /** * Describes the wire format for header fields that are common to all packet - * types. Note: the first 4 bytes are identical for TCP, UDP, and Homa. + * types. */ struct CommonHeader { - uint16_t sport, dport;///< Transport layer (L4) source and destination ports - ///< in network byte order; only used by DataHeader. HeaderPrefix prefix; ///< Common to all versions of the protocol. uint8_t opcode; ///< One of the values of Opcode. MessageId messageId; ///< RemoteOp/Message associated with this packet. /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : sport(0) - , dport(0) - , prefix(1) + : prefix(0, 0, 1) , opcode(opcode) , messageId(messageId) {} @@ -170,8 +172,8 @@ struct DataHeader { , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) { - common.sport = htobe16(sport); - common.dport = htobe16(dport); + common.prefix.sport = htobe16(sport); + common.prefix.dport = htobe16(dport); } } __attribute__((packed)); diff --git a/src/Receiver.cc b/src/Receiver.cc index d499a61..d007087 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -104,7 +104,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) SpinLock::Lock lock_allocator(messageAllocator.mutex); SocketAddress srcAddress = { .ip = sourceIp, - .port = be16toh(header->common.sport) + .port = be16toh(header->common.prefix.sport) }; message = messageAllocator.pool.construct( this, driver, dataHeaderLength, messageLength, id, @@ -126,7 +126,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) assert(id == message->id); assert(message->driver == driver); assert(message->source.ip == sourceIp); - assert(message->source.port == be16toh(header->common.sport)); + assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet diff --git a/src/Receiver.h b/src/Receiver.h index c97c462..65e65ff 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -474,7 +474,7 @@ class Receiver { /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 213e2bd..da9e0bc 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -37,6 +37,9 @@ using ::testing::NiceMock; using ::testing::Pointee; using ::testing::Return; +/// Helper macro to construct an IpAddress from a numeric number. +#define IP(x) IpAddress{x} + class ReceiverTest : public ::testing::Test { public: ReceiverTest() @@ -105,21 +108,21 @@ TEST_F(ReceiverTest, handleDataPacket) header->totalLength = totalMessageLength; header->policyVersion = policyVersion; header->unscheduledIndexLimit = 1; - mockPacket.sourceIp = IpAddress(22); + IpAddress sourceIp{22}; // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.sourceIp), Eq(policyVersion), + signalNewMessage(Eq(sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- { @@ -148,7 +151,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -162,7 +165,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -177,7 +180,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -192,7 +195,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); @@ -207,7 +210,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -251,7 +254,7 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - IpAddress mockAddress = 22; + IpAddress mockAddress{22}; Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); @@ -264,7 +267,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) char pingPayload[1028]; Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; - pingPacket.sourceIp = mockAddress; + IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; @@ -277,7 +280,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); + receiver->handlePingPacket(&pingPacket, sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); @@ -296,8 +299,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) char pingPayload[1028]; Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; - IpAddress mockAddress = (IpAddress)22; - pingPacket.sourceIp = mockAddress; + IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; @@ -310,7 +312,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); + receiver->handlePingPacket(&pingPacket, mockAddress); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; @@ -864,11 +866,11 @@ TEST_F(ReceiverTest, trySendGrants) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - for (uint64_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, SocketAddress{IpAddress(100 + i), 60001}, + 10000 * (i + 1), id, SocketAddress{IP(100 + i), 60001}, 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); @@ -996,7 +998,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[0], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[0]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[0]->peer); EXPECT_EQ(message[0], &info[0]->peer->scheduledMessages.front()); EXPECT_EQ(info[0]->peer, &receiver->scheduledPeers.front()); @@ -1008,7 +1010,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[1], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[1]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[1]->peer); EXPECT_EQ(message[1], &info[1]->peer->scheduledMessages.front()); EXPECT_EQ(info[1]->peer, &receiver->scheduledPeers.back()); @@ -1020,7 +1022,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[2], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[2]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[2]->peer); EXPECT_EQ(message[2], &info[2]->peer->scheduledMessages.front()); EXPECT_EQ(info[2]->peer, &receiver->scheduledPeers.front()); @@ -1032,7 +1034,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[3], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[3]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[3]->peer); EXPECT_EQ(message[3], &info[3]->peer->scheduledMessages.back()); EXPECT_EQ(info[3]->peer, &receiver->scheduledPeers.back()); } @@ -1043,23 +1045,24 @@ TEST_F(ReceiverTest, unschedule) Receiver::ScheduledMessageInfo* info[5]; SpinLock::Lock lock(receiver->schedulerMutex); int messageLength[5] = {10, 20, 30, 10, 20}; - for (uint64_t i = 0; i < 5; ++i) { + for (uint32_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - IpAddress source = IpAddress((i / 3) + 10); + IpAddress source = IP((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } + auto& scheduledPeers = receiver->scheduledPeers; - ASSERT_EQ(IpAddress(10), message[0]->source.ip); - ASSERT_EQ(IpAddress(10), message[1]->source.ip); - ASSERT_EQ(IpAddress(10), message[2]->source.ip); - ASSERT_EQ(IpAddress(11), message[3]->source.ip); - ASSERT_EQ(IpAddress(11), message[4]->source.ip); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + ASSERT_EQ(IP(10), message[0]->source.ip); + ASSERT_EQ(IP(10), message[1]->source.ip); + ASSERT_EQ(IP(10), message[2]->source.ip); + ASSERT_EQ(IP(11), message[3]->source.ip); + ASSERT_EQ(IP(11), message[4]->source.ip); + ASSERT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + ASSERT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); // <10>: [0](10) -> [1](20) -> [2](30) // <11>: [3](10) -> [4](20) @@ -1077,10 +1080,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[4], lock); EXPECT_EQ(nullptr, info[4]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(3U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(3U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[1]; peer in correct position. @@ -1090,10 +1093,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[1], lock); EXPECT_EQ(nullptr, info[1]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(2U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(2U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[0]; peer needs to be reordered. @@ -1103,10 +1106,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[0], lock); EXPECT_EQ(nullptr, info[0]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[3]; peer needs to be removed. @@ -1115,10 +1118,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[3], lock); EXPECT_EQ(nullptr, info[3]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(0U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(0U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); } TEST_F(ReceiverTest, updateSchedule) @@ -1127,10 +1130,10 @@ TEST_F(ReceiverTest, updateSchedule) // 11 : [20][30] SpinLock::Lock lock(receiver->schedulerMutex); Receiver::Message* other[3]; - for (uint64_t i = 0; i < 3; ++i) { + for (uint32_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - IpAddress source = IpAddress(((i + 1) / 2) + 10); + IpAddress source = IP(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 10 * (i + 1), id, SocketAddress{source, 60001}, 0); @@ -1140,12 +1143,13 @@ TEST_F(ReceiverTest, updateSchedule) receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); - ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[2]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), message->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& peerTable = receiver->peerTable; + ASSERT_EQ(&peerTable.at(IP(10)), other[0]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[1]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[2]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), message->scheduledMessageInfo.peer); + ASSERT_EQ(&receiver->scheduledPeers.front(), &peerTable.at(IP(10))); + ASSERT_EQ(&receiver->scheduledPeers.back(), &peerTable.at(IP(11))); //-------------------------------------------------------------------------- // Move message up within peer. @@ -1155,11 +1159,12 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& scheduledPeers = receiver->scheduledPeers; + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); Receiver::Peer* peer = &receiver->scheduledPeers.back(); auto it = peer->scheduledMessages.begin(); EXPECT_TRUE( - std::next(receiver->peerTable.at(11).scheduledMessages.begin()) == + std::next(receiver->peerTable.at(IP(11)).scheduledMessages.begin()) == message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1171,8 +1176,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1184,8 +1189,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); } diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 244a7c9..dfb216e 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1367,8 +1367,8 @@ TEST_F(SenderTest, sendMessage_basic) // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); - EXPECT_EQ(htobe16(sport), header->common.sport); - EXPECT_EQ(htobe16(dport), header->common.dport); + EXPECT_EQ(htobe16(sport), header->common.prefix.sport); + EXPECT_EQ(htobe16(dport), header->common.prefix.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index d4ebc70..a380944 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -98,10 +98,10 @@ TransportImpl::processPackets() const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; - int numPackets = driver->receivePackets(MAX_BURST, packets); + IpAddress srcAddrs[MAX_BURST]; + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); for (int i = 0; i < numPackets; ++i) { - Driver::Packet* packet = packets[i]; - processPacket(packet, packet->sourceIp); + processPacket(packets[i], srcAddrs[i]); } cycles = PerfUtils::Cycles::rdtsc() - cycles; diff --git a/src/Util.cc b/src/Util.cc index fe73752..90ee9f4 100644 --- a/src/Util.cc +++ b/src/Util.cc @@ -100,21 +100,5 @@ hexDump(const void* buf, uint64_t bytes) return output.str(); } -std::string -ipToString(uint32_t ip) -{ - return StringUtil::format("%d.%d.%d.%d", - (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); -} - -uint32_t -stringToIp(const char* ipStr) -{ - unsigned int bytes[4]; - sscanf(ipStr, "%u.%u.%u.%u", &bytes[0], &bytes[1], &bytes[2], &bytes[3]); - return (bytes[0] << 24) | (bytes[1] << 16) | (bytes[2] << 8) | bytes[3]; -} - - } // namespace Util } // namespace Homa diff --git a/test/Output.h b/test/Output.h index bea8f8b..de8d740 100644 --- a/test/Output.h +++ b/test/Output.h @@ -1,3 +1,18 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + #pragma once #include diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc index ebf9ba2..4ca1a82 100644 --- a/test/dpdk_test.cc +++ b/test/dpdk_test.cc @@ -1,3 +1,18 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + #include #include @@ -37,23 +52,24 @@ main(int argc, char* argv[]) Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); if (isServer) { - std::cout << Homa::Util::ipToString(driver.getLocalAddress()) + std::cout << Homa::IpAddress::toString(driver.getLocalAddress()) << std::endl; while (true) { Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { - receivedPackets = driver.receivePackets(10, incoming); + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); } while (receivedPackets == 0); Homa::Driver::Packet* pong = driver.allocPacket(); pong->length = 100; - driver.sendPacket(pong, incoming[0]->sourceIp, 0); + driver.sendPacket(pong, srcAddrs[0], 0); driver.releasePackets(incoming, receivedPackets); driver.releasePackets(&pong, 1); } } else { Homa::IpAddress server_ip = - Homa::Util::stringToIp(server_ip_string.c_str()); + Homa::IpAddress::fromString(server_ip_string.c_str()); std::vector times; for (int i = 0; i < 100000; ++i) { uint64_t start = PerfUtils::Cycles::rdtsc(); @@ -67,9 +83,10 @@ main(int argc, char* argv[]) driver.releasePackets(&ping, 1); PerfUtils::TimeTrace::record("releasePacket"); Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { - receivedPackets = driver.receivePackets(10, incoming); + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); PerfUtils::TimeTrace::record("receivePackets"); } while (receivedPackets == 0); driver.releasePackets(incoming, receivedPackets); From c1308dc5666e2862c5f14b778602a12c5f83d8bf Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Mon, 5 Oct 2020 23:48:31 -0700 Subject: [PATCH 05/33] fixed copyrights and formatting issues --- include/Homa/Driver.h | 8 ++-- include/Homa/Drivers/Fake/FakeDriver.h | 9 ++-- include/Homa/Util.h | 9 ++-- src/ControlPacket.h | 2 +- src/Driver.cc | 6 +-- src/Drivers/DPDK/DpdkDriver.cc | 3 +- src/Drivers/DPDK/DpdkDriverImpl.cc | 45 +++++++++++--------- src/Drivers/DPDK/MacAddress.cc | 2 +- src/Drivers/DPDK/MacAddress.h | 2 +- src/Drivers/DPDK/MacAddressTest.cc | 2 +- src/Drivers/Fake/FakeDriver.cc | 2 +- src/Drivers/Fake/FakeDriverTest.cc | 2 +- src/Mock/MockDriver.h | 7 +-- src/Mock/MockPolicy.h | 2 +- src/Mock/MockReceiver.h | 8 ++-- src/Mock/MockSender.h | 13 +++--- src/Policy.h | 3 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 5 ++- src/Receiver.cc | 16 +++---- src/ReceiverTest.cc | 47 +++++++++++--------- src/Sender.cc | 4 +- src/SenderTest.cc | 59 ++++++++++++++------------ src/TransportImplTest.cc | 16 +++---- test/system_test.cc | 2 +- 25 files changed, 147 insertions(+), 129 deletions(-) diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index a5cc855..67df785 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -83,7 +83,7 @@ struct PacketSpec { /// Number of bytes in the payload. int32_t length; -} __attribute__((packed)); +} __attribute__((packed)); static_assert(std::is_trivial()); /** @@ -140,7 +140,7 @@ class Driver { * getHighestPacketPriority(). */ virtual void sendPacket(Packet* packet, IpAddress destination, - int priority) = 0; + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -170,12 +170,12 @@ class Driver { * * @param maxPackets * The maximum number of Packet objects that should be returned by - * this method. + * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. * @param[out] sourceAddresses * Source IP addresses of the received packets are appended to this - * array in order of arrival. + * array in order of arrival. * * @return * Number of Packet objects being returned. diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index dd01261..5f54586 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -65,8 +65,7 @@ struct FakePacket { * FakePacket constructor. */ explicit FakePacket() - : base{.payload = buf, - .length = 0} + : base{.payload = buf, .length = 0} , buf() , sourceIp() {} @@ -75,8 +74,7 @@ struct FakePacket { * Copy constructor. */ FakePacket(const FakePacket& other) - : base{.payload = buf, - .length = other.base.length} + : base{.payload = buf, .length = other.base.length} , buf() , sourceIp() { @@ -112,7 +110,8 @@ class FakeDriver : public Driver { virtual ~FakeDriver(); virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet, IpAddress destination, int priority); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); virtual uint32_t receivePackets(uint32_t maxPackets, Packet* receivedPackets[], IpAddress sourceAddresses[]); diff --git a/include/Homa/Util.h b/include/Homa/Util.h index ba757e6..4f17acc 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2009-2018, Stanford University +/* Copyright (c) 2009-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -22,10 +22,11 @@ #include /// Cast a member of a structure out to the containing structure. -template -P* container_of(M* ptr, const M P::*member) +template +P* +container_of(M* ptr, const M P::*member) { - return (P*)((char*) ptr - (size_t) &(reinterpret_cast(0)->*member)); + return (P*)((char*)ptr - (size_t) & (reinterpret_cast(0)->*member)); } namespace Homa { diff --git a/src/ControlPacket.h b/src/ControlPacket.h index bc53f10..17310af 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Driver.cc b/src/Driver.cc index c7d61cb..b29c828 100644 --- a/src/Driver.cc +++ b/src/Driver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018-2019, Stanford University +/* Copyright (c) 2018-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -23,8 +23,8 @@ std::string IpAddress::toString(IpAddress address) { uint32_t ip = address.addr; - return StringUtil::format("%d.%d.%d.%d", - (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); + return StringUtil::format("%d.%d.%d.%d", (ip >> 24) & 0xff, + (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); } IpAddress diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index 3c8833a..c27d1df 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -30,7 +30,8 @@ DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, const Config* const config) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config) : pImpl(new Impl(ifname, _, config)) {} diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 42a2340..e9fef18 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,19 +17,19 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#include -#include #include #include +#include #include +#include #include "DpdkDriverImpl.h" #include #include "CodeLocation.h" -#include "StringUtil.h" #include "Homa/Util.h" +#include "StringUtil.h" namespace Homa { @@ -51,8 +51,7 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : base {.payload = data, - .length = 0} + : base{.payload = data, .length = 0} , bufType(MBUF) , bufRef() { @@ -66,8 +65,7 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : base {.payload = overflowBuf->data, - .length = 0} + : base{.payload = overflowBuf->data, .length = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -212,7 +210,8 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, numMbufsAvail, numMbufsInUse); return; } - char* buf = rte_pktmbuf_append(mbuf, + char* buf = rte_pktmbuf_append( + mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); @@ -506,9 +505,10 @@ DpdkDriver::Impl::_init() char mask[100]; char dev[100]; int type, flags; - int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", - ip, &type, &flags, hwa, mask, dev); - if (cols != 6) continue; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, + &type, &flags, hwa, mask, dev); + if (cols != 6) + continue; arpTable.emplace(IpAddress::fromString(ip), hwa); } @@ -517,28 +517,32 @@ DpdkDriver::Impl::_init() ifname.copy(ifr.ifr_name, ifname.length()); ifr.ifr_name[ifname.length() + 1] = 0; if (ifname.length() >= sizeof(ifr.ifr_name)) { - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Interface name %s too long", ifname.c_str())); } int fd = socket(AF_INET, SOCK_DGRAM, 0); if (fd == -1) { - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to create socket: %s", strerror(errno))); } if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { char* error = strerror(errno); close(fd); - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to obtain IP address: %s", error)); } - localIp = {be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr)}; + localIp = {be32toh(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr.s_addr)}; if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { char* error = strerror(errno); close(fd); - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to obtain MAC address: %s", error)); } close(fd); @@ -546,7 +550,8 @@ DpdkDriver::Impl::_init() // Iterate over ethernet devices to locate the port identifier. int p; - RTE_ETH_FOREACH_DEV(p) { + RTE_ETH_FOREACH_DEV(p) + { struct ether_addr mac; rte_eth_macaddr_get(p, &mac); if (MacAddress(mac.addr_bytes) == localMac) { @@ -554,9 +559,9 @@ DpdkDriver::Impl::_init() break; } } - NOTICE("Using interface %s, ip %s, mac %s, port %u", - ifname.c_str(), IpAddress::toString(localIp).c_str(), - localMac.toString().c_str(), port); + NOTICE("Using interface %s, ip %s, mac %s, port %u", ifname.c_str(), + IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), + port); std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 63149fa..e47f27a 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 148f2ce..33f47a5 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 7587a16..9b8b8ae 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 5cbafb8..26cb102 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index 43802ae..cd64917 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 4080882..9ea6ffe 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -39,13 +39,14 @@ class MockDriver : public Driver { MOCK_METHOD(Packet*, allocPacket, (), (override)); MOCK_METHOD(void, sendPacket, - (Packet* packet, IpAddress destination, int priority), + (Packet * packet, IpAddress destination, int priority), (override)); MOCK_METHOD(void, flushPackets, ()); MOCK_METHOD(uint32_t, receivePackets, (uint32_t maxPackets, Packet* receivedPackets[], - IpAddress sourceAddresses[]), (override)); - MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), + IpAddress sourceAddresses[]), + (override)); + MOCK_METHOD(void, releasePackets, (Packet * packets[], uint16_t numPackets), (override)); MOCK_METHOD(int, getHighestPacketPriority, (), (override)); MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 52cb2a5..32e7be8 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index 75eea2c..61c21ce 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -37,10 +37,10 @@ class MockReceiver : public Core::Receiver { {} MOCK_METHOD(void, handleDataPacket, - (Driver::Packet* packet, IpAddress sourceIp), (override)); - MOCK_METHOD(void, handleBusyPacket, (Driver::Packet* packet), (override)); + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, handlePingPacket, - (Driver::Packet* packet, IpAddress sourceIp), (override)); + (Driver::Packet * packet, IpAddress sourceIp), (override)); MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index 4a8bd27..cb29c90 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -38,12 +38,13 @@ class MockSender : public Core::Sender { {} MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); - MOCK_METHOD(void, handleDonePacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleGrantPacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleResendPacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet* packet), + MOCK_METHOD(void, handleDonePacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet * packet), (override)); - MOCK_METHOD(void, handleErrorPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; diff --git a/src/Policy.h b/src/Policy.h index 5339f32..0be1eb2 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -77,8 +77,7 @@ class Manager { virtual Scheduled getScheduledPolicy(); virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const IpAddress source, - uint8_t policyVersion, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index 4f23806..44b8829 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Protocol.h b/src/Protocol.h index ef2c723..55a34ac 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -113,8 +113,9 @@ enum Opcode { * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { - uint16_t sport, dport;///< Transport layer (L4) source and destination ports - ///< in network byte order; only used by DataHeader. + uint16_t sport, + dport; ///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. uint8_t version; ///< The version of the protocol being used by this ///< packet. diff --git a/src/Receiver.cc b/src/Receiver.cc index d007087..c850d07 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -103,17 +103,15 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { SpinLock::Lock lock_allocator(messageAllocator.mutex); SocketAddress srcAddress = { - .ip = sourceIp, - .port = be16toh(header->common.prefix.sport) - }; + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; message = messageAllocator.pool.construct( - this, driver, dataHeaderLength, messageLength, id, - srcAddress, numUnscheduledPackets); + this, driver, dataHeaderLength, messageLength, id, srcAddress, + numUnscheduledPackets); } bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source.ip, - header->policyVersion, header->totalLength); + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); if (message->scheduled) { // Message needs to be scheduled. @@ -244,8 +242,8 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send( - driver, sourceIp, id); + ControlPacket::send(driver, sourceIp, + id); } driver->releasePackets(&packet, 1); } diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index da9e0bc..bfccc39 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -38,13 +38,17 @@ using ::testing::Pointee; using ::testing::Return; /// Helper macro to construct an IpAddress from a numeric number. -#define IP(x) IpAddress{x} +#define IP(x) \ + IpAddress \ + { \ + x \ + } class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket {&payload} + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , payload() , receiver() @@ -266,7 +270,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; @@ -298,7 +302,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; @@ -352,7 +356,8 @@ TEST_F(ReceiverTest, poll) TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), SocketAddress{0, 60001}, 0); + Protocol::MessageId(0, 0), + SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); bucket->resendTimeouts.setTimeout(&message.resendTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -420,8 +425,9 @@ TEST_F(ReceiverTest, Message_acknowledge) receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket( - Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -456,8 +462,9 @@ TEST_F(ReceiverTest, Message_fail) receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket( - Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -477,8 +484,8 @@ TEST_F(ReceiverTest, Message_get_basic) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -504,8 +511,8 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -530,8 +537,8 @@ TEST_F(ReceiverTest, Message_get_missingPacket) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -806,16 +813,18 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1 {buf1}; - Homa::Mock::MockDriver::MockPacket mockResendPacket2 {buf2}; + Homa::Mock::MockDriver::MockPacket mockResendPacket1{buf1}; + Homa::Mock::MockDriver::MockPacket mockResendPacket2{buf2}; EXPECT_CALL(mockDriver, allocPacket()) .WillOnce(Return(&mockResendPacket1)) .WillOnce(Return(&mockResendPacket2)); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), - Eq(message[0]->source.ip), _)).Times(1); + Eq(message[0]->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), - Eq(message[0]->source.ip), _)).Times(1); + Eq(message[0]->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) diff --git a/src/Sender.cc b/src/Sender.cc index b993b6b..ea75bf4 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -391,7 +391,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); driver->sendPacket(dataPacket, message->destination.ip, - policy.priority); + policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -401,7 +401,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) // was first queued. assert(info->id == message->id); assert(!memcmp(&info->destination, &message->destination, - sizeof(info->destination))); + sizeof(info->destination))); assert(info->packets == message); // Some values need to be updated info->unsentBytes = message->messageLength; diff --git a/src/SenderTest.cc b/src/SenderTest.cc index dfb216e..8085c82 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -36,7 +36,7 @@ class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket {&payload} + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , sender() , savedLogPolicy(Debug::getLogPolicy()) @@ -314,7 +314,7 @@ TEST_F(SenderTest, handleResendPacket_basic) std::vector packets; std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); priorities.push_back(0); setMessagePacket(message, i, packets[i]); } @@ -333,10 +333,12 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)).WillOnce( - [&priorities] (auto _1, auto _2, int p) { priorities[3] = p; }); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)).WillOnce( - [&priorities] (auto _1, auto _2, int p) { priorities[4] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[4] = p; }); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -381,7 +383,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload}; + new Homa::Mock::MockDriver::MockPacket{payload}; setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -421,7 +423,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) dynamic_cast(sender->allocMessage(0)); std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -470,7 +472,7 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket {data}; + Homa::Mock::MockDriver::MockPacket dataPacket{data}; for (int i = 0; i < 10; ++i) { setMessagePacket(message, i, &dataPacket); } @@ -488,7 +490,7 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket {busy}; + Homa::Mock::MockDriver::MockPacket busyPacket{busy}; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) @@ -648,7 +650,7 @@ TEST_F(SenderTest, handleUnknownPacket_basic) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload[i]}; + new Homa::Mock::MockDriver::MockPacket{payload[i]}; Protocol::Packet::DataHeader* header = static_cast(packet->payload); header->policyVersion = policyOld.version; @@ -716,7 +718,7 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - Homa::Mock::MockDriver::MockPacket dataPacket {payload}; + Homa::Mock::MockDriver::MockPacket dataPacket{payload}; Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; @@ -768,7 +770,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload[i]}; + new Homa::Mock::MockDriver::MockPacket{payload[i]}; packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -1092,8 +1094,8 @@ TEST_F(SenderTest, Message_append_basic) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1132,8 +1134,8 @@ TEST_F(SenderTest, Message_append_truncated) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1189,8 +1191,8 @@ TEST_F(SenderTest, Message_prepend) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1224,8 +1226,8 @@ TEST_F(SenderTest, Message_reserve) { Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1279,8 +1281,8 @@ TEST_F(SenderTest, Message_getOrAllocPacket) // TODO(cstlee): cleanup Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); @@ -1353,7 +1355,8 @@ TEST_F(SenderTest, sendMessage_basic) .WillOnce(Return(policy)); int mockPriority = 0; EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) - .WillOnce([&mockPriority] (auto _1, auto _2, int p){mockPriority = p;}); + .WillOnce( + [&mockPriority](auto _1, auto _2, int p) { mockPriority = p; }); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); @@ -1391,8 +1394,8 @@ TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - Homa::Mock::MockDriver::MockPacket packet0 {payload0}; - Homa::Mock::MockDriver::MockPacket packet1 {payload1}; + Homa::Mock::MockDriver::MockPacket packet0{payload0}; + Homa::Mock::MockDriver::MockPacket packet1{payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = @@ -1663,7 +1666,7 @@ TEST_F(SenderTest, trySend_basic) const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; @@ -1745,7 +1748,7 @@ TEST_F(SenderTest, trySend_multipleMessages) message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index c69a36a..a0f66c6 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -102,56 +102,56 @@ TEST_F(TransportImplTest, processPackets) Homa::Driver::Packet* packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket {payload[0], 1024}; + Homa::Mock::MockDriver::MockPacket dataPacket{payload[0], 1024}; static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; packets[0] = &dataPacket; EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket {payload[1], 1024}; + Homa::Mock::MockDriver::MockPacket grantPacket{payload[1], 1024}; static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; packets[1] = &grantPacket; EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket {payload[2], 1024}; + Homa::Mock::MockDriver::MockPacket donePacket{payload[2], 1024}; static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; packets[2] = &donePacket; EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket {payload[3], 1024}; + Homa::Mock::MockDriver::MockPacket resendPacket{payload[3], 1024}; static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; packets[3] = &resendPacket; EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket {payload[4], 1024}; + Homa::Mock::MockDriver::MockPacket busyPacket{payload[4], 1024}; static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; packets[4] = &busyPacket; EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket {payload[5], 1024}; + Homa::Mock::MockDriver::MockPacket pingPacket{payload[5], 1024}; static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; packets[5] = &pingPacket; EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket {payload[6], 1024}; + Homa::Mock::MockDriver::MockPacket unknownPacket{payload[6], 1024}; static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; packets[6] = &unknownPacket; EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket {payload[7], 1024}; + Homa::Mock::MockDriver::MockPacket errorPacket{payload[7], 1024}; static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; packets[7] = &errorPacket; diff --git a/test/system_test.cc b/test/system_test.cc index 266d842..88b3814 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above From a75e59c34ba61145e0aa6d3d58ef69f90dfef299 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Tue, 11 Aug 2020 21:47:17 -0700 Subject: [PATCH 06/33] Include additional Perf microbenchmarks --- test/Perf.cc | 263 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 261 insertions(+), 2 deletions(-) diff --git a/test/Perf.cc b/test/Perf.cc index 2b6f04c..7e9327f 100644 --- a/test/Perf.cc +++ b/test/Perf.cc @@ -13,22 +13,25 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include #include +#include #include #include #include #include +#include #include #include +#include #include #include #include #include "Cycles.h" -#include "docopt.h" - #include "Homa/Drivers/Util/QueueEstimator.h" #include "Intrusive.h" +#include "docopt.h" static const char USAGE[] = R"(Performance Nano-Benchmark @@ -62,6 +65,135 @@ struct TestInfo { // should be less than 72 characters long). }; +TestInfo atomicLoadTestInfo = { + "atomicLoad", "Read an std::atomic", + R"(Measure the cost of reading an std::atomic value.)"}; +double +atomicLoadTest() +{ + int count = 1000000; + uint64_t temp; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + temp = val[i].load(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicStoreTestInfo = { + "atomicStore", "Write an std::atomic", + R"(Measure the cost of writing an std::atomic value.)"}; +double +atomicStoreTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicStoreRelaxedTestInfo = { + "atomicStoreRelaxed", "Write an std::atomic (std::memory_order_relaxed)", + R"(Measure the cost of a relaxed atomic write.)"}; +double +atomicStoreRelaxedTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(temp, std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicIncTestInfo = { + "atomicInc", "Increment an std::atomic", + R"(Measure the cost of incrementing an std::atomic value.)"}; +double +atomicIncTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].fetch_add(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo atomicIncRelaxedTestInfo = { + "atomicIncRelaxed", "Increment an std::atomic (std::memory_order_relaxed)", + R"(Measure the cost of a relaxed atomic incrementing.)"}; +double +atomicIncRelaxedTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].fetch_add(temp, std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo intReadWriteTestInfo = { + "intReadWrite", "Read and write a uint64_t", + R"(Measure the cost the baseline read/write.)"}; +double +intReadWriteTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + uint64_t val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i] = temp; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + +TestInfo branchTestInfo = { + "branchTest", "If-else statement", + R"(The cost of choosing a branch in an if-else statement)"}; +double +branchTest() +{ + int count = 1000000; + uint64_t temp = std::rand(); + uint64_t a[0xFF + 1]; + uint64_t b[0xFF + 1]; + for (int i = 0; i < 0xFF + 1; i++) { + a[i] = std::rand(); + b[i] = std::rand(); + } + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + int index = i & 0xFF; + if (a[index] < b[index]) { + b[index] = a[index]; + } else { + a[index] = b[index]; + } + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + TestInfo listSearchTestInfo = { "listSearch", "Linear search an intrusive list", R"(Measure the cost (per entry) of searching through an intrusive list)"}; @@ -194,6 +326,105 @@ mapNullInsertTest() return PerfUtils::Cycles::toSeconds(stop - start) / run; } +TestInfo dequeConstructInfo = { + "dequeConstruct", "Construct an std::deque", + R"(Measure the cost of constructing (and destructing) an std::deque.)"}; +double +dequeConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::deque deque; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo dequePushPopTestInfo = { + "dequePushPop", "std::deque operations", + R"(Measure the cost of pushing/popping an element to/from an std::deque.)"}; +double +dequePushPopTest() +{ + std::deque deque; + for (int i = 0; i < 10000; ++i) { + deque.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = deque.front(); + deque.pop_front(); + deque.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo listConstructInfo = { + "listConstruct", "Construct an std::list", + R"(Measure the cost of constructing (and destructing) an std::list.)"}; +double +listConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::list list; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo listPushPopTestInfo = { + "listPushPop", "std::list operations", + R"(Measure the cost of pushing/popping an element to/from an std::list.)"}; +double +listPushPopTest() +{ + std::list list; + for (int i = 0; i < 10000; ++i) { + list.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = list.front(); + list.pop_front(); + list.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + +TestInfo heapTestInfo = { + "heap", "std heap operations", + R"(Measure the cost of pushing/popping an element to/from an std heap.)"}; +double +heapTest() +{ + std::vector heap; + for (uint64_t i = 0; i < 10000; ++i) { + heap.push_back(i); + } + std::make_heap(heap.begin(), heap.end(), std::greater<>{}); + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = heap.front(); + std::pop_heap(heap.begin(), heap.end(), std::greater<>{}); + heap.pop_back(); + heap.push_back(temp + 5000); + std::push_heap(heap.begin(), heap.end(), std::greater<>{}); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + TestInfo queueEstimatorTestInfo = { "queueEstimator", "Update a QueueEstimator", R"(Measure the cost of updating a Homa::Drivers::Util::QueueEstimator.)"}; @@ -244,6 +475,21 @@ rdhrcTest() return PerfUtils::Cycles::toSeconds(stop - start) / count; } +TestInfo rdcscTestInfo = { + "rdcsc", "Read std::chrono::steady_clock", + R"(Measure the cost of reading the std::chrono::steady_clock.)"}; +double +rdcscTest() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + auto timestamp = std::chrono::steady_clock::now(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + // The following struct and table define each performance test in terms of // function that implements the test and collection of string information about // the test like the test's string name. @@ -255,13 +501,26 @@ struct TestCase { // including the test's string name. }; TestCase tests[] = { + {atomicLoadTest, &atomicLoadTestInfo}, + {atomicStoreTest, &atomicStoreTestInfo}, + {atomicStoreRelaxedTest, &atomicStoreRelaxedTestInfo}, + {atomicIncTest, &atomicIncTestInfo}, + {atomicIncRelaxedTest, &atomicIncRelaxedTestInfo}, + {branchTest, &branchTestInfo}, + {intReadWriteTest, &intReadWriteTestInfo}, {listSearchTest, &listSearchTestInfo}, {mapFindTest, &mapFindTestInfo}, {mapLookupTest, &mapLookupTestInfo}, {mapNullInsertTest, &mapNullInsertTestInfo}, + {dequeConstruct, &dequeConstructInfo}, + {dequePushPopTest, &dequePushPopTestInfo}, + {listConstruct, &listConstructInfo}, + {listPushPopTest, &listPushPopTestInfo}, + {heapTest, &heapTestInfo}, {queueEstimatorTest, &queueEstimatorTestInfo}, {rdtscTest, &rdtscTestInfo}, {rdhrcTest, &rdhrcTestInfo}, + {rdcscTest, &rdcscTestInfo}, }; /** From 1c5e460ddcc31ae52c33ff7bf805995c12d69769 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Sat, 15 Aug 2020 15:31:35 -0700 Subject: [PATCH 07/33] Introduce incremental timeout checking Timeouts are now checked incrementally with one bucket checked every poll iteration. Previously, timeout checking was more concentrated; checking was triggered less frequently but all buckets would be checked once triggered. The previous design allows the poll iteration to complete quickly when no timeout checking is triggered but significantly increases poll() execution times when checking needs to occur. This causes large, periodic latency spikes when using the transport. In the new design sightly increases the minimum poll() execution times but allows the work of checking timeouts to be more evenly distributed over time and thus reduces latency spikes. --- include/Homa/Util.h | 10 ++ src/Mock/MockReceiver.h | 1 - src/Mock/MockSender.h | 1 - src/Receiver.cc | 311 +++++++++++++++++++-------------------- src/Receiver.h | 16 +- src/ReceiverTest.cc | 71 +++------ src/Sender.cc | 201 ++++++++++++------------- src/Sender.h | 16 +- src/SenderTest.cc | 91 ++++-------- src/Timeout.h | 104 +++++++++++-- src/TimeoutTest.cc | 43 +++--- src/TransportImpl.cc | 10 -- src/TransportImplTest.cc | 12 -- src/UtilTest.cc | 7 +- 14 files changed, 450 insertions(+), 444 deletions(-) diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 4f17acc..1c1b75a 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -60,6 +60,16 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); +/** + * Return true if the given number is a power of 2; false, otherwise. + */ +template +constexpr bool +isPowerOfTwo(num_type n) +{ + return (n > 0) && ((n & (n - 1)) == 0); +} + /** * This class is used to temporarily release lock in a safe fashion. Creating * an object of this class will unlock its associated mutex; when the object diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index 61c21ce..fa9ceba 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -43,7 +43,6 @@ class MockReceiver : public Core::Receiver { (Driver::Packet * packet, IpAddress sourceIp), (override)); MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); MOCK_METHOD(void, poll, (), (override)); - MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index cb29c90..3f05128 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -46,7 +46,6 @@ class MockSender : public Core::Sender { (override)); MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, poll, (), (override)); - MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Receiver.cc b/src/Receiver.cc index c850d07..18441c9 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -18,7 +18,6 @@ #include #include "Perf.h" -#include "Util.h" namespace Homa { namespace Core { @@ -46,6 +45,7 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, , scheduledPeers() , receivedMessages() , granting() + , nextBucketIndex(0) , messageAllocator() {} @@ -280,30 +280,7 @@ void Receiver::poll() { trySendGrants(); -} - -/** - * Process any Receiver timeouts that have expired. - * - * This method must be called periodically to ensure timely handling of - * expired timeouts. - * - * @return - * The rdtsc cycle time when this method should be called again. - */ -uint64_t -Receiver::checkTimeouts() -{ - uint64_t nextTimeout; - - // Ping Timeout - nextTimeout = checkResendTimeouts(); - - // Message Timeout - uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - - return nextTimeout; + checkTimeouts(); } /** @@ -528,71 +505,68 @@ Receiver::dropMessage(Receiver::Message* message) * Process any inbound messages that have timed out due to lack of activity from * the Sender. * + * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose message timeouts should be checked. */ -uint64_t -Receiver::checkMessageTimeouts() +void +Receiver::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - - // No remaining timeouts. - if (bucket->messageTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->messageTimeouts.timeoutIntervalCycles; - break; - } + if (!bucket->messageTimeouts.anyElapsed(now)) { + return; + } - Message* message = &bucket->messageTimeouts.list.front(); + while (true) { + SpinLock::Lock lock_bucket(bucket->mutex); - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed()) { - nextTimeout = message->messageTimeout.expirationCycleTime; - break; - } + // No remaining timeouts. + if (bucket->messageTimeouts.empty()) { + break; + } - // Found expired timeout. + Message* message = &bucket->messageTimeouts.front(); - // Cancel timeouts - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); + // No remaining expired timeouts. + if (!message->messageTimeout.hasElapsed(now)) { + break; + } - if (message->state == Message::State::IN_PROGRESS) { - // Message timed out before being fully received; drop the - // message. + // Found expired timeout. - // Unschedule the message - if (message->scheduled) { - // Unschedule the message if it is still scheduled (i.e. - // still linked to a scheduled peer). - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - if (info->peer != nullptr) { - unschedule(message, lock_scheduler); - } - } + // Cancel timeouts + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - bucket->messages.remove(&message->bucketNode); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); + if (message->state == Message::State::IN_PROGRESS) { + // Message timed out before being fully received; drop the + // message. + + // Unschedule the message + if (message->scheduled) { + // Unschedule the message if it is still scheduled (i.e. + // still linked to a scheduled peer). + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + if (info->peer != nullptr) { + unschedule(message, lock_scheduler); } - } else { - // Message timed out but we already made it available to the - // Transport; let the Transport know. - message->state.store(Message::State::DROPPED); } + + bucket->messages.remove(&message->bucketNode); + { + SpinLock::Lock lock_allocator(messageAllocator.mutex); + messageAllocator.pool.destroy(message); + } + } else { + // Message timed out but we already made it available to the + // Transport; let the Transport know. + message->state.store(Message::State::DROPPED); } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); } - return globalNextTimeout; } /** @@ -600,107 +574,118 @@ Receiver::checkMessageTimeouts() * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose resend timeouts should be checked. */ -uint64_t -Receiver::checkResendTimeouts() +void +Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - - // No remaining timeouts. - if (bucket->resendTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->resendTimeouts.timeoutIntervalCycles; - break; - } + if (!bucket->resendTimeouts.anyElapsed(now)) { + return; + } - Message* message = &bucket->resendTimeouts.list.front(); + while (true) { + SpinLock::Lock lock_bucket(bucket->mutex); - // No remaining expired timeouts. - if (!message->resendTimeout.hasElapsed()) { - nextTimeout = message->resendTimeout.expirationCycleTime; - break; - } + // No remaining timeouts. + if (bucket->resendTimeouts.empty()) { + break; + } - // Found expired timeout. - assert(message->state == Message::State::IN_PROGRESS); - bucket->resendTimeouts.setTimeout(&message->resendTimeout); + Message* message = &bucket->resendTimeouts.front(); - // This Receiver expected to have heard from the Sender within the - // last timeout period but it didn't. Request a resend of granted - // packets in case DATA packets got lost. - int index = 0; - int num = 0; - int grantIndexLimit = message->numUnscheduledPackets; + // No remaining expired timeouts. + if (!message->resendTimeout.hasElapsed(now)) { + break; + } - if (message->scheduled) { - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - int receivedBytes = info->messageLength - info->bytesRemaining; - if (receivedBytes >= info->bytesGranted) { - // Sender is blocked on this Receiver; all granted packets - // have already been received. No need to check for resend. - continue; - } else if (grantIndexLimit * message->PACKET_DATA_LENGTH < - info->bytesGranted) { - grantIndexLimit = - (info->bytesGranted + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH; - } + // Found expired timeout. + assert(message->state == Message::State::IN_PROGRESS); + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + + // This Receiver expected to have heard from the Sender within the + // last timeout period but it didn't. Request a resend of granted + // packets in case DATA packets got lost. + int index = 0; + int num = 0; + int grantIndexLimit = message->numUnscheduledPackets; + + if (message->scheduled) { + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + int receivedBytes = info->messageLength - info->bytesRemaining; + if (receivedBytes >= info->bytesGranted) { + // Sender is blocked on this Receiver; all granted packets + // have already been received. No need to check for resend. + continue; + } else if (grantIndexLimit * message->PACKET_DATA_LENGTH < + info->bytesGranted) { + grantIndexLimit = + (info->bytesGranted + message->PACKET_DATA_LENGTH - 1) / + message->PACKET_DATA_LENGTH; } + } - for (int i = 0; i < grantIndexLimit; ++i) { - if (message->getPacket(i) == nullptr) { - // Unreceived packet - if (num == 0) { - // First unreceived packet - index = i; - } - ++num; - } else { - // Received packet - if (num != 0) { - // Send out the range of packets found so far. - // - // The RESEND also includes the current granted priority - // so that it can act as a GRANT in case a GRANT was - // lost. If this message hasn't been scheduled (i.e. no - // grants have been sent) then the priority will hold - // the default value; this is ok since the Sender will - // ignore the priority field for resends of purely - // unscheduled packets (see - // Sender::handleResendPacket()). - SpinLock::Lock lock_scheduler(schedulerMutex); - Perf::counters.tx_resend_pkts.add(1); - ControlPacket::send( - message->driver, message->source.ip, message->id, - Util::downCast(index), - Util::downCast(num), - message->scheduledMessageInfo.priority); - num = 0; - } + for (int i = 0; i < grantIndexLimit; ++i) { + if (message->getPacket(i) == nullptr) { + // Unreceived packet + if (num == 0) { + // First unreceived packet + index = i; + } + ++num; + } else { + // Received packet + if (num != 0) { + // Send out the range of packets found so far. + // + // The RESEND also includes the current granted priority + // so that it can act as a GRANT in case a GRANT was + // lost. If this message hasn't been scheduled (i.e. no + // grants have been sent) then the priority will hold + // the default value; this is ok since the Sender will + // ignore the priority field for resends of purely + // unscheduled packets (see + // Sender::handleResendPacket()). + SpinLock::Lock lock_scheduler(schedulerMutex); + Perf::counters.tx_resend_pkts.add(1); + ControlPacket::send( + message->driver, message->source.ip, message->id, + Util::downCast(index), + Util::downCast(num), + message->scheduledMessageInfo.priority); + num = 0; } } - if (num != 0) { - // Send out the last range of packets found. - SpinLock::Lock lock_scheduler(schedulerMutex); - Perf::counters.tx_resend_pkts.add(1); - ControlPacket::send( - message->driver, message->source.ip, message->id, - Util::downCast(index), - Util::downCast(num), - message->scheduledMessageInfo.priority); - } } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + if (num != 0) { + // Send out the last range of packets found. + SpinLock::Lock lock_scheduler(schedulerMutex); + Perf::counters.tx_resend_pkts.add(1); + ControlPacket::send( + message->driver, message->source.ip, message->id, + Util::downCast(index), Util::downCast(num), + message->scheduledMessageInfo.priority); + } } - return globalNextTimeout; +} + +/** + * Process any Receiver timeouts that have expired. + * + * Pulled out of poll() for ease of testing. + */ +void +Receiver::checkTimeouts() +{ + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = messageBuckets.buckets.at(index); + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkResendTimeouts(now, bucket); + checkMessageTimeouts(now, bucket); } /** diff --git a/src/Receiver.h b/src/Receiver.h index 65e65ff..bcca5a8 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -52,7 +53,6 @@ class Receiver { virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); virtual void poll(); - virtual uint64_t checkTimeouts(); private: // Forward declaration @@ -335,6 +335,9 @@ class Receiver { */ static const int NUM_BUCKETS = 256; + // Make sure the number of buckets is a power of 2. + static_assert(Util::isPowerOfTwo(NUM_BUCKETS)); + /** * Bit mask used to map from a hashed key to the bucket index. */ @@ -451,8 +454,9 @@ class Receiver { }; void dropMessage(Receiver::Message* message); - uint64_t checkMessageTimeouts(); - uint64_t checkResendTimeouts(); + void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); + void checkResendTimeouts(uint64_t now, MessageBucket* bucket); + void checkTimeouts(); void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); void unschedule(Message* message, const SpinLock::Lock& lock); @@ -493,6 +497,12 @@ class Receiver { /// each other. std::atomic_flag granting = ATOMIC_FLAG_INIT; + /// The index of the next bucket in the messageBuckets::buckets array to + /// process in the poll loop. The index is held in the lower order bits of + /// this variable; the higher order bits should be masked off using the + /// MessageBucketMap::HASH_KEY_MASK bit mask. + std::atomic nextBucketIndex; + /// Used to allocate Message objects. struct { /// Protects the messageAllocator.pool diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index bfccc39..9cd01ce 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -353,30 +353,6 @@ TEST_F(ReceiverTest, poll) receiver->poll(); } -TEST_F(ReceiverTest, checkTimeouts) -{ - Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), - SocketAddress{0, 60001}, 0); - Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); - bucket->resendTimeouts.setTimeout(&message.resendTimeout); - bucket->messageTimeouts.setTimeout(&message.messageTimeout); - - message.resendTimeout.expirationCycleTime = 10010; - message.messageTimeout.expirationCycleTime = 10020; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10010U, receiver->checkTimeouts()); - - message.resendTimeout.expirationCycleTime = 10030; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10020U, receiver->checkTimeouts()); - - bucket->resendTimeouts.cancelTimeout(&message.resendTimeout); - bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); -} - TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; @@ -694,7 +670,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_TRUE(bucket->resendTimeouts.list.empty()); } -TEST_F(ReceiverTest, checkMessageTimeouts_basic) +TEST_F(ReceiverTest, checkMessageTimeouts) { void* op[3]; Receiver::Message* message[3]; @@ -726,14 +702,17 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) // Message[2]: No timeout message[2]->messageTimeout.expirationCycleTime = 10001; + bucket->messageTimeouts.nextTimeout = 9998; + ASSERT_EQ(10000U, PerfUtils::Cycles::rdtsc()); ASSERT_TRUE(message[0]->messageTimeout.hasElapsed()); ASSERT_TRUE(message[1]->messageTimeout.hasElapsed()); ASSERT_FALSE(message[2]->messageTimeout.hasElapsed()); - uint64_t nextTimeout = receiver->checkMessageTimeouts(); + receiver->checkMessageTimeouts(10000, bucket); - EXPECT_EQ(message[2]->messageTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[2]->messageTimeout.expirationCycleTime, + bucket->messageTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: IN_PROGRESS EXPECT_EQ(nullptr, message[0]->messageTimeout.node.list); @@ -758,19 +737,7 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) EXPECT_EQ(2U, receiver->messageAllocator.pool.outstandingObjects); } -TEST_F(ReceiverTest, checkMessageTimeouts_empty) -{ - for (int i = 0; i < Receiver::MessageBucketMap::NUM_BUCKETS; ++i) { - Receiver::MessageBucket* bucket = - receiver->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->messageTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = receiver->checkMessageTimeouts(); - EXPECT_EQ(10000 + messageTimeoutCycles, nextTimeout); -} - -TEST_F(ReceiverTest, checkResendTimeouts_basic) +TEST_F(ReceiverTest, checkResendTimeouts) { Receiver::Message* message[3]; Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); @@ -809,6 +776,8 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) ASSERT_EQ(Receiver::Message::State::IN_PROGRESS, message[2]->state); message[2]->resendTimeout.expirationCycleTime = 10001; + bucket->resendTimeouts.nextTimeout = 9999; + EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); char buf1[1024]; @@ -831,9 +800,10 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) .Times(1); // TEST CALL - uint64_t nextTimeout = receiver->checkResendTimeouts(); + receiver->checkResendTimeouts(10000, bucket); - EXPECT_EQ(message[2]->resendTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[2]->resendTimeout.expirationCycleTime, + bucket->resendTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: resends EXPECT_EQ(10100, message[0]->resendTimeout.expirationCycleTime); @@ -859,16 +829,15 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(10001, message[2]->resendTimeout.expirationCycleTime); } -TEST_F(ReceiverTest, checkResendTimeouts_empty) +TEST_F(ReceiverTest, checkTimeouts) { - for (int i = 0; i < Receiver::MessageBucketMap::NUM_BUCKETS; ++i) { - Receiver::MessageBucket* bucket = - receiver->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->resendTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = receiver->checkResendTimeouts(); - EXPECT_EQ(10000 + resendIntervalCycles, nextTimeout); + Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); + + EXPECT_EQ(0, receiver->nextBucketIndex.load()); + + receiver->checkTimeouts(); + + EXPECT_EQ(1, receiver->nextBucketIndex.load()); } TEST_F(ReceiverTest, trySendGrants) diff --git a/src/Sender.cc b/src/Sender.cc index ea75bf4..5a2c0df 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -55,6 +55,7 @@ Sender::Sender(uint64_t transportId, Driver* driver, , sendQueue() , sending() , sendReady(false) + , nextBucketIndex(0) , messageAllocator() {} @@ -500,30 +501,7 @@ void Sender::poll() { trySend(); -} - -/** - * Process any Sender timeouts that have expired. - * - * This method must be called periodically to ensure timely handling of - * expired timeouts. - * - * @return - * The rdtsc cycle time when this method should be called again. - */ -uint64_t -Sender::checkTimeouts() -{ - uint64_t nextTimeout; - - // Ping Timeout - nextTimeout = checkPingTimeouts(); - - // Message Timeout - uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - - return nextTimeout; + checkTimeouts(); } /** @@ -871,112 +849,119 @@ Sender::dropMessage(Sender::Message* message) } /** - * Process any outbound messages that have timed out due to lack of activity - * from the Receiver. + * Process any outbound messages in a given bucket that have timed out due to + * lack of activity from the Receiver. * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose message timeouts should be checked. */ -uint64_t -Sender::checkMessageTimeouts() +void +Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - assert(MessageBucketMap::NUM_BUCKETS > 0); - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock(bucket->mutex); - // No remaining timeouts. - if (bucket->messageTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->messageTimeouts.timeoutIntervalCycles; - break; - } - Message* message = &bucket->messageTimeouts.list.front(); - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed()) { - nextTimeout = message->messageTimeout.expirationCycleTime; - break; - } - // Found expired timeout. - if (message->state != OutMessage::Status::COMPLETED) { - message->state.store(OutMessage::Status::FAILED); - // A sent NO_KEEP_ALIVE message should never reach this state - // since the shorter ping timeout should have already canceled - // the message timeout. - assert( - !((message->state == OutMessage::Status::SENT) && - (message->options & OutMessage::Options::NO_KEEP_ALIVE))); - } - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + if (!bucket->messageTimeouts.anyElapsed(now)) { + return; + } + + while (true) { + SpinLock::Lock lock(bucket->mutex); + // No remaining timeouts. + if (bucket->messageTimeouts.empty()) { + break; + } + Message* message = &bucket->messageTimeouts.front(); + // No remaining expired timeouts. + if (!message->messageTimeout.hasElapsed(now)) { + break; + } + // Found expired timeout. + if (message->state != OutMessage::Status::COMPLETED) { + message->state.store(OutMessage::Status::FAILED); + // A sent NO_KEEP_ALIVE message should never reach this state + // since the shorter ping timeout should have already canceled + // the message timeout. + assert(!((message->state == OutMessage::Status::SENT) && + (message->options & OutMessage::Options::NO_KEEP_ALIVE))); } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); } - return globalNextTimeout; } /** - * Process any outbound messages that need to be pinged to ensure the - * message is kept alive by the receiver. + * Process any outbound messages in a given bucket that need to be pinged to + * ensure the message is kept alive by the receiver. * * Pulled out of checkTimeouts() for ease of testing. * - * @return - * The rdtsc cycle time when this method should be called again. + * @param now + * The rdtsc cycle that should be considered the "current" time. + * @param bucket + * The bucket whose ping timeouts should be checked. */ -uint64_t -Sender::checkPingTimeouts() +void +Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) { - uint64_t globalNextTimeout = UINT64_MAX; - assert(MessageBucketMap::NUM_BUCKETS > 0); - for (int i = 0; i < MessageBucketMap::NUM_BUCKETS; ++i) { - MessageBucket* bucket = messageBuckets.buckets.at(i); - uint64_t nextTimeout = 0; - while (true) { - SpinLock::Lock lock(bucket->mutex); - // No remaining timeouts. - if (bucket->pingTimeouts.list.empty()) { - nextTimeout = PerfUtils::Cycles::rdtsc() + - bucket->pingTimeouts.timeoutIntervalCycles; - break; - } - Message* message = &bucket->pingTimeouts.list.front(); - // No remaining expired timeouts. - if (!message->pingTimeout.hasElapsed()) { - nextTimeout = message->pingTimeout.expirationCycleTime; - break; - } - // Found expired timeout. - if (message->state == OutMessage::Status::COMPLETED || - message->state == OutMessage::Status::FAILED) { - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - continue; - } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->state == OutMessage::Status::SENT) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - continue; - } else { - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - } + if (!bucket->pingTimeouts.anyElapsed(now)) { + return; + } - // Have not heard from the Receiver in the last timeout period. Ping - // the receiver to ensure it still knows about this Message. - Perf::counters.tx_ping_pkts.add(1); - ControlPacket::send( - message->driver, message->destination.ip, message->id); + while (true) { + SpinLock::Lock lock(bucket->mutex); + // No remaining timeouts. + if (bucket->pingTimeouts.empty()) { + break; + } + Message* message = &bucket->pingTimeouts.front(); + // No remaining expired timeouts. + if (!message->pingTimeout.hasElapsed(now)) { + break; + } + // Found expired timeout. + if (message->state == OutMessage::Status::COMPLETED || + message->state == OutMessage::Status::FAILED) { + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + continue; + } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && + message->state == OutMessage::Status::SENT) { + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + continue; + } else { + bucket->pingTimeouts.setTimeout(&message->pingTimeout); } - globalNextTimeout = std::min(globalNextTimeout, nextTimeout); + + // Have not heard from the Receiver in the last timeout period. Ping + // the receiver to ensure it still knows about this Message. + Perf::counters.tx_ping_pkts.add(1); + ControlPacket::send( + message->driver, message->destination.ip, message->id); } - return globalNextTimeout; +} + +/** + * Process any Sender timeouts that have expired. + * + * Pulled out of poll() for ease of testing. + */ +void +Sender::checkTimeouts() +{ + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = messageBuckets.buckets.at(index); + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkPingTimeouts(now, bucket); + checkMessageTimeouts(now, bucket); } /** * Send out packets for any messages with unscheduled/granted bytes. + * + * Pulled out of poll() for ease of testing. */ void Sender::trySend() diff --git a/src/Sender.h b/src/Sender.h index faa5dee..78fa05f 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -53,7 +54,6 @@ class Sender { virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); virtual void poll(); - virtual uint64_t checkTimeouts(); private: /// Forward declarations @@ -316,6 +316,9 @@ class Sender { */ static const int NUM_BUCKETS = 256; + // Make sure the number of buckets is a power of 2. + static_assert(Util::isPowerOfTwo(NUM_BUCKETS)); + /** * Bit mask used to map from a hashed key to the bucket index. */ @@ -392,8 +395,9 @@ class Sender { Message::Options options = Message::Options::NONE); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); - uint64_t checkMessageTimeouts(); - uint64_t checkPingTimeouts(); + void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); + void checkPingTimeouts(uint64_t now, MessageBucket* bucket); + void checkTimeouts(); void trySend(); /// Transport identifier. @@ -431,6 +435,12 @@ class Sender { /// if there is work to do is more efficient. std::atomic sendReady; + /// The index of the next bucket in the messageBuckets::buckets array to + /// process in the poll loop. The index is held in the lower order bits of + /// this variable; the higher order bits should be masked off using the + /// MessageBucketMap::HASH_KEY_MASK bit mask. + std::atomic nextBucketIndex; + /// Used to allocate Message objects. struct { /// Protects the messageAllocator.pool diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 8085c82..61ee299 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -764,7 +764,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); message->options = OutMessage::Options::NO_RETRY; std::vector packets; char payload[5][1028]; @@ -790,7 +790,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); @@ -1044,28 +1044,6 @@ TEST_F(SenderTest, poll) sender->poll(); } -TEST_F(SenderTest, checkTimeouts) -{ - Sender::Message message(sender, 0); - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); - bucket->pingTimeouts.setTimeout(&message.pingTimeout); - bucket->messageTimeouts.setTimeout(&message.messageTimeout); - - message.pingTimeout.expirationCycleTime = 10010; - message.messageTimeout.expirationCycleTime = 10020; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10010U, sender->checkTimeouts()); - - message.pingTimeout.expirationCycleTime = 10030; - - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_EQ(10020U, sender->checkTimeouts()); - - bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); - bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); -} - TEST_F(SenderTest, Message_destructor) { const int MAX_RAW_PACKET_LENGTH = 2000; @@ -1180,7 +1158,7 @@ TEST_F(SenderTest, Message_getStatus) TEST_F(SenderTest, Message_length) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); msg.messageLength = 200; msg.start = 20; EXPECT_EQ(180U, msg.length()); @@ -1526,14 +1504,12 @@ TEST_F(SenderTest, dropMessage) EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); } -TEST_F(SenderTest, checkMessageTimeouts_basic) +TEST_F(SenderTest, checkMessageTimeouts) { Sender::Message* message[4]; + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); for (uint64_t i = 0; i < 4; ++i) { - Protocol::MessageId id = {42, 10 + i}; message[i] = dynamic_cast(sender->allocMessage(0)); - SenderTest::addMessage(sender, id, message[i]); - Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); } @@ -1551,11 +1527,12 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) message[3]->messageTimeout.expirationCycleTime = 10001; message[3]->state = Homa::OutMessage::Status::SENT; - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); + bucket->messageTimeouts.nextTimeout = 9998; - uint64_t nextTimeout = sender->checkMessageTimeouts(); + sender->checkMessageTimeouts(10000, bucket); - EXPECT_EQ(message[3]->messageTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[3]->messageTimeout.expirationCycleTime, + bucket->messageTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: IN_PROGRESS EXPECT_EQ(nullptr, message[0]->messageTimeout.node.list); EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); @@ -1569,34 +1546,18 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) EXPECT_EQ(nullptr, message[2]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message[2]->getStatus()); // Message[3]: No timeout - EXPECT_EQ( - &sender->messageBuckets.getBucket(message[3]->id)->messageTimeouts.list, - message[3]->messageTimeout.node.list); - EXPECT_EQ( - &sender->messageBuckets.getBucket(message[3]->id)->pingTimeouts.list, - message[3]->pingTimeout.node.list); + EXPECT_EQ(&bucket->messageTimeouts.list, + message[3]->messageTimeout.node.list); + EXPECT_EQ(&bucket->pingTimeouts.list, message[3]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[3]->getStatus()); } -TEST_F(SenderTest, checkMessageTimeouts_empty) -{ - for (int i = 0; i < Sender::MessageBucketMap::NUM_BUCKETS; ++i) { - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->messageTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - uint64_t nextTimeout = sender->checkMessageTimeouts(); - EXPECT_EQ(10000 + messageTimeoutCycles, nextTimeout); -} - -TEST_F(SenderTest, checkPingTimeouts_basic) +TEST_F(SenderTest, checkPingTimeouts) { Sender::Message* message[5]; + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); for (uint64_t i = 0; i < 5; ++i) { - Protocol::MessageId id = {42, 10 + i}; message[i] = dynamic_cast(sender->allocMessage(0)); - SenderTest::addMessage(sender, id, message[i]); - Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); } @@ -1616,16 +1577,17 @@ TEST_F(SenderTest, checkPingTimeouts_basic) // Message[4]: No timeout message[4]->pingTimeout.expirationCycleTime = 10001; - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); + bucket->pingTimeouts.nextTimeout = 9997; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - uint64_t nextTimeout = sender->checkPingTimeouts(); + sender->checkPingTimeouts(10000, bucket); - EXPECT_EQ(message[4]->pingTimeout.expirationCycleTime, nextTimeout); + EXPECT_EQ(message[4]->pingTimeout.expirationCycleTime, + bucket->pingTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: COMPLETED EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); // Message[1]: Normal timeout: FAILED @@ -1642,16 +1604,15 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10001, message[4]->pingTimeout.expirationCycleTime); } -TEST_F(SenderTest, checkPingTimeouts_empty) +TEST_F(SenderTest, checkTimeouts) { - for (int i = 0; i < Sender::MessageBucketMap::NUM_BUCKETS; ++i) { - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(i); - EXPECT_TRUE(bucket->pingTimeouts.list.empty()); - } - EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - sender->checkPingTimeouts(); - uint64_t nextTimeout = sender->checkPingTimeouts(); - EXPECT_EQ(10000 + pingIntervalCycles, nextTimeout); + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); + + EXPECT_EQ(0, sender->nextBucketIndex.load()); + + sender->checkTimeouts(); + + EXPECT_EQ(1, sender->nextBucketIndex.load()); } TEST_F(SenderTest, trySend_basic) diff --git a/src/Timeout.h b/src/Timeout.h index 6931f2c..56710f4 100644 --- a/src/Timeout.h +++ b/src/Timeout.h @@ -18,18 +18,25 @@ #include +#include + #include "Intrusive.h" namespace Homa { namespace Core { +// Forward declaration. +template +class TimeoutManager; + /** * Intrusive structure to keep track of a per object timeout. * * This structure is not thread-safe. */ template -struct Timeout { +class Timeout { + public: /** * Initialize this Timeout, associating it with a particular object. * @@ -38,21 +45,34 @@ struct Timeout { */ explicit Timeout(ElementType* owner) : expirationCycleTime(0) - , node(owner) + , owner(owner) + , node(this) {} /** * Return true if this Timeout has elapsed; false otherwise. + * + * @param now + * Optionally provided "current" timestamp cycle time. Used to avoid + * unnecessary calls to PerfUtils::Cycles::rdtsc() if the current time + * is already available to the caller. */ - bool hasElapsed() + inline bool hasElapsed(uint64_t now = PerfUtils::Cycles::rdtsc()) { - return PerfUtils::Cycles::rdtsc() >= expirationCycleTime; + return now >= expirationCycleTime; } + private: /// Cycle timestamp when timeout should elapse. uint64_t expirationCycleTime; + + /// Pointer to the object that is associated with this timeout. + ElementType* owner; + /// Intrusive member to help track this timeout. - typename Intrusive::List::Node node; + typename Intrusive::List>::Node node; + + friend class TimeoutManager; }; /** @@ -61,7 +81,8 @@ struct Timeout { * This structure is not thread-safe. */ template -struct TimeoutManager { +class TimeoutManager { + public: /** * Construct a new TimeoutManager with a particular timeout interval. All * timeouts tracked by this manager will have the same timeout interval. @@ -69,6 +90,7 @@ struct TimeoutManager { */ explicit TimeoutManager(uint64_t timeoutIntervalCycles) : timeoutIntervalCycles(timeoutIntervalCycles) + , nextTimeout(UINT64_MAX) , list() {} @@ -79,12 +101,14 @@ struct TimeoutManager { * @param timeout * The Timeout that should be scheduled. */ - void setTimeout(Timeout* timeout) + inline void setTimeout(Timeout* timeout) { list.remove(&timeout->node); timeout->expirationCycleTime = PerfUtils::Cycles::rdtsc() + timeoutIntervalCycles; list.push_back(&timeout->node); + nextTimeout.store(list.front().expirationCycleTime, + std::memory_order_relaxed); } /** @@ -93,16 +117,78 @@ struct TimeoutManager { * @param timeout * The Timeout that should be canceled. */ - void cancelTimeout(Timeout* timeout) + inline void cancelTimeout(Timeout* timeout) { list.remove(&timeout->node); + if (list.empty()) { + nextTimeout.store(UINT64_MAX, std::memory_order_relaxed); + } else { + nextTimeout.store(list.front().expirationCycleTime, + std::memory_order_relaxed); + } + } + + /** + * Check if any managed Timeouts have elapsed. + * + * This method is thread-safe but may race with the other + * non-thread-safe methods of the TimeoutManager (e.g. concurrent calls + * to setTimeout() or cancelTimeout() may not be reflected in the result + * of this method call). + * + * @param now + * Optionally provided "current" timestamp cycle time. Used to + * avoid unnecessary calls to PerfUtils::Cycles::rdtsc() if the current + * time is already available to the caller. + */ + inline bool anyElapsed(uint64_t now = PerfUtils::Cycles::rdtsc()) + { + return now >= nextTimeout.load(std::memory_order_relaxed); + } + + /** + * Check if the TimeoutManager manages no Timeouts. + * + * @return + * True, if there are no Timeouts being managed; false, otherwise. + */ + inline bool empty() const + { + return list.empty(); + } + + /** + * Return a reference the managed timeout element that expires first. + * + * Calling front() an empty TimeoutManager is undefined. + */ + inline ElementType& front() + { + return *list.front().owner; } + /** + * Return a const reference the managed timeout element that expires + * first. + * + * Calling front() an empty TimeoutManager is undefined. + */ + inline const ElementType& front() const + { + return *list.front().owner; + } + + private: /// The number of cycles this newly scheduled timeouts would wait before /// they elapse. uint64_t timeoutIntervalCycles; + + /// The smallest timeout expiration time of all timeouts under + /// management. Accessing this value is thread-safe. + std::atomic nextTimeout; + /// Used to keep track of all timeouts under management. - Intrusive::List list; + Intrusive::List> list; }; } // namespace Core diff --git a/src/TimeoutTest.cc b/src/TimeoutTest.cc index 0225918..18ec9d2 100644 --- a/src/TimeoutTest.cc +++ b/src/TimeoutTest.cc @@ -13,12 +13,11 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include #include #include "Timeout.h" -#include - namespace Homa { namespace Core { namespace { @@ -46,19 +45,22 @@ TEST(TimeoutManagerTest, setTimeout) char dummyOwner; char owner; Timeout dummy(&dummyOwner); + dummy.expirationCycleTime = 42; manager.list.push_back(&dummy.node); + manager.nextTimeout = 0; Timeout t(&owner); EXPECT_EQ(0U, t.expirationCycleTime); EXPECT_EQ(nullptr, t.node.list); - EXPECT_EQ(&dummyOwner, &manager.list.back()); + EXPECT_EQ(&dummyOwner, manager.list.back().owner); manager.setTimeout(&t); EXPECT_EQ(10100U, t.expirationCycleTime); + EXPECT_EQ(42U, manager.nextTimeout.load()); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&owner, &manager.list.back()); + EXPECT_EQ(&owner, manager.list.back().owner); manager.list.clear(); PerfUtils::Cycles::mockTscValue = 0; @@ -72,21 +74,23 @@ TEST(TimeoutManagerTest, setTimeout_reset) char dummyOwner; Timeout t(&owner); Timeout dummy(&dummyOwner); + dummy.expirationCycleTime = 9001; manager.list.push_back(&t.node); manager.list.push_back(&dummy.node); t.expirationCycleTime = 50; EXPECT_EQ(50U, t.expirationCycleTime); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&owner, &manager.list.front()); - EXPECT_EQ(&dummyOwner, &manager.list.back()); + EXPECT_EQ(&owner, &manager.front()); + EXPECT_EQ(&dummyOwner, manager.list.back().owner); manager.setTimeout(&t); EXPECT_EQ(10100U, t.expirationCycleTime); + EXPECT_EQ(9001U, manager.nextTimeout.load()); EXPECT_EQ(&manager.list, t.node.list); - EXPECT_EQ(&dummyOwner, &manager.list.front()); - EXPECT_EQ(&owner, &manager.list.back()); + EXPECT_EQ(&dummyOwner, &manager.front()); + EXPECT_EQ(&owner, manager.list.back().owner); manager.list.clear(); PerfUtils::Cycles::mockTscValue = 0; @@ -96,20 +100,25 @@ TEST(TimeoutManagerTest, cancelTimeout) { TimeoutManager manager(100); char owner; - Timeout t(&owner); - manager.list.push_back(&t.node); + Timeout t1(&owner); + t1.expirationCycleTime = 42; + Timeout t2(&owner); + t2.expirationCycleTime = 9001; + manager.list.push_back(&t1.node); + manager.list.push_back(&t2.node); - EXPECT_EQ(&manager.list, t.node.list); - EXPECT_FALSE(manager.list.empty()); + EXPECT_EQ(2, manager.list.size()); - manager.cancelTimeout(&t); + manager.cancelTimeout(&t1); - EXPECT_EQ(nullptr, t.node.list); - EXPECT_TRUE(manager.list.empty()); + EXPECT_EQ(nullptr, t1.node.list); + EXPECT_EQ(9001U, manager.nextTimeout.load()); + EXPECT_EQ(1, manager.list.size()); - manager.cancelTimeout(&t); + manager.cancelTimeout(&t2); - EXPECT_EQ(nullptr, t.node.list); + EXPECT_EQ(nullptr, t2.node.list); + EXPECT_EQ(UINT64_MAX, manager.nextTimeout.load()); EXPECT_TRUE(manager.list.empty()); } diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index a380944..d3e4363 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -73,16 +73,6 @@ TransportImpl::poll() // Allow sender and receiver to make incremental progress. sender->poll(); receiver->poll(); - - if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { - uint64_t requestedTimeoutCycles; - requestedTimeoutCycles = sender->checkTimeouts(); - nextTimeoutCycles.store(requestedTimeoutCycles); - requestedTimeoutCycles = receiver->checkTimeouts(); - if (nextTimeoutCycles.load() > requestedTimeoutCycles) { - nextTimeoutCycles.store(requestedTimeoutCycles); - } - } } /** diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index a0f66c6..b5c708b 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -68,32 +68,20 @@ TEST_F(TransportImplTest, poll) EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); transport->poll(); - EXPECT_EQ(10000U, transport->nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); transport->poll(); - EXPECT_EQ(10100U, transport->nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); EXPECT_CALL(*mockSender, poll).Times(1); EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).Times(0); - EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); transport->poll(); - - EXPECT_EQ(10100U, transport->nextTimeoutCycles); } TEST_F(TransportImplTest, processPackets) diff --git a/src/UtilTest.cc b/src/UtilTest.cc index 3f8b7ff..b491220 100644 --- a/src/UtilTest.cc +++ b/src/UtilTest.cc @@ -14,7 +14,6 @@ */ #include - #include namespace Homa { @@ -33,4 +32,10 @@ TEST(UtilTest, downCast) EXPECT_EQ(64, c); } +TEST(UtilTest, isPowerOfTwo) +{ + EXPECT_TRUE(Util::isPowerOfTwo(4)); + EXPECT_FALSE(Util::isPowerOfTwo(3)); +} + } // namespace Homa \ No newline at end of file From 679e4830feb636c910e26725f52bb6c9a4a9cbf1 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Sun, 16 Aug 2020 10:09:19 -0700 Subject: [PATCH 08/33] Clean up comments --- include/Homa/Homa.h | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index aba9073..d5edf4a 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -147,15 +147,17 @@ class OutMessage { * Options with which an OutMessage can be sent. */ enum Options { - NONE = 0, //< Default send behavior. - NO_RETRY = 1 << 0, //< Message will not be resent if recoverable send - //< failure occurs; provides at-most-once delivery - //< of messages. - NO_KEEP_ALIVE = 1 << 1, //< Once the Message has been sent, Homa will - //< not automatically ping the Message's - //< receiver to ensure the receiver is still - //< alive and the Message will not "timeout" - //< due to receiver inactivity. + /// Default send behavior. + NONE = 0, + + /// Message will not be resent if recoverable send failure occurs; + /// provides at-most-once delivery of messages. + NO_RETRY = 1 << 0, + + /// Once the Message has been sent, Homa will not automatically ping the + /// Message's receiver to ensure the receiver is still alive and the + /// Message will not "timeout" due to receiver inactivity. + NO_KEEP_ALIVE = 1 << 1, }; /** From 05a84b59243682e525ce80485590965b773076e4 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Thu, 20 Aug 2020 22:36:31 -0700 Subject: [PATCH 09/33] Change default behavior to ensure message is sent Homa will now ensure that an OutMessage is fully sent before it is cleaned up if an application releases the message after calling send(). Previously, released messages that were in progress of being sent were cancelled and immediately cleaned up. Canceling a message now requires an explicit call to cancel(). --- CMakeLists.txt | 2 +- src/Protocol.h | 3 ++ src/Sender.cc | 89 ++++++++++++++++++++++++++++++---- src/Sender.h | 6 +++ src/SenderTest.cc | 118 +++++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 201 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f5cb6ef..99b16b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.2.0 LANGUAGES CXX) +project(Homa VERSION 0.1.3.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### diff --git a/src/Protocol.h b/src/Protocol.h index 55a34ac..84f8522 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -42,6 +42,9 @@ struct MessageId { uint64_t sequence; ///< Sequence number for this message (unique for ///< transportId, monotonically increasing). + /// MessageId default constructor. + MessageId() = default; + /// MessageId constructor. MessageId(uint64_t transportId, uint64_t sequence) : transportId(transportId) diff --git a/src/Sender.cc b/src/Sender.cc index 5a2c0df..84bff17 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -394,6 +394,18 @@ Sender::handleUnknownPacket(Driver::Packet* packet) driver->sendPacket(dataPacket, message->destination.ip, policy.priority); message->state.store(OutMessage::Status::SENT); + // This message must be still be held by the application since the + // message still exists (it would have been removed when dropped + // because single packet messages are never IN_PROGRESS). Assuming + // the message is still held, we can skip the auto removal of SENT + // and !held messages. + assert(message->held); + if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + } } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -783,6 +795,16 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, Perf::counters.tx_bytes.add(packet->length); driver->sendPacket(packet, message->destination.ip, policy.priority); message->state.store(OutMessage::Status::SENT); + // By definition, this message must be still be held by the application + // the send() call is since the progress. Assuming the message is still + // held, we can skip the auto removal of SENT and !held messages. + assert(message->held); + if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + } } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -829,7 +851,6 @@ Sender::cancelMessage(Sender::Message* message) sendQueue.remove(&info->sendQueueNode); } } - bucket->messages.remove(&message->bucketNode); message->state.store(OutMessage::Status::CANCELED); } } @@ -843,9 +864,21 @@ Sender::cancelMessage(Sender::Message* message) void Sender::dropMessage(Sender::Message* message) { - cancelMessage(message); - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); + Protocol::MessageId msgId = message->id; + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + message->held = false; + if (message->state != OutMessage::Status::IN_PROGRESS) { + // Ok to delete immediately since we don't have to wait for the message + // to be sent. + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + bucket->messages.remove(&message->bucketNode); + SpinLock::Lock lock_allocator(messageAllocator.mutex); + messageAllocator.pool.destroy(message); + } else { + // Defer deletion and wait for the message to be SENT. + } } /** @@ -880,11 +913,6 @@ Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) // Found expired timeout. if (message->state != OutMessage::Status::COMPLETED) { message->state.store(OutMessage::Status::FAILED); - // A sent NO_KEEP_ALIVE message should never reach this state - // since the shorter ping timeout should have already canceled - // the message timeout. - assert(!((message->state == OutMessage::Status::SENT) && - (message->options & OutMessage::Options::NO_KEEP_ALIVE))); } bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); @@ -986,6 +1014,8 @@ Sender::trySend() */ SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); + std::array sentMessageIds; + std::size_t messagesSent = 0; // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; @@ -1023,8 +1053,19 @@ Sender::trySend() } if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. + sentMessageIds[messagesSent++] = info->id; message.state.store(OutMessage::Status::SENT); it = sendQueue.remove(it); + if (messagesSent >= sentMessageIds.size()) { + // We've reached the maximum number of sent messages we can + // track. If this happens frequently, the size of sentMessageIds + // should be increased. + NOTICE( + "Max sent messages per poll reached; the limit should be " + "increased if this occurs frequently"); + sendReady = true; + break; + } } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; @@ -1035,8 +1076,36 @@ Sender::trySend() break; } } - sending.clear(); + + // Unlock the queueMutex to process any SENT messages to ensure any bucket + // mutex is always acquired before the send queueMutex. + lock_queue.unlock(); + for (std::size_t i = 0; i < messagesSent; ++i) { + Protocol::MessageId msgId = sentMessageIds[i]; + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); + if (message == nullptr) { + // Message must have already been deleted. + continue; + } + + if (!message->held) { + // Ok to delete now that the message has been sent. + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + bucket->messages.remove(&message->bucketNode); + SpinLock::Lock lock_allocator(messageAllocator.mutex); + messageAllocator.pool.destroy(message); + } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + } + } + uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; if (!idle) { Perf::counters.active_cycles.add(elapsed_cycles); diff --git a/src/Sender.h b/src/Sender.h index 78fa05f..50157fc 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -141,6 +141,7 @@ class Sender { , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) + , held(true) , start(0) , messageLength(0) , numPackets(0) @@ -198,6 +199,11 @@ class Sender { /// Contains flags for any requested optional send behavior. Options options; + /// True if a pointer to this message is accessible by the application + /// (e.g. the message has been allocated via allocMessage() but has not + /// been release via dropMessage()); false, otherwise. + bool held; + /// First byte where data is or will go if empty. int start; diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 61ee299..89bd283 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -757,6 +757,49 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_FALSE(sender->sendReady.load()); } +TEST_F(SenderTest, handleUnknownPacket_NO_KEEP_ALIVE) +{ + Protocol::MessageId id = {42, 1}; + SocketAddress destination = {22, 60001}; + Core::Policy::Unscheduled policyNew = {2, 3000, 2}; + + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::MockPacket dataPacket{payload}; + Protocol::Packet::DataHeader* dataHeader = + static_cast(dataPacket.payload); + setMessagePacket(message, 0, &dataPacket); + message->destination = destination; + message->messageLength = 500; + message->state.store(Homa::OutMessage::Status::SENT); + message->options = OutMessage::Options::NO_KEEP_ALIVE; + SenderTest::addMessage(sender, id, message); + Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + bucket->messageTimeouts.setTimeout(&message->messageTimeout); + bucket->pingTimeouts.setTimeout(&message->pingTimeout); + EXPECT_FALSE(bucket->messageTimeouts.empty()); + EXPECT_FALSE(bucket->pingTimeouts.empty()); + + Protocol::Packet::UnknownHeader* header = + static_cast(mockPacket.payload); + header->common.messageId = id; + + EXPECT_CALL( + mockPolicyManager, + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) + .WillOnce(Return(policyNew)); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), Eq(destination.ip), _)) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + .Times(1); + + sender->handleUnknownPacket(&mockPacket); + + EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); + EXPECT_TRUE(bucket->messageTimeouts.empty()); + EXPECT_TRUE(bucket->pingTimeouts.empty()); +} + TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) { Protocol::MessageId id = {42, 1}; @@ -1427,6 +1470,37 @@ TEST_F(SenderTest, sendMessage_multipacket) EXPECT_TRUE(sender->sendReady.load()); } +TEST_F(SenderTest, sendMessage_NO_KEEP_ALIVE) +{ + Protocol::MessageId id = {sender->transportId, + sender->nextMessageSequenceNumber}; + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); + + setMessagePacket(message, 0, &mockPacket); + message->messageLength = 420; + mockPacket.length = + message->messageLength + message->TRANSPORT_HEADER_LENGTH; + SocketAddress destination = {22, 60001}; + Core::Policy::Unscheduled policy = {1, 3000, 2}; + + EXPECT_CALL(mockPolicyManager, + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) + .WillOnce(Return(policy)); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .Times(1); + + sender->sendMessage(message, destination, + Sender::Message::Options::NO_KEEP_ALIVE); + + EXPECT_EQ(Sender::Message::Options::NO_KEEP_ALIVE, message->options); + EXPECT_TRUE(bucket->messages.contains(&message->bucketNode)); + EXPECT_TRUE(bucket->messageTimeouts.empty()); + EXPECT_TRUE(bucket->pingTimeouts.empty()); + EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); +} + TEST_F(SenderTest, sendMessage_missingPacket) { Protocol::MessageId id = {sender->transportId, @@ -1490,20 +1564,35 @@ TEST_F(SenderTest, cancelMessage) EXPECT_TRUE(bucket->pingTimeouts.list.empty()); EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state.load()); - EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); } -TEST_F(SenderTest, dropMessage) +TEST_F(SenderTest, dropMessage_basic) { Sender::Message* message = dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_TRUE(message->held); + EXPECT_EQ(OutMessage::Status::NOT_STARTED, message->state); sender->dropMessage(message); EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); } +TEST_F(SenderTest, dropMessage_IN_PROGRESS) +{ + Sender::Message* message = + dynamic_cast(sender->allocMessage(0)); + message->state = OutMessage::Status::IN_PROGRESS; + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_TRUE(message->held); + + sender->dropMessage(message); + + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); + EXPECT_FALSE(message->held); +} + TEST_F(SenderTest, checkMessageTimeouts) { Sender::Message* message[4]; @@ -1633,6 +1722,7 @@ TEST_F(SenderTest, trySend_basic) info->unsentBytes += PACKET_DATA_SIZE; } message->state = Homa::OutMessage::Status::IN_PROGRESS; + message->held = false; sender->sendReady = true; EXPECT_EQ(5U, message->numPackets); EXPECT_EQ(3U, info->packetsGranted); @@ -1640,6 +1730,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(5 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); // 3 granted packets; 2 will send; queue limit reached. EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); @@ -1652,6 +1743,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(3 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. @@ -1664,6 +1756,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(2 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // No additional grants; spurious ready hint. @@ -1677,6 +1770,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(2 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_NE(Homa::OutMessage::Status::SENT, message->state); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); // 2 more granted packets; will finish. @@ -1692,6 +1786,7 @@ TEST_F(SenderTest, trySend_basic) EXPECT_EQ(0 * PACKET_DATA_SIZE, info->unsentBytes); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_FALSE(sender->sendQueue.contains(&info->sendQueueNode)); + EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Mock::VerifyAndClearExpectations(&mockDriver); for (int i = 0; i < 5; ++i) { @@ -1701,14 +1796,17 @@ TEST_F(SenderTest, trySend_basic) TEST_F(SenderTest, trySend_multipleMessages) { + Protocol::MessageId id[3]; Sender::Message* message[3]; Sender::QueuedMessageInfo* info[3]; + Sender::MessageBucket* bucket[3]; Homa::Mock::MockDriver::MockPacket* packet[3]; for (uint64_t i = 0; i < 3; ++i) { - Protocol::MessageId id = {22, 10 + i}; + id[i] = {22, 10 + i}; message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; - SenderTest::addMessage(sender, id, message[i], true, 1); + SenderTest::addMessage(sender, id[i], message[i], true, 1); + bucket[i] = sender->messageBuckets.getBucket(id[i]); packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); @@ -1718,9 +1816,10 @@ TEST_F(SenderTest, trySend_multipleMessages) } sender->sendReady = true; - // Message 0: Will finish + // Message 0: Will finish, !held EXPECT_EQ(1, info[0]->packetsGranted); info[0]->packetsSent = 0; + message[0]->held = false; // Message 1: Will reach grant limit EXPECT_EQ(1, info[1]->packetsGranted); @@ -1728,9 +1827,10 @@ TEST_F(SenderTest, trySend_multipleMessages) setMessagePacket(message[1], 1, nullptr); EXPECT_EQ(2, message[1]->numPackets); - // Message 2: Will finish + // Message 2: Will finish, NO_KEEP_ALIVE EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; + message[2]->options = OutMessage::Options::NO_KEEP_ALIVE; EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); @@ -1741,12 +1841,18 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_EQ(1U, info[0]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[0]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[0]->sendQueueNode)); + EXPECT_EQ(nullptr, + bucket[0]->findMessage(id[0], SpinLock::Lock(bucket[0]->mutex))); EXPECT_EQ(1U, info[1]->packetsSent); EXPECT_NE(Homa::OutMessage::Status::SENT, message[1]->state); EXPECT_TRUE(sender->sendQueue.contains(&info[1]->sendQueueNode)); EXPECT_EQ(1U, info[2]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[2]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[2]->sendQueueNode)); + EXPECT_FALSE(bucket[2]->messageTimeouts.list.contains( + &message[2]->messageTimeout.node)); + EXPECT_FALSE( + bucket[2]->pingTimeouts.list.contains(&message[2]->pingTimeout.node)); } TEST_F(SenderTest, trySend_alreadyRunning) From dd631fe4aac1d28643a539bd511a2b491c45bd2d Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 21 Aug 2020 16:14:10 -0700 Subject: [PATCH 10/33] Add Perf benchmark atomicIncUnsafe --- test/Perf.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/Perf.cc b/test/Perf.cc index 7e9327f..b5db088 100644 --- a/test/Perf.cc +++ b/test/Perf.cc @@ -150,6 +150,24 @@ atomicIncRelaxedTest() return PerfUtils::Cycles::toSeconds(stop - start) / count; } +TestInfo atomicIncUnsafeTestInfo = { + "atomicIncUnsafe", "Increment an std::atomic using read-modify-write", + R"(Measure the cost of a thread unsafe increment of an std::atomic.)"}; +double +atomicIncUnsafeTest() +{ + int count = 1000000; + uint64_t temp = std::rand() % 100; + std::atomic val[count]; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + val[i].store(val[i].load(std::memory_order_relaxed) + temp, + std::memory_order_relaxed); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / count; +} + TestInfo intReadWriteTestInfo = { "intReadWrite", "Read and write a uint64_t", R"(Measure the cost the baseline read/write.)"}; @@ -506,6 +524,7 @@ TestCase tests[] = { {atomicStoreRelaxedTest, &atomicStoreRelaxedTestInfo}, {atomicIncTest, &atomicIncTestInfo}, {atomicIncRelaxedTest, &atomicIncRelaxedTestInfo}, + {atomicIncUnsafeTest, &atomicIncUnsafeTestInfo}, {branchTest, &branchTestInfo}, {intReadWriteTest, &intReadWriteTestInfo}, {listSearchTest, &listSearchTestInfo}, From 3ce52334cb3513b692fe9f3244eb3f0af689ee8a Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 21 Aug 2020 16:47:25 -0700 Subject: [PATCH 11/33] Remove unneeded Perf::Counters:Stat thread-safety Make the Perf::Counters::Stat::add(T val) operation thread unsafe. This significantly reduces the overhead of updating a Stat and is ok because add() is only called on thread-local Stat instances. --- src/Perf.h | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Perf.h b/src/Perf.h index 2349b01..da83c16 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -29,7 +29,7 @@ namespace Perf { */ struct Counters { /** - * Wrapper class for individual counter entires to + * Wrapper class for individual counter entries. */ template struct Stat : private std::atomic { @@ -43,8 +43,10 @@ struct Counters { /** * Add the value of another Stat to this Stat. + * + * This method is thread-safe. */ - void add(const Stat& other) + inline void add(const Stat& other) { this->fetch_add(other.load(std::memory_order_relaxed), std::memory_order_relaxed); @@ -52,16 +54,21 @@ struct Counters { /** * Add the given value to this Stat. + * + * This method is not thread-safe. */ - void add(T val) + inline void add(T val) { - this->fetch_add(val, std::memory_order_relaxed); + this->store(this->load(std::memory_order_relaxed) + val, + std::memory_order_relaxed); } /** * Return the stat value. + * + * This method is thread-safe. */ - T get() const + inline T get() const { return this->load(std::memory_order_relaxed); } From ffab71eef2559bc65e1578e82b6dd0fb317ba212 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 21 Aug 2020 17:01:12 -0700 Subject: [PATCH 12/33] Clean up Perf::Timer usage --- src/Perf.h | 18 +++++++++--------- src/Receiver.cc | 13 +++---------- src/Sender.cc | 8 +++----- src/TransportImpl.cc | 11 ++++++----- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/Perf.h b/src/Perf.h index da83c16..b448edc 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -78,8 +78,8 @@ struct Counters { * Default constructor. */ Counters() - : active_cycles(0) - , idle_cycles(0) + : total_cycles(0) + , active_cycles(0) , tx_bytes(0) , rx_bytes(0) , tx_data_pkts(0) @@ -110,8 +110,8 @@ struct Counters { */ void add(const Counters* other) { + total_cycles.add(other->total_cycles); active_cycles.add(other->active_cycles); - idle_cycles.add(other->idle_cycles); tx_bytes.add(other->tx_bytes); rx_bytes.add(other->rx_bytes); tx_data_pkts.add(other->tx_data_pkts); @@ -138,7 +138,7 @@ struct Counters { void dumpStats(Stats* stats) { stats->active_cycles = active_cycles.get(); - stats->idle_cycles = idle_cycles.get(); + stats->idle_cycles = total_cycles.get() - active_cycles.get(); stats->tx_bytes = tx_bytes.get(); stats->rx_bytes = rx_bytes.get(); stats->tx_data_pkts = tx_data_pkts.get(); @@ -159,12 +159,12 @@ struct Counters { stats->rx_error_pkts = rx_error_pkts.get(); } + /// CPU time spent running the Homa poll loop in cycles. + Stat total_cycles; + /// CPU time spent actively processing Homa messages in cycles. Stat active_cycles; - /// CPU time spent running Homa with no work to do in cycles. - Stat idle_cycles; - /// Number of bytes sent by the transport. Stat tx_bytes; @@ -240,10 +240,10 @@ extern thread_local ThreadCounters counters; class Timer { public: /** - * Construct a new uninitialized Timer. + * Construct a new Timer. */ Timer() - : split_tsc(0) + : split_tsc(PerfUtils::Cycles::rdtsc()) {} /** diff --git a/src/Receiver.cc b/src/Receiver.cc index 18441c9..e2a8003 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -694,8 +694,7 @@ Receiver::checkTimeouts() void Receiver::trySendGrants() { - uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); - bool idle = true; + Perf::Timer timer; // Skip scheduling if another poller is already working on it. if (granting.test_and_set()) { @@ -743,7 +742,6 @@ Receiver::trySendGrants() // Send a GRANT if there are too few bytes granted and unreceived. int receivedBytes = info->messageLength - info->bytesRemaining; if (info->bytesGranted - receivedBytes < policy.minScheduledBytes) { - idle = false; // Calculate new grant limit int newGrantLimit = std::min( receivedBytes + policy.maxScheduledBytes, info->messageLength); @@ -753,6 +751,7 @@ Receiver::trySendGrants() ControlPacket::send( driver, sourceIp, id, Util::downCast(info->bytesGranted), info->priority); + Perf::counters.active_cycles.add(timer.split()); } // Update the iterator first since calling unschedule() may cause the @@ -762,19 +761,13 @@ Receiver::trySendGrants() if (info->messageLength <= info->bytesGranted) { // All packets granted, unschedule the message. unschedule(message, lock); + Perf::counters.active_cycles.add(timer.split()); } ++slot; } granting.clear(); - - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; - if (!idle) { - Perf::counters.active_cycles.add(elapsed_cycles); - } else { - Perf::counters.idle_cycles.add(elapsed_cycles); - } } /** diff --git a/src/Sender.cc b/src/Sender.cc index 84bff17..d1833f0 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -994,8 +994,9 @@ Sender::checkTimeouts() void Sender::trySend() { - uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); + Perf::Timer timer; bool idle = true; + // Skip when there are no messages to send. if (!sendReady) { return; @@ -1106,11 +1107,8 @@ Sender::trySend() } } - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; if (!idle) { - Perf::counters.active_cycles.add(elapsed_cycles); - } else { - Perf::counters.idle_cycles.add(elapsed_cycles); + Perf::counters.active_cycles.add(timer.split()); } } diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index d3e4363..38fe6d3 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -67,12 +67,16 @@ TransportImpl::~TransportImpl() = default; void TransportImpl::poll() { + Perf::Timer timer; + // Receive and dispatch incoming packets. processPackets(); // Allow sender and receiver to make incremental progress. sender->poll(); receiver->poll(); + + Perf::counters.total_cycles.add(timer.split()); } /** @@ -84,7 +88,7 @@ void TransportImpl::processPackets() { // Keep track of time spent doing active processing versus idle. - uint64_t cycles = PerfUtils::Cycles::rdtsc(); + Perf::Timer timer; const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; @@ -94,11 +98,8 @@ TransportImpl::processPackets() processPacket(packets[i], srcAddrs[i]); } - cycles = PerfUtils::Cycles::rdtsc() - cycles; if (numPackets > 0) { - Perf::counters.active_cycles.add(cycles); - } else { - Perf::counters.idle_cycles.add(cycles); + Perf::counters.active_cycles.add(timer.split()); } } From 020468ec69eb6740d9a0e53c413f75362dbe9582 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 28 Aug 2020 11:31:39 -0700 Subject: [PATCH 13/33] Add more Perf microbenchmarks --- test/Perf.cc | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 218 insertions(+), 1 deletion(-) diff --git a/test/Perf.cc b/test/Perf.cc index b5db088..2e34399 100644 --- a/test/Perf.cc +++ b/test/Perf.cc @@ -31,6 +31,7 @@ #include "Cycles.h" #include "Homa/Drivers/Util/QueueEstimator.h" #include "Intrusive.h" +#include "ObjectPool.h" #include "docopt.h" static const char USAGE[] = R"(Performance Nano-Benchmark @@ -186,7 +187,7 @@ intReadWriteTest() } TestInfo branchTestInfo = { - "branchTest", "If-else statement", + "branch", "If-else statement", R"(The cost of choosing a branch in an if-else statement)"}; double branchTest() @@ -212,6 +213,70 @@ branchTest() return PerfUtils::Cycles::toSeconds(stop - start) / count; } +TestInfo defaultAllocatorTestInfo = { + "defaultAllocator", "Test new and delete of a simple structure", + R"(Measure the cost of allocation and deallocation using new and delete.)"}; +double +defaultAllocatorTest() +{ + struct Foo { + Foo() + : i() + , buf() + {} + + uint64_t i; + char buf[100]; + }; + Foo* foo[0xFFFF + 1]; + for (int i = 0; i < 0xFFFF + 1; ++i) { + foo[i] = new Foo; + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + delete foo[i & 0xFFFF]; + foo[i & 0xFFFF] = new Foo; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < 0xFFFF + 1; ++i) { + delete foo[i]; + } + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo objectPoolTestInfo = { + "objectPool", "Test ObjectPool allocation of a simple structure", + R"(Measure the cost of allocation and deallocation using an ObjectPool.)"}; +double +objectPoolTest() +{ + struct Foo { + Foo() + : i() + , buf() + {} + uint64_t i; + char buf[100]; + }; + Homa::ObjectPool pool; + Foo* foo[0xFFFF + 1]; + for (int i = 0; i < 0xFFFF + 1; ++i) { + foo[i] = pool.construct(); + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + pool.destroy(foo[i & 0xFFFF]); + foo[i & 0xFFFF] = pool.construct(); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < 0xFFFF + 1; ++i) { + pool.destroy(foo[i]); + } + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + TestInfo listSearchTestInfo = { "listSearch", "Linear search an intrusive list", R"(Measure the cost (per entry) of searching through an intrusive list)"}; @@ -381,6 +446,75 @@ dequePushPopTest() return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); } +TestInfo vectorConstructInfo = { + "vectorConstruct", "Construct an std::vector", + R"(Measure the cost of constructing (and destructing) an std::vector.)"}; +double +vectorConstruct() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::vector vector; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorReserveTestInfo = { + "vectorReserve", "Reserve capacity in an std::vector", + R"(Measure the cost of reserving capacity an std::vector.)"}; +double +vectorReserveTest() +{ + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + std::vector vector; + vector.reserve(32); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorPushTestInfo = { + "vectorPush", "std::vector push", + R"(Measure the cost of pushing a new element to an std::vector.)"}; +double +vectorPushTest() +{ + std::vector vector; + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + vector.push_back(i); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo vectorPushPopTestInfo = { + "vectorPushPop", "std::vector operations", + R"(Measure the cost of pushing/popping an element to/from an std::vector.)"}; +double +vectorPushPopTest() +{ + std::vector vector; + for (int i = 0; i < 10000; ++i) { + vector.push_back(std::rand()); + } + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + uint64_t temp = vector.back(); + vector.pop_back(); + vector.push_back(temp); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + TestInfo listConstructInfo = { "listConstruct", "Construct an std::list", R"(Measure the cost of constructing (and destructing) an std::list.)"}; @@ -396,6 +530,22 @@ listConstruct() return PerfUtils::Cycles::toSeconds(stop - start) / (count); } +TestInfo listPushTestInfo = { + "listPush", "std::list push", + R"(Measure the cost of pushing a new element to an std::list.)"}; +double +listPushTest() +{ + std::list list; + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + list.push_back(i); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + TestInfo listPushPopTestInfo = { "listPushPop", "std::list operations", R"(Measure the cost of pushing/popping an element to/from an std::list.)"}; @@ -418,6 +568,64 @@ listPushPopTest() return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); } +TestInfo ilistConstructInfo = { + "ilistConstruct", "Construct an Intrusive::list", + R"(Measure the cost of constructing (and destructing) an Intrusive::list.)"}; +double +ilistConstruct() +{ + struct Foo { + Foo() + : node(this) + {} + + Homa::Core::Intrusive::List::Node node; + }; + + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + Homa::Core::Intrusive::List list; + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + return PerfUtils::Cycles::toSeconds(stop - start) / (count); +} + +TestInfo ilistPushPopTestInfo = { + "ilistPushPop", "Intrusive::list operations", + R"(Measure the cost of pushing/popping an element to/from an Intrusive::list.)"}; +double +ilistPushPopTest() +{ + struct Foo { + Foo() + : node(this) + {} + + Homa::Core::Intrusive::List::Node node; + }; + + Homa::Core::Intrusive::List list; + for (int i = 0; i < 10000; ++i) { + Foo* foo = new Foo; + list.push_back(&foo->node); + } + int count = 1000000; + uint64_t start = PerfUtils::Cycles::rdtscp(); + for (int i = 0; i < count; i++) { + Foo* foo = &list.front(); + list.pop_front(); + list.push_front(&foo->node); + } + uint64_t stop = PerfUtils::Cycles::rdtscp(); + while (!list.empty()) { + Foo* foo = &list.front(); + list.pop_front(); + delete foo; + } + return PerfUtils::Cycles::toSeconds(stop - start) / (2 * count); +} + TestInfo heapTestInfo = { "heap", "std heap operations", R"(Measure the cost of pushing/popping an element to/from an std heap.)"}; @@ -527,14 +735,23 @@ TestCase tests[] = { {atomicIncUnsafeTest, &atomicIncUnsafeTestInfo}, {branchTest, &branchTestInfo}, {intReadWriteTest, &intReadWriteTestInfo}, + {defaultAllocatorTest, &defaultAllocatorTestInfo}, + {objectPoolTest, &objectPoolTestInfo}, {listSearchTest, &listSearchTestInfo}, {mapFindTest, &mapFindTestInfo}, {mapLookupTest, &mapLookupTestInfo}, {mapNullInsertTest, &mapNullInsertTestInfo}, {dequeConstruct, &dequeConstructInfo}, {dequePushPopTest, &dequePushPopTestInfo}, + {vectorConstruct, &vectorConstructInfo}, + {vectorReserveTest, &vectorReserveTestInfo}, + {vectorPushTest, &vectorPushTestInfo}, + {vectorPushPopTest, &vectorPushPopTestInfo}, {listConstruct, &listConstructInfo}, + {listPushTest, &listPushTestInfo}, {listPushPopTest, &listPushPopTestInfo}, + {ilistConstruct, &ilistConstructInfo}, + {ilistPushPopTest, &ilistPushPopTestInfo}, {heapTest, &heapTestInfo}, {queueEstimatorTest, &queueEstimatorTestInfo}, {rdtscTest, &rdtscTestInfo}, From 2cd326c9423f4d84683f028b80c48f6e1bd34549 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 28 Aug 2020 11:41:48 -0700 Subject: [PATCH 14/33] Remove Sender::trySend() restriction Previously, the number of messages completed per call to trySend() was limited by an std::array allocated on the stack. This was done to reduce the allocation overhead. New Perf microbenchmarks show that reserving capacity in an std::vector is likely efficient enough and would allow trySend() to seamlessly handle an unlimited number of message completions per call to trySend(). --- src/Sender.cc | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/Sender.cc b/src/Sender.cc index d1833f0..e9e2bfb 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -1015,8 +1015,8 @@ Sender::trySend() */ SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); - std::array sentMessageIds; - std::size_t messagesSent = 0; + std::vector sentMessageIds; + sentMessageIds.reserve(32); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; @@ -1054,19 +1054,9 @@ Sender::trySend() } if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. - sentMessageIds[messagesSent++] = info->id; + sentMessageIds.push_back(info->id); message.state.store(OutMessage::Status::SENT); it = sendQueue.remove(it); - if (messagesSent >= sentMessageIds.size()) { - // We've reached the maximum number of sent messages we can - // track. If this happens frequently, the size of sentMessageIds - // should be increased. - NOTICE( - "Max sent messages per poll reached; the limit should be " - "increased if this occurs frequently"); - sendReady = true; - break; - } } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; @@ -1082,8 +1072,7 @@ Sender::trySend() // Unlock the queueMutex to process any SENT messages to ensure any bucket // mutex is always acquired before the send queueMutex. lock_queue.unlock(); - for (std::size_t i = 0; i < messagesSent; ++i) { - Protocol::MessageId msgId = sentMessageIds[i]; + for (Protocol::MessageId& msgId : sentMessageIds) { MessageBucket* bucket = messageBuckets.getBucket(msgId); SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); From 793de5a5fd99ed0e19c9a5d14bbb73bac35a2ce4 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Thu, 3 Sep 2020 19:33:42 -0700 Subject: [PATCH 15/33] Simplify DpdkDriver Packet allocation --- src/Drivers/DPDK/DpdkDriverImpl.cc | 82 ++++++++++-------------------- src/Drivers/DPDK/DpdkDriverImpl.h | 4 +- 2 files changed, 31 insertions(+), 55 deletions(-) diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e9fef18..085b288 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -177,9 +177,32 @@ DpdkDriver::Impl::~Impl() Driver::Packet* DpdkDriver::Impl::allocPacket() { - DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); - if (unlikely(packet == nullptr)) { - SpinLock::Lock lock(packetLock); + DpdkDriver::Impl::Packet* packet = nullptr; + SpinLock::Lock lock(packetLock); + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); + if (unlikely(NULL == mbuf)) { + uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); + uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); + NOTICE( + "Failed to allocate an mbuf packet buffer; " + "%u mbufs available, %u mbufs in use", + numMbufsAvail, numMbufsInUse); + } else { + char* buf = rte_pktmbuf_append( + mbuf, Homa::Util::downCast(PACKET_HDR_LEN + + MAX_PAYLOAD_SIZE)); + if (unlikely(NULL == buf)) { + NOTICE("rte_pktmbuf_append call failed; dropping packet"); + rte_pktmbuf_free(mbuf); + } else { + packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); + mbufsOutstanding++; + } + } + } + if (packet == nullptr) { OverflowBuffer* buf = overflowBufferPool.construct(); packet = packetPool.construct(buf); NOTICE("OverflowBuffer used."); @@ -399,6 +422,7 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, { SpinLock::Lock lock(packetLock); packet = packetPool.construct(m, payload); + mbufsOutstanding++; } packet->base.length = length; @@ -420,6 +444,7 @@ DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { rte_pktmbuf_free(packet->bufRef.mbuf); + mbufsOutstanding--; } else { overflowBufferPool.destroy(packet->bufRef.overflowBuf); } @@ -708,57 +733,6 @@ DpdkDriver::Impl::_init() localMac.toString().c_str(), bandwidthMbps.load(), mtu); } -/** - * Helper function to try to allocation a new Dpdk Packet backed by an mbuf. - * - * @return - * The newly allocated Dpdk Packet; nullptr if the mbuf allocation - * failed. - */ -DpdkDriver::Impl::Packet* -DpdkDriver::Impl::_allocMbufPacket() -{ - DpdkDriver::Impl::Packet* packet = nullptr; - uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); - if (unlikely(numMbufsAvail <= NB_MBUF_RESERVED)) { - uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); - NOTICE( - "Driver is running low on mbuf packet buffers; " - "%u mbufs available, %u mbufs in use", - numMbufsAvail, numMbufsInUse); - return nullptr; - } - - struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); - - if (unlikely(NULL == mbuf)) { - uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); - uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); - NOTICE( - "Failed to allocate an mbuf packet buffer; " - "%u mbufs available, %u mbufs in use", - numMbufsAvail, numMbufsInUse); - return nullptr; - } - - char* buf = rte_pktmbuf_append( - mbuf, - Homa::Util::downCast(PACKET_HDR_LEN + MAX_PAYLOAD_SIZE)); - - if (unlikely(NULL == buf)) { - NOTICE("rte_pktmbuf_append call failed; dropping packet"); - rte_pktmbuf_free(mbuf); - return nullptr; - } - - // Perform packet operations with the lock held. - { - SpinLock::Lock _(packetLock); - packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); - } - return packet; -} - /** * Called before a burst of packets is transmitted to update the transmit stats. * diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 4d664fb..534f842 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -157,7 +157,6 @@ class DpdkDriver::Impl { private: void _eal_init(int argc, char* argv[]); void _init(); - Packet* _allocMbufPacket(); static uint16_t txBurstCallback(uint16_t port_id, uint16_t queue, struct rte_mbuf* pkts[], uint16_t nb_pkts, void* user_param); @@ -194,6 +193,9 @@ class DpdkDriver::Impl { /// Provides memory allocation for packet storage when mbuf are running out. ObjectPool overflowBufferPool; + /// The number of mbufs that have been given out to callers in Packets. + uint64_t mbufsOutstanding; + /// Holds packet buffers that are dequeued from the NIC's HW queues /// via DPDK. struct rte_mempool* mbufPool; From e5d45bd9fb4ce7287c01c64096e91fc86575790a Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 4 Sep 2020 10:50:05 -0700 Subject: [PATCH 16/33] Change Sender to not PING when blocked on itself The Sender will now skip pinging a message if the message still has granted but unsent packets (e.g. it is waiting on itself). --- src/Sender.cc | 10 ++++++++++ src/SenderTest.cc | 43 +++++++++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/Sender.cc b/src/Sender.cc index e9e2bfb..96ef448 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -962,6 +962,16 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) bucket->pingTimeouts.setTimeout(&message->pingTimeout); } + // Check if sender still has packets to send + if (message->state.load() == OutMessage::Status::IN_PROGRESS) { + SpinLock::Lock lock_queue(queueMutex); + QueuedMessageInfo* info = &message->queuedMessageInfo; + if (info->packetsSent < info->packetsGranted) { + // Sender is blocked on itself, no need to send ping + continue; + } + } + // Have not heard from the Receiver in the last timeout period. Ping // the receiver to ensure it still knows about this Message. Perf::counters.tx_ping_pkts.add(1); diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 89bd283..4e51e17 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1643,30 +1643,35 @@ TEST_F(SenderTest, checkMessageTimeouts) TEST_F(SenderTest, checkPingTimeouts) { - Sender::Message* message[5]; + Sender::Message* message[6]; Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); - for (uint64_t i = 0; i < 5; ++i) { + for (uint64_t i = 0; i < 6; ++i) { message[i] = dynamic_cast(sender->allocMessage(0)); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); } // Message[0]: Normal timeout: COMPLETED message[0]->state = Homa::OutMessage::Status::COMPLETED; - message[0]->pingTimeout.expirationCycleTime = 9997; + message[0]->pingTimeout.expirationCycleTime = 9996; // Message[1]: Normal timeout: FAILED message[1]->state = Homa::OutMessage::Status::FAILED; - message[1]->pingTimeout.expirationCycleTime = 9998; + message[1]->pingTimeout.expirationCycleTime = 9997; // Message[2]: Normal timeout: NO_KEEP_ALIVE && SENT message[2]->options = Homa::OutMessage::Options::NO_KEEP_ALIVE; message[2]->state = Homa::OutMessage::Status::SENT; - message[2]->pingTimeout.expirationCycleTime = 9999; - // Message[3]: Normal timeout: SENT - message[3]->state = Homa::OutMessage::Status::SENT; - message[3]->pingTimeout.expirationCycleTime = 10000; - // Message[4]: No timeout - message[4]->pingTimeout.expirationCycleTime = 10001; - - bucket->pingTimeouts.nextTimeout = 9997; + message[2]->pingTimeout.expirationCycleTime = 9998; + // Message[3]: Normal timeout: IN_PROGRESS + message[3]->state = Homa::OutMessage::Status::IN_PROGRESS; + message[3]->pingTimeout.expirationCycleTime = 9999; + message[3]->queuedMessageInfo.packetsSent = 1; + message[3]->queuedMessageInfo.packetsGranted = 2; + // Message[4]: Normal timeout: SENT + message[4]->state = Homa::OutMessage::Status::SENT; + message[4]->pingTimeout.expirationCycleTime = 10000; + // Message[5]: No timeout + message[5]->pingTimeout.expirationCycleTime = 10001; + + bucket->pingTimeouts.nextTimeout = 9996; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); @@ -1675,7 +1680,7 @@ TEST_F(SenderTest, checkPingTimeouts) sender->checkPingTimeouts(10000, bucket); - EXPECT_EQ(message[4]->pingTimeout.expirationCycleTime, + EXPECT_EQ(message[5]->pingTimeout.expirationCycleTime, bucket->pingTimeouts.nextTimeout.load()); // Message[0]: Normal timeout: COMPLETED EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); @@ -1683,14 +1688,16 @@ TEST_F(SenderTest, checkPingTimeouts) EXPECT_EQ(nullptr, message[1]->pingTimeout.node.list); // Message[2]: Normal timeout: NO_KEEP_ALIVE && SENT EXPECT_EQ(nullptr, message[2]->pingTimeout.node.list); - // Message[3]: Normal timeout: SENT - EXPECT_EQ(10100, message[3]->pingTimeout.expirationCycleTime); + // Message[3]: Normal timeout: IN_PROGRESS + EXPECT_EQ(10100, message[4]->pingTimeout.expirationCycleTime); + // Message[4]: Normal timeout: SENT + EXPECT_EQ(10100, message[4]->pingTimeout.expirationCycleTime); Protocol::Packet::CommonHeader* header = static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::PING, header->opcode); - EXPECT_EQ(message[3]->id, header->messageId); - // Message[4]: No timeout - EXPECT_EQ(10001, message[4]->pingTimeout.expirationCycleTime); + EXPECT_EQ(message[4]->id, header->messageId); + // Message[5]: No timeout + EXPECT_EQ(10001, message[5]->pingTimeout.expirationCycleTime); } TEST_F(SenderTest, checkTimeouts) From c78f1d967e296063b36267046981e0179383c3d2 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Fri, 4 Sep 2020 12:09:42 -0700 Subject: [PATCH 17/33] Drop message from sendQueue on Message timeout --- src/Sender.cc | 10 ++++++++++ src/SenderTest.cc | 2 ++ 2 files changed, 12 insertions(+) diff --git a/src/Sender.cc b/src/Sender.cc index 96ef448..0ca7d1a 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -912,6 +912,16 @@ Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) } // Found expired timeout. if (message->state != OutMessage::Status::COMPLETED) { + if (message->state == OutMessage::Status::IN_PROGRESS) { + // Check to see if the message needs to be dequeued. + SpinLock::Lock lock_queue(queueMutex); + // Recheck state with lock in case it change right before this. + if (message->state == OutMessage::Status::IN_PROGRESS) { + QueuedMessageInfo* info = &message->queuedMessageInfo; + assert(sendQueue.contains(&info->sendQueueNode)); + sendQueue.remove(&info->sendQueueNode); + } + } message->state.store(OutMessage::Status::FAILED); } bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 4e51e17..bfc3ecb 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1606,6 +1606,7 @@ TEST_F(SenderTest, checkMessageTimeouts) // Message[0]: Normal timeout: IN_PROGRESS message[0]->messageTimeout.expirationCycleTime = 9998; message[0]->state = Homa::OutMessage::Status::IN_PROGRESS; + sender->sendQueue.push_front(&message[0]->queuedMessageInfo.sendQueueNode); // Message[1]: Normal timeout: SENT message[1]->messageTimeout.expirationCycleTime = 9999; message[1]->state = Homa::OutMessage::Status::SENT; @@ -1626,6 +1627,7 @@ TEST_F(SenderTest, checkMessageTimeouts) EXPECT_EQ(nullptr, message[0]->messageTimeout.node.list); EXPECT_EQ(nullptr, message[0]->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message[0]->getStatus()); + EXPECT_TRUE(sender->sendQueue.empty()); // Message[1]: Normal timeout: SENT EXPECT_EQ(nullptr, message[1]->messageTimeout.node.list); EXPECT_EQ(nullptr, message[1]->pingTimeout.node.list); From 610b27ace01555401205bc625d3f34461c338bf3 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Mon, 7 Sep 2020 15:10:26 -0700 Subject: [PATCH 18/33] Add Message statistics --- include/Homa/Perf.h | 21 +++++++++++++++++++++ src/Perf.h | 42 ++++++++++++++++++++++++++++++++++++++++++ src/Receiver.cc | 4 ++++ src/Sender.cc | 4 ++++ 4 files changed, 71 insertions(+) diff --git a/include/Homa/Perf.h b/include/Homa/Perf.h index a163539..f213acd 100644 --- a/include/Homa/Perf.h +++ b/include/Homa/Perf.h @@ -38,6 +38,27 @@ struct Stats { /// CPU time spent running Homa with no work to do in cycles. uint64_t idle_cycles; + /// Number of InMessages that have been allocated by the Transport. + uint64_t allocated_rx_messages; + + /// Number of InMessages that have been received by the Transport. + uint64_t received_rx_messages; + + /// Number of InMessages delivered to the application. + uint64_t delivered_rx_messages; + + /// Number of InMessages released back to the Transport for destruction. + uint64_t destroyed_rx_messages; + + /// Number of OutMessages allocated for the application. + uint64_t allocated_tx_messages; + + /// Number of OutMessages released back to the transport. + uint64_t released_tx_messages; + + /// Number of OutMessages destroyed. + uint64_t destroyed_tx_messages; + /// Number of bytes sent by the transport. uint64_t tx_bytes; diff --git a/src/Perf.h b/src/Perf.h index b448edc..3b89438 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -80,6 +80,13 @@ struct Counters { Counters() : total_cycles(0) , active_cycles(0) + , allocated_rx_messages(0) + , received_rx_messages(0) + , delivered_rx_messages(0) + , destroyed_rx_messages(0) + , allocated_tx_messages(0) + , released_tx_messages(0) + , destroyed_tx_messages(0) , tx_bytes(0) , rx_bytes(0) , tx_data_pkts(0) @@ -112,6 +119,13 @@ struct Counters { { total_cycles.add(other->total_cycles); active_cycles.add(other->active_cycles); + allocated_rx_messages.add(other->allocated_rx_messages); + received_rx_messages.add(other->received_rx_messages); + delivered_rx_messages.add(other->delivered_rx_messages); + destroyed_rx_messages.add(other->destroyed_rx_messages); + allocated_tx_messages.add(other->allocated_tx_messages); + released_tx_messages.add(other->released_tx_messages); + destroyed_tx_messages.add(other->destroyed_tx_messages); tx_bytes.add(other->tx_bytes); rx_bytes.add(other->rx_bytes); tx_data_pkts.add(other->tx_data_pkts); @@ -139,6 +153,13 @@ struct Counters { { stats->active_cycles = active_cycles.get(); stats->idle_cycles = total_cycles.get() - active_cycles.get(); + stats->allocated_rx_messages = allocated_rx_messages.get(); + stats->received_rx_messages = received_rx_messages.get(); + stats->delivered_rx_messages = delivered_rx_messages.get(); + stats->destroyed_rx_messages = destroyed_rx_messages.get(); + stats->allocated_tx_messages = allocated_tx_messages.get(); + stats->released_tx_messages = released_tx_messages.get(); + stats->destroyed_tx_messages = destroyed_tx_messages.get(); stats->tx_bytes = tx_bytes.get(); stats->rx_bytes = rx_bytes.get(); stats->tx_data_pkts = tx_data_pkts.get(); @@ -165,6 +186,27 @@ struct Counters { /// CPU time spent actively processing Homa messages in cycles. Stat active_cycles; + /// Number of InMessages that have been allocated by the Transport. + Stat allocated_rx_messages; + + /// Number of InMessages that have been received by the Transport. + Stat received_rx_messages; + + /// Number of InMessages delivered to the application. + Stat delivered_rx_messages; + + /// Number of InMessages released back to the Transport for destruction. + Stat destroyed_rx_messages; + + /// Number of OutMessages allocated for the application. + Stat allocated_tx_messages; + + /// Number of OutMessages released back to the transport. + Stat released_tx_messages; + + /// Number of OutMessages destroyed. + Stat destroyed_tx_messages; + /// Number of bytes sent by the transport. Stat tx_bytes; diff --git a/src/Receiver.cc b/src/Receiver.cc index e2a8003..8869595 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -107,6 +107,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) message = messageAllocator.pool.construct( this, driver, dataHeaderLength, messageLength, id, srcAddress, numUnscheduledPackets); + Perf::counters.allocated_rx_messages.add(1); } bucket->messages.push_back(&message->bucketNode); @@ -159,6 +160,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); SpinLock::Lock lock_received_messages(receivedMessages.mutex); receivedMessages.queue.push_back(&message->receivedMessageNode); + Perf::counters.received_rx_messages.add(1); } } else { // must be a duplicate packet; drop packet. @@ -267,6 +269,7 @@ Receiver::receiveMessage() if (!receivedMessages.queue.empty()) { message = &receivedMessages.queue.front(); receivedMessages.queue.pop_front(); + Perf::counters.delivered_rx_messages.add(1); } return message; } @@ -497,6 +500,7 @@ Receiver::dropMessage(Receiver::Message* message) { SpinLock::Lock lock_allocator(messageAllocator.mutex); messageAllocator.pool.destroy(message); + Perf::counters.destroyed_rx_messages.add(1); } } } diff --git a/src/Sender.cc b/src/Sender.cc index 0ca7d1a..a589a42 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -71,6 +71,7 @@ Homa::OutMessage* Sender::allocMessage(uint16_t sourcePort) { SpinLock::Lock lock_allocator(messageAllocator.mutex); + Perf::counters.allocated_tx_messages.add(1); return messageAllocator.pool.construct(this, sourcePort); } @@ -868,6 +869,7 @@ Sender::dropMessage(Sender::Message* message) MessageBucket* bucket = messageBuckets.getBucket(msgId); SpinLock::Lock lock(bucket->mutex); message->held = false; + Perf::counters.released_tx_messages.add(1); if (message->state != OutMessage::Status::IN_PROGRESS) { // Ok to delete immediately since we don't have to wait for the message // to be sent. @@ -876,6 +878,7 @@ Sender::dropMessage(Sender::Message* message) bucket->messages.remove(&message->bucketNode); SpinLock::Lock lock_allocator(messageAllocator.mutex); messageAllocator.pool.destroy(message); + Perf::counters.destroyed_tx_messages.add(1); } else { // Defer deletion and wait for the message to be SENT. } @@ -1108,6 +1111,7 @@ Sender::trySend() bucket->messages.remove(&message->bucketNode); SpinLock::Lock lock_allocator(messageAllocator.mutex); messageAllocator.pool.destroy(message); + Perf::counters.destroyed_tx_messages.add(1); } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. From 3c6b2e198eedaa73270fd7c822bcd8cfc2aa5100 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Mon, 7 Sep 2020 16:16:43 -0700 Subject: [PATCH 19/33] DpdkDriver bug fixes and improvements --- src/Drivers/DPDK/DpdkDriverImpl.cc | 50 +++++++++++------------------- src/Drivers/DPDK/DpdkDriverImpl.h | 11 ++++--- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 085b288..2225bde 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -187,8 +187,8 @@ DpdkDriver::Impl::allocPacket() uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); NOTICE( "Failed to allocate an mbuf packet buffer; " - "%u mbufs available, %u mbufs in use", - numMbufsAvail, numMbufsInUse); + "%u mbufs available, %u mbufs in use, %u mbufs held by app", + numMbufsAvail, numMbufsInUse, mbufsOutstanding); } else { char* buf = rte_pktmbuf_append( mbuf, Homa::Util::downCast(PACKET_HDR_LEN + @@ -298,10 +298,14 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, struct rte_mbuf* mbuf_clone = rte_pktmbuf_clone(mbuf, mbufPool); if (unlikely(mbuf_clone == NULL)) { WARNING("Failed to clone packet for loopback; dropping packet"); + return; } int ret = rte_ring_enqueue(loopbackRing, mbuf_clone); if (unlikely(ret != 0)) { - WARNING("rte_ring_enqueue returned %d; packet may be lost?", ret); + WARNING( + "rte_ring_enqueue returned %d with %u packets queued; " + "packet may be lost?", + ret, rte_ring_count(loopbackRing)); rte_pktmbuf_free(mbuf_clone); } return; @@ -421,8 +425,15 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, DpdkDriver::Impl::Packet* packet = nullptr; { SpinLock::Lock lock(packetLock); - packet = packetPool.construct(m, payload); - mbufsOutstanding++; + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + packet = packetPool.construct(m, payload); + mbufsOutstanding++; + } else { + OverflowBuffer* buf = overflowBufferPool.construct(); + rte_memcpy(payload, buf->data, length); + packet = packetPool.construct(buf); + } } packet->base.length = length; @@ -656,14 +667,6 @@ DpdkDriver::Impl::_init() "Cannot allocate buffer for tx on port %u", port)); } rte_eth_tx_buffer_init(tx.buffer, MAX_PKT_BURST); - ret = rte_eth_tx_buffer_set_err_callback(tx.buffer, txBurstErrorCallback, - &tx.stats); - if (ret < 0) { - throw DriverInitFailure( - HERE_STR, - StringUtil::format( - "Cannot set error callback for tx buffer on port %u", port)); - } // get the current MTU. ret = rte_eth_dev_get_mtu(port, &mtu); @@ -722,7 +725,8 @@ DpdkDriver::Impl::_init() // create an in-memory ring, used as a software loopback in order to // handle packets that are addressed to the localhost. - loopbackRing = rte_ring_create(ringName.c_str(), 4096, SOCKET_ID_ANY, 0); + loopbackRing = + rte_ring_create(ringName.c_str(), NB_LOOPBACK_SLOTS, SOCKET_ID_ANY, 0); if (NULL == loopbackRing) { throw DriverInitFailure( HERE_STR, StringUtil::format("Failed to allocate loopback ring: %s", @@ -772,24 +776,6 @@ DpdkDriver::Impl::txBurstCallback(uint16_t port_id, uint16_t queue, return nb_pkts; } -/** - * Called to process the packets cannot be sent. - */ -void -DpdkDriver::Impl::txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, - void* userdata) -{ - uint64_t bytesDropped = 0; - for (int i = 0; i < unsent; ++i) { - bytesDropped += rte_pktmbuf_pkt_len(pkts[i]); - rte_pktmbuf_free(pkts[i]); - } - Tx::Stats* stats = static_cast(userdata); - SpinLock::Lock lock(stats->mutex); - assert(bytesDropped <= stats->bufferedBytes); - stats->bufferedBytes -= bytesDropped; -} - } // namespace DPDK } // namespace Drivers } // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 534f842..7305ce8 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -45,7 +45,7 @@ const int NDESC = 256; // Maximum number of packet buffers that the memory pool can hold. The // documentation of `rte_mempool_create` suggests that the optimum value // (in terms of memory usage) of this number is a power of two minus one. -const int NB_MBUF = 8191; +const int NB_MBUF = 16383; // If cache_size is non-zero, the rte_mempool library will try to limit the // accesses to the common lockless pool, by maintaining a per-lcore object @@ -57,7 +57,10 @@ const int MEMPOOL_CACHE_SIZE = 32; // The number of mbufs the driver should try to reserve for receiving packets. // Prevents applications from claiming more mbufs once the number of available // mbufs reaches this level. -const uint32_t NB_MBUF_RESERVED = 1024; +const uint32_t NB_MBUF_RESERVED = 4096; + +// The number of packets that can be held in loopback before they get dropped +const uint32_t NB_LOOPBACK_SLOTS = 4096; // The number of packets that the driver can buffer while corked. const uint16_t MAX_PKT_BURST = 32; @@ -160,8 +163,6 @@ class DpdkDriver::Impl { static uint16_t txBurstCallback(uint16_t port_id, uint16_t queue, struct rte_mbuf* pkts[], uint16_t nb_pkts, void* user_param); - static void txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, - void* userdata); /// Name of the Linux network interface to be used by DPDK. std::string ifname; @@ -194,7 +195,7 @@ class DpdkDriver::Impl { ObjectPool overflowBufferPool; /// The number of mbufs that have been given out to callers in Packets. - uint64_t mbufsOutstanding; + uint32_t mbufsOutstanding; /// Holds packet buffers that are dequeued from the NIC's HW queues /// via DPDK. From 429db01eeef47ac742807cd54416a6465ce32fbb Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Sun, 13 Sep 2020 19:40:21 -0700 Subject: [PATCH 20/33] Prevent DpdkDriver loopback starvation --- src/Drivers/DPDK/DpdkDriverImpl.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 2225bde..4fb8df3 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -360,26 +360,26 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, if (maxPackets > MAX_PACKETS_AT_ONCE) { maxPackets = MAX_PACKETS_AT_ONCE; } + uint32_t maxLoopbackPkts = maxPackets / 2; + struct rte_mbuf* mPkts[MAX_PACKETS_AT_ONCE]; // attempt to dequeue a batch of received packets from the NIC // as well as from the loopback ring. - uint32_t incomingPkts = 0; uint32_t loopbackPkts = 0; + uint32_t incomingPkts = 0; { SpinLock::Lock lock(rx.mutex); - incomingPkts = rte_eth_rx_burst( - port, 0, mPkts, Homa::Util::downCast(maxPackets)); - loopbackPkts = rte_ring_count(loopbackRing); - if (incomingPkts + loopbackPkts > maxPackets) { - loopbackPkts = maxPackets - incomingPkts; - } + loopbackPkts = std::min(loopbackPkts, maxLoopbackPkts); for (uint32_t i = 0; i < loopbackPkts; i++) { - rte_ring_dequeue(loopbackRing, reinterpret_cast( - &mPkts[incomingPkts + i])); + rte_ring_dequeue(loopbackRing, reinterpret_cast(&mPkts[i])); } + + incomingPkts = rte_eth_rx_burst( + port, 0, &(mPkts[loopbackPkts]), + Homa::Util::downCast(maxPackets - loopbackPkts)); } uint32_t totalPkts = incomingPkts + loopbackPkts; From 55c3e34422b909091ad27006c22fde87521e4e34 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Tue, 6 Oct 2020 16:55:34 -0700 Subject: [PATCH 21/33] Bump version to 0.2.0.0 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 99b16b2..9a2e935 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.3.0 LANGUAGES CXX) +project(Homa VERSION 0.2.0.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### From e143c27df065641fc615282d81e491849dde7595 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Wed, 7 Oct 2020 10:36:58 -0700 Subject: [PATCH 22/33] Add missing DpdkDriverImpl member init --- src/Drivers/DPDK/DpdkDriverImpl.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 4fb8df3..b4372ad 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,15 +17,15 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include "DpdkDriverImpl.h" + #include #include +#include #include #include -#include -#include "DpdkDriverImpl.h" - -#include +#include #include "CodeLocation.h" #include "Homa/Util.h" @@ -97,6 +97,7 @@ DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], , packetLock() , packetPool() , overflowBufferPool() + , mbufsOutstanding(0) , mbufPool(nullptr) , loopbackRing(nullptr) , rx() From b617c474447093c7c6c120e40c744e488cb5c248 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Wed, 7 Oct 2020 10:52:23 -0700 Subject: [PATCH 23/33] Expose Receiver::checkTimeouts() as public method Will be used in future integration with non-polling interface. --- src/Mock/MockReceiver.h | 1 + src/Receiver.cc | 32 ++++++++++++++++---------------- src/Receiver.h | 2 +- src/ReceiverTest.cc | 22 +++++++++++----------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index fa9ceba..0646a94 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -43,6 +43,7 @@ class MockReceiver : public Core::Receiver { (Driver::Packet * packet, IpAddress sourceIp), (override)); MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(void, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Receiver.cc b/src/Receiver.cc index 8869595..2fa98e6 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -286,6 +286,22 @@ Receiver::poll() checkTimeouts(); } +/** + * Make incremental progress processing expired Receiver timeouts. + * + * Pulled out of poll() for ease of testing. + */ +void +Receiver::checkTimeouts() +{ + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = messageBuckets.buckets.at(index); + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkResendTimeouts(now, bucket); + checkMessageTimeouts(now, bucket); +} + /** * Destruct a Message. Will release all contained Packet objects. */ @@ -676,22 +692,6 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) } } -/** - * Process any Receiver timeouts that have expired. - * - * Pulled out of poll() for ease of testing. - */ -void -Receiver::checkTimeouts() -{ - uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & - MessageBucketMap::HASH_KEY_MASK; - MessageBucket* bucket = messageBuckets.buckets.at(index); - uint64_t now = PerfUtils::Cycles::rdtsc(); - checkResendTimeouts(now, bucket); - checkMessageTimeouts(now, bucket); -} - /** * Send GRANTs to incoming Message according to the Receiver's policy. */ diff --git a/src/Receiver.h b/src/Receiver.h index bcca5a8..5833177 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -53,6 +53,7 @@ class Receiver { virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); virtual void poll(); + virtual void checkTimeouts(); private: // Forward declaration @@ -456,7 +457,6 @@ class Receiver { void dropMessage(Receiver::Message* message); void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); void checkResendTimeouts(uint64_t now, MessageBucket* bucket); - void checkTimeouts(); void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); void unschedule(Message* message, const SpinLock::Lock& lock); diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 9cd01ce..fdf0ef5 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -353,6 +353,17 @@ TEST_F(ReceiverTest, poll) receiver->poll(); } +TEST_F(ReceiverTest, checkTimeouts) +{ + Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); + + EXPECT_EQ(0, receiver->nextBucketIndex.load()); + + receiver->checkTimeouts(); + + EXPECT_EQ(1, receiver->nextBucketIndex.load()); +} + TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; @@ -829,17 +840,6 @@ TEST_F(ReceiverTest, checkResendTimeouts) EXPECT_EQ(10001, message[2]->resendTimeout.expirationCycleTime); } -TEST_F(ReceiverTest, checkTimeouts) -{ - Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); - - EXPECT_EQ(0, receiver->nextBucketIndex.load()); - - receiver->checkTimeouts(); - - EXPECT_EQ(1, receiver->nextBucketIndex.load()); -} - TEST_F(ReceiverTest, trySendGrants) { Receiver::Message* message[4]; From 40345afb299d3a2698adc5aba02f2090edcdaa14 Mon Sep 17 00:00:00 2001 From: Collin Lee Date: Wed, 7 Oct 2020 10:57:48 -0700 Subject: [PATCH 24/33] Expose Sender::checkTimeouts() as public method. Will be used by a future non-polling interface. --- src/Mock/MockSender.h | 1 + src/Sender.cc | 32 ++++++++++++++++---------------- src/Sender.h | 2 +- src/SenderTest.cc | 22 +++++++++++----------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index 3f05128..2cf4234 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -46,6 +46,7 @@ class MockSender : public Core::Sender { (override)); MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(void, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Sender.cc b/src/Sender.cc index a589a42..91be2b1 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -517,6 +517,22 @@ Sender::poll() checkTimeouts(); } +/** + * Make incremental progress processing expired Sender timeouts. + * + * Pulled out of poll() for ease of testing. + */ +void +Sender::checkTimeouts() +{ + uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & + MessageBucketMap::HASH_KEY_MASK; + MessageBucket* bucket = messageBuckets.buckets.at(index); + uint64_t now = PerfUtils::Cycles::rdtsc(); + checkPingTimeouts(now, bucket); + checkMessageTimeouts(now, bucket); +} + /** * Destruct a Message. Will release all contained Packet objects. */ @@ -993,22 +1009,6 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) } } -/** - * Process any Sender timeouts that have expired. - * - * Pulled out of poll() for ease of testing. - */ -void -Sender::checkTimeouts() -{ - uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & - MessageBucketMap::HASH_KEY_MASK; - MessageBucket* bucket = messageBuckets.buckets.at(index); - uint64_t now = PerfUtils::Cycles::rdtsc(); - checkPingTimeouts(now, bucket); - checkMessageTimeouts(now, bucket); -} - /** * Send out packets for any messages with unscheduled/granted bytes. * diff --git a/src/Sender.h b/src/Sender.h index 50157fc..e498f0d 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -54,6 +54,7 @@ class Sender { virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); virtual void poll(); + virtual void checkTimeouts(); private: /// Forward declarations @@ -403,7 +404,6 @@ class Sender { void dropMessage(Sender::Message* message); void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); void checkPingTimeouts(uint64_t now, MessageBucket* bucket); - void checkTimeouts(); void trySend(); /// Transport identifier. diff --git a/src/SenderTest.cc b/src/SenderTest.cc index bfc3ecb..07630a8 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1087,6 +1087,17 @@ TEST_F(SenderTest, poll) sender->poll(); } +TEST_F(SenderTest, checkTimeouts) +{ + Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); + + EXPECT_EQ(0, sender->nextBucketIndex.load()); + + sender->checkTimeouts(); + + EXPECT_EQ(1, sender->nextBucketIndex.load()); +} + TEST_F(SenderTest, Message_destructor) { const int MAX_RAW_PACKET_LENGTH = 2000; @@ -1702,17 +1713,6 @@ TEST_F(SenderTest, checkPingTimeouts) EXPECT_EQ(10001, message[5]->pingTimeout.expirationCycleTime); } -TEST_F(SenderTest, checkTimeouts) -{ - Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); - - EXPECT_EQ(0, sender->nextBucketIndex.load()); - - sender->checkTimeouts(); - - EXPECT_EQ(1, sender->nextBucketIndex.load()); -} - TEST_F(SenderTest, trySend_basic) { Protocol::MessageId id = {42, 10}; From 062fa25654939e86b887ec445175535aadef53e9 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 13 Oct 2020 21:09:21 -0700 Subject: [PATCH 25/33] Minor cleanups in Sender.h --- src/Sender.cc | 2 +- src/Sender.h | 76 ++++++++++++++++++--------------------------------- 2 files changed, 28 insertions(+), 50 deletions(-) diff --git a/src/Sender.cc b/src/Sender.cc index 91be2b1..f9f707b 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -527,7 +527,7 @@ Sender::checkTimeouts() { uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & MessageBucketMap::HASH_KEY_MASK; - MessageBucket* bucket = messageBuckets.buckets.at(index); + MessageBucket* bucket = &messageBuckets.buckets[index]; uint64_t now = PerfUtils::Cycles::rdtsc(); checkPingTimeouts(now, bucket); checkMessageTimeouts(now, bucket); diff --git a/src/Sender.h b/src/Sender.h index e498f0d..8f4282e 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -157,15 +157,15 @@ class Sender { {} virtual ~Message(); - virtual void append(const void* source, size_t count); - virtual void cancel(); - virtual Status getStatus() const; - virtual size_t length() const; - virtual void prepend(const void* source, size_t count); - virtual void release(); - virtual void reserve(size_t count); - virtual void send(SocketAddress destination, - Options options = Options::NONE); + void append(const void* source, size_t count) override; + void cancel() override; + Status getStatus() const override; + size_t length() const override; + void prepend(const void* source, size_t count) override; + void release() override; + void reserve(size_t count) override; + void send(SocketAddress destination, + Options options = Options::NONE) override; private: /// Define the maximum number of packets that a message can hold. @@ -289,14 +289,12 @@ class Sender { const SpinLock::Lock& lock) { (void)lock; - Message* message = nullptr; - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (it->id == msgId) { - message = &(*it); - break; + for (auto& it : messages) { + if (it.id == msgId) { + return ⁢ } } - return message; + return nullptr; } /// Mutex protecting the contents of this bucket. @@ -308,7 +306,7 @@ class Sender { /// Maintains Message objects in increasing order of timeout. TimeoutManager messageTimeouts; - /// Maintains Message object in increase order of ping timeout. + /// Maintains Message objects in increasing order of ping timeout. TimeoutManager pingTimeouts; }; @@ -334,27 +332,6 @@ class Sender { // Make sure bit mask correctly matches the number of buckets. static_assert(NUM_BUCKETS == HASH_KEY_MASK + 1); - /** - * Helper method to create the set of buckets. - * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. - * @param pingIntervalCycles - * Number of cycles of inactivity to wait between checking on the - * liveness of a Message. - */ - static std::array makeBuckets( - uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) - { - std::array buckets; - for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets[i] = - new MessageBucket(messageTimeoutCycles, pingIntervalCycles); - } - return buckets; - } - /** * MessageBucketMap constructor. * @@ -367,32 +344,33 @@ class Sender { */ explicit MessageBucketMap(uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) - : buckets(makeBuckets(messageTimeoutCycles, pingIntervalCycles)) + : buckets() , hasher() - {} - - /** - * MessageBucketMap destructor. - */ - ~MessageBucketMap() { + buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - delete buckets[i]; + buckets.emplace_back(messageTimeoutCycles, pingIntervalCycles); } } + /** + * MessageBucketMap destructor. + */ + ~MessageBucketMap() = default; + /** * Return the MessageBucket that should hold a Message with the given * MessageId. */ - MessageBucket* getBucket(const Protocol::MessageId& msgId) const + MessageBucket* getBucket(const Protocol::MessageId& msgId) { uint index = hasher(msgId) & HASH_KEY_MASK; - return buckets[index]; + return &buckets[index]; } - /// Array of buckets. - std::array const buckets; + /// Array of NUM_BUCKETS buckets. Defined as a vector to avoid the need + /// for a default constructor in MessageBucket. + std::vector buckets; /// MessageId hash function container. Protocol::MessageId::Hasher hasher; From e788958e583d02fb9c4aeb5301e11e4105a2be40 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 14 Oct 2020 03:02:48 -0700 Subject: [PATCH 26/33] wip: complete cleanups for Sender --- src/Receiver.cc | 12 +- src/Receiver.h | 2 +- src/Sender.cc | 407 +++++++++++++++++++++++++------------------ src/Sender.h | 42 +++-- src/SenderTest.cc | 1 + src/TransportImpl.cc | 25 ++- src/TransportImpl.h | 2 +- 7 files changed, 305 insertions(+), 186 deletions(-) diff --git a/src/Receiver.cc b/src/Receiver.cc index 2fa98e6..b2208d5 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -84,8 +84,12 @@ Receiver::~Receiver() * The incoming packet to be processed. * @param sourceIp * Source IP address of the packet. + * @return + * True if the Receiver decides to take ownership of the packet. False + * if the Receiver has no more use of this packet and it can be released + * to the driver. */ -void +bool Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::DataHeader* header = @@ -164,9 +168,9 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) } } else { // must be a duplicate packet; drop packet. - driver->releasePackets(&packet, 1); + return false; } - return; + return true; } /** @@ -193,7 +197,6 @@ Receiver::handleBusyPacket(Driver::Packet* packet) bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } - driver->releasePackets(&packet, 1); } /** @@ -247,7 +250,6 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) ControlPacket::send(driver, sourceIp, id); } - driver->releasePackets(&packet, 1); } /** diff --git a/src/Receiver.h b/src/Receiver.h index 5833177..e3eac22 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -48,7 +48,7 @@ class Receiver { uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual bool handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); virtual void handleBusyPacket(Driver::Packet* packet); virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); diff --git a/src/Sender.cc b/src/Sender.cc index f9f707b..1fd81a7 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -76,36 +76,54 @@ Sender::allocMessage(uint16_t sourcePort) } /** - * Process an incoming DONE packet. + * Execute the common processing logic that is shared among all incoming control + * packets. * * @param packet - * Incoming DONE packet to be processed. + * Incoming control packet to be processed. + * @param resetTimeout + * True if we should update the timeouts in response to the packet. + * @return + * Pointer to the message targeted by the incoming packet, or nullptr if no + * matching message can be found. */ -void -Sender::handleDonePacket(Driver::Packet* packet) +Sender::Message* +Sender::handleIncomingPacket(Driver::Packet* packet, bool resetTimeout) { - Protocol::Packet::DoneHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; - + Protocol::Packet::CommonHeader* commonHeader = + static_cast(packet->payload); + Protocol::MessageId msgId = commonHeader->messageId; MessageBucket* bucket = messageBuckets.getBucket(msgId); SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); + if (resetTimeout) { + message->resetTimeout(lock); + } + return message; +} +/** + * Process an incoming DONE packet. + * + * @param packet + * Incoming DONE packet to be processed. + */ +void +Sender::handleDonePacket(Driver::Packet* packet) +{ + Message* message = handleIncomingPacket(packet, false); if (message == nullptr) { - // No message for this DONE packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this DONE packet; must be old. return; } // Process DONE packet + Protocol::MessageId msgId = message->id; OutMessage::Status status = message->getStatus(); switch (status) { case OutMessage::Status::SENT: // Expected behavior - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::COMPLETED); + message->setStatus(OutMessage::Status::COMPLETED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -142,8 +160,6 @@ Sender::handleDonePacket(Driver::Packet* packet) msgId.transportId, msgId.sequence); break; } - - driver->releasePackets(&packet, 1); } /** @@ -155,21 +171,14 @@ Sender::handleDonePacket(Driver::Packet* packet) void Sender::handleResendPacket(Driver::Packet* packet) { - Protocol::Packet::ResendHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; - int index = header->index; - int resendEnd = index + header->num; - - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); + Message* message = handleIncomingPacket(packet, true); + // FIXME: with handleIncomingPacket, the bucket mutex no longer covers the entire method; need to double-check if this is OK in all methods + // FIXME: in particular, what message states are protected by the bucket lock? do we need a per-message lock? // Check for unexpected conditions if (message == nullptr) { // No message for this RESEND; RESEND must be old. Just ignore it; this // case should be pretty rare and the Receiver will timeout eventually. - driver->releasePackets(&packet, 1); return; } else if (message->numPackets < 2) { // We should never get a RESEND for a single packet message. Just @@ -177,13 +186,14 @@ Sender::handleResendPacket(Driver::Packet* packet) WARNING( "Message (%lu, %lu) with only 1 packet received unexpected RESEND " "request; peer Transport may be confused.", - msgId.transportId, msgId.sequence); - driver->releasePackets(&packet, 1); + message->id.transportId, message->id.sequence); return; } - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); + Protocol::Packet::ResendHeader* resendHeader = + static_cast(packet->payload); + int index = resendHeader->index; + int resendEnd = index + resendHeader->num; SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -195,9 +205,8 @@ Sender::handleResendPacket(Driver::Packet* packet) "Message (%lu, %lu) RESEND request range out of bounds: requested " "range [%d, %d); message only contains %d packets; peer Transport " "may be confused.", - msgId.transportId, msgId.sequence, index, resendEnd, + message->id.transportId, message->id.sequence, index, resendEnd, info->packets->numPackets); - driver->releasePackets(&packet, 1); return; } @@ -207,7 +216,7 @@ Sender::handleResendPacket(Driver::Packet* packet) // Note that the priority of messages under the unscheduled byte limit // will never be overridden since the resend index will not exceed the // preset packetsGranted. - info->priority = header->priority; + info->priority = resendHeader->priority; sendReady.store(true); } @@ -232,8 +241,6 @@ Sender::handleResendPacket(Driver::Packet* packet) driver->sendPacket(packet, message->destination.ip, resendPriority); } } - - driver->releasePackets(&packet, 1); } /** @@ -245,23 +252,15 @@ Sender::handleResendPacket(Driver::Packet* packet) void Sender::handleGrantPacket(Driver::Packet* packet) { - Protocol::Packet::GrantHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; - - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); + Message* message = handleIncomingPacket(packet, true); if (message == nullptr) { - // No message for this grant; grant must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this grant; grant must be old. return; } - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - - if (message->state.load() == OutMessage::Status::IN_PROGRESS) { + Protocol::Packet::GrantHeader* grantHeader = + static_cast(packet->payload); + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -270,7 +269,7 @@ Sender::handleGrantPacket(Driver::Packet* packet) // can cause at most 1 packet worth of data to be sent without a grant // but allows the sender to always send full packets. int incomingGrantIndex = - (header->byteLimit + info->packets->PACKET_DATA_LENGTH - 1) / + (grantHeader->byteLimit + info->packets->PACKET_DATA_LENGTH - 1) / info->packets->PACKET_DATA_LENGTH; // Make that grants don't exceed the number of packets. Internally, @@ -279,8 +278,8 @@ Sender::handleGrantPacket(Driver::Packet* packet) WARNING( "Message (%lu, %lu) GRANT exceeds message length; granted " "packets: %d, message packets %d; extra grants are ignored.", - msgId.transportId, msgId.sequence, incomingGrantIndex, - info->packets->numPackets); + message->id.transportId, message->id.sequence, + incomingGrantIndex, info->packets->numPackets); incomingGrantIndex = info->packets->numPackets; } @@ -289,12 +288,10 @@ Sender::handleGrantPacket(Driver::Packet* packet) // Note that the priority of messages under the unscheduled byte // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. - info->priority = header->priority; + info->priority = grantHeader->priority; sendReady.store(true); } } - - driver->releasePackets(&packet, 1); } /** @@ -306,17 +303,9 @@ Sender::handleGrantPacket(Driver::Packet* packet) void Sender::handleUnknownPacket(Driver::Packet* packet) { - Protocol::Packet::UnknownHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; - - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); - + Message* message = handleIncomingPacket(packet, false); if (message == nullptr) { - // No message was found. Just drop the packet. - driver->releasePackets(&packet, 1); + // No message was found. return; } @@ -333,35 +322,34 @@ Sender::handleUnknownPacket(Driver::Packet* packet) // failed since the application asked for the message not to be retried. // Remove Message from sendQueue. - if (message->numPackets > 1) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - assert(!sendQueue.contains(&info->sendQueueNode)); - } - - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + // FIXME: move the following block into setStatus? +// if (message->numPackets > 1) { +// SpinLock::Lock lock_queue(queueMutex); +// QueuedMessageInfo* info = &message->queuedMessageInfo; +// if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { +// assert(sendQueue.contains(&info->sendQueueNode)); +// sendQueue.remove(&info->sendQueueNode); +// } +// assert(!sendQueue.contains(&info->sendQueueNode)); +// } + message->deschedule(); + message->setStatus(OutMessage::Status::FAILED); } else { // Message isn't done yet so we will restart sending the message. // Make sure the message is not in the sendQueue before making any // changes to the message. - if (message->numPackets > 1) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - assert(!sendQueue.contains(&info->sendQueueNode)); - } - - message->state.store(OutMessage::Status::IN_PROGRESS); +// if (message->numPackets > 1) { +// SpinLock::Lock lock_queue(queueMutex); +// QueuedMessageInfo* info = &message->queuedMessageInfo; +// if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { +// assert(sendQueue.contains(&info->sendQueueNode)); +// sendQueue.remove(&info->sendQueueNode); +// } +// assert(!sendQueue.contains(&info->sendQueueNode)); +// } + message->deschedule(); + message->setStatus(OutMessage::Status::IN_PROGRESS); // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( @@ -381,11 +369,8 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Util::downCast(unscheduledIndexLimit); } - // Reset the timeouts - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - assert(message->numPackets > 0); + bool needTimeouts = true; if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. Driver::Packet* dataPacket = message->getPacket(0); @@ -394,18 +379,18 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Perf::counters.tx_bytes.add(dataPacket->length); driver->sendPacket(dataPacket, message->destination.ip, policy.priority); - message->state.store(OutMessage::Status::SENT); + message->setStatus(OutMessage::Status::SENT); // This message must be still be held by the application since the // message still exists (it would have been removed when dropped // because single packet messages are never IN_PROGRESS). Assuming // the message is still held, we can skip the auto removal of SENT // and !held messages. assert(message->held); + // FIXME: wait... this whole chunk of code is copied from sendMessage??? if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + needTimeouts = false; } } else { // Otherwise, queue the message to be sent in SRPT order. @@ -431,9 +416,13 @@ Sender::handleUnknownPacket(Driver::Packet* packet) QueuedMessageInfo::ComparePriority()); sendReady.store(true); } - } - driver->releasePackets(&packet, 1); + // Initialize the timeouts + if (needTimeouts) { + SpinLock::Lock bucket_lock(message->bucket->mutex); + message->resetTimeout(bucket_lock); + } + } } /** @@ -445,26 +434,18 @@ Sender::handleUnknownPacket(Driver::Packet* packet) void Sender::handleErrorPacket(Driver::Packet* packet) { - Protocol::Packet::ErrorHeader* header = - static_cast(packet->payload); - Protocol::MessageId msgId = header->common.messageId; - - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); + Message* message = handleIncomingPacket(packet, false); if (message == nullptr) { - // No message for this ERROR packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + // No message for this ERROR packet; must be old. return; } + Protocol::MessageId msgId = message->id; OutMessage::Status status = message->getStatus(); switch (status) { case OutMessage::Status::SENT: // Message was sent and a failure notification was received. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -501,8 +482,6 @@ Sender::handleErrorPacket(Driver::Packet* packet) msgId.transportId, msgId.sequence); break; } - - driver->releasePackets(&packet, 1); } /** @@ -589,13 +568,97 @@ Sender::Message::cancel() sender->cancelMessage(this); } + +// FIXME +void +Sender::Message::destroy(const SpinLock::Lock& bucketMutex) +{ + // TODO: we assume that the message has been unlinked from the sendQueue + // Remove this message from all global data structures of the Sender. + cancelTimeout(bucketMutex); + bucket->messages.remove(&bucketNode); + + // Destruct the Message object. + SpinLock::Lock lock_allocator(sender->messageAllocator.mutex); + sender->messageAllocator.pool.destroy(this); + Perf::counters.destroyed_tx_messages.add(1); +} + /** * @copydoc Homa::OutMessage::getStatus() */ OutMessage::Status Sender::Message::getStatus() const { - return state.load(); + return state.load(std::memory_order_acquire); +} + +/** + * Change the status of this message. + * + * TODO: + */ +void +Sender::Message::setStatus(Status newStatus) +{ + // FIXME: this extra lock argument is ugly and quite confusing for this method + state.store(newStatus, std::memory_order_release); + + // Clean up its state if the scheduler doesn't concern ??? + // FIXME + switch (newStatus) { + case OutMessage::Status::CANCELED: + case OutMessage::Status::COMPLETED: + case OutMessage::Status::FAILED: { + SpinLock::Lock lock(bucket->mutex); + cancelTimeout(lock); + } + default: + break; + } + + // FIXME: why cancel timeouts only? why not also remove itself from the buckets and the SRPT queue? +} + +/** + * Remove this Message from Sender::sendQueue. + */ +void +Sender::Message::deschedule() +{ + // FIXME: I don't think this optimization is correct; it relies on the assumption + // that all single-packet messages will bypass the throttling mechanism; this + // definitely doesn't make sense for jumbo packets... + if (numPackets <= 1) { + return; + } + + // TODO: well, if deschedule is so simple; no need to use a separate method! + SpinLock::Lock lock_queue(sender->queueMutex); + sender->sendQueue.remove(&queuedMessageInfo.sendQueueNode); + // FIXME: why so complicated? +// QueuedMessageInfo* info = &queuedMessageInfo; +// if (getStatus() == OutMessage::Status::IN_PROGRESS) { +// assert(sender->sendQueue.contains(&info->sendQueueNode)); +// sender->sendQueue.remove(&info->sendQueueNode); +// } +// assert(!sender->sendQueue.contains(&info->sendQueueNode)); +} + +void +Sender::Message::resetTimeout(const SpinLock::Lock& lock) +{ + (void)lock; + bucket->messageTimeouts.setTimeout(&messageTimeout); + bucket->pingTimeouts.setTimeout(&pingTimeout); +} + +void +Sender::Message::cancelTimeout(const SpinLock::Lock& lock) +{ + (void)lock; + bucket->messageTimeouts.cancelTimeout(&messageTimeout); + bucket->pingTimeouts.cancelTimeout(&pingTimeout); } /** @@ -755,8 +818,6 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, { // Prepare the message assert(message->driver == driver); - // Allocate a new message id - Protocol::MessageId id(transportId, nextMessageSequenceNumber++); Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( destination.ip, message->messageLength); @@ -764,10 +825,9 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); - message->id = id; message->destination = destination; message->options = options; - message->state.store(OutMessage::Status::IN_PROGRESS); + message->setStatus(OutMessage::Status::IN_PROGRESS); int actualMessageLen = 0; // fill out metadata. @@ -796,14 +856,15 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, sizeof(Protocol::Packet::DataHeader)); // Track message - MessageBucket* bucket = messageBuckets.getBucket(message->id); - SpinLock::Lock lock(bucket->mutex); - assert(!bucket->messages.contains(&message->bucketNode)); - bucket->messages.push_back(&message->bucketNode); - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - bucket->pingTimeouts.setTimeout(&message->pingTimeout); + MessageBucket* bucket = message->bucket; + { + SpinLock::Lock lock(bucket->mutex); + assert(!bucket->messages.contains(&message->bucketNode)); + bucket->messages.push_back(&message->bucketNode); + } assert(message->numPackets > 0); + bool needTimeouts = true; if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. Driver::Packet* packet = message->getPacket(0); @@ -811,7 +872,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); driver->sendPacket(packet, message->destination.ip, policy.priority); - message->state.store(OutMessage::Status::SENT); + message->setStatus(OutMessage::Status::SENT); // By definition, this message must be still be held by the application // the send() call is since the progress. Assuming the message is still // held, we can skip the auto removal of SENT and !held messages. @@ -819,14 +880,13 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + needTimeouts = false; } } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - info->id = id; + info->id = message->id; info->destination = message->destination; info->packets = message; info->unsentBytes = message->messageLength; @@ -840,6 +900,11 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, QueuedMessageInfo::ComparePriority()); sendReady.store(true); } + + if (needTimeouts) { + SpinLock::Lock lock(bucket->mutex); + message->resetTimeout(lock); + } } /** @@ -851,25 +916,34 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, void Sender::cancelMessage(Sender::Message* message) { - Protocol::MessageId msgId = message->id; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); + MessageBucket* bucket = message->bucket; + SpinLock::UniqueLock bucket_lock(bucket->mutex); + + // FIXME: why should we even bother to do the following test? why not just remove it from the bucket, the timeout list, and the SRPT queue? + // TODO: the remove method of an intrusive list should be idempotent, right? if (bucket->messages.contains(&message->bucketNode)) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); if (message->numPackets > 1 && - message->state == OutMessage::Status::IN_PROGRESS) { + message->getStatus() == OutMessage::Status::IN_PROGRESS) { // Check to see if the message needs to be dequeued. SpinLock::Lock lock_queue(queueMutex); // Recheck state with lock in case it change right before this. - if (message->state == OutMessage::Status::IN_PROGRESS) { + // FIXME: somehow I feel like I have seen the following code snippet a million times + // FIXME: why is this sendQueue stuff so complicated? + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { QueuedMessageInfo* info = &message->queuedMessageInfo; assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } } - message->state.store(OutMessage::Status::CANCELED); + + bucket_lock.unlock(); + message->setStatus(OutMessage::Status::CANCELED); + // FIXME: who is responsible for removing this message from the bucket? } + + // FIXME: why not change the entire method to the following: +// message->deschedule(); +// message->setStatus(OutMessage::Status::CANCELED); } /** @@ -881,20 +955,14 @@ Sender::cancelMessage(Sender::Message* message) void Sender::dropMessage(Sender::Message* message) { - Protocol::MessageId msgId = message->id; - MessageBucket* bucket = messageBuckets.getBucket(msgId); + MessageBucket* bucket = message->bucket; SpinLock::Lock lock(bucket->mutex); message->held = false; Perf::counters.released_tx_messages.add(1); - if (message->state != OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() != OutMessage::Status::IN_PROGRESS) { // Ok to delete immediately since we don't have to wait for the message // to be sent. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - bucket->messages.remove(&message->bucketNode); - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - Perf::counters.destroyed_tx_messages.add(1); + message->destroy(lock); } else { // Defer deletion and wait for the message to be SENT. } @@ -919,7 +987,7 @@ Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) } while (true) { - SpinLock::Lock lock(bucket->mutex); + SpinLock::UniqueLock bucket_lock(bucket->mutex); // No remaining timeouts. if (bucket->messageTimeouts.empty()) { break; @@ -929,22 +997,34 @@ Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) if (!message->messageTimeout.hasElapsed(now)) { break; } + + // Release the bucket mutex to avoid deadlock inside setStatus(). + bucket_lock.unlock(); + // Found expired timeout. - if (message->state != OutMessage::Status::COMPLETED) { - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() != OutMessage::Status::COMPLETED) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { // Check to see if the message needs to be dequeued. SpinLock::Lock lock_queue(queueMutex); + // FIXME: why double-check? why does it even matter? // Recheck state with lock in case it change right before this. - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { QueuedMessageInfo* info = &message->queuedMessageInfo; assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } } - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); + } else { + // TODO: double-check with Collin + SpinLock::Lock lock(bucket->mutex); + message->cancelTimeout(lock); + WARNING("SHOULDN'T BE HERE?"); } - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + // FIXME: I don't understand this; if the message is completed, its + // timeouts should've been cancelled already, no? +// bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); +// bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); } } @@ -978,21 +1058,24 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) break; } // Found expired timeout. - if (message->state == OutMessage::Status::COMPLETED || - message->state == OutMessage::Status::FAILED) { + if (message->getStatus() == OutMessage::Status::COMPLETED || + message->getStatus() == OutMessage::Status::FAILED) { + // FIXME: how is this possible? setStatus ensures that all timeouts + // will be cancelled when the status enters an end state bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->state == OutMessage::Status::SENT) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + message->getStatus() == OutMessage::Status::SENT) { + message->cancelTimeout(lock); continue; } else { + // TODO: can be change to the following to avoid calling setTimeout directly? + //message->resetTimeout(lock); bucket->pingTimeouts.setTimeout(&message->pingTimeout); } // Check if sender still has packets to send - if (message->state.load() == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; if (info->packetsSent < info->packetsGranted) { @@ -1046,7 +1129,7 @@ Sender::trySend() auto it = sendQueue.begin(); while (it != sendQueue.end()) { Message& message = *it; - assert(message.state.load() == OutMessage::Status::IN_PROGRESS); + assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; assert(info->packetsGranted <= info->packets->numPackets); while (info->packetsSent < info->packetsGranted) { @@ -1078,7 +1161,7 @@ Sender::trySend() if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. sentMessageIds.push_back(info->id); - message.state.store(OutMessage::Status::SENT); + message.setStatus(OutMessage::Status::SENT); it = sendQueue.remove(it); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. @@ -1106,17 +1189,11 @@ Sender::trySend() if (!message->held) { // Ok to delete now that the message has been sent. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - bucket->messages.remove(&message->bucketNode); - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - Perf::counters.destroyed_tx_messages.add(1); + message->destroy(lock); } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + message->cancelTimeout(lock); } } diff --git a/src/Sender.h b/src/Sender.h index 8f4282e..c2069e6 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -53,12 +53,16 @@ class Sender { virtual void handleGrantPacket(Driver::Packet* packet); virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); + virtual void poll(); virtual void checkTimeouts(); private: /// Forward declarations class Message; + class MessageBucket; + + Message* handleIncomingPacket(Driver::Packet* packet, bool resetTimeout); /** * Contains metadata for a Message that has been queued to be sent. @@ -138,7 +142,8 @@ class Sender { , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) - , id(0, 0) + , id(sender->transportId, sender->nextMessageSequenceNumber++) + , bucket(sender->messageBuckets.getBucket(id)) , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) @@ -159,7 +164,12 @@ class Sender { virtual ~Message(); void append(const void* source, size_t count) override; void cancel() override; + void destroy(const SpinLock::Lock& lock); Status getStatus() const override; + void setStatus(Status newStatus); + void deschedule(); + void resetTimeout(const SpinLock::Lock& lock); + void cancelTimeout(const SpinLock::Lock& lock); size_t length() const override; void prepend(const void* source, size_t count) override; void release() override; @@ -189,15 +199,20 @@ class Sender { const int PACKET_DATA_LENGTH; /// Contains the unique identifier for this message. - Protocol::MessageId id; + const Protocol::MessageId id; + + /// Message bucket this message belongs to. + MessageBucket* const bucket; /// Contains source address of this message. - SocketAddress source; + const SocketAddress source; - /// Contains destination address of this message. + /// Contains destination address of this message. Must be constant after + /// send() is invoked. SocketAddress destination; - /// Contains flags for any requested optional send behavior. + /// Contains flags for any requested optional send behavior. Must be + /// constant after send() is invoked. Options options; /// True if a pointer to this message is accessible by the application @@ -205,21 +220,26 @@ class Sender { /// been release via dropMessage()); false, otherwise. bool held; - /// First byte where data is or will go if empty. + /// First byte where data is or will go if empty. Must be constant after + /// send() is invoked. int start; /// Number of bytes in this Message including any reserved headroom. + /// Must be constant after send() is invoked. int messageLength; - /// Number of packets currently contained in this message. + /// Number of packets currently contained in this message. Must be + /// constant after send() is invoked. int numPackets; - /// Bit array representing which entires in the _packets_ array are set. - /// Used to avoid having to zero out the entire _packets_ array. + /// Bit array representing which entries in the _packets_ array are set. + /// Used to avoid having to zero out the entire _packets_ array. Must be + /// constant after send() is invoked. std::bitset occupied; /// Collection of Packet objects that make up this context's Message. - /// These Packets will be released when this context is destroyed. + /// These Packets will be released when this context is destroyed. Must + /// be constant after send() is invoked. Driver::Packet* packets[MAX_MESSAGE_PACKETS]; /// This message's current state. @@ -395,7 +415,7 @@ class Sender { Policy::Manager* const policyManager; /// The sequence number to be used for the next Message. - std::atomic nextMessageSequenceNumber; + volatile uint64_t nextMessageSequenceNumber; /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 07630a8..b3beda9 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1575,6 +1575,7 @@ TEST_F(SenderTest, cancelMessage) EXPECT_TRUE(bucket->pingTimeouts.list.empty()); EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state.load()); + // FIXME: shouldn't we check if the bucket is empty? } TEST_F(SenderTest, dropMessage_basic) diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 38fe6d3..165cc87 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -94,16 +94,33 @@ TransportImpl::processPackets() Driver::Packet* packets[MAX_BURST]; IpAddress srcAddrs[MAX_BURST]; int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); + int releaseCount = 0; for (int i = 0; i < numPackets; ++i) { - processPacket(packets[i], srcAddrs[i]); + bool retainPacket = processPacket(packets[i], srcAddrs[i]); + if (!retainPacket) { + packets[releaseCount++] = packets[i]; + } } + driver->releasePackets(packets, releaseCount); if (numPackets > 0) { Perf::counters.active_cycles.add(timer.split()); } } -void +/** + * Process an incoming packet. + * + * @param packet + * Incoming packet to be processed. + * @param sourceIp + * Source IP address. + * @return + * True if the transport decides to take ownership of the packet. False + * if the transport has no more use of this packet and it can be released + * to the driver. + */ +bool TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) { assert(packet->length >= @@ -111,10 +128,11 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) Perf::counters.rx_bytes.add(packet->length); Protocol::Packet::CommonHeader* header = static_cast(packet->payload); + bool retainPacket = false; switch (header->opcode) { case Protocol::Packet::DATA: Perf::counters.rx_data_pkts.add(1); - receiver->handleDataPacket(packet, sourceIp); + retainPacket = receiver->handleDataPacket(packet, sourceIp); break; case Protocol::Packet::GRANT: Perf::counters.rx_grant_pkts.add(1); @@ -145,6 +163,7 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) sender->handleErrorPacket(packet); break; } + return retainPacket; } } // namespace Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index ad46f99..79c7453 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -75,7 +75,7 @@ class TransportImpl : public Transport { private: void processPackets(); - void processPacket(Driver::Packet* packet, IpAddress source); + bool processPacket(Driver::Packet* packet, IpAddress source); /// Unique identifier for this transport. const std::atomic transportId; From 8859e059654105a2a8cf9e2ada5784f4c748463b Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Thu, 15 Oct 2020 02:06:30 -0700 Subject: [PATCH 27/33] completed the first draft of new Sender --- src/Sender.cc | 564 +++++++++++++++---------------------------- src/Sender.h | 46 ++-- src/SenderTest.cc | 1 - src/TransportImpl.cc | 4 - 4 files changed, 209 insertions(+), 406 deletions(-) diff --git a/src/Sender.cc b/src/Sender.cc index 1fd81a7..386ea0c 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -59,11 +59,6 @@ Sender::Sender(uint64_t transportId, Driver* driver, , messageAllocator() {} -/** - * Sender Destructor - */ -Sender::~Sender() {} - /** * Allocate an OutMessage that can be sent with this Sender. */ @@ -90,14 +85,18 @@ Sender::allocMessage(uint16_t sourcePort) Sender::Message* Sender::handleIncomingPacket(Driver::Packet* packet, bool resetTimeout) { + // Find the message bucket Protocol::Packet::CommonHeader* commonHeader = static_cast(packet->payload); Protocol::MessageId msgId = commonHeader->messageId; MessageBucket* bucket = messageBuckets.getBucket(msgId); + + // Find the target message and update its expiration time SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); if (resetTimeout) { - message->resetTimeout(lock); + bucket->messageTimeouts.setTimeout(&message->messageTimeout); + bucket->pingTimeouts.setTimeout(&message->pingTimeout); } return message; } @@ -123,7 +122,7 @@ Sender::handleDonePacket(Driver::Packet* packet) switch (status) { case OutMessage::Status::SENT: // Expected behavior - message->setStatus(OutMessage::Status::COMPLETED); + message->setStatus(OutMessage::Status::COMPLETED, false); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -172,8 +171,6 @@ void Sender::handleResendPacket(Driver::Packet* packet) { Message* message = handleIncomingPacket(packet, true); - // FIXME: with handleIncomingPacket, the bucket mutex no longer covers the entire method; need to double-check if this is OK in all methods - // FIXME: in particular, what message states are protected by the bucket lock? do we need a per-message lock? // Check for unexpected conditions if (message == nullptr) { @@ -227,18 +224,18 @@ Sender::handleResendPacket(Driver::Packet* packet) // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->destination.ip, info->id); + driver, info->packets->destination.ip, info->packets->id); } else { // There are some packets to resend but only resend packets that have // already been sent. resendEnd = std::min(resendEnd, info->packetsSent); int resendPriority = policyManager->getResendPriority(); - for (uint16_t i = index; i < resendEnd; ++i) { - Driver::Packet* packet = info->packets->getPacket(i); - // Packets will be sent at the priority their original priority. + for (int i = index; i < resendEnd; ++i) { + Driver::Packet* resendPacket = info->packets->getPacket(i); Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet, message->destination.ip, resendPriority); + Perf::counters.tx_bytes.add(resendPacket->length); + driver->sendPacket(resendPacket, message->destination.ip, + resendPriority); } } } @@ -317,111 +314,12 @@ Sender::handleUnknownPacket(Driver::Packet* packet) // must be a stale response to a ping. } else if (message->options & OutMessage::Options::NO_RETRY) { // Option: NO_RETRY - // Either the Message or the DONE packet was lost; consider the message // failed since the application asked for the message not to be retried. - - // Remove Message from sendQueue. - // FIXME: move the following block into setStatus? -// if (message->numPackets > 1) { -// SpinLock::Lock lock_queue(queueMutex); -// QueuedMessageInfo* info = &message->queuedMessageInfo; -// if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { -// assert(sendQueue.contains(&info->sendQueueNode)); -// sendQueue.remove(&info->sendQueueNode); -// } -// assert(!sendQueue.contains(&info->sendQueueNode)); -// } - message->deschedule(); - message->setStatus(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED, true); } else { // Message isn't done yet so we will restart sending the message. - - // Make sure the message is not in the sendQueue before making any - // changes to the message. -// if (message->numPackets > 1) { -// SpinLock::Lock lock_queue(queueMutex); -// QueuedMessageInfo* info = &message->queuedMessageInfo; -// if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { -// assert(sendQueue.contains(&info->sendQueueNode)); -// sendQueue.remove(&info->sendQueueNode); -// } -// assert(!sendQueue.contains(&info->sendQueueNode)); -// } - message->deschedule(); - message->setStatus(OutMessage::Status::IN_PROGRESS); - - // Get the current policy for unscheduled bytes. - Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - message->destination.ip, message->messageLength); - int unscheduledIndexLimit = - ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH); - - // Update the policy version for each packet - for (uint16_t i = 0; i < message->numPackets; ++i) { - Driver::Packet* dataPacket = message->getPacket(i); - assert(dataPacket != nullptr); - Protocol::Packet::DataHeader* header = - static_cast(dataPacket->payload); - header->policyVersion = policy.version; - header->unscheduledIndexLimit = - Util::downCast(unscheduledIndexLimit); - } - - assert(message->numPackets > 0); - bool needTimeouts = true; - if (message->numPackets == 1) { - // If there is only one packet in the message, send it right away. - Driver::Packet* dataPacket = message->getPacket(0); - assert(dataPacket != nullptr); - Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(dataPacket->length); - driver->sendPacket(dataPacket, message->destination.ip, - policy.priority); - message->setStatus(OutMessage::Status::SENT); - // This message must be still be held by the application since the - // message still exists (it would have been removed when dropped - // because single packet messages are never IN_PROGRESS). Assuming - // the message is still held, we can skip the auto removal of SENT - // and !held messages. - assert(message->held); - // FIXME: wait... this whole chunk of code is copied from sendMessage??? - if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { - // No timeouts need to be checked after sending the message when - // the NO_KEEP_ALIVE option is enabled. - needTimeouts = false; - } - } else { - // Otherwise, queue the message to be sent in SRPT order. - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - // Some of these values should still be set from when the message - // was first queued. - assert(info->id == message->id); - assert(!memcmp(&info->destination, &message->destination, - sizeof(info->destination))); - assert(info->packets == message); - // Some values need to be updated - info->unsentBytes = message->messageLength; - info->packetsGranted = - std::min(unscheduledIndexLimit, message->numPackets); - info->priority = policy.priority; - info->packetsSent = 0; - // Insert and move message into the correct order in the priority - // queue. - sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize( - &sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); - sendReady.store(true); - } - - // Initialize the timeouts - if (needTimeouts) { - SpinLock::Lock bucket_lock(message->bucket->mutex); - message->resetTimeout(bucket_lock); - } + startMessage(message, true); } } @@ -445,7 +343,7 @@ Sender::handleErrorPacket(Driver::Packet* packet) switch (status) { case OutMessage::Status::SENT: // Message was sent and a failure notification was received. - message->setStatus(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED, false); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -543,9 +441,9 @@ Sender::Message::append(const void* source, size_t count) int bytesToCopy = std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); Driver::Packet* packet = getOrAllocPacket(packetIndex); - char* destination = static_cast(packet->payload); - destination += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(destination, static_cast(source) + bytesCopied, + char* copyDst = static_cast(packet->payload); + copyDst += packetOffset + TRANSPORT_HEADER_LENGTH; + std::memcpy(copyDst, static_cast(source) + bytesCopied, bytesToCopy); // TODO(cstlee): A Message probably shouldn't be in charge of setting // the packet length. @@ -565,18 +463,28 @@ Sender::Message::append(const void* source, size_t count) void Sender::Message::cancel() { - sender->cancelMessage(this); + setStatus(OutMessage::Status::CANCELED, true); } - -// FIXME +/** + * Detach this message from the transport and destruct the Message object. + * + * Note: no one should access this message after the method returns. + */ void -Sender::Message::destroy(const SpinLock::Lock& bucketMutex) +Sender::Message::destroy() { - // TODO: we assume that the message has been unlinked from the sendQueue - // Remove this message from all global data structures of the Sender. - cancelTimeout(bucketMutex); - bucket->messages.remove(&bucketNode); + // We assume that this message has been unlinked from the sendQueue before + // this method is invoked. + assert(getStatus() != OutMessage::Status::IN_PROGRESS); + + // Remove this message from the other data structures of the Sender. + { + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->messageTimeouts.cancelTimeout(&messageTimeout); + bucket->pingTimeouts.cancelTimeout(&pingTimeout); + bucket->messages.remove(&bucketNode); + } // Destruct the Message object. SpinLock::Lock lock_allocator(sender->messageAllocator.mutex); @@ -593,72 +501,48 @@ Sender::Message::getStatus() const return state.load(std::memory_order_acquire); } -/** - * Change the status of this message. - * - * TODO: - */ + /** + * Change the status of this message. + * + * All status change must be done by this method. + * + * @param newStatus + * The new status. + * @param deschedule + * True if we should remove this message from the send queue. + */ void -Sender::Message::setStatus(Status newStatus) +Sender::Message::setStatus(Status newStatus, bool deschedule) { - // FIXME: this extra lock argument is ugly and quite confusing for this method - state.store(newStatus, std::memory_order_release); - - // Clean up its state if the scheduler doesn't concern ??? - // FIXME - switch (newStatus) { - case OutMessage::Status::CANCELED: - case OutMessage::Status::COMPLETED: - case OutMessage::Status::FAILED: { - SpinLock::Lock lock(bucket->mutex); - cancelTimeout(lock); + // Whether to remove the message from the send queue depends on more than + // just the message status; only the caller has enough information to make + // the decision. + if (deschedule) { + // TODO: with jumbo packets, single-packet messages may also be paced + + // An outgoing message is on the sendQueue iff. it's still in progress + // and subject to the sender's packet pacing mechanism; test this + // condition first to reduce the expensive locking operation + if ((numPackets > 1) && + (getStatus() == OutMessage::Status::IN_PROGRESS)) { + SpinLock::Lock lock_queue(sender->queueMutex); + sender->sendQueue.remove(&queuedMessageInfo.sendQueueNode); } - default: - break; } - // FIXME: why cancel timeouts only? why not also remove itself from the buckets and the SRPT queue? -} + state.store(newStatus, std::memory_order_release); -/** - * Remove this Message from Sender::sendQueue. - */ -void -Sender::Message::deschedule() -{ - // FIXME: I don't think this optimization is correct; it relies on the assumption - // that all single-packet messages will bypass the throttling mechanism; this - // definitely doesn't make sense for jumbo packets... - if (numPackets <= 1) { - return; + // Cancel the timeouts if the message reaches an end state. + if (newStatus == OutMessage::Status::CANCELED || + newStatus == OutMessage::Status::COMPLETED || + newStatus == OutMessage::Status::FAILED) { + SpinLock::Lock lock(bucket->mutex); + bucket->messageTimeouts.cancelTimeout(&messageTimeout); + bucket->pingTimeouts.cancelTimeout(&pingTimeout); } - // TODO: well, if deschedule is so simple; no need to use a separate method! - SpinLock::Lock lock_queue(sender->queueMutex); - sender->sendQueue.remove(&queuedMessageInfo.sendQueueNode); - // FIXME: why so complicated? -// QueuedMessageInfo* info = &queuedMessageInfo; -// if (getStatus() == OutMessage::Status::IN_PROGRESS) { -// assert(sender->sendQueue.contains(&info->sendQueueNode)); -// sender->sendQueue.remove(&info->sendQueueNode); -// } -// assert(!sender->sendQueue.contains(&info->sendQueueNode)); -} - -void -Sender::Message::resetTimeout(const SpinLock::Lock& lock) -{ - (void)lock; - bucket->messageTimeouts.setTimeout(&messageTimeout); - bucket->pingTimeouts.setTimeout(&pingTimeout); -} - -void -Sender::Message::cancelTimeout(const SpinLock::Lock& lock) -{ - (void)lock; - bucket->messageTimeouts.cancelTimeout(&messageTimeout); - bucket->pingTimeouts.cancelTimeout(&pingTimeout); + // This method is not the right place to remove the message from the bucket; + // it's the job of Message::release(). } /** @@ -690,9 +574,9 @@ Sender::Message::prepend(const void* source, size_t count) std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); Driver::Packet* packet = getPacket(packetIndex); assert(packet != nullptr); - char* destination = static_cast(packet->payload); - destination += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(destination, static_cast(source) + bytesCopied, + char* copyDst = static_cast(packet->payload); + copyDst += packetOffset + TRANSPORT_HEADER_LENGTH; + std::memcpy(copyDst, static_cast(source) + bytesCopied, bytesToCopy); bytesCopied += bytesToCopy; packetIndex++; @@ -706,7 +590,15 @@ Sender::Message::prepend(const void* source, size_t count) void Sender::Message::release() { - sender->dropMessage(this); + held.store(false, std::memory_order_release); + if (getStatus() != OutMessage::Status::IN_PROGRESS) { + // Ok to delete immediately since we don't have to wait for the message + // to be sent. + destroy(); + } else { + // Defer deletion and wait for the message to be SENT. + } + Perf::counters.released_tx_messages.add(1); } /** @@ -755,7 +647,12 @@ void Sender::Message::send(SocketAddress destination, Sender::Message::Options options) { - sender->sendMessage(this, destination, options); + // Prepare the message + this->destination = destination; + this->options = options; + + // Kick start the transmission. + sender->startMessage(this, false); } /** @@ -801,81 +698,78 @@ Sender::Message::getOrAllocPacket(size_t index) } /** - * Queue a message to be sent. + * (Re)start the transmission of an outgoing message. * * @param message * Sender::Message to be sent. - * @param destination - * Destination address for this message. - * @param options - * Flags indicating requested non-default send behavior. - * - * @sa dropMessage() + * @param restart + * False if the message is new to the transport; true means the message is + * restarted by the transport. */ void -Sender::sendMessage(Sender::Message* message, SocketAddress destination, - Sender::Message::Options options) +Sender::startMessage(Sender::Message* message, bool restart) { - // Prepare the message - assert(message->driver == driver); + // If we are restarting an existing message, make sure it's not in the + // sendQueue before making any changes to it. + message->setStatus(OutMessage::Status::IN_PROGRESS, restart); + // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - destination.ip, message->messageLength); - int unscheduledPacketLimit = + message->destination.ip, message->messageLength); + uint16_t unscheduledIndexLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); - message->destination = destination; - message->options = options; - message->setStatus(OutMessage::Status::IN_PROGRESS); - - int actualMessageLen = 0; - // fill out metadata. - for (int i = 0; i < message->numPackets; ++i) { - Driver::Packet* packet = message->getPacket(i); - if (packet == nullptr) { - PANIC( - "Incomplete message with id (%lu:%lu); missing packet " - "at offset %d; this shouldn't happen.", - message->id.transportId, message->id.sequence, - i * message->PACKET_DATA_LENGTH); + if (!restart) { + // Fill out packet headers. + int actualMessageLen = 0; + for (int i = 0; i < message->numPackets; ++i) { + Driver::Packet* packet = message->getPacket(i); + assert(packet != nullptr); + new (packet->payload) Protocol::Packet::DataHeader( + message->source.port, message->destination.port, message->id, + Util::downCast(message->messageLength), + policy.version, unscheduledIndexLimit, + Util::downCast(i)); + actualMessageLen += + (packet->length - message->TRANSPORT_HEADER_LENGTH); } + assert(message->messageLength == actualMessageLen); - new (packet->payload) Protocol::Packet::DataHeader( - message->source.port, destination.port, message->id, - Util::downCast(message->messageLength), policy.version, - Util::downCast(unscheduledPacketLimit), - Util::downCast(i)); - actualMessageLen += (packet->length - message->TRANSPORT_HEADER_LENGTH); - } - - // perform sanity checks. - assert(message->driver == driver); - assert(message->messageLength == actualMessageLen); - assert(message->TRANSPORT_HEADER_LENGTH == - sizeof(Protocol::Packet::DataHeader)); - - // Track message - MessageBucket* bucket = message->bucket; - { + // Start tracking the new message + MessageBucket* bucket = message->bucket; SpinLock::Lock lock(bucket->mutex); assert(!bucket->messages.contains(&message->bucketNode)); bucket->messages.push_back(&message->bucketNode); + } else { + // Update the policy version for each packet + for (int i = 0; i < message->numPackets; ++i) { + Driver::Packet* packet = message->getPacket(i); + assert(packet != nullptr); + Protocol::Packet::DataHeader* header = + static_cast(packet->payload); + header->policyVersion = policy.version; + header->unscheduledIndexLimit = unscheduledIndexLimit; + } } + // Kick start the message. assert(message->numPackets > 0); bool needTimeouts = true; if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. - Driver::Packet* packet = message->getPacket(0); - assert(packet != nullptr); + Driver::Packet* dataPacket = message->getPacket(0); + assert(dataPacket != nullptr); Perf::counters.tx_data_pkts.add(1); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet, message->destination.ip, policy.priority); - message->setStatus(OutMessage::Status::SENT); - // By definition, this message must be still be held by the application - // the send() call is since the progress. Assuming the message is still - // held, we can skip the auto removal of SENT and !held messages. + Perf::counters.tx_bytes.add(dataPacket->length); + driver->sendPacket(dataPacket, message->destination.ip, + policy.priority); + message->setStatus(OutMessage::Status::SENT, false); + // This message must be still be held by the application since the + // message still exists (it would have been removed when dropped + // because single packet messages are never IN_PROGRESS). Assuming + // the message is still held, we can skip the auto removal of SENT + // and !held messages. assert(message->held); if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when @@ -886,85 +780,28 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - info->id = message->id; - info->destination = message->destination; - info->packets = message; + // Some values need to be updated info->unsentBytes = message->messageLength; info->packetsGranted = - std::min(unscheduledPacketLimit, message->numPackets); + std::min(unscheduledIndexLimit, + Util::downCast(message->numPackets)); info->priority = policy.priority; info->packetsSent = 0; - // Insert and move message into the correct order in the priority queue. + // Insert and move message into the correct order in the priority + // queue. sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::deprioritize( + &sendQueue, &info->sendQueueNode, + QueuedMessageInfo::ComparePriority()); sendReady.store(true); } + // Initialize the timeouts if (needTimeouts) { + MessageBucket* bucket = message->bucket; SpinLock::Lock lock(bucket->mutex); - message->resetTimeout(lock); - } -} - -/** - * Inform the Sender that a Message no longer needs to be sent. - * - * @param message - * The Sender::Message that is no longer needs to be sent. - */ -void -Sender::cancelMessage(Sender::Message* message) -{ - MessageBucket* bucket = message->bucket; - SpinLock::UniqueLock bucket_lock(bucket->mutex); - - // FIXME: why should we even bother to do the following test? why not just remove it from the bucket, the timeout list, and the SRPT queue? - // TODO: the remove method of an intrusive list should be idempotent, right? - if (bucket->messages.contains(&message->bucketNode)) { - if (message->numPackets > 1 && - message->getStatus() == OutMessage::Status::IN_PROGRESS) { - // Check to see if the message needs to be dequeued. - SpinLock::Lock lock_queue(queueMutex); - // Recheck state with lock in case it change right before this. - // FIXME: somehow I feel like I have seen the following code snippet a million times - // FIXME: why is this sendQueue stuff so complicated? - if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { - QueuedMessageInfo* info = &message->queuedMessageInfo; - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - } - - bucket_lock.unlock(); - message->setStatus(OutMessage::Status::CANCELED); - // FIXME: who is responsible for removing this message from the bucket? - } - - // FIXME: why not change the entire method to the following: -// message->deschedule(); -// message->setStatus(OutMessage::Status::CANCELED); -} - -/** - * Inform the Sender that a Message is no longer needed. - * - * @param message - * The Sender::Message that is no longer needed. - */ -void -Sender::dropMessage(Sender::Message* message) -{ - MessageBucket* bucket = message->bucket; - SpinLock::Lock lock(bucket->mutex); - message->held = false; - Perf::counters.released_tx_messages.add(1); - if (message->getStatus() != OutMessage::Status::IN_PROGRESS) { - // Ok to delete immediately since we don't have to wait for the message - // to be sent. - message->destroy(lock); - } else { - // Defer deletion and wait for the message to be SENT. + bucket->messageTimeouts.setTimeout(&message->messageTimeout); + bucket->pingTimeouts.setTimeout(&message->pingTimeout); } } @@ -1002,29 +839,10 @@ Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) bucket_lock.unlock(); // Found expired timeout. - if (message->getStatus() != OutMessage::Status::COMPLETED) { - if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { - // Check to see if the message needs to be dequeued. - SpinLock::Lock lock_queue(queueMutex); - // FIXME: why double-check? why does it even matter? - // Recheck state with lock in case it change right before this. - if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { - QueuedMessageInfo* info = &message->queuedMessageInfo; - assert(sendQueue.contains(&info->sendQueueNode)); - sendQueue.remove(&info->sendQueueNode); - } - } - message->setStatus(OutMessage::Status::FAILED); - } else { - // TODO: double-check with Collin - SpinLock::Lock lock(bucket->mutex); - message->cancelTimeout(lock); - WARNING("SHOULDN'T BE HERE?"); - } - // FIXME: I don't understand this; if the message is completed, its - // timeouts should've been cancelled already, no? -// bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); -// bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + OutMessage::Status status = message->getStatus(); + assert(status == OutMessage::Status::IN_PROGRESS || + status == OutMessage::Status::SENT); + message->setStatus(OutMessage::Status::FAILED, true); } } @@ -1047,7 +865,7 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) } while (true) { - SpinLock::Lock lock(bucket->mutex); + SpinLock::UniqueLock bucket_lock(bucket->mutex); // No remaining timeouts. if (bucket->pingTimeouts.empty()) { break; @@ -1058,24 +876,24 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) break; } // Found expired timeout. - if (message->getStatus() == OutMessage::Status::COMPLETED || - message->getStatus() == OutMessage::Status::FAILED) { - // FIXME: how is this possible? setStatus ensures that all timeouts - // will be cancelled when the status enters an end state + OutMessage::Status status = message->getStatus(); + assert(status == OutMessage::Status::IN_PROGRESS || + status == OutMessage::Status::SENT); + if (message->options & OutMessage::Options::NO_KEEP_ALIVE && + status == OutMessage::Status::SENT) { + bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; - } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->getStatus() == OutMessage::Status::SENT) { - message->cancelTimeout(lock); - continue; } else { - // TODO: can be change to the following to avoid calling setTimeout directly? - //message->resetTimeout(lock); bucket->pingTimeouts.setTimeout(&message->pingTimeout); } + // The following code doesn't access bucket data anymore; release the + // mutex to reduce the critical section. + bucket_lock.unlock(); + // Check if sender still has packets to send - if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { + if (status == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; if (info->packetsSent < info->packetsGranted) { @@ -1121,8 +939,6 @@ Sender::trySend() */ SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); - std::vector sentMessageIds; - sentMessageIds.reserve(32); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; @@ -1160,9 +976,31 @@ Sender::trySend() } if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. - sentMessageIds.push_back(info->id); - message.setStatus(OutMessage::Status::SENT); - it = sendQueue.remove(it); + + // Advance the iterator first to avoid invalidation. + ++it; + + // Unlock the queueMutex before setStatus() since our spinlock is + // non-reentrant. + lock_queue.unlock(); + message.setStatus(OutMessage::Status::SENT, true); + + if (!message.held.load(std::memory_order_acquire)) { + // Ok to delete now that the message has been sent. + message.destroy(); + } else if (message.options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + + // Note: we can't be holding queueMutex here because our locking + // principle dictates that any bucket mutex must be acquired + // before the send queueMutex. + MessageBucket* bucket = message.bucket; + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); + bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); + } + lock_queue.lock(); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; @@ -1175,28 +1013,6 @@ Sender::trySend() } sending.clear(); - // Unlock the queueMutex to process any SENT messages to ensure any bucket - // mutex is always acquired before the send queueMutex. - lock_queue.unlock(); - for (Protocol::MessageId& msgId : sentMessageIds) { - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); - if (message == nullptr) { - // Message must have already been deleted. - continue; - } - - if (!message->held) { - // Ok to delete now that the message has been sent. - message->destroy(lock); - } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { - // No timeouts need to be checked after sending the message when - // the NO_KEEP_ALIVE option is enabled. - message->cancelTimeout(lock); - } - } - if (!idle) { Perf::counters.active_cycles.add(timer.split()); } diff --git a/src/Sender.h b/src/Sender.h index c2069e6..9ef5ea1 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -45,7 +45,7 @@ class Sender { explicit Sender(uint64_t transportId, Driver* driver, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); - virtual ~Sender(); + virtual ~Sender() = default; virtual Homa::OutMessage* allocMessage(uint16_t sourcePort); virtual void handleDonePacket(Driver::Packet* packet); @@ -60,7 +60,7 @@ class Sender { private: /// Forward declarations class Message; - class MessageBucket; + struct MessageBucket; Message* handleIncomingPacket(Driver::Packet* packet, bool resetTimeout); @@ -87,9 +87,7 @@ class Sender { * Message to which this metadata is associated. */ explicit QueuedMessageInfo(Message* message) - : id(0, 0) - , destination() - , packets(nullptr) + : packets(message) , unsentBytes(0) , packetsGranted(0) , priority(0) @@ -97,16 +95,10 @@ class Sender { , sendQueueNode(message) {} - /// Contains the unique identifier for this message. - Protocol::MessageId id; - - /// Contains destination address this message. - SocketAddress destination; - /// Handle to the queue Message for access to the packets that will /// be sent. This member documents that the packets are logically owned /// by the sendQueue and thus protected by the queueMutex. - Message* packets; + Message* const packets; /// The number of bytes that still need to be sent for a queued Message. int unsentBytes; @@ -128,8 +120,8 @@ class Sender { /** * Represents an outgoing message that can be sent. * - * Sender::Message objects are contained in the Transport::Op but should - * only be accessed by the Sender. + * TODO: document which part of the Message state are immutable, which part + * is thread-safe, and which part should be protected by mutex. */ class Message : public Homa::OutMessage { public: @@ -164,12 +156,9 @@ class Sender { virtual ~Message(); void append(const void* source, size_t count) override; void cancel() override; - void destroy(const SpinLock::Lock& lock); + void destroy(); Status getStatus() const override; - void setStatus(Status newStatus); - void deschedule(); - void resetTimeout(const SpinLock::Lock& lock); - void cancelTimeout(const SpinLock::Lock& lock); + void setStatus(Status newStatus, bool deschedule); size_t length() const override; void prepend(const void* source, size_t count) override; void release() override; @@ -179,7 +168,7 @@ class Sender { private: /// Define the maximum number of packets that a message can hold. - static const size_t MAX_MESSAGE_PACKETS = 1024; + static const int MAX_MESSAGE_PACKETS = 1024; Driver::Packet* getPacket(size_t index) const; Driver::Packet* getOrAllocPacket(size_t index); @@ -218,7 +207,7 @@ class Sender { /// True if a pointer to this message is accessible by the application /// (e.g. the message has been allocated via allocMessage() but has not /// been release via dropMessage()); false, otherwise. - bool held; + std::atomic held; /// First byte where data is or will go if empty. Must be constant after /// send() is invoked. @@ -323,6 +312,8 @@ class Sender { /// Collection of outbound messages Intrusive::List messages; + // FIXME: we should be able eliminate this field if messageTimeout is + // always a multiple of pingTimeout /// Maintains Message objects in increasing order of timeout. TimeoutManager messageTimeouts; @@ -396,11 +387,9 @@ class Sender { Protocol::MessageId::Hasher hasher; }; - void sendMessage(Sender::Message* message, SocketAddress destination, - Message::Options options = Message::Options::NONE); - void cancelMessage(Sender::Message* message); - void dropMessage(Sender::Message* message); - void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); + void startMessage(Sender::Message* message, bool restart); + // FIXME: merge the following two methods + static void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); void checkPingTimeouts(uint64_t now, MessageBucket* bucket); void trySend(); @@ -423,7 +412,10 @@ class Sender { /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - /// Protects the readyQueue. + // TODO: document the locking principle that if someone want to acquire both + // a bucket mutex and this queueMutex, the bucket mutex must be acquired first! + // TODO: why this principle? why not the reverse order? + /// Protects the sendQueue. SpinLock queueMutex; /// A list of outbound messages that have unsent packets. Messages are kept diff --git a/src/SenderTest.cc b/src/SenderTest.cc index b3beda9..07630a8 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1575,7 +1575,6 @@ TEST_F(SenderTest, cancelMessage) EXPECT_TRUE(bucket->pingTimeouts.list.empty()); EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state.load()); - // FIXME: shouldn't we check if the bucket is empty? } TEST_F(SenderTest, dropMessage_basic) diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 165cc87..38ebddd 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -14,11 +14,7 @@ */ #include "TransportImpl.h" - #include -#include -#include - #include "Cycles.h" #include "Perf.h" #include "Protocol.h" From ee3db759c2a1c8d703e71bf47feb8a84ace02601 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Thu, 15 Oct 2020 16:47:48 -0700 Subject: [PATCH 28/33] final cleanup on Sender --- include/Homa/Util.h | 13 +++ src/Drivers/DPDK/DpdkDriverImpl.cc | 23 ++---- src/Drivers/DPDK/DpdkDriverImpl.h | 3 - src/ObjectPool.h | 40 ++++++---- src/Sender.cc | 124 +++++++++-------------------- src/Sender.h | 49 +++--------- src/Timeout.h | 10 ++- 7 files changed, 103 insertions(+), 159 deletions(-) diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 1c1b75a..954c206 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -70,6 +70,19 @@ isPowerOfTwo(num_type n) return (n > 0) && ((n & (n - 1)) == 0); } +/** + * Round up the result of x divided by y, where both x and y are positive + * integers. + */ +template +constexpr num_type +roundUpIntDiv(num_type x, num_type y) +{ + static_assert(std::is_integral::value, "Integral required."); + assert(x > 0 && y > 0); + return (x + y - 1) / y; +} + /** * This class is used to temporarily release lock in a safe fashion. Creating * an object of this class will unlock its associated mutex; when the object diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index b4372ad..82ff001 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -94,7 +94,6 @@ DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 : config->HIGHEST_PACKET_PRIORITY_OVERRIDE) - , packetLock() , packetPool() , overflowBufferPool() , mbufsOutstanding(0) @@ -147,7 +146,6 @@ DpdkDriver::Impl::Impl(const char* ifname, (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 : config->HIGHEST_PACKET_PRIORITY_OVERRIDE) - , packetLock() , packetPool() , overflowBufferPool() , mbufPool(nullptr) @@ -179,7 +177,6 @@ Driver::Packet* DpdkDriver::Impl::allocPacket() { DpdkDriver::Impl::Packet* packet = nullptr; - SpinLock::Lock lock(packetLock); static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); @@ -424,17 +421,14 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, assert(length <= MAX_PAYLOAD_SIZE); DpdkDriver::Impl::Packet* packet = nullptr; - { - SpinLock::Lock lock(packetLock); - static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; - if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { - packet = packetPool.construct(m, payload); - mbufsOutstanding++; - } else { - OverflowBuffer* buf = overflowBufferPool.construct(); - rte_memcpy(payload, buf->data, length); - packet = packetPool.construct(buf); - } + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + packet = packetPool.construct(m, payload); + mbufsOutstanding++; + } else { + OverflowBuffer* buf = overflowBufferPool.construct(); + rte_memcpy(payload, buf->data, length); + packet = packetPool.construct(buf); } packet->base.length = length; @@ -451,7 +445,6 @@ void DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - SpinLock::Lock lock(packetLock); DpdkDriver::Impl::Packet* packet = container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 7305ce8..b8def23 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -184,9 +184,6 @@ class DpdkDriver::Impl { /// set by override). const int HIGHEST_PACKET_PRIORITY; - /// Protects access to the packetPool. - SpinLock packetLock; - /// Provides memory allocation for the DPDK specific implementation of a /// Driver Packet. ObjectPool packetPool; diff --git a/src/ObjectPool.h b/src/ObjectPool.h index 3b6a918..456c22f 100644 --- a/src/ObjectPool.h +++ b/src/ObjectPool.h @@ -16,11 +16,10 @@ #ifndef HOMA_OBJECTPOOL_H #define HOMA_OBJECTPOOL_H -#include "Homa/Exception.h" - -#include "Debug.h" - #include +#include "Debug.h" +#include "Homa/Exception.h" +#include "SpinLock.h" /* * Notes on performance and efficiency: @@ -52,6 +51,8 @@ namespace Homa { * new and delete an relatively fixed set of objects very quickly. * For example, transports use ObjectPool to allocate short-lived rpc * objects that cannot be kept in a stack context. + * + * This class is thread-safe. */ template class ObjectPool { @@ -99,24 +100,26 @@ class ObjectPool { template T* construct(Args&&... args) { - void* backing = NULL; - if (pool.size() == 0) { - backing = operator new(sizeof(T)); - } else { - backing = pool.back(); - pool.pop_back(); + void* backing = nullptr; + { + SpinLock::Lock lock(mutex); + if (pool.size() == 0) { + backing = operator new(sizeof(T)); + } else { + backing = pool.back(); + pool.pop_back(); + } + outstandingObjects++; } - T* object = NULL; try { - object = new (backing) T(static_cast(args)...); + return new (backing) T(static_cast(args)...); } catch (...) { + SpinLock::Lock lock(mutex); pool.push_back(backing); + outstandingObjects--; throw; } - - outstandingObjects++; - return object; } /** @@ -124,13 +127,18 @@ class ObjectPool { */ void destroy(T* object) { - assert(outstandingObjects > 0); object->~T(); + + SpinLock::Lock lock(mutex); + assert(outstandingObjects > 0); pool.push_back(static_cast(object)); outstandingObjects--; } private: + /// Monitor-style lock to protect the metadata of the pool. + SpinLock mutex; + /// Count of the number of objects for which construct() was called, but /// destroy() was not. uint64_t outstandingObjects; diff --git a/src/Sender.cc b/src/Sender.cc index 386ea0c..ee6c097 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -50,7 +50,9 @@ Sender::Sender(uint64_t transportId, Driver* driver, , policyManager(policyManager) , nextMessageSequenceNumber(1) , DRIVER_QUEUED_BYTE_LIMIT(2 * driver->getMaxPayloadSize()) - , messageBuckets(messageTimeoutCycles, pingIntervalCycles) + , MESSAGE_TIMEOUT_INTERVALS( + Util::roundUpIntDiv(messageTimeoutCycles, pingIntervalCycles)) + , messageBuckets(pingIntervalCycles) , queueMutex() , sendQueue() , sending() @@ -65,9 +67,8 @@ Sender::Sender(uint64_t transportId, Driver* driver, Homa::OutMessage* Sender::allocMessage(uint16_t sourcePort) { - SpinLock::Lock lock_allocator(messageAllocator.mutex); Perf::counters.allocated_tx_messages.add(1); - return messageAllocator.pool.construct(this, sourcePort); + return messageAllocator.construct(this, sourcePort); } /** @@ -95,7 +96,7 @@ Sender::handleIncomingPacket(Driver::Packet* packet, bool resetTimeout) SpinLock::Lock lock(bucket->mutex); Message* message = bucket->findMessage(msgId, lock); if (resetTimeout) { - bucket->messageTimeouts.setTimeout(&message->messageTimeout); + message->numPingTimeouts = 0; bucket->pingTimeouts.setTimeout(&message->pingTimeout); } return message; @@ -407,14 +408,29 @@ Sender::checkTimeouts() MessageBucket* bucket = &messageBuckets.buckets[index]; uint64_t now = PerfUtils::Cycles::rdtsc(); checkPingTimeouts(now, bucket); - checkMessageTimeouts(now, bucket); } /** - * Destruct a Message. Will release all contained Packet objects. + * Destruct a Message. + * + * This method will detach the message from the Sender and release all contained + * Packet objects. */ Sender::Message::~Message() { + Perf::counters.destroyed_tx_messages.add(1); + + // We assume that this message has been unlinked from the sendQueue before + // this method is invoked. + assert(getStatus() != OutMessage::Status::IN_PROGRESS); + + // Remove this message from the other data structures of the Sender. + { + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->pingTimeouts.cancelTimeout(&pingTimeout); + bucket->messages.remove(&bucketNode); + } + // Sender message must be contiguous driver->releasePackets(packets, numPackets); } @@ -466,32 +482,6 @@ Sender::Message::cancel() setStatus(OutMessage::Status::CANCELED, true); } -/** - * Detach this message from the transport and destruct the Message object. - * - * Note: no one should access this message after the method returns. - */ -void -Sender::Message::destroy() -{ - // We assume that this message has been unlinked from the sendQueue before - // this method is invoked. - assert(getStatus() != OutMessage::Status::IN_PROGRESS); - - // Remove this message from the other data structures of the Sender. - { - SpinLock::Lock bucket_lock(bucket->mutex); - bucket->messageTimeouts.cancelTimeout(&messageTimeout); - bucket->pingTimeouts.cancelTimeout(&pingTimeout); - bucket->messages.remove(&bucketNode); - } - - // Destruct the Message object. - SpinLock::Lock lock_allocator(sender->messageAllocator.mutex); - sender->messageAllocator.pool.destroy(this); - Perf::counters.destroyed_tx_messages.add(1); -} - /** * @copydoc Homa::OutMessage::getStatus() */ @@ -537,7 +527,6 @@ Sender::Message::setStatus(Status newStatus, bool deschedule) newStatus == OutMessage::Status::COMPLETED || newStatus == OutMessage::Status::FAILED) { SpinLock::Lock lock(bucket->mutex); - bucket->messageTimeouts.cancelTimeout(&messageTimeout); bucket->pingTimeouts.cancelTimeout(&pingTimeout); } @@ -594,7 +583,7 @@ Sender::Message::release() if (getStatus() != OutMessage::Status::IN_PROGRESS) { // Ok to delete immediately since we don't have to wait for the message // to be sent. - destroy(); + sender->messageAllocator.destroy(this); } else { // Defer deletion and wait for the message to be SENT. } @@ -787,12 +776,10 @@ Sender::startMessage(Sender::Message* message, bool restart) Util::downCast(message->numPackets)); info->priority = policy.priority; info->packetsSent = 0; - // Insert and move message into the correct order in the priority - // queue. + // Insert and move message into the correct order in the priority queue. sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize( - &sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, + QueuedMessageInfo::ComparePriority()); sendReady.store(true); } @@ -800,52 +787,11 @@ Sender::startMessage(Sender::Message* message, bool restart) if (needTimeouts) { MessageBucket* bucket = message->bucket; SpinLock::Lock lock(bucket->mutex); - bucket->messageTimeouts.setTimeout(&message->messageTimeout); + message->numPingTimeouts = 0; bucket->pingTimeouts.setTimeout(&message->pingTimeout); } } -/** - * Process any outbound messages in a given bucket that have timed out due to - * lack of activity from the Receiver. - * - * Pulled out of checkTimeouts() for ease of testing. - * - * @param now - * The rdtsc cycle that should be considered the "current" time. - * @param bucket - * The bucket whose message timeouts should be checked. - */ -void -Sender::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) -{ - if (!bucket->messageTimeouts.anyElapsed(now)) { - return; - } - - while (true) { - SpinLock::UniqueLock bucket_lock(bucket->mutex); - // No remaining timeouts. - if (bucket->messageTimeouts.empty()) { - break; - } - Message* message = &bucket->messageTimeouts.front(); - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed(now)) { - break; - } - - // Release the bucket mutex to avoid deadlock inside setStatus(). - bucket_lock.unlock(); - - // Found expired timeout. - OutMessage::Status status = message->getStatus(); - assert(status == OutMessage::Status::IN_PROGRESS || - status == OutMessage::Status::SENT); - message->setStatus(OutMessage::Status::FAILED, true); - } -} - /** * Process any outbound messages in a given bucket that need to be pinged to * ensure the message is kept alive by the receiver. @@ -881,11 +827,20 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) status == OutMessage::Status::SENT); if (message->options & OutMessage::Options::NO_KEEP_ALIVE && status == OutMessage::Status::SENT) { - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; } else { - bucket->pingTimeouts.setTimeout(&message->pingTimeout); + message->numPingTimeouts++; + if (message->numPingTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { + // Found expired message. + + // Release the bucket mutex to avoid deadlock in setStatus(). + bucket_lock.unlock(); + message->setStatus(OutMessage::Status::FAILED, true); + continue; + } else { + bucket->pingTimeouts.setTimeout(&message->pingTimeout); + } } // The following code doesn't access bucket data anymore; release the @@ -987,7 +942,7 @@ Sender::trySend() if (!message.held.load(std::memory_order_acquire)) { // Ok to delete now that the message has been sent. - message.destroy(); + messageAllocator.destroy(&message); } else if (message.options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. @@ -997,7 +952,6 @@ Sender::trySend() // before the send queueMutex. MessageBucket* bucket = message.bucket; SpinLock::Lock bucket_lock(bucket->mutex); - bucket->messageTimeouts.cancelTimeout(&message.messageTimeout); bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); } lock_queue.lock(); diff --git a/src/Sender.h b/src/Sender.h index 9ef5ea1..e26b0ca 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -119,9 +119,6 @@ class Sender { /** * Represents an outgoing message that can be sent. - * - * TODO: document which part of the Message state are immutable, which part - * is thread-safe, and which part should be protected by mutex. */ class Message : public Homa::OutMessage { public: @@ -148,7 +145,7 @@ class Sender { // construction. See Message::occupied. , state(Status::NOT_STARTED) , bucketNode(this) - , messageTimeout(this) + , numPingTimeouts(0) , pingTimeout(this) , queuedMessageInfo(this) {} @@ -156,7 +153,6 @@ class Sender { virtual ~Message(); void append(const void* source, size_t count) override; void cancel() override; - void destroy(); Status getStatus() const override; void setStatus(Status newStatus, bool deschedule); size_t length() const override; @@ -239,10 +235,9 @@ class Sender { /// is protected by the associated MessageBucket::mutex; Intrusive::List::Node bucketNode; - /// Intrusive structure used by the Sender to keep track when the - /// sending of this message should be considered failed. Access to this + /// Number of ping timeouts that occurred in a row. Access to this /// structure is protected by the associated MessageBucket::mutex. - Timeout messageTimeout; + int numPingTimeouts; /// Intrusive structure used by the Sender to keep track when this /// message should be checked to ensure progress is still being made. @@ -268,18 +263,13 @@ class Sender { /** * MessageBucket constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - explicit MessageBucket(uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) + explicit MessageBucket(uint64_t pingIntervalCycles) : mutex() , messages() - , messageTimeouts(messageTimeoutCycles) , pingTimeouts(pingIntervalCycles) {} @@ -312,11 +302,6 @@ class Sender { /// Collection of outbound messages Intrusive::List messages; - // FIXME: we should be able eliminate this field if messageTimeout is - // always a multiple of pingTimeout - /// Maintains Message objects in increasing order of timeout. - TimeoutManager messageTimeouts; - /// Maintains Message objects in increasing order of ping timeout. TimeoutManager pingTimeouts; }; @@ -346,21 +331,17 @@ class Sender { /** * MessageBucketMap constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - explicit MessageBucketMap(uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) + explicit MessageBucketMap(uint64_t pingIntervalCycles) : buckets() , hasher() { buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets.emplace_back(messageTimeoutCycles, pingIntervalCycles); + buckets.emplace_back(pingIntervalCycles); } } @@ -388,8 +369,6 @@ class Sender { }; void startMessage(Sender::Message* message, bool restart); - // FIXME: merge the following two methods - static void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); void checkPingTimeouts(uint64_t now, MessageBucket* bucket); void trySend(); @@ -409,13 +388,14 @@ class Sender { /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; + /// The number of ping timeouts to occur before declaring a message timeout. + const int MESSAGE_TIMEOUT_INTERVALS; + /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - // TODO: document the locking principle that if someone want to acquire both - // a bucket mutex and this queueMutex, the bucket mutex must be acquired first! - // TODO: why this principle? why not the reverse order? - /// Protects the sendQueue. + /// Protects the sendQueue. Locking principle: when a bucket mutex is also + /// required, it must be acquired before the sendQueue mutex. SpinLock queueMutex; /// A list of outbound messages that have unsent packets. Messages are kept @@ -438,12 +418,7 @@ class Sender { std::atomic nextBucketIndex; /// Used to allocate Message objects. - struct { - /// Protects the messageAllocator.pool - SpinLock mutex; - /// Pool allocator for Message objects. - ObjectPool pool; - } messageAllocator; + ObjectPool messageAllocator; }; } // namespace Core diff --git a/src/Timeout.h b/src/Timeout.h index 56710f4..7d17ae0 100644 --- a/src/Timeout.h +++ b/src/Timeout.h @@ -100,12 +100,16 @@ class TimeoutManager { * * @param timeout * The Timeout that should be scheduled. + * @param now + * Optionally provided "current" timestamp cycle time. Used to avoid + * unnecessary calls to PerfUtils::Cycles::rdtsc() if the current time + * is already available to the caller. */ - inline void setTimeout(Timeout* timeout) + inline void setTimeout(Timeout* timeout, + uint64_t now = PerfUtils::Cycles::rdtsc()) { list.remove(&timeout->node); - timeout->expirationCycleTime = - PerfUtils::Cycles::rdtsc() + timeoutIntervalCycles; + timeout->expirationCycleTime = now + timeoutIntervalCycles; list.push_back(&timeout->node); nextTimeout.store(list.front().expirationCycleTime, std::memory_order_relaxed); From c382ea30e5598cff259a6d14e1d3bfdc8722ca69 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Fri, 16 Oct 2020 23:47:30 -0700 Subject: [PATCH 29/33] similar cleanups for Receiver --- include/Homa/Homa.h | 11 -- include/Homa/Util.h | 7 +- src/Intrusive.h | 34 ++-- src/Receiver.cc | 448 ++++++++++++++++++------------------------- src/Receiver.h | 193 +++++++++---------- src/Sender.cc | 19 +- src/Sender.h | 53 +++-- src/TransportImpl.cc | 2 + 8 files changed, 345 insertions(+), 422 deletions(-) diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index d5edf4a..62fd4b3 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -70,17 +70,6 @@ class InMessage { */ virtual void acknowledge() const = 0; - /** - * Returns true if the sender is no longer waiting for this message to be - * processed; false otherwise. - */ - virtual bool dropped() const = 0; - - /** - * Inform the sender that this message has failed to be processed. - */ - virtual void fail() const = 0; - /** * Get the contents of a specified range of bytes in the Message by * copying them into the provided destination memory region. diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 954c206..72f0752 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -74,13 +74,14 @@ isPowerOfTwo(num_type n) * Round up the result of x divided by y, where both x and y are positive * integers. */ -template +template constexpr num_type -roundUpIntDiv(num_type x, num_type y) +roundUpIntDiv(num_type x, Y y) { static_assert(std::is_integral::value, "Integral required."); assert(x > 0 && y > 0); - return (x + y - 1) / y; + num_type yy = downCast(y); + return (x + yy - 1) / yy; } /** diff --git a/src/Intrusive.h b/src/Intrusive.h index fc3391c..47cddaa 100644 --- a/src/Intrusive.h +++ b/src/Intrusive.h @@ -486,26 +486,25 @@ class List { * @tparam ElementType * Type of the element held in the Intrusive::List. * @tparam Compare - * A weak strict ordering binary comparator for objects of ElementType. + * A weak strict ordering binary comparator for objects of ElementType + * which returns true when the first argument should be ordered before + * the second. The signature should be equivalent to the following: + * bool comp(const ElementType& a, const ElementType& b); * @param list * List that contains the element. * @parma node * Intrusive list node for the element that should be prioritized. - * @param comp - * Comparison function object which returns true when the first argument - * should be ordered before the second. The signature should be equivalent - * to the following: - * bool comp(const ElementType& a, const ElementType& b); */ -template +template void -prioritize(List* list, typename List::Node* node, - Compare comp) +prioritize(List* list, typename List::Node* node) { assert(list->contains(node)); auto it_node = list->get(node); auto it_pos = it_node; while (it_pos != list->begin()) { + Compare comp; if (!comp(*it_node, *std::prev(it_pos))) { // Found the correct location; just before it_pos. break; @@ -528,25 +527,24 @@ prioritize(List* list, typename List::Node* node, * @tparam ElementType * Type of the element held in the Intrusive::List. * @tparam Compare - * A weak strict ordering binary comparator for objects of ElementType. + * A weak strict ordering binary comparator for objects of ElementType + * which returns true when the first argument should be ordered before + * the second. The signature should be equivalent to the following: + * bool comp(const ElementType& a, const ElementType& b); * @param list * List that contains the element. * @parma node * Intrusive list node for the element that should be prioritized. - * @param comp - * Comparison function object which returns true when the first argument - * should be ordered before the second. The signature should be equivalent - * to the following: - * bool comp(const ElementType& a, const ElementType& b); */ -template +template void -deprioritize(List* list, typename List::Node* node, - Compare comp) +deprioritize(List* list, typename List::Node* node) { assert(list->contains(node)); auto it_node = list->get(node); auto it_pos = std::next(it_node); + Compare comp; while (it_pos != list->end()) { if (comp(*it_node, *it_pos)) { // Found the correct location; just before it_pos. diff --git a/src/Receiver.cc b/src/Receiver.cc index b2208d5..8cb948d 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -40,7 +40,9 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) : driver(driver) , policyManager(policyManager) - , messageBuckets(messageTimeoutCycles, resendIntervalCycles) + , MESSAGE_TIMEOUT_INTERVALS( + Util::roundUpIntDiv(messageTimeoutCycles, resendIntervalCycles)) + , messageBuckets(resendIntervalCycles) , schedulerMutex() , scheduledPeers() , receivedMessages() @@ -54,70 +56,83 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, */ Receiver::~Receiver() { - schedulerMutex.lock(); + // To ensure that all resources of a Receiver can be freed correctly, it's + // the user's responsibility to ensure the following before destructing the + // Receiver: + // - The transport must have been taken "offline" so that no more incoming + // packets will arrive. + // - All completed incoming messages that are delivered to the application + // must have been returned back to the Receiver (so that they don't hold + // dangling pointers to the destructed Receiver afterwards). + // - There must be only one thread left that can hold a reference to the + // transport (the destructor is designed to run exactly once). + + // Technically speaking, the Receiver is designed in a way that a default + // destructor should be sufficient. However, for clarity and debugging + // purpose, we decided to write the cleanup procedure explicitly anyway. + + // Destruct all MessageBucket's and the Messages within. + messageBuckets.buckets.clear(); + + // Destruct all Peer's. Peer's must be removed from scheduledPeers first. scheduledPeers.clear(); peerTable.clear(); - receivedMessages.mutex.lock(); - receivedMessages.queue.clear(); - for (auto it = messageBuckets.buckets.begin(); - it != messageBuckets.buckets.end(); ++it) { - MessageBucket* bucket = *it; - bucket->mutex.lock(); - auto iit = bucket->messages.begin(); - while (iit != bucket->messages.end()) { - Message* message = &(*iit); - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - iit = bucket->messages.remove(iit); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - } - } + + // Destruct all completed Messages that are not yet delivered. + for (auto& message : receivedMessages.queue) { + messageAllocator.destroy(&message); } } /** - * Process an incoming DATA packet. + * Execute the common processing logic that is shared among all incoming + * packets. * * @param packet - * The incoming packet to be processed. + * Incoming packet to be processed. + * @param createIfAbsent + * True if a new Message should be constructed when no matching message + * can be found. * @param sourceIp - * Source IP address of the packet. + * Source IP address of the packet. Only valid when createIfAbsent is true. * @return - * True if the Receiver decides to take ownership of the packet. False - * if the Receiver has no more use of this packet and it can be released - * to the driver. + * Pointer to the message targeted by the incoming packet, or nullptr if no + * matching message can be found. */ -bool -Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) +Receiver::Message* +Receiver::handleIncomingPacket(Driver::Packet* packet, bool createIfAbsent, + IpAddress* sourceIp) { - Protocol::Packet::DataHeader* header = - static_cast(packet->payload); - uint16_t dataHeaderLength = sizeof(Protocol::Packet::DataHeader); - Protocol::MessageId id = header->common.messageId; + // Find the message bucket. + Protocol::Packet::CommonHeader* commonHeader = + static_cast(packet->payload); + Protocol::MessageId msgId = commonHeader ->messageId; + MessageBucket* bucket = messageBuckets.getBucket(msgId); - MessageBucket* bucket = messageBuckets.getBucket(id); + // Acquire the bucket mutex to ensure that a new message can be constructed + // and inserted to the bucket atomically. SpinLock::Lock lock_bucket(bucket->mutex); - Message* message = bucket->findMessage(id, lock_bucket); - if (message == nullptr) { - // New message + + // Find the target message, or construct a new message if necessary. + Message* message = bucket->findMessage(msgId, lock_bucket); + if (message == nullptr && createIfAbsent) { + // Construct a new message + Protocol::Packet::DataHeader* header = + static_cast(packet->payload); int messageLength = header->totalLength; int numUnscheduledPackets = header->unscheduledIndexLimit; - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - SocketAddress srcAddress = { - .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; - message = messageAllocator.pool.construct( - this, driver, dataHeaderLength, messageLength, id, srcAddress, - numUnscheduledPackets); - Perf::counters.allocated_rx_messages.add(1); - } - + SocketAddress srcAddress = { + .ip = *sourceIp, .port = be16toh(header->common.prefix.sport)}; + message = messageAllocator.construct( + this, driver, sizeof(Protocol::Packet::DataHeader), messageLength, + header->common.messageId, srcAddress, numUnscheduledPackets); + Perf::counters.allocated_rx_messages.add(1); + + // Start tracking the message. bucket->messages.push_back(&message->bucketNode); + policyManager->signalNewMessage( message->source.ip, header->policyVersion, header->totalLength); - if (message->scheduled) { // Message needs to be scheduled. SpinLock::Lock lock_scheduler(schedulerMutex); @@ -125,50 +140,85 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) } } - // Things that must be true (sanity check) - assert(id == message->id); - assert(message->driver == driver); + // The sender is still alive; reschedule the timeout. + if (message != nullptr) { + // Optimization: for single-packet messages, it would be wasteful to set + // a resendTimeout just to cancel it immediately. + bool needTimeout = (message->numExpectedPackets > 1); + + // If the message is not in progress, its resend timeout must have been + // cancelled in setState(); don't re-insert a new one. + if (needTimeout && message->getState() == Message::State::IN_PROGRESS) { + message->numResendTimeouts = 0; + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } + } + return message; +} + +/** + * Process an incoming DATA packet. + * + * @param packet + * The incoming packet to be processed. + * @param sourceIp + * Source IP address of the packet. + * @return + * True if the Receiver decides to take ownership of the packet. False + * if the Receiver has no more use of this packet and it can be released + * to the driver. + */ +bool +Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) +{ + // Find the message + Message* message = handleIncomingPacket(packet, true, &sourceIp); + Protocol::Packet::DataHeader* header = + static_cast(packet->payload); + + // Sanity checks assert(message->source.ip == sourceIp); assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet bool packetAdded = message->setPacket(header->index, packet); - if (packetAdded) { - // Update schedule for scheduled messages. - if (message->scheduled) { - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - // Update the schedule if the message is still being scheduled - // (i.e. still linked to a scheduled peer). - if (info->peer != nullptr) { - int packetDataBytes = - packet->length - message->TRANSPORT_HEADER_LENGTH; - assert(info->bytesRemaining >= packetDataBytes); - info->bytesRemaining -= packetDataBytes; - updateSchedule(message, lock_scheduler); - } + if (!packetAdded) { + // Must be a duplicate packet; drop it. + return false; + } + + // Update schedule for scheduled messages. + if (message->scheduled) { + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + // Update the schedule if the message is still being scheduled + // (i.e. still linked to a scheduled peer). + if (info->peer != nullptr) { + int packetDataBytes = + packet->length - message->TRANSPORT_HEADER_LENGTH; + assert(info->bytesRemaining >= packetDataBytes); + info->bytesRemaining -= packetDataBytes; + updateSchedule(message, lock_scheduler); } + } - // Receiving a new packet means the message is still active so it - // shouldn't time out until a while later. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->numPackets < message->numExpectedPackets) { - // Still waiting for more packets to arrive but the arrival of a - // new packet means we should wait a while longer before requesting - // RESENDs of the missing packets. - bucket->resendTimeouts.setTimeout(&message->resendTimeout); - } else { - // All message packets have been received. - message->state.store(Message::State::COMPLETED); + // Complete the message if all packets have been received. + if (message->numPackets == message->numExpectedPackets) { + message->state.store(Message::State::COMPLETED, + std::memory_order_release); + // Optimization: for single-packet messages, there is no need to cancel + // the resendTimeout if we don't insert one in the first place. + if (message->numPackets > 1) { + MessageBucket* bucket = message->bucket; + SpinLock::Lock lock(bucket->mutex); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - receivedMessages.queue.push_back(&message->receivedMessageNode); - Perf::counters.received_rx_messages.add(1); } - } else { - // must be a duplicate packet; drop packet. - return false; + + // Deliver the message to the user of the transport. + SpinLock::Lock lock_received_messages(receivedMessages.mutex); + receivedMessages.queue.push_back(&message->receivedMessageNode); + Perf::counters.received_rx_messages.add(1); } return true; } @@ -182,21 +232,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) void Receiver::handleBusyPacket(Driver::Packet* packet) { - Protocol::Packet::BusyHeader* header = - static_cast(packet->payload); - Protocol::MessageId id = header->common.messageId; - - MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* message = bucket->findMessage(id, lock_bucket); - if (message != nullptr) { - // Sender has replied BUSY to our RESEND request; consider this message - // still active. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->state == Message::State::IN_PROGRESS) { - bucket->resendTimeouts.setTimeout(&message->resendTimeout); - } - } + handleIncomingPacket(packet, false, nullptr); } /** @@ -210,17 +246,8 @@ Receiver::handleBusyPacket(Driver::Packet* packet) void Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { - Protocol::Packet::PingHeader* header = - static_cast(packet->payload); - Protocol::MessageId id = header->common.messageId; - - MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* message = bucket->findMessage(id, lock_bucket); + Message* message = handleIncomingPacket(packet, false, nullptr); if (message != nullptr) { - // Sender is checking on this message; consider it still active. - bucket->messageTimeouts.setTimeout(&message->messageTimeout); - // We are here either because a GRANT got lost, or we haven't issued a // GRANT in along time. Send out the latest GRANT if one exists or just // an "empty" GRANT to let the Sender know we are aware of the message. @@ -247,8 +274,10 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send(driver, sourceIp, - id); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + ControlPacket::send( + driver, sourceIp, header->messageId); } } @@ -260,8 +289,6 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) * * @return * A new Message which has been received, if available; otherwise, nullptr. - * - * @sa dropMessage() */ Homa::InMessage* Receiver::receiveMessage() @@ -298,17 +325,42 @@ Receiver::checkTimeouts() { uint index = nextBucketIndex.fetch_add(1, std::memory_order_relaxed) & MessageBucketMap::HASH_KEY_MASK; - MessageBucket* bucket = messageBuckets.buckets.at(index); + MessageBucket* bucket = &messageBuckets.buckets[index]; uint64_t now = PerfUtils::Cycles::rdtsc(); checkResendTimeouts(now, bucket); - checkMessageTimeouts(now, bucket); } /** - * Destruct a Message. Will release all contained Packet objects. + * Destruct a Message. + * + * This method will detach the message from the transport and release all + * contained Packet objects. */ Receiver::Message::~Message() { + Perf::counters.destroyed_rx_messages.add(1); + Receiver* receiver = bucket->receiver; + + // Unschedule the message if it is still scheduled (i.e. still linked to a + // scheduled peer). + if (scheduled) { + SpinLock::Lock lock_scheduler(receiver->schedulerMutex); + ScheduledMessageInfo* info = &scheduledMessageInfo; + if (info->peer != nullptr) { + receiver->unschedule(this, lock_scheduler); + } + } + + // Remove this message from the other data structures of the Receiver. + { + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->resendTimeouts.cancelTimeout(&resendTimeout); + bucket->messages.remove(&bucketNode); + + SpinLock::Lock receive_lock(receiver->receivedMessages.mutex); + receiver->receivedMessages.queue.remove(&receivedMessageNode); + } + // Find contiguous ranges of packets and release them back to the // driver. int num = 0; @@ -342,33 +394,10 @@ Receiver::Message::~Message() void Receiver::Message::acknowledge() const { - MessageBucket* bucket = receiver->messageBuckets.getBucket(id); - SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_done_pkts.add(1); ControlPacket::send(driver, source.ip, id); } -/** - * @copydoc Homa::InMessage::dropped() - */ -bool -Receiver::Message::dropped() const -{ - return state.load() == State::DROPPED; -} - -/** - * @copydoc See Homa::InMessage::fail() - */ -void -Receiver::Message::fail() const -{ - MessageBucket* bucket = receiver->messageBuckets.getBucket(id); - SpinLock::Lock lock(bucket->mutex); - Perf::counters.tx_error_pkts.add(1); - ControlPacket::send(driver, source.ip, id); -} - /** * @copydoc Homa::InMessage::get() */ @@ -438,7 +467,7 @@ Receiver::Message::strip(size_t count) void Receiver::Message::release() { - receiver->dropMessage(this); + bucket->receiver->messageAllocator.destroy(this); } /** @@ -487,110 +516,6 @@ Receiver::Message::setPacket(size_t index, Driver::Packet* packet) return true; } -/** - * Inform the Receiver that an Message returned by receiveMessage() is not - * needed and can be dropped. - * - * @param message - * Message which will be dropped. - */ -void -Receiver::dropMessage(Receiver::Message* message) -{ - Protocol::MessageId msgId = message->id; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* foundMessage = bucket->findMessage(msgId, lock_bucket); - if (foundMessage != nullptr) { - assert(message == foundMessage); - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - if (message->scheduled) { - // Unschedule the message if it is still scheduled (i.e. still - // linked to a scheduled peer). - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - if (info->peer != nullptr) { - unschedule(message, lock_scheduler); - } - } - bucket->messages.remove(&message->bucketNode); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - Perf::counters.destroyed_rx_messages.add(1); - } - } -} - -/** - * Process any inbound messages that have timed out due to lack of activity from - * the Sender. - * - * - * Pulled out of checkTimeouts() for ease of testing. - * - * @param now - * The rdtsc cycle that should be considered the "current" time. - * @param bucket - * The bucket whose message timeouts should be checked. - */ -void -Receiver::checkMessageTimeouts(uint64_t now, MessageBucket* bucket) -{ - if (!bucket->messageTimeouts.anyElapsed(now)) { - return; - } - - while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - - // No remaining timeouts. - if (bucket->messageTimeouts.empty()) { - break; - } - - Message* message = &bucket->messageTimeouts.front(); - - // No remaining expired timeouts. - if (!message->messageTimeout.hasElapsed(now)) { - break; - } - - // Found expired timeout. - - // Cancel timeouts - bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); - bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - - if (message->state == Message::State::IN_PROGRESS) { - // Message timed out before being fully received; drop the - // message. - - // Unschedule the message - if (message->scheduled) { - // Unschedule the message if it is still scheduled (i.e. - // still linked to a scheduled peer). - SpinLock::Lock lock_scheduler(schedulerMutex); - ScheduledMessageInfo* info = &message->scheduledMessageInfo; - if (info->peer != nullptr) { - unschedule(message, lock_scheduler); - } - } - - bucket->messages.remove(&message->bucketNode); - { - SpinLock::Lock lock_allocator(messageAllocator.mutex); - messageAllocator.pool.destroy(message); - } - } else { - // Message timed out but we already made it available to the - // Transport; let the Transport know. - message->state.store(Message::State::DROPPED); - } - } -} - /** * Process any inbound messages that may need to issue resends. * @@ -616,16 +541,22 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) break; } - Message* message = &bucket->resendTimeouts.front(); - // No remaining expired timeouts. + Message* message = &bucket->resendTimeouts.front(); if (!message->resendTimeout.hasElapsed(now)) { break; } // Found expired timeout. - assert(message->state == Message::State::IN_PROGRESS); - bucket->resendTimeouts.setTimeout(&message->resendTimeout); + assert(message->getState() == Message::State::IN_PROGRESS); + message->numResendTimeouts++; + if (message->numResendTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { + // Message timed out before being fully received; drop the message. + messageAllocator.destroy(message); + continue; + } else { + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } // This Receiver expected to have heard from the Sender within the // last timeout period but it didn't. Request a resend of granted @@ -644,9 +575,8 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) continue; } else if (grantIndexLimit * message->PACKET_DATA_LENGTH < info->bytesGranted) { - grantIndexLimit = - (info->bytesGranted + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH; + grantIndexLimit = Util::roundUpIntDiv( + info->bytesGranted, message->PACKET_DATA_LENGTH); } } @@ -795,22 +725,19 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) // Insert the Message peer->scheduledMessages.push_front(&info->scheduledMessageNode); Intrusive::deprioritize(&peer->scheduledMessages, - &info->scheduledMessageNode, - ScheduledMessageInfo::ComparePriority()); + &info->scheduledMessageNode); info->peer = peer; if (!scheduledPeers.contains(&peer->scheduledPeerNode)) { // Must be the only message of this peer; push the peer to the // end of list to be moved later. assert(peer->scheduledMessages.size() == 1); scheduledPeers.push_front(&peer->scheduledPeerNode); - Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, - Peer::ComparePriority()); + Intrusive::deprioritize(&scheduledPeers, + &peer->scheduledPeerNode); } else if (&info->peer->scheduledMessages.front() == message) { // Update the Peer's position in the queue since the new message is the // peer's first scheduled message. - Intrusive::prioritize(&scheduledPeers, - &info->peer->scheduledPeerNode, - Peer::ComparePriority()); + Intrusive::prioritize(&scheduledPeers, &peer->scheduledPeerNode); } else { // The peer's first scheduled message did not change. Nothing to do. } @@ -844,7 +771,8 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) // Cleanup the schedule if (peer->scheduledMessages.empty()) { - // Remove the empty peer. + // Remove the empty peer from the schedule (the peer object is still + // alive). scheduledPeers.remove(it); } else if (std::next(it) == scheduledPeers.end() || !comp(*std::next(it), *it)) { @@ -853,8 +781,8 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) // since removing a message cannot increase the peer's priority. } else { // Peer needs to be moved. - Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, - comp); + Intrusive::deprioritize(&scheduledPeers, + &peer->scheduledPeerNode); } } @@ -880,15 +808,13 @@ Receiver::updateSchedule(Receiver::Message* message, const SpinLock::Lock& lock) // Update the message's position within its Peer scheduled message queue. Intrusive::prioritize(&info->peer->scheduledMessages, - &info->scheduledMessageNode, - ScheduledMessageInfo::ComparePriority()); + &info->scheduledMessageNode); // Update the Peer's position in the queue if this message is now the first // scheduled message. if (&info->peer->scheduledMessages.front() == message) { Intrusive::prioritize(&scheduledPeers, - &info->peer->scheduledPeerNode, - Peer::ComparePriority()); + &info->peer->scheduledPeerNode); } } diff --git a/src/Receiver.h b/src/Receiver.h index e3eac22..4bc64bd 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -58,24 +58,16 @@ class Receiver { private: // Forward declaration class Message; + struct MessageBucket; struct Peer; + Message* handleIncomingPacket(Driver::Packet* packet, bool createIfAbsent, + IpAddress* sourceIp = nullptr); + /** * Contains metadata for a Message that requires additional GRANTs. */ struct ScheduledMessageInfo { - /** - * Implements a binary comparison function for the strict weak priority - * ordering of two Message objects. - */ - struct ComparePriority { - bool operator()(const Message& a, const Message& b) - { - return a.scheduledMessageInfo.bytesRemaining < - b.scheduledMessageInfo.bytesRemaining; - } - }; - /** * ScheduledMessageInfo constructor. * @@ -121,6 +113,18 @@ class Receiver { */ class Message : public Homa::InMessage { public: + /** + * Implements a binary comparison function for the strict weak priority + * ordering of two Message objects. + */ + struct ComparePriority { + bool operator()(const Message& a, const Message& b) + { + return a.scheduledMessageInfo.bytesRemaining < + b.scheduledMessageInfo.bytesRemaining; + } + }; + /** * Defines the possible states of this Message. */ @@ -128,23 +132,21 @@ class Receiver { IN_PROGRESS, //< Receiver is in the process of receiving this // message. COMPLETED, //< Receiver has received the entire message. - DROPPED, //< Message was COMPLETED but the Receiver has lost - //< communication with the Sender. }; explicit Message(Receiver* receiver, Driver* driver, - size_t packetHeaderLength, size_t messageLength, + size_t packetHeaderLength, int messageLength, Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) - : receiver(receiver) - , driver(driver) + : driver(driver) , id(id) + , bucket(receiver->messageBuckets.getBucket(id)) , source(source) , TRANSPORT_HEADER_LENGTH(packetHeaderLength) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) - , numExpectedPackets((messageLength + PACKET_DATA_LENGTH - 1) / - PACKET_DATA_LENGTH) + , numExpectedPackets( + Util::roundUpIntDiv(messageLength, PACKET_DATA_LENGTH)) , numUnscheduledPackets(numUnscheduledPackets) , scheduled(numExpectedPackets > numUnscheduledPackets) , start(0) @@ -156,27 +158,25 @@ class Receiver { , state(Message::State::IN_PROGRESS) , bucketNode(this) , receivedMessageNode(this) - , messageTimeout(this) + , numResendTimeouts(0) , resendTimeout(this) , scheduledMessageInfo(this, messageLength) {} virtual ~Message(); - virtual void acknowledge() const; - virtual bool dropped() const; - virtual void fail() const; - virtual size_t get(size_t offset, void* destination, - size_t count) const; - virtual size_t length() const; - virtual void strip(size_t count); - virtual void release(); + void acknowledge() const override; + size_t get(size_t offset, void* destination, + size_t count) const override; + size_t length() const override; + void strip(size_t count) override; + void release() override; /** * Return the current state of this message. */ State getState() const { - return state.load(); + return state.load(std::memory_order_acquire); } private: @@ -186,9 +186,6 @@ class Receiver { Driver::Packet* getPacket(size_t index) const; bool setPacket(size_t index, Driver::Packet* packet); - /// The Receiver responsible for this message. - Receiver* const receiver; - /// Driver from which packets were received and to which they should be /// returned when this message is no longer needed. Driver* const driver; @@ -196,6 +193,9 @@ class Receiver { /// Contains the unique identifier for this message. const Protocol::MessageId id; + /// Message bucket this message belongs to. + MessageBucket* const bucket; + /// Contains source address this message. const SocketAddress source; @@ -245,9 +245,9 @@ class Receiver { /// message when it has been completely received. Intrusive::List::Node receivedMessageNode; - /// Intrusive structure used by the Receiver to keep track when the - /// receiving of this message should be considered failed. - Timeout messageTimeout; + /// Number of resend timeouts that occurred in a row. Access to this + /// structure is protected by the associated MessageBucket::mutex. + int numResendTimeouts; /// Intrusive structure used by the Receiver to keep track when /// unreceived parts of this message should be re-requested. @@ -271,22 +271,33 @@ class Receiver { /** * MessageBucket constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. + * @param receiver + * Receiver that owns this bucket. * @param resendIntervalCycles * Number of cycles of inactivity to wait between requesting * retransmission of un-received parts of a Message. * liveness of a Message. */ - explicit MessageBucket(uint64_t messageTimeoutCycles, + explicit MessageBucket(Receiver* receiver, uint64_t resendIntervalCycles) - : mutex() + : receiver(receiver) + , mutex() , messages() - , messageTimeouts(messageTimeoutCycles) , resendTimeouts(resendIntervalCycles) {} + /** + * Destruct a MessageBucket. Will destroy all contained Message objects. + */ + ~MessageBucket() + { + // Intrusive::List is not responsible for destructing its elements; + // it must be done manually. + for (auto& message : messages) { + receiver->messageAllocator.destroy(&message); + } + } + /** * Return the Message with the given MessageId. * @@ -302,26 +313,24 @@ class Receiver { const SpinLock::Lock& lock) { (void)lock; - Message* message = nullptr; - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (it->id == msgId) { - message = &(*it); - break; + for (auto& it : messages) { + if (it.id == msgId) { + return ⁢ } } - return message; + return nullptr; } + /// The Receiver that owns this bucket. + Receiver* const receiver; + /// Mutex protecting the contents of this bucket. SpinLock mutex; /// Collection of inbound messages Intrusive::List messages; - /// Maintains Message objects in increasing order of timeout. - TimeoutManager messageTimeouts; - - /// Maintains Message object in increase order of resend timeout. + /// Maintains Message object in increasing order of resend timeout. TimeoutManager resendTimeouts; }; @@ -348,66 +357,41 @@ class Receiver { static_assert(NUM_BUCKETS == HASH_KEY_MASK + 1); /** - * Helper method to create the set of buckets. + * MessageBucketMap constructor. * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. * @param resendIntervalCycles * Number of cycles of inactivity to wait between requesting * retransmission of un-received parts of a Message. * liveness of a Message. */ - static std::array makeBuckets( - uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) + explicit MessageBucketMap(uint64_t resendIntervalCycles) + : buckets() + , hasher() { - std::array buckets; + buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets[i] = new MessageBucket(messageTimeoutCycles, - resendIntervalCycles); + buckets.emplace_back(resendIntervalCycles); } - return buckets; } - /** - * MessageBucketMap constructor. - * - * @param messageTimeoutCycles - * Number of cycles of inactivity to wait before a Message is - * considered failed. - * @param resendIntervalCycles - * Number of cycles of inactivity to wait between requesting - * retransmission of un-received parts of a Message. - * liveness of a Message. - */ - explicit MessageBucketMap(uint64_t messageTimeoutCycles, - uint64_t resendIntervalCycles) - : buckets(makeBuckets(messageTimeoutCycles, resendIntervalCycles)) - , hasher() - {} - /** * MessageBucketMap destructor. */ - ~MessageBucketMap() - { - for (int i = 0; i < NUM_BUCKETS; ++i) { - delete buckets[i]; - } - } + ~MessageBucketMap() = default; /** * Return the MessageBucket that should hold a Message with the given * MessageId. */ - MessageBucket* getBucket(const Protocol::MessageId& msgId) const + MessageBucket* getBucket(const Protocol::MessageId& msgId) { uint index = hasher(msgId) & HASH_KEY_MASK; - return buckets[index]; + return &buckets[index]; } - /// Array of buckets. - std::array const buckets; + /// Array of NUM_BUCKETS buckets. Defined as a vector to avoid the need + /// for a default constructor in MessageBucket. + std::vector buckets; /// MessageId hash function container. Protocol::MessageId::Hasher hasher; @@ -415,6 +399,9 @@ class Receiver { /** * Holds the incoming scheduled messages from another transport. + * + * The lifetime of a Peer is the same as Receiver: we never destruct Peer + * objects when the transport is running. */ struct Peer { /** @@ -426,11 +413,17 @@ class Receiver { {} /** - * Peer destructor. + * Peer destructor. Only invoked from the destructor of Receiver. */ ~Peer() { - scheduledMessages.clear(); + // By the time we need to destruct a peer, all Message's coming from + // it should have been released. + assert(scheduledMessages.empty()); + + // To keep Peer (constructor) simple, we don't store a reference to + // the outer Receiver in Peer. As a result, Receiver is responsible + // for clearing schedulerPeers. } /** @@ -442,7 +435,7 @@ class Receiver { { assert(!a.scheduledMessages.empty()); assert(!b.scheduledMessages.empty()); - ScheduledMessageInfo::ComparePriority comp; + Message::ComparePriority comp; return comp(a.scheduledMessages.front(), b.scheduledMessages.front()); } @@ -454,8 +447,6 @@ class Receiver { Intrusive::List::Node scheduledPeerNode; }; - void dropMessage(Receiver::Message* message); - void checkMessageTimeouts(uint64_t now, MessageBucket* bucket); void checkResendTimeouts(uint64_t now, MessageBucket* bucket); void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); @@ -463,12 +454,17 @@ class Receiver { void updateSchedule(Message* message, const SpinLock::Lock& lock); /// Driver with which all packets will be sent and received. This driver - /// is chosen by the Transport that owns this Sender. + /// is chosen by the Transport that owns this Receiver. Driver* const driver; - /// Provider of network packet priority and grant policy decisions. + /// Provider of network packet priority and grant policy decisions. Not + /// owned by this class. Policy::Manager* const policyManager; + /// The number of resend timeouts to occur before declaring a message + /// timeout. + const int MESSAGE_TIMEOUT_INTERVALS; + /// Tracks the set of inbound messages being received by this Receiver. MessageBucketMap messageBuckets; @@ -504,12 +500,7 @@ class Receiver { std::atomic nextBucketIndex; /// Used to allocate Message objects. - struct { - /// Protects the messageAllocator.pool - SpinLock mutex; - /// Pool from which Message objects can be allocated. - ObjectPool pool; - } messageAllocator; + ObjectPool messageAllocator; }; } // namespace Core diff --git a/src/Sender.cc b/src/Sender.cc index ee6c097..35544be 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -52,7 +52,7 @@ Sender::Sender(uint64_t transportId, Driver* driver, , DRIVER_QUEUED_BYTE_LIMIT(2 * driver->getMaxPayloadSize()) , MESSAGE_TIMEOUT_INTERVALS( Util::roundUpIntDiv(messageTimeoutCycles, pingIntervalCycles)) - , messageBuckets(pingIntervalCycles) + , messageBuckets(this, pingIntervalCycles) , queueMutex() , sendQueue() , sending() @@ -266,9 +266,8 @@ Sender::handleGrantPacket(Driver::Packet* packet) // that holds the last granted byte is also considered granted. This // can cause at most 1 packet worth of data to be sent without a grant // but allows the sender to always send full packets. - int incomingGrantIndex = - (grantHeader->byteLimit + info->packets->PACKET_DATA_LENGTH - 1) / - info->packets->PACKET_DATA_LENGTH; + int incomingGrantIndex = Util::roundUpIntDiv( + grantHeader->byteLimit, info->packets->PACKET_DATA_LENGTH); // Make that grants don't exceed the number of packets. Internally, // the sender always assumes that packetsGranted <= numPackets. @@ -705,9 +704,8 @@ Sender::startMessage(Sender::Message* message, bool restart) // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( message->destination.ip, message->messageLength); - uint16_t unscheduledIndexLimit = - ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / - message->PACKET_DATA_LENGTH); + uint16_t unscheduledIndexLimit = Util::roundUpIntDiv( + policy.unscheduledByteLimit, message->PACKET_DATA_LENGTH); if (!restart) { // Fill out packet headers. @@ -778,8 +776,7 @@ Sender::startMessage(Sender::Message* message, bool restart) info->packetsSent = 0; // Insert and move message into the correct order in the priority queue. sendQueue.push_front(&info->sendQueueNode); - Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::deprioritize(&sendQueue, &info->sendQueueNode); sendReady.store(true); } @@ -924,9 +921,7 @@ Sender::trySend() info->unsentBytes -= packetDataBytes; // The Message's unsentBytes only ever decreases. See if the // updated Message should move up in the queue. - Intrusive::prioritize( - &sendQueue, &info->sendQueueNode, - QueuedMessageInfo::ComparePriority()); + Intrusive::prioritize(&sendQueue, &info->sendQueueNode); ++info->packetsSent; } if (info->packetsSent >= info->packets->numPackets) { diff --git a/src/Sender.h b/src/Sender.h index e26b0ca..8933b59 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -68,18 +68,6 @@ class Sender { * Contains metadata for a Message that has been queued to be sent. */ struct QueuedMessageInfo { - /** - * Implements a binary comparison function for the strict weak priority - * ordering of two Message objects. - */ - struct ComparePriority { - bool operator()(const Message& a, const Message& b) - { - return a.queuedMessageInfo.unsentBytes < - b.queuedMessageInfo.unsentBytes; - } - }; - /** * QueuedMessageInfo constructor. * @@ -122,6 +110,19 @@ class Sender { */ class Message : public Homa::OutMessage { public: + + /** + * Implements a binary comparison function for the strict weak priority + * ordering of two Message objects. + */ + struct ComparePriority { + bool operator()(const Message& a, const Message& b) + { + return a.queuedMessageInfo.unsentBytes < + b.queuedMessageInfo.unsentBytes; + } + }; + /** * Construct an Message. */ @@ -263,16 +264,31 @@ class Sender { /** * MessageBucket constructor. * + * @param Sender + * Sender that owns this bucket. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - explicit MessageBucket(uint64_t pingIntervalCycles) - : mutex() + explicit MessageBucket(Sender* sender, uint64_t pingIntervalCycles) + : sender(sender) + , mutex() , messages() , pingTimeouts(pingIntervalCycles) {} + /** + * Destruct a MessageBucket. Will destroy all contained Message objects. + */ + ~MessageBucket() + { + // Intrusive::List is not responsible for destructing its elements; + // it must be done manually. + for (auto& message : messages) { + sender->messageAllocator.destroy(&message); + } + } + /** * Return the Message with the given MessageId. * @@ -296,6 +312,9 @@ class Sender { return nullptr; } + /// Sender that owns this object. + Sender* const sender; + /// Mutex protecting the contents of this bucket. SpinLock mutex; @@ -331,17 +350,19 @@ class Sender { /** * MessageBucketMap constructor. * + * @param sender + * Sender that owns this bucket map. * @param pingIntervalCycles * Number of cycles of inactivity to wait between checking on the * liveness of a Message. */ - explicit MessageBucketMap(uint64_t pingIntervalCycles) + explicit MessageBucketMap(Sender* sender, uint64_t pingIntervalCycles) : buckets() , hasher() { buckets.reserve(NUM_BUCKETS); for (int i = 0; i < NUM_BUCKETS; ++i) { - buckets.emplace_back(pingIntervalCycles); + buckets.emplace_back(sender, pingIntervalCycles); } } diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 38ebddd..894862a 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -135,6 +135,7 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) sender->handleGrantPacket(packet); break; case Protocol::Packet::DONE: + // fixme: rename DONE to ACK? Perf::counters.rx_done_pkts.add(1); sender->handleDonePacket(packet); break; @@ -155,6 +156,7 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) sender->handleUnknownPacket(packet); break; case Protocol::Packet::ERROR: + // FIXME: remove ERROR? Perf::counters.rx_error_pkts.add(1); sender->handleErrorPacket(packet); break; From 1973fdd3b1cf4c2b68ea57866b4408f7d13ea965 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Mon, 26 Oct 2020 23:21:59 -0700 Subject: [PATCH 30/33] pulled in the changes that simplify Homa::InMessage + improvements based on Collin's comments --- include/Homa/Homa.h | 29 +--- include/Homa/Util.h | 10 +- src/Receiver.cc | 375 +++++++++++++++---------------------------- src/Receiver.h | 139 ++++++++-------- src/Sender.cc | 94 ++++++----- src/Sender.h | 34 ++-- src/TransportImpl.cc | 24 +-- src/TransportImpl.h | 2 +- 8 files changed, 275 insertions(+), 432 deletions(-) diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index 62fd4b3..b430118 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -71,40 +71,19 @@ class InMessage { virtual void acknowledge() const = 0; /** - * Get the contents of a specified range of bytes in the Message by - * copying them into the provided destination memory region. - * - * @param offset - * The number of bytes in the Message preceding the range of bytes - * being requested. - * @param destination - * The pointer to the memory region into which the requested byte - * range will be copied. The caller must ensure that the buffer is - * big enough to hold the requested number of bytes. - * @param count - * The number of bytes being requested. + * Get the underlying contiguous memory buffer serving as message storage. + * The buffer will be large enough to hold at least length() bytes. * * @return - * The number of bytes actually copied out. This number may be less - * than "num" if the requested byte range exceeds the range of - * bytes in the Message. + * Pointer to the message buffer. */ - virtual size_t get(size_t offset, void* destination, - size_t count) const = 0; + virtual void* data() const = 0; /** * Return the number of bytes this Message contains. */ virtual size_t length() const = 0; - /** - * Remove a number of bytes from the beginning of the Message. - * - * @param count - * Number of bytes to remove. - */ - virtual void strip(size_t count) = 0; - protected: /** * Signal that this message is no longer needed. The caller should not diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 72f0752..462e1ff 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -74,13 +74,13 @@ isPowerOfTwo(num_type n) * Round up the result of x divided by y, where both x and y are positive * integers. */ -template -constexpr num_type -roundUpIntDiv(num_type x, Y y) +template +constexpr num_type_x +roundUpIntDiv(num_type_x x, num_type_y y) { - static_assert(std::is_integral::value, "Integral required."); + static_assert(std::is_integral::value, "Integral required."); assert(x > 0 && y > 0); - num_type yy = downCast(y); + num_type_x yy = downCast(y); return (x + yy - 1) / yy; } diff --git a/src/Receiver.cc b/src/Receiver.cc index 8cb948d..0af7e2f 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -42,6 +42,8 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, , policyManager(policyManager) , MESSAGE_TIMEOUT_INTERVALS( Util::roundUpIntDiv(messageTimeoutCycles, resendIntervalCycles)) + , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) + , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) , messageBuckets(resendIntervalCycles) , schedulerMutex() , scheduledPeers() @@ -49,6 +51,7 @@ Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, , granting() , nextBucketIndex(0) , messageAllocator() + , externalBuffers() {} /** @@ -72,6 +75,14 @@ Receiver::~Receiver() // purpose, we decided to write the cleanup procedure explicitly anyway. // Destruct all MessageBucket's and the Messages within. + for (auto& bucket : messageBuckets.buckets) { + // Intrusive::List is not responsible for destructing its elements; + // it must be done manually. + for (auto& message : bucket.messages) { + messageAllocator.destroy(&message); + } + assert(bucket.resendTimeouts.empty()); + } messageBuckets.buckets.clear(); // Destruct all Peer's. Peer's must be removed from scheduledPeers first. @@ -84,78 +95,6 @@ Receiver::~Receiver() } } -/** - * Execute the common processing logic that is shared among all incoming - * packets. - * - * @param packet - * Incoming packet to be processed. - * @param createIfAbsent - * True if a new Message should be constructed when no matching message - * can be found. - * @param sourceIp - * Source IP address of the packet. Only valid when createIfAbsent is true. - * @return - * Pointer to the message targeted by the incoming packet, or nullptr if no - * matching message can be found. - */ -Receiver::Message* -Receiver::handleIncomingPacket(Driver::Packet* packet, bool createIfAbsent, - IpAddress* sourceIp) -{ - // Find the message bucket. - Protocol::Packet::CommonHeader* commonHeader = - static_cast(packet->payload); - Protocol::MessageId msgId = commonHeader ->messageId; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - - // Acquire the bucket mutex to ensure that a new message can be constructed - // and inserted to the bucket atomically. - SpinLock::Lock lock_bucket(bucket->mutex); - - // Find the target message, or construct a new message if necessary. - Message* message = bucket->findMessage(msgId, lock_bucket); - if (message == nullptr && createIfAbsent) { - // Construct a new message - Protocol::Packet::DataHeader* header = - static_cast(packet->payload); - int messageLength = header->totalLength; - int numUnscheduledPackets = header->unscheduledIndexLimit; - SocketAddress srcAddress = { - .ip = *sourceIp, .port = be16toh(header->common.prefix.sport)}; - message = messageAllocator.construct( - this, driver, sizeof(Protocol::Packet::DataHeader), messageLength, - header->common.messageId, srcAddress, numUnscheduledPackets); - Perf::counters.allocated_rx_messages.add(1); - - // Start tracking the message. - bucket->messages.push_back(&message->bucketNode); - - policyManager->signalNewMessage( - message->source.ip, header->policyVersion, header->totalLength); - if (message->scheduled) { - // Message needs to be scheduled. - SpinLock::Lock lock_scheduler(schedulerMutex); - schedule(message, lock_scheduler); - } - } - - // The sender is still alive; reschedule the timeout. - if (message != nullptr) { - // Optimization: for single-packet messages, it would be wasteful to set - // a resendTimeout just to cancel it immediately. - bool needTimeout = (message->numExpectedPackets > 1); - - // If the message is not in progress, its resend timeout must have been - // cancelled in setState(); don't re-insert a new one. - if (needTimeout && message->getState() == Message::State::IN_PROGRESS) { - message->numResendTimeouts = 0; - bucket->resendTimeouts.setTimeout(&message->resendTimeout); - } - } - return message; -} - /** * Process an incoming DATA packet. * @@ -163,40 +102,77 @@ Receiver::handleIncomingPacket(Driver::Packet* packet, bool createIfAbsent, * The incoming packet to be processed. * @param sourceIp * Source IP address of the packet. - * @return - * True if the Receiver decides to take ownership of the packet. False - * if the Receiver has no more use of this packet and it can be released - * to the driver. */ -bool +void Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { - // Find the message - Message* message = handleIncomingPacket(packet, true, &sourceIp); Protocol::Packet::DataHeader* header = static_cast(packet->payload); + Protocol::MessageId id = header->common.messageId; + + bool needSchedule = false; + bool messageComplete; + MessageBucket* bucket = messageBuckets.getBucket(id); + Message* message; + { + // Scoped critical section guarded by MessageBucket::mutex; this ensures + // that the bucket mutex is dropped before acquiring the schedulerMutex. + SpinLock::Lock lock_bucket(bucket->mutex); + message = bucket->findMessage(id, lock_bucket); + if (message == nullptr) { + // New message + int messageLength = header->totalLength; + int numUnscheduledPackets = header->unscheduledIndexLimit; + SocketAddress srcAddress = { + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; + message = + messageAllocator.construct(this, driver, messageLength, id, + srcAddress, numUnscheduledPackets); + Perf::counters.allocated_rx_messages.add(1); + + // Start tracking the message. + bucket->messages.push_back(&message->bucketNode); + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); + + if (message->scheduled) { + // Don't schedule the message while holding the bucket mutex. + needSchedule = true; + } + } + + // Add the packet, but don't copy the payload yet. + if (message->occupied.test(header->index)) { + // Must be a duplicate packet; drop it. + return; + } + message->occupied.set(header->index); + message->numPackets++; + messageComplete = (message->numPackets == message->numExpectedPackets); + } + + // Copy the payload into the message buffer. + std::memcpy(message->buffer + header->index * PACKET_DATA_LENGTH, + static_cast(packet->payload) + TRANSPORT_HEADER_LENGTH, + packet->length - TRANSPORT_HEADER_LENGTH); // Sanity checks assert(message->source.ip == sourceIp); assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); - // Add the packet - bool packetAdded = message->setPacket(header->index, packet); - if (!packetAdded) { - // Must be a duplicate packet; drop it. - return false; - } - - // Update schedule for scheduled messages. - if (message->scheduled) { + if (needSchedule) { + // A new Message needs to be entered into the scheduler. + SpinLock::Lock lock_scheduler(schedulerMutex); + schedule(message, lock_scheduler); + } else if (message->scheduled) { + // Update schedule for an existing scheduled message. SpinLock::Lock lock_scheduler(schedulerMutex); ScheduledMessageInfo* info = &message->scheduledMessageInfo; // Update the schedule if the message is still being scheduled // (i.e. still linked to a scheduled peer). if (info->peer != nullptr) { - int packetDataBytes = - packet->length - message->TRANSPORT_HEADER_LENGTH; + int packetDataBytes = packet->length - TRANSPORT_HEADER_LENGTH; assert(info->bytesRemaining >= packetDataBytes); info->bytesRemaining -= packetDataBytes; updateSchedule(message, lock_scheduler); @@ -204,13 +180,10 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) } // Complete the message if all packets have been received. - if (message->numPackets == message->numExpectedPackets) { + if (messageComplete) { message->state.store(Message::State::COMPLETED, std::memory_order_release); - // Optimization: for single-packet messages, there is no need to cancel - // the resendTimeout if we don't insert one in the first place. - if (message->numPackets > 1) { - MessageBucket* bucket = message->bucket; + if (message->needTimeout) { SpinLock::Lock lock(bucket->mutex); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); } @@ -220,7 +193,6 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) receivedMessages.queue.push_back(&message->receivedMessageNode); Perf::counters.received_rx_messages.add(1); } - return true; } /** @@ -232,7 +204,20 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) void Receiver::handleBusyPacket(Driver::Packet* packet) { - handleIncomingPacket(packet, false, nullptr); + Protocol::Packet::BusyHeader* header = + static_cast(packet->payload); + Protocol::MessageId id = header->common.messageId; + + MessageBucket* bucket = messageBuckets.getBucket(id); + SpinLock::Lock lock_bucket(bucket->mutex); + Message* message = bucket->findMessage(id, lock_bucket); + if (message != nullptr) { + // Sender has replied BUSY to our RESEND request; consider this message + // still active. + if (message->getState() == Message::State::IN_PROGRESS) { + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } + } } /** @@ -246,9 +231,18 @@ Receiver::handleBusyPacket(Driver::Packet* packet) void Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { - Message* message = handleIncomingPacket(packet, false, nullptr); + Protocol::Packet::PingHeader* header = + static_cast(packet->payload); + Protocol::MessageId id = header->common.messageId; + + MessageBucket* bucket = messageBuckets.getBucket(id); + Message* message; + { + SpinLock::Lock lock_bucket(bucket->mutex); + message = bucket->findMessage(id, lock_bucket); + } if (message != nullptr) { - // We are here either because a GRANT got lost, or we haven't issued a + // We are here either because a GRANT got lost, or we haven't issued a // GRANT in along time. Send out the latest GRANT if one exists or just // an "empty" GRANT to let the Sender know we are aware of the message. @@ -274,10 +268,8 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); - Protocol::Packet::CommonHeader* header = - static_cast(packet->payload); ControlPacket::send( - driver, sourceIp, header->messageId); + driver, sourceIp, header->common.messageId); } } @@ -361,30 +353,11 @@ Receiver::Message::~Message() receiver->receivedMessages.queue.remove(&receivedMessageNode); } - // Find contiguous ranges of packets and release them back to the - // driver. - int num = 0; - int index = 0; - int packetsFound = 0; - for (int i = 0; i < MAX_MESSAGE_PACKETS && packetsFound < numPackets; ++i) { - if (occupied.test(i)) { - if (num == 0) { - // First packet in new region. - index = i; - } - ++num; - ++packetsFound; - } else { - if (num != 0) { - // End of region; release the last region. - driver->releasePackets(&packets[index], num); - num = 0; - } - } - } - if (num != 0) { - // Release the last region (if any). - driver->releasePackets(&packets[index], num); + // Release the external buffer, if any. + if (buffer != internalBuffer) { + MessageBuffer* externalBuf = + (MessageBuffer*)buffer; + receiver->externalBuffers.destroy(externalBuf); } } @@ -399,48 +372,12 @@ Receiver::Message::acknowledge() const } /** - * @copydoc Homa::InMessage::get() + * @copydoc Homa::InMessage::data() */ -size_t -Receiver::Message::get(size_t offset, void* destination, size_t count) const +void* +Receiver::Message::data() const { - // This operation should be performed with the offset relative to the - // logical beginning of the Message. - int _offset = Util::downCast(offset); - int _count = Util::downCast(count); - int realOffset = _offset + start; - int packetIndex = realOffset / PACKET_DATA_LENGTH; - int packetOffset = realOffset % PACKET_DATA_LENGTH; - int bytesCopied = 0; - - // Offset is passed the end of the message. - if (realOffset >= messageLength) { - return 0; - } - - if (realOffset + _count > messageLength) { - _count = messageLength - realOffset; - } - - while (bytesCopied < _count) { - uint32_t bytesToCopy = - std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); - Driver::Packet* packet = getPacket(packetIndex); - if (packet != nullptr) { - char* source = static_cast(packet->payload); - source += packetOffset + TRANSPORT_HEADER_LENGTH; - std::memcpy(static_cast(destination) + bytesCopied, source, - bytesToCopy); - } else { - ERROR("Message is missing data starting at packet index %u", - packetIndex); - break; - } - bytesCopied += bytesToCopy; - packetIndex++; - packetOffset = 0; - } - return bytesCopied; + return buffer; } /** @@ -449,16 +386,7 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const size_t Receiver::Message::length() const { - return Util::downCast(messageLength - start); -} - -/** - * @copydoc Homa::InMessage::strip() - */ -void -Receiver::Message::strip(size_t count) -{ - start = std::min(start + Util::downCast(count), messageLength); + return Util::downCast(messageLength); } /** @@ -470,52 +398,6 @@ Receiver::Message::release() bucket->receiver->messageAllocator.destroy(this); } -/** - * Return the Packet with the given index. - * - * @param index - * A Packet's index in the array of packets that form the message. - * "packet index = "packet message offset" / PACKET_DATA_LENGTH - * @return - * Pointer to a Packet at the given index if it exists; nullptr otherwise. - */ -Driver::Packet* -Receiver::Message::getPacket(size_t index) const -{ - if (occupied.test(index)) { - return packets[index]; - } - return nullptr; -} - -/** - * Store the given packet as the Packet of the given index if one does not - * already exist. - * - * Responsibly for releasing the given Packet is passed to this context if the - * Packet is stored (returns true). - * - * @param index - * The Packet's index in the array of packets that form the message. - * "packet index = "packet message offset" / PACKET_DATA_LENGTH - * @param packet - * The packet pointer that should be stored. - * @return - * True if the packet was stored; false if a packet already exists (the new - * packet is not stored). - */ -bool -Receiver::Message::setPacket(size_t index, Driver::Packet* packet) -{ - if (occupied.test(index)) { - return false; - } - packets[index] = packet; - occupied.set(index); - numPackets++; - return true; -} - /** * Process any inbound messages that may need to issue resends. * @@ -533,9 +415,8 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) return; } + SpinLock::Lock lock_bucket(bucket->mutex); while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); - // No remaining timeouts. if (bucket->resendTimeouts.empty()) { break; @@ -561,10 +442,17 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) // This Receiver expected to have heard from the Sender within the // last timeout period but it didn't. Request a resend of granted // packets in case DATA packets got lost. - int index = 0; - int num = 0; + uint16_t index = 0; + uint16_t num = 0; int grantIndexLimit = message->numUnscheduledPackets; + // The RESEND also includes the current granted priority so that it + // can act as a GRANT in case a GRANT was lost. If this message + // hasn't been scheduled (i.e. no grants have been sent) then the + // priority will hold the default value; this is ok since the Sender + // will ignore the priority field for resends of purely unscheduled + // packets (see Sender::handleResendPacket()). + int resendPriority = 0; if (message->scheduled) { SpinLock::Lock lock_scheduler(schedulerMutex); ScheduledMessageInfo* info = &message->scheduledMessageInfo; @@ -573,15 +461,16 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) // Sender is blocked on this Receiver; all granted packets // have already been received. No need to check for resend. continue; - } else if (grantIndexLimit * message->PACKET_DATA_LENGTH < + } else if (grantIndexLimit * PACKET_DATA_LENGTH < info->bytesGranted) { - grantIndexLimit = Util::roundUpIntDiv( - info->bytesGranted, message->PACKET_DATA_LENGTH); + grantIndexLimit = + Util::roundUpIntDiv(info->bytesGranted, PACKET_DATA_LENGTH); } + resendPriority = info->priority; } for (int i = 0; i < grantIndexLimit; ++i) { - if (message->getPacket(i) == nullptr) { + if (!message->occupied.test(i)) { // Unreceived packet if (num == 0) { // First unreceived packet @@ -592,34 +481,20 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) // Received packet if (num != 0) { // Send out the range of packets found so far. - // - // The RESEND also includes the current granted priority - // so that it can act as a GRANT in case a GRANT was - // lost. If this message hasn't been scheduled (i.e. no - // grants have been sent) then the priority will hold - // the default value; this is ok since the Sender will - // ignore the priority field for resends of purely - // unscheduled packets (see - // Sender::handleResendPacket()). - SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source.ip, message->id, - Util::downCast(index), - Util::downCast(num), - message->scheduledMessageInfo.priority); + message->driver, message->source.ip, message->id, index, + num, resendPriority); num = 0; } } } if (num != 0) { // Send out the last range of packets found. - SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source.ip, message->id, - Util::downCast(index), Util::downCast(num), - message->scheduledMessageInfo.priority); + message->driver, message->source.ip, message->id, index, num, + resendPriority); } } } @@ -728,8 +603,8 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) &info->scheduledMessageNode); info->peer = peer; if (!scheduledPeers.contains(&peer->scheduledPeerNode)) { - // Must be the only message of this peer; push the peer to the - // end of list to be moved later. + // Must be the only message of this peer; push the peer to the front of + // list to be moved later. assert(peer->scheduledMessages.size() == 1); scheduledPeers.push_front(&peer->scheduledPeerNode); Intrusive::deprioritize(&scheduledPeers, diff --git a/src/Receiver.h b/src/Receiver.h index 4bc64bd..045b86e 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -48,7 +48,7 @@ class Receiver { uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual bool handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); virtual void handleBusyPacket(Driver::Packet* packet); virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); @@ -61,8 +61,11 @@ class Receiver { struct MessageBucket; struct Peer; - Message* handleIncomingPacket(Driver::Packet* packet, bool createIfAbsent, - IpAddress* sourceIp = nullptr); + /// Define the maximum number of packets that a message can hold. + static const int MAX_MESSAGE_PACKETS = 1024; + + /// Define the maximum number of bytes within a message. + static const int MAX_MESSAGE_LENGTH = 1u << 20u; /** * Contains metadata for a Message that requires additional GRANTs. @@ -134,41 +137,39 @@ class Receiver { COMPLETED, //< Receiver has received the entire message. }; - explicit Message(Receiver* receiver, Driver* driver, - size_t packetHeaderLength, int messageLength, + explicit Message(Receiver* receiver, Driver* driver, int messageLength, Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) : driver(driver) , id(id) , bucket(receiver->messageBuckets.getBucket(id)) , source(source) - , TRANSPORT_HEADER_LENGTH(packetHeaderLength) - , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - - TRANSPORT_HEADER_LENGTH) - , numExpectedPackets( - Util::roundUpIntDiv(messageLength, PACKET_DATA_LENGTH)) + , numExpectedPackets(Util::roundUpIntDiv( + messageLength, receiver->PACKET_DATA_LENGTH)) , numUnscheduledPackets(numUnscheduledPackets) , scheduled(numExpectedPackets > numUnscheduledPackets) - , start(0) , messageLength(messageLength) , numPackets(0) , occupied() - // packets is not initialized to reduce the work done during - // construction. See Message::occupied. + , buffer(messageLength <= Util::arrayLength(internalBuffer) + ? internalBuffer + : receiver->externalBuffers.construct()->raw) + // No need to zero-out internalBuffer. , state(Message::State::IN_PROGRESS) , bucketNode(this) , receivedMessageNode(this) + , needTimeout(numExpectedPackets > 1) , numResendTimeouts(0) , resendTimeout(this) , scheduledMessageInfo(this, messageLength) - {} + { + assert(messageLength <= MAX_MESSAGE_LENGTH); + } virtual ~Message(); void acknowledge() const override; - size_t get(size_t offset, void* destination, - size_t count) const override; + void* data() const override; size_t length() const override; - void strip(size_t count) override; void release() override; /** @@ -180,12 +181,6 @@ class Receiver { } private: - /// Define the maximum number of packets that a message can hold. - static const int MAX_MESSAGE_PACKETS = 1024; - - Driver::Packet* getPacket(size_t index) const; - bool setPacket(size_t index, Driver::Packet* packet); - /// Driver from which packets were received and to which they should be /// returned when this message is no longer needed. Driver* const driver; @@ -199,13 +194,6 @@ class Receiver { /// Contains source address this message. const SocketAddress source; - /// Number of bytes at the beginning of each Packet that should be - /// reserved for the Homa transport header. - const int TRANSPORT_HEADER_LENGTH; - - /// Number of bytes of data in each full packet. - const int PACKET_DATA_LENGTH; - /// Number of packets the message is expected to contain. const int numExpectedPackets; @@ -216,22 +204,22 @@ class Receiver { /// GRANTs to be sent. const bool scheduled; - /// First byte where data is or will go if empty. - int start; - /// Number of bytes in this Message including any stripped bytes. - int messageLength; + const int messageLength; - /// Number of packets currently contained in this message. + /// Number of packets currently contained in this message. Protected by + /// MessageBucket::mutex. int numPackets; - /// Bit array representing which entires in the _packets_ array are set. - /// Used to avoid having to zero out the entire _packets_ array. + /// Bit array representing which packets in this message are received. + /// Protected by MessageBucket::mutex. std::bitset occupied; - /// Collection of Packet objects that make up this context's Message. - /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + /// Pointer to the contiguous memory buffer serving as message storage. + char* const buffer; + + /// Internal memory buffer used to store messages within 2KB. + char internalBuffer[2048]; /// This message's current state. std::atomic state; @@ -245,12 +233,18 @@ class Receiver { /// message when it has been completely received. Intrusive::List::Node receivedMessageNode; - /// Number of resend timeouts that occurred in a row. Access to this - /// structure is protected by the associated MessageBucket::mutex. + /// True if this message needs to be tracked by the timeout manager. + /// As an inbound message, this variable should only be set to false + /// when the message fits in a single packet. + const bool needTimeout; + + /// Number of resend timeouts that occurred in a row. Protected by + /// MessageBucket::mutex. int numResendTimeouts; /// Intrusive structure used by the Receiver to keep track when - /// unreceived parts of this message should be re-requested. + /// unreceived parts of this message should be re-requested. Protected + /// by MessageBucket::mutex. Timeout resendTimeout; /// Intrusive structure used by the Receiver to keep track of this @@ -261,6 +255,20 @@ class Receiver { friend class Receiver; }; + /** + * Memory buffer used to hold large messages that don't fit in Message's + * internal buffer. It's basically a simple wrapper around an array so + * that it can be allocated from an ObjectPool. + * + * @tparam length + * Number of bytes in the buffer. + */ + template + struct MessageBuffer { + /// Buffer space. + char raw[length]; + }; + /** * A collection of incoming Message objects and their associated timeouts. * @@ -286,17 +294,9 @@ class Receiver { , resendTimeouts(resendIntervalCycles) {} - /** - * Destruct a MessageBucket. Will destroy all contained Message objects. - */ - ~MessageBucket() - { - // Intrusive::List is not responsible for destructing its elements; - // it must be done manually. - for (auto& message : messages) { - receiver->messageAllocator.destroy(&message); - } - } + // MessageBucket's are only destroyed when the transport is destructed; + // all the real work are done in ~Receiver(). + ~MessageBucket() = default; /** * Return the Message with the given MessageId. @@ -405,26 +405,16 @@ class Receiver { */ struct Peer { /** - * Peer constructor. + * Default constructor. Easy to use with std::unorderd_map. */ Peer() : scheduledMessages() , scheduledPeerNode(this) {} - /** - * Peer destructor. Only invoked from the destructor of Receiver. - */ - ~Peer() - { - // By the time we need to destruct a peer, all Message's coming from - // it should have been released. - assert(scheduledMessages.empty()); - - // To keep Peer (constructor) simple, we don't store a reference to - // the outer Receiver in Peer. As a result, Receiver is responsible - // for clearing schedulerPeers. - } + // Peer's are only destroyed when the transport is destructed; all the + // real work are done in ~Receiver(). + ~Peer() = default; /** * Implements a binary comparison function for the strict weak priority @@ -465,11 +455,19 @@ class Receiver { /// timeout. const int MESSAGE_TIMEOUT_INTERVALS; + /// Number of bytes at the beginning of each Packet that should be reserved + /// for the Homa transport header. + const int TRANSPORT_HEADER_LENGTH; + + /// Number of bytes of data in each full packet. + const int PACKET_DATA_LENGTH; + /// Tracks the set of inbound messages being received by this Receiver. MessageBucketMap messageBuckets; /// Protects access to the Receiver's scheduler state (i.e. peerTable, - /// scheduledPeers, and ScheduledMessageInfo). + /// scheduledPeers, and ScheduledMessageInfo). Locking order constraint: + /// MessageBucket::mutex must be acquired before schedulerMutex. SpinLock schedulerMutex; /// Collection of all peers; used for fast access. Access is protected by @@ -501,6 +499,9 @@ class Receiver { /// Used to allocate Message objects. ObjectPool messageAllocator; + + /// Used to allocate large memory buffers outside the Message struct. + ObjectPool> externalBuffers; }; } // namespace Core diff --git a/src/Sender.cc b/src/Sender.cc index 35544be..0cb8c9a 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -193,22 +193,20 @@ Sender::handleResendPacket(Driver::Packet* packet) int index = resendHeader->index; int resendEnd = index + resendHeader->num; - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - // Check if RESEND request is out of range. - if (index >= info->packets->numPackets || - resendEnd > info->packets->numPackets) { + if (index >= message->numPackets || resendEnd > message->numPackets) { WARNING( "Message (%lu, %lu) RESEND request range out of bounds: requested " "range [%d, %d); message only contains %d packets; peer Transport " "may be confused.", message->id.transportId, message->id.sequence, index, resendEnd, - info->packets->numPackets); + message->numPackets); return; } // In case a GRANT may have been lost, consider the RESEND a GRANT. + QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::UniqueLock lock_queue(queueMutex); if (info->packetsGranted < resendEnd) { info->packetsGranted = resendEnd; // Note that the priority of messages under the unscheduled byte limit @@ -218,21 +216,24 @@ Sender::handleResendPacket(Driver::Packet* packet) sendReady.store(true); } - if (index >= info->packetsSent) { + // Release the queue mutex; the rest of the code won't touch the queue. + int packetsSent = info->packetsSent; + lock_queue.unlock(); + if (index >= packetsSent) { // If this RESEND is only requesting unsent packets, it must be that // this Sender has been busy and the Receiver is trying to ensure there // are no lost packets. Reply BUSY and allow this Sender to send DATA // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->packets->destination.ip, info->packets->id); + driver, message->destination.ip, message->id); } else { // There are some packets to resend but only resend packets that have // already been sent. - resendEnd = std::min(resendEnd, info->packetsSent); + resendEnd = std::min(resendEnd, packetsSent); int resendPriority = policyManager->getResendPriority(); for (int i = index; i < resendEnd; ++i) { - Driver::Packet* resendPacket = info->packets->getPacket(i); + Driver::Packet* resendPacket = message->getPacket(i); Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(resendPacket->length); driver->sendPacket(resendPacket, message->destination.ip, @@ -259,27 +260,26 @@ Sender::handleGrantPacket(Driver::Packet* packet) Protocol::Packet::GrantHeader* grantHeader = static_cast(packet->payload); if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { - SpinLock::Lock lock_queue(queueMutex); - QueuedMessageInfo* info = &message->queuedMessageInfo; - // Convert the byteLimit to a packet index limit such that the packet // that holds the last granted byte is also considered granted. This // can cause at most 1 packet worth of data to be sent without a grant // but allows the sender to always send full packets. int incomingGrantIndex = Util::roundUpIntDiv( - grantHeader->byteLimit, info->packets->PACKET_DATA_LENGTH); + grantHeader->byteLimit, message->PACKET_DATA_LENGTH); // Make that grants don't exceed the number of packets. Internally, // the sender always assumes that packetsGranted <= numPackets. - if (incomingGrantIndex > info->packets->numPackets) { + if (incomingGrantIndex > message->numPackets) { WARNING( "Message (%lu, %lu) GRANT exceeds message length; granted " "packets: %d, message packets %d; extra grants are ignored.", message->id.transportId, message->id.sequence, - incomingGrantIndex, info->packets->numPackets); - incomingGrantIndex = info->packets->numPackets; + incomingGrantIndex, message->numPackets); + incomingGrantIndex = message->numPackets; } + QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::Lock lock_queue(queueMutex); if (info->packetsGranted < incomingGrantIndex) { info->packetsGranted = incomingGrantIndex; // Note that the priority of messages under the unscheduled byte @@ -490,16 +490,19 @@ Sender::Message::getStatus() const return state.load(std::memory_order_acquire); } - /** - * Change the status of this message. - * - * All status change must be done by this method. - * - * @param newStatus - * The new status. - * @param deschedule - * True if we should remove this message from the send queue. - */ +/** + * Change the status of this message. + * + * All status change must be done by this method. + * + * @param newStatus + * The new status. + * @param deschedule + * True if we should remove this message from the send queue. + * + * Note: special care must be taken when calling this method to avoid deadlocks + * because our spinlock is not reentrant. + */ void Sender::Message::setStatus(Status newStatus, bool deschedule) { @@ -840,14 +843,13 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) } } - // The following code doesn't access bucket data anymore; release the - // mutex to reduce the critical section. + // Release the bucket mutex to follow the locking order constraint. bucket_lock.unlock(); // Check if sender still has packets to send if (status == OutMessage::Status::IN_PROGRESS) { - SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; + SpinLock::Lock lock_queue(queueMutex); if (info->packetsSent < info->packetsGranted) { // Sender is blocked on itself, no need to send ping continue; @@ -889,7 +891,7 @@ Sender::trySend() * Each time this method is called we will try to send enough packet to keep * the NIC busy but not too many as to cause excessive queue in the NIC. */ - SpinLock::UniqueLock lock_queue(queueMutex); + SpinLock::Lock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. @@ -899,12 +901,11 @@ Sender::trySend() Message& message = *it; assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; - assert(info->packetsGranted <= info->packets->numPackets); + assert(info->packetsGranted <= message.numPackets); while (info->packetsSent < info->packetsGranted) { // There are packets to send idle = false; - Driver::Packet* packet = - info->packets->getPacket(info->packetsSent); + Driver::Packet* packet = message.getPacket(info->packetsSent); assert(packet != nullptr); queuedBytesEstimate += packet->length; // Check if the send limit would be reached... @@ -916,7 +917,7 @@ Sender::trySend() Perf::counters.tx_bytes.add(packet->length); driver->sendPacket(packet, message.destination.ip, info->priority); int packetDataBytes = - packet->length - info->packets->TRANSPORT_HEADER_LENGTH; + packet->length - message.TRANSPORT_HEADER_LENGTH; assert(info->unsentBytes >= packetDataBytes); info->unsentBytes -= packetDataBytes; // The Message's unsentBytes only ever decreases. See if the @@ -924,16 +925,16 @@ Sender::trySend() Intrusive::prioritize(&sendQueue, &info->sendQueueNode); ++info->packetsSent; } - if (info->packetsSent >= info->packets->numPackets) { + if (info->packetsSent >= message.numPackets) { // We have finished sending the message. - // Advance the iterator first to avoid invalidation. - ++it; - - // Unlock the queueMutex before setStatus() since our spinlock is - // non-reentrant. - lock_queue.unlock(); - message.setStatus(OutMessage::Status::SENT, true); + // Note: instead of relying on setStatus(), manually deschedule this + // message since we are already holding the queueMutex (out spinlock + // is not reentrant). + assert(message.numPackets > 1 && + message.getStatus() == OutMessage::Status::IN_PROGRESS); + it = sendQueue.remove(it); + message.setStatus(OutMessage::Status::SENT, false); if (!message.held.load(std::memory_order_acquire)) { // Ok to delete now that the message has been sent. @@ -941,15 +942,10 @@ Sender::trySend() } else if (message.options & OutMessage::Options::NO_KEEP_ALIVE) { // No timeouts need to be checked after sending the message when // the NO_KEEP_ALIVE option is enabled. - - // Note: we can't be holding queueMutex here because our locking - // principle dictates that any bucket mutex must be acquired - // before the send queueMutex. MessageBucket* bucket = message.bucket; - SpinLock::Lock bucket_lock(bucket->mutex); + SpinLock::Lock lock_bucket(bucket->mutex); bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); } - lock_queue.lock(); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; diff --git a/src/Sender.h b/src/Sender.h index 8933b59..f889cea 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -75,33 +75,29 @@ class Sender { * Message to which this metadata is associated. */ explicit QueuedMessageInfo(Message* message) - : packets(message) - , unsentBytes(0) + : unsentBytes(0) , packetsGranted(0) - , priority(0) , packetsSent(0) + , priority(0) , sendQueueNode(message) {} - /// Handle to the queue Message for access to the packets that will - /// be sent. This member documents that the packets are logically owned - /// by the sendQueue and thus protected by the queueMutex. - Message* const packets; - /// The number of bytes that still need to be sent for a queued Message. + /// This variable is used to rank messages in SRPT order so it must be + /// protected by Sender::queueMutex. int unsentBytes; /// The number of packets that can be sent for this Message. int packetsGranted; - /// The network priority at which this Message should be sent. - int priority; - /// The number of packets that have been sent for this Message. int packetsSent; + /// The network priority at which this Message should be sent. + int priority; + /// Intrusive structure used to enqueue the associated Message into - /// the sendQueue. + /// the sendQueue. Protected by Sender::queueMutex. Intrusive::List::Node sendQueueNode; }; @@ -110,7 +106,6 @@ class Sender { */ class Message : public Homa::OutMessage { public: - /** * Implements a binary comparison function for the strict weak priority * ordering of two Message objects. @@ -218,6 +213,7 @@ class Sender { /// constant after send() is invoked. int numPackets; + // FIXME: seems like an overkill? (e.g., packets should be added in order) /// Bit array representing which entries in the _packets_ array are set. /// Used to avoid having to zero out the entire _packets_ array. Must be /// constant after send() is invoked. @@ -315,7 +311,8 @@ class Sender { /// Sender that owns this object. Sender* const sender; - /// Mutex protecting the contents of this bucket. + /// Mutex protecting the contents of this bucket. See Sender::queueMutex + /// for locking order constraints. SpinLock mutex; /// Collection of outbound messages @@ -415,8 +412,13 @@ class Sender { /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - /// Protects the sendQueue. Locking principle: when a bucket mutex is also - /// required, it must be acquired before the sendQueue mutex. + /// Protects the sendQueue, including all member variables of its items. + /// When multiple locks must be acquired, this class follows the locking + /// order constraint below ("<" means "acquired before"): + /// queueMutex < MessageBucket::mutex + /// Usually, it's more natural to acquire coarser-grained locks first, + /// unless inverting the order would make the common code path simpler + /// and/or faster. SpinLock queueMutex; /// A list of outbound messages that have unsent packets. Messages are kept diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 894862a..8178b1a 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -90,14 +90,10 @@ TransportImpl::processPackets() Driver::Packet* packets[MAX_BURST]; IpAddress srcAddrs[MAX_BURST]; int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); - int releaseCount = 0; for (int i = 0; i < numPackets; ++i) { - bool retainPacket = processPacket(packets[i], srcAddrs[i]); - if (!retainPacket) { - packets[releaseCount++] = packets[i]; - } + processPacket(packets[i], srcAddrs[i]); } - driver->releasePackets(packets, releaseCount); + driver->releasePackets(packets, numPackets); if (numPackets > 0) { Perf::counters.active_cycles.add(timer.split()); @@ -105,18 +101,16 @@ TransportImpl::processPackets() } /** - * Process an incoming packet. + * Process an incoming packet. The transport will have no more use of this + * packet afterwards, so the packet can be released to the driver when the + * method returns. * * @param packet * Incoming packet to be processed. * @param sourceIp * Source IP address. - * @return - * True if the transport decides to take ownership of the packet. False - * if the transport has no more use of this packet and it can be released - * to the driver. */ -bool +void TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) { assert(packet->length >= @@ -124,18 +118,16 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) Perf::counters.rx_bytes.add(packet->length); Protocol::Packet::CommonHeader* header = static_cast(packet->payload); - bool retainPacket = false; switch (header->opcode) { case Protocol::Packet::DATA: Perf::counters.rx_data_pkts.add(1); - retainPacket = receiver->handleDataPacket(packet, sourceIp); + receiver->handleDataPacket(packet, sourceIp); break; case Protocol::Packet::GRANT: Perf::counters.rx_grant_pkts.add(1); sender->handleGrantPacket(packet); break; case Protocol::Packet::DONE: - // fixme: rename DONE to ACK? Perf::counters.rx_done_pkts.add(1); sender->handleDonePacket(packet); break; @@ -156,12 +148,10 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) sender->handleUnknownPacket(packet); break; case Protocol::Packet::ERROR: - // FIXME: remove ERROR? Perf::counters.rx_error_pkts.add(1); sender->handleErrorPacket(packet); break; } - return retainPacket; } } // namespace Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 79c7453..ad46f99 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -75,7 +75,7 @@ class TransportImpl : public Transport { private: void processPackets(); - bool processPacket(Driver::Packet* packet, IpAddress source); + void processPacket(Driver::Packet* packet, IpAddress source); /// Unique identifier for this transport. const std::atomic transportId; From 6969f252c3a51940345c7d8da14435f0ad247dfe Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 10 Nov 2020 00:42:25 -0800 Subject: [PATCH 31/33] fixed some trivial issues raised in the code reviews --- src/Drivers/DPDK/DpdkDriverImpl.cc | 23 ++++++++----- src/Drivers/DPDK/DpdkDriverImpl.h | 7 ++-- src/ObjectPool.h | 53 +++++++++++++++++++++--------- src/Receiver.cc | 9 ++--- src/Receiver.h | 4 +-- src/ReceiverTest.cc | 16 ++++----- src/Sender.cc | 4 ++- src/Sender.h | 7 ++-- src/Timeout.h | 10 ++---- 9 files changed, 83 insertions(+), 50 deletions(-) diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 82ff001..b4372ad 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -94,6 +94,7 @@ DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 : config->HIGHEST_PACKET_PRIORITY_OVERRIDE) + , packetLock() , packetPool() , overflowBufferPool() , mbufsOutstanding(0) @@ -146,6 +147,7 @@ DpdkDriver::Impl::Impl(const char* ifname, (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 : config->HIGHEST_PACKET_PRIORITY_OVERRIDE) + , packetLock() , packetPool() , overflowBufferPool() , mbufPool(nullptr) @@ -177,6 +179,7 @@ Driver::Packet* DpdkDriver::Impl::allocPacket() { DpdkDriver::Impl::Packet* packet = nullptr; + SpinLock::Lock lock(packetLock); static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { struct rte_mbuf* mbuf = rte_pktmbuf_alloc(mbufPool); @@ -421,14 +424,17 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, assert(length <= MAX_PAYLOAD_SIZE); DpdkDriver::Impl::Packet* packet = nullptr; - static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; - if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { - packet = packetPool.construct(m, payload); - mbufsOutstanding++; - } else { - OverflowBuffer* buf = overflowBufferPool.construct(); - rte_memcpy(payload, buf->data, length); - packet = packetPool.construct(buf); + { + SpinLock::Lock lock(packetLock); + static const int MBUF_ALLOC_LIMIT = NB_MBUF - NB_MBUF_RESERVED; + if (mbufsOutstanding < MBUF_ALLOC_LIMIT) { + packet = packetPool.construct(m, payload); + mbufsOutstanding++; + } else { + OverflowBuffer* buf = overflowBufferPool.construct(); + rte_memcpy(payload, buf->data, length); + packet = packetPool.construct(buf); + } } packet->base.length = length; @@ -445,6 +451,7 @@ void DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { + SpinLock::Lock lock(packetLock); DpdkDriver::Impl::Packet* packet = container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index b8def23..9425d34 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -184,12 +184,15 @@ class DpdkDriver::Impl { /// set by override). const int HIGHEST_PACKET_PRIORITY; + /// Protects access to the packetPool. + SpinLock packetLock; + /// Provides memory allocation for the DPDK specific implementation of a /// Driver Packet. - ObjectPool packetPool; + ObjectPool packetPool; /// Provides memory allocation for packet storage when mbuf are running out. - ObjectPool overflowBufferPool; + ObjectPool overflowBufferPool; /// The number of mbufs that have been given out to callers in Packets. uint32_t mbufsOutstanding; diff --git a/src/ObjectPool.h b/src/ObjectPool.h index 456c22f..860be71 100644 --- a/src/ObjectPool.h +++ b/src/ObjectPool.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2010-2018, Stanford University +/* Copyright (c) 2010-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -52,9 +52,10 @@ namespace Homa { * For example, transports use ObjectPool to allocate short-lived rpc * objects that cannot be kept in a stack context. * - * This class is thread-safe. + * Thread-safety of this class can be configured statically using the + * template argument ThreadSafe. */ -template +template class ObjectPool { public: /** @@ -64,7 +65,8 @@ class ObjectPool { * allocations. For simplicity, no bulk allocations are performed. */ ObjectPool() - : outstandingObjects(0) + : mutex() + , outstandingObjects(0) , pool() {} @@ -101,23 +103,23 @@ class ObjectPool { T* construct(Args&&... args) { void* backing = nullptr; - { - SpinLock::Lock lock(mutex); - if (pool.size() == 0) { - backing = operator new(sizeof(T)); - } else { - backing = pool.back(); - pool.pop_back(); - } - outstandingObjects++; + enterCS(); + if (pool.size() == 0) { + backing = operator new(sizeof(T)); + } else { + backing = pool.back(); + pool.pop_back(); } + outstandingObjects++; + exitCS(); try { return new (backing) T(static_cast(args)...); } catch (...) { - SpinLock::Lock lock(mutex); + enterCS(); pool.push_back(backing); outstandingObjects--; + exitCS(); throw; } } @@ -129,13 +131,34 @@ class ObjectPool { { object->~T(); - SpinLock::Lock lock(mutex); + enterCS(); assert(outstandingObjects > 0); pool.push_back(static_cast(object)); outstandingObjects--; + exitCS(); } private: + /** + * Enter the critical section guarded by _mutex_. + */ + void enterCS() + { + if (ThreadSafe) { + mutex.lock(); + } + } + + /** + * Exit the critical section guarded by _mutex_. + */ + void exitCS() + { + if (ThreadSafe) { + mutex.unlock(); + } + } + /// Monitor-style lock to protect the metadata of the pool. SpinLock mutex; diff --git a/src/Receiver.cc b/src/Receiver.cc index 0af7e2f..d04b0c1 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -142,11 +142,11 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) } // Add the packet, but don't copy the payload yet. - if (message->occupied.test(header->index)) { + if (message->received.test(header->index)) { // Must be a duplicate packet; drop it. return; } - message->occupied.set(header->index); + message->received.set(header->index); message->numPackets++; messageComplete = (message->numPackets == message->numExpectedPackets); } @@ -415,8 +415,9 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) return; } - SpinLock::Lock lock_bucket(bucket->mutex); while (true) { + SpinLock::Lock lock_bucket(bucket->mutex); + // No remaining timeouts. if (bucket->resendTimeouts.empty()) { break; @@ -470,7 +471,7 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) } for (int i = 0; i < grantIndexLimit; ++i) { - if (!message->occupied.test(i)) { + if (!message->received.test(i)) { // Unreceived packet if (num == 0) { // First unreceived packet diff --git a/src/Receiver.h b/src/Receiver.h index 045b86e..2085bc4 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -150,7 +150,7 @@ class Receiver { , scheduled(numExpectedPackets > numUnscheduledPackets) , messageLength(messageLength) , numPackets(0) - , occupied() + , received() , buffer(messageLength <= Util::arrayLength(internalBuffer) ? internalBuffer : receiver->externalBuffers.construct()->raw) @@ -213,7 +213,7 @@ class Receiver { /// Bit array representing which packets in this message are received. /// Protected by MessageBucket::mutex. - std::bitset occupied; + std::bitset received; /// Pointer to the contiguous memory buffer serving as message storage. char* const buffer; diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index fdf0ef5..fcd450f 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -374,7 +374,7 @@ TEST_F(ReceiverTest, Message_destructor_basic) message->numPackets = NUM_PKTS; for (int i = 0; i < NUM_PKTS; ++i) { - message->occupied.set(i); + message->received.set(i); } EXPECT_CALL(mockDriver, releasePackets(Eq(message->packets), Eq(NUM_PKTS))) @@ -392,10 +392,10 @@ TEST_F(ReceiverTest, Message_destructor_holes) const uint16_t NUM_PKTS = 4; message->numPackets = NUM_PKTS; - message->occupied.set(0); - message->occupied.set(1); - message->occupied.set(3); - message->occupied.set(4); + message->received.set(0); + message->received.set(1); + message->received.set(3); + message->received.set(4); EXPECT_CALL(mockDriver, releasePackets(Eq(&message->packets[0]), Eq(2))) .Times(1); @@ -593,7 +593,7 @@ TEST_F(ReceiverTest, Message_getPacket) EXPECT_EQ(nullptr, message->getPacket(0)); - message->occupied.set(0); + message->received.set(0); EXPECT_EQ(packet, message->getPacket(0)); } @@ -605,13 +605,13 @@ TEST_F(ReceiverTest, Message_setPacket) receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; - EXPECT_FALSE(message->occupied.test(0)); + EXPECT_FALSE(message->received.test(0)); EXPECT_EQ(0U, message->numPackets); EXPECT_TRUE(message->setPacket(0, packet)); EXPECT_EQ(packet, message->packets[0]); - EXPECT_TRUE(message->occupied.test(0)); + EXPECT_TRUE(message->received.test(0)); EXPECT_EQ(1U, message->numPackets); EXPECT_FALSE(message->setPacket(0, packet)); diff --git a/src/Sender.cc b/src/Sender.cc index 0cb8c9a..8784237 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -68,7 +68,9 @@ Homa::OutMessage* Sender::allocMessage(uint16_t sourcePort) { Perf::counters.allocated_tx_messages.add(1); - return messageAllocator.construct(this, sourcePort); + uint64_t messageId = + nextMessageSequenceNumber.fetch_add(1, std::memory_order_relaxed); + return messageAllocator.construct(this, messageId, sourcePort); } /** diff --git a/src/Sender.h b/src/Sender.h index f889cea..db62c89 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -121,13 +121,14 @@ class Sender { /** * Construct an Message. */ - explicit Message(Sender* sender, uint16_t sourcePort) + explicit Message(Sender* sender, uint64_t messageId, + uint16_t sourcePort) : sender(sender) , driver(sender->driver) , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) - , id(sender->transportId, sender->nextMessageSequenceNumber++) + , id(sender->transportId, messageId) , bucket(sender->messageBuckets.getBucket(id)) , source{driver->getLocalAddress(), sourcePort} , destination() @@ -401,7 +402,7 @@ class Sender { Policy::Manager* const policyManager; /// The sequence number to be used for the next Message. - volatile uint64_t nextMessageSequenceNumber; + std::atomic nextMessageSequenceNumber; /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; diff --git a/src/Timeout.h b/src/Timeout.h index 7d17ae0..56710f4 100644 --- a/src/Timeout.h +++ b/src/Timeout.h @@ -100,16 +100,12 @@ class TimeoutManager { * * @param timeout * The Timeout that should be scheduled. - * @param now - * Optionally provided "current" timestamp cycle time. Used to avoid - * unnecessary calls to PerfUtils::Cycles::rdtsc() if the current time - * is already available to the caller. */ - inline void setTimeout(Timeout* timeout, - uint64_t now = PerfUtils::Cycles::rdtsc()) + inline void setTimeout(Timeout* timeout) { list.remove(&timeout->node); - timeout->expirationCycleTime = now + timeoutIntervalCycles; + timeout->expirationCycleTime = + PerfUtils::Cycles::rdtsc() + timeoutIntervalCycles; list.push_back(&timeout->node); nextTimeout.store(list.front().expirationCycleTime, std::memory_order_relaxed); From ca0a7cb2debb11abfbf36b26a64b994741aa029b Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 10 Nov 2020 17:07:03 -0800 Subject: [PATCH 32/33] fixed the race conditions in Receiver due to reduced coverage of the bucket mutex --- src/Receiver.cc | 206 +++++++++++++++++++++++++----------------------- src/Receiver.h | 35 +++----- 2 files changed, 116 insertions(+), 125 deletions(-) diff --git a/src/Receiver.cc b/src/Receiver.cc index d04b0c1..fc1cd80 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -74,12 +74,15 @@ Receiver::~Receiver() // destructor should be sufficient. However, for clarity and debugging // purpose, we decided to write the cleanup procedure explicitly anyway. + // Remove all completed Messages that are still inside the receive queue. + receivedMessages.queue.clear(); + // Destruct all MessageBucket's and the Messages within. for (auto& bucket : messageBuckets.buckets) { // Intrusive::List is not responsible for destructing its elements; // it must be done manually. for (auto& message : bucket.messages) { - messageAllocator.destroy(&message); + dropMessage(&message); } assert(bucket.resendTimeouts.empty()); } @@ -88,11 +91,6 @@ Receiver::~Receiver() // Destruct all Peer's. Peer's must be removed from scheduledPeers first. scheduledPeers.clear(); peerTable.clear(); - - // Destruct all completed Messages that are not yet delivered. - for (auto& message : receivedMessages.queue) { - messageAllocator.destroy(&message); - } } /** @@ -110,58 +108,44 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) static_cast(packet->payload); Protocol::MessageId id = header->common.messageId; - bool needSchedule = false; - bool messageComplete; MessageBucket* bucket = messageBuckets.getBucket(id); - Message* message; - { - // Scoped critical section guarded by MessageBucket::mutex; this ensures - // that the bucket mutex is dropped before acquiring the schedulerMutex. - SpinLock::Lock lock_bucket(bucket->mutex); - message = bucket->findMessage(id, lock_bucket); - if (message == nullptr) { - // New message - int messageLength = header->totalLength; - int numUnscheduledPackets = header->unscheduledIndexLimit; - SocketAddress srcAddress = { - .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; - message = - messageAllocator.construct(this, driver, messageLength, id, - srcAddress, numUnscheduledPackets); - Perf::counters.allocated_rx_messages.add(1); - - // Start tracking the message. - bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage( - message->source.ip, header->policyVersion, header->totalLength); - - if (message->scheduled) { - // Don't schedule the message while holding the bucket mutex. - needSchedule = true; - } - } - - // Add the packet, but don't copy the payload yet. - if (message->received.test(header->index)) { - // Must be a duplicate packet; drop it. - return; - } - message->received.set(header->index); - message->numPackets++; - messageComplete = (message->numPackets == message->numExpectedPackets); + SpinLock::UniqueLock lock_bucket(bucket->mutex); + Message* message = bucket->findMessage(id, lock_bucket); + if (message == nullptr) { + // New message + int messageLength = header->totalLength; + int numUnscheduledPackets = header->unscheduledIndexLimit; + SocketAddress srcAddress = { + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; + message = messageAllocator.construct(this, driver, messageLength, id, + srcAddress, numUnscheduledPackets); + Perf::counters.allocated_rx_messages.add(1); + + // Start tracking the message. + bucket->messages.push_back(&message->bucketNode); + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); } - // Copy the payload into the message buffer. - std::memcpy(message->buffer + header->index * PACKET_DATA_LENGTH, - static_cast(packet->payload) + TRANSPORT_HEADER_LENGTH, - packet->length - TRANSPORT_HEADER_LENGTH); - // Sanity checks assert(message->source.ip == sourceIp); assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); - if (needSchedule) { + if (message->received.test(header->index)) { + // Must be a duplicate packet; drop it. + return; + } else { + // Add the packet and copy the payload. + message->received.set(header->index); + message->numPackets++; + std::memcpy( + message->buffer + header->index * PACKET_DATA_LENGTH, + static_cast(packet->payload) + TRANSPORT_HEADER_LENGTH, + packet->length - TRANSPORT_HEADER_LENGTH); + } + + if (message->scheduled) { // A new Message needs to be entered into the scheduler. SpinLock::Lock lock_scheduler(schedulerMutex); schedule(message, lock_scheduler); @@ -179,19 +163,23 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) } } - // Complete the message if all packets have been received. - if (messageComplete) { - message->state.store(Message::State::COMPLETED, - std::memory_order_release); + if (message->numPackets == message->numExpectedPackets) { + // All message packets have been received. + message->completed.store(true, std::memory_order_release); if (message->needTimeout) { - SpinLock::Lock lock(bucket->mutex); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); } + lock_bucket.unlock(); // Deliver the message to the user of the transport. SpinLock::Lock lock_received_messages(receivedMessages.mutex); receivedMessages.queue.push_back(&message->receivedMessageNode); Perf::counters.received_rx_messages.add(1); + } else if (message->needTimeout) { + // Receiving a new packet means the message is still active so it + // shouldn't time out until a while later. + message->numResendTimeouts = 0; + bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } @@ -209,12 +197,13 @@ Receiver::handleBusyPacket(Driver::Packet* packet) Protocol::MessageId id = header->common.messageId; MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); + SpinLock::UniqueLock lock_bucket(bucket->mutex); Message* message = bucket->findMessage(id, lock_bucket); if (message != nullptr) { // Sender has replied BUSY to our RESEND request; consider this message // still active. - if (message->getState() == Message::State::IN_PROGRESS) { + if (message->needTimeout && !message->completed) { + message->numResendTimeouts = 0; bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } @@ -235,13 +224,20 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) static_cast(packet->payload); Protocol::MessageId id = header->common.messageId; + // FIXME(Yilong): after making the transport purely message-based, we need + // to send back an ACK packet here if this message is complete + MessageBucket* bucket = messageBuckets.getBucket(id); - Message* message; - { - SpinLock::Lock lock_bucket(bucket->mutex); - message = bucket->findMessage(id, lock_bucket); - } + SpinLock::UniqueLock lock_bucket(bucket->mutex); + Message* message = bucket->findMessage(id, lock_bucket); if (message != nullptr) { + // Don't (re-)insert a timeout unless necessary. + if (message->needTimeout && !message->completed) { + // Sender is checking on this message; consider it still active. + message->numResendTimeouts = 0; + bucket->resendTimeouts.setTimeout(&message->resendTimeout); + } + // We are here either because a GRANT got lost, or we haven't issued a // GRANT in along time. Send out the latest GRANT if one exists or just // an "empty" GRANT to let the Sender know we are aware of the message. @@ -261,15 +257,17 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) priority = info->priority; } + lock_bucket.unlock(); Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, message->source.ip, message->id, bytesGranted, priority); + driver, sourceIp, id, bytesGranted, priority); } else { // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. + lock_bucket.unlock(); Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send( - driver, sourceIp, header->common.messageId); + ControlPacket::send(driver, sourceIp, + id); } } @@ -324,40 +322,14 @@ Receiver::checkTimeouts() /** * Destruct a Message. - * - * This method will detach the message from the transport and release all - * contained Packet objects. */ Receiver::Message::~Message() { - Perf::counters.destroyed_rx_messages.add(1); - Receiver* receiver = bucket->receiver; - - // Unschedule the message if it is still scheduled (i.e. still linked to a - // scheduled peer). - if (scheduled) { - SpinLock::Lock lock_scheduler(receiver->schedulerMutex); - ScheduledMessageInfo* info = &scheduledMessageInfo; - if (info->peer != nullptr) { - receiver->unschedule(this, lock_scheduler); - } - } - - // Remove this message from the other data structures of the Receiver. - { - SpinLock::Lock bucket_lock(bucket->mutex); - bucket->resendTimeouts.cancelTimeout(&resendTimeout); - bucket->messages.remove(&bucketNode); - - SpinLock::Lock receive_lock(receiver->receivedMessages.mutex); - receiver->receivedMessages.queue.remove(&receivedMessageNode); - } - // Release the external buffer, if any. if (buffer != internalBuffer) { MessageBuffer* externalBuf = (MessageBuffer*)buffer; - receiver->externalBuffers.destroy(externalBuf); + bucket->receiver->externalBuffers.destroy(externalBuf); } } @@ -395,7 +367,40 @@ Receiver::Message::length() const void Receiver::Message::release() { - bucket->receiver->messageAllocator.destroy(this); + bucket->receiver->dropMessage(this); +} + +/** + * Drop a message because it's no longer needed (either the application released + * the message or a timeout occurred). + * + * @param message + * Message which will be detached from the transport and destroyed. + */ +void +Receiver::dropMessage(Receiver::Message* message) +{ + // Unschedule the message if it is still scheduled (i.e. still linked to a + // scheduled peer). + if (message->scheduled) { + SpinLock::Lock lock_scheduler(schedulerMutex); + ScheduledMessageInfo* info = &message->scheduledMessageInfo; + if (info->peer != nullptr) { + unschedule(message, lock_scheduler); + } + } + + // Remove this message from the other data structures of the Receiver. + MessageBucket* bucket = message->bucket; + { + SpinLock::Lock bucket_lock(bucket->mutex); + bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); + bucket->messages.remove(&message->bucketNode); + } + + // Destroy the message. + messageAllocator.destroy(message); + Perf::counters.destroyed_rx_messages.add(1); } /** @@ -416,7 +421,7 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) } while (true) { - SpinLock::Lock lock_bucket(bucket->mutex); + SpinLock::UniqueLock lock_bucket(bucket->mutex); // No remaining timeouts. if (bucket->resendTimeouts.empty()) { @@ -430,11 +435,12 @@ Receiver::checkResendTimeouts(uint64_t now, MessageBucket* bucket) } // Found expired timeout. - assert(message->getState() == Message::State::IN_PROGRESS); + assert(!message->completed); message->numResendTimeouts++; if (message->numResendTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { // Message timed out before being fully received; drop the message. - messageAllocator.destroy(message); + lock_bucket.unlock(); + dropMessage(message); continue; } else { bucket->resendTimeouts.setTimeout(&message->resendTimeout); @@ -541,11 +547,11 @@ Receiver::trySendGrants() int slot = 0; while (it != scheduledPeers.end() && slot < policy.degreeOvercommitment) { assert(!it->scheduledMessages.empty()); + // No need to acquire the bucket mutex here because we are only going to + // access the const members of a Message; besides, the message can't get + // destroyed while we are holding the schedulerMutex. Message* message = &it->scheduledMessages.front(); ScheduledMessageInfo* info = &message->scheduledMessageInfo; - // Access message const variables without message mutex. - const Protocol::MessageId id = message->id; - const IpAddress sourceIp = message->source.ip; // Recalculate message priority info->priority = @@ -561,7 +567,7 @@ Receiver::trySendGrants() info->bytesGranted = newGrantLimit; Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, sourceIp, id, + driver, message->source.ip, message->id, Util::downCast(info->bytesGranted), info->priority); Perf::counters.active_cycles.add(timer.split()); } diff --git a/src/Receiver.h b/src/Receiver.h index 2085bc4..8955c3f 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -128,15 +128,6 @@ class Receiver { } }; - /** - * Defines the possible states of this Message. - */ - enum class State { - IN_PROGRESS, //< Receiver is in the process of receiving this - // message. - COMPLETED, //< Receiver has received the entire message. - }; - explicit Message(Receiver* receiver, Driver* driver, int messageLength, Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) @@ -155,7 +146,7 @@ class Receiver { ? internalBuffer : receiver->externalBuffers.construct()->raw) // No need to zero-out internalBuffer. - , state(Message::State::IN_PROGRESS) + , completed(false) , bucketNode(this) , receivedMessageNode(this) , needTimeout(numExpectedPackets > 1) @@ -172,14 +163,6 @@ class Receiver { size_t length() const override; void release() override; - /** - * Return the current state of this message. - */ - State getState() const - { - return state.load(std::memory_order_acquire); - } - private: /// Driver from which packets were received and to which they should be /// returned when this message is no longer needed. @@ -221,8 +204,9 @@ class Receiver { /// Internal memory buffer used to store messages within 2KB. char internalBuffer[2048]; - /// This message's current state. - std::atomic state; + /// Current state of the message. True means the entire message has been + /// received; otherwise, this message is still in progress. + std::atomic completed; /// Intrusive structure used by the Receiver to hold on to this Message /// in one of the Receiver's MessageBuckets. Access to this structure @@ -304,15 +288,14 @@ class Receiver { * @param msgId * MessageId of the Message to be found. * @param lock - * Reminder to hold the MessageBucket::mutex during this call. (Not - * used) + * Reminder to hold the MessageBucket::mutex during this call. * @return * A pointer to the Message if found; nullptr, otherwise. */ Message* findMessage(const Protocol::MessageId& msgId, - const SpinLock::Lock& lock) + const SpinLock::UniqueLock& lock) { - (void)lock; + assert(lock.owns_lock()); for (auto& it : messages) { if (it.id == msgId) { return ⁢ @@ -324,7 +307,8 @@ class Receiver { /// The Receiver that owns this bucket. Receiver* const receiver; - /// Mutex protecting the contents of this bucket. + /// Mutex protecting the contents of this bucket. This includes messages + /// within the bucket. SpinLock mutex; /// Collection of inbound messages @@ -437,6 +421,7 @@ class Receiver { Intrusive::List::Node scheduledPeerNode; }; + void dropMessage(Receiver::Message* message); void checkResendTimeouts(uint64_t now, MessageBucket* bucket); void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); From f3799b9ffccfef47b9204b6026be7260594e3ba4 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 11 Nov 2020 22:17:56 -0800 Subject: [PATCH 33/33] fixed race conditions on the Sender side --- src/Receiver.cc | 12 +-- src/Receiver.h | 6 +- src/Sender.cc | 260 +++++++++++++++++++++++++++--------------------- src/Sender.h | 28 ++++-- 4 files changed, 170 insertions(+), 136 deletions(-) diff --git a/src/Receiver.cc b/src/Receiver.cc index fc1cd80..402a39d 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -599,9 +599,9 @@ Receiver::trySendGrants() * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::schedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; Peer* peer = &peerTable[message->source.ip]; // Insert the Message @@ -636,9 +636,9 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::unschedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; assert(info->peer != nullptr); Peer* peer = info->peer; @@ -681,9 +681,9 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) * Reminder to hold the Receiver::schedulerMutex during this call. */ void -Receiver::updateSchedule(Receiver::Message* message, const SpinLock::Lock& lock) +Receiver::updateSchedule(Receiver::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) { - (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; assert(info->peer != nullptr); assert(info->peer->scheduledMessages.contains(&info->scheduledMessageNode)); diff --git a/src/Receiver.h b/src/Receiver.h index 8955c3f..3630d75 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -424,9 +424,9 @@ class Receiver { void dropMessage(Receiver::Message* message); void checkResendTimeouts(uint64_t now, MessageBucket* bucket); void trySendGrants(); - void schedule(Message* message, const SpinLock::Lock& lock); - void unschedule(Message* message, const SpinLock::Lock& lock); - void updateSchedule(Message* message, const SpinLock::Lock& lock); + void schedule(Message* message, SpinLock::Lock& lock); + void unschedule(Message* message, SpinLock::Lock& lock); + void updateSchedule(Message* message, SpinLock::Lock& lock); /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Receiver. diff --git a/src/Sender.cc b/src/Sender.cc index 8784237..646bc25 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -57,6 +57,7 @@ Sender::Sender(uint64_t transportId, Driver* driver, , sendQueue() , sending() , sendReady(false) + , sentMessages() , nextBucketIndex(0) , messageAllocator() {} @@ -73,37 +74,6 @@ Sender::allocMessage(uint16_t sourcePort) return messageAllocator.construct(this, messageId, sourcePort); } -/** - * Execute the common processing logic that is shared among all incoming control - * packets. - * - * @param packet - * Incoming control packet to be processed. - * @param resetTimeout - * True if we should update the timeouts in response to the packet. - * @return - * Pointer to the message targeted by the incoming packet, or nullptr if no - * matching message can be found. - */ -Sender::Message* -Sender::handleIncomingPacket(Driver::Packet* packet, bool resetTimeout) -{ - // Find the message bucket - Protocol::Packet::CommonHeader* commonHeader = - static_cast(packet->payload); - Protocol::MessageId msgId = commonHeader->messageId; - MessageBucket* bucket = messageBuckets.getBucket(msgId); - - // Find the target message and update its expiration time - SpinLock::Lock lock(bucket->mutex); - Message* message = bucket->findMessage(msgId, lock); - if (resetTimeout) { - message->numPingTimeouts = 0; - bucket->pingTimeouts.setTimeout(&message->pingTimeout); - } - return message; -} - /** * Process an incoming DONE packet. * @@ -113,19 +83,24 @@ Sender::handleIncomingPacket(Driver::Packet* packet, bool resetTimeout) void Sender::handleDonePacket(Driver::Packet* packet) { - Message* message = handleIncomingPacket(packet, false); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->messageId; + + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this DONE packet; must be old. return; } // Process DONE packet - Protocol::MessageId msgId = message->id; OutMessage::Status status = message->getStatus(); switch (status) { case OutMessage::Status::SENT: // Expected behavior - message->setStatus(OutMessage::Status::COMPLETED, false); + message->setStatus(OutMessage::Status::COMPLETED, false, lock); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -173,7 +148,15 @@ Sender::handleDonePacket(Driver::Packet* packet) void Sender::handleResendPacket(Driver::Packet* packet) { - Message* message = handleIncomingPacket(packet, true); + Protocol::Packet::ResendHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->common.messageId; + int index = header->index; + int resendEnd = index + header->num; + + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); // Check for unexpected conditions if (message == nullptr) { @@ -190,10 +173,8 @@ Sender::handleResendPacket(Driver::Packet* packet) return; } - Protocol::Packet::ResendHeader* resendHeader = - static_cast(packet->payload); - int index = resendHeader->index; - int resendEnd = index + resendHeader->num; + message->numPingTimeouts = 0; + bucket->pingTimeouts.setTimeout(&message->pingTimeout); // Check if RESEND request is out of range. if (index >= message->numPackets || resendEnd > message->numPackets) { @@ -214,7 +195,7 @@ Sender::handleResendPacket(Driver::Packet* packet) // Note that the priority of messages under the unscheduled byte limit // will never be overridden since the resend index will not exceed the // preset packetsGranted. - info->priority = resendHeader->priority; + info->priority = header->priority; sendReady.store(true); } @@ -253,11 +234,19 @@ Sender::handleResendPacket(Driver::Packet* packet) void Sender::handleGrantPacket(Driver::Packet* packet) { - Message* message = handleIncomingPacket(packet, true); + Protocol::Packet::GrantHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->common.messageId; + + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this grant; grant must be old. return; } + message->numPingTimeouts = 0; + bucket->pingTimeouts.setTimeout(&message->pingTimeout); Protocol::Packet::GrantHeader* grantHeader = static_cast(packet->payload); @@ -302,7 +291,14 @@ Sender::handleGrantPacket(Driver::Packet* packet) void Sender::handleUnknownPacket(Driver::Packet* packet) { - Message* message = handleIncomingPacket(packet, false); + Protocol::Packet::UnknownHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->common.messageId; + + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); + if (message == nullptr) { // No message was found. return; @@ -318,10 +314,10 @@ Sender::handleUnknownPacket(Driver::Packet* packet) // Option: NO_RETRY // Either the Message or the DONE packet was lost; consider the message // failed since the application asked for the message not to be retried. - message->setStatus(OutMessage::Status::FAILED, true); + message->setStatus(OutMessage::Status::FAILED, true, lock); } else { // Message isn't done yet so we will restart sending the message. - startMessage(message, true); + startMessage(message, true, lock); } } @@ -334,18 +330,23 @@ Sender::handleUnknownPacket(Driver::Packet* packet) void Sender::handleErrorPacket(Driver::Packet* packet) { - Message* message = handleIncomingPacket(packet, false); + Protocol::Packet::ErrorHeader* header = + static_cast(packet->payload); + Protocol::MessageId msgId = header->common.messageId; + + MessageBucket* bucket = messageBuckets.getBucket(msgId); + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this ERROR packet; must be old. return; } - Protocol::MessageId msgId = message->id; OutMessage::Status status = message->getStatus(); switch (status) { case OutMessage::Status::SENT: // Message was sent and a failure notification was received. - message->setStatus(OutMessage::Status::FAILED, false); + message->setStatus(OutMessage::Status::FAILED, false, lock); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -419,19 +420,6 @@ Sender::checkTimeouts() */ Sender::Message::~Message() { - Perf::counters.destroyed_tx_messages.add(1); - - // We assume that this message has been unlinked from the sendQueue before - // this method is invoked. - assert(getStatus() != OutMessage::Status::IN_PROGRESS); - - // Remove this message from the other data structures of the Sender. - { - SpinLock::Lock bucket_lock(bucket->mutex); - bucket->pingTimeouts.cancelTimeout(&pingTimeout); - bucket->messages.remove(&bucketNode); - } - // Sender message must be contiguous driver->releasePackets(packets, numPackets); } @@ -480,7 +468,8 @@ Sender::Message::append(const void* source, size_t count) void Sender::Message::cancel() { - setStatus(OutMessage::Status::CANCELED, true); + SpinLock::Lock lock(bucket->mutex); + setStatus(OutMessage::Status::CANCELED, true, lock); } /** @@ -501,41 +490,42 @@ Sender::Message::getStatus() const * The new status. * @param deschedule * True if we should remove this message from the send queue. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. * - * Note: special care must be taken when calling this method to avoid deadlocks + * Note: the caller should not hold Sender::queueMutex when calling this method * because our spinlock is not reentrant. */ void -Sender::Message::setStatus(Status newStatus, bool deschedule) +Sender::Message::setStatus(Status newStatus, bool deschedule, + [[maybe_unused]] SpinLock::Lock& lock) { // Whether to remove the message from the send queue depends on more than // just the message status; only the caller has enough information to make // the decision. if (deschedule) { - // TODO: with jumbo packets, single-packet messages may also be paced - // An outgoing message is on the sendQueue iff. it's still in progress // and subject to the sender's packet pacing mechanism; test this // condition first to reduce the expensive locking operation - if ((numPackets > 1) && - (getStatus() == OutMessage::Status::IN_PROGRESS)) { + if (throttled && (getStatus() == OutMessage::Status::IN_PROGRESS)) { SpinLock::Lock lock_queue(sender->queueMutex); sender->sendQueue.remove(&queuedMessageInfo.sendQueueNode); } } state.store(newStatus, std::memory_order_release); - - // Cancel the timeouts if the message reaches an end state. - if (newStatus == OutMessage::Status::CANCELED || - newStatus == OutMessage::Status::COMPLETED || - newStatus == OutMessage::Status::FAILED) { - SpinLock::Lock lock(bucket->mutex); - bucket->pingTimeouts.cancelTimeout(&pingTimeout); + bool endState = (newStatus == OutMessage::Status::CANCELED || + newStatus == OutMessage::Status::COMPLETED || + newStatus == OutMessage::Status::FAILED); + if (endState) { + if (!held.load(std::memory_order_acquire)) { + // Ok to delete now that the message has reached an end state. + sender->dropMessage(this, lock); + } else { + // Cancel the timeouts; the message will be dropped upon release(). + bucket->pingTimeouts.cancelTimeout(&pingTimeout); + } } - - // This method is not the right place to remove the message from the bucket; - // it's the job of Message::release(). } /** @@ -587,9 +577,10 @@ Sender::Message::release() if (getStatus() != OutMessage::Status::IN_PROGRESS) { // Ok to delete immediately since we don't have to wait for the message // to be sent. - sender->messageAllocator.destroy(this); + SpinLock::Lock lock(bucket->mutex); + sender->dropMessage(this, lock); } else { - // Defer deletion and wait for the message to be SENT. + // Defer deletion and wait for the message to be SENT at least. } Perf::counters.released_tx_messages.add(1); } @@ -644,8 +635,12 @@ Sender::Message::send(SocketAddress destination, this->destination = destination; this->options = options; + // All single-packet messages currently bypass our packet pacing mechanism. + throttled = (numPackets > 1); + // Kick start the transmission. - sender->startMessage(this, false); + SpinLock::Lock lock(bucket->mutex); + sender->startMessage(this, false, lock); } /** @@ -690,6 +685,31 @@ Sender::Message::getOrAllocPacket(size_t index) return packets[index]; } +/** + * Drop a message that is no longer needed (released by the application). + * + * @param message + * Message which will be detached from the transport and destroyed. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. + */ +void +Sender::dropMessage(Sender::Message* message, + [[maybe_unused]] SpinLock::Lock& lock) +{ + // We assume that this message has been unlinked from the sendQueue before + // this method is invoked. + assert(message->getStatus() != OutMessage::Status::IN_PROGRESS); + + // Remove this message from the other data structures of the Sender. + MessageBucket* bucket = message->bucket; + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + bucket->messages.remove(&message->bucketNode); + + Perf::counters.destroyed_tx_messages.add(1); + messageAllocator.destroy(message); +} + /** * (Re)start the transmission of an outgoing message. * @@ -698,13 +718,16 @@ Sender::Message::getOrAllocPacket(size_t index) * @param restart * False if the message is new to the transport; true means the message is * restarted by the transport. + * @param lock + * Reminder to hold the MessageBucket::mutex during this call. */ void -Sender::startMessage(Sender::Message* message, bool restart) +Sender::startMessage(Sender::Message* message, bool restart, + [[maybe_unused]] SpinLock::Lock& lock) { // If we are restarting an existing message, make sure it's not in the // sendQueue before making any changes to it. - message->setStatus(OutMessage::Status::IN_PROGRESS, restart); + message->setStatus(OutMessage::Status::IN_PROGRESS, restart, lock); // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( @@ -730,7 +753,6 @@ Sender::startMessage(Sender::Message* message, bool restart) // Start tracking the new message MessageBucket* bucket = message->bucket; - SpinLock::Lock lock(bucket->mutex); assert(!bucket->messages.contains(&message->bucketNode)); bucket->messages.push_back(&message->bucketNode); } else { @@ -748,15 +770,15 @@ Sender::startMessage(Sender::Message* message, bool restart) // Kick start the message. assert(message->numPackets > 0); bool needTimeouts = true; - if (message->numPackets == 1) { - // If there is only one packet in the message, send it right away. + if (!message->throttled) { + // The message is too small to be scheduled, send it right away. Driver::Packet* dataPacket = message->getPacket(0); assert(dataPacket != nullptr); Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); driver->sendPacket(dataPacket, message->destination.ip, policy.priority); - message->setStatus(OutMessage::Status::SENT, false); + message->setStatus(OutMessage::Status::SENT, false, lock); // This message must be still be held by the application since the // message still exists (it would have been removed when dropped // because single packet messages are never IN_PROGRESS). Assuming @@ -787,10 +809,8 @@ Sender::startMessage(Sender::Message* message, bool restart) // Initialize the timeouts if (needTimeouts) { - MessageBucket* bucket = message->bucket; - SpinLock::Lock lock(bucket->mutex); message->numPingTimeouts = 0; - bucket->pingTimeouts.setTimeout(&message->pingTimeout); + message->bucket->pingTimeouts.setTimeout(&message->pingTimeout); } } @@ -813,7 +833,7 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) } while (true) { - SpinLock::UniqueLock bucket_lock(bucket->mutex); + SpinLock::Lock lock_bucket(bucket->mutex); // No remaining timeouts. if (bucket->pingTimeouts.empty()) { break; @@ -835,19 +855,14 @@ Sender::checkPingTimeouts(uint64_t now, MessageBucket* bucket) message->numPingTimeouts++; if (message->numPingTimeouts >= MESSAGE_TIMEOUT_INTERVALS) { // Found expired message. - - // Release the bucket mutex to avoid deadlock in setStatus(). - bucket_lock.unlock(); - message->setStatus(OutMessage::Status::FAILED, true); + message->setStatus(OutMessage::Status::FAILED, true, + lock_bucket); continue; } else { bucket->pingTimeouts.setTimeout(&message->pingTimeout); } } - // Release the bucket mutex to follow the locking order constraint. - bucket_lock.unlock(); - // Check if sender still has packets to send if (status == OutMessage::Status::IN_PROGRESS) { QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -893,13 +908,17 @@ Sender::trySend() * Each time this method is called we will try to send enough packet to keep * the NIC busy but not too many as to cause excessive queue in the NIC. */ - SpinLock::Lock lock_queue(queueMutex); + SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; auto it = sendQueue.begin(); while (it != sendQueue.end()) { + // No need to acquire the bucket mutex here because we are only going to + // access the const members of a Message; besides, the message can't get + // destroyed concurrently because whoever needs to destroy the message + // will need to acquire queueMutex first. Message& message = *it; assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; @@ -929,25 +948,8 @@ Sender::trySend() } if (info->packetsSent >= message.numPackets) { // We have finished sending the message. - - // Note: instead of relying on setStatus(), manually deschedule this - // message since we are already holding the queueMutex (out spinlock - // is not reentrant). - assert(message.numPackets > 1 && - message.getStatus() == OutMessage::Status::IN_PROGRESS); it = sendQueue.remove(it); - message.setStatus(OutMessage::Status::SENT, false); - - if (!message.held.load(std::memory_order_acquire)) { - // Ok to delete now that the message has been sent. - messageAllocator.destroy(&message); - } else if (message.options & OutMessage::Options::NO_KEEP_ALIVE) { - // No timeouts need to be checked after sending the message when - // the NO_KEEP_ALIVE option is enabled. - MessageBucket* bucket = message.bucket; - SpinLock::Lock lock_bucket(bucket->mutex); - bucket->pingTimeouts.cancelTimeout(&message.pingTimeout); - } + sentMessages.emplace_back(message.bucket, message.id); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; @@ -960,6 +962,32 @@ Sender::trySend() } sending.clear(); + // Unlock the queueMutex to process any SENT messages to ensure any bucket + // mutex is always acquired before the send queueMutex. + lock_queue.unlock(); + for (auto& [bucket, id] : sentMessages) { + SpinLock::Lock lock(bucket->mutex); + Message* message = bucket->findMessage(id, lock); + if (message == nullptr) { + // Message must have already been deleted. + continue; + } + + // The message status may change after the queueMutex is unlocked; + // ignore this message if that is the case. + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { + message->setStatus(OutMessage::Status::SENT, false, lock); + if (!message->held.load(std::memory_order_acquire)) { + // Ok to delete now that the message has been sent. + dropMessage(message, lock); + } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE) { + // No timeouts need to be checked after sending the message when + // the NO_KEEP_ALIVE option is enabled. + bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); + } + } + } + if (!idle) { Perf::counters.active_cycles.add(timer.split()); } diff --git a/src/Sender.h b/src/Sender.h index db62c89..b039d51 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -62,8 +62,6 @@ class Sender { class Message; struct MessageBucket; - Message* handleIncomingPacket(Driver::Packet* packet, bool resetTimeout); - /** * Contains metadata for a Message that has been queued to be sent. */ @@ -137,6 +135,7 @@ class Sender { , start(0) , messageLength(0) , numPackets(0) + , throttled() , occupied() // packets is not initialized to reduce the work done during // construction. See Message::occupied. @@ -151,7 +150,7 @@ class Sender { void append(const void* source, size_t count) override; void cancel() override; Status getStatus() const override; - void setStatus(Status newStatus, bool deschedule); + void setStatus(Status newStatus, bool deschedule, SpinLock::Lock& lock); size_t length() const override; void prepend(const void* source, size_t count) override; void release() override; @@ -214,7 +213,12 @@ class Sender { /// constant after send() is invoked. int numPackets; - // FIXME: seems like an overkill? (e.g., packets should be added in order) + /// False if this message is so small that we are better off sending it + /// right away to reduce software overhead; otherwise, the message must + /// be put into Sender::sendQueue and transmitted in SRPT order. Must be + /// constant after send() is invoked. + bool throttled; + /// Bit array representing which entries in the _packets_ array are set. /// Used to avoid having to zero out the entire _packets_ array. Must be /// constant after send() is invoked. @@ -298,9 +302,8 @@ class Sender { * A pointer to the Message if found; nullptr, otherwise. */ Message* findMessage(const Protocol::MessageId& msgId, - const SpinLock::Lock& lock) + [[maybe_unused]] const SpinLock::Lock& lock) { - (void)lock; for (auto& it : messages) { if (it.id == msgId) { return ⁢ @@ -387,7 +390,9 @@ class Sender { Protocol::MessageId::Hasher hasher; }; - void startMessage(Sender::Message* message, bool restart); + void dropMessage(Sender::Message* message, SpinLock::Lock& lock); + void startMessage(Sender::Message* message, bool restart, + SpinLock::Lock& lock); void checkPingTimeouts(uint64_t now, MessageBucket* bucket); void trySend(); @@ -416,10 +421,7 @@ class Sender { /// Protects the sendQueue, including all member variables of its items. /// When multiple locks must be acquired, this class follows the locking /// order constraint below ("<" means "acquired before"): - /// queueMutex < MessageBucket::mutex - /// Usually, it's more natural to acquire coarser-grained locks first, - /// unless inverting the order would make the common code path simpler - /// and/or faster. + /// MessageBucket::mutex < queueMutex SpinLock queueMutex; /// A list of outbound messages that have unsent packets. Messages are kept @@ -435,6 +437,10 @@ class Sender { /// if there is work to do is more efficient. std::atomic sendReady; + /// Used to temporarily hold messages whose last packets have just been + /// sent out. Always empty unless inside trySend(). + std::vector> sentMessages; + /// The index of the next bucket in the messageBuckets::buckets array to /// process in the poll loop. The index is held in the lower order bits of /// this variable; the higher order bits should be masked off using the