Created
May 31, 2025 02:02
-
-
Save Rexicon226/64a0d6f5e20a90985e7e707424fcae8b to your computer and use it in GitHub Desktop.
Pippinger MSM with no allocation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
inline fn radixSizeHint(w: u64) u64 { | |
return switch (w) { | |
4...7 => (@as(u64, 256) + w - 1) / w, | |
8 => (@as(u64, 256) + w - 1) / w + 1, | |
else => unreachable, | |
}; | |
} | |
inline fn mulByPow2(p: Ed25519, k: u32) Ed25519 { | |
var s = p; | |
for (0..k - 1) |_| s = s.dbl(); | |
return s.dbl(); | |
} | |
fn asRadix(c: CompressedScalar, w: u6) [64]i8 { | |
var scalar64x4: [4]u64 = @splat(0); | |
@memcpy(scalar64x4[0..4], std.mem.bytesAsSlice(u64, &c)); | |
const radix = @as(u64, 1) << w; | |
const window_mask = radix - 1; | |
var carry: u64 = 0; | |
var digits: [64]i8 = @splat(0); | |
const digits_count = (@as(u64, 256) + w - 1) / w; | |
for (0..digits_count) |i| { | |
const bit_offset = i * w; | |
const u64_idx = bit_offset / 64; | |
const bit_idx: u6 = @intCast(bit_offset % 64); | |
const bit_buf: u64 = if (bit_idx < @as(u64, 64) - w or u64_idx == 3) | |
scalar64x4[u64_idx] >> bit_idx | |
else | |
(scalar64x4[u64_idx] >> bit_idx) | (scalar64x4[1 + u64_idx] << @intCast(@as(u64, 64) - bit_idx)); | |
const coef = carry + (bit_buf & window_mask); | |
carry = (coef + (radix / 2)) >> w; | |
const signed_coef: i64 = @bitCast(coef); | |
digits[i] = @truncate(signed_coef - @as(i64, @bitCast(carry << w))); | |
} | |
switch (w) { | |
8 => digits[digits_count] += @intCast(@as(i64, @bitCast(carry))), | |
else => digits[digits_count - 1] += @intCast(@as(i64, @bitCast(carry << w))), | |
} | |
return digits; | |
} | |
fn fastMsm( | |
comptime max_elements: comptime_int, | |
compressed_scalars: []const CompressedScalar, | |
ed_points: []const Ed25519, | |
) Ed25519 { | |
std.debug.assert(compressed_scalars.len == ed_points.len); | |
std.debug.assert(compressed_scalars.len <= max_elements); | |
const w: u6 = if (max_elements < 500) | |
6 | |
else if (max_elements < 800) | |
7 | |
else | |
8; | |
const max_digit = @as(u64, 1) << w; | |
const digits_count = radixSizeHint(w); | |
const buckets_count = max_digit / 2; | |
var scalars: std.BoundedArray([64]i8, max_elements) = .{}; | |
// var points: std.BoundedArray(NeilsPoint, max_elements) = .{}; | |
for (compressed_scalars) |s| { | |
scalars.appendAssumeCapacity(asRadix(s, w)); | |
} | |
var columns: [digits_count]Ed25519 = undefined; | |
var buckets: [buckets_count]Ed25519 = @splat(.identityElement); | |
for (0..digits_count, &columns) |fwd, *column| { | |
const digit_index = digits_count - 1 - fwd; | |
@memset(&buckets, .identityElement); | |
for (scalars.constSlice(), ed_points) |digits, pt| { | |
const digit = digits[digit_index]; | |
switch (std.math.order(digit, 0)) { | |
.gt => { | |
const b: u64 = @intCast(digit - 1); | |
buckets[b] = buckets[b].add(pt); | |
}, | |
.lt => { | |
const b: u64 = @intCast(-digit - 1); | |
buckets[b] = buckets[b].sub(pt); | |
}, | |
.eq => {}, | |
} | |
} | |
var buckets_interm_sum = buckets[buckets_count - 1]; | |
var buckets_sum = buckets[buckets_count - 1]; | |
for (0..buckets_count - 1) |bucket_fwd| { | |
const i = buckets_count - 2 - bucket_fwd; | |
buckets_interm_sum = buckets_interm_sum.add(buckets[i]); | |
buckets_sum = buckets_sum.add(buckets_interm_sum); | |
} | |
column.* = buckets_sum; | |
} | |
var hi_column = columns[0]; | |
for (columns[1..]) |p| { | |
hi_column = mulByPow2(hi_column, w).add(p); | |
} | |
return hi_column; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment