Created
April 19, 2021 14:26
-
-
Save attilaolah/4bb5d5de607bb42922f48436b64f83c5 to your computer and use it in GitHub Desktop.
MASS plot file verifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""MASS database format verifier.""" | |
import hashlib | |
import os.path | |
import sys | |
USAGE = """Usage: | |
{} proof_dir/*.massdb | |
""" | |
HEADER_LEN = 4096 | |
HEADER_CODE = hashlib.sha256(hashlib.sha256(b"MASSDB").digest()).digest() | |
VALID_BIT_LEN = tuple(range(24, 42, 2)) | |
# TODO: This should be (0, 1), but somehow the encoded files seem to be | |
# different: A files contain 1, and B files contain 2. Not sure what's up here. | |
VALID_DB_TYPE = (1, 2) # (A, B) db types | |
def main(args: list[str]) -> None: | |
"""Verify all MassDB files in argv.""" | |
if len(args) == 1: | |
print(USAGE.format(args[0])) | |
return | |
if not args: | |
raise ValueError("Empty args list!") | |
for path in args[1:]: | |
filename = os.path.basename(path) | |
print(f"Checking {filename}:", end=" ") | |
errors = verify_massdb_file(path) | |
if errors: | |
print("FAIL! Errors:") | |
for error in errors: | |
print(f" - {error}") | |
else: | |
print("PASS") | |
def verify_massdb_file(path: str, full_check: bool = True) -> list[str]: | |
"""Verify a MassDB file.""" | |
with open(path, "rb") as dbf: | |
data = dbf.read(HEADER_LEN) | |
size = len(data) | |
if size != HEADER_LEN: | |
return [f"File is too small: expected at least {HEADER_LEN}b, found: " | |
"{size}b."] | |
errors: list[str] = [] | |
pos, size = 0, len(HEADER_CODE) | |
header_code = data[pos:pos+size] | |
if header_code != HEADER_CODE: | |
exp = "".join("{:X}".format(c) for c in HEADER_CODE) | |
got = "".join("{:X}".format(c) for c in header_code) | |
errors.append(f"Bad file code at [{pos}:{pos+size}]: expected {exp}, " | |
f"found: {got}.") | |
pos, size = pos + size, 8 | |
version = int.from_bytes(data[pos:pos+size], byteorder="little") | |
if version != 1: | |
errors.append(f"Bad file version at [{pos}:{pos+size}]: expected 1, " | |
"found: {version}.") | |
pos, size = pos + size, 1 | |
bit_len = int.from_bytes(data[pos:pos+size], byteorder="little") | |
if bit_len not in VALID_BIT_LEN: | |
errors.append(f"Bad bit length at [{pos}:{pos+size}]: expected one of " | |
f"{VALID_BIT_LEN}, found: {bit_len}.") | |
pos, size = pos + size, 1 | |
db_type = int.from_bytes(data[pos:pos+size], byteorder="little") | |
if db_type not in VALID_DB_TYPE: | |
errors.append(f"Bad database type at [{pos}:{pos+size}]: expected one " | |
f"of {VALID_DB_TYPE}, found: {db_type}.") | |
exp = 2 ** (bit_len - 1) | |
pos, size = pos + size, 8 | |
checkpoint = int.from_bytes(data[pos:pos+size], byteorder="little") | |
if checkpoint != exp: | |
errors.append(f"Incomplete file: at position [{pos}:{pos+size}], " | |
f"expected checkpoint {exp}, found: {checkpoint}") | |
pos, size = pos + size, 32 | |
pub_key_hash = data[pos:pos+size] | |
pos, size = pos + size, 33 | |
pub_key = data[pos:pos+size] | |
errors.extend(verify_pub_key_hash(pub_key, pub_key_hash)) | |
pos += size | |
if sum(data[pos:HEADER_LEN]): | |
errors.append(f"Found non-zero bytes in padding [{pos}:{HEADER_LEN}].") | |
if full_check: | |
with open(path, "rb") as dbf: | |
data = dbf.read()[HEADER_LEN:] | |
if db_type == 1: | |
errors += check_type_a(data) | |
if db_type == 2: | |
errors += check_type_b(data, pub_key_hash, bit_len) | |
return errors | |
def check_type_a(data: bytes) -> list[str]: | |
"""Check Type A table contents.""" | |
raise NotImplementedError("TODO!") | |
def check_type_b(data: bytes, pub_key_hash: bytes, bit_len: int) -> list[str]: | |
"""Check Type B table contents.""" | |
errors = [] | |
bit_mask = (1 << bit_len) - 1 | |
byte_len = (bit_len + 7) // 8 | |
row_count = (1 << bit_len) | |
prefix = hashlib.sha256(b"MASS").digest() + pub_key_hash | |
exp = (row_count * 2) * byte_len | |
if len(data) != exp: | |
errors.append(f"Incorrect data length, expected {exp} bytes, found: " | |
f"{len(data)}.") | |
tbl_a = plot_table_a(pub_key_hash, bit_len) | |
# The generated table should have the correct number of rows: | |
assert len(tbl_a) == row_count | |
tbl_b = [0] * row_count | |
for row, val in enumerate(tbl_a): | |
row_p = row ^ bit_mask | |
val_p = tbl_a[row_p] | |
val_z = hashlib.sha256( | |
prefix + | |
val.to_bytes(byte_len, byteorder="little") + | |
val_p.to_bytes(byte_len, byteorder="little") | |
).digest()[:byte_len] | |
z_key = int.from_bytes(val_z, byteorder="little") | |
tbl_b[z_key] = val | |
tbl_b[z_key + 1] = val_p | |
continue | |
for row, val in enumerate(tbl_b): | |
exp = val.to_bytes(byte_len, byteorder="little") | |
got = data[row*byte_len:(row+1)*byte_len] | |
if exp != got: | |
print(f"Table B @ row {row} expected {exp}, got: {got}.") | |
return errors | |
def plot_table_a(pub_key_hash: bytes, bit_len: int) -> list[int]: | |
"""Reconstruct a Type A table from scratch.""" | |
bit_mask = (1 << bit_len) - 1 | |
byte_len = (bit_len + 7) // 8 | |
half_count = (1 << (bit_len - 1)) | |
prefix = hashlib.sha256(b"MASS").digest() + pub_key_hash | |
tbl_a: list[tuple[bytes, bytes]] = [] | |
tbl_a = [0] * half_count * 2 | |
for row in range(half_count): | |
# Calculate for x = row: | |
val = row.to_bytes(byte_len, byteorder="little") | |
key = hashlib.sha256(prefix + val).digest()[:byte_len] | |
tbl_a[int.from_bytes(key, byteorder="little")] = row | |
# Calculate for x' = ~x | |
row ^= bit_mask | |
val = row.to_bytes(byte_len, byteorder="little") | |
key = hashlib.sha256(prefix + val).digest()[:byte_len] | |
tbl_a[int.from_bytes(key, byteorder="little")] = row | |
return tbl_a | |
def verify_pub_key_hash(pub_key: bytes, pub_key_hash: bytes) -> list[str]: | |
"""Verify the public key hash.""" | |
checksum = hashlib.sha256(hashlib.sha256(pub_key).digest()).digest() | |
if checksum != pub_key_hash: | |
exp = "".join("{:X}".format(c) for c in pub_key_hash) | |
got = "".join("{:X}".format(c) for c in checksum) | |
return [f"Bad public key hash, expected: {exp}, got: {got}."] | |
return [] | |
if __name__ == "__main__": | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment