//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include #include #include #include #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir-c/Interfaces.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/IRInterfaces.h" namespace nb = nanobind; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { constexpr static const char *inferReturnTypesDoc = R"(Given the arguments required to build an operation, attempts to infer its return types. Raises ValueError on failure.)"; constexpr static const char *inferReturnTypeComponentsDoc = R"(Given the arguments required to build an operation, attempts to infer its return shaped type components. Raises ValueError on failure.)"; namespace { /// Takes in an optional ist of operands and converts them into a std::vector /// of MlirVlaues. Returns an empty std::vector if the list is empty. std::vector wrapOperands(std::optional operandList) { std::vector mlirOperands; if (!operandList || nb::len(*operandList) == 0) { return mlirOperands; } // Note: as the list may contain other lists this may not be final size. mlirOperands.reserve(nb::len(*operandList)); for (size_t i = 0, e = nb::len(*operandList); i < e; ++i) { nb::handle operand = (*operandList)[i]; intptr_t index = static_cast(i); if (operand.is_none()) continue; PyValue *val; try { val = nb::cast(operand); if (!val) throw nb::cast_error(); mlirOperands.push_back(val->get()); continue; } catch (nb::cast_error &err) { // Intentionally unhandled to try sequence below first. (void)err; } try { auto vals = nb::cast(operand); for (nb::handle v : vals) { try { val = nb::cast(v); if (!val) throw nb::cast_error(); mlirOperands.push_back(val->get()); } catch (nb::cast_error &err) { throw nb::value_error( nanobind::detail::join("Operand ", index, " must be a Value or Sequence of Values (", err.what(), ")") .c_str()); } } continue; } catch (nb::cast_error &err) { throw nb::value_error( nanobind::detail::join("Operand ", index, " must be a Value or Sequence of Values (", err.what(), ")") .c_str()); } throw nb::cast_error(); } return mlirOperands; } /// Takes in an optional vector of PyRegions and returns a std::vector of /// MlirRegion. Returns an empty std::vector if the list is empty. std::vector wrapRegions(std::optional> regions) { std::vector mlirRegions; if (regions) { mlirRegions.reserve(regions->size()); for (PyRegion ®ion : *regions) { mlirRegions.push_back(region); } } return mlirRegions; } } // namespace /// Python wrapper for InferTypeOpInterface. This interface has only static /// methods. class PyInferTypeOpInterface : public PyConcreteOpInterface { public: using PyConcreteOpInterface::PyConcreteOpInterface; constexpr static const char *pyClassName = "InferTypeOpInterface"; constexpr static GetTypeIDFunctionTy getInterfaceID = &mlirInferTypeOpInterfaceTypeID; /// C-style user-data structure for type appending callback. struct AppendResultsCallbackData { std::vector &inferredTypes; PyMlirContext &pyMlirContext; }; /// Appends the types provided as the two first arguments to the user-data /// structure (expects AppendResultsCallbackData). static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData) { auto *data = static_cast(userData); data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); for (intptr_t i = 0; i < nTypes; ++i) { data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); } } /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on failure. std::vector inferReturnTypes(std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { std::vector mlirOperands = wrapOperands(std::move(operandList)); std::vector mlirRegions = wrapRegions(std::move(regions)); std::vector inferredTypes; PyMlirContext &pyContext = context.resolve(); AppendResultsCallbackData data{inferredTypes, pyContext}; MlirStringRef opNameRef = mlirStringRefCreate(getOpName().data(), getOpName().length()); MlirAttribute attributeDict = attributes ? attributes->get() : mlirAttributeGetNull(); MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), mlirOperands.data(), attributeDict, properties, mlirRegions.size(), mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { throw nb::value_error("Failed to infer result types"); } return inferredTypes; } static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), nb::arg("properties") = nb::none(), nb::arg("regions") = nb::none(), nb::arg("context") = nb::none(), nb::arg("loc") = nb::none(), inferReturnTypesDoc); } }; /// Wrapper around an shaped type components. class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} PyShapedTypeComponents(nb::list shape, MlirType elementType) : shape(std::move(shape)), elementType(elementType), ranked(true) {} PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute) : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} PyShapedTypeComponents(PyShapedTypeComponents &) = delete; PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} static void bind(nb::module_ &m) { nb::class_(m, "ShapedTypeComponents") .def_prop_ro( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, nb::sig("def element_type(self) -> Type"), "Returns the element type of the shaped type components.") .def_static( "get", [](PyType &elementType) { return PyShapedTypeComponents(elementType); }, nb::arg("element_type"), "Create an shaped type components object with only the element " "type.") .def_static( "get", [](nb::typed shape, PyType &elementType) { return PyShapedTypeComponents(std::move(shape), elementType); }, nb::arg("shape"), nb::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", [](nb::typed shape, PyType &elementType, PyAttribute &attribute) { return PyShapedTypeComponents(std::move(shape), elementType, attribute); }, nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), "Create a ranked shaped type components object with attribute.") .def_prop_ro( "has_rank", [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, "Returns whether the given shaped type component is ranked.") .def_prop_ro( "rank", [](PyShapedTypeComponents &self) -> std::optional { if (!self.ranked) return {}; return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " "the shaped type components does not have a rank, None is " "returned.") .def_prop_ro( "shape", [](PyShapedTypeComponents &self) -> std::optional { if (!self.ranked) return {}; return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " "of integers. Returns none if the shaped type component does not " "have a rank."); } nb::object getCapsule(); static PyShapedTypeComponents createFromCapsule(nb::object capsule); private: nb::list shape; MlirType elementType; MlirAttribute attribute; bool ranked{false}; }; /// Python wrapper for InferShapedTypeOpInterface. This interface has only /// static methods. class PyInferShapedTypeOpInterface : public PyConcreteOpInterface { public: using PyConcreteOpInterface< PyInferShapedTypeOpInterface>::PyConcreteOpInterface; constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; constexpr static GetTypeIDFunctionTy getInterfaceID = &mlirInferShapedTypeOpInterfaceTypeID; /// C-style user-data structure for type appending callback. struct AppendResultsCallbackData { std::vector &inferredShapedTypeComponents; }; /// Appends the shaped type components provided as unpacked shape, element /// type, attribute to the user-data. static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData) { auto *data = static_cast(userData); if (!hasRank) { data->inferredShapedTypeComponents.emplace_back(elementType); } else { nb::list shapeList; for (intptr_t i = 0; i < rank; ++i) { shapeList.append(shape[i]); } data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, attribute); } } /// Given the arguments required to build an operation, attempts to infer the /// shaped type components. Throws value_error on failure. std::vector inferReturnTypeComponents( std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { std::vector mlirOperands = wrapOperands(std::move(operandList)); std::vector mlirRegions = wrapRegions(std::move(regions)); std::vector inferredShapedTypeComponents; PyMlirContext &pyContext = context.resolve(); AppendResultsCallbackData data{inferredShapedTypeComponents}; MlirStringRef opNameRef = mlirStringRefCreate(getOpName().data(), getOpName().length()); MlirAttribute attributeDict = attributes ? attributes->get() : mlirAttributeGetNull(); MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), mlirOperands.data(), attributeDict, properties, mlirRegions.size(), mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { throw nb::value_error("Failed to infer result shape type components"); } return inferredShapedTypeComponents; } static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), nb::arg("regions") = nb::none(), nb::arg("properties") = nb::none(), nb::arg("context") = nb::none(), nb::arg("loc") = nb::none(), inferReturnTypeComponentsDoc); } }; /// Wrapper around the MemoryEffectsOpInterface. class PyMemoryEffectsOpInterface : public PyConcreteOpInterface { public: using PyConcreteOpInterface< PyMemoryEffectsOpInterface>::PyConcreteOpInterface; constexpr static const char *pyClassName = "MemoryEffectsOpInterface"; constexpr static GetTypeIDFunctionTy getInterfaceID = &mlirMemoryEffectsOpInterfaceTypeID; /// Attach a new MemoryEffectsOpInterface FallbackModel to the named /// operation. The FallbackModel acts as a trampoline for callbacks on the /// Python class. static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx) { MlirMemoryEffectsOpInterfaceCallbacks callbacks; callbacks.userData = target.ptr(); nb::handle(static_cast(callbacks.userData)).inc_ref(); callbacks.construct = nullptr; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.getEffects = [](MlirOperation op, MlirMemoryEffectInstancesList effects, void *userData) { nb::handle pyClass(static_cast(userData)); // Get the 'get_effects' method from the Python class. auto pyGetEffects = nb::cast(nb::getattr(pyClass, "get_effects")); PyMemoryEffectsInstanceList effectsWrapper{effects}; PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(op)); auto opview = PyOperation::forOperation(context, op)->createOpView(); // Invoke `pyClass.get_effects(op, effects)`. pyGetEffects(opview, effectsWrapper); }; mlirMemoryEffectsOpInterfaceAttachFallbackModel( ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()), callbacks); } static void bindDerived(ClassTy &cls) { cls.attr("attach") = classmethod( [](const nb::object &cls, const nb::object &opName, nb::object target, DefaultingPyMlirContext context) { if (target.is_none()) target = cls; return attach(target, nb::cast(opName), context); }, nb::arg("cls"), nb::arg("op_name"), nb::kw_only(), nb::arg("target").none() = nb::none(), nb::arg("context").none() = nb::none(), "Attach the interface subclass to the given operation name."); } }; void populateIRInterfaces(nb::module_ &m) { nb::class_(m, "MemoryEffectInstancesList"); PyInferShapedTypeOpInterface::bind(m); PyInferTypeOpInterface::bind(m); PyMemoryEffectsOpInterface::bind(m); PyShapedTypeComponents::bind(m); } } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir