在Python中在进程之间共享连续的NumPy数组

24

虽然我找到了很多类似我的问题的答案,但我认为这里还没有直接解决它,并且我有几个额外的问题。共享连续的numpy数组的动机如下:

  • 我正在使用在Caffe上运行的卷积神经网络对图像进行回归,将它们映射到一系列连续值标签。
  • 这些图像需要特定的预处理和数据增强。
  • 标签连续 (是浮点数) 和数据增强的限制意味着我在Python中预处理数据,然后通过在内存中的数据层中使用连续的numpy数组来提供它。
  • 将训练数据加载到内存中相对较慢。我想并行化它,以便:

(1) 我编写的 Python 创建一个 "数据处理器" 类,该类实例化两个连续的 numpy 数组。 (2) 工作进程交替使用这些 numpy 数组,从磁盘加载数据,执行预处理,并将数据插入到 numpy 数组中。 (3) 同时,Python Caffe 包装器将数据从另一个数组发送到 GPU 中运行网络。

我有几个问题:

  1. 是否有可能分配连续的 numpy 数组内存,然后使用 Python 的 multiprocessing 中的 Array 类之类的东西将其包装在共享内存对象中(我不确定“对象”是否是这里的正确术语)?

  2. numpy 数组有一个 .ctypes 属性,我认为这对于使用 Array() 实例化共享内存数组很有用,但似乎无法确定如何精确使用它们。

  3. 如果共享内存没有实例化 numpy 数组,则它是否仍然保持连续? 如果没有,是否有一种方法可以确保它保持连续?

是否有可能做到这样的事情:

import numpy as np
from multiprocessing import Array
contArr = np.ascontiguousarray(np.zeros((n_images, n_channels, img_height, img_width)), dtype=np.float32)
sm_contArr = Array(contArr.ctypes.?, contArr?)

然后用以下方式实例化工作者

p.append(Process(target=some_worker_function, args=(data_to_load, sm_contArr)))
p.start()

谢谢!

编辑:我知道有很多库具有类似的功能,但维护状态各不相同。我更倾向于只使用纯Python和NumPy,但如果不可能,我当然愿意使用其中一个。


这只是用于预测阶段吗?还是您也想以这种方式训练您的神经网络? - user1269942
这实际上是用于训练和预测的。 - eriophora
1
类似于这个链接: https://dev59.com/T2035IYBdhLWcg3wNtVw#5550156 ? - BeRecursive
你的问题听起来和我的非常相似,你能解决了吗?看看我写的内容: https://dev59.com/2pDea4cB1Zd3GeqPgcgT - alfredox
我没能解决它,最后只好把数组复制过去并承受性能损失。 - eriophora
1个回答

7

将numpy的ndarray封装到multiprocessing的RawArray()

有多种方法可以在进程之间共享内存中的numpy数组。让我们看看如何使用multiprocessing模块来实现。

第一个重要的观察是,numpy提供了np.frombuffer()函数,可以将ndarray接口封装到支持缓冲区协议(例如bytes()bytearray()array()等)的预先存在的对象中。这将从只读对象创建只读数组和从可写对象创建可写数组。

我们可以将其与multiprocessing提供的共享内存RawArray()结合使用。请注意,Array()不能用于此目的,因为它是一个带有锁的代理对象,并且不直接公开缓冲区接口。当然,这意味着我们需要自己提供对numpified RawArrays的适当同步。

关于ndarray包装的RawArrays,有一个复杂的问题:当multiprocessing在进程之间发送这样的数组时 - 事实上,一旦创建,它将需要向两个工作进程发送我们的数组 - 它会对它们进行pickle和unpickle。不幸的是,这导致它创建ndarrays的副本,而不是在内存中共享它们。

解决方案有点丑陋,就是保持RawArrays不变,直到它们被传输到工作进程并且只有在每个工作进程启动后再将它们封装在ndarrays中

此外,最好通过multiprocessing.Queue直接通信数组,无论是纯粹的RawArray还是包装了ndarray,但这也行不通。RawArray不能放入这样的队列中,而且包装了ndarray的数组会被pickled和unpickled,因此实际上是复制了一份。

解决方法是向工作进程发送所有预分配数组的列表,并通过Queues通信索引。这非常像传递令牌(索引),谁持有令牌就允许操作相关数组。

主程序的结构可能如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import queue

from multiprocessing import freeze_support, set_start_method
from multiprocessing import Event, Process, Queue
from multiprocessing.sharedctypes import RawArray


def create_shared_arrays(size, dtype=np.int32, num=2):
    dtype = np.dtype(dtype)
    if dtype.isbuiltin and dtype.char in 'bBhHiIlLfd':
        typecode = dtype.char
    else:
        typecode, size = 'B', size * dtype.itemsize

    return [RawArray(typecode, size) for _ in range(num)]


def main():
    my_dtype = np.float32

    # 125000000 (size) * 4 (dtype) * 2 (num) ~= 1 GB memory usage
    arrays = create_shared_arrays(125000000, dtype=my_dtype)
    q_free = Queue()
    q_used = Queue()
    bail = Event()

    for arr_id in range(len(arrays)):
        q_free.put(arr_id)  # pre-fill free queue with allocated array indices

    pr1 = MyDataLoader(arrays, q_free, q_used, bail,
                       dtype=my_dtype, step=1024)
    pr2 = MyDataProcessor(arrays, q_free, q_used, bail,
                          dtype=my_dtype, step=1024)

    pr1.start()
    pr2.start()

    pr2.join()
    print("\n{} joined.".format(pr2.name))

    pr1.join()
    print("{} joined.".format(pr1.name))


if __name__ == '__main__':
    freeze_support()

    # On Windows, only "spawn" is available.
    # Also, this tests proper sharing of the arrays without "cheating".
    set_start_method('spawn')
    main()

这将准备一个包含两个数组的列表,即两个“队列” - 一个是“自由”队列,MyDataProcessor 将完成处理的数组索引放入其中,MyDataLoader 则从它那里获取;另一个是“已使用”队列,MyDataLoader 将已填充数组的索引放入其中,MyDataProcessor 则从中获取。此外还有一个 multiprocessing.Event 用户启动所有工作线程的有序退出。当前我们只有一个生产者和一个消费者,可以先不用后者,但为了应对更多工作线程,最好也准备一下。
然后我们用列表中的所有 RawArrays 的索引预先填充“空”Queue,并实例化每种类型的工作线程,传递必要的通信对象给它们。接着启动这两个工作线程,并等待它们执行完毕(join())。
以下是 MyDataProcessor 的示例,从“已使用”Queue中消耗数组索引,然后将数据发送到某个外部黑盒(在本例中为 debugio.output):
class MyDataProcessor(Process):
    def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
        super().__init__()
        self.arrays = arrays
        self.q_free = q_free
        self.q_used = q_used
        self.bail = bail
        self.dtype = dtype
        self.step = step

    def run(self):
        # wrap RawArrays inside ndarrays
        arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]

        from debugio import output as writer

        while True:
            arr_id = self.q_used.get()
            if arr_id is None:
                break

            arr = arrays[arr_id]

            print('(', end='', flush=True)          # just visualizing activity
            for j in range(0, len(arr), self.step):
                writer.write(str(arr[j]) + '\n')
            print(')', end='', flush=True)          # just visualizing activity

            self.q_free.put(arr_id)

            writer.flush()

        self.bail.set()                     # tell loaders to bail out ASAP
        self.q_free.put(None, timeout=1)    # wake up loader blocking on get()

        try:
            while True:
                self.q_used.get_nowait()    # wake up loader blocking on put()
        except queue.Empty:
            pass

