如何在numba.njit中实现离散傅里叶变换(FFT)?

8

大家好,我正在尝试使用 numba.njit 装饰器,在这个 最小工作示例 中进行 离散傅里叶变换(discrete Fourier transform)

import numba
import numpy as np
import scipy
import scipy.fftpack

@numba.njit
def main():
    wave = [[[0.09254795,  0.10001078,  0.10744892, 0.07755555,  0.08506225, 0.09254795],
          [0.09907245,  0.10706145,  0.11502401,  0.08302302,  0.09105898, 0.09907245],
          [0.09565098,  0.10336405,  0.11105158,  0.08015589,  0.08791429, 0.09565098],
          [0.00181467,  0.001961,    0.00210684,  0.0015207,   0.00166789, 0.00181467]],
         [[-0.45816267, - 0.46058367, - 0.46289091, - 0.45298182, - 0.45562851, -0.45816267],
          [-0.49046506, - 0.49305676, - 0.49552669, - 0.48491893, - 0.48775223, -0.49046506],
          [-0.47352483, - 0.47602701, - 0.47841162, - 0.46817027, - 0.4709057, -0.47352483],
          [-0.00898358, - 0.00903105, - 0.00907629, - 0.008882, - 0.00893389, -0.00898358]],
         [[0.36561472,  0.36057289,  0.355442,  0.37542627,  0.37056626, 0.36561472],
          [0.39139261,  0.38599531,  0.38050268,  0.40189591,  0.39669325, 0.39139261],
          [0.37787385,  0.37266296,  0.36736003,  0.38801438,  0.38299141, 0.37787385],
          [0.00716892,  0.00707006,  0.00696945,  0.0073613,  0.00726601, 0.00716892]]]

    new_fft = scipy.fftpack.fft(wave)


if __name__ == '__main__':
    main()

输出:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'scipy.fftpack' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fftpack\\__init__.py'>)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fftpack.fft(wave)
    ^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fftpack.fft(wave)
    ^


Process finished with exit code 1

不幸的是,scipy.fftpack.fft 看起来是一个已经被淘汰的函数,不受 numba 支持。因此我寻找替代方案。我找到了两个:

1. scipy.fft(wave) 是上述淘汰函数的更新版本。它会产生以下错误输出:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) with parameters (list(list(list(float64))))
No type info available for Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) as a callable.
[1] During: resolving callee type: Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)


File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fft(wave)
    ^


Process finished with exit code 1

2. np.fft.fft(wave) 似乎被支持,但也会产生错误:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'numpy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\numpy\\fft\\__init__.py'>)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = np.fft.fft(wave)
    ^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = np.fft.fft(wave)
    ^


Process finished with exit code 1

你是否了解一种与 numba.njit 装饰器配合使用的 fft 函数?


我猜 scipy.fft.fft 是一个非常快的实现,你真的需要 JIT 吗? - tstanisl
numba numpy函数列表中可以看出,不支持numpy fft模块,所以你的第二种情况似乎很正常。 - Yacola
@zariiii9003,仅使用对象模式会减慢fft函数还是整个njit函数的速度? - Artur Müller Romanov
1
pyculib.fft 应该得到支持。https://devblogs.nvidia.com/seven-things-numba/ - norok2
@norok2 我尝试使用 pyculib,但在 PyCharm 中似乎无法工作,至少我无法让它工作。我已经发布了一个帖子:https://stackoverflow.com/questions/62360236/how-to-import-pyculib-in-pycharm - Artur Müller Romanov
显示剩余2条评论
3个回答

5
如果您对一维离散傅里叶变换感到满意,那么最好使用快速傅里叶变换(FFT)。这里报告了一个适用于Numba的实现fft_1d(),可处理任意输入大小:
import cmath
import numpy as np
import numba as nb


@nb.jit
def ilog2(n):
    result = -1
    if n < 0:
        n = -n
    while n > 0:
        n >>= 1
        result += 1
    return result


@nb.njit(fastmath=True)
def reverse_bits(val, width):
    result = 0
    for _ in range(width):
        result = (result << 1) | (val & 1)
        val >>= 1
    return result


