Skip to content

Commit

Permalink
feat: Implement first part of the azure write file API (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
romange authored Dec 21, 2024
1 parent 687f708 commit 04fd389
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 84 deletions.
23 changes: 21 additions & 2 deletions examples/gcs_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,29 @@ void RunAzure(SSL_CTX* ctx) {

string prefix = GetFlag(FLAGS_prefix);

if (GetFlag(FLAGS_read) > 0) {
if (GetFlag(FLAGS_write) > 0) {
auto src = io::ReadFileToString("/proc/self/exe");
CHECK(src);
LOG(INFO) << "Writing " << src->size() << " bytes to " << prefix;
for (unsigned i = 0; i < GetFlag(FLAGS_write); ++i) {
string dest_key = absl::StrCat(prefix, "_", i);

cloud::azure::WriteFileOptions opts;
opts.creds_provider = &provider;
opts.ssl_cntx = ctx;
io::Result<io::WriteFile*> dest_res = cloud::azure::OpenWriteFile(bucket, dest_key, opts);
CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message();
unique_ptr<io::WriteFile> dest(*dest_res);
error_code ec = dest->Write(*src);
CHECK(!ec);
ec = dest->Close();
CHECK(!ec);
CONSOLE_INFO << "Written " << dest_key;
}
} else if (GetFlag(FLAGS_read) > 0) {
for (unsigned i = 0; i < GetFlag(FLAGS_read); ++i) {
string dest_key = prefix;
cloud::azure::AzureReadFileOptions opts;
cloud::azure::ReadFileOptions opts;
opts.creds_provider = &provider;
opts.ssl_cntx = ctx;

Expand Down
29 changes: 20 additions & 9 deletions util/cloud/azure/azure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,35 @@ void HMAC(absl::string_view key, absl::string_view msg, uint8_t dest[32]) {
CHECK_EQ(len, 32u);
}

string ComputeSignature(string_view account, const boost::beast::http::header<true>& req_header,
string ComputeSignature(string_view account, h2::verb verb, const h2::header<true>& req_header,
string_view account_key) {
string key_bin;
CHECK(absl::Base64Unescape(account_key, &key_bin));

// see here:
// https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#blob-queue-and-file-services-shared-key-authorization
string new_lines;
for (unsigned i = 0; i < 12; ++i)
absl::StrAppend(&new_lines, "\n");

vector<pair<string_view, string_view>> x_head;
for (const auto& h : req_header) {
if (h.name_string().starts_with("x-ms-")) {
x_head.emplace_back(detail::FromBoostSV(h.name_string()), detail::FromBoostSV(h.value()));
}
}
sort(x_head.begin(), x_head.end());
string to_sign = absl::StrCat("GET", new_lines);
string_view verb_str = detail::FromBoostSV(h2::to_string(verb));

auto it = req_header.find(h2::field::content_length);
string content_length;
if (it != req_header.end() && it->value() != "0") {
absl::StrAppend(&content_length, detail::FromBoostSV(it->value()), "\n");
} else {
content_length = "\n";
}

// see here:
// https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#blob-queue-and-file-services-shared-key-authorization
string new_lines;
for (unsigned i = 0; i < 8; ++i)
absl::StrAppend(&new_lines, "\n");

string to_sign = absl::StrCat(verb_str, "\n\n\n", content_length, new_lines);
for (const auto& p : x_head) {
absl::StrAppend(&to_sign, p.first, ":", p.second, "\n");
}
Expand Down Expand Up @@ -117,7 +127,8 @@ void Credentials::Sign(detail::HttpRequestBase* req) const {
req->SetHeader("x-ms-date", date);
req->SetHeader("x-ms-version", kVersion);

string signature = ComputeSignature(account_name_, req->GetHeaders(), account_key_);
string signature =
ComputeSignature(account_name_, req->GetMethod(), req->GetHeaders(), account_key_);
req->SetHeader("Authorization", absl::StrCat("SharedKey ", account_name_, ":", signature));
}

Expand Down
83 changes: 77 additions & 6 deletions util/cloud/azure/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ detail::EmptyRequestImpl FillRequest(string_view endpoint, string_view url, Cred

class ReadFile final : public io::ReadonlyFile {
public:
ReadFile(string read_obj_url, AzureReadFileOptions opts)
: read_obj_url_(read_obj_url), opts_(opts) {
ReadFile(string read_obj_url, ReadFileOptions opts) : read_obj_url_(read_obj_url), opts_(opts) {
}

virtual ~ReadFile();
Expand All @@ -188,7 +187,7 @@ class ReadFile final : public io::ReadonlyFile {
error_code InitRead();

const string read_obj_url_;
AzureReadFileOptions opts_;
ReadFileOptions opts_;

using Parser = h2::response_parser<h2::buffer_body>;
std::optional<Parser> parser_;
Expand All @@ -199,6 +198,31 @@ class ReadFile final : public io::ReadonlyFile {
size_t size_ = 0, offs_ = 0;
};

// File handle that writes to Azure.
//
// This uses multipart uploads, where it will buffer upto the configured part
// size before uploading.
class WriteFile : public detail::AbstractStorageFile {
public:
WriteFile(string_view container, string_view key, const WriteFileOptions& opts);
~WriteFile();

// Closes the object and completes the multipart upload. Therefore the object
// will not be uploaded unless Close is called.
error_code Close() override;

private:
error_code Upload();

using UploadRequest = detail::DynamicBodyRequestImpl;
unique_ptr<UploadRequest> PrepareRequest();

unique_ptr<http::ClientPool> pool_; // must be before client_handle_.
string target_;
unsigned block_id_ = 1;
WriteFileOptions opts_;
};

ReadFile::~ReadFile() {
}

Expand Down Expand Up @@ -273,6 +297,52 @@ io::SizeOrError ReadFile::Read(size_t offset, const iovec* v, uint32_t len) {
return total;
}

WriteFile::WriteFile(string_view container, string_view key, const WriteFileOptions& opts)
: detail::AbstractStorageFile(key, 1UL << 23), opts_(opts) {
string endpoint = opts_.creds_provider->GetEndpoint();
pool_ = CreatePool(endpoint, opts_.ssl_cntx, fb2::ProactorBase::me());
target_ = absl::StrCat("/", container, "/", key);
}

WriteFile::~WriteFile() {
}

error_code WriteFile::Close() {
return {};
}

error_code WriteFile::Upload() {
size_t body_size = body_mb_.size();
CHECK_GT(body_size, 0u);

auto req = PrepareRequest();

error_code res;
RobustSender sender(pool_.get(), opts_.creds_provider);
RobustSender::SenderResult send_res;
RETURN_ERROR(sender.Send(3, req.get(), &send_res));

auto parser_ptr = std::move(send_res.eb_parser);
const auto& resp_msg = parser_ptr->get();
VLOG(1) << "Upload response: " << resp_msg;

return {};
}

auto WriteFile::PrepareRequest() -> unique_ptr<UploadRequest> {
string url =
absl::StrCat(target_, "?comp=block&blockid=", absl::Dec(block_id_++, absl::kZeroPad4));
unique_ptr<UploadRequest> upload_req(new UploadRequest(url, h2::verb::put));

upload_req->SetBody(std::move(body_mb_));

upload_req->SetHeader(h2::field::host, opts_.creds_provider->GetEndpoint());
upload_req->Finalize();
opts_.creds_provider->Sign(upload_req.get());

return upload_req;
}

} // namespace

error_code Storage::ListContainers(function<void(const ContainerItem&)> cb) {
Expand Down Expand Up @@ -347,16 +417,17 @@ string BuildGetObjUrl(const string& container, const string& key) {
}

io::Result<io::ReadonlyFile*> OpenReadFile(const std::string& container, const std::string& key,
const AzureReadFileOptions& opts) {
const ReadFileOptions& opts) {
DCHECK(opts.creds_provider && opts.ssl_cntx);
string url = BuildGetObjUrl(container, key);
return new ReadFile(url, opts);
}

io::Result<io::WriteFile*> OpenWriteFile(const std::string& container, const std::string& key,
const AzureWriteFileOptions& opts) {
const WriteFileOptions& opts) {
DCHECK(opts.creds_provider && opts.ssl_cntx);
return UnexpectedError(errc::function_not_supported);

return new WriteFile(container, key, opts);
}

} // namespace cloud::azure
Expand Down
8 changes: 4 additions & 4 deletions util/cloud/azure/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ class Storage {
Credentials* creds_;
};

struct AzureReadFileOptions {
struct ReadFileOptions {
Credentials* creds_provider = nullptr;
SSL_CTX* ssl_cntx;
};

using AzureWriteFileOptions = AzureReadFileOptions;
using WriteFileOptions = ReadFileOptions;

io::Result<io::ReadonlyFile*> OpenReadFile(const std::string& container,
const std::string& key,
const AzureReadFileOptions& opts);
const ReadFileOptions& opts);

io::Result<io::WriteFile*> OpenWriteFile(const std::string& container, const std::string& key,
const AzureWriteFileOptions& opts);
const WriteFileOptions& opts);

} // namespace cloud::azure
} // namespace util
10 changes: 0 additions & 10 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <boost/beast/http/parser.hpp>
#include <memory>

#include "io/io.h"
#include "util/cloud/utils.h"
#include "util/http/https_client_pool.h"

Expand All @@ -25,13 +24,4 @@ namespace detail {

std::string AuthHeader(std::string_view access_token);

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) { \
VLOG(1) << "Failed " << #x << ": " << ec.message(); \
return nonstd::make_unexpected(ec); \
} \
} while (false)

} // namespace util::cloud
53 changes: 4 additions & 49 deletions util/cloud/gcp/gcs_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,8 @@ inline void SetRange(size_t from, size_t to, h2::fields* flds) {
//
// This uses multipart uploads, where it will buffer upto the configured part
// size before uploading.
class GcsWriteFile : public io::WriteFile {
class GcsWriteFile : public detail::AbstractStorageFile {
public:
// Writes bytes to the GCS object. This will either buffer internally or
// write a part to GCS.
io::Result<size_t> WriteSome(const iovec* v, uint32_t len) override;

// Closes the object and completes the multipart upload. Therefore the object
// will not be uploaded unless Close is called.
error_code Close() override;
Expand All @@ -87,14 +83,12 @@ class GcsWriteFile : public io::WriteFile {
~GcsWriteFile();

private:
error_code FillBuf(const uint8* buffer, size_t length);
error_code Upload();
error_code Upload() final;

using UploadRequest = detail::DynamicBodyRequestImpl;
unique_ptr<UploadRequest> PrepareRequest(size_t to, ssize_t total);

string upload_id_;
multi_buffer body_mb_;
size_t uploaded_ = 0;
GcsWriteFileOptions opts_;
};
Expand Down Expand Up @@ -139,7 +133,7 @@ class GcsReadFile final : public io::ReadonlyFile {
******************************************************************************************/

GcsWriteFile::GcsWriteFile(string_view key, string_view upload_id, const GcsWriteFileOptions& opts)
: io::WriteFile(key), upload_id_(upload_id), body_mb_(opts.part_size), opts_(opts) {
: detail::AbstractStorageFile(key, opts.part_size), upload_id_(upload_id), opts_(opts) {
}

GcsWriteFile::~GcsWriteFile() {
Expand All @@ -148,15 +142,6 @@ GcsWriteFile::~GcsWriteFile() {
}
}

io::Result<size_t> GcsWriteFile::WriteSome(const iovec* v, uint32_t len) {
size_t total = 0;
for (uint32_t i = 0; i < len; ++i) {
RETURN_UNEXPECTED(FillBuf(reinterpret_cast<const uint8_t*>(v->iov_base), v->iov_len));
total += v->iov_len;
}
return total;
}

error_code GcsWriteFile::Close() {
size_t to = uploaded_ + body_mb_.size();
auto req = PrepareRequest(to, to);
Expand Down Expand Up @@ -208,36 +193,6 @@ error_code GcsWriteFile::Close() {
return {};
}

error_code GcsWriteFile::FillBuf(const uint8* buffer, size_t length) {
while (length >= body_mb_.max_size() - body_mb_.size()) {
size_t prepare_size = body_mb_.max_size() - body_mb_.size();
auto mbs = body_mb_.prepare(prepare_size);
size_t offs = 0;
for (auto mb : mbs) {
memcpy(mb.data(), buffer + offs, mb.size());
offs += mb.size();
}
DCHECK_EQ(offs, prepare_size);
body_mb_.commit(prepare_size);

auto ec = Upload();
if (ec)
return ec;

length -= prepare_size;
buffer += prepare_size;
}

if (length) {
auto mbs = body_mb_.prepare(length);
for (auto mb : mbs) {
memcpy(mb.data(), buffer, mb.size());
buffer += mb.size();
}
body_mb_.commit(length);
}
return {};
}

error_code GcsWriteFile::Upload() {
size_t body_size = body_mb_.size();
Expand Down Expand Up @@ -274,7 +229,7 @@ error_code GcsWriteFile::Upload() {
}

auto GcsWriteFile::PrepareRequest(size_t to, ssize_t total) -> unique_ptr<UploadRequest> {
unique_ptr<UploadRequest> upload_req(new UploadRequest(upload_id_));
unique_ptr<UploadRequest> upload_req(new UploadRequest(upload_id_, h2::verb::post));

upload_req->SetBody(std::move(body_mb_));
upload_req->SetHeader(h2::field::content_range, ContentRangeHeader(uploaded_, to, total));
Expand Down
Loading

0 comments on commit 04fd389

Please sign in to comment.