From e9e169d074b73b60aa090f467da3e939407ee10d Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Tue, 15 Oct 2024 15:37:50 +0300 Subject: [PATCH] fix: dragonfly_connection should only access the original reply_builder (#3924) ConnectionContext.reply_builder can be injected and replaced by the service logic. before - dragonfly_connection accessed it via cc_->reply_builder in some places, which led it to access the injected object. Moreover, EVAL commands can be offloaded to another thread and that thread could inject the object, making the access to cc_->reply_builder_ non thread-safe. Now dragonfly_connection copies aside the replier_builder_ pointer, and uses only this pointer for communicating with client. Also, remove redundant arguments. Signed-off-by: Roman Gershman --- src/facade/dragonfly_connection.cc | 93 +++++++++++++++--------------- src/facade/dragonfly_connection.h | 17 +++--- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 9ea6cf5b79f3..fc6e40a5fdea 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -663,7 +663,6 @@ void Connection::HandleRequests() { auto remote_ep = RemoteEndpointStr(); - FiberSocketBase* peer = socket_.get(); #ifdef DFLY_USE_SSL if (ssl_ctx_) { const bool no_tls_on_admin_port = GetFlag(FLAGS_no_tls_on_admin_port); @@ -680,7 +679,7 @@ void Connection::HandleRequests() { VLOG(1) << "Bad TLS header " << absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2), absl::Hex(buf[1], absl::kZeroPad2)); - peer->Write( + socket_->Write( io::Buffer("-ERR Bad TLS header, double check " "if you enabled TLS for your client.\r\n")); } @@ -697,7 +696,6 @@ void Connection::HandleRequests() { LOG(WARNING) << "Error handshaking " << aresult.error().message(); return; } - peer = socket_.get(); VLOG(1) << "TLS handshake succeeded"; } } @@ -705,7 +703,7 @@ void Connection::HandleRequests() { io::Result http_res{false}; - http_res = CheckForHttpProto(peer); + http_res = CheckForHttpProto(); // We need to check if the socket is open because the server might be // shutting down. During the shutdown process, the server iterates over @@ -717,12 +715,14 @@ void Connection::HandleRequests() { // because both Write and Recv internally check if the socket was shut // down and return with an error accordingly. if (http_res && socket_->IsOpen()) { - cc_.reset(service_->CreateContext(peer, this)); + cc_.reset(service_->CreateContext(socket_.get(), this)); + reply_builder_ = cc_->reply_builder(); + if (*http_res) { VLOG(1) << "HTTP1.1 identified"; is_http_ = true; HttpConnection http_conn{http_listener_}; - http_conn.SetSocket(peer); + http_conn.SetSocket(socket_.get()); http_conn.set_user_data(cc_.get()); // We validate the http request using basic-auth inside HttpConnection::HandleSingleRequest. @@ -741,13 +741,14 @@ void Connection::HandleRequests() { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } - ConnectionFlow(peer); + ConnectionFlow(); socket_->CancelOnErrorCb(); // noop if nothing is registered. } VLOG(1) << "Closed connection for peer " << GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex()); cc_.reset(); + reply_builder_ = nullptr; } } @@ -801,13 +802,14 @@ std::pair Connection::GetClientInfoBeforeAfterTid() co string_view phase_name = PHASE_NAMES[phase_]; if (cc_) { - DCHECK(cc_->reply_builder()); + DCHECK(cc_->reply_builder() && reply_builder_); string cc_info = service_->GetContextInfo(cc_.get()).Format(); - if (cc_->reply_builder()->IsSendActive()) + if (reply_builder_->IsSendActive()) phase_name = "send"; absl::StrAppend(&after, " ", cc_info); } absl::StrAppend(&after, " phase=", phase_name); + return {std::move(before), std::move(after)}; } @@ -872,7 +874,7 @@ const absl::flat_hash_map& Connection::GetLibStatsTL() { return g_libname_ver_map; } -io::Result Connection::CheckForHttpProto(FiberSocketBase* peer) { +io::Result Connection::CheckForHttpProto() { if (!IsPrivileged() && !IsMain()) { return false; } @@ -883,6 +885,7 @@ io::Result Connection::CheckForHttpProto(FiberSocketBase* peer) { } size_t last_len = 0; + auto* peer = socket_.get(); do { auto buf = io_buf_.AppendBuffer(); DCHECK(!buf.empty()); @@ -916,27 +919,26 @@ io::Result Connection::CheckForHttpProto(FiberSocketBase* peer) { return false; } -void Connection::ConnectionFlow(FiberSocketBase* peer) { +void Connection::ConnectionFlow() { ++stats_->num_conns; ++stats_->conn_received_cnt; stats_->read_buf_capacity += io_buf_.Capacity(); ParserStatus parse_status = OK; - SinkReplyBuilder* orig_builder = cc_->reply_builder(); // At the start we read from the socket to determine the HTTP/Memstore protocol. // Therefore we may already have some data in the buffer. if (io_buf_.InputLen() > 0) { phase_ = PROCESS; if (redis_parser_) { - parse_status = ParseRedis(orig_builder); + parse_status = ParseRedis(); } else { DCHECK(memcache_parser_); parse_status = ParseMemcache(); } } - error_code ec = orig_builder->GetError(); + error_code ec = reply_builder_->GetError(); // Main loop. if (parse_status != ERROR && !ec) { @@ -944,7 +946,7 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { UpdateIoBufCapacity(io_buf_, stats_, [&]() { io_buf_.EnsureCapacity(io_buf_.Capacity() * 2); }); } - auto res = IoLoop(peer, orig_builder); + auto res = IoLoop(); if (holds_alternative(res)) { ec = get(res); @@ -975,10 +977,10 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { VLOG(1) << "Error parser status " << parser_error_; if (redis_parser_) { - SendProtocolError(RedisParser::Result(parser_error_), orig_builder); + SendProtocolError(RedisParser::Result(parser_error_), reply_builder_); } else { DCHECK(memcache_parser_); - orig_builder->SendProtocolError("bad command line format"); + reply_builder_->SendProtocolError("bad command line format"); } // Shut down the servers side of the socket to send a FIN to the client @@ -988,15 +990,15 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { // Otherwise the clients write could fail (or block), so they would never // read the above protocol error (see issue #1327). // TODO: we have a bug that can potentially deadlock the code below. - // If a peer does not close the socket on the other side, the while loop will never finish. + // If the socket does not close the socket on the other side, the while loop will never finish. // to reproduce: nc localhost 6379 and then run invalid sequence: *1 *1 - error_code ec2 = peer->Shutdown(SHUT_WR); + error_code ec2 = socket_->Shutdown(SHUT_WR); LOG_IF(WARNING, ec2) << "Could not shutdown socket " << ec2; if (!ec2) { while (true) { // Discard any received data. io_buf_.Clear(); - if (!peer->Recv(io_buf_.AppendBuffer())) { + if (!socket_->Recv(io_buf_.AppendBuffer())) { break; } } @@ -1064,7 +1066,7 @@ void Connection::DispatchSingle(bool has_more, absl::FunctionRef invoke_ } } -Connection::ParserStatus Connection::ParseRedis(SinkReplyBuilder* orig_builder) { +Connection::ParserStatus Connection::ParseRedis() { uint32_t consumed = 0; RedisParser::Result result = RedisParser::OK; @@ -1096,7 +1098,7 @@ Connection::ParserStatus Connection::ParseRedis(SinkReplyBuilder* orig_builder) DispatchSingle(has_more, dispatch_sync, dispatch_async); } io_buf_.ConsumeInput(consumed); - } while (RedisParser::OK == result && !orig_builder->GetError()); + } while (RedisParser::OK == result && !reply_builder_->GetError()); parser_error_ = result; if (result == RedisParser::OK) @@ -1120,7 +1122,7 @@ auto Connection::ParseMemcache() -> ParserStatus { return {make_unique(std::move(cmd), value)}; }; - MCReplyBuilder* builder = static_cast(cc_->reply_builder()); + MCReplyBuilder* builder = static_cast(reply_builder_); do { string_view str = ToSV(io_buf_.InputBuffer()); @@ -1179,10 +1181,10 @@ void Connection::OnBreakCb(int32_t mask) { return; } - DCHECK(cc_->reply_builder()) << "[" << id_ << "] " << phase_ << " " << migration_in_process_; + DCHECK(reply_builder_) << "[" << id_ << "] " << phase_ << " " << migration_in_process_; VLOG(1) << "[" << id_ << "] Got event " << mask << " " << phase_ << " " - << cc_->reply_builder()->IsSendActive() << " " << cc_->reply_builder()->GetError(); + << reply_builder_->IsSendActive() << " " << reply_builder_->GetError(); cc_->conn_closing = true; BreakOnce(mask); @@ -1212,6 +1214,7 @@ void Connection::HandleMigrateRequest() { // We need to return early as the socket is closing and IoLoop will clean up. // The reason that this is true is because of the following DCHECK DCHECK(!dispatch_fb_.IsJoinable()); + // which can never trigger since we Joined on the dispatch_fb_ above and we are // atomic in respect to our proactor meaning that no other fiber will // launch the DispatchFiber. @@ -1219,17 +1222,14 @@ void Connection::HandleMigrateRequest() { return; } } - - // In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started. - LaunchDispatchFiberIfNeeded(); } -auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_builder) - -> variant { +auto Connection::IoLoop() -> variant { error_code ec; ParserStatus parse_status = OK; size_t max_iobfuf_len = GetFlag(FLAGS_max_client_iobuf_len); + auto* peer = socket_.get(); do { HandleMigrateRequest(); @@ -1256,7 +1256,7 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil bool is_iobuf_full = io_buf_.AppendLen() == 0; if (redis_parser_) { - parse_status = ParseRedis(orig_builder); + parse_status = ParseRedis(); } else { DCHECK(memcache_parser_); parse_status = ParseMemcache(); @@ -1303,7 +1303,7 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil } else if (parse_status != OK) { break; } - ec = orig_builder->GetError(); + ec = reply_builder_->GetError(); } while (peer->IsOpen() && !ec); if (ec) @@ -1337,7 +1337,7 @@ bool Connection::ShouldEndDispatchFiber(const MessageHandle& msg) { return false; } -void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) { +void Connection::SquashPipeline() { DCHECK_EQ(dispatch_q_.size(), pending_pipeline_cmd_cnt_); vector squash_cmds; @@ -1356,8 +1356,8 @@ void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) { size_t dispatched = service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), cc_.get()); if (pending_pipeline_cmd_cnt_ == squash_cmds.size()) { // Flush if no new commands appeared - builder->FlushBatch(); - builder->SetBatchMode(false); // in case the next dispatch is sync + reply_builder_->FlushBatch(); + reply_builder_->SetBatchMode(false); // in case the next dispatch is sync } cc_->async_dispatch = false; @@ -1376,7 +1376,7 @@ void Connection::SquashPipeline(facade::SinkReplyBuilder* builder) { } void Connection::ClearPipelinedMessages() { - DispatchOperations dispatch_op{cc_->reply_builder(), this}; + DispatchOperations dispatch_op{reply_builder_, this}; // Recycle messages even from disconnecting client to keep properly track of memory stats // As well as to avoid pubsub backpressure leakege. @@ -1421,17 +1421,17 @@ std::string Connection::DebugInfo() const { // into the dispatch queue and DispatchFiber will run those commands asynchronously with // InputLoop. Note: in some cases, InputLoop may decide to dispatch directly and bypass the // DispatchFiber. -void Connection::ExecutionFiber(util::FiberSocketBase* peer) { +void Connection::ExecutionFiber() { ThisFiber::SetName("ExecutionFiber"); - SinkReplyBuilder* builder = cc_->reply_builder(); - DispatchOperations dispatch_op{builder, this}; + + DispatchOperations dispatch_op{reply_builder_, this}; size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash); uint64_t prev_epoch = fb2::FiberSwitchEpoch(); fb2::NoOpLock noop_lk; - while (!builder->GetError()) { + while (!reply_builder_->GetError()) { DCHECK_EQ(socket()->proactor(), ProactorBase::me()); cnd_.wait(noop_lk, [this] { return cc_->conn_closing || (!dispatch_q_.empty() && !cc_->sync_dispatch); @@ -1455,7 +1455,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) { } prev_epoch = cur_epoch; - builder->SetBatchMode(dispatch_q_.size() > 1); + reply_builder_->SetBatchMode(dispatch_q_.size() > 1); bool subscriber_over_limit = stats_->dispatch_queue_subscriber_bytes >= queue_backpressure_->publish_buffer_limit; @@ -1468,7 +1468,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) { bool threshold_reached = pending_pipeline_cmd_cnt_ > squashing_threshold; bool are_all_plain_cmds = pending_pipeline_cmd_cnt_ == dispatch_q_.size(); if (squashing_enabled && threshold_reached && are_all_plain_cmds && !skip_next_squashing_) { - SquashPipeline(builder); + SquashPipeline(); } else { MessageHandle msg = std::move(dispatch_q_.front()); dispatch_q_.pop_front(); @@ -1477,7 +1477,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) { // last command to reply and flush. If it doesn't reply (i.e. is a control message like // migrate), we have to flush manually. if (dispatch_q_.empty() && !msg.IsReplying()) { - builder->FlushBatch(); + reply_builder_->FlushBatch(); } if (ShouldEndDispatchFiber(msg)) { @@ -1505,7 +1505,7 @@ void Connection::ExecutionFiber(util::FiberSocketBase* peer) { queue_backpressure_->pubsub_ec.notify(); } - DCHECK(cc_->conn_closing || builder->GetError()); + DCHECK(cc_->conn_closing || reply_builder_->GetError()); cc_->conn_closing = true; queue_backpressure_->pipeline_cnd.notify_all(); } @@ -1582,6 +1582,7 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) { } listener()->Migrate(this, dest); + // After we migrate, it could be the case the connection was shut down. We should // act accordingly. if (!socket()->IsOpen()) { @@ -1646,8 +1647,8 @@ void Connection::SendInvalidationMessageAsync(InvalidationMessage msg) { void Connection::LaunchDispatchFiberIfNeeded() { if (!dispatch_fb_.IsJoinable() && !migration_in_process_) { VLOG(1) << "[" << id_ << "] LaunchDispatchFiberIfNeeded "; - dispatch_fb_ = fb2::Fiber(fb2::Launch::post, "connection_dispatch", - [this, peer = socket_.get()]() { ExecutionFiber(peer); }); + dispatch_fb_ = + fb2::Fiber(fb2::Launch::post, "connection_dispatch", [this]() { ExecutionFiber(); }); } } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index bb9613ff1844..2ff58198e4e5 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -331,14 +331,13 @@ class Connection : public util::Connection { void HandleRequests() final; // Start dispatch fiber and run IoLoop. - void ConnectionFlow(util::FiberSocketBase* peer); + void ConnectionFlow(); // Main loop reading client messages and passing requests to dispatch queue. - std::variant IoLoop(util::FiberSocketBase* peer, - SinkReplyBuilder* orig_builder); + std::variant IoLoop(); // Returns true if HTTP header is detected. - io::Result CheckForHttpProto(util::FiberSocketBase* peer); + io::Result CheckForHttpProto(); // Dispatches a single (Redis or MC) command. // `has_more` should indicate whether the io buffer has more commands @@ -348,7 +347,7 @@ class Connection : public util::Connection { absl::FunctionRef cmd_msg_cb); // Handles events from dispatch queue. - void ExecutionFiber(util::FiberSocketBase* peer); + void ExecutionFiber(); void SendAsync(MessageHandle msg); @@ -358,7 +357,7 @@ class Connection : public util::Connection { // Create new pipeline request, re-use from pool when possible. PipelineMessagePtr FromArgs(RespVec args, mi_heap_t* heap); - ParserStatus ParseRedis(SinkReplyBuilder* orig_builder); + ParserStatus ParseRedis(); ParserStatus ParseMemcache(); void OnBreakCb(int32_t mask); @@ -373,8 +372,9 @@ class Connection : public util::Connection { bool ShouldEndDispatchFiber(const MessageHandle& msg); void LaunchDispatchFiberIfNeeded(); // Dispatch fiber is started lazily + // Squashes pipelined commands from the dispatch queue to spread load over all threads - void SquashPipeline(facade::SinkReplyBuilder*); + void SquashPipeline(); // Clear pipelined messages, disaptching only intrusive ones. void ClearPipelinedMessages(); @@ -398,6 +398,9 @@ class Connection : public util::Connection { Protocol protocol_; ConnectionStats* stats_ = nullptr; + // cc_->reply_builder may change during the lifetime of the connection, due to injections. + // This is a pointer to the original, socket based reply builder that never changes. + SinkReplyBuilder* reply_builder_ = nullptr; util::HttpListenerBase* http_listener_; SSL_CTX* ssl_ctx_;