Skip to content

Commit

Permalink
chore: reimplement TlsSocket::WriteSome (#351)
Browse files Browse the repository at this point in the history
Allow cleaner separation between writing to ssl and flushing the ssl output buffer into a socket.

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange authored Dec 22, 2024
1 parent 04fd389 commit b508ad5
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 212 deletions.
93 changes: 50 additions & 43 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,46 +33,14 @@ static void ClearSslError() {
} while (l);
}

static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* location) {
DCHECK_LE(result, 0);

unsigned long error = ERR_get_error();
if (error != 0) {
return nonstd::make_unexpected(error);
}

int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

return Engine::EOF_STREAM;
}

#define S1(x) #x
#define S2(x) S1(x)
#define LOCATION __FILE__ " : " S2(__LINE__)

#define RETURN_RESULT(res) \
if (res > 0) \
return res; \
return ToOpResult(ssl_, res, LOCATION)
return ToOpResult(res, LOCATION)

Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) {
CHECK(ssl_);
Expand Down Expand Up @@ -135,15 +103,13 @@ void Engine::ConsumeOutputBuf(unsigned sz) {
CHECK_EQ(unsigned(res), sz);
}

auto Engine::WriteBuf(const Buffer& buf) -> OpResult {
unsigned Engine::WriteBuf(const Buffer& buf) {
DCHECK(!buf.empty());

char* cbuf = nullptr;
int res = BIO_nwrite(external_bio_, &cbuf, buf.size());
if (res < 0) {
unsigned long error = ::ERR_get_error();
return nonstd::make_unexpected(error);
} else if (res > 0) {
CHECK_GE(res, 0);
if (res > 0) {
memcpy(cbuf, buf.data(), res);
}
return res;
Expand Down Expand Up @@ -171,20 +137,25 @@ auto Engine::Handshake(HandshakeType type) -> OpResult {
}

auto Engine::Shutdown() -> OpResult {
if (state_ & FATAL_ERROR)
return 1;

int result = SSL_shutdown(ssl_);
// See https://www.openssl.org/docs/man1.1.1/man3/SSL_shutdown.html
if (result == 0) // First step of Shutdown (close_notify) returns 0.
return result;

if (result == 0) { // First step of Shutdown (close_notify) returns 0.
result = SSL_shutdown(ssl_); // Initiate the second step.
}

RETURN_RESULT(result);
}

auto Engine::Write(const Buffer& buf) -> OpResult {
if (buf.empty())
return 0;
CHECK(!buf.empty());
int sz = buf.size() < INT_MAX ? buf.size() : INT_MAX;
int result = SSL_write(ssl_, buf.data(), sz);
RETURN_RESULT(result);

RETURN_RESULT(result); // Should never return 0.
}

auto Engine::Read(uint8_t* dest, size_t len) -> OpResult {
Expand Down Expand Up @@ -245,5 +216,41 @@ int SslProbeSetDefaultCALocation(SSL_CTX* ctx) {
return -1;
}


auto Engine::ToOpResult(int result, const char* location) -> OpResult {
DCHECK_LE(result, 0);

int ssl_error = SSL_get_error(ssl_, result);
unsigned long queue_error = 0;

#define ERROR_DETAILS errno << ":" << queue_error << " " \
<< ERR_reason_error_string(queue_error) << " " << location

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN: // graceful shutdown of TLS connection.
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL: // fatal error in system call.
queue_error = ERR_get_error();
VLOG(1) << "SSL syscall error " << ERROR_DETAILS;
break;
case SSL_ERROR_SSL:
state_ |= FATAL_ERROR;
queue_error = ERR_get_error();
LOG(WARNING) << "SSL protocol error " << ERROR_DETAILS;
break;
default:
queue_error = ERR_get_error();
state_ |= FATAL_ERROR;
LOG(WARNING) << "Unexpected SSL error " << ssl_error << " " << ERROR_DETAILS;
break;
}

return Engine::EOF_STREAM;
}

} // namespace tls
} // namespace util
19 changes: 14 additions & 5 deletions util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Engine {
// if value == NEED_XXX then it means that it should either write data to IO and then read or just
// write. In any case for non-error OpResult a caller must check OutputPending and write the
// output buffer to the appropriate channel.
using OpResult = io::Result<int, unsigned long>;
using OpResult = int;

// Construct a new engine for the specified context.
explicit Engine(SSL_CTX* context);
Expand All @@ -63,14 +63,16 @@ class Engine {

// Perform an SSL handshake using either SSL_connect (client-side) or
// SSL_accept (server-side).
// Returns 1 if succeeded or negative opcodes to execute upon.
OpResult Handshake(HandshakeType type);

// Returns 1 if succeeded or negative opcodes to execute upon.
OpResult Shutdown();

// Write bytes to the SSL session. Non-negative value - says how much was written.
// Write decrypted bytes to the SSL session. Non-negative value - says how much was written.
OpResult Write(const Buffer& data);

// Read bytes from the SSL session.
// Read bytes from the SSL session (decrypted side).
OpResult Read(uint8_t* dest, size_t len);

//! Returns output buffer which is the read buffer of tls engine.
Expand All @@ -82,10 +84,10 @@ class Engine {
//! sz should be not greater than the buffer size from the last PeekOutputBuf() call.
void ConsumeOutputBuf(unsigned sz);

//! Writes the buffer into input ssl buffer.
//! Writes encrypted data into input ssl buffer.
//! Returns number of written bytes or the error.
//! TODO: should be replaced with PeekInputBuf, memcpy, CommitInput sequence.
OpResult WriteBuf(const Buffer& buf);
unsigned WriteBuf(const Buffer& buf);

// We usually use this function to write from the raw socket to SSL engine.
// Returns direct reference to the input (write) buffer. This operation is not destructive.
Expand Down Expand Up @@ -118,6 +120,13 @@ class Engine {
// Perform one operation. Returns > 0 on success.
using EngineOp = int (Engine::*)(void*, std::size_t);

OpResult ToOpResult(int result, const char* location);

enum StateMask {
FATAL_ERROR = 1,
};

uint8_t state_ = 0;
SSL* ssl_;
BIO* external_bio_;
};
Expand Down
Loading

0 comments on commit b508ad5

Please sign in to comment.