Files
llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp
Sergei Lebedev 2c0e63c9b6 [MLIR] [Python] Relaxed list to Sequence in most parameter types (#188543)
Using `Sequence` frees users from the need to cast to `list` in cases
where the underlying API does not really care about the type of the
container.

Note that accepting an `nb::sequence` is marginally slower than
accepting `nb::list` directly, because `__getitem__`, `__len__` etc need
to go through an extra layer of indirection. However, I expect the
performance difference to be negligible.
2026-03-25 19:15:38 +00:00

1546 lines
59 KiB
C++

//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/ExtensibleDialect.h"
#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir;
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
//------------------------------------------------------------------------------
// Docstrings (trivial, non-duplicated docstrings are included inline).
//------------------------------------------------------------------------------
static const char kDenseElementsAttrGetDocstring[] =
R"(Gets a DenseElementsAttr from a Python buffer or array.
When `type` is not provided, then some limited type inferencing is done based
on the buffer format. Support presently exists for 8/16/32/64 signed and
unsigned integers and float16/float32/float64. DenseElementsAttrs of these
types can also be converted back to a corresponding buffer.
For conversions outside of these types, a `type=` must be explicitly provided
and the buffer contents must be bit-castable to the MLIR internal
representation:
* Integer types: the buffer must be byte aligned to the next byte boundary.
* Floating point types: Must be bit-castable to the given floating point
size.
* i1 (bool): Each boolean value is stored as a single byte (0 or 1).
If a single element buffer is passed, then a splat will be created.
Args:
array: The array or buffer to convert.
signless: If inferring an appropriate MLIR type, use signless types for
integers (defaults True).
type: Skips inference of the MLIR element type and uses this instead. The
storage size must be consistent with the actual contents of the buffer.
shape: Overrides the shape of the buffer when constructing the MLIR
shaped type. This is needed when the physical and logical shape differ.
context: Explicit context, if not from context manager.
Returns:
DenseElementsAttr on success.
Raises:
ValueError: If the type of the buffer or array cannot be matched to an MLIR
type or if the buffer does not meet expectations.
)";
static const char kDenseElementsAttrGetFromListDocstring[] =
R"(Gets a DenseElementsAttr from a Python list of attributes.
Note that it can be expensive to construct attributes individually.
For a large number of elements, consider using a Python buffer or array instead.
Args:
attrs: A list of attributes.
type: The desired shape and type of the resulting DenseElementsAttr.
If not provided, the element type is determined based on the type
of the 0th attribute and the shape is `[len(attrs)]`.
context: Explicit context, if not from context manager.
Returns:
DenseElementsAttr on success.
Raises:
ValueError: If the type of the attributes does not match the type
specified by `shaped_type`.
)";
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
This function does minimal validation or massaging of the data, and it is
up to the caller to ensure that the buffer meets the characteristics
implied by the shape.
The backing buffer and any user objects will be retained for the lifetime
of the resource blob. This is typically bounded to the context but the
resource can have a shorter lifespan depending on how it is used in
subsequent processing.
Args:
buffer: The array or buffer to convert.
name: Name to provide to the resource (may be changed upon collision).
type: The explicit ShapedType to construct the attribute with.
context: Explicit context, if not from context manager.
Returns:
DenseResourceElementsAttr on success.
Raises:
ValueError: If the type of the buffer or array cannot be matched to an MLIR
type or if the buffer does not meet expectations.
)";
namespace {
/// Local helper adapted from llvm::scope_exit.
template <typename Callable>
class [[nodiscard]] scope_exit {
Callable ExitFunction;
bool Engaged = true; // False once moved-from or release()d.
public:
template <typename Fp>
explicit scope_exit(Fp &&F) : ExitFunction(std::forward<Fp>(F)) {}
scope_exit(scope_exit &&Rhs)
: ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) {
Rhs.release();
}
scope_exit(const scope_exit &) = delete;
scope_exit &operator=(scope_exit &&) = delete;
scope_exit &operator=(const scope_exit &) = delete;
void release() { Engaged = false; }
~scope_exit() {
if (Engaged)
ExitFunction();
}
};
template <typename Callable>
scope_exit(Callable) -> scope_exit<Callable>;
} // namespace
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
nb_buffer_info::nb_buffer_info(
void *ptr, Py_ssize_t itemsize, const char *format, Py_ssize_t ndim,
std::vector<Py_ssize_t> shape_in, std::vector<Py_ssize_t> strides_in,
bool readonly,
std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in)
: ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)),
readonly(readonly), owned_view(std::move(owned_view_in)) {
size = 1;
for (Py_ssize_t i = 0; i < ndim; ++i) {
size *= shape[i];
}
}
nb_buffer_info nb_buffer::request() const {
int flags = PyBUF_STRIDES | PyBUF_FORMAT;
auto *view = new Py_buffer();
if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
delete view;
throw nb::python_error();
}
return nb_buffer_info(view);
}
template <>
struct nb_format_descriptor<bool> {
static const char *format() { return "?"; }
};
template <>
struct nb_format_descriptor<int8_t> {
static const char *format() { return "b"; }
};
template <>
struct nb_format_descriptor<uint8_t> {
static const char *format() { return "B"; }
};
template <>
struct nb_format_descriptor<int16_t> {
static const char *format() { return "h"; }
};
template <>
struct nb_format_descriptor<uint16_t> {
static const char *format() { return "H"; }
};
template <>
struct nb_format_descriptor<int32_t> {
static const char *format() { return "i"; }
};
template <>
struct nb_format_descriptor<uint32_t> {
static const char *format() { return "I"; }
};
template <>
struct nb_format_descriptor<int64_t> {
static const char *format() { return "q"; }
};
template <>
struct nb_format_descriptor<uint64_t> {
static const char *format() { return "Q"; }
};
template <>
struct nb_format_descriptor<float> {
static const char *format() { return "f"; }
};
template <>
struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
void PyAffineMapAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyAffineMap &affineMap) {
MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
return PyAffineMapAttribute(affineMap.getContext(), attr);
},
nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
c.def_prop_ro(
"value",
[](PyAffineMapAttribute &self) {
return PyAffineMap(self.getContext(), mlirAffineMapAttrGetValue(self));
},
"Returns the value of the AffineMap attribute");
}
void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyIntegerSet &integerSet) {
MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
return PyIntegerSetAttribute(integerSet.getContext(), attr);
},
nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
}
nb::typed<nb::object, PyAttribute>
PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (PyArrayAttribute::PyArrayAttributeIterator::nextIndex >=
mlirArrayAttrGetNumElements(
PyArrayAttribute::PyArrayAttributeIterator::attr.get())) {
PyErr_SetNone(PyExc_StopIteration);
// python functions should return NULL after setting any exception
return nb::object();
}
return PyAttribute(
this->PyArrayAttribute::PyArrayAttributeIterator::attr
.getContext(),
mlirArrayAttrGetElement(
PyArrayAttribute::PyArrayAttributeIterator::attr.get(),
PyArrayAttribute::PyArrayAttributeIterator::nextIndex++))
.maybeDownCast();
}
void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) {
nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
.def("__iter__", &PyArrayAttributeIterator::dunderIter)
.def("__next__", &PyArrayAttributeIterator::dunderNext);
}
MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
return mlirArrayAttrGetElement(*this, i);
}
void PyArrayAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](nb::typed<nb::sequence, PyAttribute> attributes,
DefaultingPyMlirContext context) {
std::vector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(nb::len(attributes));
for (auto attribute : attributes) {
mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
}
MlirAttribute attr = mlirArrayAttrGet(
context->get(), mlirAttributes.size(), mlirAttributes.data());
return PyArrayAttribute(context->getRef(), attr);
},
nb::arg("attributes"), nb::arg("context") = nb::none(),
"Gets a uniqued Array attribute");
c.def("__getitem__",
[](PyArrayAttribute &arr,
intptr_t i) -> nb::typed<nb::object, PyAttribute> {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
})
.def("__len__",
[](const PyArrayAttribute &arr) {
return mlirArrayAttrGetNumElements(arr);
})
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
c.def("__add__", [](PyArrayAttribute arr,
nb::typed<nb::sequence, PyAttribute> extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + nb::len(extras));
for (intptr_t i = 0; i < numOldElements; ++i)
attributes.push_back(arr.getItem(i));
for (nb::handle attr : extras)
attributes.push_back(pyTryCast<PyAttribute>(attr));
MlirAttribute arrayAttr = mlirArrayAttrGet(
arr.getContext()->get(), attributes.size(), attributes.data());
return PyArrayAttribute(arr.getContext(), arrayAttr);
});
}
void PyFloatAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
if (mlirAttributeIsNull(attr))
throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_unchecked",
[](PyType &type, double value, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute attr =
mlirFloatAttrDoubleGet(context.get()->get(), type, value);
if (mlirAttributeIsNull(attr))
throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirFloatAttrDoubleGet(
context->get(), mlirF32TypeGet(context->get()), value);
return PyFloatAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets an uniqued float point attribute associated to a f32 type");
c.def_static(
"get_f64",
[](double value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirFloatAttrDoubleGet(
context->get(), mlirF64TypeGet(context->get()), value);
return PyFloatAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets an uniqued float point attribute associated to a f64 type");
c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
"Returns the value of the float attribute");
c.def("__float__", mlirFloatAttrGetValueDouble,
"Converts the value of the float attribute to a Python float");
}
void PyIntegerAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &type, nb::object value) {
// Handle IndexType - it doesn't have a bit width or signedness.
if (mlirTypeIsAIndex(type)) {
int64_t intValue = nb::cast<int64_t>(value);
MlirAttribute attr = mlirIntegerAttrGet(type, intValue);
return PyIntegerAttribute(type.getContext(), attr);
}
// Get the bit width of the integer type.
unsigned bitWidth = mlirIntegerTypeGetWidth(type);
// Try to use the fast path for small integers.
if (bitWidth <= 64) {
int64_t intValue = nb::cast<int64_t>(value);
MlirAttribute attr = mlirIntegerAttrGet(type, intValue);
return PyIntegerAttribute(type.getContext(), attr);
}
// For larger integers, convert Python int to array of 64-bit words.
unsigned numWords = std::ceil(static_cast<double>(bitWidth) / 64);
std::vector<uint64_t> words(numWords, 0);
// Extract words from Python integer (little-endian order).
nb::object mask = nb::int_(0xFFFFFFFFFFFFFFFFULL);
nb::object shift = nb::int_(64);
nb::object current = value;
// Handle negative numbers for signed types by converting to two's
// complement representation.
if (mlirIntegerTypeIsSigned(type)) {
nb::object zero = nb::int_(0);
if (nb::cast<bool>(current < zero)) {
nb::object twoToTheBitWidth = nb::int_(1) << nb::int_(bitWidth);
current = current + twoToTheBitWidth;
}
}
for (unsigned i = 0; i < numWords; ++i) {
words[i] = nb::cast<uint64_t>(current & mask);
current = current >> shift;
}
MlirAttribute attr =
mlirIntegerAttrGetFromWords(type, numWords, words.data());
return PyIntegerAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"),
"Gets an uniqued integer attribute associated to a type");
c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
c.def("__int__", toPyInt,
"Converts the value of the integer attribute to a Python int");
c.def_prop_ro_static("static_typeid", [](nb::object & /*class*/) {
return PyTypeID(mlirIntegerAttrGetTypeID());
});
}
nb::int_ PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
MlirType type = mlirAttributeGetType(self);
unsigned bitWidth = mlirIntegerAttrGetValueBitWidth(self);
// For integers that fit in 64 bits, use the fast path.
if (bitWidth <= 64) {
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return nb::int_(mlirIntegerAttrGetValueInt(self));
if (mlirIntegerTypeIsSigned(type))
return nb::int_(mlirIntegerAttrGetValueSInt(self));
return nb::int_(mlirIntegerAttrGetValueUInt(self));
}
// For larger integers, reconstruct the value from raw words.
unsigned numWords = mlirIntegerAttrGetValueNumWords(self);
std::vector<uint64_t> words(numWords);
mlirIntegerAttrGetValueWords(self, words.data());
// Build the Python integer by shifting and ORing the words together.
// Words are in little-endian order (least significant first).
nb::object result = nb::int_(0);
nb::object shift = nb::int_(64);
for (unsigned i = numWords; i > 0; --i) {
result = result << shift;
result = result | nb::int_(words[i - 1]);
}
// Handle signed integers: if the sign bit is set, subtract 2^bitWidth.
if (mlirIntegerTypeIsSigned(type)) {
// Check if sign bit is set (most significant bit of the value).
bool signBitSet = (words[numWords - 1] >> ((bitWidth - 1) % 64)) & 1;
if (signBitSet) {
nb::object twoToTheBitWidth = nb::int_(1) << nb::int_(bitWidth);
result = result - twoToTheBitWidth;
}
}
return nb::cast<nb::int_>(result);
}
void PyBoolAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](bool value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
return PyBoolAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets an uniqued bool attribute");
c.def_prop_ro("value", mlirBoolAttrGetValue,
"Returns the value of the bool attribute");
c.def("__bool__", mlirBoolAttrGetValue,
"Converts the value of the bool attribute to a Python bool");
}
PySymbolRefAttribute
PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
PyMlirContext &context) {
if (symbols.empty())
throw std::runtime_error("SymbolRefAttr must be composed of at least "
"one symbol.");
MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
std::vector<MlirAttribute> referenceAttrs;
for (size_t i = 1; i < symbols.size(); ++i) {
referenceAttrs.push_back(
mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
}
return PySymbolRefAttribute(context.getRef(),
mlirSymbolRefAttrGet(context.get(), rootSymbol,
referenceAttrs.size(),
referenceAttrs.data()));
}
void PySymbolRefAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::vector<std::string> &symbols,
DefaultingPyMlirContext context) {
return PySymbolRefAttribute::fromList(symbols, context.resolve());
},
nb::arg("symbols"), nb::arg("context") = nb::none(),
"Gets a uniqued SymbolRef attribute from a list of symbol names");
c.def_prop_ro(
"value",
[](PySymbolRefAttribute &self) {
intptr_t numNested = mlirSymbolRefAttrGetNumNestedReferences(self);
std::vector<MlirStringRef> symbols;
symbols.reserve(numNested + 1);
symbols.push_back(mlirSymbolRefAttrGetRootReference(self));
for (intptr_t i = 0; i < numNested; ++i) {
symbols.push_back(mlirSymbolRefAttrGetRootReference(
mlirSymbolRefAttrGetNestedReference(self, i)));
}
return symbols;
},
"Returns the value of the SymbolRef attribute as a list[str]");
}
void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
return PyFlatSymbolRefAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued FlatSymbolRef attribute");
c.def_prop_ro(
"value",
[](PyFlatSymbolRefAttribute &self) {
MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
return nb::str(stringRef.data, stringRef.length);
},
"Returns the value of the FlatSymbolRef attribute as a string");
}
void PyOpaqueAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &dialectNamespace, const nb_buffer &buffer,
PyType &type, DefaultingPyMlirContext context) {
const nb_buffer_info bufferInfo = buffer.request();
intptr_t bufferSize = bufferInfo.size;
MlirAttribute attr = mlirOpaqueAttrGet(
context->get(), toMlirStringRef(dialectNamespace), bufferSize,
static_cast<char *>(bufferInfo.ptr), type);
return PyOpaqueAttribute(context->getRef(), attr);
},
nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
// clang-format on
"Gets an Opaque attribute.");
c.def_prop_ro(
"dialect_namespace",
[](PyOpaqueAttribute &self) {
MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
return nb::str(stringRef.data, stringRef.length);
},
"Returns the dialect namespace for the Opaque attribute as a string");
c.def_prop_ro(
"data",
[](PyOpaqueAttribute &self) {
MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
return nb::bytes(stringRef.data, stringRef.length);
},
"Returns the data for the Opaqued attributes as `bytes`");
}
PyDenseElementsAttribute PyDenseElementsAttribute::getFromList(
const nb::typed<nb::sequence, PyAttribute> &attributes,
std::optional<PyType> explicitType,
DefaultingPyMlirContext contextWrapper) {
const size_t numAttributes = nb::len(attributes);
if (numAttributes == 0)
throw nb::value_error("Attributes list must be non-empty.");
MlirType shapedType;
if (explicitType) {
if ((!mlirTypeIsAShaped(*explicitType) ||
!mlirShapedTypeHasStaticShape(*explicitType))) {
std::string message = nanobind::detail::join(
"Expected a static ShapedType for the shaped_type parameter: ",
nb::cast<std::string>(nb::repr(nb::cast(*explicitType))));
throw nb::value_error(message.c_str());
}
shapedType = *explicitType;
} else {
std::vector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
shapedType = mlirRankedTensorTypeGet(
shape.size(), shape.data(),
mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
mlirAttributeGetNull());
}
std::vector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(numAttributes);
for (const nb::handle &attribute : attributes) {
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
MlirType attrType = mlirAttributeGetType(mlirAttribute);
mlirAttributes.push_back(mlirAttribute);
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
std::string message = nanobind::detail::join(
"All attributes must be of the same type and match the type "
"parameter: expected=",
nb::cast<std::string>(nb::repr(nb::cast(shapedType))),
", but got=", nb::cast<std::string>(nb::repr(nb::cast(attrType))));
throw nb::value_error(message.c_str());
}
}
MlirAttribute elements = mlirDenseElementsAttrGet(
shapedType, mlirAttributes.size(), mlirAttributes.data());
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
const nb_buffer &array, bool signless,
const std::optional<PyType> &explicitType,
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
int flags = PyBUF_ND;
if (!explicitType) {
flags |= PyBUF_FORMAT;
}
Py_buffer view;
if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
throw nb::python_error();
}
scope_exit freeBuffer([&]() { PyBuffer_Release(&view); });
MlirContext context = contextWrapper->get();
MlirAttribute attr = getAttributeFromBuffer(
view, signless, explicitType, std::move(explicitShape), context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
"This may mean that the Python buffer layout does not match that "
"MLIR expected layout and is a bug.");
}
return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
}
PyDenseElementsAttribute
PyDenseElementsAttribute::getSplat(const PyType &shapedType,
PyAttribute &elementAttr) {
auto contextWrapper =
PyMlirContext::forContext(mlirTypeGetContext(shapedType));
if (!mlirAttributeIsAInteger(elementAttr) &&
!mlirAttributeIsAFloat(elementAttr)) {
std::string message = "Illegal element type for DenseElementsAttr: ";
message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
throw nb::value_error(message.c_str());
}
if (!mlirTypeIsAShaped(shapedType) ||
!mlirShapedTypeHasStaticShape(shapedType)) {
std::string message =
"Expected a static ShapedType for the shaped_type parameter: ";
message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
throw nb::value_error(message.c_str());
}
MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
MlirType attrType = mlirAttributeGetType(elementAttr);
if (!mlirTypeEqual(shapedElementType, attrType)) {
std::string message =
"Shaped element type and attribute type must be equal: shaped=";
message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
message.append(", element=");
message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
throw nb::value_error(message.c_str());
}
MlirAttribute elements =
mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
intptr_t PyDenseElementsAttribute::dunderLen() const {
return mlirElementsAttrGetNumElements(*this);
}
std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
std::string format;
if (mlirTypeIsAF32(elementType)) {
// f32
return bufferInfo<float>(shapedType);
}
if (mlirTypeIsAF64(elementType)) {
// f64
return bufferInfo<double>(shapedType);
}
if (mlirTypeIsAF16(elementType)) {
// f16
return bufferInfo<uint16_t>(shapedType, "e");
}
if (mlirTypeIsAIndex(elementType)) {
// Same as IndexType::kInternalStorageBitWidth
return bufferInfo<int64_t>(shapedType);
}
if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 32) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i32
return bufferInfo<int32_t>(shapedType);
}
if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i32
return bufferInfo<uint32_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 64) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i64
return bufferInfo<int64_t>(shapedType);
}
if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i64
return bufferInfo<uint64_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 8) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i8
return bufferInfo<int8_t>(shapedType);
}
if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i8
return bufferInfo<uint8_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 16) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i16
return bufferInfo<int16_t>(shapedType);
}
if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 1) {
// i1 / bool
return bufferInfo<bool>(shapedType);
}
// TODO: Currently crashes the program.
// Reported as https://github.com/pybind/pybind11/issues/3336
throw std::invalid_argument(
"unsupported data type for conversion to Python buffer");
}
template <typename ClassT>
void PyDenseElementsAttribute::bindFactoryMethods(ClassT &c,
const char *pyClassName) {
std::string getSig1 =
// clang-format off
"def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> " +
// clang-format on
std::string(pyClassName);
std::string getSig2 =
// clang-format off
"def get(attrs: Sequence[Attribute], type: Type | None = None, context: Context | None = None) -> " +
// clang-format on
std::string(pyClassName);
std::string getSplatSig =
// clang-format off
"def get_splat(shaped_type: Type, element_attr: Attribute) -> " +
// clang-format on
std::string(pyClassName);
c.def_static("get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
nb::arg("signless") = true, nb::arg("type") = nb::none(),
nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
nb::sig(getSig1.c_str()), kDenseElementsAttrGetDocstring)
.def_static("get", PyDenseElementsAttribute::getFromList,
nb::arg("attrs"), nb::arg("type") = nb::none(),
nb::arg("context") = nb::none(), nb::sig(getSig2.c_str()),
kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
nb::arg("shaped_type"), nb::arg("element_attr"),
nb::sig(getSplatSig.c_str()),
("Gets a " + std::string(pyClassName) +
" where all values are the same")
.c_str());
}
void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
c.def("__len__", &PyDenseElementsAttribute::dunderLen);
bindFactoryMethods(c, pyClassName);
c.def_prop_ro("is_splat",
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
.def("get_splat_value",
[](PyDenseElementsAttribute &self)
-> nb::typed<nb::object, PyAttribute> {
if (!mlirDenseElementsAttrIsSplat(self))
throw nb::value_error(
"get_splat_value called on a non-splat attribute");
return PyAttribute(self.getContext(),
mlirDenseElementsAttrGetSplatValue(self))
.maybeDownCast();
});
}
bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
std::string_view format) {
if (format.empty())
return false;
char code = format[0];
return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
code == 'Q';
}
bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
char code = format[0];
return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
code == 'q';
}
MlirType PyDenseElementsAttribute::getShapedType(
std::optional<MlirType> bulkLoadElementType,
std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
std::vector<int64_t> shape;
if (explicitShape) {
shape.insert(shape.end(), explicitShape->begin(), explicitShape->end());
} else {
shape.insert(shape.end(), view.shape, view.shape + view.ndim);
}
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
}
MlirAttribute encodingAttr = mlirAttributeGetNull();
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
const std::optional<std::vector<int64_t>> &explicitShape,
MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
// representation in the buffer protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
bulkLoadElementType = mlirIntegerTypeGet(context, 1);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}
PyType_Slot PyDenseElementsAttribute::slots[] = {
{Py_bf_getbuffer,
reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
{Py_bf_releasebuffer,
reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
{0, nullptr},
};
/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
Py_buffer *view,
int flags) {
view->obj = nullptr;
std::unique_ptr<nb_buffer_info> info;
try {
auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
info = attr->accessBuffer();
} catch (nb::python_error &e) {
e.restore();
nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
return -1;
} catch (std::exception &e) {
nb::chain_error(PyExc_BufferError,
"Error converting attribute to buffer: %s", e.what());
return -1;
}
view->obj = obj;
view->ndim = 1;
view->buf = info->ptr;
view->itemsize = info->itemsize;
view->len = info->itemsize;
for (auto s : info->shape) {
view->len *= s;
}
view->readonly = info->readonly;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char *>(info->format);
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = static_cast<int>(info->ndim);
view->strides = info->strides.data();
view->shape = info->shape.data();
}
view->suboffsets = nullptr;
view->internal = info.release();
Py_INCREF(obj);
return 0;
}
/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
Py_buffer *view) {
delete reinterpret_cast<nb_buffer_info *>(view->internal);
}
nb::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) const {
if (pos < 0 || pos >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
// Index type can also appear as a DenseIntElementsAttr and therefore can be
// casted to integer.
assert(mlirTypeIsAInteger(type) ||
mlirTypeIsAIndex(type) && "expected integer/index element type in "
"dense int elements attribute");
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. nb::int_ is implicitly
// constructible from any C++ integral type and handles bitwidth correctly.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAIndex(type)) {
return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
}
unsigned width = mlirIntegerTypeGetWidth(type);
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
if (width == 8) {
return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
}
if (width == 16) {
return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
}
if (width == 32) {
return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
}
if (width == 64) {
return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
}
} else {
if (width == 1) {
return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
if (width == 8) {
return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
}
if (width == 16) {
return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
}
if (width == 32) {
return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
}
if (width == 64) {
return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
}
}
throw nb::type_error("Unsupported integer type");
}
void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
PyDenseElementsAttribute::bindFactoryMethods(c, pyClassName);
c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
}
// Py_IsFinalizing is part of the stable ABI since 3.13. Before that, it was
// available as the private _Py_IsFinalizing, which is not part of the limited
// API.
#if defined(Py_LIMITED_API) && Py_LIMITED_API < 0x030d0000
// Under limited API targeting < 3.13, use sys.is_finalizing() via C API.
// PySys_GetObject avoids import machinery (safe during finalization).
static int Py_IsFinalizing(void) {
// PySys_GetObject returns a borrowed reference; no Py_DECREF needed.
PyObject *fn = PySys_GetObject("is_finalizing");
if (!fn)
return 0;
PyObject *result = PyObject_CallNoArgs(fn);
if (!result) {
PyErr_Clear();
return 0;
}
int val = PyObject_IsTrue(result);
Py_DECREF(result);
return val > 0 ? 1 : 0;
}
#elif PY_VERSION_HEX < 0x030d0000
#define Py_IsFinalizing _Py_IsFinalizing
#endif
PyDenseResourceElementsAttribute
PyDenseResourceElementsAttribute::getFromBuffer(
const nb_buffer &buffer, const std::string &name, const PyType &type,
std::optional<size_t> alignment, bool isMutable,
DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
throw std::invalid_argument(
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
}
// Do not request any conversions as we must ensure to use caller
// managed memory.
int flags = PyBUF_STRIDES;
std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
throw nb::python_error();
}
// This scope releaser will only release if we haven't yet transferred
// ownership.
scope_exit freeBuffer([&]() {
if (view)
PyBuffer_Release(view.get());
});
if (!PyBuffer_IsContiguous(view.get(), 'A')) {
throw std::invalid_argument("Contiguous buffer is required.");
}
// Infer alignment to be the stride of one element if not explicit.
size_t inferredAlignment;
if (alignment)
inferredAlignment = *alignment;
else if (view->ndim == 0)
inferredAlignment = view->itemsize;
else
inferredAlignment = view->strides[view->ndim - 1];
// The userData is a Py_buffer* that the deleter owns.
auto deleter = [](void *userData, const void *data, size_t size,
size_t align) {
if (Py_IsFinalizing())
return;
assert(Py_IsInitialized() && "expected interpreter to be initialized");
Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
nb::gil_scoped_acquire gil;
PyBuffer_Release(ownedView);
delete ownedView;
};
size_t rawBufferSize = view->len;
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
isMutable, deleter, static_cast<void *>(view.get()));
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseResourceElementsAttr could not be constructed from the given "
"buffer. "
"This may mean that the Python buffer layout does not match that "
"MLIR expected layout and is a bug.");
}
view.release();
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
}
void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
nb::arg("array"), nb::arg("name"), nb::arg("type"),
nb::arg("alignment") = nb::none(), nb::arg("is_mutable") = false,
nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
// clang-format on
kDenseResourceElementsAttrGetFromBufferDocstring);
}
intptr_t PyDictAttribute::dunderLen() const {
return mlirDictionaryAttrGetNumElements(*this);
}
bool PyDictAttribute::dunderContains(const std::string &name) const {
return !mlirAttributeIsNull(
mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
}
void PyDictAttribute::bindDerived(ClassTy &c) {
c.def("__contains__", &PyDictAttribute::dunderContains);
c.def("__len__", &PyDictAttribute::dunderLen);
c.def_static(
"get",
[](const nb::typed<nb::dict, nb::str, PyAttribute> &attributes,
DefaultingPyMlirContext context) {
std::vector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size());
for (std::pair<nb::handle, nb::handle> it : attributes) {
auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
auto name = nb::cast<std::string>(it.first);
mlirNamedAttributes.push_back(mlirNamedAttributeGet(
mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
toMlirStringRef(name)),
mlirAttr));
}
MlirAttribute attr =
mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
mlirNamedAttributes.data());
return PyDictAttribute(context->getRef(), attr);
},
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__",
[](PyDictAttribute &self,
const std::string &name) -> nb::typed<nb::object, PyAttribute> {
MlirAttribute attr =
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
return PyAttribute(self.getContext(), attr).maybeDownCast();
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
}
MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
return PyNamedAttribute(
namedAttr.attribute,
std::string(mlirIdentifierStr(namedAttr.name).data));
});
}
nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) const {
if (pos < 0 || pos >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. nb::float_ is implicitly
// constructible from float and double.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAF32(type)) {
return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
if (mlirTypeIsAF64(type)) {
return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
}
throw nb::type_error("Unsupported floating-point type");
}
void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
PyDenseElementsAttribute::bindFactoryMethods(c, pyClassName);
c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
}
void PyTypeAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const PyType &value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirTypeAttrGet(value.get());
return PyTypeAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
c.def_prop_ro(
"value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
.maybeDownCast();
});
}
void PyUnitAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
return PyUnitAttribute(context->getRef(),
mlirUnitAttrGet(context->get()));
},
nb::arg("context") = nb::none(), "Create a Unit attribute.");
}
void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](int64_t offset, const std::vector<int64_t> &strides,
DefaultingPyMlirContext ctx) {
MlirAttribute attr = mlirStridedLayoutAttrGet(
ctx->get(), offset, strides.size(), strides.data());
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
"Gets a strided layout attribute.");
c.def_static(
"get_fully_dynamic",
[](int64_t rank, DefaultingPyMlirContext ctx) {
auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
std::vector<int64_t> strides(rank);
std::fill(strides.begin(), strides.end(), dynamic);
MlirAttribute attr = mlirStridedLayoutAttrGet(
ctx->get(), dynamic, strides.size(), strides.data());
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
nb::arg("rank"), nb::arg("context") = nb::none(),
"Gets a strided layout attribute with dynamic offset and strides of "
"a "
"given rank.");
c.def_prop_ro(
"offset",
[](PyStridedLayoutAttribute &self) {
return mlirStridedLayoutAttrGetOffset(self);
},
"Returns the value of the float point attribute");
c.def_prop_ro(
"strides",
[](PyStridedLayoutAttribute &self) {
intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
std::vector<int64_t> strides(size);
for (intptr_t i = 0; i < size; i++) {
strides[i] = mlirStridedLayoutAttrGetStride(self, i);
}
return strides;
},
"Returns the value of the float point attribute");
}
nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
throw nb::type_error(msg.c_str());
}
nb::object denseTypedElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseTypedElementsAttr (") +
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
throw nb::type_error(msg.c_str());
}
nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
if (PyBoolAttribute::isaFunction(pyAttribute))
return nb::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
return nb::cast(PyIntegerAttribute(pyAttribute));
std::string msg = std::string("Can't cast unknown attribute type Attr (") +
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
")";
throw nb::type_error(msg.c_str());
}
nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
if (PySymbolRefAttribute::isaFunction(pyAttribute))
return nb::cast(PySymbolRefAttribute(pyAttribute));
std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
")";
throw nb::type_error(msg.c_str());
}
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued string attribute");
c.def_static(
"get",
[](const nb::bytes &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued string attribute");
c.def_static(
"get_typed",
[](PyType &type, const std::string &value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type, toMlirStringRef(value));
return PyStringAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"),
"Gets a uniqued string attribute associated to a type");
c.def_prop_ro(
"value",
[](PyStringAttribute &self) {
MlirStringRef stringRef = mlirStringAttrGetValue(self);
return nb::str(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute");
c.def_prop_ro(
"value_bytes",
[](PyStringAttribute &self) {
MlirStringRef stringRef = mlirStringAttrGetValue(self);
return nb::bytes(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute as `bytes`");
}
static MlirDynamicAttrDefinition
getDynamicAttrDef(const std::string &fullAttrName,
DefaultingPyMlirContext context) {
size_t dotPos = fullAttrName.find('.');
if (dotPos == std::string::npos) {
throw nb::value_error("Expected full attribute name to be in the format "
"'<dialectName>.<attributeName>'.");
}
std::string dialectName = fullAttrName.substr(0, dotPos);
std::string attrName = fullAttrName.substr(dotPos + 1);
PyDialects dialects(context->getRef());
MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
if (!mlirDialectIsAExtensibleDialect(dialect))
throw nb::value_error(
("Dialect '" + dialectName + "' is not an extensible dialect.")
.c_str());
MlirDynamicAttrDefinition attrDef = mlirExtensibleDialectLookupAttrDefinition(
dialect, toMlirStringRef(attrName));
if (attrDef.ptr == nullptr) {
throw nb::value_error(("Dialect '" + dialectName +
"' does not contain an attribute named '" +
attrName + "'.")
.c_str());
}
return attrDef;
}
void PyDynamicAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &fullAttrName, const std::vector<PyAttribute> &attrs,
DefaultingPyMlirContext context) {
std::vector<MlirAttribute> mlirAttrs;
mlirAttrs.reserve(attrs.size());
for (const auto &attr : attrs)
mlirAttrs.push_back(attr.get());
MlirDynamicAttrDefinition attrDef =
getDynamicAttrDef(fullAttrName, context);
MlirAttribute attr =
mlirDynamicAttrGet(attrDef, mlirAttrs.data(), mlirAttrs.size());
return PyDynamicAttribute(context->getRef(), attr);
},
nb::arg("full_attr_name"), nb::arg("attributes"),
nb::arg("context") = nb::none(), "Create a dynamic attribute.");
c.def_prop_ro(
"params",
[](PyDynamicAttribute &self) {
size_t numParams = mlirDynamicAttrGetNumParams(self);
std::vector<PyAttribute> params;
params.reserve(numParams);
for (size_t i = 0; i < numParams; ++i)
params.emplace_back(self.getContext(),
mlirDynamicAttrGetParam(self, i));
return params;
},
"Returns the parameters of the dynamic attribute as a list of "
"attributes.");
c.def_prop_ro("attr_name", [](PyDynamicAttribute &self) {
MlirDynamicAttrDefinition attrDef = mlirDynamicAttrGetAttrDef(self);
MlirStringRef name = mlirDynamicAttrDefinitionGetName(attrDef);
MlirDialect dialect = mlirDynamicAttrDefinitionGetDialect(attrDef);
MlirStringRef dialectNamespace = mlirDialectGetNamespace(dialect);
return std::string(dialectNamespace.data, dialectNamespace.length) + "." +
std::string(name.data, name.length);
});
c.def_static(
"lookup_typeid",
[](const std::string &fullAttrName, DefaultingPyMlirContext context) {
MlirDynamicAttrDefinition attrDef =
getDynamicAttrDef(fullAttrName, context);
return PyTypeID(mlirDynamicAttrDefinitionGetTypeID(attrDef));
},
nb::arg("full_attr_name"), nb::arg("context") = nb::none(),
"Look up the TypeID for the given dynamic attribute name.");
}
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI8ArrayAttribute::bind(m);
PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI16ArrayAttribute::bind(m);
PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI32ArrayAttribute::bind(m);
PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI64ArrayAttribute::bind(m);
PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseF32ArrayAttribute::bind(m);
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseF64ArrayAttribute::bind(m);
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseArrayAttrGetTypeID(),
nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
PyArrayAttribute::bind(m);
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
PyBoolAttribute::bind(m);
PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
PyDenseFPElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyGlobals::get().registerTypeCaster(mlirDenseTypedElementsAttrGetTypeID(),
nb::cast<nb::callable>(nb::cpp_function(
denseTypedElementsAttributeCaster)));
PyDenseResourceElementsAttribute::bind(m);
PyDictAttribute::bind(m);
PySymbolRefAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirSymbolRefAttrGetTypeID(),
nb::cast<nb::callable>(
nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
PyFlatSymbolRefAttribute::bind(m);
PyOpaqueAttribute::bind(m);
PyFloatAttribute::bind(m);
PyIntegerAttribute::bind(m);
PyIntegerSetAttribute::bind(m);
PyStringAttribute::bind(m);
PyTypeAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirIntegerAttrGetTypeID(),
nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
PyDynamicAttribute::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir