使用Numba与Scipy例程

3

我正在编写一个程序,在其中的某些点上使用了Scipy CubicSpline例程,由于使用了Scipy例程,我不能在整个程序上使用Numba @jit。

我最近遇到了@overload特性,想知道它是否可以这样使用。

from numba.extending import overload
from numba import jit
from scipy.interpolate import CubicSpline
import numpy as np

x = np.arange(10)
y = np.sin(x)
xs = np.arange(-0.5, 9.6, 0.1)

def Spline_interp(xs,x,y):
    cs = CubicSpline(x, y)
    ds = cs(xs)
    return ds

@overload(Spline_interp)
def jit_Spline_interp(xs,x,y):
   ds = Spline_interp(xs,x,y)
   def jit_Spline_interp_impl(xs,x, y):
       return ds
   return jit_Spline_interp_impl

@jit(nopython=True)
def main():

    # other codes compatible with @njit

    ds = Spline_interp(xs,x,y)

    # other codes compatible with @njit

    return ds

print(main())

如果我对@overload功能的理解有误,您可以进行更正,并请问如何使用Numba来使用这些Scipy库。


1
Scipy已经是一个高性能库,所以我不指望numba在这里有什么区别。 - s.ouchene
1
这部分是否关键性能,还是你代码的其他部分更相关?-> https://numba.pydata.org/numba-doc/latest/user/withobjmode.html 如果不是,那么需要更多的工作来完成。 - max9111
1
谢谢您的回复。我知道scipy插值非常快。它不是我的代码瓶颈,但我的代码中其他部分是。目前我所做的是将程序拆分成许多不使用三次样条和使用numba jit的函数。但这使得程序有点混乱。我想知道过载功能是否是解决此问题的方法。这样我就可以jit整个程序,仍然能够使用scipy.interp而不是编写自己的插值例程。 - sreenath subramaniam
3个回答

2

使用ctypes(Numba)封装编译函数

特别是对于更复杂的函数,重新用numba可编译的Python代码实现所有内容可能需要很多工作,有时还会更慢。以下答案将介绍直接从共享对象或动态库调用类似C的函数。

编译Fortran例程

这个例子将展示在Windows上如何做到这一点,但在其他操作系统上也应该很简单。为了具有可移植性,建议定义一个ISO_C_BINDING接口。在这个答案中,我将尝试不使用接口来实现它。

dll.def

EXPORTS
   SPLEV @1

编译

ifort /dll dll.def splev.f fpbspl.f /O3 /fast

直接从Numba调用此函数

  • 查看Fortran例程所需的内容
  • 检查包装器中的每个输入(数据类型、连续性)。您只需提供一些指向Fortran函数的指针。没有额外的安全检查。

包装器

以下代码展示了两种调用此函数的方式。在Numba中,不可能直接通过引用传递标量。您可以在堆上分配一个数组(对于小型函数而言速度较慢),或者使用内置函数来使用堆栈数组。

import numba as nb
import numpy as np
import ctypes

lib = ctypes.cdll.LoadLibrary("splev.dll")

dble_p=ctypes.POINTER(ctypes.c_double)
int_p =ctypes.POINTER(ctypes.c_longlong)

SPLEV=lib.SPLEV
SPLEV.restype =  ctypes.c_void_p
SPLEV.argtypes = (dble_p,int_p,dble_p,int_p,dble_p,dble_p,int_p,int_p,int_p)

from numba import types
from numba.extending import intrinsic
from numba.core import cgutils

@intrinsic
def val_to_ptr(typingctx, data):
    def impl(context, builder, signature, args):
        ptr = cgutils.alloca_once_value(builder,args[0])
        return ptr
    sig = types.CPointer(nb.typeof(data).instance_type)(nb.typeof(data).instance_type)
    return sig, impl

@intrinsic
def ptr_to_val(typingctx, data):
    def impl(context, builder, signature, args):
        val = builder.load(args[0])
        return val
    sig = data.dtype(types.CPointer(data.dtype))
    return sig, impl

#with intrinsics, temporary arrays are allocated on stack
#faster but much more relevant for functions with very low runtime
@nb.njit()
def splev_wrapped(x, coeff,e):
    #There are just pointers passed to the fortran function.
    #The arrays have to be contiguous!
    t=np.ascontiguousarray(coeff[0])
    x=np.ascontiguousarray(x)
    
    c=coeff[1]
    k=coeff[2]
    
    y=np.empty(x.shape[0],dtype=np.float64)
    
    n_arr=val_to_ptr(nb.int64(t.shape[0]))
    k_arr=val_to_ptr(nb.int64(k))
    m_arr=val_to_ptr(nb.int64(x.shape[0]))
    e_arr=val_to_ptr(nb.int64(e))
    ier_arr=val_to_ptr(nb.int64(0))
    
    SPLEV(t.ctypes,n_arr,c.ctypes,k_arr,x.ctypes,
        y.ctypes,m_arr,e_arr,ier_arr)
    return y, ptr_to_val(ier_arr)

#without using intrinsics
@nb.njit()
def splev_wrapped_2(x, coeff,e):
    #There are just pointers passed to the fortran function.
    #The arrays have to be contiguous!
    t=np.ascontiguousarray(coeff[0])
    x=np.ascontiguousarray(x)
    
    c=coeff[1]
    k=coeff[2]
    y=np.empty(x.shape[0],dtype=np.float64)
    
    n_arr = np.empty(1,  dtype=np.int64)
    k_arr = np.empty(1,  dtype=np.int64)
    m_arr = np.empty(1,  dtype=np.int64)
    e_arr = np.empty(1,  dtype=np.int64)
    ier_arr = np.zeros(1,  dtype=np.int64)
    
    n_arr[0]=t.shape[0]
    k_arr[0]=k
    m_arr[0]=x.shape[0]
    e_arr[0]=e
    
    SPLEV(t.ctypes,n_arr.ctypes,c.ctypes,k_arr.ctypes,x.ctypes,
        y.ctypes,m_arr.ctypes,e_arr.ctypes,ier_arr.ctypes)
    return y, ier_arr[0]

非常感谢您付出的努力,并写出了一个良好的示例。我相信上述示例一定会对那些想要在numba jitted函数内使用scipy程序的人有所帮助。 - sreenath subramaniam

2
你可以选择要么在本地回退到object-mode(就像@max9111建议的那样),或者在Numba中自己实现CubicSpline函数。
据我所知,重载装饰器“仅”使编译器意识到如果遇到重载函数,则可以使用与Numba兼容的实现。它不能神奇地将函数转换为Numba兼容。
有一个包将一些Scipy功能暴露给Numba,但这似乎还处于早期阶段,目前仅包含一些scipy.special函数。

https://github.com/numba/numba-scipy


0
这是我在 numba 论坛上发布的解决方案的转载 https://numba.discourse.group/t/call-scipy-splev-routine-in-numba-jitted-function/1122/7
起初,我采用了 @max9111 建议的 objmode。它提供了一个临时的解决方案。但由于代码对性能要求很高,最终我编写了一个 numba 版本的 scipy 的 'interpolate.splev' 子程序进行样条插值。
import numpy as np
import numba
from scipy import interpolate
import matplotlib.pyplot as plt
import time

# Custom wrap of scipy's splrep
def custom_splrep(x, y, k=3):
    
    """
    Custom wrap of scipy's splrep for calculating spline coefficients, 
    which also check if the data is equispaced.
    
    """
    
    # Check if x is equispaced
    x_diff = np.diff(x)
    equi_spaced = all(np.round(x_diff,5) == np.round(x_diff[0],5))
    dx = x_diff[0]
    
    # Calculate knots & coefficients (cubic spline by default)
    t,c,k = interpolate.splrep(x,y, k=k) 
    
    return (t,c,k,equi_spaced,dx) 

# Numba accelerated implementation of scipy's splev
@numba.njit(cache=True)
def numba_splev(x, coeff):
    
    """
    Custom implementation of scipy's splev for spline interpolation, 
    with additional section for faster search of knot interval, if knots are equispaced.
    Spline is extrapolated from the end spans for points not in the support.
    
    """
    t,c,k, equi_spaced, dx = coeff
    
    t0 = t[0]
    
    n = t.size
    m = x.size
    
    k1  = k+1
    k2  = k1+1
    nk1 = n - k1
    
    l  = k1
    l1 = l+1
    
    y = np.zeros(m)
    
    h  = np.zeros(20)
    hh = np.zeros(19)

    for i in range(m):
        
       # fetch a new x-value arg
       arg = x[i]
       
       # search for knot interval t[l] <= arg <= t[l+1]
       if(equi_spaced):
           l = int((arg-t0)/dx) + k
           l = min(max(l, k1), nk1)
       else:
           while not ((arg >= t[l-1]) or (l1 == k2)):
               l1 = l
               l  = l-1
           while not ((arg < t[l1-1]) or (l == nk1)):
               l = l1
               l1 = l+1
       
       # evaluate the non-zero b-splines at arg.    
       h[:]  = 0.0
       hh[:] = 0.0
       
       h[0] = 1.0
       
       for j in range(k):
       
           for ll in range(j+1):
               hh[ll] = h[ll]
           h[0] = 0.0
       
           for ll in range(j+1):
               li = l + ll 
               lj = li - j - 1
               if(t[li] != t[lj]):
                   f = hh[ll]/(t[li]-t[lj])
                   h[ll] += f*(t[li]-arg)
                   h[ll+1] = f*(arg-t[lj])
               else:
                   h[ll+1] = 0.0
                   break
       
       sp = 0.0
       ll = l - 1 - k1
       
       for j in range(k1):
           ll += 1
           sp += c[ll]*h[j]
       y[i] = sp
    
    return y

######################### Testing and comparison #############################

# Generate a data set for interpolation
x, dx = np.linspace(10,100,200, retstep=True)
y = np.sin(x)

# Calculate the cubic spline spline coeff's
coeff_1 = interpolate.splrep(x,y, k=3)  # scipy's splrep
coeff_2 = custom_splrep(x,y, k=3)       # Custom wrap of scipy's splrep

# Generate data for interpolation and randomize
x2 = np.linspace(0,110,10000) 
np.random.shuffle(x2)

# Interpolate
y2 = interpolate.splev(x2, coeff_1) # scipy's splev
y3 = numba_splev(x2, coeff_2)       # Numba accelerated implementation of scipy's splev

# Plot data
plt.plot(x,y,'--', linewidth=1.0,color='green', label='data')
plt.plot(x2,y2,'o',color='blue', markersize=2.0, label='scipy splev')
plt.plot(x2,y3,'.',color='red',  markersize=1.0, label='numba splev')
plt.legend()
plt.show()

print("\nTime for random interpolations")
# Calculation time evaluation for scipy splev
t1 = time.time()
for n in range(0,10000):
  y2 = interpolate.splev(x2, coeff_1)
print("scipy splev", time.time() - t1)

# Calculation time evaluation for numba splev
t1 = time.time()
for n in range(0,10000):
  y2 = numba_splev(x2, coeff_2)
print("numba splev",time.time() - t1)

print("\nTime for non random interpolations")
# Generate data for interpolation without randomize
x2 = np.linspace(0,110,10000) 

# Calculation time evaluation for scipy splev
t1 = time.time()
for n in range(0,10000):
  y2 = interpolate.splev(x2, coeff_1)
print("scipy splev", time.time() - t1)

# Calculation time evaluation for numba splev
t1 = time.time()
for n in range(0,10000):
  y2 = numba_splev(x2, coeff_2)
print("numba splev",time.time() - t1)

如果节点间距相等,则上述代码针对更快的节点搜索进行了优化。 在我的corei7机器上,如果在随机值处进行插值,则numba版本更快,

Scipy的splev = 0.896秒 Numba的splev = 0.375秒

如果不在随机值处进行插值,则Scipy的版本更快,

Scipy的splev = 0.281秒 Numba的splev = 0.375秒

参考:https://github.com/scipy/scipy/tree/v1.7.1/scipy/interpolate/fitpackhttps://github.com/dbstein/fast_splines


你是否成功地从Numba直接调用了预编译的Fortran函数?如果需要,我可以添加一个关于如何实现的答案。 - max9111
@max9111,我还没有能够做到。如果您能够友好地发布调用编译后的Fortran函数的步骤,并提供一个简单的示例,那将非常有帮助。 - sreenath subramaniam

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接