将numpy的ndarray
封装到multiprocessing的RawArray()
中
有多种方法可以在进程之间共享内存中的numpy数组。让我们看看如何使用multiprocessing模块来实现。
第一个重要的观察是,numpy提供了np.frombuffer()
函数,可以将ndarray接口封装到支持缓冲区协议(例如bytes()
、bytearray()
、array()
等)的预先存在的对象中。这将从只读对象创建只读数组和从可写对象创建可写数组。
我们可以将其与multiprocessing提供的共享内存RawArray()
结合使用。请注意,Array()
不能用于此目的,因为它是一个带有锁的代理对象,并且不直接公开缓冲区接口。当然,这意味着我们需要自己提供对numpified RawArrays的适当同步。
关于ndarray包装的RawArrays,有一个复杂的问题:当multiprocessing在进程之间发送这样的数组时 - 事实上,一旦创建,它将需要向两个工作进程发送我们的数组 - 它会对它们进行pickle和unpickle。不幸的是,这导致它创建ndarrays的副本,而不是在内存中共享它们。
解决方案有点丑陋,就是保持RawArrays不变,直到它们被传输到工作进程并且只有在每个工作进程启动后再将它们封装在ndarrays中。
此外,最好通过multiprocessing.Queue直接通信数组,无论是纯粹的RawArray还是包装了ndarray,但这也行不通。RawArray不能放入这样的队列中,而且包装了ndarray的数组会被pickled和unpickled,因此实际上是复制了一份。
解决方法是向工作进程发送所有预分配数组的列表,并通过Queues通信索引。这非常像传递令牌(索引),谁持有令牌就允许操作相关数组。
主程序的结构可能如下:
import numpy as np
import queue
from multiprocessing import freeze_support, set_start_method
from multiprocessing import Event, Process, Queue
from multiprocessing.sharedctypes import RawArray
def create_shared_arrays(size, dtype=np.int32, num=2):
dtype = np.dtype(dtype)
if dtype.isbuiltin and dtype.char in 'bBhHiIlLfd':
typecode = dtype.char
else:
typecode, size = 'B', size * dtype.itemsize
return [RawArray(typecode, size) for _ in range(num)]
def main():
my_dtype = np.float32
arrays = create_shared_arrays(125000000, dtype=my_dtype)
q_free = Queue()
q_used = Queue()
bail = Event()
for arr_id in range(len(arrays)):
q_free.put(arr_id)
pr1 = MyDataLoader(arrays, q_free, q_used, bail,
dtype=my_dtype, step=1024)
pr2 = MyDataProcessor(arrays, q_free, q_used, bail,
dtype=my_dtype, step=1024)
pr1.start()
pr2.start()
pr2.join()
print("\n{} joined.".format(pr2.name))
pr1.join()
print("{} joined.".format(pr1.name))
if __name__ == '__main__':
freeze_support()
set_start_method('spawn')
main()
这将准备一个包含两个数组的列表,即两个“队列” - 一个是“自由”队列,
MyDataProcessor 将完成处理的数组索引放入其中,
MyDataLoader 则从它那里获取;另一个是“已使用”队列,
MyDataLoader 将已填充数组的索引放入其中,
MyDataProcessor 则从中获取。此外还有一个
multiprocessing.Event
用户启动所有工作线程的有序退出。当前我们只有一个生产者和一个消费者,可以先不用后者,但为了应对更多工作线程,最好也准备一下。
然后我们用列表中的所有
RawArrays 的索引预先填充“空”
Queue,并实例化每种类型的工作线程,传递必要的通信对象给它们。接着启动这两个工作线程,并等待它们执行完毕(
join()
)。
以下是
MyDataProcessor 的示例,从“已使用”
Queue中消耗数组索引,然后将数据发送到某个外部黑盒(在本例中为
debugio.output
):
class MyDataProcessor(Process):
def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
super().__init__()
self.arrays = arrays
self.q_free = q_free
self.q_used = q_used
self.bail = bail
self.dtype = dtype
self.step = step
def run(self):
arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]
from debugio import output as writer
while True:
arr_id = self.q_used.get()
if arr_id is None:
break
arr = arrays[arr_id]
print('(', end='', flush=True)
for j in range(0, len(arr), self.step):
writer.write(str(arr[j]) + '\n')
print(')', end='', flush=True)
self.q_free.put(arr_id)
writer.flush()
self.bail.set()
self.q_free.put(None, timeout=1)
try:
while True:
self.q_used.get_nowait()
except queue.Empty:
pass
它的第一步是使用'np.frombuffer()'将接收到的原始数组
RawArrays封装成
ndarrays并保留新列表,这样在处理过程中它们就可以用作
numpy数组,而不需要反复封装。
还要注意,
MyDataProcessor仅向
self.bail
Event写入,它从不检查。相反,如果需要告诉它退出,则会在队列上找到一个
None
标记,而不是数组索引。当
MyDataLoader没有更多可用数据并开始拆卸过程时,这么做可以使
MyDataProcessor能够在不过早退出的情况下处理所有有效的数组。
以下是
MyDataLoader的示例代码:
class MyDataLoader(Process):
def __init__(self, arrays, q_free, q_used, bail, dtype=np.int32, step=1):
super().__init__()
self.arrays = arrays
self.q_free = q_free
self.q_used = q_used
self.bail = bail
self.dtype = dtype
self.step = step
def run(self):
arrays = [np.frombuffer(arr, dtype=self.dtype) for arr in self.arrays]
from debugio import input as reader
for _ in range(10):
if self.bail.is_set():
return
arr_id = self.q_free.get()
if arr_id is None:
self.q_free.put(None, timeout=1)
return
if self.bail.is_set():
return
arr = arrays[arr_id]
eof = False
print('<', end='', flush=True)
for j in range(0, len(arr), self.step):
line = reader.readline()
if not line:
eof = True
break
arr[j] = np.fromstring(line, dtype=self.dtype, sep='\n')
if eof:
print('EOF>', end='', flush=True)
break
print('>', end='', flush=True)
if self.bail.is_set():
return
self.q_used.put(arr_id)
if not self.bail.is_set():
self.bail.set()
self.q_used.put(None)
它的结构非常类似于另一个worker。它有点膨胀是因为它在许多点检查self.bail
事件,从而减少被卡住的可能性。(这并不完全可靠,因为在检查和访问队列之间设置事件的微小机会。如果这是个问题,则需要使用一些同步原语仲裁对事件和队列的访问。)
它还在最开始将接收到的RawArrays包装成ndarrays,并从外部黑盒子(例如debugio.input
)读取数据。
注意,在main()
函数中通过调整两个worker的step=
参数,我们可以改变读写比例(严格用于测试目的,在生产环境中step=1
,读写所有numpy数组成员)。
增加这两个值使工作进程只访问numpy数组中的一些值,从而显着提高了速度,这表明性能不受worker进程之间通信的限制。如果我们直接将numpy数组放入队列中,在进程之间完全复制它们来回传递,增加步长不会显著提高性能-它将保持缓慢。
供参考,这是我用于测试的debugio
模块:
from ast import literal_eval
from io import RawIOBase, BufferedReader, BufferedWriter, TextIOWrapper
class DebugInput(RawIOBase):
def __init__(self, end=None):
if end is not None and end < 0:
raise ValueError("end must be non-negative")
super().__init__()
self.pos = 0
self.end = end
def readable(self):
return True
def read(self, size=-1):
if self.end is None:
if size < 0:
raise NotImplementedError("size must be non-negative")
end = self.pos + size
elif size < 0:
end = self.end
else:
end = min(self.pos + size, self.end)
lines = []
while self.pos < end:
offset = self.pos % 400
pos = self.pos - offset
if offset < 18:
i = (offset + 2) // 2
pos += i * 2 - 2
elif offset < 288:
i = (offset + 12) // 3
pos += i * 3 - 12
else:
i = (offset + 112) // 4
pos += i * 4 - 112
line = str(i).encode('ascii') + b'\n'
line = line[self.pos - pos:end - pos]
self.pos += len(line)
size -= len(line)
lines.append(line)
return b''.join(lines)
def readinto(self, b):
data = self.read(len(b))
b[:len(data)] = data
return len(data)
def seekable(self):
return True
def seek(self, offset, whence=0):
if whence == 0:
pos = offset
elif whence == 1:
pos = self.pos + offset
elif whence == 2:
if self.end is None:
raise ValueError("cannot seek to end of infinite stream")
pos = self.end + offset
else:
raise NotImplementedError("unknown whence value")
self.pos = max((pos if self.end is None else min(pos, self.end)), 0)
return self.pos
class DebugOutput(RawIOBase):
def __init__(self):
super().__init__()
self.buf = b''
self.num = 1
def writable(self):
return True
def write(self, b):
*lines, self.buf = (self.buf + b).split(b'\n')
for line in lines:
value = literal_eval(line.decode('ascii'))
if value != int(value) or int(value) & 255 != self.num:
raise ValueError("expected {}, got {}".format(self.num, value))
self.num = self.num % 127 + 1
return len(b)
input = TextIOWrapper(BufferedReader(DebugInput()), encoding='ascii')
output = TextIOWrapper(BufferedWriter(DebugOutput()), encoding='ascii')