在Redis中存储NumPy数组的最快方法

17

我在一个AI项目中使用redis。

这个想法是让多个环境模拟器在很多CPU核心上运行策略。 模拟器将体验(状态/动作/奖励元组的列表)写入redis服务器(重放缓冲区)。 然后,训练过程将经验数据集读取为一个新策略生成工具。 新策略被部署到模拟器上,之前运行的数据被删除,然后该过程继续执行。

大部分经验都包含在“状态”中。 它通常表示为一个大的numpy数组,维度为80 x 80。 模拟器能够尽可能快速地生成这些数组。

为此,有没有人有好的想法或经验来将大量的numpy数组快速简单地写入redis。 这都在同一台机器上,但以后可能会在一组云服务器上运行。 欢迎提供代码示例!


使用https://pypi.org/project/direct-redis/,无需任何麻烦。 - undefined
5个回答

33

我不知道这是否是最快的方法,但你可以尝试像这样做...

将Numpy数组存储到Redis中的方法如下 - 参见函数toRedis():

  • 获取Numpy数组的形状并编码
  • 将Numpy数组以字节形式附加到形状上
  • 在提供的键下存储编码数组

检索Numpy数组的方法如下 - 参见函数fromRedis()

  • 从Redis中检索与提供的键对应的编码字符串
  • 从字符串中提取Numpy数组的形状
  • 提取数据并重新填充Numpy数组,将其重新塑造为原始形状

#!/usr/bin/env python3

import struct
import redis
import numpy as np

def toRedis(r,a,n):
   """Store given Numpy array 'a' in Redis under key 'n'"""
   h, w = a.shape
   shape = struct.pack('>II',h,w)
   encoded = shape + a.tobytes()

   # Store encoded data in Redis
   r.set(n,encoded)
   return

def fromRedis(r,n):
   """Retrieve Numpy array from Redis key 'n'"""
   encoded = r.get(n)
   h, w = struct.unpack('>II',encoded[:8])
   # Add slicing here, or else the array would differ from the original
   a = np.frombuffer(encoded[8:]).reshape(h,w)
   return a

# Create 80x80 numpy array to store
a0 = np.arange(6400,dtype=np.uint16).reshape(80,80) 

# Redis connection
r = redis.Redis(host='localhost', port=6379, db=0)

# Store array a0 in Redis under name 'a0array'
toRedis(r,a0,'a0array')

# Retrieve from Redis
a1 = fromRedis(r,'a0array')

np.testing.assert_array_equal(a0,a1)
你可以通过在形状信息中编码 Numpy 数组的 dtype 来增加更多的灵活性。我没有这样做是因为你可能已经知道所有数组都是特定类型的,那样代码只会变得更大更难读,而没有必要。
现代 iMac 上的粗略基准测试:
80x80 Numpy array of np.uint16   => 58 microseconds to write
200x200 Numpy array of np.uint16 => 88 microseconds to write

关键词: Python, Numpy, Redis, 数组, 序列化, 键, 自增, 唯一


@MarkSetchell 这个脚本现在抛出了一个错误 ValueError: 无法将大小为1600的数组重塑为形状(80,80) - ashnair1
@AshwinNair 一个80x80的数组需要6400个元素,所以这是正确的,它不能从1600个元素的数组中重新塑形。 - Mark Setchell
是的,但您已经声明了a0有6400个元素。我的意思是说,按照现在的脚本运行会失败。不确定原因。 - ashnair1
好的,这是由于数据类型(dtype)的原因。你可以通过np.frombuffer(encoded[8:], dtype=np.uint16)指定数据类型。但更好的选择确实是将数据类型也进行编码。 - ashnair1
请问您能否更新您的答案,同时添加数据类型持久性呢? - Jean Carlo Machado
显示剩余6条评论

12
您还可以考虑使用msgpack-numpy,它提供了“编码和解码例程,使用高效的msgpack格式对numpy提供的数字和数组数据类型进行序列化和反序列化”。--请参见https://msgpack.org/
快速概念验证:
import msgpack
import msgpack_numpy as m
import numpy as np
m.patch()               # Important line to monkey-patch for numpy support!

from redis import Redis

r = Redis('127.0.0.1')

# Create an array, then use msgpack to serialize it 
d_orig = np.array([1,2,3,4])
d_orig_packed = m.packb(d_orig)

# Set the data in redis
r.set('d', d_orig_packed)

# Retrieve and unpack the data
d_out = m.unpackb(r.get('d'))

