Created
March 31, 2019 11:20
-
-
Save jamescasbon/b0e1f2113a28e523ff3326d7b93eda19 to your computer and use it in GitHub Desktop.
Evil monkeypatch for numpy cmp with attrs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import attr | |
import numpy as np | |
import attr._make | |
original_make_cmp = attr._make._make_cmp | |
def _is_np_attr(x): | |
return x.type == np.ndarray | |
def numpy_make_cmp(attrs): | |
np_attrs = [a for a in attrs if _is_np_attr(a)] | |
if not np_attrs: | |
return original_make_cmp(attrs) | |
other_attrs = [a for a in attrs if not _is_np_attr(a)] | |
eq, ne, lt, le, gt, ge = original_make_cmp(other_attrs) | |
np_eqs = [ | |
lambda x, y: np.array_equal(getattr(x, a.name), getattr(y, a.name)) | |
for a in np_attrs | |
] | |
def __eq__(self, other): | |
return eq(self, other) and all(e(self, other) for e in np_eqs) | |
def __ne__(self, other): | |
return not __eq__(self, other) | |
return __eq__, __ne__, None, None, None, None | |
attr._make._make_cmp = numpy_make_cmp | |
@attr.s(auto_attribs=True) | |
class C: | |
x: np.ndarray | |
y: int | |
c1 = C(x=np.array([1, 2]), y=1) | |
c2 = C(x=np.array([1, 2]), y=1) | |
assert c1 == c2 | |
c2 = C(x=np.array([1, 2]), y=2) | |
assert c1 != c2 | |
c2 = C(x=np.array([1, 3]), y=1) | |
assert c1 != c2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment