Skip to content

Commit

Permalink
chore: reimplement TlsSocket::WriteSome
Browse files Browse the repository at this point in the history
Allow cleaner separation between writing to ssl and flushing the ssl output buffer into a socket.

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Dec 22, 2024
1 parent 04fd389 commit a9b13d8
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 167 deletions.
91 changes: 49 additions & 42 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,46 +33,14 @@ static void ClearSslError() {
} while (l);
}

static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* location) {
DCHECK_LE(result, 0);

unsigned long error = ERR_get_error();
if (error != 0) {
return nonstd::make_unexpected(error);
}

int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

return Engine::EOF_STREAM;
}

#define S1(x) #x
#define S2(x) S1(x)
#define LOCATION __FILE__ " : " S2(__LINE__)

#define RETURN_RESULT(res) \
if (res > 0) \
return res; \
return ToOpResult(ssl_, res, LOCATION)
return ToOpResult(res, LOCATION)

Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) {
CHECK(ssl_);
Expand Down Expand Up @@ -135,15 +103,13 @@ void Engine::ConsumeOutputBuf(unsigned sz) {
CHECK_EQ(unsigned(res), sz);
}

auto Engine::WriteBuf(const Buffer& buf) -> OpResult {
unsigned Engine::WriteBuf(const Buffer& buf) {
DCHECK(!buf.empty());

char* cbuf = nullptr;
int res = BIO_nwrite(external_bio_, &cbuf, buf.size());
if (res < 0) {
unsigned long error = ::ERR_get_error();
return nonstd::make_unexpected(error);
} else if (res > 0) {
CHECK_GE(res, 0);
if (res > 0) {
memcpy(cbuf, buf.data(), res);
}
return res;
Expand Down Expand Up @@ -171,20 +137,25 @@ auto Engine::Handshake(HandshakeType type) -> OpResult {
}

auto Engine::Shutdown() -> OpResult {
if (state_ & FATAL_ERROR)
return 1;

int result = SSL_shutdown(ssl_);
// See https://www.openssl.org/docs/man1.1.1/man3/SSL_shutdown.html

// TODO: to handle correctly.
if (result == 0) // First step of Shutdown (close_notify) returns 0.
return result;
return 1;

RETURN_RESULT(result);
}

auto Engine::Write(const Buffer& buf) -> OpResult {
if (buf.empty())
return 0;
CHECK(!buf.empty());
int sz = buf.size() < INT_MAX ? buf.size() : INT_MAX;
int result = SSL_write(ssl_, buf.data(), sz);
RETURN_RESULT(result);

RETURN_RESULT(result); // Should never return 0.
}

auto Engine::Read(uint8_t* dest, size_t len) -> OpResult {
Expand Down Expand Up @@ -245,5 +216,41 @@ int SslProbeSetDefaultCALocation(SSL_CTX* ctx) {
return -1;
}


auto Engine::ToOpResult(int result, const char* location) -> OpResult {
DCHECK_LE(result, 0);

int ssl_error = SSL_get_error(ssl_, result);
unsigned long queue_error = 0;

#define ERROR_DETAILS errno << ":" << queue_error << " " \
<< ERR_reason_error_string(queue_error) << " " << location

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN: // graceful shutdown of TLS connection.
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL: // fatal error in system call.
queue_error = ERR_get_error();
VLOG(1) << "SSL syscall error " << ERROR_DETAILS;
break;
case SSL_ERROR_SSL:
state_ |= FATAL_ERROR;
queue_error = ERR_get_error();
LOG(WARNING) << "SSL protocol error " << ERROR_DETAILS;
break;
default:
queue_error = ERR_get_error();
state_ |= FATAL_ERROR;
LOG(WARNING) << "Unexpected SSL error " << ssl_error << " " << ERROR_DETAILS;
break;
}

return Engine::EOF_STREAM;
}

} // namespace tls
} // namespace util
12 changes: 10 additions & 2 deletions util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Engine {
// if value == NEED_XXX then it means that it should either write data to IO and then read or just
// write. In any case for non-error OpResult a caller must check OutputPending and write the
// output buffer to the appropriate channel.
using OpResult = io::Result<int, unsigned long>;
using OpResult = int;

// Construct a new engine for the specified context.
explicit Engine(SSL_CTX* context);
Expand All @@ -65,6 +65,7 @@ class Engine {
// SSL_accept (server-side).
OpResult Handshake(HandshakeType type);

// Returns 1 if succeeded or negative opcodes to execute upon.
OpResult Shutdown();

// Write bytes to the SSL session. Non-negative value - says how much was written.
Expand All @@ -85,7 +86,7 @@ class Engine {
//! Writes the buffer into input ssl buffer.
//! Returns number of written bytes or the error.
//! TODO: should be replaced with PeekInputBuf, memcpy, CommitInput sequence.
OpResult WriteBuf(const Buffer& buf);
unsigned WriteBuf(const Buffer& buf);

// We usually use this function to write from the raw socket to SSL engine.
// Returns direct reference to the input (write) buffer. This operation is not destructive.
Expand Down Expand Up @@ -118,6 +119,13 @@ class Engine {
// Perform one operation. Returns > 0 on success.
using EngineOp = int (Engine::*)(void*, std::size_t);

OpResult ToOpResult(int result, const char* location);

enum StateMask {
FATAL_ERROR = 1,
};

uint8_t state_ = 0;
SSL* ssl_;
BIO* external_bio_;
};
Expand Down
22 changes: 8 additions & 14 deletions util/tls/tls_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "util/fibers/fibers.h"
#include "util/tls/tls_socket.h"

namespace util {
namespace tls {
Expand Down Expand Up @@ -133,11 +132,9 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb
Engine* dest) {
unsigned input_pending = 1;
while (true) {
auto op_result = cb(src);
if (!op_result) {
return op_result.error();
}
VLOG(1) << opts.name << " OpResult: " << *op_result;
Engine::OpResult op_result = cb(src);

VLOG(1) << opts.name << " OpResult: " << op_result;
unsigned output_pending = src->OutputPending();
if (output_pending > 0) {
auto buffer = src->PeekOutputBuf();
Expand All @@ -154,18 +151,15 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb
src->ConsumeOutputBuf(buffer.size());
} else {
auto write_result = dest->WriteBuf(buffer);
if (!write_result) {
return write_result.error();
}
CHECK_GT(*write_result, 0);
src->ConsumeOutputBuf(*write_result);
CHECK_GT(write_result, 0u);
src->ConsumeOutputBuf(write_result);
}
}

if (*op_result >= 0) { // Shutdown or empty read/write may return 0.
if (op_result >= 0) { // Shutdown or empty read/write may return 0.
return 0;
}
if (*op_result == Engine::EOF_STREAM) {
if (op_result == Engine::EOF_STREAM) {
LOG(WARNING) << opts.name << " stream truncated";
return 0;
}
Expand Down Expand Up @@ -276,7 +270,7 @@ TEST_F(SslStreamTest, HandshakeErrServer) {
LOG(INFO) << SSLError(cl_err);
LOG(INFO) << SSLError(srv_err);

ASSERT_NE(0, cl_err);
ASSERT_EQ(0, cl_err);
}

TEST_F(SslStreamTest, ReadShutdown) {
Expand Down
Loading

0 comments on commit a9b13d8

Please sign in to comment.