在Python中绘制三维数组的最有效方法是什么?

9
什么是 Python 中绘制三维数组最有效的方法?
例如:
volume = np.random.rand(512, 512, 512)

其中数组项表示每个像素的灰度颜色。


以下代码运行速度过慢:

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.gca(projection='3d')
volume = np.random.rand(20, 20, 20)
for x in range(len(volume[:, 0, 0])):
    for y in range(len(volume[0, :, 0])):
        for z in range(len(volume[0, 0, :])):
            ax.scatter(x, y, z, c = tuple([volume[x, y, z], volume[x, y, z], volume[x, y, z], 1]))
plt.show()

如果你想浏览数据而不是一次性绘制图形,可以使用以下链接:https://www.datacamp.com/community/tutorials/matplotlib-3d-volumetric-data - ferdymercury
3个回答

6
为了获得更好的性能,如果可能的话,请避免多次调用ax.scatter。相反,将所有的xyz坐标和颜色打包成一维数组(或列表),然后只调用一次ax.scatter
ax.scatter(x, y, z, c=volume.ravel())

这个问题(无论是CPU时间还是内存)会随着size(立方体的边长)的增加呈现size**3的规模增长。

此外,ax.scatter将尝试渲染所有size**3个点,而不考虑大部分点都被外层点所遮挡的事实。

在渲染之前,通过对volume进行汇总或重新采样/插值等方式来减少volume中的点数量可能有所帮助。

我们还可以通过仅绘制外壳将CPU和内存的需求从O(size**3)降低到O(size**2)

import functools
import itertools as IT
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def cartesian_product_broadcasted(*arrays):
    """
    https://dev59.com/0Wgu5IYBdhLWcg3wso-j#11146645 (senderle)
    """
    broadcastable = np.ix_(*arrays)
    broadcasted = np.broadcast_arrays(*broadcastable)
    dtype = np.result_type(*arrays)
    rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
    out = np.empty(rows * cols, dtype=dtype)
    start, end = 0, rows
    for a in broadcasted:
        out[start:end] = a.reshape(-1)
        start, end = end, end + rows
    return out.reshape(cols, rows).T

# @profile  # used with `python -m memory_profiler script.py` to measure memory usage
def main():
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection='3d')

    size = 512
    volume = np.random.rand(size, size, size)
    x, y, z = cartesian_product_broadcasted(*[np.arange(size, dtype='int16')]*3).T
    mask = ((x == 0) | (x == size-1) 
            | (y == 0) | (y == size-1) 
            | (z == 0) | (z == size-1))
    x = x[mask]
    y = y[mask]
    z = z[mask]
    volume = volume.ravel()[mask]

    ax.scatter(x, y, z, c=volume, cmap=plt.get_cmap('Greys'))
    plt.show()

if __name__ == '__main__':
    main()

enter image description here

请注意,即使只绘制外壳,要实现size=512的绘图,我们仍需要大约1.3 GiB的内存。还要注意,即使您拥有足够的总内存,但由于缺乏RAM,程序使用交换空间,则程序的整体速度将会显著减慢。如果您发现自己处于这种情况,那么唯一的解决方案就是找到一种更聪明的方法来渲染一个可接受的图像,使用较少的点,或购买更多的RAM。

enter image description here


这段代码对于大小不超过(100,100,100)的数组运行速度足够快。我能否期望更快的构建x、y、z的方法有助于使其适用于(512,512,512)的数组? - Dmitry
1
对于这样大小的数组,我认为主要瓶颈是绘制512**3(约1.34亿)个点,而不是创建坐标数组。我已经编辑了上面的帖子,展示了如何只绘制立方体的外壳。这将把CPU和内存复杂度从O(size**3)降低到O(size**2)。然而,根据您的机器速度和资源情况,渲染可能需要相当长的时间。 - unutbu
数组中的大部分项都是零。因此,结果应该在立方体内显示相同的形状... - Dmitry

5
首先,一个密集的 512x512x512 点的网格数据太多了,不仅从技术角度而言,而且在观察这个图时很难看到有用的东西。您可能需要提取一些等值面、查看切片等。如果大部分点是不可见的,则可能还好,但然后您应该要求 ax.scatter 只显示非零点以使其更快。
话虽如此,您可以更快地完成它。诀窍是消除所有 Python 循环,包括那些隐藏在库中的循环,例如 itertools
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

# Make this bigger to generate a dense grid.
N = 8

# Create some random data.
volume = np.random.rand(N, N, N)

# Create the x, y, and z coordinate arrays.  We use 
# numpy's broadcasting to do all the hard work for us.
# We could shorten this even more by using np.meshgrid.
x = np.arange(volume.shape[0])[:, None, None]
y = np.arange(volume.shape[1])[None, :, None]
z = np.arange(volume.shape[2])[None, None, :]
x, y, z = np.broadcast_arrays(x, y, z)

# Turn the volumetric data into an RGB array that's
# just grayscale.  There might be better ways to make
# ax.scatter happy.
c = np.tile(volume.ravel()[:, None], [1, 3])

# Do the plotting in a single call.
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(),
           y.ravel(),
           z.ravel(),
           c=c)

非常感谢您的回答!但是它的速度与unutbu的解决方案相同。只适用于大小不超过~(100,100,100)的数组。但我会尽力清理它,只留下非零点。 - Dmitry

1

使用 itertools 中的 product 函数也可以实现类似的解决方案:

from itertools import product
from matplotlib import pyplot as plt
N = 8
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(projection="3d")
space = np.array([*product(range(N), range(N), range(N))]) # all possible triplets of numbers from 0 to N-1
volume = np.random.rand(N, N, N) # generate random data
ax.scatter(space[:,0], space[:,1], space[:,2], c=space/8, s=volume*300)

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接