Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: Add optional builder + future-wait to cert reload callback + expose rebuild #2573

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions include/seastar/net/tls.hh
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ namespace tls {
};

class reloadable_credentials_base;
class credentials_builder;

using reload_callback = std::function<void(const std::unordered_set<sstring>&, std::exception_ptr)>;
using reload_callback_ex = std::function<future<>(const credentials_builder&, const std::unordered_set<sstring>&, std::exception_ptr)>;

/**
* Intentionally "primitive", and more importantly, copyable
Expand Down Expand Up @@ -320,10 +322,16 @@ namespace tls {
shared_ptr<certificate_credentials> build_certificate_credentials() const;
shared_ptr<server_credentials> build_server_credentials() const;

void rebuild(certificate_credentials&) const;
void rebuild(server_credentials&) const;

// same as above, but any files used for certs/keys etc will be watched
// for modification and reloaded if changed
future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback_ex = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback_ex = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;

future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback, std::optional<std::chrono::milliseconds> tolerance = {}) const;
private:
friend class reloadable_credentials_base;

Expand Down
44 changes: 33 additions & 11 deletions src/net/tls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ class tls::reloadable_credentials_base {
public:
using time_point = std::chrono::system_clock::time_point;

reloading_builder(credentials_builder b, reload_callback cb, reloadable_credentials_base* creds, delay_type delay)
reloading_builder(credentials_builder b, reload_callback_ex cb, reloadable_credentials_base* creds, delay_type delay)
: credentials_builder(std::move(b))
, _cb(std::move(cb))
, _creds(creds)
Expand Down Expand Up @@ -955,7 +955,7 @@ class tls::reloadable_credentials_base {
}
void do_callback(std::exception_ptr ep = {}) {
if (_cb && !_files.empty()) {
_cb(boost::copy_range<std::unordered_set<sstring>>(_files | boost::adaptors::map_keys), std::move(ep));
_cb(*this, boost::copy_range<std::unordered_set<sstring>>(_files | boost::adaptors::map_keys), std::move(ep)).get();
}
}
// called from seastar::thread
Expand Down Expand Up @@ -988,7 +988,7 @@ class tls::reloadable_credentials_base {
});
}

reload_callback _cb;
reload_callback_ex _cb;
reloadable_credentials_base* _creds;
fsnotifier _fsn;
std::unordered_map<fsnotifier::watch_token, std::pair<fsnotifier::watch, sstring>> _watches;
Expand All @@ -997,7 +997,7 @@ class tls::reloadable_credentials_base {
timer<> _timer;
delay_type _delay;
};
reloadable_credentials_base(credentials_builder builder, reload_callback cb, delay_type delay = default_tolerance)
reloadable_credentials_base(credentials_builder builder, reload_callback_ex cb, delay_type delay = default_tolerance)
: _builder(seastar::make_shared<reloading_builder>(std::move(builder), std::move(cb), this, delay))
{
_builder->start();
Expand All @@ -1016,7 +1016,7 @@ class tls::reloadable_credentials_base {
template<typename Base>
class tls::reloadable_credentials : public Base, public tls::reloadable_credentials_base {
public:
reloadable_credentials(credentials_builder builder, reload_callback cb, Base b, delay_type delay = default_tolerance)
reloadable_credentials(credentials_builder builder, reload_callback_ex cb, Base b, delay_type delay = default_tolerance)
: Base(std::move(b))
, tls::reloadable_credentials_base(std::move(builder), std::move(cb), delay)
{}
Expand All @@ -1025,30 +1025,52 @@ class tls::reloadable_credentials : public Base, public tls::reloadable_credenti

template<>
void tls::reloadable_credentials<tls::certificate_credentials>::rebuild(const credentials_builder& builder) {
auto tmp = builder.build_certificate_credentials();
this->_impl = std::move(tmp->_impl);
builder.rebuild(*this);
}

template<>
void tls::reloadable_credentials<tls::server_credentials>::rebuild(const credentials_builder& builder) {
auto tmp = builder.build_server_credentials();
this->_impl = std::move(tmp->_impl);
builder.rebuild(*this);
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
void tls::credentials_builder::rebuild(certificate_credentials& creds) const {
auto tmp = build_certificate_credentials();
creds._impl = std::move(tmp->_impl);
}

void tls::credentials_builder::rebuild(server_credentials& creds) const {
auto tmp = build_server_credentials();
creds._impl = std::move(tmp->_impl);
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback_ex cb, std::optional<std::chrono::milliseconds> tolerance) const {
auto creds = seastar::make_shared<reloadable_credentials<tls::certificate_credentials>>(*this, std::move(cb), std::move(*build_certificate_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance));
return creds->init().then([creds] {
return make_ready_future<shared_ptr<tls::certificate_credentials>>(creds);
});
}

future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback_ex cb, std::optional<std::chrono::milliseconds> tolerance) const {
auto creds = seastar::make_shared<reloadable_credentials<tls::server_credentials>>(*this, std::move(cb), std::move(*build_server_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance));
return creds->init().then([creds] {
return make_ready_future<shared_ptr<tls::server_credentials>>(creds);
});
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
return build_reloadable_certificate_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set<sstring>& files, std::exception_ptr p) {
cb(files, p);
return make_ready_future<>();
}, tolerance);
}

future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
return build_reloadable_server_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set<sstring>& files, std::exception_ptr p) {
cb(files, p);
return make_ready_future<>();
}, tolerance);
}

namespace tls {

/**
Expand Down
124 changes: 124 additions & 0 deletions tests/unit/tls_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <seastar/net/inet_address.hh>
#include <seastar/testing/test_case.hh>
#include <seastar/testing/thread_test_case.hh>
#include <seastar/util/defer.hh>

#include <boost/dll.hpp>

Expand Down Expand Up @@ -1595,3 +1596,126 @@ SEASTAR_THREAD_TEST_CASE(test_tls13_session_tickets) {
}

}

SEASTAR_THREAD_TEST_CASE(test_reload_certificates_with_only_shard0_notify) {
tmpdir tmp;

namespace fs = std::filesystem;

// copy the wrong certs. We don't trust these
// blocking calls, but this is a test and seastar does not have a copy
// util and I am lazy...
fs::copy_file(certfile("other.crt"), tmp.path() / "test.crt");
fs::copy_file(certfile("other.key"), tmp.path() / "test.key");

auto cert = (tmp.path() / "test.crt").native();
auto key = (tmp.path() / "test.key").native();
promise<> p;

tls::credentials_builder b;
b.set_x509_key_file(cert, key, tls::x509_crt_format::PEM).get();
b.set_dh_level();

auto certs = b.build_server_credentials();

auto shard_1_certs = smp::submit_to(1, [&]() -> future<shared_ptr<tls::server_credentials>> {
co_return co_await b.build_reloadable_server_credentials([&, changed = std::unordered_set<sstring>{}](const tls::credentials_builder& builder, const std::unordered_set<sstring>& files, std::exception_ptr ep) mutable -> future<> {
if (ep) {
co_return;
}
changed.insert(files.begin(), files.end());
if (changed.count(cert) && changed.count(key)) {
// shard one certs are not reloadable. We issue a reload of them from shard 0
// - to save inotify instances.
co_await smp::submit_to(0, [&] {
builder.rebuild(*certs);
p.set_value();
});
}
});
}).get();

auto def = defer([&]() noexcept {
try {
smp::submit_to(1, [&] {
shard_1_certs = nullptr;
}).get();
} catch (...) {}
});

::listen_options opts;
opts.reuse_address = true;
auto addr = ::make_ipv4_address( {0x7f000001, 4712});
auto server = tls::listen(certs, addr, opts);

tls::credentials_builder b2;
b2.set_x509_trust_file(certfile("catest.pem"), tls::x509_crt_format::PEM).get();

{
auto sa = server.accept();
auto c = tls::connect(b2.build_certificate_credentials(), addr).get();
auto s = sa.get();
auto in = s.connection.input();

output_stream<char> out(c.output().detach(), 4096);

try {
out.write("apa").get();
auto f = out.flush();
auto f2 = in.read();

try {
f.get();
BOOST_FAIL("should not reach");
} catch (tls::verification_error&) {
// ok
}
try {
out.close().get();
} catch (...) {
}

try {
f2.get();
BOOST_FAIL("should not reach");
} catch (...) {
// ok
}
try {
in.close().get();
} catch (...) {
}
} catch (tls::verification_error&) {
// ok
}
}

// copy the right (trusted) certs over the old ones.
fs::copy_file(certfile("test.crt"), tmp.path() / "test0.crt");
fs::copy_file(certfile("test.key"), tmp.path() / "test0.key");

rename_file((tmp.path() / "test0.crt").native(), (tmp.path() / "test.crt").native()).get();
rename_file((tmp.path() / "test0.key").native(), (tmp.path() / "test.key").native()).get();

p.get_future().get();

// now it should work
{
auto sa = server.accept();
auto c = tls::connect(b2.build_certificate_credentials(), addr).get();
auto s = sa.get();
auto in = s.connection.input();

output_stream<char> out(c.output().detach(), 4096);

out.write("apa").get();
auto f = out.flush();
auto buf = in.read().get();
f.get();
out.close().get();
in.read().get(); // ignore - just want eof
in.close().get();

BOOST_CHECK_EQUAL(sstring(buf.begin(), buf.end()), "apa");
}
}
Loading