Last active
July 18, 2025 18:26
-
-
Save tompng/b984ec29987bf55820c9ef675bfb12d2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Operand | |
def default_bin(op, other) | |
if is_a?(Val) && other.is_a?(Val) | |
return Val.new(value.send(op, other.value)) | |
end | |
Op.new(op, self, other) | |
end | |
def +(other) | |
other = Val.new(other) if other.is_a?(Integer) | |
zero? ? other : other.zero? ? self : default_bin(:+, other) | |
end | |
def -(other) | |
other = Val.new(other) if other.is_a?(Integer) | |
zero? ? -other : other.zero? ? self : default_bin(:-, other) | |
end | |
def *(other) | |
other = Val.new(other) if other.is_a?(Integer) | |
zero? || other.zero? ? Val.new(0) : one? ? other : other.one? ? self : default_bin(:*, other) | |
end | |
def /(other) | |
other = Val.new(other) if other.is_a?(Integer) | |
zero? ? Val.new(0) : default_bin(:/, other) | |
end | |
def %(other) | |
other = Val.new(other) if other.is_a?(Integer) | |
zero? ? Val.new(0) : default_bin(:%, other) | |
end | |
def -@ | |
is_a?(Val) ? Val.new(-value) : Op.new(:-@, self) | |
end | |
def +@ = self | |
def long? | |
min, max = minmax | |
min < -2**31 || max > 2**31-1 | |
end | |
end | |
class Var | |
include Operand | |
attr_reader :name, :min, :max | |
def one? = false | |
def zero? = false | |
def minmax = [@min, @max] | |
def initialize(name, min: 0, max: 2**32-1) | |
@name = name | |
@min, @max = min, max | |
end | |
def hash | |
name.hash | |
end | |
end | |
class Val | |
include Operand | |
attr_reader :value | |
def one? = @value == 1 | |
def zero? = @value == 0 | |
def minmax = [@value, @value] | |
def initialize(value) | |
raise unless value.is_a?(Integer) | |
@value = value | |
end | |
def hash | |
@value.hash | |
end | |
end | |
class Op | |
include Operand | |
attr_reader :op, :args, :min, :max | |
def one? = false | |
def zero? = false | |
def minmax = [@min, @max] | |
def initialize(op, left, right=nil) | |
@op = op | |
@args = [left, right].compact | |
@min, @max = calculate_minmax | |
end | |
def hash | |
@hash ||= [@op, @args.map(&:hash)].hash | |
end | |
def calculate_minmax | |
(lmin, lmax), (rmin, rmax) = @args.map(&:minmax) | |
case @op | |
in :+@ | |
[lmin, lmax] | |
in :-@ | |
[-lmax, lmin] | |
in :+ | :- | :* | |
[lmin, lmax].product([rmin, rmax]).map {_1.send @op, _2}.minmax | |
in :/ | |
right = [rmin, rmax] | |
right << -1 if rmin < 0 && rmax >= 0 | |
right << 1 if rmin <= 0 && rmax > 0 | |
[lmin, lmax].product(right).map { _1 / _2 }.minmax | |
in :% | |
[0, rmax] | |
end | |
end | |
end | |
def count_op(values) | |
visited = {} | |
count = 0 | |
visit = ->node{ | |
return if visited[node] | |
visited[node] = true | |
if node.is_a?(Op) | |
count += 1 | |
node.args.each(&visit) | |
end | |
} | |
values.each(&visit) | |
count | |
end | |
def generate_code(assigns, values) | |
node_counts = Hash.new(0) | |
pre_visit = ->(node) { | |
next unless node.is_a?(Op) | |
node_counts[node] += 1 | |
node.args.each(&pre_visit) if node_counts[node] == 1 | |
} | |
values.each(&pre_visit) | |
vars = {} | |
reused = node_counts.select { |_, count| count > 1 }.transform_values { nil } | |
genexp = ->(node) { | |
return reused[node] if reused[node] | |
exp = ( | |
case node | |
in Var | |
node.name | |
in Val | |
node.value < 0 ? "(#{node.value})" : node.value.to_s | |
in Op | |
l, r = node.args | |
le, re = node.args.map(&genexp) | |
if node.op == :-@ | |
"(-#{le})" | |
elsif node.long? && !l.long? && r.long? | |
"((long long)#{le}#{node.op}#{re})" | |
else | |
"(#{le}#{node.op}#{re})" | |
end | |
end | |
) | |
if reused.key?(node) | |
v = "v#{vars.size}" | |
vars[v] = exp | |
reused[node] = v | |
else | |
exp | |
end | |
} | |
value_exps = values.map(&genexp) | |
code = [ | |
vars.map {|v, exp| "#{v} = #{exp}" }.each_slice(4).map do | |
"long long #{it.join(", ")};" | |
end, | |
assigns.zip(value_exps).map {|assign, exp| "#{assign} = #{exp};" }.each_slice(4).map{it.join(' ')} | |
].join("\n") | |
end | |
def karatsuba_mult(a, b, threshold = 6) | |
return normal_mult(a, b) if a.size < threshold || b.size < threshold | |
n = [(a.size + 1) / 2, (b.size + 1) / 2].max | |
a0, a1 = a[0...n], a[n..] | |
b0, b1 = b[0...n], b[n..] | |
a1 << Val.new(0) while a1.size < n | |
b1 << Val.new(0) while b1.size < n | |
ab0 = karatsuba_mult(a0, b0, threshold) | |
ab2 = karatsuba_mult(a1, b1, threshold) | |
ad = a0.zip(a1).map { _1 - _2 } | |
bd = b0.zip(b1).map { _1 - _2 } | |
abd = karatsuba_mult(ad, bd, threshold) | |
ab1 = ab0.zip(ab2, abd).map { _1 + _2 - _3 } | |
c = [] | |
[ab0, ab1, ab2].each_with_index do |ab, i| | |
ab.each_with_index do |val, j| | |
idx = i * n + j | |
c[idx] = c[idx] ? c[idx] + val : val | |
end | |
end | |
c | |
end | |
def normal_mult(a, b) | |
c = [] | |
a.each_with_index do |ai, i| | |
b.each_with_index do |bj, j| | |
ab = ai * bj | |
c[i+j] = c[i+j] ? c[i+j] + ab : ab | |
end | |
end | |
c | |
end | |
def generate_inline_mult(nmin, nmax = nmin) | |
n = nmax | |
dig = 10000 | |
a = n.times.map { |i| Var.new("a#{i}", min: 0, max: dig-1) } | |
b = n.times.map { |i| Var.new("b#{i}", min: 0, max: dig-1) } | |
values = karatsuba_mult(a, b, 6) | |
carry = 0 | |
(0...values.size).each do |i| | |
sum = values[i] + carry | |
values[i] = sum % dig | |
carry = sum / dig | |
end | |
values << carry | |
if nmin == nmax | |
assigns = (2 * n).times.map { |i| "c[#{i}]" } | |
<<~CODE | |
void mult_#{n}(int *a, int *b, int *c) { | |
int #{n.times.map{"a#{it}=a[#{it}]"}.join(", ")}; | |
int #{n.times.map{"b#{it}=b[#{it}]"}.join(", ")}; | |
#{generate_code(assigns, values).gsub(/\n/, "\n ")} | |
} | |
CODE | |
else | |
assigns = (nmin * 2).times.map { "c[#{it}]" } | |
assigns += (nmin * 2...nmax * 2).map { "if (abn > #{it}) { c[#{it}]" } | |
<<~CODE | |
void mult_#{nmin}_#{nmax}(int an, int bn, int *a, int *b, int *c) { | |
int #{n.times.map{"a#{it}"}.join(', ')}; | |
int #{n.times.map{"b#{it}"}.join(', ')}; | |
#{(0...nmin).map { "a#{it} = a[#{it}]" }.join('; ')}; | |
#{(0...nmin).map { "b#{it} = b[#{it}]" }.join('; ')}; | |
#{(nmin...nmax).map{"a#{it}=an>#{it}?a[#{it}]:0"}.join('; ')}; | |
#{(nmin...nmax).map{"b#{it}=bn>#{it}?b[#{it}]:0"}.join('; ')}; | |
int abn = an + bn; | |
#{generate_code(assigns, values).gsub(/\n/, "\n ")} | |
#{'}' * (nmax-nmin) * 2} | |
} | |
CODE | |
end | |
end | |
n = 16 | |
a = n.times.map{Var.new("a#{it}")} | |
b = n.times.map{Var.new("b#{it}")} | |
p count_op(normal_mult(a, b)) | |
p count_op(karatsuba_mult(a, b, 6)) | |
File.write("inline_mult.c", <<~C) | |
#{generate_inline_mult(16, 19)} | |
#{generate_inline_mult(20, 25)} | |
#{generate_inline_mult(26, 32)} | |
void mult_16_32(int an, int bn, int *a, int *b, int *c) { | |
switch(an){ | |
#{(16..19).map { |n| "case #{n}:" }.join("\n ")} | |
mult_16_19(an, bn, a, b, c); break; | |
#{(20..25).map { |n| "case #{n}:" }.join("\n ")} | |
mult_20_25(an, bn, a, b, c); break; | |
#{(26..32).map { |n| "case #{n}:" }.join("\n ")} | |
mult_26_32(an, bn, a, b, c); break; | |
} | |
} | |
C |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <sys/time.h> | |
#include "inline_mult.c" | |
#include "inline_mult_hand.c" | |
void mult_loop(int an, int bn, int *a, int *b, int *c) { | |
long long c2[64] = {0}; | |
for(int i = 0; i < an; i++) { | |
for(int j = 0; j < bn; j++) { | |
c2[i + j] += a[i] * b[j]; | |
} | |
} | |
long long carry = 0; | |
for (int i = 0; i < an + bn; i++) { | |
long long x = carry + c2[i]; | |
c[i] = x % 10000; | |
carry = x / 10000; | |
} | |
} | |
#define N 100000 | |
int main(int argc, char *argv[]) { | |
int a[32] = {1234, 5678, 9101, 1121, 3141, 5161, 7181, 9202, 2233, 4455, 6677, 8899, 1010, 1213, 1415, 1617, 1234, 5678, 9101, 1121, 3141, 5161, 7181, 9202, 2233, 4455, 6677, 8899, 1010, 1213, 1415, 1617}; | |
int b[32] = {1716, 1514, 1312, 1110, 9987, 7765, 5543, 3321, 2109, 8765, 5432, 1098, 7654, 3210, 9876, 5432, 1716, 1514, 1312, 1110, 9987, 7765, 5543, 3321, 2109, 8765, 5432, 1098, 7654, 3210, 9876, 5432}; | |
int size_from = 16; | |
int size_to = 32; | |
if (argc > 100) { | |
for (int i = 0; i < 32; i++) a[i] += (argc + i) % (10 + i); | |
size_from += argc; | |
size_to += argc; | |
} | |
struct timespec start_time, end_time; | |
int sum1 = 0; | |
int sum2 = 0; | |
long long tsum1 = 0; | |
long long tsum2 = 0; | |
for (int size = size_from; size <= size_to; size++) { | |
printf("size=%d\n", size); | |
clock_gettime(CLOCK_REALTIME, &start_time); | |
for(int i=0;i<N;i++){ | |
int c[64] = {0}; | |
mult_16_32(size, size, a, b, c); | |
for (int j = 0; j < 64; j++) sum1 += c[j]; | |
} | |
clock_gettime(CLOCK_REALTIME, &end_time); | |
long long t1 = (end_time.tv_sec - start_time.tv_sec) * 1000000000 + (end_time.tv_nsec - start_time.tv_nsec); | |
printf("time: %lld\n", t1); | |
tsum1 += t1; | |
clock_gettime(CLOCK_REALTIME, &start_time); | |
for(int i=0;i<N;i++){ | |
int c[64] = {0}; | |
mult_loop(size, size, a, b, c); | |
for (int j = 0; j < 64; j++) sum2 += c[j]; | |
} | |
clock_gettime(CLOCK_REALTIME, &end_time); | |
long long t2 = (end_time.tv_sec - start_time.tv_sec) * 1000000000 + (end_time.tv_nsec - start_time.tv_nsec); | |
tsum2 += t2; | |
printf("time: %lld\n", t2); | |
} | |
printf("t1=%lld, t2=%lld\n", tsum1, tsum2); | |
printf("sum1=%d, sum2=%d\n", sum1, sum2); | |
{ | |
int c[64] = {0}; | |
mult_16_32(30, 30, a, b, c); | |
for (int i = 0; i < 64; i++) printf("%d ", c[i]); | |
printf("\n"); | |
} | |
{ | |
int c[64] = {0}; | |
mult_loop(30, 30, a, b, c); | |
for (int i = 0; i < 64; i++) printf("%d ", c[i]); | |
printf("\n"); | |
} | |
printf("\n"); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'matrix' | |
# toomcook with n segments * m segments | |
def toomcook_params(an, bn) | |
cn = an + bn - 1 | |
xs = [0, Float::INFINITY, 1, -1] | |
while xs.size < cn | |
a = xs.size / 4 + 1 | |
xs.push(a, -a, 1r/a, -1r/a) | |
end | |
xs = xs.take(cn) | |
xparams = ->(n, x) { | |
if x == Float::INFINITY | |
[0] * (n - 1) + [1] | |
else | |
n.times.map { (x**_1 * x.denominator**(n-1)).to_i } | |
end | |
} | |
mults = xs.map {|x| [xparams.call(an, x), xparams.call(bn, x)] } | |
matrix = Matrix[*xs.map {|x| xparams.call(cn, x) }].inv | |
[mults, matrix] | |
end | |
toomcook_params 5,5; | |
def mult_with_toomcook(a, b) | |
mults, matrix = toomcook_params(5, 5) | |
cs = mults.map do |as, bs| | |
as.zip(a).sum { _1*_2 } * bs.zip(b).sum { _1*_2 } | |
end | |
matrix * Vector[*cs] | |
end | |
def bd_mult_with_toomcook | |
t=Time.now | |
n = 4 | |
m = 2500 | |
mults, matrix = toomcook_params(n, n) | |
a = n.times.map{ BigDecimal('1'*m) } | |
b = n.times.map{ BigDecimal('1'*m) } | |
cs = mults.map do |as, bs| | |
as.zip(a).sum { _1*_2 } * bs.zip(b).sum { _1*_2 } | |
end | |
ds = matrix.to_a.map do |row| | |
lcm = row.map(&:denominator).reduce(:lcm) | |
row.zip(cs).sum { |x, y| (x * lcm).to_i * y } / lcm | |
end | |
ans = ds.map.with_index.sum {|a, b| a * BigDecimal("1e#{b * m}")} | |
end | |
require 'bigdecimal' | |
binding.irb | |
exit | |
bd_mult_with_toomcook | |
p mult_with_toomcook([1,2,3,4,5], [3,4,7,1,2]) | |
binding.irb | |
def generate_multiplication_code(an, bn) | |
mults, matrix = toomcook_params(an, bn) | |
cn = an + bn - 1 | |
<<~C_CODE | |
void toomcook_multiply_#{an}_#{bn}(int a_size, const int *a, const b_size, const int *b, int *c, int *tmp) { | |
int s = Max((a_size+#{an-1})/#{an}, (b_size+#{bn-1}/#{bn}); | |
int a0_size = a_size - s*#{an-1}; | |
int b0_size = b_size - s*#{bn-1}; | |
int *a0=a,#{(1...an).map { |i| "*a#{i}=a+a0_size+s*#{i-1}" }.join(', ')}; | |
int *b0=b,#{(1...bn).map { |i| "*b#{i}=b+b0_size+s*#{i-1}" }.join(', ')}; | |
int #{(0..cn+2).map { |i| "*c#{i}=tmp+s*#{2*i+2}" }.join(', ')}; | |
toomcook_multiply(a0_size, a0, b0_size, b0, c0, c1); | |
toomcook_multiply(s, a1, s, b1, c1, c2); | |
#{ | |
(2...cn).map do |i| | |
amult, bmult = mults[i] | |
asum_target = "c#{i+1}" | |
bsum_target = "c#{i+2}" | |
mult_target = "c#{i}" | |
mult_tmp = "c#{i+1}" | |
<<~C_CODE | |
toomcook_multiply(s, #{asum_target}, s, #{bsum_target}, #{mult_target}, #{mult_tmp}); | |
C_CODE | |
end.join("\n").gsub(/^/, ' '*2) | |
} | |
} | |
C_CODE | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment