From df3f1385507cc0ee59668eddcdce1c00f53740a1 Mon Sep 17 00:00:00 2001 From: mbinary Date: Tue, 11 Jun 2019 12:48:40 +0800 Subject: [PATCH] Update fft: iteration version --- math/fft.py | 65 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/math/fft.py b/math/fft.py index 5d6b9c6..487a51f 100644 --- a/math/fft.py +++ b/math/fft.py @@ -1,45 +1,91 @@ import numpy as np +def _fft_n2(a, invert): + '''O(n^2)''' + N = len(a) + w = np.arange(N) + i = 2j if invert else -2j + m = w.reshape((N, 1)) * w + W = np.exp(m * i * np.pi / N) + return np.concatenate(np.dot(W, a.reshape((N, + 1)))) # important, cannot use * + + def _fft(a, invert=False): - '''fft, len(a) is power of two''' + '''recursion version''' N = len(a) if N == 1: return [a[0]] - else: + elif N & (N - 1) == 0: # O(nlogn), 2^k even = _fft(a[::2], invert) odd = _fft(a[1::2], invert) i = 2j if invert else -2j factor = np.exp(i * np.pi * np.arange(N // 2) / N) prod = factor * odd return np.concatenate([even + prod, even - prod]) + else: + return _fft_n2(a, invert) + + +def _fft2(a, invert=False): + ''' iteration version''' + + def rev(x): + ret = 0 + for i in range(r): + ret <<= 1 + if x & 1: + ret += 1 + x >>= 1 + return ret + + N = len(a) + if N & (N - 1) == 0: # O(nlogn), 2^k + r = int(np.log(N)) + c = np.array(a,dtype='complex') + i = 2j if invert else -2j + w = np.exp(i * np.pi / N) + for h in range(r - 1, -1, -1): + p = 2**h + z = w**(N / p / 2) + for k in range(N): + if k % p == k % (2 * p): + c[k], c[k + p] = c[k] + c[k + p], c[k] * z**(k % p) + + return np.asarray([c[rev(i)] for i in range(N)]) + else: # O(n^2) + return _fft_n2(a, invert) def fft(a): '''fourier[a]''' n = len(a) - if n == 0 or n&(n-1)!=0: - raise Exception(f"[Error]: {n} is not power of 2") + if n == 0: + raise Exception("[Error]: Invalid length: 0") return _fft(a) def ifft(a): '''invert fourier[a]''' n = len(a) - if n == 0 or n&(n-1)!=0: - raise Exception(f"[Error]: {n} is not power of 2") + if n == 0: + raise Exception("[Error]: Invalid length: 0") return _fft(a, True) / n def fft2(arr): - return np.apply_along_axis(fft, 0, np.apply_along_axis(fft, 1, np.asarray(arr))) + return np.apply_along_axis(fft, 0, + np.apply_along_axis(fft, 1, np.asarray(arr))) def ifft2(arr): - return np.apply_along_axis(ifft, 0, np.apply_along_axis(ifft, 1, np.asarray(arr))) + return np.apply_along_axis(ifft, 0, + np.apply_along_axis(ifft, 1, np.asarray(arr))) def test(n=128): + print('\nsequence length:', n) print('fft') li = np.random.random(n) print(np.allclose(fft(li), np.fft.fft(li))) @@ -58,4 +104,5 @@ def test(n=128): if __name__ == '__main__': - test(128) + for i in range(1, 4): + test(i * 16)