From a9b13d86b7a6db96924f0d18269ff9ff1424c503 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 21 Dec 2024 22:44:01 +0200 Subject: [PATCH] chore: reimplement TlsSocket::WriteSome Allow cleaner separation between writing to ssl and flushing the ssl output buffer into a socket. Signed-off-by: Roman Gershman --- util/tls/tls_engine.cc | 91 +++++++++-------- util/tls/tls_engine.h | 12 ++- util/tls/tls_engine_test.cc | 22 ++-- util/tls/tls_socket.cc | 193 ++++++++++++++++-------------------- util/tls/tls_socket.h | 11 +- 5 files changed, 162 insertions(+), 167 deletions(-) diff --git a/util/tls/tls_engine.cc b/util/tls/tls_engine.cc index bc94e573..e8684909 100644 --- a/util/tls/tls_engine.cc +++ b/util/tls/tls_engine.cc @@ -33,38 +33,6 @@ static void ClearSslError() { } while (l); } -static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* location) { - DCHECK_LE(result, 0); - - unsigned long error = ERR_get_error(); - if (error != 0) { - return nonstd::make_unexpected(error); - } - - int ssl_error = SSL_get_error(ssl, result); - int io_err = errno; - - switch (ssl_error) { - case SSL_ERROR_ZERO_RETURN: - break; - case SSL_ERROR_WANT_READ: - return Engine::NEED_READ_AND_MAYBE_WRITE; - case SSL_ERROR_WANT_WRITE: - return Engine::NEED_WRITE; - case SSL_ERROR_SYSCALL: - LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location; - break; - case SSL_ERROR_SSL: - LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location; - break; - default: - LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location; - break; - } - - return Engine::EOF_STREAM; -} - #define S1(x) #x #define S2(x) S1(x) #define LOCATION __FILE__ " : " S2(__LINE__) @@ -72,7 +40,7 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat #define RETURN_RESULT(res) \ if (res > 0) \ return res; \ - return ToOpResult(ssl_, res, LOCATION) + return ToOpResult(res, LOCATION) Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) { CHECK(ssl_); @@ -135,15 +103,13 @@ void Engine::ConsumeOutputBuf(unsigned sz) { CHECK_EQ(unsigned(res), sz); } -auto Engine::WriteBuf(const Buffer& buf) -> OpResult { +unsigned Engine::WriteBuf(const Buffer& buf) { DCHECK(!buf.empty()); char* cbuf = nullptr; int res = BIO_nwrite(external_bio_, &cbuf, buf.size()); - if (res < 0) { - unsigned long error = ::ERR_get_error(); - return nonstd::make_unexpected(error); - } else if (res > 0) { + CHECK_GE(res, 0); + if (res > 0) { memcpy(cbuf, buf.data(), res); } return res; @@ -171,20 +137,25 @@ auto Engine::Handshake(HandshakeType type) -> OpResult { } auto Engine::Shutdown() -> OpResult { + if (state_ & FATAL_ERROR) + return 1; + int result = SSL_shutdown(ssl_); // See https://www.openssl.org/docs/man1.1.1/man3/SSL_shutdown.html + + // TODO: to handle correctly. if (result == 0) // First step of Shutdown (close_notify) returns 0. - return result; + return 1; RETURN_RESULT(result); } auto Engine::Write(const Buffer& buf) -> OpResult { - if (buf.empty()) - return 0; + CHECK(!buf.empty()); int sz = buf.size() < INT_MAX ? buf.size() : INT_MAX; int result = SSL_write(ssl_, buf.data(), sz); - RETURN_RESULT(result); + + RETURN_RESULT(result); // Should never return 0. } auto Engine::Read(uint8_t* dest, size_t len) -> OpResult { @@ -245,5 +216,41 @@ int SslProbeSetDefaultCALocation(SSL_CTX* ctx) { return -1; } + +auto Engine::ToOpResult(int result, const char* location) -> OpResult { + DCHECK_LE(result, 0); + + int ssl_error = SSL_get_error(ssl_, result); + unsigned long queue_error = 0; + +#define ERROR_DETAILS errno << ":" << queue_error << " " \ + << ERR_reason_error_string(queue_error) << " " << location + + switch (ssl_error) { + case SSL_ERROR_ZERO_RETURN: // graceful shutdown of TLS connection. + break; + case SSL_ERROR_WANT_READ: + return Engine::NEED_READ_AND_MAYBE_WRITE; + case SSL_ERROR_WANT_WRITE: + return Engine::NEED_WRITE; + case SSL_ERROR_SYSCALL: // fatal error in system call. + queue_error = ERR_get_error(); + VLOG(1) << "SSL syscall error " << ERROR_DETAILS; + break; + case SSL_ERROR_SSL: + state_ |= FATAL_ERROR; + queue_error = ERR_get_error(); + LOG(WARNING) << "SSL protocol error " << ERROR_DETAILS; + break; + default: + queue_error = ERR_get_error(); + state_ |= FATAL_ERROR; + LOG(WARNING) << "Unexpected SSL error " << ssl_error << " " << ERROR_DETAILS; + break; + } + + return Engine::EOF_STREAM; +} + } // namespace tls } // namespace util diff --git a/util/tls/tls_engine.h b/util/tls/tls_engine.h index 4c07bdcf..ab3e09c5 100644 --- a/util/tls/tls_engine.h +++ b/util/tls/tls_engine.h @@ -42,7 +42,7 @@ class Engine { // if value == NEED_XXX then it means that it should either write data to IO and then read or just // write. In any case for non-error OpResult a caller must check OutputPending and write the // output buffer to the appropriate channel. - using OpResult = io::Result; + using OpResult = int; // Construct a new engine for the specified context. explicit Engine(SSL_CTX* context); @@ -65,6 +65,7 @@ class Engine { // SSL_accept (server-side). OpResult Handshake(HandshakeType type); + // Returns 1 if succeeded or negative opcodes to execute upon. OpResult Shutdown(); // Write bytes to the SSL session. Non-negative value - says how much was written. @@ -85,7 +86,7 @@ class Engine { //! Writes the buffer into input ssl buffer. //! Returns number of written bytes or the error. //! TODO: should be replaced with PeekInputBuf, memcpy, CommitInput sequence. - OpResult WriteBuf(const Buffer& buf); + unsigned WriteBuf(const Buffer& buf); // We usually use this function to write from the raw socket to SSL engine. // Returns direct reference to the input (write) buffer. This operation is not destructive. @@ -118,6 +119,13 @@ class Engine { // Perform one operation. Returns > 0 on success. using EngineOp = int (Engine::*)(void*, std::size_t); + OpResult ToOpResult(int result, const char* location); + + enum StateMask { + FATAL_ERROR = 1, + }; + + uint8_t state_ = 0; SSL* ssl_; BIO* external_bio_; }; diff --git a/util/tls/tls_engine_test.cc b/util/tls/tls_engine_test.cc index 11bde551..be6ffdb6 100644 --- a/util/tls/tls_engine_test.cc +++ b/util/tls/tls_engine_test.cc @@ -11,7 +11,6 @@ #include "base/gtest.h" #include "base/logging.h" #include "util/fibers/fibers.h" -#include "util/tls/tls_socket.h" namespace util { namespace tls { @@ -133,11 +132,9 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb Engine* dest) { unsigned input_pending = 1; while (true) { - auto op_result = cb(src); - if (!op_result) { - return op_result.error(); - } - VLOG(1) << opts.name << " OpResult: " << *op_result; + Engine::OpResult op_result = cb(src); + + VLOG(1) << opts.name << " OpResult: " << op_result; unsigned output_pending = src->OutputPending(); if (output_pending > 0) { auto buffer = src->PeekOutputBuf(); @@ -154,18 +151,15 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb src->ConsumeOutputBuf(buffer.size()); } else { auto write_result = dest->WriteBuf(buffer); - if (!write_result) { - return write_result.error(); - } - CHECK_GT(*write_result, 0); - src->ConsumeOutputBuf(*write_result); + CHECK_GT(write_result, 0u); + src->ConsumeOutputBuf(write_result); } } - if (*op_result >= 0) { // Shutdown or empty read/write may return 0. + if (op_result >= 0) { // Shutdown or empty read/write may return 0. return 0; } - if (*op_result == Engine::EOF_STREAM) { + if (op_result == Engine::EOF_STREAM) { LOG(WARNING) << opts.name << " stream truncated"; return 0; } @@ -276,7 +270,7 @@ TEST_F(SslStreamTest, HandshakeErrServer) { LOG(INFO) << SSLError(cl_err); LOG(INFO) << SSLError(srv_err); - ASSERT_NE(0, cl_err); + ASSERT_EQ(0, cl_err); } TEST_F(SslStreamTest, ReadShutdown) { diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index f476958c..f93dd115 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -15,8 +15,8 @@ #define VSOCK(verbosity) \ VLOG(verbosity) << "sock[" << native_handle() << "], state " << int(state_) \ - << ", write_total:" << upstream_write_ << " " << " pending output: " \ - << engine_->OutputPending() << " " + << ", write_total:" << upstream_write_ << " " \ + << " pending output: " << engine_->OutputPending() << " " namespace util { namespace tls { @@ -91,8 +91,7 @@ void TlsSocket::InitSSL(SSL_CTX* context, Buffer prefix) { engine_.reset(new Engine{context}); if (!prefix.empty()) { Engine::OpResult op_result = engine_->WriteBuf(prefix); - CHECK(op_result); - CHECK_EQ(unsigned(*op_result), prefix.size()); + CHECK_EQ(unsigned(op_result), prefix.size()); } } @@ -131,6 +130,14 @@ auto TlsSocket::Accept() -> AcceptResult { while (true) { Engine::OpResult op_result = engine_->Handshake(Engine::SERVER); + if (op_result == Engine::EOF_STREAM) { + return make_unexpected(make_error_code(errc::connection_aborted)); + } + + if (op_result == 1) { // Success. + break; + } + // it is important to send output (protocol errors) before we return from this function. error_code ec = MaybeSendOutput(); if (ec) { @@ -138,18 +145,7 @@ auto TlsSocket::Accept() -> AcceptResult { return make_unexpected(ec); } - // now check the result of the handshake. - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } - - int op_val = *op_result; - - if (op_val >= 0) { // Shutdown or empty read/write may return 0. - break; - } - - ec = HandleOp(op_val); + ec = HandleOp(op_result); if (ec) return make_unexpected(ec); } @@ -160,29 +156,25 @@ auto TlsSocket::Accept() -> AcceptResult { error_code TlsSocket::Connect(const endpoint_type& endpoint, std::function on_pre_connect) { DCHECK(engine_); - Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); - if (!op_result) { - return std::error_code(op_result.error(), std::system_category()); - } + while (true) { + Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); + if (op_result == 1) { + break; + } - // If the socket is already open, we should not call connect on it - if (!IsOpen()) { - RETURN_ON_ERROR(next_sock_->Connect(endpoint, std::move(on_pre_connect))); - } + if (op_result == Engine::EOF_STREAM) { + return make_error_code(errc::connection_refused); + } - // Flush the ssl data to the socket and run the loop that ensures handshaking converges. - int op_val = *op_result; + // If the socket is already open, we should not call connect on it + if (!IsOpen()) { + RETURN_ON_ERROR(next_sock_->Connect(endpoint, std::move(on_pre_connect))); + } - // it should guide us to write and then read. - DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE); - while (op_val < 0) { - RETURN_ON_ERROR(HandleOp(op_val)); + // Flush the ssl data to the socket and run the loop that ensures handshaking converges. + int op_val = op_result; - op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); - if (!op_result) { - return std::error_code(op_result.error(), std::system_category()); - } - op_val = *op_result; + RETURN_ON_ERROR(HandleOp(op_val)); } const SSL_CIPHER* cipher = SSL_get_current_cipher(engine_->native_handle()); @@ -241,7 +233,6 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { Engine::MutableBuffer dest{reinterpret_cast(io->iov_base), io->iov_len}; size_t read_total = 0; - SpinCounter spin_count(20); while (true) { DCHECK(!dest.empty()); @@ -249,21 +240,12 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { size_t read_len = std::min(dest.size(), size_t(INT_MAX)); Engine::OpResult op_result = engine_->Read(dest.data(), read_len); - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } - int op_val = *op_result; - DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; + int op_val = op_result; - if (spin_count.Check(op_val <= 0)) { - // Once every 30 seconds. - LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() - << " Spin: " << spin_count.Spins(); - } + DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; if (op_val > 0) { - spin_count.Reset(); read_total += op_val; // I do not understand this code and what the hell I meant to do here. Seems to work @@ -304,84 +286,79 @@ io::Result TlsSocket::Recv(const io::MutableBytes& mb, int flags) { } io::Result TlsSocket::WriteSome(const iovec* ptr, uint32_t len) { - // Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16. - // IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes. - constexpr size_t kBufferSize = 1392; - io::Result res; - size_t total_sent = 0; + while (true) { + io::Result push_res = PushToEngine(ptr, len); + if (!push_res) { + return make_unexpected(push_res.error()); + } - while (len) { - if (ptr->iov_len > kBufferSize || len == 1) { - res = SendBuffer(Engine::Buffer{reinterpret_cast(ptr->iov_base), ptr->iov_len}); - ptr++; - len--; - } else { - alignas(64) uint8_t scratch[kBufferSize]; - size_t buffered_size = 0; - while (len && (buffered_size + ptr->iov_len) <= kBufferSize) { - std::memcpy(scratch + buffered_size, ptr->iov_base, ptr->iov_len); - buffered_size += ptr->iov_len; - ptr++; - len--; + if (push_res->engine_opcode < 0) { + auto ec = HandleOp(push_res->engine_opcode); + if (ec) { + VLOG(1) << "HandleOp failed " << ec.message(); + return make_unexpected(ec); } - res = SendBuffer({scratch, buffered_size}); } - if (!res) { - return res; + + if (push_res->written > 0) { + auto ec = MaybeSendOutput(); + if (ec) { + VLOG(1) << "MaybeSendOutput failed " << ec.message(); + return make_unexpected(ec); + } + return push_res->written; } - total_sent += *res; } - return total_sent; } -io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { - DVLOG(2) << "TlsSocket::SendBuffer " << buf.size() << " bytes"; +io::Result TlsSocket::PushToEngine(const iovec* ptr, uint32_t len) { + PushResult res; - // Sending buffer into ssl. - DCHECK(engine_); - DCHECK_GT(buf.size(), 0u); + // Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16. + // IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes. + static constexpr size_t kBatchSize = 1392; - size_t send_total = 0; - SpinCounter spin_count(20); + while (len) { + Engine::OpResult op_result; + Engine::Buffer buf; - while (true) { - Engine::OpResult op_result = engine_->Write(buf); - if (!op_result) { - return make_unexpected(SSL2Error(__LINE__, op_result.error())); - } + if (ptr->iov_len >= kBatchSize || len == 1) { + buf = {reinterpret_cast(ptr->iov_base), ptr->iov_len}; + op_result = engine_->Write(buf); + ptr++; + len--; + } else { + size_t batch_size = 0; + uint8_t batch_buf[kBatchSize]; - int op_val = *op_result; + do { + std::memcpy(batch_buf + batch_size, ptr->iov_base, ptr->iov_len); + batch_size += ptr->iov_len; + ptr++; + len--; + } while (len && (batch_size + ptr->iov_len) <= kBatchSize); - if (op_val > 0) { - send_total += op_val; + buf = {batch_buf, batch_size}; - if (size_t(op_val) == buf.size()) { - break; - } - spin_count.Reset(); - buf.remove_prefix(op_val); - continue; + // In general we should pass the same arguments in case of retries, but since we + // configure the engine with SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, we can change the + // buffer between retries. + op_result = engine_->Write(buf); } - if (spin_count.Check(op_val <= 0)) { - // Once every 30 seconds. - LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() - << " Spins: " << spin_count.Spins(); + int op_val = op_result; + if (op_val < 0) { + res.engine_opcode = op_val; + return res; } - error_code ec = HandleOp(op_val); - if (ec) - return make_unexpected(ec); - } - - // Usually we want to batch writes as much as possible, but here we can not now if more writes - // will follow. We must flush the output buffer, so that data will be sent down the socket. - error_code ec = MaybeSendOutput(); - if (ec) { - return make_unexpected(ec); + CHECK_GT(op_val, 0); + res.written += op_val; + if (unsigned(op_val) != buf.size()) { + break; // need to flush the SSL output buffer to the underlying socket. + } } - - return send_total; + return res; } // TODO: to implement async functionality. @@ -495,7 +472,7 @@ error_code TlsSocket::HandleOp(int op_val) { switch (op_val) { case Engine::EOF_STREAM: VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle(); - return make_error_code(errc::connection_reset); + return make_error_code(errc::connection_aborted); case Engine::NEED_READ_AND_MAYBE_WRITE: return HandleUpstreamRead(); case Engine::NEED_WRITE: diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 16627551..d4edf7cd 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -89,7 +89,16 @@ class TlsSocket final : public FiberSocketBase { virtual void SetProactor(ProactorBase* p) override; private: - io::Result SendBuffer(Buffer buf); + + struct PushResult { + size_t written = 0; + int engine_opcode = 0; // Engine::OpCode + }; + + // Pushes the buffers into input ssl buffer until either everything is written, + // or an error occurs or the engine needs to flush its output. Does not interact with the network, + // just with the engine. It's up to the caller to send the output buffer to the network. + io::Result PushToEngine(const iovec* ptr, uint32_t len); /// Feed encrypted data from the TLS engine into the network socket. error_code MaybeSendOutput();