Skip to content

Instantly share code, notes, and snippets.

@tompng
Last active July 18, 2025 18:26
Show Gist options
  • Save tompng/b984ec29987bf55820c9ef675bfb12d2 to your computer and use it in GitHub Desktop.
Save tompng/b984ec29987bf55820c9ef675bfb12d2 to your computer and use it in GitHub Desktop.
module Operand
def default_bin(op, other)
if is_a?(Val) && other.is_a?(Val)
return Val.new(value.send(op, other.value))
end
Op.new(op, self, other)
end
def +(other)
other = Val.new(other) if other.is_a?(Integer)
zero? ? other : other.zero? ? self : default_bin(:+, other)
end
def -(other)
other = Val.new(other) if other.is_a?(Integer)
zero? ? -other : other.zero? ? self : default_bin(:-, other)
end
def *(other)
other = Val.new(other) if other.is_a?(Integer)
zero? || other.zero? ? Val.new(0) : one? ? other : other.one? ? self : default_bin(:*, other)
end
def /(other)
other = Val.new(other) if other.is_a?(Integer)
zero? ? Val.new(0) : default_bin(:/, other)
end
def %(other)
other = Val.new(other) if other.is_a?(Integer)
zero? ? Val.new(0) : default_bin(:%, other)
end
def -@
is_a?(Val) ? Val.new(-value) : Op.new(:-@, self)
end
def +@ = self
def long?
min, max = minmax
min < -2**31 || max > 2**31-1
end
end
class Var
include Operand
attr_reader :name, :min, :max
def one? = false
def zero? = false
def minmax = [@min, @max]
def initialize(name, min: 0, max: 2**32-1)
@name = name
@min, @max = min, max
end
def hash
name.hash
end
end
class Val
include Operand
attr_reader :value
def one? = @value == 1
def zero? = @value == 0
def minmax = [@value, @value]
def initialize(value)
raise unless value.is_a?(Integer)
@value = value
end
def hash
@value.hash
end
end
class Op
include Operand
attr_reader :op, :args, :min, :max
def one? = false
def zero? = false
def minmax = [@min, @max]
def initialize(op, left, right=nil)
@op = op
@args = [left, right].compact
@min, @max = calculate_minmax
end
def hash
@hash ||= [@op, @args.map(&:hash)].hash
end
def calculate_minmax
(lmin, lmax), (rmin, rmax) = @args.map(&:minmax)
case @op
in :+@
[lmin, lmax]
in :-@
[-lmax, lmin]
in :+ | :- | :*
[lmin, lmax].product([rmin, rmax]).map {_1.send @op, _2}.minmax
in :/
right = [rmin, rmax]
right << -1 if rmin < 0 && rmax >= 0
right << 1 if rmin <= 0 && rmax > 0
[lmin, lmax].product(right).map { _1 / _2 }.minmax
in :%
[0, rmax]
end
end
end
def count_op(values)
visited = {}
count = 0
visit = ->node{
return if visited[node]
visited[node] = true
if node.is_a?(Op)
count += 1
node.args.each(&visit)
end
}
values.each(&visit)
count
end
def generate_code(assigns, values)
node_counts = Hash.new(0)
pre_visit = ->(node) {
next unless node.is_a?(Op)
node_counts[node] += 1
node.args.each(&pre_visit) if node_counts[node] == 1
}
values.each(&pre_visit)
vars = {}
reused = node_counts.select { |_, count| count > 1 }.transform_values { nil }
genexp = ->(node) {
return reused[node] if reused[node]
exp = (
case node
in Var
node.name
in Val
node.value < 0 ? "(#{node.value})" : node.value.to_s
in Op
l, r = node.args
le, re = node.args.map(&genexp)
if node.op == :-@
"(-#{le})"
elsif node.long? && !l.long? && r.long?
"((long long)#{le}#{node.op}#{re})"
else
"(#{le}#{node.op}#{re})"
end
end
)
if reused.key?(node)
v = "v#{vars.size}"
vars[v] = exp
reused[node] = v
else
exp
end
}
value_exps = values.map(&genexp)
code = [
vars.map {|v, exp| "#{v} = #{exp}" }.each_slice(4).map do
"long long #{it.join(", ")};"
end,
assigns.zip(value_exps).map {|assign, exp| "#{assign} = #{exp};" }.each_slice(4).map{it.join(' ')}
].join("\n")
end
def karatsuba_mult(a, b, threshold = 6)
return normal_mult(a, b) if a.size < threshold || b.size < threshold
n = [(a.size + 1) / 2, (b.size + 1) / 2].max
a0, a1 = a[0...n], a[n..]
b0, b1 = b[0...n], b[n..]
a1 << Val.new(0) while a1.size < n
b1 << Val.new(0) while b1.size < n
ab0 = karatsuba_mult(a0, b0, threshold)
ab2 = karatsuba_mult(a1, b1, threshold)
ad = a0.zip(a1).map { _1 - _2 }
bd = b0.zip(b1).map { _1 - _2 }
abd = karatsuba_mult(ad, bd, threshold)
ab1 = ab0.zip(ab2, abd).map { _1 + _2 - _3 }
c = []
[ab0, ab1, ab2].each_with_index do |ab, i|
ab.each_with_index do |val, j|
idx = i * n + j
c[idx] = c[idx] ? c[idx] + val : val
end
end
c
end
def normal_mult(a, b)
c = []
a.each_with_index do |ai, i|
b.each_with_index do |bj, j|
ab = ai * bj
c[i+j] = c[i+j] ? c[i+j] + ab : ab
end
end
c
end
def generate_inline_mult(nmin, nmax = nmin)
n = nmax
dig = 10000
a = n.times.map { |i| Var.new("a#{i}", min: 0, max: dig-1) }
b = n.times.map { |i| Var.new("b#{i}", min: 0, max: dig-1) }
values = karatsuba_mult(a, b, 6)
carry = 0
(0...values.size).each do |i|
sum = values[i] + carry
values[i] = sum % dig
carry = sum / dig
end
values << carry
if nmin == nmax
assigns = (2 * n).times.map { |i| "c[#{i}]" }
<<~CODE
void mult_#{n}(int *a, int *b, int *c) {
int #{n.times.map{"a#{it}=a[#{it}]"}.join(", ")};
int #{n.times.map{"b#{it}=b[#{it}]"}.join(", ")};
#{generate_code(assigns, values).gsub(/\n/, "\n ")}
}
CODE
else
assigns = (nmin * 2).times.map { "c[#{it}]" }
assigns += (nmin * 2...nmax * 2).map { "if (abn > #{it}) { c[#{it}]" }
<<~CODE
void mult_#{nmin}_#{nmax}(int an, int bn, int *a, int *b, int *c) {
int #{n.times.map{"a#{it}"}.join(', ')};
int #{n.times.map{"b#{it}"}.join(', ')};
#{(0...nmin).map { "a#{it} = a[#{it}]" }.join('; ')};
#{(0...nmin).map { "b#{it} = b[#{it}]" }.join('; ')};
#{(nmin...nmax).map{"a#{it}=an>#{it}?a[#{it}]:0"}.join('; ')};
#{(nmin...nmax).map{"b#{it}=bn>#{it}?b[#{it}]:0"}.join('; ')};
int abn = an + bn;
#{generate_code(assigns, values).gsub(/\n/, "\n ")}
#{'}' * (nmax-nmin) * 2}
}
CODE
end
end
n = 16
a = n.times.map{Var.new("a#{it}")}
b = n.times.map{Var.new("b#{it}")}
p count_op(normal_mult(a, b))
p count_op(karatsuba_mult(a, b, 6))
File.write("inline_mult.c", <<~C)
#{generate_inline_mult(16, 19)}
#{generate_inline_mult(20, 25)}
#{generate_inline_mult(26, 32)}
void mult_16_32(int an, int bn, int *a, int *b, int *c) {
switch(an){
#{(16..19).map { |n| "case #{n}:" }.join("\n ")}
mult_16_19(an, bn, a, b, c); break;
#{(20..25).map { |n| "case #{n}:" }.join("\n ")}
mult_20_25(an, bn, a, b, c); break;
#{(26..32).map { |n| "case #{n}:" }.join("\n ")}
mult_26_32(an, bn, a, b, c); break;
}
}
C
#include <stdio.h>
#include <sys/time.h>
#include "inline_mult.c"
#include "inline_mult_hand.c"
void mult_loop(int an, int bn, int *a, int *b, int *c) {
long long c2[64] = {0};
for(int i = 0; i < an; i++) {
for(int j = 0; j < bn; j++) {
c2[i + j] += a[i] * b[j];
}
}
long long carry = 0;
for (int i = 0; i < an + bn; i++) {
long long x = carry + c2[i];
c[i] = x % 10000;
carry = x / 10000;
}
}
#define N 100000
int main(int argc, char *argv[]) {
int a[32] = {1234, 5678, 9101, 1121, 3141, 5161, 7181, 9202, 2233, 4455, 6677, 8899, 1010, 1213, 1415, 1617, 1234, 5678, 9101, 1121, 3141, 5161, 7181, 9202, 2233, 4455, 6677, 8899, 1010, 1213, 1415, 1617};
int b[32] = {1716, 1514, 1312, 1110, 9987, 7765, 5543, 3321, 2109, 8765, 5432, 1098, 7654, 3210, 9876, 5432, 1716, 1514, 1312, 1110, 9987, 7765, 5543, 3321, 2109, 8765, 5432, 1098, 7654, 3210, 9876, 5432};
int size_from = 16;
int size_to = 32;
if (argc > 100) {
for (int i = 0; i < 32; i++) a[i] += (argc + i) % (10 + i);
size_from += argc;
size_to += argc;
}
struct timespec start_time, end_time;
int sum1 = 0;
int sum2 = 0;
long long tsum1 = 0;
long long tsum2 = 0;
for (int size = size_from; size <= size_to; size++) {
printf("size=%d\n", size);
clock_gettime(CLOCK_REALTIME, &start_time);
for(int i=0;i<N;i++){
int c[64] = {0};
mult_16_32(size, size, a, b, c);
for (int j = 0; j < 64; j++) sum1 += c[j];
}
clock_gettime(CLOCK_REALTIME, &end_time);
long long t1 = (end_time.tv_sec - start_time.tv_sec) * 1000000000 + (end_time.tv_nsec - start_time.tv_nsec);
printf("time: %lld\n", t1);
tsum1 += t1;
clock_gettime(CLOCK_REALTIME, &start_time);
for(int i=0;i<N;i++){
int c[64] = {0};
mult_loop(size, size, a, b, c);
for (int j = 0; j < 64; j++) sum2 += c[j];
}
clock_gettime(CLOCK_REALTIME, &end_time);
long long t2 = (end_time.tv_sec - start_time.tv_sec) * 1000000000 + (end_time.tv_nsec - start_time.tv_nsec);
tsum2 += t2;
printf("time: %lld\n", t2);
}
printf("t1=%lld, t2=%lld\n", tsum1, tsum2);
printf("sum1=%d, sum2=%d\n", sum1, sum2);
{
int c[64] = {0};
mult_16_32(30, 30, a, b, c);
for (int i = 0; i < 64; i++) printf("%d ", c[i]);
printf("\n");
}
{
int c[64] = {0};
mult_loop(30, 30, a, b, c);
for (int i = 0; i < 64; i++) printf("%d ", c[i]);
printf("\n");
}
printf("\n");
return 0;
}
require 'matrix'
# toomcook with n segments * m segments
def toomcook_params(an, bn)
cn = an + bn - 1
xs = [0, Float::INFINITY, 1, -1]
while xs.size < cn
a = xs.size / 4 + 1
xs.push(a, -a, 1r/a, -1r/a)
end
xs = xs.take(cn)
xparams = ->(n, x) {
if x == Float::INFINITY
[0] * (n - 1) + [1]
else
n.times.map { (x**_1 * x.denominator**(n-1)).to_i }
end
}
mults = xs.map {|x| [xparams.call(an, x), xparams.call(bn, x)] }
matrix = Matrix[*xs.map {|x| xparams.call(cn, x) }].inv
[mults, matrix]
end
toomcook_params 5,5;
def mult_with_toomcook(a, b)
mults, matrix = toomcook_params(5, 5)
cs = mults.map do |as, bs|
as.zip(a).sum { _1*_2 } * bs.zip(b).sum { _1*_2 }
end
matrix * Vector[*cs]
end
def bd_mult_with_toomcook
t=Time.now
n = 4
m = 2500
mults, matrix = toomcook_params(n, n)
a = n.times.map{ BigDecimal('1'*m) }
b = n.times.map{ BigDecimal('1'*m) }
cs = mults.map do |as, bs|
as.zip(a).sum { _1*_2 } * bs.zip(b).sum { _1*_2 }
end
ds = matrix.to_a.map do |row|
lcm = row.map(&:denominator).reduce(:lcm)
row.zip(cs).sum { |x, y| (x * lcm).to_i * y } / lcm
end
ans = ds.map.with_index.sum {|a, b| a * BigDecimal("1e#{b * m}")}
end
require 'bigdecimal'
binding.irb
exit
bd_mult_with_toomcook
p mult_with_toomcook([1,2,3,4,5], [3,4,7,1,2])
binding.irb
def generate_multiplication_code(an, bn)
mults, matrix = toomcook_params(an, bn)
cn = an + bn - 1
<<~C_CODE
void toomcook_multiply_#{an}_#{bn}(int a_size, const int *a, const b_size, const int *b, int *c, int *tmp) {
int s = Max((a_size+#{an-1})/#{an}, (b_size+#{bn-1}/#{bn});
int a0_size = a_size - s*#{an-1};
int b0_size = b_size - s*#{bn-1};
int *a0=a,#{(1...an).map { |i| "*a#{i}=a+a0_size+s*#{i-1}" }.join(', ')};
int *b0=b,#{(1...bn).map { |i| "*b#{i}=b+b0_size+s*#{i-1}" }.join(', ')};
int #{(0..cn+2).map { |i| "*c#{i}=tmp+s*#{2*i+2}" }.join(', ')};
toomcook_multiply(a0_size, a0, b0_size, b0, c0, c1);
toomcook_multiply(s, a1, s, b1, c1, c2);
#{
(2...cn).map do |i|
amult, bmult = mults[i]
asum_target = "c#{i+1}"
bsum_target = "c#{i+2}"
mult_target = "c#{i}"
mult_tmp = "c#{i+1}"
<<~C_CODE
toomcook_multiply(s, #{asum_target}, s, #{bsum_target}, #{mult_target}, #{mult_tmp});
C_CODE
end.join("\n").gsub(/^/, ' '*2)
}
}
C_CODE
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment