[orc-rt] Add managed-code-calls TaskGroup. (#190740)

Adds a ManagedCodeCallsGroup TaskGroup to Session, and updates the
shutdown sequence to wait until all calls into managed code have
completed before proceeding to shut down the Session's Services and the
Session itself.

To support safe calls into managed code two new helper template methods
are added:

callManagedCodeSync attempts to acquire a TaskGroup::Token for the
ManagedCodeCallsGroup before calling the given function and returning
its result.

callManagedCodeAsync attempts to acquire a TaskGroup::Token for the
ManagedCodeCallsGroup before calling the given async function. The
wrapped Return call for the async function will carry the acquired
Token, ensuring that shutdown waits for the async Return call to be
destroyed (whether or not it's actually called).
This commit is contained in:
Lang Hames
2026-04-07 20:53:28 +10:00
committed by GitHub
parent 546787ec97
commit 543ec358dd
3 changed files with 370 additions and 20 deletions

View File

@@ -20,6 +20,7 @@
#include "orc-rt/Service.h"
#include "orc-rt/SimpleSymbolTable.h"
#include "orc-rt/TaskDispatcher.h"
#include "orc-rt/TaskGroup.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
@@ -30,10 +31,10 @@
#include <future>
#include <memory>
#include <mutex>
#include <type_traits>
#include <vector>
namespace orc_rt {
class Session;
inline orc_rt_SessionRef wrap(Session *S) noexcept {
@@ -46,6 +47,60 @@ inline Session *unwrap(orc_rt_SessionRef S) noexcept {
/// Represents an ORC executor Session.
class Session {
private:
// Implementation helper for callManagedCodeSync (non-void version).
template <typename RetT> struct ManagedCodeSyncCaller {
template <typename FnT, typename... ArgTs>
static std::optional<RetT> call(TaskGroup::Token Tok, FnT &&Fn,
ArgTs &&...Args) {
if (!Tok)
return std::nullopt;
return std::forward<FnT>(Fn)(std::forward<ArgTs>(Args)...);
}
};
// Implementation helper for callManagedCodeSync (void version).
template <> struct ManagedCodeSyncCaller<void> {
template <typename FnT, typename... ArgTs>
static bool call(TaskGroup::Token Tok, FnT &&Fn, ArgTs &&...Args) {
if (!Tok)
return false;
std::forward<FnT>(Fn)(std::forward<ArgTs>(Args)...);
return true;
}
};
template <typename ReturnArgTupleT> struct ManagedCodeAsyncCaller;
// Implementation helper for callManagedCodeAsync (non-void version).
template <typename T>
struct ManagedCodeAsyncCaller<std::tuple<std::optional<T>>> {
template <typename ReturnT, typename FnT, typename... ArgTs>
static void call(TaskGroup::Token Tok, ReturnT &&Return, FnT &&Fn,
ArgTs &&...Args) {
if (!Tok)
return std::forward<ReturnT>(Return)(std::nullopt);
std::forward<FnT>(Fn)([Tok = std::move(Tok), R = std::move(Return)](
T Value) { R(std::move(Value)); },
std::forward<ArgTs>(Args)...);
}
};
// Implementation helper for callManagedCodeAsync (void version).
template <> struct ManagedCodeAsyncCaller<std::tuple<bool>> {
template <typename ReturnT, typename FnT, typename... ArgTs>
static void call(TaskGroup::Token Tok, ReturnT &&Return, FnT &&Fn,
ArgTs &&...Args) {
if (!Tok)
return std::forward<ReturnT>(Return)(false);
std::forward<FnT>(Fn)(
[Tok = std::move(Tok), R = std::move(Return)]() { R(true); },
std::forward<ArgTs>(Args)...);
}
};
public:
using ErrorReporterFn = move_only_function<void(Error)>;
using OnDetachFn = move_only_function<void()>;
@@ -242,6 +297,64 @@ public:
/// Session has already shut down, the callback will be called immediately.
void addOnShutdown(OnShutdownFn OnShutdown);
/// Returns a reference to this Session's ManagedCodeCallsGroup.
///
/// When calling code managed by a Session (e.g. JIT'd code, or library code
/// loaded on behalf of JIT'd code), clients should hold a token for this
/// group. That token will prevent the Session from shutting down any Services
/// (and the Session itself) until calls into managed code have completed.
///
/// Clients should prefer using the callManagedCodeSync and
/// callManagedCodeAsync helpers to automatically acquire and hold a token
/// for the duration of a call.
const std::shared_ptr<TaskGroup> &managedCodeCallsGroup() const {
return ManagedCodeCallsGroup;
}
/// Synchronously call managed code.
///
/// This helper tries to acquire a ManagedCodeCallsGroup token and then call
/// the given function object with the given arguments while holding the
/// token.
///
/// If the token is successfully acquired then this function will return the
/// call result as a std::optional<T> (for a non-void return type T), or
/// boolean true (for void returns).
///
/// If the token is not successfully acquired then this function will return
/// std::nullopt (for non-void return type) or boolean false (for void
/// returns).
template <typename FnT, typename... ArgTs>
decltype(auto) callManagedCodeSync(FnT &&Fn, ArgTs &&...Args) {
return ManagedCodeSyncCaller<std::invoke_result_t<FnT, ArgTs...>>::call(
TaskGroup::Token(ManagedCodeCallsGroup), std::forward<FnT>(Fn),
std::forward<ArgTs>(Args)...);
}
/// Asynchronously call managed code.
///
/// ReturnT must be a function object that takes either a boolean or a
/// std::optional<T>.
///
/// callManagedCodeAsync tries to acquire a ManagedCodeCallsGroup token and
/// then call the given async function object while holding that token.
///
/// If the token is successfully acquired then this function will call Fn,
/// passing in a wrapped version of Return that takes a T (if Return takes a
/// std::optional<T>), or a wrapped version of Return that takes no arguments
/// (if Return takes a bool).
///
/// If the token is not successfully acquired then this function will not
/// call Fn, but instead immediately call Return with std::nullopt (if Return
/// takes a std::optional<T>), or false (if Return takes a boolean).
template <typename ReturnT, typename FnT, typename... ArgTs>
void callManagedCodeAsync(ReturnT &&Return, FnT &&Fn, ArgTs &&...Args) {
ManagedCodeAsyncCaller<typename CallableArgInfo<ReturnT>::args_tuple_type>::
call(TaskGroup::Token(ManagedCodeCallsGroup),
std::forward<ReturnT>(Return), std::forward<FnT>(Fn),
std::forward<ArgTs>(Args)...);
}
/// Call a tagged handler in the Controller.
///
/// This method can be called directly, but is expected to be more commonly
@@ -309,12 +422,21 @@ private:
void detachServices(std::vector<Service *> ToNotify, bool ShutdownRequested);
void completeDetach();
void proceedToShutdown(std::unique_lock<std::mutex> &Lock);
void waitForManagedCodeCallsThenShutdown();
void proceedToShutdown();
void shutdownServices(std::vector<Service *> ToNotify);
void completeShutdown();
void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn,
WrapperFunctionBuffer ArgBytes) {
if (!ManagedCodeCallsGroup->acquireToken()) {
// The ManagedCodeCallsGroup is only closed after detach, so if token
// acquisition fails we don't try to return an error: the controller
// should already have signalled error to the caller, and we have no
// way to transmit an error anyway.
return;
}
dispatch(makeGenericTask([=, ArgBytes = std::move(ArgBytes)]() mutable {
Fn(wrap(this), CallId, wrapperReturn, ArgBytes.release());
}));
@@ -323,6 +445,7 @@ private:
void sendWrapperResult(uint64_t CallId, WrapperFunctionBuffer ResultBytes) {
if (auto TmpCA = std::atomic_load(&CA))
TmpCA->sendWrapperResult(CallId, std::move(ResultBytes));
ManagedCodeCallsGroup->releaseToken();
}
static void wrapperReturn(orc_rt_SessionRef S, uint64_t CallId,
@@ -330,6 +453,7 @@ private:
ExecutorProcessInfo EPI;
std::unique_ptr<TaskDispatcher> Dispatcher;
std::shared_ptr<TaskGroup> ManagedCodeCallsGroup = TaskGroup::Create();
std::shared_ptr<ControllerAccess> CA;
ErrorReporterFn ReportError;

View File

@@ -183,7 +183,8 @@ void Session::shutdown(OnShutdownFn OnShutdown) {
TmpCA = std::atomic_load(&this->CA);
break;
case State::Detached:
proceedToShutdown(Lock);
Lock.unlock();
waitForManagedCodeCallsThenShutdown();
return;
default:
assert(false && "Illegal state");
@@ -314,27 +315,35 @@ void Session::detachServices(std::vector<Service *> ToNotify,
}
void Session::completeDetach() {
std::unique_lock<std::mutex> Lock(M);
assert(CurrentState == State::Detached);
if (TargetState == State::Detached) {
TargetState = State::None;
return;
{
std::scoped_lock<std::mutex> Lock(M);
assert(CurrentState == State::Detached);
if (TargetState == State::Detached) {
TargetState = State::None;
return;
}
// Someone must have requested shutdown.
assert(TargetState == State::Shutdown);
}
// Someone must have requested shutdown.
assert(TargetState == State::Shutdown);
proceedToShutdown(Lock);
waitForManagedCodeCallsThenShutdown();
}
void Session::proceedToShutdown(std::unique_lock<std::mutex> &Lock) {
std::vector<Service *> ToNotify;
ToNotify.reserve(Services.size());
for (auto &Srv : Services)
ToNotify.push_back(Srv.get());
CurrentState = State::Shutdown;
Lock.unlock();
void Session::waitForManagedCodeCallsThenShutdown() {
ManagedCodeCallsGroup->addOnComplete([this]() { proceedToShutdown(); });
ManagedCodeCallsGroup->close();
}
void Session::proceedToShutdown() {
std::vector<Service *> ToNotify;
{
std::scoped_lock<std::mutex> Lock(M);
ToNotify.reserve(Services.size());
for (auto &Srv : Services)
ToNotify.push_back(Srv.get());
CurrentState = State::Shutdown;
}
// Notify services.
shutdownServices(std::move(ToNotify));
}

View File

@@ -357,7 +357,7 @@ TEST(SessionTest, MultipleServices) {
}
}
TEST(SessionTest, ExpectedShutdownSequence) {
TEST(SessionTest, ExpectedShutdownSequenceWithNoActiveManagedCodeCalls) {
// Check that Session shutdown results in...
// 1. Services being shut down.
// 2. The TaskDispatcher being shut down.
@@ -392,6 +392,223 @@ TEST(SessionTest, ExpectedShutdownSequence) {
EXPECT_TRUE(SessionShutdownComplete);
}
TEST(SessionTest, ActiveManagedCallsDelayShutdown) {
std::deque<std::unique_ptr<Task>> Tasks;
Session S(mockExecutorProcessInfo(),
std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx;
std::optional<size_t> ShutdownOpIdx;
S.createService<MockService>(DetachOpIdx, ShutdownOpIdx, OpIdx);
ASSERT_FALSE(DetachOpIdx);
ASSERT_FALSE(ShutdownOpIdx);
// Take a managed code call token. This should succeed.
auto Tok = TaskGroup::Token(S.managedCodeCallsGroup());
ASSERT_TRUE(Tok);
// We expect shutdown to wait for any active managed calls to complete.
bool ShutdownComplete = false;
S.shutdown([&]() { ShutdownComplete = true; });
// Detach should have happened, but shutdown should be waiting on token.
EXPECT_EQ(DetachOpIdx, 0U);
EXPECT_FALSE(ShutdownOpIdx);
EXPECT_FALSE(ShutdownComplete);
// The managed calls code group should have been closed. Assert that we
// can't get a new token.
ASSERT_FALSE(TaskGroup::Token(S.managedCodeCallsGroup()));
Tok = TaskGroup::Token(); // Reset token.
EXPECT_EQ(ShutdownOpIdx, 1U);
EXPECT_TRUE(ShutdownComplete);
}
static void managedSyncVoidFunction(int *P) { *P = 42; }
TEST(SessionTest, SyncCallManagedCodeVoidFn) {
// Test synchronous calls to a void function while holding a
// ManagedCodeCallsGroup token.
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);
{
// Pre-shutdown we expect token acquisition to succeed and the function to
// run.
int X = 0;
bool CallSucceeded = S.callManagedCodeSync(managedSyncVoidFunction, &X);
EXPECT_TRUE(CallSucceeded);
EXPECT_EQ(X, 42U);
}
S.waitForShutdown();
{
// Post-shutdown we expect token acquisition to fail, and
// callManagedCodeSync to return false.
int X = 0;
bool CallSucceeded = S.callManagedCodeSync(managedSyncVoidFunction, &X);
EXPECT_FALSE(CallSucceeded);
}
}
static int managedSyncNonVoidFunction(int N) { return N + 1; }
TEST(SessionTest, SyncCallManagedCodeNonVoidFn) {
// Test synchronous calls to a non-void function while holding a
// ManagedCodeCallsGroup token.
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);
{
// Pre-shutdown we expect token acquisition to succeed, the function to be
// run, and the result to be returned.
auto Result = S.callManagedCodeSync(managedSyncNonVoidFunction, 41);
EXPECT_TRUE(Result);
EXPECT_EQ(*Result, 42U);
}
S.waitForShutdown();
{
// Post-shutdown we expect token acquisition to fail, and
// callManagedCodeSync to return std::nullopt.
auto Result = S.callManagedCodeSync(managedSyncNonVoidFunction, 41);
EXPECT_EQ(Result, std::nullopt);
}
}
static void managedAsyncVoidFunction(move_only_function<void()> Return,
int *P) {
*P = 42;
Return();
}
TEST(SessionTest, AsyncCallManagedCodeVoidFn) {
// Test asynchronous calls to a void function while holding a
// ManagedCodeCallsGroup token.
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);
{
// Pre-shutdown we expect token acquisition to succeed, and the function
// and Return callback to be run.
int X = 0;
bool ReturnSucceeded = false;
S.callManagedCodeAsync([&](bool B) { ReturnSucceeded = B; },
managedAsyncVoidFunction, &X);
EXPECT_TRUE(ReturnSucceeded);
EXPECT_EQ(X, 42U);
}
S.waitForShutdown();
{
// Post-shutdown we expect token acquisition to fail. Return should be
// with `false` and the function should not be called.
int X = 0;
bool ReturnSucceeded = false;
S.callManagedCodeAsync([&](bool B) { ReturnSucceeded = B; },
managedAsyncVoidFunction, &X);
EXPECT_FALSE(ReturnSucceeded);
EXPECT_EQ(X, 0U);
}
}
static void managedAsyncNonVoidFunction(move_only_function<void(int)> Return,
int *P) {
Return(++*P);
}
TEST(SessionTest, AsyncCallManagedCodeNonVoidFn) {
// Test asynchronous calls to a non-void function while holding a
// ManagedCodeCallsGroup token.
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);
{
// Pre-shutdown we expect token acquisition to succeed, and the function
// and Return callback to be run.
int N = 41;
std::optional<int> Result;
S.callManagedCodeAsync([&](std::optional<int> N) { Result = N; },
managedAsyncNonVoidFunction, &N);
EXPECT_TRUE(Result);
EXPECT_EQ(*Result, 42U);
EXPECT_EQ(N, 42U);
}
S.waitForShutdown();
{
// Post-shutdown we expect token acquisition to fail. Return should be
// with `std::nullopt` and the function should not be called.
int N = 41;
std::optional<int> Result;
S.callManagedCodeAsync([&](std::optional<int> N) { Result = N; },
managedAsyncNonVoidFunction, &N);
EXPECT_EQ(Result, std::nullopt);
EXPECT_EQ(N, 41U);
}
}
TEST(SessionTest, AsyncCallManagedCodeHoldsTokenAcrossAsyncGap) {
// Verify that the ManagedCodeCallsGroup token is held until the async
// continuation runs, not just until callManagedCodeAsync returns. This
// ensures shutdown blocks for the duration of the actual async work.
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx;
std::optional<size_t> ShutdownOpIdx;
S.createService<MockService>(DetachOpIdx, ShutdownOpIdx, OpIdx);
// The managed code function stashes its continuation instead of calling it.
std::optional<int> Result;
move_only_function<void(int)> StashedContinuation;
S.callManagedCodeAsync([&](std::optional<int> N) { Result = std::move(N); },
[&](move_only_function<void(int)> Return, int N) {
// Stash the continuation and return without calling
// it.
StashedContinuation = std::move(Return);
},
41);
// callManagedCodeAsync has returned, but the continuation hasn't been
// called yet. The token should still be held inside StashedContinuation.
ASSERT_TRUE(StashedContinuation);
// Request shutdown. It should detach but block on the outstanding token.
bool ShutdownComplete = false;
S.shutdown([&]() { ShutdownComplete = true; });
EXPECT_EQ(DetachOpIdx, 0U);
EXPECT_FALSE(ShutdownOpIdx);
EXPECT_FALSE(ShutdownComplete);
// Now invoke the stashed continuation and then destroy it, releasing the
// token.
StashedContinuation(42);
StashedContinuation = {};
// Check result.
EXPECT_EQ(Result, 42);
// Shutdown should now have completed.
EXPECT_EQ(ShutdownOpIdx, 1U);
EXPECT_TRUE(ShutdownComplete);
}
TEST(SessionTest, AddServiceAndUseRef) {
Session S(mockExecutorProcessInfo(), std::make_unique<NoDispatcher>(),
noErrors);