Created
May 9, 2025 10:13
-
-
Save mildsunrise/fc98f83420ff9ac5f3443f128c61532f to your computer and use it in GitHub Desktop.
bit matrices, with a "reader monad" user interface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MUint: | |
''' | |
a bit vector (unsigned integer) which is the result of an affine map. | |
except for the 'algebra-specific operations', like apply, instances of this | |
class behave like `int` values when operated on. this includes a subset of | |
the operations possible on ints, namely bitwise operations: | |
- and (`&`) and or (`|`) are allowed as long as the other operand is an `int` | |
- xor (`^`) are allowed with both ints and other MUints | |
- shifts (`>>`, `<<`) are allowed (the shift amount must be an `int`) | |
- not (`~`) is only allowed on FixedMUint (the operation makes no sense on unsigned ints without width) | |
''' | |
# bit 'i' of this MUint is obtained by taking (bvecs[i] & input_bitstring).parity(), | |
# but the input_bitstring gets concatenated with a 1 in its LSB, so that the constant | |
# part lives in the LSB of bvecs[i]. see apply | |
bvecs: tuple[int, ...] | |
def __init__(self, bvecs: tuple[int, ...]): | |
self.bvecs = bvecs | |
def __delattr__(self, name: str) -> None: | |
raise Exception('frozen') | |
def __setattr__(self, name: str, value) -> None: | |
if hasattr(self, 'bvecs'): | |
raise Exception('frozen') | |
return super().__setattr__(name, value) | |
# algebra-specific operations | |
def apply(self, input: int) -> int: | |
''' evaluate this MUint to a concrete value by feeding the affine map an input ''' | |
input = input << 1 | 1 | |
result = 0 | |
for i, bvec in enumerate(self.bvecs): | |
bit = (bvec & input).bit_count() & 1 | |
result |= bit << i | |
return result | |
# TODO: apply_muint(self, input: MUint) -> MUint | |
# (multiplies the matrices) | |
def kernel(self, k: int) -> tuple['FixedMUint', int] | None: | |
''' | |
find which k-vectors cause this MUint to evaluate to all zeros. | |
because there could be many of them, the result is returned as | |
a (solution, n) pair, where `solution` is a FixedMUint that | |
generates all possible solutions by applying it with n-vectors. | |
(if there is no solution, None is returned) | |
''' | |
col = 1 | |
bvecs = list(self.bvecs) | |
for i in range(k): | |
pivot = next((p for p in range(i, len(bvecs)) if (bvecs[p] >> col) & 1), None) | |
if pivot == None: | |
bvecs.insert(i, 1 << col) | |
col += 1 | |
continue | |
bvecs[pivot], bvecs[i] = bvecs[i], bvecs[pivot] | |
rmask = ((~0) << col); mask = ~rmask | |
for j in range(len(bvecs)): | |
if j != i and (bvecs[j] >> col) & 1: | |
bvecs[j] ^= bvecs[i] | |
for j in range(len(bvecs)): | |
bvecs[j] = (bvecs[j] & mask) | ((bvecs[j] >> 1) & rmask) | |
if any(bvecs[i] & 1 for i in range(k, len(bvecs))): | |
return None | |
return FixedMUint(tuple(bvecs[:k])), col - 1 | |
@staticmethod | |
def constant(x: int): | |
''' obtain a MUint that always evaluates to the same constant 'x' ''' | |
return MUint(tuple((x >> i) & 1 for i in range(x.bit_length()))) | |
def simplify(self): | |
''' obtain an equivalent, but reduced (normalized) representation of this MUint ''' | |
w = len(self.bvecs) | |
while w and not self.bvecs[w - 1]: w -= 1 | |
return MUint(self.bvecs[:w]) | |
# this is allocated width, not real width (unless simplify has been called) | |
@property | |
def _width(self): | |
return len(self.bvecs) | |
# ensure _width is at least this value | |
def _ensure_width(self, width: int): | |
n = width - len(self.bvecs) | |
return (self.bvecs + (0,) * n) if n > 0 else self.bvecs | |
def to_fixed(self, width: int) -> 'FixedMUint': | |
assert self._width <= width | |
return FixedMUint(self._ensure_width(width)) | |
# standard operations on int | |
def __xor__(self, value: 'MUint | int'): | |
if isinstance(value, int): | |
value = MUint.constant(value) | |
assert isinstance(value, MUint) | |
w = max(self._width, value._width) | |
a = self._ensure_width(w) | |
b = value._ensure_width(w) | |
return MUint(tuple(a^b for a,b in zip(a,b))) | |
def __or__(self, value: int): | |
assert isinstance(value, int) | |
bvecs = self._ensure_width(value.bit_length()) | |
return MUint(tuple( | |
bit if (bit := (value >> i) & 1) else x | |
for i, x in enumerate(bvecs))) | |
def __and__(self, value: int): | |
assert isinstance(value, int) | |
bvecs = self.bvecs[:value.bit_length()] | |
return MUint(tuple( | |
bit if not (bit := (value >> i) & 1) else x | |
for i, x in enumerate(bvecs))) | |
def __rxor__(self, value: 'MUint | int'): | |
return self.__xor__(value) | |
def __ror__(self, value: int): | |
return self.__or__(value) | |
def __rand__(self, value: int): | |
return self.__and__(value) | |
def __lshift__(self, value: int): | |
assert isinstance(value, int) and value >= 0 | |
return MUint((0,) * value + self.bvecs) | |
def __rshift__(self, value: int): | |
assert isinstance(value, int) and value >= 0 | |
return MUint(self.bvecs[value:]) | |
# custom operations that make sense in any int | |
def bits(self, start: int, count: int): | |
''' extract n bits starting at bit k ''' | |
assert start >= 0 and count >= 0 | |
return MUint(self.bvecs[start:][:count]).to_fixed(count) | |
def bit(self, bit: int): | |
''' extract bit k of x ''' | |
return self.bits(bit, 1) | |
@staticmethod | |
def concat(*xs: 'FixedMUint'): | |
''' concatenate the bits of several uints (the most significant one is passed first) ''' | |
bvecs = [] | |
for x in reversed(xs): | |
assert isinstance(x, FixedMUint) | |
bvecs += x.bvecs | |
return FixedMUint(tuple(bvecs)) | |
class FixedMUint(MUint): | |
''' an unsigned integer that carries an explicit bit width ''' | |
@staticmethod | |
def identity(start: int, width: int): | |
start += 1 | |
return FixedMUint(tuple((1 << (start + i)) for i in range(width))) | |
class Allocator: | |
start: int | |
def __init__(self, start=0): | |
self.start = start | |
def __call__(self, width: int): | |
result = FixedMUint.identity(self.start, width) | |
self.start += width | |
return result | |
@property | |
def width(self): | |
return self._width | |
def apply(self, input: int): | |
return FixedUint(MUint.apply(self, input), self.width) | |
def ror(self, k: int): | |
''' rotate right by k bits (k must be between 0 and width) ''' | |
assert 0 <= k <= self.width | |
return type(self)(self.bvecs[k:] + self.bvecs[:k]) | |
# XORing two FixedUints of the same width preserves it | |
def __xor__(self, other: 'MUint | int'): | |
result = MUint.__xor__(self, other) | |
if isinstance(other, FixedMUint) and other.width == self.width: | |
result = type(self)(result.bvecs) | |
return result | |
def repeat(self, count: int): | |
''' concatenate 'count' copies of self ''' | |
return type(self)(self.bvecs * count) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment