diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index 4eb27fae7c4..ee1ecaa9486 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -14,9 +14,10 @@ #include #include -#include +#include #include #include +#include #include #include @@ -129,6 +130,25 @@ std::string LongThreadName(const char* exe_name); //! Event loop implementation. //! +//! Cap'n Proto threading model is very simple: all I/O operations are +//! asynchronous and must be performed on a single thread. This includes: +//! +//! - Code starting an asynchronous operation (calling a function that returns a +//! promise object) +//! - Code notifying that an asynchronous operation is complete (code using a +//! fulfiller object) +//! - Code handling a completed operation (code chaining or waiting for a promise) +//! +//! All this code needs to run on one thread, and the EventLoop::loop() method +//! is the entry point for this thread. ProxyClient and ProxyServer objects that +//! use other threads and need to perform I/O operations post to this thread +//! using EventLoop::post() and EventLoop::sync() methods. +//! +//! Specifically, because ProxyClient methods can be called from arbitrary +//! threads, and ProxyServer methods can run on arbitrary threads, ProxyClient +//! methods use the EventLoop thread to send requests, and ProxyServer methods +//! use the thread to return results. +//! //! Based on https://groups.google.com/d/msg/capnproto/TuQFF1eH2-M/g81sHaTAAQAJ class EventLoop { @@ -144,7 +164,7 @@ public: //! Run function on event loop thread. Does not return until function completes. //! Must be called while the loop() function is active. - void post(const std::function& fn); + void post(kj::Function fn); //! Wrapper around EventLoop::post that takes advantage of the //! fact that callable will not go out of scope to avoid requirement that it @@ -152,7 +172,7 @@ public: template void sync(Callable&& callable) { - post(std::ref(callable)); + post(std::forward(callable)); } //! Start asynchronous worker thread if necessary. This is only done if @@ -166,13 +186,10 @@ public: //! is important that ProxyServer::m_impl destructors do not run on the //! eventloop thread because they may need it to do I/O if they perform //! other IPC calls. - void startAsyncThread(std::unique_lock& lock); + void startAsyncThread() MP_REQUIRES(m_mutex); - //! Add/remove remote client reference counts. - void addClient(std::unique_lock& lock); - bool removeClient(std::unique_lock& lock); //! Check if loop should exit. - bool done(std::unique_lock& lock); + bool done() MP_REQUIRES(m_mutex); Logger log() { @@ -195,10 +212,10 @@ public: std::thread m_async_thread; //! Callback function to run on event loop thread during post() or sync() call. - const std::function* m_post_fn = nullptr; + kj::Function* m_post_fn MP_GUARDED_BY(m_mutex) = nullptr; //! Callback functions to run on async thread. - CleanupList m_async_fns; + CleanupList m_async_fns MP_GUARDED_BY(m_mutex); //! Pipe read handle used to wake up the event loop thread. int m_wait_fd = -1; @@ -208,11 +225,11 @@ public: //! Number of clients holding references to ProxyServerBase objects that //! reference this event loop. - int m_num_clients = 0; + int m_num_clients MP_GUARDED_BY(m_mutex) = 0; //! Mutex and condition variable used to post tasks to event loop and async //! thread. - std::mutex m_mutex; + Mutex m_mutex; std::condition_variable m_cv; //! Capnp IO context. @@ -263,11 +280,9 @@ struct Waiter // in the case where a capnp response is sent and a brand new // request is immediately received. while (m_fn) { - auto fn = std::move(m_fn); - m_fn = nullptr; - lock.unlock(); - fn(); - lock.lock(); + auto fn = std::move(*m_fn); + m_fn.reset(); + Unlock(lock, fn); } const bool done = pred(); return done; @@ -276,7 +291,7 @@ struct Waiter std::mutex m_mutex; std::condition_variable m_cv; - std::function m_fn; + std::optional> m_fn; }; //! Object holding network & rpc state associated with either an incoming server @@ -290,21 +305,13 @@ public: Connection(EventLoop& loop, kj::Own&& stream_) : m_loop(loop), m_stream(kj::mv(stream_)), m_network(*m_stream, ::capnp::rpc::twoparty::Side::CLIENT, ::capnp::ReaderOptions()), - m_rpc_system(::capnp::makeRpcClient(m_network)) - { - std::unique_lock lock(m_loop.m_mutex); - m_loop.addClient(lock); - } + m_rpc_system(::capnp::makeRpcClient(m_network)) {} Connection(EventLoop& loop, kj::Own&& stream_, const std::function<::capnp::Capability::Client(Connection&)>& make_client) : m_loop(loop), m_stream(kj::mv(stream_)), m_network(*m_stream, ::capnp::rpc::twoparty::Side::SERVER, ::capnp::ReaderOptions()), - m_rpc_system(::capnp::makeRpcServer(m_network, make_client(*this))) - { - std::unique_lock lock(m_loop.m_mutex); - m_loop.addClient(lock); - } + m_rpc_system(::capnp::makeRpcServer(m_network, make_client(*this))) {} //! Run cleanup functions. Must be called from the event loop thread. First //! calls synchronous cleanup functions while blocked (to free capnp @@ -333,12 +340,12 @@ public: // to the EventLoop TaskSet to avoid "Promise callback destroyed itself" // error in cases where f deletes this Connection object. m_on_disconnect.add(m_network.onDisconnect().then( - [f = std::move(f), this]() mutable { m_loop.m_task_set->add(kj::evalLater(kj::mv(f))); })); + [f = std::move(f), this]() mutable { m_loop->m_task_set->add(kj::evalLater(kj::mv(f))); })); } - EventLoop& m_loop; + EventLoopRef m_loop; kj::Own m_stream; - LoggingErrorHandler m_error_handler{m_loop}; + LoggingErrorHandler m_error_handler{*m_loop}; kj::TaskSet m_on_disconnect{m_error_handler}; ::capnp::TwoPartyVatNetwork m_network; std::optional<::capnp::RpcSystem<::capnp::rpc::twoparty::VatId>> m_rpc_system; @@ -381,21 +388,12 @@ ProxyClientBase::ProxyClientBase(typename Interface::Client cli : m_client(std::move(client)), m_context(connection) { - { - std::unique_lock lock(m_context.connection->m_loop.m_mutex); - m_context.connection->m_loop.addClient(lock); - } - // Handler for the connection getting destroyed before this client object. auto cleanup_it = m_context.connection->addSyncCleanup([this]() { // Release client capability by move-assigning to temporary. { typename Interface::Client(std::move(m_client)); } - { - std::unique_lock lock(m_context.connection->m_loop.m_mutex); - m_context.connection->m_loop.removeClient(lock); - } m_context.connection = nullptr; }); @@ -423,16 +421,11 @@ ProxyClientBase::ProxyClientBase(typename Interface::Client cli Sub::destroy(*this); // FIXME: Could just invoke removed addCleanup fn here instead of duplicating code - m_context.connection->m_loop.sync([&]() { + m_context.loop->sync([&]() { // Release client capability by move-assigning to temporary. { typename Interface::Client(std::move(m_client)); } - { - std::unique_lock lock(m_context.connection->m_loop.m_mutex); - m_context.connection->m_loop.removeClient(lock); - } - if (destroy_connection) { delete m_context.connection; m_context.connection = nullptr; @@ -454,8 +447,6 @@ ProxyServerBase::ProxyServerBase(std::shared_ptr impl, Co : m_impl(std::move(impl)), m_context(&connection) { assert(m_impl); - std::unique_lock lock(m_context.connection->m_loop.m_mutex); - m_context.connection->m_loop.addClient(lock); } //! ProxyServer destructor, called from the EventLoop thread by Cap'n Proto @@ -489,8 +480,6 @@ ProxyServerBase::~ProxyServerBase() }); } assert(m_context.cleanup_fns.empty()); - std::unique_lock lock(m_context.connection->m_loop.m_mutex); - m_context.connection->m_loop.removeClient(lock); } //! If the capnp interface defined a special "destroy" method, as described the diff --git a/include/mp/proxy-types.h b/include/mp/proxy-types.h index 1a519efde02..0caeb60d937 100644 --- a/include/mp/proxy-types.h +++ b/include/mp/proxy-types.h @@ -558,7 +558,7 @@ template void clientDestroy(Client& client) { if (client.m_context.connection) { - client.m_context.connection->m_loop.log() << "IPC client destroy " << typeid(client).name(); + client.m_context.loop->log() << "IPC client destroy " << typeid(client).name(); } else { KJ_LOG(INFO, "IPC interrupted client destroy", typeid(client).name()); } @@ -567,7 +567,7 @@ void clientDestroy(Client& client) template void serverDestroy(Server& server) { - server.m_context.connection->m_loop.log() << "IPC server destroy " << typeid(server).name(); + server.m_context.loop->log() << "IPC server destroy " << typeid(server).name(); } //! Entry point called by generated client code that looks like: @@ -582,12 +582,9 @@ void serverDestroy(Server& server) template void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, FieldObjs&&... fields) { - if (!proxy_client.m_context.connection) { - throw std::logic_error("clientInvoke call made after disconnect"); - } if (!g_thread_context.waiter) { assert(g_thread_context.thread_name.empty()); - g_thread_context.thread_name = ThreadName(proxy_client.m_context.connection->m_loop.m_exe_name); + g_thread_context.thread_name = ThreadName(proxy_client.m_context.loop->m_exe_name); // If next assert triggers, it means clientInvoke is being called from // the capnp event loop thread. This can happen when a ProxyServer // method implementation that runs synchronously on the event loop @@ -598,7 +595,7 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel // declaration so the server method runs in a dedicated thread. assert(!g_thread_context.loop_thread); g_thread_context.waiter = std::make_unique(); - proxy_client.m_context.connection->m_loop.logPlain() + proxy_client.m_context.loop->logPlain() << "{" << g_thread_context.thread_name << "} IPC client first request from current thread, constructing waiter"; } @@ -606,18 +603,27 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel std::exception_ptr exception; std::string kj_exception; bool done = false; - proxy_client.m_context.connection->m_loop.sync([&]() { + const char* disconnected = nullptr; + proxy_client.m_context.loop->sync([&]() { + if (!proxy_client.m_context.connection) { + const std::unique_lock lock(invoke_context.thread_context.waiter->m_mutex); + done = true; + disconnected = "IPC client method called after disconnect."; + invoke_context.thread_context.waiter->m_cv.notify_all(); + return; + } + auto request = (proxy_client.m_client.*get_request)(nullptr); using Request = CapRequestTraits; using FieldList = typename ProxyClientMethodTraits::Fields; IterateFields().handleChain(invoke_context, request, FieldList(), typename FieldObjs::BuildParams{&fields}...); - proxy_client.m_context.connection->m_loop.logPlain() + proxy_client.m_context.loop->logPlain() << "{" << invoke_context.thread_context.thread_name << "} IPC client send " << TypeName() << " " << LogEscape(request.toString()); - proxy_client.m_context.connection->m_loop.m_task_set->add(request.send().then( + proxy_client.m_context.loop->m_task_set->add(request.send().then( [&](::capnp::Response&& response) { - proxy_client.m_context.connection->m_loop.logPlain() + proxy_client.m_context.loop->logPlain() << "{" << invoke_context.thread_context.thread_name << "} IPC client recv " << TypeName() << " " << LogEscape(response.toString()); try { @@ -631,9 +637,13 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel invoke_context.thread_context.waiter->m_cv.notify_all(); }, [&](const ::kj::Exception& e) { - kj_exception = kj::str("kj::Exception: ", e).cStr(); - proxy_client.m_context.connection->m_loop.logPlain() - << "{" << invoke_context.thread_context.thread_name << "} IPC client exception " << kj_exception; + if (e.getType() == ::kj::Exception::Type::DISCONNECTED) { + disconnected = "IPC client method call interrupted by disconnect."; + } else { + kj_exception = kj::str("kj::Exception: ", e).cStr(); + proxy_client.m_context.loop->logPlain() + << "{" << invoke_context.thread_context.thread_name << "} IPC client exception " << kj_exception; + } const std::unique_lock lock(invoke_context.thread_context.waiter->m_mutex); done = true; invoke_context.thread_context.waiter->m_cv.notify_all(); @@ -643,7 +653,8 @@ void clientInvoke(ProxyClient& proxy_client, const GetRequest& get_request, Fiel std::unique_lock lock(invoke_context.thread_context.waiter->m_mutex); invoke_context.thread_context.waiter->wait(lock, [&done]() { return done; }); if (exception) std::rethrow_exception(exception); - if (!kj_exception.empty()) proxy_client.m_context.connection->m_loop.raise() << kj_exception; + if (!kj_exception.empty()) proxy_client.m_context.loop->raise() << kj_exception; + if (disconnected) proxy_client.m_context.loop->raise() << disconnected; } //! Invoke callable `fn()` that may return void. If it does return void, replace @@ -682,7 +693,7 @@ kj::Promise serverInvoke(Server& server, CallContext& call_context, Fn fn) using Results = typename decltype(call_context.getResults())::Builds; int req = ++server_reqs; - server.m_context.connection->m_loop.log() << "IPC server recv request #" << req << " " + server.m_context.loop->log() << "IPC server recv request #" << req << " " << TypeName() << " " << LogEscape(params.toString()); try { @@ -699,14 +710,14 @@ kj::Promise serverInvoke(Server& server, CallContext& call_context, Fn fn) return ReplaceVoid([&]() { return fn.invoke(server_context, ArgList()); }, [&]() { return kj::Promise(kj::mv(call_context)); }) .then([&server, req](CallContext call_context) { - server.m_context.connection->m_loop.log() << "IPC server send response #" << req << " " << TypeName() + server.m_context.loop->log() << "IPC server send response #" << req << " " << TypeName() << " " << LogEscape(call_context.getResults().toString()); }); } catch (const std::exception& e) { - server.m_context.connection->m_loop.log() << "IPC server unhandled exception: " << e.what(); + server.m_context.loop->log() << "IPC server unhandled exception: " << e.what(); throw; } catch (...) { - server.m_context.connection->m_loop.log() << "IPC server unhandled exception"; + server.m_context.loop->log() << "IPC server unhandled exception"; throw; } } diff --git a/include/mp/proxy.h b/include/mp/proxy.h index 76be0992109..34a2540d8ea 100644 --- a/include/mp/proxy.h +++ b/include/mp/proxy.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -47,13 +48,34 @@ inline void CleanupRun(CleanupList& fns) { } } +//! Event loop smart pointer automatically managing m_num_clients. +//! If a lock pointer argument is passed, the specified lock will be used, +//! otherwise EventLoop::m_mutex will be locked when needed. +class EventLoopRef +{ +public: + explicit EventLoopRef(EventLoop& loop, Lock* lock = nullptr); + EventLoopRef(EventLoopRef&& other) noexcept : m_loop(other.m_loop) { other.m_loop = nullptr; } + EventLoopRef(const EventLoopRef&) = delete; + EventLoopRef& operator=(const EventLoopRef&) = delete; + EventLoopRef& operator=(EventLoopRef&&) = delete; + ~EventLoopRef() { reset(); } + EventLoop& operator*() const { assert(m_loop); return *m_loop; } + EventLoop* operator->() const { assert(m_loop); return m_loop; } + bool reset(Lock* lock = nullptr); + + EventLoop* m_loop{nullptr}; + Lock* m_lock{nullptr}; +}; + //! Context data associated with proxy client and server classes. struct ProxyContext { Connection* connection; + EventLoopRef loop; CleanupList cleanup_fns; - ProxyContext(Connection* connection) : connection(connection) {} + ProxyContext(Connection* connection); }; //! Base class for generated ProxyClient classes that implement a C++ interface diff --git a/include/mp/type-context.h b/include/mp/type-context.h index 7c12afe2ff0..ae4a2d89b29 100644 --- a/include/mp/type-context.h +++ b/include/mp/type-context.h @@ -64,8 +64,7 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn& auto future = kj::newPromiseAndFulfiller(); auto& server = server_context.proxy_server; int req = server_context.req; - auto invoke = MakeAsyncCallable( - [fulfiller = kj::mv(future.fulfiller), + auto invoke = [fulfiller = kj::mv(future.fulfiller), call_context = kj::mv(server_context.call_context), &server, req, fn, args...]() mutable { const auto& params = call_context.getParams(); Context::Reader context_arg = Accessor::get(params); @@ -132,35 +131,35 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn& return; } KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - server.m_context.connection->m_loop.sync([&] { + server.m_context.loop->sync([&] { auto fulfiller_dispose = kj::mv(fulfiller); fulfiller_dispose->fulfill(kj::mv(call_context)); }); })) { - server.m_context.connection->m_loop.sync([&]() { + server.m_context.loop->sync([&]() { auto fulfiller_dispose = kj::mv(fulfiller); fulfiller_dispose->reject(kj::mv(*exception)); }); } - }); + }; // Lookup Thread object specified by the client. The specified thread should // be a local Thread::Server object, but it needs to be looked up // asynchronously with getLocalServer(). auto thread_client = context_arg.getThread(); return server.m_context.connection->m_threads.getLocalServer(thread_client) - .then([&server, invoke, req](const kj::Maybe& perhaps) { + .then([&server, invoke = kj::mv(invoke), req](const kj::Maybe& perhaps) mutable { // Assuming the thread object is found, pass it a pointer to the // `invoke` lambda above which will invoke the function on that // thread. KJ_IF_MAYBE (thread_server, perhaps) { const auto& thread = static_cast&>(*thread_server); - server.m_context.connection->m_loop.log() + server.m_context.loop->log() << "IPC server post request #" << req << " {" << thread.m_thread_context.thread_name << "}"; thread.m_thread_context.waiter->post(std::move(invoke)); } else { - server.m_context.connection->m_loop.log() + server.m_context.loop->log() << "IPC server error request #" << req << ", missing thread to execute request"; throw std::runtime_error("invalid thread handle"); } diff --git a/include/mp/util.h b/include/mp/util.h index ebfc3b5e720..d3a5c3eb0b4 100644 --- a/include/mp/util.h +++ b/include/mp/util.h @@ -6,6 +6,7 @@ #define MP_UTIL_H #include +#include #include #include #include @@ -13,11 +14,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include namespace mp { @@ -130,6 +133,58 @@ const char* TypeName() return short_name ? short_name + 1 : display_name; } +//! Convenient wrapper around std::variant +template +struct PtrOrValue { + std::variant data; + + template + PtrOrValue(T* ptr, Args&&... args) : data(ptr ? ptr : std::variant{std::in_place_type, std::forward(args)...}) {} + + T& operator*() { return data.index() ? std::get(data) : *std::get(data); } + T* operator->() { return &**this; } + T& operator*() const { return data.index() ? std::get(data) : *std::get(data); } + T* operator->() const { return &**this; } +}; + +// Annotated mutex and lock class (https://clang.llvm.org/docs/ThreadSafetyAnalysis.html) +#if defined(__clang__) && (!defined(SWIG)) +#define MP_TSA(x) __attribute__((x)) +#else +#define MP_TSA(x) // no-op +#endif + +#define MP_CAPABILITY(x) MP_TSA(capability(x)) +#define MP_SCOPED_CAPABILITY MP_TSA(scoped_lockable) +#define MP_REQUIRES(x) MP_TSA(requires_capability(x)) +#define MP_ACQUIRE(...) MP_TSA(acquire_capability(__VA_ARGS__)) +#define MP_RELEASE(...) MP_TSA(release_capability(__VA_ARGS__)) +#define MP_ASSERT_CAPABILITY(x) MP_TSA(assert_capability(x)) +#define MP_GUARDED_BY(x) MP_TSA(guarded_by(x)) + +class MP_CAPABILITY("mutex") Mutex { +public: + void lock() MP_ACQUIRE() { m_mutex.lock(); } + void unlock() MP_RELEASE() { m_mutex.unlock(); } + + std::mutex m_mutex; +}; + +class MP_SCOPED_CAPABILITY Lock { +public: + explicit Lock(Mutex& m) MP_ACQUIRE(m) : m_lock(m.m_mutex) {} + ~Lock() MP_RELEASE() {} + void unlock() MP_RELEASE() { m_lock.unlock(); } + void lock() MP_ACQUIRE() { m_lock.lock(); } + void assert_locked(Mutex& mutex) MP_ASSERT_CAPABILITY() MP_ASSERT_CAPABILITY(mutex) + { + assert(m_lock.mutex() == &mutex.m_mutex); + assert(m_lock); + } + + std::unique_lock m_lock; +}; + //! Analog to std::lock_guard that unlocks instead of locks. template struct UnlockGuard @@ -146,46 +201,6 @@ void Unlock(Lock& lock, Callback&& callback) callback(); } -//! Needed for libc++/macOS compatibility. Lets code work with shared_ptr nothrow declaration -//! https://github.com/capnproto/capnproto/issues/553#issuecomment-328554603 -template -struct DestructorCatcher -{ - T value; - template - DestructorCatcher(Params&&... params) : value(kj::fwd(params)...) - { - } - ~DestructorCatcher() noexcept try { - } catch (const kj::Exception& e) { // NOLINT(bugprone-empty-catch) - } -}; - -//! Wrapper around callback function for compatibility with std::async. -//! -//! std::async requires callbacks to be copyable and requires noexcept -//! destructors, but this doesn't work well with kj types which are generally -//! move-only and not noexcept. -template -struct AsyncCallable -{ - AsyncCallable(Callable&& callable) : m_callable(std::make_shared>(std::move(callable))) - { - } - AsyncCallable(const AsyncCallable&) = default; - AsyncCallable(AsyncCallable&&) = default; - ~AsyncCallable() noexcept = default; - ResultOf operator()() const { return (m_callable->value)(); } - mutable std::shared_ptr> m_callable; -}; - -//! Construct AsyncCallable object. -template -AsyncCallable> MakeAsyncCallable(Callable&& callable) -{ - return std::move(callable); -} - //! Format current thread name as "{exe_name}-{$pid}/{thread_name}-{$tid}". std::string ThreadName(const char* exe_name); diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index 194ded5f265..4916b31adaa 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,36 @@ void LoggingErrorHandler::taskFailed(kj::Exception&& exception) m_loop.log() << "Uncaught exception in daemonized task."; } +EventLoopRef::EventLoopRef(EventLoop& loop, Lock* lock) : m_loop(&loop), m_lock(lock) +{ + auto loop_lock{PtrOrValue{m_lock, m_loop->m_mutex}}; + loop_lock->assert_locked(m_loop->m_mutex); + m_loop->m_num_clients += 1; +} + +bool EventLoopRef::reset(Lock* lock) +{ + bool done = false; + if (m_loop) { + auto loop_lock{PtrOrValue{lock ? lock : m_lock, m_loop->m_mutex}}; + loop_lock->assert_locked(m_loop->m_mutex); + assert(m_loop->m_num_clients > 0); + m_loop->m_num_clients -= 1; + if (m_loop->done()) { + done = true; + m_loop->m_cv.notify_all(); + int post_fd{m_loop->m_post_fd}; + loop_lock->unlock(); + char buffer = 0; + KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) + } + m_loop = nullptr; + } + return done; +} + +ProxyContext::ProxyContext(Connection* connection) : connection(connection), loop{*connection->m_loop} {} + Connection::~Connection() { // Shut down RPC system first, since this will garbage collect Server @@ -103,18 +134,18 @@ Connection::~Connection() m_sync_cleanup_fns.pop_front(); } while (!m_async_cleanup_fns.empty()) { - const std::unique_lock lock(m_loop.m_mutex); - m_loop.m_async_fns.emplace_back(std::move(m_async_cleanup_fns.front())); + const Lock lock(m_loop->m_mutex); + m_loop->m_async_fns.emplace_back(std::move(m_async_cleanup_fns.front())); m_async_cleanup_fns.pop_front(); } - std::unique_lock lock(m_loop.m_mutex); - m_loop.startAsyncThread(lock); - m_loop.removeClient(lock); + Lock lock(m_loop->m_mutex); + m_loop->startAsyncThread(); + m_loop.reset(&lock); } CleanupIt Connection::addSyncCleanup(std::function fn) { - const std::unique_lock lock(m_loop.m_mutex); + const Lock lock(m_loop->m_mutex); // Add cleanup callbacks to the front of list, so sync cleanup functions run // in LIFO order. This is a good approach because sync cleanup functions are // added as client objects are created, and it is natural to clean up @@ -128,13 +159,13 @@ CleanupIt Connection::addSyncCleanup(std::function fn) void Connection::removeSyncCleanup(CleanupIt it) { - const std::unique_lock lock(m_loop.m_mutex); + const Lock lock(m_loop->m_mutex); m_sync_cleanup_fns.erase(it); } void Connection::addAsyncCleanup(std::function fn) { - const std::unique_lock lock(m_loop.m_mutex); + const Lock lock(m_loop->m_mutex); // Add async cleanup callbacks to the back of the list. Unlike the sync // cleanup list, this list order is more significant because it determines // the order server objects are destroyed when there is a sudden disconnect, @@ -170,7 +201,7 @@ EventLoop::EventLoop(const char* exe_name, LogFn log_fn, void* context) EventLoop::~EventLoop() { if (m_async_thread.joinable()) m_async_thread.join(); - const std::lock_guard lock(m_mutex); + const Lock lock(m_mutex); KJ_ASSERT(m_post_fn == nullptr); KJ_ASSERT(m_async_fns.empty()); KJ_ASSERT(m_wait_fd == -1); @@ -195,14 +226,14 @@ void EventLoop::loop() for (;;) { const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope); if (read_bytes != 1) throw std::logic_error("EventLoop wait_stream closed unexpectedly"); - std::unique_lock lock(m_mutex); + Lock lock(m_mutex); if (m_post_fn) { Unlock(lock, *m_post_fn); m_post_fn = nullptr; m_cv.notify_all(); - } else if (done(lock)) { + } else if (done()) { // Intentionally do not break if m_post_fn was set, even if done() - // would return true, to ensure that the removeClient write(post_fd) + // would return true, to ensure that the EventLoopRef write(post_fd) // call always succeeds and the loop does not exit between the time // that the done condition is set and the write call is made. break; @@ -213,75 +244,62 @@ void EventLoop::loop() log() << "EventLoop::loop bye."; wait_stream = nullptr; KJ_SYSCALL(::close(post_fd)); - const std::unique_lock lock(m_mutex); + const Lock lock(m_mutex); m_wait_fd = -1; m_post_fd = -1; } -void EventLoop::post(const std::function& fn) +void EventLoop::post(kj::Function fn) { if (std::this_thread::get_id() == m_thread_id) { fn(); return; } - std::unique_lock lock(m_mutex); - addClient(lock); - m_cv.wait(lock, [this] { return m_post_fn == nullptr; }); + Lock lock(m_mutex); + EventLoopRef ref(*this, &lock); + m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; }); m_post_fn = &fn; int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; KJ_SYSCALL(write(post_fd, &buffer, 1)); }); - m_cv.wait(lock, [this, &fn] { return m_post_fn != &fn; }); - removeClient(lock); + m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; }); } -void EventLoop::addClient(std::unique_lock& lock) { m_num_clients += 1; } - -bool EventLoop::removeClient(std::unique_lock& lock) -{ - m_num_clients -= 1; - if (done(lock)) { - m_cv.notify_all(); - int post_fd{m_post_fd}; - lock.unlock(); - char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) - return true; - } - return false; -} - -void EventLoop::startAsyncThread(std::unique_lock& lock) +void EventLoop::startAsyncThread() { if (m_async_thread.joinable()) { m_cv.notify_all(); } else if (!m_async_fns.empty()) { m_async_thread = std::thread([this] { - std::unique_lock lock(m_mutex); - while (true) { + Lock lock(m_mutex); + while (!done()) { if (!m_async_fns.empty()) { - addClient(lock); + EventLoopRef ref{*this, &lock}; const std::function fn = std::move(m_async_fns.front()); m_async_fns.pop_front(); Unlock(lock, fn); - if (removeClient(lock)) break; + // Important to explictly call ref.reset() here and + // explicitly break if the EventLoop is done, not relying on + // while condition above. Reason is that end of `ref` + // lifetime can cause EventLoop::loop() to exit, and if + // there is external code that immediately deletes the + // EventLoop object as soon as EventLoop::loop() method + // returns, checking the while condition may crash. + if (ref.reset()) break; + // Continue without waiting in case there are more async_fns continue; - } else if (m_num_clients == 0) { - break; } - m_cv.wait(lock); + m_cv.wait(lock.m_lock); } }); } } -bool EventLoop::done(std::unique_lock& lock) +bool EventLoop::done() { assert(m_num_clients >= 0); - assert(lock.owns_lock()); - assert(lock.mutex() == &m_mutex); return m_num_clients == 0 && m_async_fns.empty(); } @@ -375,7 +393,7 @@ kj::Promise ProxyServer::makeThread(MakeThreadContext context) const std::string from = context.getParams().getName(); std::promise thread_context; std::thread thread([&thread_context, from, this]() { - g_thread_context.thread_name = ThreadName(m_connection.m_loop.m_exe_name) + " (from " + from + ")"; + g_thread_context.thread_name = ThreadName(m_connection.m_loop->m_exe_name) + " (from " + from + ")"; g_thread_context.waiter = std::make_unique(); thread_context.set_value(&g_thread_context); std::unique_lock lock(g_thread_context.waiter->m_mutex); diff --git a/test/mp/test/test.cpp b/test/mp/test/test.cpp index 7fc64f6741d..7092836d85e 100644 --- a/test/mp/test/test.cpp +++ b/test/mp/test/test.cpp @@ -23,32 +23,82 @@ namespace mp { namespace test { +/** + * Test setup class creating a two way connection between a + * ProxyServer object and a ProxyClient. + * + * Provides client_disconnect and server_disconnect lambdas that can be used to + * trigger disconnects and test handling of broken and closed connections. + * + * Accepts a client_owns_connection option to test different ProxyClient + * destroy_connection values and control whether destroying the ProxyClient + * object destroys the client Connection object. Normally it makes sense for + * this to be true to simplify shutdown and avoid needing to call + * client_disconnect manually, but false allows testing more ProxyClient + * behavior and the "IPC client method called after disconnect" code path. + */ +class TestSetup +{ +public: + std::thread thread; + std::function server_disconnect; + std::function client_disconnect; + std::promise>> client_promise; + std::unique_ptr> client; + + TestSetup(bool client_owns_connection = true) + : thread{[&] { + EventLoop loop("mptest", [](bool raise, const std::string& log) { + std::cout << "LOG" << raise << ": " << log << "\n"; + if (raise) throw std::runtime_error(log); + }); + auto pipe = loop.m_io_context.provider->newTwoWayPipe(); + + auto server_connection = + std::make_unique(loop, kj::mv(pipe.ends[0]), [&](Connection& connection) { + auto server_proxy = kj::heap>( + std::make_shared(), connection); + return capnp::Capability::Client(kj::mv(server_proxy)); + }); + server_disconnect = [&] { loop.sync([&] { server_connection.reset(); }); }; + // Set handler to destroy the server when the client disconnects. This + // is ignored if server_disconnect() is called instead. + server_connection->onDisconnect([&] { server_connection.reset(); }); + + auto client_connection = std::make_unique(loop, kj::mv(pipe.ends[1])); + auto client_proxy = std::make_unique>( + client_connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs(), + client_connection.get(), /* destroy_connection= */ client_owns_connection); + if (client_owns_connection) { + client_connection.release(); + } else { + client_disconnect = [&] { loop.sync([&] { client_connection.reset(); }); }; + } + + client_promise.set_value(std::move(client_proxy)); + loop.loop(); + }} + { + client = client_promise.get_future().get(); + } + + ~TestSetup() + { + // Test that client cleanup_fns are executed. + bool destroyed = false; + client->m_context.cleanup_fns.emplace_front([&destroyed] { destroyed = true; }); + client.reset(); + KJ_EXPECT(destroyed); + + thread.join(); + } +}; + KJ_TEST("Call FooInterface methods") { - std::promise>> foo_promise; - std::function disconnect_client; - std::thread thread([&]() { - EventLoop loop("mptest", [](bool raise, const std::string& log) { - std::cout << "LOG" << raise << ": " << log << "\n"; - }); - auto pipe = loop.m_io_context.provider->newTwoWayPipe(); + TestSetup setup; + ProxyClient* foo = setup.client.get(); - auto connection_client = std::make_unique(loop, kj::mv(pipe.ends[0])); - auto foo_client = std::make_unique>( - connection_client->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs(), - connection_client.get(), /* destroy_connection= */ false); - foo_promise.set_value(std::move(foo_client)); - disconnect_client = [&] { loop.sync([&] { connection_client.reset(); }); }; - - auto connection_server = std::make_unique(loop, kj::mv(pipe.ends[1]), [&](Connection& connection) { - auto foo_server = kj::heap>(std::make_shared(), connection); - return capnp::Capability::Client(kj::mv(foo_server)); - }); - connection_server->onDisconnect([&] { connection_server.reset(); }); - loop.loop(); - }); - - auto foo = foo_promise.get_future().get(); KJ_EXPECT(foo->add(1, 2) == 3); FooStruct in; @@ -127,14 +177,40 @@ KJ_TEST("Call FooInterface methods") mut.message = "init"; foo->passMutable(mut); KJ_EXPECT(mut.message == "init build pass call return read"); +} - disconnect_client(); - thread.join(); +KJ_TEST("Call IPC method after client connection is closed") +{ + TestSetup setup{/*client_owns_connection=*/false}; + ProxyClient* foo = setup.client.get(); + KJ_EXPECT(foo->add(1, 2) == 3); + setup.client_disconnect(); - bool destroyed = false; - foo->m_context.cleanup_fns.emplace_front([&destroyed]{ destroyed = true; }); - foo.reset(); - KJ_EXPECT(destroyed); + bool disconnected{false}; + try { + foo->add(1, 2); + } catch (const std::runtime_error& e) { + KJ_EXPECT(std::string_view{e.what()} == "IPC client method called after disconnect."); + disconnected = true; + } + KJ_EXPECT(disconnected); +} + +KJ_TEST("Calling IPC method after server connection is closed") +{ + TestSetup setup; + ProxyClient* foo = setup.client.get(); + KJ_EXPECT(foo->add(1, 2) == 3); + setup.server_disconnect(); + + bool disconnected{false}; + try { + foo->add(1, 2); + } catch (const std::runtime_error& e) { + KJ_EXPECT(std::string_view{e.what()} == "IPC client method call interrupted by disconnect."); + disconnected = true; + } + KJ_EXPECT(disconnected); } } // namespace test