Skip to content

Instantly share code, notes, and snippets.

@mildsunrise
Created May 9, 2025 10:13
Show Gist options
  • Save mildsunrise/fc98f83420ff9ac5f3443f128c61532f to your computer and use it in GitHub Desktop.
Save mildsunrise/fc98f83420ff9ac5f3443f128c61532f to your computer and use it in GitHub Desktop.
bit matrices, with a "reader monad" user interface
class MUint:
'''
a bit vector (unsigned integer) which is the result of an affine map.
except for the 'algebra-specific operations', like apply, instances of this
class behave like `int` values when operated on. this includes a subset of
the operations possible on ints, namely bitwise operations:
- and (`&`) and or (`|`) are allowed as long as the other operand is an `int`
- xor (`^`) are allowed with both ints and other MUints
- shifts (`>>`, `<<`) are allowed (the shift amount must be an `int`)
- not (`~`) is only allowed on FixedMUint (the operation makes no sense on unsigned ints without width)
'''
# bit 'i' of this MUint is obtained by taking (bvecs[i] & input_bitstring).parity(),
# but the input_bitstring gets concatenated with a 1 in its LSB, so that the constant
# part lives in the LSB of bvecs[i]. see apply
bvecs: tuple[int, ...]
def __init__(self, bvecs: tuple[int, ...]):
self.bvecs = bvecs
def __delattr__(self, name: str) -> None:
raise Exception('frozen')
def __setattr__(self, name: str, value) -> None:
if hasattr(self, 'bvecs'):
raise Exception('frozen')
return super().__setattr__(name, value)
# algebra-specific operations
def apply(self, input: int) -> int:
''' evaluate this MUint to a concrete value by feeding the affine map an input '''
input = input << 1 | 1
result = 0
for i, bvec in enumerate(self.bvecs):
bit = (bvec & input).bit_count() & 1
result |= bit << i
return result
# TODO: apply_muint(self, input: MUint) -> MUint
# (multiplies the matrices)
def kernel(self, k: int) -> tuple['FixedMUint', int] | None:
'''
find which k-vectors cause this MUint to evaluate to all zeros.
because there could be many of them, the result is returned as
a (solution, n) pair, where `solution` is a FixedMUint that
generates all possible solutions by applying it with n-vectors.
(if there is no solution, None is returned)
'''
col = 1
bvecs = list(self.bvecs)
for i in range(k):
pivot = next((p for p in range(i, len(bvecs)) if (bvecs[p] >> col) & 1), None)
if pivot == None:
bvecs.insert(i, 1 << col)
col += 1
continue
bvecs[pivot], bvecs[i] = bvecs[i], bvecs[pivot]
rmask = ((~0) << col); mask = ~rmask
for j in range(len(bvecs)):
if j != i and (bvecs[j] >> col) & 1:
bvecs[j] ^= bvecs[i]
for j in range(len(bvecs)):
bvecs[j] = (bvecs[j] & mask) | ((bvecs[j] >> 1) & rmask)
if any(bvecs[i] & 1 for i in range(k, len(bvecs))):
return None
return FixedMUint(tuple(bvecs[:k])), col - 1
@staticmethod
def constant(x: int):
''' obtain a MUint that always evaluates to the same constant 'x' '''
return MUint(tuple((x >> i) & 1 for i in range(x.bit_length())))
def simplify(self):
''' obtain an equivalent, but reduced (normalized) representation of this MUint '''
w = len(self.bvecs)
while w and not self.bvecs[w - 1]: w -= 1
return MUint(self.bvecs[:w])
# this is allocated width, not real width (unless simplify has been called)
@property
def _width(self):
return len(self.bvecs)
# ensure _width is at least this value
def _ensure_width(self, width: int):
n = width - len(self.bvecs)
return (self.bvecs + (0,) * n) if n > 0 else self.bvecs
def to_fixed(self, width: int) -> 'FixedMUint':
assert self._width <= width
return FixedMUint(self._ensure_width(width))
# standard operations on int
def __xor__(self, value: 'MUint | int'):
if isinstance(value, int):
value = MUint.constant(value)
assert isinstance(value, MUint)
w = max(self._width, value._width)
a = self._ensure_width(w)
b = value._ensure_width(w)
return MUint(tuple(a^b for a,b in zip(a,b)))
def __or__(self, value: int):
assert isinstance(value, int)
bvecs = self._ensure_width(value.bit_length())
return MUint(tuple(
bit if (bit := (value >> i) & 1) else x
for i, x in enumerate(bvecs)))
def __and__(self, value: int):
assert isinstance(value, int)
bvecs = self.bvecs[:value.bit_length()]
return MUint(tuple(
bit if not (bit := (value >> i) & 1) else x
for i, x in enumerate(bvecs)))
def __rxor__(self, value: 'MUint | int'):
return self.__xor__(value)
def __ror__(self, value: int):
return self.__or__(value)
def __rand__(self, value: int):
return self.__and__(value)
def __lshift__(self, value: int):
assert isinstance(value, int) and value >= 0
return MUint((0,) * value + self.bvecs)
def __rshift__(self, value: int):
assert isinstance(value, int) and value >= 0
return MUint(self.bvecs[value:])
# custom operations that make sense in any int
def bits(self, start: int, count: int):
''' extract n bits starting at bit k '''
assert start >= 0 and count >= 0
return MUint(self.bvecs[start:][:count]).to_fixed(count)
def bit(self, bit: int):
''' extract bit k of x '''
return self.bits(bit, 1)
@staticmethod
def concat(*xs: 'FixedMUint'):
''' concatenate the bits of several uints (the most significant one is passed first) '''
bvecs = []
for x in reversed(xs):
assert isinstance(x, FixedMUint)
bvecs += x.bvecs
return FixedMUint(tuple(bvecs))
class FixedMUint(MUint):
''' an unsigned integer that carries an explicit bit width '''
@staticmethod
def identity(start: int, width: int):
start += 1
return FixedMUint(tuple((1 << (start + i)) for i in range(width)))
class Allocator:
start: int
def __init__(self, start=0):
self.start = start
def __call__(self, width: int):
result = FixedMUint.identity(self.start, width)
self.start += width
return result
@property
def width(self):
return self._width
def apply(self, input: int):
return FixedUint(MUint.apply(self, input), self.width)
def ror(self, k: int):
''' rotate right by k bits (k must be between 0 and width) '''
assert 0 <= k <= self.width
return type(self)(self.bvecs[k:] + self.bvecs[:k])
# XORing two FixedUints of the same width preserves it
def __xor__(self, other: 'MUint | int'):
result = MUint.__xor__(self, other)
if isinstance(other, FixedMUint) and other.width == self.width:
result = type(self)(result.bvecs)
return result
def repeat(self, count: int):
''' concatenate 'count' copies of self '''
return type(self)(self.bvecs * count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment