Merge commit 'c021835739272d1daa8b4a2afc8e8c7d093d534c' into pr/ipc-stop-base

This commit is contained in:
Ryan Ofsky 2025-04-24 10:53:03 +01:00
commit b9e16ff790
7 changed files with 324 additions and 194 deletions

View file

@ -14,9 +14,10 @@
#include <assert.h>
#include <functional>
#include <optional>
#include <kj/function.h>
#include <map>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
@ -129,6 +130,25 @@ std::string LongThreadName(const char* exe_name);
//! Event loop implementation.
//!
//! Cap'n Proto threading model is very simple: all I/O operations are
//! asynchronous and must be performed on a single thread. This includes:
//!
//! - Code starting an asynchronous operation (calling a function that returns a
//! promise object)
//! - Code notifying that an asynchronous operation is complete (code using a
//! fulfiller object)
//! - Code handling a completed operation (code chaining or waiting for a promise)
//!
//! All this code needs to run on one thread, and the EventLoop::loop() method
//! is the entry point for this thread. ProxyClient and ProxyServer objects that
//! use other threads and need to perform I/O operations post to this thread
//! using EventLoop::post() and EventLoop::sync() methods.
//!
//! Specifically, because ProxyClient methods can be called from arbitrary
//! threads, and ProxyServer methods can run on arbitrary threads, ProxyClient
//! methods use the EventLoop thread to send requests, and ProxyServer methods
//! use the thread to return results.
//!
//! Based on https://groups.google.com/d/msg/capnproto/TuQFF1eH2-M/g81sHaTAAQAJ
class EventLoop
{
@ -144,7 +164,7 @@ public:
//! Run function on event loop thread. Does not return until function completes.
//! Must be called while the loop() function is active.
void post(const std::function<void()>& fn);
void post(kj::Function<void()> fn);
//! Wrapper around EventLoop::post that takes advantage of the
//! fact that callable will not go out of scope to avoid requirement that it
@ -152,7 +172,7 @@ public:
template <typename Callable>
void sync(Callable&& callable)
{
post(std::ref(callable));
post(std::forward<Callable>(callable));
}
//! Start asynchronous worker thread if necessary. This is only done if
@ -166,13 +186,10 @@ public:
//! is important that ProxyServer::m_impl destructors do not run on the
//! eventloop thread because they may need it to do I/O if they perform
//! other IPC calls.
void startAsyncThread(std::unique_lock<std::mutex>& lock);
void startAsyncThread() MP_REQUIRES(m_mutex);
//! Add/remove remote client reference counts.
void addClient(std::unique_lock<std::mutex>& lock);
bool removeClient(std::unique_lock<std::mutex>& lock);
//! Check if loop should exit.
bool done(std::unique_lock<std::mutex>& lock);
bool done() MP_REQUIRES(m_mutex);
Logger log()
{
@ -195,10 +212,10 @@ public:
std::thread m_async_thread;
//! Callback function to run on event loop thread during post() or sync() call.
const std::function<void()>* m_post_fn = nullptr;
kj::Function<void()>* m_post_fn MP_GUARDED_BY(m_mutex) = nullptr;
//! Callback functions to run on async thread.
CleanupList m_async_fns;
CleanupList m_async_fns MP_GUARDED_BY(m_mutex);
//! Pipe read handle used to wake up the event loop thread.
int m_wait_fd = -1;
@ -208,11 +225,11 @@ public:
//! Number of clients holding references to ProxyServerBase objects that
//! reference this event loop.
int m_num_clients = 0;
int m_num_clients MP_GUARDED_BY(m_mutex) = 0;
//! Mutex and condition variable used to post tasks to event loop and async
//! thread.
std::mutex m_mutex;
Mutex m_mutex;
std::condition_variable m_cv;
//! Capnp IO context.
@ -263,11 +280,9 @@ struct Waiter
// in the case where a capnp response is sent and a brand new
// request is immediately received.
while (m_fn) {
auto fn = std::move(m_fn);
m_fn = nullptr;
lock.unlock();
fn();
lock.lock();
auto fn = std::move(*m_fn);
m_fn.reset();
Unlock(lock, fn);
}
const bool done = pred();
return done;
@ -276,7 +291,7 @@ struct Waiter
std::mutex m_mutex;
std::condition_variable m_cv;
std::function<void()> m_fn;
std::optional<kj::Function<void()>> m_fn;
};
//! Object holding network & rpc state associated with either an incoming server
@ -290,21 +305,13 @@ public:
Connection(EventLoop& loop, kj::Own<kj::AsyncIoStream>&& stream_)
: m_loop(loop), m_stream(kj::mv(stream_)),
m_network(*m_stream, ::capnp::rpc::twoparty::Side::CLIENT, ::capnp::ReaderOptions()),
m_rpc_system(::capnp::makeRpcClient(m_network))
{
std::unique_lock<std::mutex> lock(m_loop.m_mutex);
m_loop.addClient(lock);
}
m_rpc_system(::capnp::makeRpcClient(m_network)) {}
Connection(EventLoop& loop,
kj::Own<kj::AsyncIoStream>&& stream_,
const std::function<::capnp::Capability::Client(Connection&)>& make_client)
: m_loop(loop), m_stream(kj::mv(stream_)),
m_network(*m_stream, ::capnp::rpc::twoparty::Side::SERVER, ::capnp::ReaderOptions()),
m_rpc_system(::capnp::makeRpcServer(m_network, make_client(*this)))
{
std::unique_lock<std::mutex> lock(m_loop.m_mutex);
m_loop.addClient(lock);
}
m_rpc_system(::capnp::makeRpcServer(m_network, make_client(*this))) {}
//! Run cleanup functions. Must be called from the event loop thread. First
//! calls synchronous cleanup functions while blocked (to free capnp
@ -333,12 +340,12 @@ public:
// to the EventLoop TaskSet to avoid "Promise callback destroyed itself"
// error in cases where f deletes this Connection object.
m_on_disconnect.add(m_network.onDisconnect().then(
[f = std::move(f), this]() mutable { m_loop.m_task_set->add(kj::evalLater(kj::mv(f))); }));
[f = std::move(f), this]() mutable { m_loop->m_task_set->add(kj::evalLater(kj::mv(f))); }));
}
EventLoop& m_loop;
EventLoopRef m_loop;
kj::Own<kj::AsyncIoStream> m_stream;
LoggingErrorHandler m_error_handler{m_loop};
LoggingErrorHandler m_error_handler{*m_loop};
kj::TaskSet m_on_disconnect{m_error_handler};
::capnp::TwoPartyVatNetwork m_network;
std::optional<::capnp::RpcSystem<::capnp::rpc::twoparty::VatId>> m_rpc_system;
@ -381,21 +388,12 @@ ProxyClientBase<Interface, Impl>::ProxyClientBase(typename Interface::Client cli
: m_client(std::move(client)), m_context(connection)
{
{
std::unique_lock<std::mutex> lock(m_context.connection->m_loop.m_mutex);
m_context.connection->m_loop.addClient(lock);
}
// Handler for the connection getting destroyed before this client object.
auto cleanup_it = m_context.connection->addSyncCleanup([this]() {
// Release client capability by move-assigning to temporary.
{
typename Interface::Client(std::move(m_client));
}
{
std::unique_lock<std::mutex> lock(m_context.connection->m_loop.m_mutex);
m_context.connection->m_loop.removeClient(lock);
}
m_context.connection = nullptr;
});
@ -423,16 +421,11 @@ ProxyClientBase<Interface, Impl>::ProxyClientBase(typename Interface::Client cli
Sub::destroy(*this);
// FIXME: Could just invoke removed addCleanup fn here instead of duplicating code
m_context.connection->m_loop.sync([&]() {
m_context.loop->sync([&]() {
// Release client capability by move-assigning to temporary.
{
typename Interface::Client(std::move(m_client));
}
{
std::unique_lock<std::mutex> lock(m_context.connection->m_loop.m_mutex);
m_context.connection->m_loop.removeClient(lock);
}
if (destroy_connection) {
delete m_context.connection;
m_context.connection = nullptr;
@ -454,8 +447,6 @@ ProxyServerBase<Interface, Impl>::ProxyServerBase(std::shared_ptr<Impl> impl, Co
: m_impl(std::move(impl)), m_context(&connection)
{
assert(m_impl);
std::unique_lock<std::mutex> lock(m_context.connection->m_loop.m_mutex);
m_context.connection->m_loop.addClient(lock);
}
//! ProxyServer destructor, called from the EventLoop thread by Cap'n Proto
@ -489,8 +480,6 @@ ProxyServerBase<Interface, Impl>::~ProxyServerBase()
});
}
assert(m_context.cleanup_fns.empty());
std::unique_lock<std::mutex> lock(m_context.connection->m_loop.m_mutex);
m_context.connection->m_loop.removeClient(lock);
}
//! If the capnp interface defined a special "destroy" method, as described the

View file

@ -558,7 +558,7 @@ template <typename Client>
void clientDestroy(Client& client)
{
if (client.m_context.connection) {
client.m_context.connection->m_loop.log() << "IPC client destroy " << typeid(client).name();
client.m_context.loop->log() << "IPC client destroy " << typeid(client).name();
} else {
KJ_LOG(INFO, "IPC interrupted client destroy", typeid(client).name());
}
@ -567,7 +567,7 @@ void clientDestroy(Client& client)
template <typename Server>
void serverDestroy(Server& server)
{
server.m_context.connection->m_loop.log() << "IPC server destroy " << typeid(server).name();
server.m_context.loop->log() << "IPC server destroy " << typeid(server).name();
}
//! Entry point called by generated client code that looks like:
@ -582,12 +582,9 @@ void serverDestroy(Server& server)
template <typename ProxyClient, typename GetRequest, typename... FieldObjs>
void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, FieldObjs&&... fields)
{
if (!proxy_client.m_context.connection) {
throw std::logic_error("clientInvoke call made after disconnect");
}
if (!g_thread_context.waiter) {
assert(g_thread_context.thread_name.empty());
g_thread_context.thread_name = ThreadName(proxy_client.m_context.connection->m_loop.m_exe_name);
g_thread_context.thread_name = ThreadName(proxy_client.m_context.loop->m_exe_name);
// If next assert triggers, it means clientInvoke is being called from
// the capnp event loop thread. This can happen when a ProxyServer
// method implementation that runs synchronously on the event loop
@ -598,7 +595,7 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel
// declaration so the server method runs in a dedicated thread.
assert(!g_thread_context.loop_thread);
g_thread_context.waiter = std::make_unique<Waiter>();
proxy_client.m_context.connection->m_loop.logPlain()
proxy_client.m_context.loop->logPlain()
<< "{" << g_thread_context.thread_name
<< "} IPC client first request from current thread, constructing waiter";
}
@ -606,18 +603,27 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel
std::exception_ptr exception;
std::string kj_exception;
bool done = false;
proxy_client.m_context.connection->m_loop.sync([&]() {
const char* disconnected = nullptr;
proxy_client.m_context.loop->sync([&]() {
if (!proxy_client.m_context.connection) {
const std::unique_lock<std::mutex> lock(invoke_context.thread_context.waiter->m_mutex);
done = true;
disconnected = "IPC client method called after disconnect.";
invoke_context.thread_context.waiter->m_cv.notify_all();
return;
}
auto request = (proxy_client.m_client.*get_request)(nullptr);
using Request = CapRequestTraits<decltype(request)>;
using FieldList = typename ProxyClientMethodTraits<typename Request::Params>::Fields;
IterateFields().handleChain(invoke_context, request, FieldList(), typename FieldObjs::BuildParams{&fields}...);
proxy_client.m_context.connection->m_loop.logPlain()
proxy_client.m_context.loop->logPlain()
<< "{" << invoke_context.thread_context.thread_name << "} IPC client send "
<< TypeName<typename Request::Params>() << " " << LogEscape(request.toString());
proxy_client.m_context.connection->m_loop.m_task_set->add(request.send().then(
proxy_client.m_context.loop->m_task_set->add(request.send().then(
[&](::capnp::Response<typename Request::Results>&& response) {
proxy_client.m_context.connection->m_loop.logPlain()
proxy_client.m_context.loop->logPlain()
<< "{" << invoke_context.thread_context.thread_name << "} IPC client recv "
<< TypeName<typename Request::Results>() << " " << LogEscape(response.toString());
try {
@ -631,9 +637,13 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel
invoke_context.thread_context.waiter->m_cv.notify_all();
},
[&](const ::kj::Exception& e) {
kj_exception = kj::str("kj::Exception: ", e).cStr();
proxy_client.m_context.connection->m_loop.logPlain()
<< "{" << invoke_context.thread_context.thread_name << "} IPC client exception " << kj_exception;
if (e.getType() == ::kj::Exception::Type::DISCONNECTED) {
disconnected = "IPC client method call interrupted by disconnect.";
} else {
kj_exception = kj::str("kj::Exception: ", e).cStr();
proxy_client.m_context.loop->logPlain()
<< "{" << invoke_context.thread_context.thread_name << "} IPC client exception " << kj_exception;
}
const std::unique_lock<std::mutex> lock(invoke_context.thread_context.waiter->m_mutex);
done = true;
invoke_context.thread_context.waiter->m_cv.notify_all();
@ -643,7 +653,8 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel
std::unique_lock<std::mutex> lock(invoke_context.thread_context.waiter->m_mutex);
invoke_context.thread_context.waiter->wait(lock, [&done]() { return done; });
if (exception) std::rethrow_exception(exception);
if (!kj_exception.empty()) proxy_client.m_context.connection->m_loop.raise() << kj_exception;
if (!kj_exception.empty()) proxy_client.m_context.loop->raise() << kj_exception;
if (disconnected) proxy_client.m_context.loop->raise() << disconnected;
}
//! Invoke callable `fn()` that may return void. If it does return void, replace
@ -682,7 +693,7 @@ kj::Promise<void> serverInvoke(Server& server, CallContext& call_context, Fn fn)
using Results = typename decltype(call_context.getResults())::Builds;
int req = ++server_reqs;
server.m_context.connection->m_loop.log() << "IPC server recv request #" << req << " "
server.m_context.loop->log() << "IPC server recv request #" << req << " "
<< TypeName<typename Params::Reads>() << " " << LogEscape(params.toString());
try {
@ -699,14 +710,14 @@ kj::Promise<void> serverInvoke(Server& server, CallContext& call_context, Fn fn)
return ReplaceVoid([&]() { return fn.invoke(server_context, ArgList()); },
[&]() { return kj::Promise<CallContext>(kj::mv(call_context)); })
.then([&server, req](CallContext call_context) {
server.m_context.connection->m_loop.log() << "IPC server send response #" << req << " " << TypeName<Results>()
server.m_context.loop->log() << "IPC server send response #" << req << " " << TypeName<Results>()
<< " " << LogEscape(call_context.getResults().toString());
});
} catch (const std::exception& e) {
server.m_context.connection->m_loop.log() << "IPC server unhandled exception: " << e.what();
server.m_context.loop->log() << "IPC server unhandled exception: " << e.what();
throw;
} catch (...) {
server.m_context.connection->m_loop.log() << "IPC server unhandled exception";
server.m_context.loop->log() << "IPC server unhandled exception";
throw;
}
}

View file

@ -8,6 +8,7 @@
#include <mp/util.h>
#include <array>
#include <cassert>
#include <functional>
#include <list>
#include <stddef.h>
@ -47,13 +48,34 @@ inline void CleanupRun(CleanupList& fns) {
}
}
//! Event loop smart pointer automatically managing m_num_clients.
//! If a lock pointer argument is passed, the specified lock will be used,
//! otherwise EventLoop::m_mutex will be locked when needed.
class EventLoopRef
{
public:
explicit EventLoopRef(EventLoop& loop, Lock* lock = nullptr);
EventLoopRef(EventLoopRef&& other) noexcept : m_loop(other.m_loop) { other.m_loop = nullptr; }
EventLoopRef(const EventLoopRef&) = delete;
EventLoopRef& operator=(const EventLoopRef&) = delete;
EventLoopRef& operator=(EventLoopRef&&) = delete;
~EventLoopRef() { reset(); }
EventLoop& operator*() const { assert(m_loop); return *m_loop; }
EventLoop* operator->() const { assert(m_loop); return m_loop; }
bool reset(Lock* lock = nullptr);
EventLoop* m_loop{nullptr};
Lock* m_lock{nullptr};
};
//! Context data associated with proxy client and server classes.
struct ProxyContext
{
Connection* connection;
EventLoopRef loop;
CleanupList cleanup_fns;
ProxyContext(Connection* connection) : connection(connection) {}
ProxyContext(Connection* connection);
};
//! Base class for generated ProxyClient classes that implement a C++ interface

View file

@ -64,8 +64,7 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
auto future = kj::newPromiseAndFulfiller<typename ServerContext::CallContext>();
auto& server = server_context.proxy_server;
int req = server_context.req;
auto invoke = MakeAsyncCallable(
[fulfiller = kj::mv(future.fulfiller),
auto invoke = [fulfiller = kj::mv(future.fulfiller),
call_context = kj::mv(server_context.call_context), &server, req, fn, args...]() mutable {
const auto& params = call_context.getParams();
Context::Reader context_arg = Accessor::get(params);
@ -132,35 +131,35 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
return;
}
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
server.m_context.connection->m_loop.sync([&] {
server.m_context.loop->sync([&] {
auto fulfiller_dispose = kj::mv(fulfiller);
fulfiller_dispose->fulfill(kj::mv(call_context));
});
}))
{
server.m_context.connection->m_loop.sync([&]() {
server.m_context.loop->sync([&]() {
auto fulfiller_dispose = kj::mv(fulfiller);
fulfiller_dispose->reject(kj::mv(*exception));
});
}
});
};
// Lookup Thread object specified by the client. The specified thread should
// be a local Thread::Server object, but it needs to be looked up
// asynchronously with getLocalServer().
auto thread_client = context_arg.getThread();
return server.m_context.connection->m_threads.getLocalServer(thread_client)
.then([&server, invoke, req](const kj::Maybe<Thread::Server&>& perhaps) {
.then([&server, invoke = kj::mv(invoke), req](const kj::Maybe<Thread::Server&>& perhaps) mutable {
// Assuming the thread object is found, pass it a pointer to the
// `invoke` lambda above which will invoke the function on that
// thread.
KJ_IF_MAYBE (thread_server, perhaps) {
const auto& thread = static_cast<ProxyServer<Thread>&>(*thread_server);
server.m_context.connection->m_loop.log()
server.m_context.loop->log()
<< "IPC server post request #" << req << " {" << thread.m_thread_context.thread_name << "}";
thread.m_thread_context.waiter->post(std::move(invoke));
} else {
server.m_context.connection->m_loop.log()
server.m_context.loop->log()
<< "IPC server error request #" << req << ", missing thread to execute request";
throw std::runtime_error("invalid thread handle");
}

View file

@ -6,6 +6,7 @@
#define MP_UTIL_H
#include <capnp/schema.h>
#include <cassert>
#include <cstddef>
#include <functional>
#include <future>
@ -13,11 +14,13 @@
#include <kj/exception.h>
#include <kj/string-tree.h>
#include <memory>
#include <mutex>
#include <string.h>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
namespace mp {
@ -130,6 +133,58 @@ const char* TypeName()
return short_name ? short_name + 1 : display_name;
}
//! Convenient wrapper around std::variant<T*, T>
template <typename T>
struct PtrOrValue {
std::variant<T*, T> data;
template <typename... Args>
PtrOrValue(T* ptr, Args&&... args) : data(ptr ? ptr : std::variant<T*, T>{std::in_place_type<T>, std::forward<Args>(args)...}) {}
T& operator*() { return data.index() ? std::get<T>(data) : *std::get<T*>(data); }
T* operator->() { return &**this; }
T& operator*() const { return data.index() ? std::get<T>(data) : *std::get<T*>(data); }
T* operator->() const { return &**this; }
};
// Annotated mutex and lock class (https://clang.llvm.org/docs/ThreadSafetyAnalysis.html)
#if defined(__clang__) && (!defined(SWIG))
#define MP_TSA(x) __attribute__((x))
#else
#define MP_TSA(x) // no-op
#endif
#define MP_CAPABILITY(x) MP_TSA(capability(x))
#define MP_SCOPED_CAPABILITY MP_TSA(scoped_lockable)
#define MP_REQUIRES(x) MP_TSA(requires_capability(x))
#define MP_ACQUIRE(...) MP_TSA(acquire_capability(__VA_ARGS__))
#define MP_RELEASE(...) MP_TSA(release_capability(__VA_ARGS__))
#define MP_ASSERT_CAPABILITY(x) MP_TSA(assert_capability(x))
#define MP_GUARDED_BY(x) MP_TSA(guarded_by(x))
class MP_CAPABILITY("mutex") Mutex {
public:
void lock() MP_ACQUIRE() { m_mutex.lock(); }
void unlock() MP_RELEASE() { m_mutex.unlock(); }
std::mutex m_mutex;
};
class MP_SCOPED_CAPABILITY Lock {
public:
explicit Lock(Mutex& m) MP_ACQUIRE(m) : m_lock(m.m_mutex) {}
~Lock() MP_RELEASE() {}
void unlock() MP_RELEASE() { m_lock.unlock(); }
void lock() MP_ACQUIRE() { m_lock.lock(); }
void assert_locked(Mutex& mutex) MP_ASSERT_CAPABILITY() MP_ASSERT_CAPABILITY(mutex)
{
assert(m_lock.mutex() == &mutex.m_mutex);
assert(m_lock);
}
std::unique_lock<std::mutex> m_lock;
};
//! Analog to std::lock_guard that unlocks instead of locks.
template <typename Lock>
struct UnlockGuard
@ -146,46 +201,6 @@ void Unlock(Lock& lock, Callback&& callback)
callback();
}
//! Needed for libc++/macOS compatibility. Lets code work with shared_ptr nothrow declaration
//! https://github.com/capnproto/capnproto/issues/553#issuecomment-328554603
template <typename T>
struct DestructorCatcher
{
T value;
template <typename... Params>
DestructorCatcher(Params&&... params) : value(kj::fwd<Params>(params)...)
{
}
~DestructorCatcher() noexcept try {
} catch (const kj::Exception& e) { // NOLINT(bugprone-empty-catch)
}
};
//! Wrapper around callback function for compatibility with std::async.
//!
//! std::async requires callbacks to be copyable and requires noexcept
//! destructors, but this doesn't work well with kj types which are generally
//! move-only and not noexcept.
template <typename Callable>
struct AsyncCallable
{
AsyncCallable(Callable&& callable) : m_callable(std::make_shared<DestructorCatcher<Callable>>(std::move(callable)))
{
}
AsyncCallable(const AsyncCallable&) = default;
AsyncCallable(AsyncCallable&&) = default;
~AsyncCallable() noexcept = default;
ResultOf<Callable> operator()() const { return (m_callable->value)(); }
mutable std::shared_ptr<DestructorCatcher<Callable>> m_callable;
};
//! Construct AsyncCallable object.
template <typename Callable>
AsyncCallable<std::remove_reference_t<Callable>> MakeAsyncCallable(Callable&& callable)
{
return std::move(callable);
}
//! Format current thread name as "{exe_name}-{$pid}/{thread_name}-{$tid}".
std::string ThreadName(const char* exe_name);

View file

@ -22,6 +22,7 @@
#include <kj/common.h>
#include <kj/debug.h>
#include <kj/exception.h>
#include <kj/function.h>
#include <kj/memory.h>
#include <map>
#include <memory>
@ -48,6 +49,36 @@ void LoggingErrorHandler::taskFailed(kj::Exception&& exception)
m_loop.log() << "Uncaught exception in daemonized task.";
}
EventLoopRef::EventLoopRef(EventLoop& loop, Lock* lock) : m_loop(&loop), m_lock(lock)
{
auto loop_lock{PtrOrValue{m_lock, m_loop->m_mutex}};
loop_lock->assert_locked(m_loop->m_mutex);
m_loop->m_num_clients += 1;
}
bool EventLoopRef::reset(Lock* lock)
{
bool done = false;
if (m_loop) {
auto loop_lock{PtrOrValue{lock ? lock : m_lock, m_loop->m_mutex}};
loop_lock->assert_locked(m_loop->m_mutex);
assert(m_loop->m_num_clients > 0);
m_loop->m_num_clients -= 1;
if (m_loop->done()) {
done = true;
m_loop->m_cv.notify_all();
int post_fd{m_loop->m_post_fd};
loop_lock->unlock();
char buffer = 0;
KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon)
}
m_loop = nullptr;
}
return done;
}
ProxyContext::ProxyContext(Connection* connection) : connection(connection), loop{*connection->m_loop} {}
Connection::~Connection()
{
// Shut down RPC system first, since this will garbage collect Server
@ -103,18 +134,18 @@ Connection::~Connection()
m_sync_cleanup_fns.pop_front();
}
while (!m_async_cleanup_fns.empty()) {
const std::unique_lock<std::mutex> lock(m_loop.m_mutex);
m_loop.m_async_fns.emplace_back(std::move(m_async_cleanup_fns.front()));
const Lock lock(m_loop->m_mutex);
m_loop->m_async_fns.emplace_back(std::move(m_async_cleanup_fns.front()));
m_async_cleanup_fns.pop_front();
}
std::unique_lock<std::mutex> lock(m_loop.m_mutex);
m_loop.startAsyncThread(lock);
m_loop.removeClient(lock);
Lock lock(m_loop->m_mutex);
m_loop->startAsyncThread();
m_loop.reset(&lock);
}
CleanupIt Connection::addSyncCleanup(std::function<void()> fn)
{
const std::unique_lock<std::mutex> lock(m_loop.m_mutex);
const Lock lock(m_loop->m_mutex);
// Add cleanup callbacks to the front of list, so sync cleanup functions run
// in LIFO order. This is a good approach because sync cleanup functions are
// added as client objects are created, and it is natural to clean up
@ -128,13 +159,13 @@ CleanupIt Connection::addSyncCleanup(std::function<void()> fn)
void Connection::removeSyncCleanup(CleanupIt it)
{
const std::unique_lock<std::mutex> lock(m_loop.m_mutex);
const Lock lock(m_loop->m_mutex);
m_sync_cleanup_fns.erase(it);
}
void Connection::addAsyncCleanup(std::function<void()> fn)
{
const std::unique_lock<std::mutex> lock(m_loop.m_mutex);
const Lock lock(m_loop->m_mutex);
// Add async cleanup callbacks to the back of the list. Unlike the sync
// cleanup list, this list order is more significant because it determines
// the order server objects are destroyed when there is a sudden disconnect,
@ -170,7 +201,7 @@ EventLoop::EventLoop(const char* exe_name, LogFn log_fn, void* context)
EventLoop::~EventLoop()
{
if (m_async_thread.joinable()) m_async_thread.join();
const std::lock_guard<std::mutex> lock(m_mutex);
const Lock lock(m_mutex);
KJ_ASSERT(m_post_fn == nullptr);
KJ_ASSERT(m_async_fns.empty());
KJ_ASSERT(m_wait_fd == -1);
@ -195,14 +226,14 @@ void EventLoop::loop()
for (;;) {
const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope);
if (read_bytes != 1) throw std::logic_error("EventLoop wait_stream closed unexpectedly");
std::unique_lock<std::mutex> lock(m_mutex);
Lock lock(m_mutex);
if (m_post_fn) {
Unlock(lock, *m_post_fn);
m_post_fn = nullptr;
m_cv.notify_all();
} else if (done(lock)) {
} else if (done()) {
// Intentionally do not break if m_post_fn was set, even if done()
// would return true, to ensure that the removeClient write(post_fd)
// would return true, to ensure that the EventLoopRef write(post_fd)
// call always succeeds and the loop does not exit between the time
// that the done condition is set and the write call is made.
break;
@ -213,75 +244,62 @@ void EventLoop::loop()
log() << "EventLoop::loop bye.";
wait_stream = nullptr;
KJ_SYSCALL(::close(post_fd));
const std::unique_lock<std::mutex> lock(m_mutex);
const Lock lock(m_mutex);
m_wait_fd = -1;
m_post_fd = -1;
}
void EventLoop::post(const std::function<void()>& fn)
void EventLoop::post(kj::Function<void()> fn)
{
if (std::this_thread::get_id() == m_thread_id) {
fn();
return;
}
std::unique_lock<std::mutex> lock(m_mutex);
addClient(lock);
m_cv.wait(lock, [this] { return m_post_fn == nullptr; });
Lock lock(m_mutex);
EventLoopRef ref(*this, &lock);
m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; });
m_post_fn = &fn;
int post_fd{m_post_fd};
Unlock(lock, [&] {
char buffer = 0;
KJ_SYSCALL(write(post_fd, &buffer, 1));
});
m_cv.wait(lock, [this, &fn] { return m_post_fn != &fn; });
removeClient(lock);
m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; });
}
void EventLoop::addClient(std::unique_lock<std::mutex>& lock) { m_num_clients += 1; }
bool EventLoop::removeClient(std::unique_lock<std::mutex>& lock)
{
m_num_clients -= 1;
if (done(lock)) {
m_cv.notify_all();
int post_fd{m_post_fd};
lock.unlock();
char buffer = 0;
KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon)
return true;
}
return false;
}
void EventLoop::startAsyncThread(std::unique_lock<std::mutex>& lock)
void EventLoop::startAsyncThread()
{
if (m_async_thread.joinable()) {
m_cv.notify_all();
} else if (!m_async_fns.empty()) {
m_async_thread = std::thread([this] {
std::unique_lock<std::mutex> lock(m_mutex);
while (true) {
Lock lock(m_mutex);
while (!done()) {
if (!m_async_fns.empty()) {
addClient(lock);
EventLoopRef ref{*this, &lock};
const std::function<void()> fn = std::move(m_async_fns.front());
m_async_fns.pop_front();
Unlock(lock, fn);
if (removeClient(lock)) break;
// Important to explictly call ref.reset() here and
// explicitly break if the EventLoop is done, not relying on
// while condition above. Reason is that end of `ref`
// lifetime can cause EventLoop::loop() to exit, and if
// there is external code that immediately deletes the
// EventLoop object as soon as EventLoop::loop() method
// returns, checking the while condition may crash.
if (ref.reset()) break;
// Continue without waiting in case there are more async_fns
continue;
} else if (m_num_clients == 0) {
break;
}
m_cv.wait(lock);
m_cv.wait(lock.m_lock);
}
});
}
}
bool EventLoop::done(std::unique_lock<std::mutex>& lock)
bool EventLoop::done()
{
assert(m_num_clients >= 0);
assert(lock.owns_lock());
assert(lock.mutex() == &m_mutex);
return m_num_clients == 0 && m_async_fns.empty();
}
@ -375,7 +393,7 @@ kj::Promise<void> ProxyServer<ThreadMap>::makeThread(MakeThreadContext context)
const std::string from = context.getParams().getName();
std::promise<ThreadContext*> thread_context;
std::thread thread([&thread_context, from, this]() {
g_thread_context.thread_name = ThreadName(m_connection.m_loop.m_exe_name) + " (from " + from + ")";
g_thread_context.thread_name = ThreadName(m_connection.m_loop->m_exe_name) + " (from " + from + ")";
g_thread_context.waiter = std::make_unique<Waiter>();
thread_context.set_value(&g_thread_context);
std::unique_lock<std::mutex> lock(g_thread_context.waiter->m_mutex);

View file

@ -23,32 +23,82 @@
namespace mp {
namespace test {
/**
* Test setup class creating a two way connection between a
* ProxyServer<FooInterface> object and a ProxyClient<FooInterface>.
*
* Provides client_disconnect and server_disconnect lambdas that can be used to
* trigger disconnects and test handling of broken and closed connections.
*
* Accepts a client_owns_connection option to test different ProxyClient
* destroy_connection values and control whether destroying the ProxyClient
* object destroys the client Connection object. Normally it makes sense for
* this to be true to simplify shutdown and avoid needing to call
* client_disconnect manually, but false allows testing more ProxyClient
* behavior and the "IPC client method called after disconnect" code path.
*/
class TestSetup
{
public:
std::thread thread;
std::function<void()> server_disconnect;
std::function<void()> client_disconnect;
std::promise<std::unique_ptr<ProxyClient<messages::FooInterface>>> client_promise;
std::unique_ptr<ProxyClient<messages::FooInterface>> client;
TestSetup(bool client_owns_connection = true)
: thread{[&] {
EventLoop loop("mptest", [](bool raise, const std::string& log) {
std::cout << "LOG" << raise << ": " << log << "\n";
if (raise) throw std::runtime_error(log);
});
auto pipe = loop.m_io_context.provider->newTwoWayPipe();
auto server_connection =
std::make_unique<Connection>(loop, kj::mv(pipe.ends[0]), [&](Connection& connection) {
auto server_proxy = kj::heap<ProxyServer<messages::FooInterface>>(
std::make_shared<FooImplementation>(), connection);
return capnp::Capability::Client(kj::mv(server_proxy));
});
server_disconnect = [&] { loop.sync([&] { server_connection.reset(); }); };
// Set handler to destroy the server when the client disconnects. This
// is ignored if server_disconnect() is called instead.
server_connection->onDisconnect([&] { server_connection.reset(); });
auto client_connection = std::make_unique<Connection>(loop, kj::mv(pipe.ends[1]));
auto client_proxy = std::make_unique<ProxyClient<messages::FooInterface>>(
client_connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs<messages::FooInterface>(),
client_connection.get(), /* destroy_connection= */ client_owns_connection);
if (client_owns_connection) {
client_connection.release();
} else {
client_disconnect = [&] { loop.sync([&] { client_connection.reset(); }); };
}
client_promise.set_value(std::move(client_proxy));
loop.loop();
}}
{
client = client_promise.get_future().get();
}
~TestSetup()
{
// Test that client cleanup_fns are executed.
bool destroyed = false;
client->m_context.cleanup_fns.emplace_front([&destroyed] { destroyed = true; });
client.reset();
KJ_EXPECT(destroyed);
thread.join();
}
};
KJ_TEST("Call FooInterface methods")
{
std::promise<std::unique_ptr<ProxyClient<messages::FooInterface>>> foo_promise;
std::function<void()> disconnect_client;
std::thread thread([&]() {
EventLoop loop("mptest", [](bool raise, const std::string& log) {
std::cout << "LOG" << raise << ": " << log << "\n";
});
auto pipe = loop.m_io_context.provider->newTwoWayPipe();
TestSetup setup;
ProxyClient<messages::FooInterface>* foo = setup.client.get();
auto connection_client = std::make_unique<Connection>(loop, kj::mv(pipe.ends[0]));
auto foo_client = std::make_unique<ProxyClient<messages::FooInterface>>(
connection_client->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs<messages::FooInterface>(),
connection_client.get(), /* destroy_connection= */ false);
foo_promise.set_value(std::move(foo_client));
disconnect_client = [&] { loop.sync([&] { connection_client.reset(); }); };
auto connection_server = std::make_unique<Connection>(loop, kj::mv(pipe.ends[1]), [&](Connection& connection) {
auto foo_server = kj::heap<ProxyServer<messages::FooInterface>>(std::make_shared<FooImplementation>(), connection);
return capnp::Capability::Client(kj::mv(foo_server));
});
connection_server->onDisconnect([&] { connection_server.reset(); });
loop.loop();
});
auto foo = foo_promise.get_future().get();
KJ_EXPECT(foo->add(1, 2) == 3);
FooStruct in;
@ -127,14 +177,40 @@ KJ_TEST("Call FooInterface methods")
mut.message = "init";
foo->passMutable(mut);
KJ_EXPECT(mut.message == "init build pass call return read");
}
disconnect_client();
thread.join();
KJ_TEST("Call IPC method after client connection is closed")
{
TestSetup setup{/*client_owns_connection=*/false};
ProxyClient<messages::FooInterface>* foo = setup.client.get();
KJ_EXPECT(foo->add(1, 2) == 3);
setup.client_disconnect();
bool destroyed = false;
foo->m_context.cleanup_fns.emplace_front([&destroyed]{ destroyed = true; });
foo.reset();
KJ_EXPECT(destroyed);
bool disconnected{false};
try {
foo->add(1, 2);
} catch (const std::runtime_error& e) {
KJ_EXPECT(std::string_view{e.what()} == "IPC client method called after disconnect.");
disconnected = true;
}
KJ_EXPECT(disconnected);
}
KJ_TEST("Calling IPC method after server connection is closed")
{
TestSetup setup;
ProxyClient<messages::FooInterface>* foo = setup.client.get();
KJ_EXPECT(foo->add(1, 2) == 3);
setup.server_disconnect();
bool disconnected{false};
try {
foo->add(1, 2);
} catch (const std::runtime_error& e) {
KJ_EXPECT(std::string_view{e.what()} == "IPC client method call interrupted by disconnect.");
disconnected = true;
}
KJ_EXPECT(disconnected);
}
} // namespace test