Skip to content

Commit

Permalink
fix: dragonfly_connection should only access the original reply_build…
Browse files Browse the repository at this point in the history
…er (#3924)

ConnectionContext.reply_builder can be injected and replaced by the service logic.
before - dragonfly_connection accessed it via cc_->reply_builder in some places,
which led it to access the injected object. Moreover, EVAL commands can be offloaded
to another thread and that thread could inject the object, making the access to cc_->reply_builder_
non thread-safe.

Now dragonfly_connection copies aside the replier_builder_ pointer, and uses only this pointer for communicating with client.

Also, remove redundant arguments.

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange authored Oct 15, 2024
1 parent c3f9ec1 commit e9e169d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 53 deletions.
93 changes: 47 additions & 46 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,6 @@ void Connection::HandleRequests() {

auto remote_ep = RemoteEndpointStr();

FiberSocketBase* peer = socket_.get();
#ifdef DFLY_USE_SSL
if (ssl_ctx_) {
const bool no_tls_on_admin_port = GetFlag(FLAGS_no_tls_on_admin_port);
Expand All @@ -680,7 +679,7 @@ void Connection::HandleRequests() {
VLOG(1) << "Bad TLS header "
<< absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2),
absl::Hex(buf[1], absl::kZeroPad2));
peer->Write(
socket_->Write(
io::Buffer("-ERR Bad TLS header, double check "
"if you enabled TLS for your client.\r\n"));
}
Expand All @@ -697,15 +696,14 @@ void Connection::HandleRequests() {
LOG(WARNING) << "Error handshaking " << aresult.error().message();
return;
}
peer = socket_.get();
VLOG(1) << "TLS handshake succeeded";
}
}
#endif

io::Result<bool> http_res{false};

http_res = CheckForHttpProto(peer);
http_res = CheckForHttpProto();

// We need to check if the socket is open because the server might be
// shutting down. During the shutdown process, the server iterates over
Expand All @@ -717,12 +715,14 @@ void Connection::HandleRequests() {
// because both Write and Recv internally check if the socket was shut
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(peer, this));
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder();

if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
is_http_ = true;
HttpConnection http_conn{http_listener_};
http_conn.SetSocket(peer);
http_conn.SetSocket(socket_.get());
http_conn.set_user_data(cc_.get());

// We validate the http request using basic-auth inside HttpConnection::HandleSingleRequest.
Expand All @@ -741,13 +741,14 @@ void Connection::HandleRequests() {
socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); });
}

ConnectionFlow(peer);
ConnectionFlow();

socket_->CancelOnErrorCb(); // noop if nothing is registered.
}
VLOG(1) << "Closed connection for peer "
<< GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex());
cc_.reset();
reply_builder_ = nullptr;
}
}

Expand Down Expand Up @@ -801,13 +802,14 @@ std::pair<std::string, std::string> Connection::GetClientInfoBeforeAfterTid() co
string_view phase_name = PHASE_NAMES[phase_];

if (cc_) {
DCHECK(cc_->reply_builder());
DCHECK(cc_->reply_builder() && reply_builder_);
string cc_info = service_->GetContextInfo(cc_.get()).Format();
if (cc_->reply_builder()->IsSendActive())
if (reply_builder_->IsSendActive())
phase_name = "send";
absl::StrAppend(&after, " ", cc_info);
}
absl::StrAppend(&after, " phase=", phase_name);

return {std::move(before), std::move(after)};
}

Expand Down Expand Up @@ -872,7 +874,7 @@ const absl::flat_hash_map<string, uint64_t>& Connection::GetLibStatsTL() {
return g_libname_ver_map;
}

