Last active
September 9, 2025 06:36
-
-
Save makslevental/ca6170ea884ef2b8d04601be3b6c8cac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp | |
index 6ee85e8a3149..88557706bd04 100644 | |
--- a/mlir/lib/Bindings/Python/Pass.cpp | |
+++ b/mlir/lib/Bindings/Python/Pass.cpp | |
@@ -59,6 +59,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |
//---------------------------------------------------------------------------- | |
// Mapping of the top-level PassManager | |
//---------------------------------------------------------------------------- | |
+ | |
+ nb::class_<MlirExternalPass>(m, "ExternalPass"); | |
+ | |
+ m.def("signal_pass_failure", | |
+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); | |
+ | |
nb::class_<PyPassManager>(m, "PassManager") | |
.def( | |
"__init__", | |
@@ -182,9 +188,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |
callbacks.clone = [](void *) -> void * { | |
throw std::runtime_error("Cloning Python passes not supported"); | |
}; | |
- callbacks.run = [](MlirOperation op, MlirExternalPass, | |
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass, | |
void *userData) { | |
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op); | |
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op, | |
+ pass); | |
}; | |
auto externalPass = mlirCreateExternalPass( | |
passID, mlirStringRefCreate(name->data(), name->length()), | |
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py | |
index c94f96e20966..57dc9dd23453 100644 | |
--- a/mlir/test/python/python_pass.py | |
+++ b/mlir/test/python/python_pass.py | |
@@ -64,11 +64,11 @@ def testCustomPass(): | |
""" | |
) | |
- def custom_pass_1(op): | |
+ def custom_pass_1(op, _pass): | |
print("hello from pass 1!!!", file=sys.stderr) | |
class CustomPass2: | |
- def __call__(self, m): | |
+ def __call__(self, m, _pass): | |
apply_patterns_and_fold_greedily(m, frozen) | |
custom_pass_2 = CustomPass2() | |
@@ -86,3 +86,19 @@ def testCustomPass(): | |
# CHECK: llvm.mul | |
pm.add("convert-arith-to-llvm") | |
pm.run(module) | |
+ | |
+ # test signal_pass_failure | |
+ class CustomPassThatFails: | |
+ def __call__(self, m, pass_): | |
+ print("hello from pass that fails", file=sys.stderr) | |
+ signal_pass_failure(pass_) | |
+ | |
+ custom_pass_that_fails = CustomPassThatFails() | |
+ | |
+ pm = PassManager("any") | |
+ pm.add(custom_pass_that_fails, "CustomPassThatFails") | |
+ # CHECK: hello from pass that fails | |
+ try: | |
+ pm.run(module) | |
+ except MLIRError as e: | |
+ assert e.message == "Failure while executing pass pipeline" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment