在numpy/openblas上运行时设置线程的最大数量

17

我想知道是否可以在(Python)运行时更改numpy后面的OpenBLAS使用的最大线程数?

我知道可以通过环境变量OMP_NUM_THREADS在运行解释器之前设置它,但我想在运行时更改它。

通常,在使用MKL而不是OpenBLAS时,这是可能的:

import mkl
mkl.set_num_threads(n)

2
您可以尝试使用ctypes模块调用openblas_set_num_threads函数。类似于这个问题 - user2379410
2个回答

16

您可以通过使用ctypes调用openblas_set_num_threads函数来实现此操作。我经常发现自己想要这样做,所以我编写了一个小的上下文管理器:

import contextlib
import ctypes
from ctypes.util import find_library

# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/
# from Ubuntu repos
try_paths = ['/opt/OpenBLAS/lib/libopenblas.so',
             '/lib/libopenblas.so',
             '/usr/lib/libopenblas.so.0',
             find_library('openblas')]
openblas_lib = None
for libpath in try_paths:
    try:
        openblas_lib = ctypes.cdll.LoadLibrary(libpath)
        break
    except OSError:
        continue
if openblas_lib is None:
    raise EnvironmentError('Could not locate an OpenBLAS shared library', 2)


def set_num_threads(n):
    """Set the current number of threads used by the OpenBLAS server."""
    openblas_lib.openblas_set_num_threads(int(n))


# At the time of writing these symbols were very new:
# https://github.com/xianyi/OpenBLAS/commit/65a847c
try:
    openblas_lib.openblas_get_num_threads()
    def get_num_threads():
        """Get the current number of threads used by the OpenBLAS server."""
        return openblas_lib.openblas_get_num_threads()
except AttributeError:
    def get_num_threads():
        """Dummy function (symbol not present in %s), returns -1."""
        return -1
    pass

try:
    openblas_lib.openblas_get_num_procs()
    def get_num_procs():
        """Get the total number of physical processors"""
        return openblas_lib.openblas_get_num_procs()
except AttributeError:
    def get_num_procs():
        """Dummy function (symbol not present), returns -1."""
        return -1
    pass


@contextlib.contextmanager
def num_threads(n):
    """Temporarily changes the number of OpenBLAS threads.

    Example usage:

        print("Before: {}".format(get_num_threads()))
        with num_threads(n):
            print("In thread context: {}".format(get_num_threads()))
        print("After: {}".format(get_num_threads()))
    """
    old_n = get_num_threads()
    set_num_threads(n)
    try:
        yield
    finally:
        set_num_threads(old_n)

您可以像这样使用它:

with num_threads(8):
    np.dot(x, y)

如评论中所述,openblas_get_num_threadsopenblas_get_num_procs是非常新的功能,在编写此文时可能不可用,除非您从最新版本的源代码编译OpenBLAS。


2
请注意,从v0.2.14版本开始,pthread openblas_get_num_procs不考虑亲和性,因此在可用CPU数量受限(例如在容器中)时可能会导致过度订阅。请改用len(os.sched_getaffinity(0))(Python >= 3.3)。 - jtaylor
@jtaylor 很好的想法,我在思考,是否可以在运行时更改线程绑定。例如,我希望某些任务在第一个 CPU 插槽上使用 8 个线程完成,而其他任务在第二个 CPU 插槽上使用单个线程完成。 - Y00

14
我们最近开发了一个名为 threadpoolctl 的跨平台包,用于控制 Python 中调用 C 级别线程池时所使用的线程数。它类似于 @ali_m 给出的解决方案,但是它可以通过循环遍历所有加载的库来自动检测需要限制的库。它还提供了内省 API。
该软件包可使用 pip install threadpoolctl 安装,并带有一个上下文管理器,允许您控制诸如 numpy 等软件包使用的线程数:
from threadpoolctl import threadpool_limits
import numpy as np


with threadpool_limits(limits=1, user_api='blas'):
    # In this block, calls to blas implementation (like openblas or MKL)
    # will be limited to use only one thread. They can thus be used jointly
    # with thread-parallelism.
    a = np.random.randn(1000, 1000)
    a_squared = a @ a

您还可以对不同的线程池进行更精细的控制(例如区分blasopenmp调用)。

注意:该软件包仍在开发中,欢迎提供任何反馈意见。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接