Last active
August 16, 2019 04:46
-
-
Save MaskRay/258495af940ca1d50fd1a26088403b38 to your computer and use it in GitHub Desktop.
FFT技巧
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
https://www.hackerrank.com/contests/w23/challenges/sasha-and-swaps-ii | |
感谢ftiasch老师教导,参考了 https://async.icpc-camp.org/d/408-fft 和其他一些地方的东西 | |
P = 1e9+7 | |
Q = ceil(sqrt(P)) | |
a, b为两个向量 | |
c = convolution(a, b) | |
c系数取值为 [0, n*(P-1)^2] 的整数,若n*(P-1)^2的表示需要超过53 bits(double mantissa)则很可能会出错 | |
通常用的 IFFT 是 unnormalized 的,中间结果是 [0, n^2*(P-1)^2] 的整数,但因为涉及到 `x /= n` 所以只要求能表示 [0, n*(P-1)^2] 的整数 | |
# 技巧 | |
## 技巧0:折半 | |
分解a = a0 + Q * a1, 0 <= a0,a1 < Q | |
求出 | |
c00 = convolution(a0, b0) | |
c01 = convolution(a0, b1) | |
c10 = convolution(a1, b0) | |
c11 = convolution(a1, b1) | |
以上四个 convolution 系数最大值为 n*(Q-1)^2 ~= n*P | |
c = c00 + Q*(c01+c10) + Q*Q*c11 | |
使用Toom-2 (Karatsuba)可以简化为三次convolution | |
## 技巧1:[0,P-1] => [-(P-1)/2, (P-1)/2] | |
设 a,b 系数取自 [0,P-1] 的 uniform distribution | |
则 c 系数均值阶为 np^2/4,方差阶为 np^4/9 | |
若平移至 [-(P-1)/2, (P-1)/2] | |
则 c 系数均值 0,方差阶为 np^4/144 | |
由Chebyshev's inequality,系数绝对值在若干倍标准差以内 | |
## 技巧2a:正交地计算两个FFT;辅助技巧0 | |
取S与sqrt(P)接近且 M=P-S*S%P 尽可能小 | |
分解 a = a0 + S * a1, b = b0 + S * b1 | |
用两次FFT一次IFFT计算 convolution(a0+i*sqrt(M)*a1, b0+i*sqrt(M)*b1) 即得到 | |
convolution(a0,b0) - M*convolution(a1,b1) + i*sqrt(M)*(convolution(a0,b1)+convolution(a1,b0)) | |
分离real和imag即可算出 c = convolution(a0, b0) + S * (convolution(a0, b1) + convolution(a1, b0)) - M * convolution(a1, b1) | |
## 技巧2b | |
效率比技巧2a略低,用两次FFT和两次IFFT,但系数绝对值更小 | |
分解 a = a0 + Q * a1, b = b0 + Q * b1 | |
记 A = a0+i*a1, B = b0+i*b1,函数 rev(a) = {a[0], a[n-1], a[n-2], ..., a[1]} | |
计算 FFT(A) 与 FFT(B) 后求出: | |
- FFT(a0) = FFT(re(A)) = [FFT(A) + FFT(conj(A))] / 2 = [FFT(A) + conj(rev(FFT(A)))] / 2 | |
- FFT(a1) = FFT(im(A)) = [FFT(A) - FFT(conj(A))] * -0.5i = [FFT(A) - conj(rev(FFT(A)))] * -0.5i | |
- FFT(b0) = FFT(re(B)) = [FFT(B) + FFT(conj(B))] / 2 = [FFT(B) + conj(rev(FFT(B)))] / 2 | |
- FFT(b1) = FFT(im(B)) = [FFT(B) - FFT(conj(B))] * -0.5i = [FFT(B) - conj(rev(FFT(B)))] * -0.5i | |
再用 IFFT 计算: | |
convolution(a0, b0) + i * convolution(a0, b1) = IFFT(FFT(a0)*FFT(b0) + i*FFT(a0)*FFT(b1)) | |
convolution(a1, b0) + i * convolution(a1, b1) = IFFT(FFT(a1)*FFT(b0) + i*FFT(a1)*FFT(b1)) | |
分离real和imag即可算出 c = convolution(a0, b0) + Q * (convolution(a0, b1) + convolution(a1, b0)) + Q * Q * convolution(a1, b1) | |
另外,用Haskell的记号:ifft = (/n) . fft . rev | |
https://www.hackerrank.com/rest/contests/w23/challenges/sasha-and-swaps-ii/hackers/Hezhu/download_solution | |
# 题目 | |
sasha-and-swaps-ii 题中 P = 10^9+7,取S=10^5,M=70 | |
绝对值<=P-1的原系数调整为 (floor((P-1)/2/Q) + i*M*(Q-1)) | |
结果的real绝对值最大值约为 n * ((P/Q)^2/4+M*Q^2),55.96 bits,但均值为0取到最大值可能性极低 | |
防止最坏情况下出问题,可以取 independent and identically distributed 随机数向量 noise | |
convolution(a, b) = convolution(a+noise, b) - convolution(noise, b) | |
# 其他 | |
根据 Roundoff Error Analysis of the Fast Fourier Transform,没仔细看 | |
relative error 均值为 log2(n)*浮点运算精度*变换前系数最大值 | |
哪里看到的,unit root一定要用cos(2*M_PI/n*m) sin(2*M_PI/n*m)或者lookup table,用乘法`w *= dw`计算会使误差达到指数级。我觉得可能指误差达到 n*浮点运算精度*变换前系数最大值 | |
有了技巧0+1+2,感觉 complex<double> 的 Fast Fourier transform 恒优于 Fast number theoretic transform | |
涉及 int64 时,a*b % m 性能很差。`long x = a*b, r = x - mod*long(double(a)*double(b)/mod+0.5); return r < 0 ? r + mod : r;`比 汇编MUL DIV 快,但还是不如 complex<double> | |
代码中调用 complex<double>::operator* 的地方会编译为 call __muldc3,__muldc3 会判断NAN INF,有很多多余操作,性能很低。如果编译时带上 -ffast-math 可以快很多,但 __attribute__((optimize("fast-math"))) 这些没有效果,因为 __muldc3 在其他文件中不受到影响。最好自行实现 complex<double> 的乘法 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
#include <complex> | |
#include <iostream> | |
#include <type_traits> | |
#include <utility> | |
#include <vector> | |
using namespace std; | |
typedef complex<double> cd; | |
#define FOR(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (a); i < (b); i++) | |
#define REP(i, n) FOR(i, 0, n) | |
#define ROF(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (b); --i >= (a); ) | |
const long MOD = 1000000007, SQ = 100000, NN = 262144; | |
const double IROOT = sqrt(double(MOD-SQ*SQ%MOD)); | |
long bitrev[NN]; | |
cd units[NN]; | |
void fft_prepare(long n) | |
{ | |
long logn = 63-__builtin_clzl(n); | |
REP(i, n) | |
bitrev[i] = bitrev[i>>1] >> 1 | (i & 1) << logn-1; | |
double ph = 2*M_PI/n; | |
REP(i, n) | |
units[i] = {cos(ph*i), sin(ph*i)}; | |
} | |
void fft_dit2(cd a[], long n, int is) | |
{ | |
long logn = 63-__builtin_clzl(n); | |
if (is < 0) | |
for (long i = 1, j = n-1; i < j; i++, j--) | |
swap(a[i], a[j]); | |
REP(i, n) | |
if (i < bitrev[i]) | |
swap(a[i], a[bitrev[i]]); | |
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1) { | |
long mh = m >> 1; | |
for (long r = 0; r < n; r += m) { | |
cd *x = a+r, *y = a+r+mh, *w = units; | |
REP(j, mh) { | |
cd t{y->real()*w->real()-y->imag()*w->imag(), y->real()*w->imag()+y->imag()*w->real()}; | |
*y++ = *x-t; | |
*x++ += t; | |
w += dwi; | |
} | |
} | |
} | |
if (is < 0) | |
REP(i, n) | |
a[i] *= 1.0/n; | |
} | |
vector<cd> fft_interleave(const vector<int>& a, long n) | |
{ | |
vector<cd> r(n); | |
REP(i, a.size()) { | |
long z = a[i] <= MOD/2 ? a[i] : a[i]-MOD; | |
r[i] = cd(z%SQ, z/SQ*IROOT); | |
} | |
fft_dit2(&r[0], n, 1); | |
return r; | |
} | |
vector<int> ifft_interleave(vector<cd>& a) | |
{ | |
fft_dit2(&a[0], a.size(), -1); | |
vector<int> r(a.size()); | |
REP(i, a.size()) { | |
long x = round(a[i].real()), y = long(round(a[i].imag()/IROOT)); | |
r[i] = (x+y%MOD*SQ)%MOD; | |
if (r[i] < 0) r[i] += MOD; | |
} | |
return r; | |
} | |
vector<int> rising_factorial(long l, long h) | |
{ | |
if (h-l <= 64-1) { | |
vector<int> r(h-l+1); | |
r[0] = 1; | |
FOR(i, l, h) { | |
int ul = r[0]; | |
r[0] = r[0]*i%MOD; | |
REP(j, i-l+1) { | |
int t = (r[j+1]*i+ul)%MOD; | |
ul = r[j+1]; | |
r[j+1] = t; | |
} | |
} | |
r.resize(1 << 63-__builtin_clzl(r.size()-1)+1); | |
return r; | |
} | |
long m = l+h >> 1; | |
auto a = rising_factorial(l, m), b = rising_factorial(m, h); | |
long n = 1 << 63-__builtin_clzl(a.size()+b.size()-2)+1; | |
fft_prepare(n); | |
auto aa = fft_interleave(a, n), bb = fft_interleave(b, n); | |
REP(i, n) | |
aa[i] *= bb[i]; | |
return ifft_interleave(aa); | |
} | |
int main() | |
{ | |
ios::sync_with_stdio(0); | |
cin.tie(0); | |
long n; | |
cin >> n; | |
vector<int> stirling1 = rising_factorial(0, n); | |
ROF(i, 0, n-1) | |
stirling1[i] = (stirling1[i]+stirling1[i+2])%MOD; | |
ROF(i, 1, n) | |
cout << stirling1[i] << ' '; | |
cout << endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment