//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Test SPSWrapperFunction and associated utilities. // //===----------------------------------------------------------------------===// #include "CommonTestUtils.h" #include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" #include "DirectCaller.h" #include "gtest/gtest.h" static void add_via_function(orc_rt::move_only_function Return, int32_t X, int32_t Y) { Return(X + Y); } // Note: This macro use has been deliberately moved above the // "using namespace orc_rt;" statement below to check that its expansion works // from other namespaces. ORC_RT_SPS_WRAPPER(add_via_function_sps_wrapper, int32_t(int32_t, int32_t), add_via_function); using namespace orc_rt; static void void_noop_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, [](move_only_function Return) { Return(); }); } TEST(SPSWrapperFunctionUtilsTest, VoidNoop) { bool Ran = false; SPSWrapperFunction::call(DirectCaller(nullptr, void_noop_sps_wrapper), [&](Error Err) { cantFail(std::move(Err)); Ran = true; }); EXPECT_TRUE(Ran); } static void add_via_lambda_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, [](move_only_function Return, int32_t X, int32_t Y) { Return(X + Y); }); } TEST(SPSWrapperFunctionUtilsTest, BinaryOpViaLambda) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_lambda_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } TEST(SPSWrapperFunctionUtilsTest, BinaryOpViaFunction) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_function_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } static void add_via_function_pointer_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, &add_via_function); } TEST(SPSWrapperFunctionUtilsTest, BinaryOpViaFunctionPointer) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_function_pointer_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } static void round_trip_string_via_span_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, [](move_only_function Return, span S) { Return({S.data(), S.size()}); }); } TEST(SPSWrapperFunctionUtilsTest, RoundTripStringViaSpan) { /// Test that the SPSWrapperFunction<...>::handle call in /// round_trip_string_via_span_sps_wrapper can deserialize into a usable /// span. std::string Result; SPSWrapperFunction::call( DirectCaller(nullptr, round_trip_string_via_span_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, std::string_view("hello, world!")); EXPECT_EQ(Result, "hello, world!"); } static void improbable_feat_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, [](move_only_function Return, bool LuckyHat) { if (LuckyHat) Return(Error::success()); else Return(make_error("crushed by boulder")); }); } TEST(SPSWrapperFunctionUtilsTest, TransparentConversionErrorSuccessCase) { bool DidRun = false; SPSWrapperFunction::call( DirectCaller(nullptr, improbable_feat_sps_wrapper), [&](Expected E) { DidRun = true; cantFail(cantFail(std::move(E))); }, true); EXPECT_TRUE(DidRun); } TEST(SPSWrapperFunctionUtilsTest, TransparentConversionErrorFailureCase) { std::string ErrMsg; SPSWrapperFunction::call( DirectCaller(nullptr, improbable_feat_sps_wrapper), [&](Expected E) { ErrMsg = toString(cantFail(std::move(E))); }, false); EXPECT_EQ(ErrMsg, "crushed by boulder"); } static void halve_number_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction(int32_t)>::handle( S, CallId, Return, ArgBytes, [](move_only_function)> Return, int N) { if (N % 2 == 0) Return(N >> 1); else Return(make_error("N is not a multiple of 2")); }); } TEST(SPSWrapperFunctionUtilsTest, TransparentConversionExpectedSuccessCase) { int32_t Result = 0; SPSWrapperFunction(int32_t)>::call( DirectCaller(nullptr, halve_number_sps_wrapper), [&](Expected> R) { Result = cantFail(cantFail(std::move(R))); }, 2); EXPECT_EQ(Result, 1); } TEST(SPSWrapperFunctionUtilsTest, TransparentConversionExpectedFailureCase) { std::string ErrMsg; SPSWrapperFunction(int32_t)>::call( DirectCaller(nullptr, halve_number_sps_wrapper), [&](Expected> R) { ErrMsg = toString(cantFail(std::move(R)).takeError()); }, 3); EXPECT_EQ(ErrMsg, "N is not a multiple of 2"); } template struct SPSOpCounter {}; namespace orc_rt { template class SPSSerializationTraits, OpCounter> { public: static size_t size(const OpCounter &O) { return 0; } static bool serialize(SPSOutputBuffer &OB, const OpCounter &O) { return true; } static bool deserialize(SPSInputBuffer &OB, OpCounter &O) { return true; } }; } // namespace orc_rt static void handle_with_reference_types_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, SPSOpCounter<3>)>::handle(S, CallId, Return, ArgBytes, [](move_only_function Return, OpCounter<0>, OpCounter<1> &, const OpCounter<2> &, OpCounter<3> &&) { Return(); }); } TEST(SPSWrapperFunctionUtilsTest, HandlerWithReferences) { // Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref // arguments, and that we generate the expected number of moves. OpCounter<0>::reset(); OpCounter<1>::reset(); OpCounter<2>::reset(); OpCounter<3>::reset(); bool DidRun = false; SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, SPSOpCounter<3>)>:: call( DirectCaller(nullptr, handle_with_reference_types_sps_wrapper), [&](Error R) { cantFail(std::move(R)); DidRun = true; }, OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>()); EXPECT_TRUE(DidRun); // We expect two default constructions for each parameter: one for the // argument to call, and one for the object to deserialize into. EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U); // Pass-by-value: we expect two moves (one for SPS transparent conversion, // one to copy the value to the parameter), and no copies. EXPECT_EQ(OpCounter<0>::moves(), 2U); EXPECT_EQ(OpCounter<0>::copies(), 0U); // Pass-by-lvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<1>::moves(), 1U); EXPECT_EQ(OpCounter<1>::copies(), 0U); // Pass-by-const-lvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<2>::moves(), 1U); EXPECT_EQ(OpCounter<2>::copies(), 0U); // Pass-by-rvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<3>::moves(), 1U); EXPECT_EQ(OpCounter<3>::copies(), 0U); } namespace { class Adder { public: int32_t addSync(int32_t X, int32_t Y) { return X + Y; } void addAsync(move_only_function Return, int32_t X, int32_t Y) { Return(addSync(X, Y)); } }; } // anonymous namespace static void adder_add_async_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, WrapperFunction::handleWithAsyncMethod(&Adder::addAsync)); } TEST(SPSWrapperFunctionUtilsTest, HandleWtihAsyncMethod) { auto A = std::make_unique(); int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, adder_add_async_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, ExecutorAddr::fromPtr(A.get()), 41, 1); EXPECT_EQ(Result, 42); } static void adder_add_sync_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, WrapperFunction::handleWithSyncMethod(&Adder::addSync)); } TEST(SPSWrapperFunctionUtilsTest, HandleWithSyncMethod) { auto A = std::make_unique(); int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, adder_add_sync_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, ExecutorAddr::fromPtr(A.get()), 41, 1); EXPECT_EQ(Result, 42); }