Skip to content

Instantly share code, notes, and snippets.

@makslevental
Last active September 9, 2025 06:36
Show Gist options
  • Save makslevental/ca6170ea884ef2b8d04601be3b6c8cac to your computer and use it in GitHub Desktop.
Save makslevental/ca6170ea884ef2b8d04601be3b6c8cac to your computer and use it in GitHub Desktop.
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a3149..88557706bd04 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -59,6 +59,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
+
+ nb::class_<MlirExternalPass>(m, "ExternalPass");
+
+ m.def("signal_pass_failure",
+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+
nb::class_<PyPassManager>(m, "PassManager")
.def(
"__init__",
@@ -182,9 +188,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op,
+ pass);
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966..57dc9dd23453 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,11 +64,11 @@ def testCustomPass():
"""
)
- def custom_pass_1(op):
+ def custom_pass_1(op, _pass):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
- def __call__(self, m):
+ def __call__(self, m, _pass):
apply_patterns_and_fold_greedily(m, frozen)
custom_pass_2 = CustomPass2()
@@ -86,3 +86,19 @@ def testCustomPass():
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ class CustomPassThatFails:
+ def __call__(self, m, pass_):
+ print("hello from pass that fails", file=sys.stderr)
+ signal_pass_failure(pass_)
+
+ custom_pass_that_fails = CustomPassThatFails()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ try:
+ pm.run(module)
+ except MLIRError as e:
+ assert e.message == "Failure while executing pass pipeline"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment