Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter. With assistance of Claude.
119 lines
4.0 KiB
C++
119 lines
4.0 KiB
C++
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
|
|
#define MLIR_BINDINGS_PYTHON_REWRITE_H
|
|
|
|
#include "mlir-c/Rewrite.h"
|
|
#include "mlir/Bindings/Python/IRCore.h"
|
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
namespace mlir {
|
|
namespace python {
|
|
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
|
|
|
/// CRTP Base class for rewriter wrappers.
|
|
template <typename DerivedTy>
|
|
class MLIR_PYTHON_API_EXPORTED PyRewriterBase {
|
|
public:
|
|
PyRewriterBase(MlirRewriterBase rewriter)
|
|
: base(rewriter),
|
|
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
|
|
|
|
PyInsertionPoint getInsertionPoint() const {
|
|
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
|
|
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
|
|
|
|
if (mlirOperationIsNull(op)) {
|
|
MlirOperation owner = mlirBlockGetParentOperation(block);
|
|
auto parent = PyOperation::forOperation(ctx, owner);
|
|
return PyInsertionPoint(PyBlock(parent, block));
|
|
}
|
|
|
|
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
|
|
}
|
|
|
|
static void bind(nanobind::module_ &m) {
|
|
nanobind::class_<DerivedTy>(m, DerivedTy::pyClassName)
|
|
.def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
|
|
"The current insertion point of the PatternRewriter.")
|
|
.def(
|
|
"replace_op",
|
|
[](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
|
|
mlirRewriterBaseReplaceOpWithOperation(
|
|
self.base, op.getOperation(), newOp.getOperation());
|
|
},
|
|
"Replace an operation with a new operation.", nanobind::arg("op"),
|
|
nanobind::arg("new_op"))
|
|
.def(
|
|
"replace_op",
|
|
[](DerivedTy &self, PyOperationBase &op,
|
|
const std::vector<PyValue> &values) {
|
|
std::vector<MlirValue> values_(values.size());
|
|
std::copy(values.begin(), values.end(), values_.begin());
|
|
mlirRewriterBaseReplaceOpWithValues(
|
|
self.base, op.getOperation(), values_.size(), values_.data());
|
|
},
|
|
"Replace an operation with a list of values.", nanobind::arg("op"),
|
|
nanobind::arg("values"))
|
|
.def(
|
|
"erase_op",
|
|
[](DerivedTy &self, PyOperationBase &op) {
|
|
mlirRewriterBaseEraseOp(self.base, op.getOperation());
|
|
},
|
|
"Erase an operation.", nanobind::arg("op"));
|
|
}
|
|
|
|
private:
|
|
MlirRewriterBase base;
|
|
PyMlirContextRef ctx;
|
|
};
|
|
|
|
/// Wrapper around MlirRewritePatternSet.
|
|
/// The default constructor creates an owned pattern set that is destroyed
|
|
/// in the destructor. The constructor taking MlirRewritePatternSet creates
|
|
/// a non-owning reference.
|
|
class PyTypeConverter;
|
|
class MLIR_PYTHON_API_EXPORTED PyRewritePatternSet {
|
|
public:
|
|
/// Create an owned pattern set.
|
|
PyRewritePatternSet(MlirContext ctx);
|
|
|
|
/// Create a non-owning reference to an existing pattern set.
|
|
PyRewritePatternSet(MlirRewritePatternSet patterns);
|
|
|
|
~PyRewritePatternSet();
|
|
|
|
MlirRewritePatternSet get() const;
|
|
|
|
bool isOwned() const;
|
|
|
|
/// Add a new rewrite pattern to the pattern set.
|
|
void add(nanobind::handle root, const nanobind::callable &matchAndRewrite,
|
|
unsigned benefit);
|
|
|
|
/// Add a new conversion pattern to the pattern set.
|
|
void addConversion(nanobind::handle root,
|
|
const nanobind::callable &matchAndRewrite,
|
|
PyTypeConverter &typeConverter, unsigned benefit);
|
|
|
|
static void bind(nanobind::module_ &m);
|
|
|
|
private:
|
|
MlirRewritePatternSet patterns;
|
|
bool owned;
|
|
};
|
|
|
|
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
|
|
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
|
} // namespace python
|
|
} // namespace mlir
|
|
|
|
#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
|