From 33760888594eedd402094cf5e3a0a120c44ef4d9 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 21 Dec 2024 22:44:01 +0200 Subject: [PATCH] chore: reimplement TlsSocket::WriteSome Allow cleaner separation between writing to ssl and flushing the ssl output buffer into a socket. Signed-off-by: Roman Gershman --- util/tls/tls_engine.cc | 93 +++++++++-------- util/tls/tls_engine.h | 19 +++- util/tls/tls_engine_test.cc | 117 +++++++++++---------- util/tls/tls_socket.cc | 203 +++++++++++++++++------------------- util/tls/tls_socket.h | 11 +- 5 files changed, 231 insertions(+), 212 deletions(-) diff --git a/util/tls/tls_engine.cc b/util/tls/tls_engine.cc index bc94e573..d17f91e2 100644 --- a/util/tls/tls_engine.cc +++ b/util/tls/tls_engine.cc @@ -33,38 +33,6 @@ static void ClearSslError() { } while (l); } -static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* location) { - DCHECK_LE(result, 0); - - unsigned long error = ERR_get_error(); - if (error != 0) { - return nonstd::make_unexpected(error); - } - - int ssl_error = SSL_get_error(ssl, result); - int io_err = errno; - - switch (ssl_error) { - case SSL_ERROR_ZERO_RETURN: - break; - case SSL_ERROR_WANT_READ: - return Engine::NEED_READ_AND_MAYBE_WRITE; - case SSL_ERROR_WANT_WRITE: - return Engine::NEED_WRITE; - case SSL_ERROR_SYSCALL: - LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location; - break; - case SSL_ERROR_SSL: - LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location; - break; - default: - LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location; - break; - } - - return Engine::EOF_STREAM; -} - #define S1(x) #x #define S2(x) S1(x) #define LOCATION __FILE__ " : " S2(__LINE__) @@ -72,7 +40,7 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat #define RETURN_RESULT(res) \ if (res > 0) \ return res; \ - return ToOpResult(ssl_, res, LOCATION) + return ToOpResult(res, LOCATION) Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) { CHECK(ssl_); @@ -135,15 +103,13 @@ void Engine::ConsumeOutputBuf(unsigned sz) { CHECK_EQ(unsigned(res), sz); } -auto Engine::WriteBuf(const Buffer& buf) -> OpResult { +unsigned Engine::WriteBuf(const Buffer& buf) { DCHECK(!buf.empty()); char* cbuf = nullptr; int res = BIO_nwrite(external_bio_, &cbuf, buf.size()); - if (res < 0) { - unsigned long error = ::ERR_get_error(); - return nonstd::make_unexpected(error); - } else if (res > 0) { + CHECK_GE(res, 0); + if (res > 0) { memcpy(cbuf, buf.data(), res); } return res; @@ -171,20 +137,25 @@ auto Engine::Handshake(HandshakeType type) -> OpResult { } auto Engine::Shutdown() -> OpResult { + if (state_ & FATAL_ERROR) + return 1; + int result = SSL_shutdown(ssl_); // See https://www.openssl.org/docs/man1.1.1/man3/SSL_shutdown.html - if (result == 0) // First step of Shutdown (close_notify) returns 0. - return result; + + if (result == 0) { // First step of Shutdown (close_notify) returns 0. + result = SSL_shutdown(ssl_); // Initiate the second step. + } RETURN_RESULT(result); } auto Engine::Write(const Buffer& buf) -> OpResult { - if (buf.empty()) - return 0; + CHECK(!buf.empty()); int sz = buf.size() < INT_MAX ? buf.size() : INT_MAX; int result = SSL_write(ssl_, buf.data(), sz); - RETURN_RESULT(result); + + RETURN_RESULT(result); // Should never return 0. } auto Engine::Read(uint8_t* dest, size_t len) -> OpResult { @@ -245,5 +216,41 @@ int SslProbeSetDefaultCALocation(SSL_CTX* ctx) { return -1; } + +auto Engine::ToOpResult(int result, const char* location) -> OpResult { + DCHECK_LE(result, 0); + + int ssl_error = SSL_get_error(ssl_, result); + unsigned long queue_error = 0; + +#define ERROR_DETAILS errno << ":" << queue_error << " " \ + << ERR_reason_error_string(queue_error) << " " << location + + switch (ssl_error) { + case SSL_ERROR_ZERO_RETURN: // graceful shutdown of TLS connection. + break; + case SSL_ERROR_WANT_READ: + return Engine::NEED_READ_AND_MAYBE_WRITE; + case SSL_ERROR_WANT_WRITE: + return Engine::NEED_WRITE; + case SSL_ERROR_SYSCALL: // fatal error in system call. + queue_error = ERR_get_error(); + VLOG(1) << "SSL syscall error " << ERROR_DETAILS; + break; + case SSL_ERROR_SSL: + state_ |= FATAL_ERROR; + queue_error = ERR_get_error(); + LOG(WARNING) << "SSL protocol error " << ERROR_DETAILS; + break; + default: + queue_error = ERR_get_error(); + state_ |= FATAL_ERROR; + LOG(WARNING) << "Unexpected SSL error " << ssl_error << " " << ERROR_DETAILS; + break; + } + + return Engine::EOF_STREAM; +} + } // namespace tls } // namespace util diff --git a/util/tls/tls_engine.h b/util/tls/tls_engine.h index 4c07bdcf..8df6b84e 100644 --- a/util/tls/tls_engine.h +++ b/util/tls/tls_engine.h @@ -42,7 +42,7 @@ class Engine { // if value == NEED_XXX then it means that it should either write data to IO and then read or just // write. In any case for non-error OpResult a caller must check OutputPending and write the // output buffer to the appropriate channel. - using OpResult = io::Result; + using OpResult = int; // Construct a new engine for the specified context. explicit Engine(SSL_CTX* context); @@ -63,14 +63,16 @@ class Engine { // Perform an SSL handshake using either SSL_connect (client-side) or // SSL_accept (server-side). + // Returns 1 if succeeded or negative opcodes to execute upon. OpResult Handshake(HandshakeType type); + // Returns 1 if succeeded or negative opcodes to execute upon. OpResult Shutdown(); - // Write bytes to the SSL session. Non-negative value - says how much was written. + // Write decrypted bytes to the SSL session. Non-negative value - says how much was written. OpResult Write(const Buffer& data); - // Read bytes from the SSL session. + // Read bytes from the SSL session (decrypted side). OpResult Read(uint8_t* dest, size_t len); //! Returns output buffer which is the read buffer of tls engine. @@ -82,10 +84,10 @@ class Engine { //! sz should be not greater than the buffer size from the last PeekOutputBuf() call. void ConsumeOutputBuf(unsigned sz); - //! Writes the buffer into input ssl buffer. + //! Writes encrypted data into input ssl buffer. //! Returns number of written bytes or the error. //! TODO: should be replaced with PeekInputBuf, memcpy, CommitInput sequence. - OpResult WriteBuf(const Buffer& buf); + unsigned WriteBuf(const Buffer& buf); // We usually use this function to write from the raw socket to SSL engine. // Returns direct reference to the input (write) buffer. This operation is not destructive. @@ -118,6 +120,13 @@ class Engine { // Perform one operation. Returns > 0 on success. using EngineOp = int (Engine::*)(void*, std::size_t); + OpResult ToOpResult(int result, const char* location); + + enum StateMask { + FATAL_ERROR = 1, + }; + + uint8_t state_ = 0; SSL* ssl_; BIO* external_bio_; }; diff --git a/util/tls/tls_engine_test.cc b/util/tls/tls_engine_test.cc index 11bde551..a1e68c74 100644 --- a/util/tls/tls_engine_test.cc +++ b/util/tls/tls_engine_test.cc @@ -11,7 +11,6 @@ #include "base/gtest.h" #include "base/logging.h" #include "util/fibers/fibers.h" -#include "util/tls/tls_socket.h" namespace util { namespace tls { @@ -50,17 +49,17 @@ SSL_CTX* CreateSslCntx() { } static void* TestMalloc(size_t num, const char* file, int line) { - VLOG(1) << "MyTestMalloc " << num << " " << file << ":" << line; + DVLOG(2) << "MyTestMalloc " << num << " " << file << ":" << line; return malloc(num); } static void* TestRealloc(void* addr, size_t num, const char* file, int line) { - VLOG(1) << "MyTestRealloc " << num << " " << file << ":" << line; + DVLOG(2) << "MyTestRealloc " << num << " " << file << ":" << line; return realloc(addr, num); } static void TestFree(void* addr, const char* file, int line) { - VLOG(1) << "TestFree " << file << ":" << line << " " << addr; + DVLOG(2) << "TestFree " << file << ":" << line << " " << addr; free(addr); } @@ -120,6 +119,7 @@ void SslStreamTest::SetUp() { srv_handshake_ = [](Engine* eng) { return eng->Handshake(Engine::SERVER); }; client_handshake_ = [](Engine* eng) { return eng->Handshake(Engine::CLIENT); }; + read_op_ = [this](Engine* eng) { return eng->Read(tmp_buf_.get(), TMP_CAPACITY); }; shutdown_op_ = [](Engine* eng) { return eng->Shutdown(); }; write_op_ = [this](Engine* eng) { @@ -129,56 +129,63 @@ void SslStreamTest::SetUp() { ERR_print_errors_fp(stderr); // Empties the queue. } +void TransmitData(Engine* src, Engine* dest, SslStreamTest::Options* opts) { + auto buffer = src->PeekOutputBuf(); + VLOG(1) << opts->name << " wrote " << buffer.size() << " bytes"; + CHECK(!buffer.empty()); + + if (opts->mutate_indx) { + uint8_t* mem = const_cast(buffer.data()); + mem[opts->mutate_indx % buffer.size()] = opts->mutate_val; + opts->mutate_indx = 0; + } + + if (opts->drain_output) { + src->ConsumeOutputBuf(buffer.size()); + } else { + auto dest_buf = dest->PeekInputBuf(); + CHECK_LT(buffer.size(), dest_buf.size()); + memcpy(dest_buf.data(), buffer.data(), buffer.size()); + dest->CommitInput(buffer.size()); + src->ConsumeOutputBuf(buffer.size()); + } +} + static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb, Engine* src, Engine* dest) { - unsigned input_pending = 1; while (true) { - auto op_result = cb(src); - if (!op_result) { - return op_result.error(); - } - VLOG(1) << opts.name << " OpResult: " << *op_result; - unsigned output_pending = src->OutputPending(); - if (output_pending > 0) { - auto buffer = src->PeekOutputBuf(); - VLOG(1) << opts.name << " wrote " << buffer.size() << " bytes"; - CHECK(!buffer.empty()); - - if (opts.mutate_indx) { - uint8_t* mem = const_cast(buffer.data()); - mem[opts.mutate_indx % buffer.size()] = opts.mutate_val; - opts.mutate_indx = 0; - } + Engine::OpResult op_result = cb(src); - if (opts.drain_output) { - src->ConsumeOutputBuf(buffer.size()); - } else { - auto write_result = dest->WriteBuf(buffer); - if (!write_result) { - return write_result.error(); - } - CHECK_GT(*write_result, 0); - src->ConsumeOutputBuf(*write_result); + VLOG(1) << opts.name << " OpResult: " << op_result; + if (op_result > 0) { // Successful op + if (src->OutputPending()) { + TransmitData(src, dest, &opts); } + return op_result; } - if (*op_result >= 0) { // Shutdown or empty read/write may return 0. - return 0; + if (op_result == Engine::NEED_READ_AND_MAYBE_WRITE) { + ThisFiber::Yield(); // another peer will write. + if (src->OutputPending()) + op_result = Engine::NEED_WRITE; + } + + bool dirty = false; + if (op_result == Engine::NEED_WRITE) { + dirty = true; + TransmitData(src, dest, &opts); } - if (*op_result == Engine::EOF_STREAM) { + + if (op_result == Engine::EOF_STREAM) { LOG(WARNING) << opts.name << " stream truncated"; return 0; } - if (input_pending == 0 && output_pending == 0) { // dropped connection - LOG(INFO) << "Dropped connections for " << opts.name; - - return ERR_PACK(ERR_LIB_USER, 0, ERR_R_OPERATION_FAIL); - } ThisFiber::Yield(); - - input_pending = src->InputPending(); - VLOG(1) << "Input size: " << input_pending; + if (!dirty && src->InputPending() == 0) { + return 0; + } + VLOG(1) << opts.name << ", input size: " << src->InputPending(); } } @@ -252,8 +259,8 @@ TEST_F(SslStreamTest, Handshake) { client_fb.Join(); server_fb.Join(); - ASSERT_EQ(0, cl_err); - ASSERT_EQ(0, srv_err); + ASSERT_EQ(1, cl_err); + ASSERT_EQ(1, srv_err); } TEST_F(SslStreamTest, HandshakeErrServer) { @@ -276,7 +283,7 @@ TEST_F(SslStreamTest, HandshakeErrServer) { LOG(INFO) << SSLError(cl_err); LOG(INFO) << SSLError(srv_err); - ASSERT_NE(0, cl_err); + ASSERT_EQ(0, cl_err); } TEST_F(SslStreamTest, ReadShutdown) { @@ -293,8 +300,8 @@ TEST_F(SslStreamTest, ReadShutdown) { client_fb.Join(); server_fb.Join(); - ASSERT_EQ(0, cl_err); - ASSERT_EQ(0, srv_err); + EXPECT_EQ(1, cl_err); + EXPECT_EQ(1, srv_err); client_fb = Fiber([&] { cl_err = RunPeer(client_opts_, shutdown_op_, client_engine_.get(), server_engine_.get()); @@ -305,13 +312,13 @@ TEST_F(SslStreamTest, ReadShutdown) { server_fb.Join(); client_fb.Join(); - ASSERT_EQ(0, cl_err); - ASSERT_EQ(0, srv_err); + EXPECT_EQ(0, cl_err); + EXPECT_EQ(0, srv_err); int shutdown_srv = SSL_get_shutdown(server_engine_->native_handle()); int shutdown_client = SSL_get_shutdown(client_engine_->native_handle()); - ASSERT_EQ(SSL_RECEIVED_SHUTDOWN, shutdown_srv); - ASSERT_EQ(SSL_SENT_SHUTDOWN, shutdown_client); + EXPECT_EQ(SSL_RECEIVED_SHUTDOWN, shutdown_srv); + EXPECT_EQ(SSL_SENT_SHUTDOWN, shutdown_client); client_fb = Fiber([&] { cl_err = RunPeer(client_opts_, shutdown_op_, client_engine_.get(), server_engine_.get()); @@ -323,13 +330,13 @@ TEST_F(SslStreamTest, ReadShutdown) { server_fb.Join(); client_fb.Join(); - ASSERT_EQ(0, cl_err) << SSLError(cl_err); - ASSERT_EQ(0, srv_err); + EXPECT_EQ(1, cl_err) << SSLError(cl_err); + EXPECT_EQ(1, srv_err); shutdown_srv = SSL_get_shutdown(server_engine_->native_handle()); shutdown_client = SSL_get_shutdown(client_engine_->native_handle()); - ASSERT_EQ(SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN, shutdown_srv); - ASSERT_EQ(shutdown_client, shutdown_srv); + EXPECT_EQ(SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN, shutdown_srv); + EXPECT_EQ(shutdown_client, shutdown_srv); } TEST_F(SslStreamTest, Write) { @@ -348,7 +355,7 @@ TEST_F(SslStreamTest, Write) { client_opts_.drain_output = true; for (size_t i = 0; i < 10; ++i) { cl_err = RunPeer(client_opts_, write_op_, client_engine_.get(), server_engine_.get()); - ASSERT_EQ(0, cl_err); + ASSERT_EQ(TMP_CAPACITY, cl_err); } } diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index f476958c..b1a24d15 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -15,8 +15,8 @@ #define VSOCK(verbosity) \ VLOG(verbosity) << "sock[" << native_handle() << "], state " << int(state_) \ - << ", write_total:" << upstream_write_ << " " << " pending output: " \ - << engine_->OutputPending() << " " + << ", write_total:" << upstream_write_ << " " \ + << " pending output: " << engine_->OutputPending() << " " namespace util { namespace tls { @@ -91,8 +91,7 @@ void TlsSocket::InitSSL(SSL_CTX* context, Buffer prefix) { engine_.reset(new Engine{context}); if (!prefix.empty()) { Engine::OpResult op_result = engine_->WriteBuf(prefix); - CHECK(op_result); - CHECK_EQ(unsigned(*op_result), prefix.size()); + CHECK_EQ(unsigned(op_result), prefix.size()); } } @@ -131,25 +130,32 @@ auto TlsSocket::Accept() -> AcceptResult { while (true) { Engine::OpResult op_result = engine_->Handshake(Engine::SERVER); - // it is important to send output (protocol errors) before we return from this function. - error_code ec = MaybeSendOutput(); - if (ec) { - VSOCK(1) << "MaybeSendOutput failed " << ec; - return make_unexpected(ec); + if (op_result == Engine::EOF_STREAM) { + return make_unexpected(make_error_code(errc::connection_aborted)); } - // now check the result of the handshake. - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } + if (op_result == 1) { // Success. + if (VLOG_IS_ON(1)) { + const SSL_CIPHER* cipher = SSL_get_current_cipher(engine_->native_handle()); + string_view proto_version = SSL_get_version(engine_->native_handle()); - int op_val = *op_result; + // IANA mapping https://testssl.sh/openssl-iana.mapping.html + uint16_t protocol_id = SSL_CIPHER_get_protocol_id(cipher); - if (op_val >= 0) { // Shutdown or empty read/write may return 0. + LOG(INFO) << "SSL accept success, chosen " << SSL_CIPHER_get_name(cipher) << "/" + << proto_version << " " << protocol_id; + } break; } - ec = HandleOp(op_val); + // it is important to send output (protocol errors) before we return from this function. + error_code ec = MaybeSendOutput(); + if (ec) { + VSOCK(1) << "MaybeSendOutput failed " << ec; + return make_unexpected(ec); + } + + ec = HandleOp(op_result); if (ec) return make_unexpected(ec); } @@ -160,29 +166,25 @@ auto TlsSocket::Accept() -> AcceptResult { error_code TlsSocket::Connect(const endpoint_type& endpoint, std::function on_pre_connect) { DCHECK(engine_); - Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); - if (!op_result) { - return std::error_code(op_result.error(), std::system_category()); - } + while (true) { + Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); + if (op_result == 1) { + break; + } - // If the socket is already open, we should not call connect on it - if (!IsOpen()) { - RETURN_ON_ERROR(next_sock_->Connect(endpoint, std::move(on_pre_connect))); - } + if (op_result == Engine::EOF_STREAM) { + return make_error_code(errc::connection_refused); + } - // Flush the ssl data to the socket and run the loop that ensures handshaking converges. - int op_val = *op_result; + // If the socket is already open, we should not call connect on it + if (!IsOpen()) { + RETURN_ON_ERROR(next_sock_->Connect(endpoint, std::move(on_pre_connect))); + } - // it should guide us to write and then read. - DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE); - while (op_val < 0) { - RETURN_ON_ERROR(HandleOp(op_val)); + // Flush the ssl data to the socket and run the loop that ensures handshaking converges. + int op_val = op_result; - op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); - if (!op_result) { - return std::error_code(op_result.error(), std::system_category()); - } - op_val = *op_result; + RETURN_ON_ERROR(HandleOp(op_val)); } const SSL_CIPHER* cipher = SSL_get_current_cipher(engine_->native_handle()); @@ -241,7 +243,6 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { Engine::MutableBuffer dest{reinterpret_cast(io->iov_base), io->iov_len}; size_t read_total = 0; - SpinCounter spin_count(20); while (true) { DCHECK(!dest.empty()); @@ -249,21 +250,12 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { size_t read_len = std::min(dest.size(), size_t(INT_MAX)); Engine::OpResult op_result = engine_->Read(dest.data(), read_len); - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } - int op_val = *op_result; - DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; + int op_val = op_result; - if (spin_count.Check(op_val <= 0)) { - // Once every 30 seconds. - LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() - << " Spin: " << spin_count.Spins(); - } + DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; if (op_val > 0) { - spin_count.Reset(); read_total += op_val; // I do not understand this code and what the hell I meant to do here. Seems to work @@ -304,84 +296,79 @@ io::Result TlsSocket::Recv(const io::MutableBytes& mb, int flags) { } io::Result TlsSocket::WriteSome(const iovec* ptr, uint32_t len) { - // Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16. - // IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes. - constexpr size_t kBufferSize = 1392; - io::Result res; - size_t total_sent = 0; + while (true) { + io::Result push_res = PushToEngine(ptr, len); + if (!push_res) { + return make_unexpected(push_res.error()); + } - while (len) { - if (ptr->iov_len > kBufferSize || len == 1) { - res = SendBuffer(Engine::Buffer{reinterpret_cast(ptr->iov_base), ptr->iov_len}); - ptr++; - len--; - } else { - alignas(64) uint8_t scratch[kBufferSize]; - size_t buffered_size = 0; - while (len && (buffered_size + ptr->iov_len) <= kBufferSize) { - std::memcpy(scratch + buffered_size, ptr->iov_base, ptr->iov_len); - buffered_size += ptr->iov_len; - ptr++; - len--; + if (push_res->engine_opcode < 0) { + auto ec = HandleOp(push_res->engine_opcode); + if (ec) { + VLOG(1) << "HandleOp failed " << ec.message(); + return make_unexpected(ec); } - res = SendBuffer({scratch, buffered_size}); } - if (!res) { - return res; + + if (push_res->written > 0) { + auto ec = MaybeSendOutput(); + if (ec) { + VLOG(1) << "MaybeSendOutput failed " << ec.message(); + return make_unexpected(ec); + } + return push_res->written; } - total_sent += *res; } - return total_sent; } -io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { - DVLOG(2) << "TlsSocket::SendBuffer " << buf.size() << " bytes"; +io::Result TlsSocket::PushToEngine(const iovec* ptr, uint32_t len) { + PushResult res; - // Sending buffer into ssl. - DCHECK(engine_); - DCHECK_GT(buf.size(), 0u); + // Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16. + // IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes. + static constexpr size_t kBatchSize = 1392; - size_t send_total = 0; - SpinCounter spin_count(20); + while (len) { + Engine::OpResult op_result; + Engine::Buffer buf; - while (true) { - Engine::OpResult op_result = engine_->Write(buf); - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } + if (ptr->iov_len >= kBatchSize || len == 1) { + buf = {reinterpret_cast(ptr->iov_base), ptr->iov_len}; + op_result = engine_->Write(buf); + ptr++; + len--; + } else { + size_t batch_size = 0; + uint8_t batch_buf[kBatchSize]; - int op_val = *op_result; + do { + std::memcpy(batch_buf + batch_size, ptr->iov_base, ptr->iov_len); + batch_size += ptr->iov_len; + ptr++; + len--; + } while (len && (batch_size + ptr->iov_len) <= kBatchSize); - if (op_val > 0) { - send_total += op_val; + buf = {batch_buf, batch_size}; - if (size_t(op_val) == buf.size()) { - break; - } - spin_count.Reset(); - buf.remove_prefix(op_val); - continue; + // In general we should pass the same arguments in case of retries, but since we + // configure the engine with SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, we can change the + // buffer between retries. + op_result = engine_->Write(buf); } - if (spin_count.Check(op_val <= 0)) { - // Once every 30 seconds. - LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() - << " Spins: " << spin_count.Spins(); + int op_val = op_result; + if (op_val < 0) { + res.engine_opcode = op_val; + return res; } - error_code ec = HandleOp(op_val); - if (ec) - return make_unexpected(ec); - } - - // Usually we want to batch writes as much as possible, but here we can not now if more writes - // will follow. We must flush the output buffer, so that data will be sent down the socket. - error_code ec = MaybeSendOutput(); - if (ec) { - return make_unexpected(ec); + CHECK_GT(op_val, 0); + res.written += op_val; + if (unsigned(op_val) != buf.size()) { + break; // need to flush the SSL output buffer to the underlying socket. + } } - - return send_total; + return res; } // TODO: to implement async functionality. @@ -495,7 +482,7 @@ error_code TlsSocket::HandleOp(int op_val) { switch (op_val) { case Engine::EOF_STREAM: VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle(); - return make_error_code(errc::connection_reset); + return make_error_code(errc::connection_aborted); case Engine::NEED_READ_AND_MAYBE_WRITE: return HandleUpstreamRead(); case Engine::NEED_WRITE: diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 16627551..d4edf7cd 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -89,7 +89,16 @@ class TlsSocket final : public FiberSocketBase { virtual void SetProactor(ProactorBase* p) override; private: - io::Result SendBuffer(Buffer buf); + + struct PushResult { + size_t written = 0; + int engine_opcode = 0; // Engine::OpCode + }; + + // Pushes the buffers into input ssl buffer until either everything is written, + // or an error occurs or the engine needs to flush its output. Does not interact with the network, + // just with the engine. It's up to the caller to send the output buffer to the network. + io::Result PushToEngine(const iovec* ptr, uint32_t len); /// Feed encrypted data from the TLS engine into the network socket. error_code MaybeSendOutput();