# Check they match
assert np.alltrue(d_orig == d_out)
assert d_orig.dtype == d_out.dtype

在我的机器上,使用 msgpack 比使用 struct 运行速度更快:

In: %timeit struct.pack('4096L', *np.arange(0, 4096))
1000 loops, best of 3: 443 µs per loop

In: %timeit m.packb(np.arange(0, 4096))
The slowest run took 7.74 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 32.6 µs per loop

1
尽管我确实欣赏使用 msgpack 的简单和优雅,但我不确定您的样本时间是什么意思。您似乎将 msgpack 的计时与结构装配进行了比较,但如果您仔细阅读我的回答,您会发现我只结构化包装维度而不是数组数据本身,对于这些数组数据本身,我使用 np.tobytes()。在我的机器上,如果您将 np.tobytes()msgpack 进行比较,它至少快了50倍,即314纳秒对比17.3微秒。 - Mark Setchell
@MarkSetchell,是的,你完全正确,这不是一个公平的比较。如果我从你的答案中仅提取打包逻辑以测试速度并将其命名为def pack(a),那么在一个80x80数组上,%timeit pack(a)给出4.62us,而%timeit m.packb(a)需要12us,因此慢了2.5倍。尽管如此,msgpack-numpy是一个很棒的软件包! - telegraphic

6
你可以查看Mark Setchell的答案,了解如何将字节写入Redis。在下面,我将重新编写函数 fromRedis toRedis 以考虑具有可变维度大小的数组,并包括数组形状信息。请参考以下内容:

def toRedis(arr: np.array) -> str:
    arr_dtype = bytearray(str(arr.dtype), 'utf-8')
    arr_shape = bytearray(','.join([str(a) for a in arr.shape]), 'utf-8')
    sep = bytearray('|', 'utf-8')
    arr_bytes = arr.ravel().tobytes()
    to_return = arr_dtype + sep + arr_shape + sep + arr_bytes
    return to_return

def fromRedis(serialized_arr: str) -> np.array:
    sep = '|'.encode('utf-8')
    i_0 = serialized_arr.find(sep)
    i_1 = serialized_arr.find(sep, i_0 + 1)
    arr_dtype = serialized_arr[:i_0].decode('utf-8')
    arr_shape = tuple([int(a) for a in serialized_arr[i_0 + 1:i_1].decode('utf-8').split(',')])
    arr_str = serialized_arr[i_1 + 1:]
    arr = np.frombuffer(arr_str, dtype = arr_dtype).reshape(arr_shape)
    return arr

4

尝试使用Plasma,因为它避免了串行化/反序列化开销。

使用pip install pyarrow安装Plasma。

文档:https://arrow.apache.org/docs/python/plasma.html

首先,使用1GB内存启动Plasma [终端]:

plasma_store -m 1000000000 -s /tmp/plasma

import pyarrow.plasma as pa
import numpy as np
client = pa.connect("/tmp/plasma")
temp = np.random.rand(80,80)

写入时间: 130微秒 vs 782微秒 (Redis实现:Mark Setchell的回答)

使用Plasma巨页可以提高写入时间,但仅适用于Linux机器:https://arrow.apache.org/docs/python/plasma.html#using-plasma-with-huge-pages

读取时间: 31.2微秒 vs 99.5微秒 (Redis实现:Mark Setchell的回答)

PS: 代码在MacPro上运行。


感谢pyarrow示例。欢迎贡献! - Duane
1
有意思 - 我不知道plasma/pyarrow。但有几个问题要提出来。1)你的代码并没有展示如何写入或读取plasma;2)你的代码在另一台机器上使用了不同的dtype和不同的数据,因此时间无法比较;3)如果我使用plasma和client.put()将与我在回答中创建的相同数组,那么Redis需要大约70微秒,而plasma需要196微秒 - 尽管我必须说我没有使用plasma或优化它的经验。 - Mark Setchell

1

tobytes() 函数并不是非常节省存储空间的。 为了减少需要写入 Redis 服务器的存储空间,可以使用 base64 包:

def encode_vector(ar):
    return base64.encodestring(ar.tobytes()).decode('ascii')

def decode_vector(ar):
    return np.fromstring(base64.decodestring(bytes(ar.decode('ascii'), 'ascii')), dtype='uint16')

@EDIT: 好的,由于Redis将值存储为字节字符串,直接存储字节字符串更加高效。但是,如果您将其转换为字符串,将其打印到控制台或将其存储在文本文件中,则进行编码是有意义的。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接