它的第一步是使用'np.frombuffer()'将接收到的原始数组RawArrays封装成ndarrays并保留新列表,这样在处理过程中它们就可以用作numpy数组,而不需要反复封装。
还要注意,MyDataProcessor仅向self.bail Event写入,它从不检查。相反,如果需要告诉它退出,则会在队列上找到一个None标记,而不是数组索引。当MyDataLoader没有更多可用数据并开始拆卸过程时,这么做可以使MyDataProcessor能够在不过早退出的情况下处理所有有效的数组。
以下是MyDataLoader的示例代码:
class MyDataLoader(Process):
    def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
        super().__init__()
        self.arrays = arrays
        self.q_free = q_free
        self.q_used = q_used
        self.bail = bail
        self.dtype = dtype
        self.step = step

    def run(self):
        # wrap RawArrays inside ndarrays
        arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]

        from debugio import input as reader

        for _ in range(10):  # for testing we end after a set amount of passes
            if self.bail.is_set():
                # we were asked to bail out while waiting on put()
                return

            arr_id = self.q_free.get()
            if arr_id is None:
                # we were asked to bail out while waiting on get()
                self.q_free.put(None, timeout=1)  # put it back for next loader
                return

            if self.bail.is_set():
                # we were asked to bail out while we got a normal array
                return

            arr = arrays[arr_id]

            eof = False
            print('<', end='', flush=True)          # just visualizing activity
            for j in range(0, len(arr), self.step):
                line = reader.readline()
                if not line:
                    eof = True
                    break

                arr[j] = np.fromstring(line, dtype=self.dtype, sep='\n')

            if eof:
                print('EOF>', end='', flush=True)   # just visualizing activity
                break

            print('>', end='', flush=True)          # just visualizing activity

            if self.bail.is_set():
                # we were asked to bail out while we filled the array
                return

            self.q_used.put(arr_id)     # tell processor an array is filled

        if not self.bail.is_set():
            self.bail.set()             # tell other loaders to bail out ASAP
            # mark end of data for processor as we are the first to bail out
            self.q_used.put(None)

它的结构非常类似于另一个worker。它有点膨胀是因为它在许多点检查self.bail事件,从而减少被卡住的可能性。(这并不完全可靠,因为在检查和访问队列之间设置事件的微小机会。如果这是个问题,则需要使用一些同步原语仲裁对事件队列的访问。)

它还在最开始将接收到的RawArrays包装成ndarrays,并从外部黑盒子(例如debugio.input)读取数据。

注意,在main()函数中通过调整两个worker的step=参数,我们可以改变读写比例(严格用于测试目的,在生产环境中step=1,读写所有numpy数组成员)。

增加这两个值使工作进程只访问numpy数组中的一些值,从而显着提高了速度,这表明性能不受worker进程之间通信的限制。如果我们直接将numpy数组放入队列中,在进程之间完全复制它们来回传递,增加步长不会显著提高性能-它将保持缓慢。

供参考,这是我用于测试的debugio模块:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from ast import literal_eval
from io import RawIOBase, BufferedReader, BufferedWriter, TextIOWrapper


class DebugInput(RawIOBase):
    def __init__(self, end=None):
        if end is not None and end < 0:
            raise ValueError("end must be non-negative")

        super().__init__()
        self.pos = 0
        self.end = end

    def readable(self):
        return True

    def read(self, size=-1):
        if self.end is None:
            if size < 0:
                raise NotImplementedError("size must be non-negative")
            end = self.pos + size
        elif size < 0:
            end = self.end
        else:
            end = min(self.pos + size, self.end)

        lines = []
        while self.pos < end:
            offset = self.pos % 400
            pos = self.pos - offset
            if offset < 18:
                i = (offset + 2) // 2
                pos += i * 2 - 2
            elif offset < 288:
                i = (offset + 12) // 3
                pos += i * 3 - 12
            else:
                i = (offset + 112) // 4
                pos += i * 4 - 112

            line = str(i).encode('ascii') + b'\n'
            line = line[self.pos - pos:end - pos]
            self.pos += len(line)
            size -= len(line)
            lines.append(line)

        return b''.join(lines)

    def readinto(self, b):
        data = self.read(len(b))
        b[:len(data)] = data
        return len(data)

    def seekable(self):
        return True

    def seek(self, offset, whence=0):
        if whence == 0:
            pos = offset
        elif whence == 1:
            pos = self.pos + offset
        elif whence == 2:
            if self.end is None:
                raise ValueError("cannot seek to end of infinite stream")
            pos = self.end + offset
        else:
            raise NotImplementedError("unknown whence value")

        self.pos = max((pos if self.end is None else min(pos, self.end)), 0)
        return self.pos


class DebugOutput(RawIOBase):
    def __init__(self):
        super().__init__()
        self.buf = b''
        self.num = 1

    def writable(self):
        return True

    def write(self, b):
        *lines, self.buf = (self.buf + b).split(b'\n')

        for line in lines:
            value = literal_eval(line.decode('ascii'))
            if value != int(value) or int(value) & 255 != self.num:
                raise ValueError("expected {}, got {}".format(self.num, value))

            self.num = self.num % 127 + 1

        return len(b)


input = TextIOWrapper(BufferedReader(DebugInput()), encoding='ascii')
output = TextIOWrapper(BufferedWriter(DebugOutput()), encoding='ascii')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接