@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    levels = ilog2(n)
    e_arr = np.empty_like(arr)
    coeff = (-2j if direct else 2j) * cmath.pi / n
    for i in range(n):
        e_arr[i] = cmath.exp(coeff * i)
    result = np.empty_like(arr)
    for i in range(n):
        result[i] = arr[reverse_bits(i, levels)]
    # Radix-2 decimation-in-time FFT
    size = 2
    while size <= n:
        half_size = size // 2
        step = n // size
        for i in range(0, n, size):
            k = 0
            for j in range(i, i + half_size):
                temp = result[j + half_size] * e_arr[k]
                result[j + half_size] = result[j] - temp
                result[j] += temp
                k += step
        size *= 2
    return result


@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
    """1D FFT for arbitrary inputs using chirp z-transform"""
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    m = 1 << (ilog2(n) + 2)
    e_arr = np.empty(n, dtype=np.complex128)
    for i in range(n):
        e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
    result = np.zeros(m, dtype=np.complex128)
    result[:n] = arr * e_arr
    coeff = np.zeros_like(result)
    coeff[:n] = e_arr.conjugate()
    coeff[-n + 1:] = e_arr[:0:-1].conjugate()
    return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m


@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
    return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)


@nb.njit(fastmath=True)
def fft_1d(arr):
    n = len(arr)
    if not n & (n - 1):
        return fft_1d_radix2_rbi(arr)
    else:
        return fft_1d_arb(arr)

与天真的DFT算法(基本上与this相同的dft_1d())相比,您将获得数量级的提升,但速度仍然通常比np.fft.fft()慢得多。

vs_dft

相对速度因输入大小而异。对于2的幂次方输入,这通常在一个数量级内与np.fft.fft()相当。

pow2

对于非2的幂次方,这通常在np.fft.fft()的两个数量级内。

not-pow2

对于最坏情况(如质数等),这比np.fft.fft()快了a倍,其中a是2的幂+1。

primes

FFT时间的非线性行为是需要更复杂的算法来处理不是2的幂次方的任意输入大小的结果。这会影响到这个实现和np.fft.fft()的实现,但是np.fft.fft()包含了更多的优化,使其平均表现更好。
2的幂次方FFT的其他实现在这里展示。

2
Numba文档提到np.fft.fft不受支持。解决方法是使用objmode上下文调用尚未支持的Python函数。仅在objmode上下文中的部分将以对象模式运行,因此可能会很慢。对于您的特定情况,这部分不会太慢,因为np.fft.fft已经非常快,正如@tstanisl在问题的第一个评论中指出的那样。以下是一个例子。
from numba import njit
import numpy as np

@njit()
def compute_fft(x):
   y = np.zeros(., dtype=np.complex128) 
   with objmode(y='type[:]'):
      y = np.fft.fft(x)
   return y

@njit()
def main():
   ...
   x = np.random.randint(100)
   fft_x = compute_fft(x) 
   ...

1
我找到了一个解决方法。需要注意的是,像 numpy.fft.fft 这样的函数有很多方便的操作,所以如果你没有像我一样卡住,应该使用它们。
下面的 njit 函数对一个 一维数组 进行 离散傅里叶变换
import numba
import numpy as np
import cmath

def dft(wave=None):
    dft = np.fft.fft(wave)
    return dft

@numba.njit
def dft_njit(wave=None):
    N = len(wave)
    dft_njit = np.zeros(N, dtype=np.complex128)
    for i in range(N):
        series_element = 0
        for n in range(N):
            series_element += wave[n] * cmath.exp(-2j * cmath.pi * i * n * (1 / N))
        dft_njit[i] = series_element
    return dft_njit

if __name__ == '__main__':

    wave = [1,2,3,4,5]
    wave = np.array(wave)

    print(f' dft: \n{dft(wave=wave)}')
    print(f' dft_njit: \n{dft_njit(wave=wave)}')

输出:

 dft: 
[15. +0.j         -2.5+3.4409548j  -2.5+0.81229924j -2.5-0.81229924j
 -2.5-3.4409548j ]
 dft_njit: 
[15. +0.j         -2.5+3.4409548j  -2.5+0.81229924j -2.5-0.81229924j
 -2.5-3.4409548j ]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接