io::Result<bool> Connection::CheckForHttpProto(FiberSocketBase* peer) {
io::Result<bool> Connection::CheckForHttpProto() {
if (!IsPrivileged() && !IsMain()) {
return false;
}
Expand All @@ -883,6 +885,7 @@ io::Result<bool> Connection::CheckForHttpProto(FiberSocketBase* peer) {
}

size_t last_len = 0;
auto* peer = socket_.get();
do {
auto buf = io_buf_.AppendBuffer();
DCHECK(!buf.empty());
Expand Down Expand Up @@ -916,35 +919,34 @@ io::Result<bool> Connection::CheckForHttpProto(FiberSocketBase* peer) {
return false;
}

void Connection::ConnectionFlow(FiberSocketBase* peer) {
void Connection::ConnectionFlow() {
++stats_->num_conns;
++stats_->conn_received_cnt;
stats_->read_buf_capacity += io_buf_.Capacity();

ParserStatus parse_status = OK;
SinkReplyBuilder* orig_builder = cc_->reply_builder();

// At the start we read from the socket to determine the HTTP/Memstore protocol.
// Therefore we may already have some data in the buffer.
if (io_buf_.InputLen() > 0) {
phase_ = PROCESS;
if (redis_parser_) {
parse_status = ParseRedis(orig_builder);
parse_status = ParseRedis();
} else {
DCHECK(memcache_parser_);
parse_status = ParseMemcache();
}
}

error_code ec = orig_builder->GetError();
error_code ec = reply_builder_->GetError();

// Main loop.
if (parse_status != ERROR && !ec) {
if (io_buf_.AppendLen() < 64) {
UpdateIoBufCapacity(io_buf_, stats_,
[&]() { io_buf_.EnsureCapacity(io_buf_.Capacity() * 2); });
}
auto res = IoLoop(peer, orig_builder);
auto res = IoLoop();

if (holds_alternative<error_code>(res)) {
ec = get<error_code>(res);
Expand Down Expand Up @@ -975,10 +977,10 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) {
VLOG(1) << "Error parser status " << parser_error_;

if (redis_parser_) {
SendProtocolError(RedisParser::Result(parser_error_), orig_builder);
SendProtocolError(RedisParser::Result(parser_error_), reply_builder_);
} else {
DCHECK(memcache_parser_);
orig_builder->SendProtocolError("bad command line format");
reply_builder_->SendProtocolError("bad command line format");
}

// Shut down the servers side of the socket to send a FIN to the client
Expand All @@ -988,15 +990,15 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) {
// Otherwise the clients write could fail (or block), so they would never
// read the above protocol error (see issue #1327).
// TODO: we have a bug that can potentially deadlock the code below.
// If a peer does not close the socket on the other side, the while loop will never finish.
// If the socket does not close the socket on the other side, the while loop will never finish.
// to reproduce: nc localhost 6379 and then run invalid sequence: *1 <enter> *1 <enter>
error_code ec2 = peer->Shutdown(SHUT_WR);
error_code ec2 = socket_->Shutdown(SHUT_WR);
LOG_IF(WARNING, ec2) << "Could not shutdown socket " << ec2;
if (!ec2) {
while (true) {
// Discard any received data.
io_buf_.Clear();
if (!peer->Recv(io_buf_.AppendBuffer())) {
if (!socket_->Recv(io_buf_.AppendBuffer())) {
break;
}
}
Expand Down Expand Up @@ -1064,7 +1066,7 @@ void Connection::DispatchSingle(bool has_more, absl::FunctionRef<void()> invoke_
}
}

Connection::ParserStatus Connection::ParseRedis(SinkReplyBuilder* orig_builder) {
Connection::ParserStatus Connection::ParseRedis() {
uint32_t consumed = 0;
RedisParser::Result result = RedisParser::OK;

Expand Down Expand Up @@ -1096,7 +1098,7 @@ Connection::ParserStatus Connection::ParseRedis(SinkReplyBuilder* orig_builder)
DispatchSingle(has_more, dispatch_sync, dispatch_async);
}
io_buf_.ConsumeInput(consumed);
} while (RedisParser::OK == result && !orig_builder->GetError());
} while (RedisParser::OK == result && !reply_builder_->GetError());

parser_error_ = result;
if (result == RedisParser::OK)
Expand All @@ -1120,7 +1122,7 @@ auto Connection::ParseMemcache() -> ParserStatus {
return {make_unique<MCPipelineMessage>(std::move(cmd), value)};
};

MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(cc_->reply_builder());
MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(reply_builder_);

do {
string_view str = ToSV(io_buf_.InputBuffer());
Expand Down Expand Up @@ -1179,10 +1181,10 @@ void Connection::OnBreakCb(int32_t mask) {
return;
}

DCHECK(cc_->reply_builder()) << "[" << id_ << "] " << phase_ << " " << migration_in_process_;
DCHECK(reply_builder_) << "[" << id_ << "] " << phase_ << " " << migration_in_process_;

VLOG(1) << "[" << id_ << "] Got event " << mask << " " << phase_ << " "
<< cc_->reply_builder()->IsSendActive() << " " << cc_->reply_builder()->GetError();
<< reply_builder_->IsSendActive() << " " << reply_builder_->GetError();

cc_->conn_closing = true;
BreakOnce(mask);
Expand Down Expand Up @@ -1212,24 +1214,22 @@ void Connection::HandleMigrateRequest() {
// We need to return early as the socket is closing and IoLoop will clean up.
// The reason that this is true is because of the following DCHECK
DCHECK(!dispatch_fb_.IsJoinable());

// which can never trigger since we Joined on the dispatch_fb_ above and we are
// atomic in respect to our proactor meaning that no other fiber will
// launch the DispatchFiber.
if (!this->Migrate(dest)) {
return;
}
}

// In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started.
LaunchDispatchFiberIfNeeded();
}

auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_builder)
-> variant<error_code, ParserStatus> {
auto Connection::IoLoop() -> variant<error_code, ParserStatus> {
error_code ec;
ParserStatus parse_status = OK;

size_t max_iobfuf_len = GetFlag(FLAGS_max_client_iobuf_len);
auto* peer = socket_.get();

do {
HandleMigrateRequest();
Expand All @@ -1256,7 +1256,7 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil
bool is_iobuf_full = io_buf_.AppendLen() == 0;

if (redis_parser_) {
parse_status = ParseRedis(orig_builder);
parse_status = ParseRedis();
} else {
DCHECK(memcache_parser_);
parse_status = ParseMemcache();
Expand Down Expand Up @@ -1303,7 +1303,7 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil
} else if (parse_status != OK) {
break;
}
ec = orig_builder->GetError();
ec = reply_builder_->GetError();
} while (peer->IsOpen() && !ec);

if (ec)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ bool Connection::ShouldEndDispatchFiber(const MessageHandle& msg) {
return false;
}

void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) {
void Connection::SquashPipeline() {
DCHECK_EQ(dispatch_q_.size(), pending_pipeline_cmd_cnt_);

vector<CmdArgList> squash_cmds;
Expand All @@ -1356,8 +1356,8 @@ void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) {
size_t dispatched = service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), cc_.get());

if (pending_pipeline_cmd_cnt_ == squash_cmds.size()) { // Flush if no new commands appeared
builder->FlushBatch();
builder->SetBatchMode(false); // in case the next dispatch is sync
reply_builder_->FlushBatch();
reply_builder_->SetBatchMode(false); // in case the next dispatch is sync
}

cc_->async_dispatch = false;
Expand All @@ -1376,7 +1376,7 @@ void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) {
}

void Connection::ClearPipelinedMessages() {
DispatchOperations dispatch_op{cc_->reply_builder(), this};
DispatchOperations dispatch_op{reply_builder_, this};

// Recycle messages even from disconnecting client to keep properly track of memory stats
// As well as to avoid pubsub backpressure leakege.
Expand Down Expand Up @@ -1421,17 +1421,17 @@ std::string Connection::DebugInfo() const {
// into the dispatch queue and DispatchFiber will run those commands asynchronously with
// InputLoop. Note: in some cases, InputLoop may decide to dispatch directly and bypass the
// DispatchFiber.
void Connection::ExecutionFiber(util::FiberSocketBase* peer) {
void Connection::ExecutionFiber() {
ThisFiber::SetName("ExecutionFiber");
SinkReplyBuilder* builder = cc_->reply_builder();
DispatchOperations dispatch_op{builder, this};

DispatchOperations dispatch_op{reply_builder_, this};

size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash);

uint64_t prev_epoch = fb2::FiberSwitchEpoch();
fb2::NoOpLock noop_lk;

while (!builder->GetError()) {
while (!reply_builder_->GetError()) {
DCHECK_EQ(socket()->proactor(), ProactorBase::me());
cnd_.wait(noop_lk, [this] {
return cc_->conn_closing || (!dispatch_q_.empty() && !cc_->sync_dispatch);
Expand All @@ -1455,7 +1455,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) {
}
prev_epoch = cur_epoch;

builder->SetBatchMode(dispatch_q_.size() > 1);
reply_builder_->SetBatchMode(dispatch_q_.size() > 1);

bool subscriber_over_limit =
stats_->dispatch_queue_subscriber_bytes >= queue_backpressure_->publish_buffer_limit;
Expand All @@ -1468,7 +1468,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) {
bool threshold_reached = pending_pipeline_cmd_cnt_ > squashing_threshold;
bool are_all_plain_cmds = pending_pipeline_cmd_cnt_ == dispatch_q_.size();
if (squashing_enabled && threshold_reached && are_all_plain_cmds && !skip_next_squashing_) {
SquashPipeline(builder);
SquashPipeline();
} else {
MessageHandle msg = std::move(dispatch_q_.front());
dispatch_q_.pop_front();
Expand All @@ -1477,7 +1477,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) {
// last command to reply and flush. If it doesn't reply (i.e. is a control message like
// migrate), we have to flush manually.
if (dispatch_q_.empty() && !msg.IsReplying()) {
builder->FlushBatch();
reply_builder_->FlushBatch();
}

if (ShouldEndDispatchFiber(msg)) {
Expand Down Expand Up @@ -1505,7 +1505,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) {
queue_backpressure_->pubsub_ec.notify();
}

DCHECK(cc_->conn_closing || builder->GetError());
DCHECK(cc_->conn_closing || reply_builder_->GetError());
cc_->conn_closing = true;
queue_backpressure_->pipeline_cnd.notify_all();
}
Expand Down Expand Up @@ -1582,6 +1582,7 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) {
}

listener()->Migrate(this, dest);

// After we migrate, it could be the case the connection was shut down. We should
// act accordingly.
if (!socket()->IsOpen()) {
Expand Down Expand Up @@ -1646,8 +1647,8 @@ void Connection::SendInvalidationMessageAsync(InvalidationMessage msg) {
void Connection::LaunchDispatchFiberIfNeeded() {
if (!dispatch_fb_.IsJoinable() && !migration_in_process_) {
VLOG(1) << "[" << id_ << "] LaunchDispatchFiberIfNeeded ";
dispatch_fb_ = fb2::Fiber(fb2::Launch::post, "connection_dispatch",
[this, peer = socket_.get()]() { ExecutionFiber(peer); });
dispatch_fb_ =
fb2::Fiber(fb2::Launch::post, "connection_dispatch", [this]() { ExecutionFiber(); });
}
}

Expand Down
Loading

0 comments on commit e9e169d

Please sign in to comment.