如何去除子图之间的间距

24
我正在一个项目中工作,需要组合一个10行3列的绘图网格。虽然我已经能够制作出绘图并安排子图,但是我无法像gridspec文档中的下面这个图那样产生一个没有白边的漂亮图像。image w/o white space
我尝试了以下帖子,但仍然无法完全去除白边,就像示例图像中那样。请问有人能给我一些指导吗?谢谢! 这是我的图片:my image 以下是我的代码。完整的脚本在GitHub上。 注意:images_2和images_fool都是形状为(1032, 10)的扁平化图像的numpy数组,而delta是形状为(28, 28)的图像数组。
def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')
5个回答

40

注意:如果想完全控制间距,请避免使用plt.tight_layout(),因为它会尝试将图形中的子图等量且漂亮地分布。虽然这通常是可以的且产生令人愉悦的结果,但它会随意调整间距。

您引用自Matplotlib示例库的GridSpec示例之所以效果很好,是因为子图的方面比例未预定义。也就是说,子图将在网格上简单扩展并保留设置的间距(在此示例中为 wspace=0.0, hspace=0.0),而独立于图形大小。

相反,当您用imshow绘制图像时,图像的方面比例默认为相等(相当于 ax.set_aspect("equal"))。也就是说,您当然可以将set_aspect("auto")应用于每个子图(并像示例中的GridSpec一样添加wspace=0.0, hspace=0.0作为参数),这将产生一个没有间距的图。

但是,当使用图像时,保持相等的纵横比非常重要,这样每个像素宽度和高度都相同,并且正方形数组显示为正方形图像。
那么您需要做的就是调整图像大小和图形边距,以获得预期结果。figure的figsize参数是英寸为单位的图形宽度和高度,这里可以玩弄两个数字的比例。subplot参数wspace、hspace、top、bottom、left可以手动调整,以达到所需的结果。 下面是一个示例:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10)) 

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

enter image description here

编辑:
当然,不需要手动调整参数是很理想的。 因此,可以根据行数和列数计算一些最优参数。

nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1)) 

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()

像魔术一样好用,感谢@ImportanceOfBeingErnest!只是想知道为什么您使用figsize =(4,10)而不是figsize =(10,10)...后者会立即带回空间。 - George Liu
1
如果你的行数比列数多3倍,为什么需要一个正方形的图形尺寸呢?当然,你可以将其设置为(10,10),然后再次调整“left”和“right”参数。我选择figsize=(4, 10)更多是因为想法是,对于有n行和m列的情况,一个(m+1, n)的figsize可能是合适的;其余的则通过微调subplot参数来完成。 - ImportanceOfBeingErnest
1
figsize 是图形的尺寸,单位为英寸。我已经编辑了答案,包括一个示例,其中参数是从列数和行数计算出来的。 - ImportanceOfBeingErnest
1
对于那些使用图像数组的人,将 gen = iter(image_array) 放在 for 循环之外,ax.imshow(next(gen, np.full((1, 1, 3), 255))) 放在循环内。当 gen 没有值时,它会渲染一个大小为 1x1 的白色图像。 - DharmaTurtle
我喜欢你,Ser。相比其他不起作用的答案,这是唯一最好的答案。 - anmo
显示剩余2条评论

15

尝试在你的代码中添加这行:

fig.subplots_adjust(wspace=0, hspace=0)

对于每个轴对象集:

ax.set_xticklabels([])
ax.set_yticklabels([])

6
这个解决方案在子图的纵横比设置为“auto”时可以正常工作。对于使用imshow绘制的图像,它将失败。请参考我的解决方案。 - ImportanceOfBeingErnest
2
谢谢您的回复。这确实去掉了垂直空间,但水平空间仍然存在... - George Liu

5

在 ImportanceOfBeingErnest 的答案之后,如果您想使用 plt.subplots 及其功能:

fig, axes = plt.subplots(
    nrow, ncol,
    gridspec_kw=dict(wspace=0.0, hspace=0.0,
                     top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                     left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
    figsize=(ncol + 1, nrow + 1),
    sharey='row', sharex='col', #  optionally
)

0
如果您正在使用matplotlib.pyplot.subplots,您可以使用Axes数组显示任意数量的图像。您可以通过对matplotlib.pyplot.subplots配置进行一些调整来消除图像之间的空白。
import matplotlib.pyplot as plt

def show_dataset_overview(self, img_list):
"""show each image in img_list without space"""
    img_number = len(img_list)
    img_number_at_a_row = 3
    row_number = int(img_number /img_number_at_a_row) 
    fig_size = (15*(img_number_at_a_row/row_number), 15)
    _, axs = plt.subplots(row_number, 
                          img_number_at_a_row, 
                          figsize=fig_size , 
                          gridspec_kw=dict(
                                       top = 1, bottom = 0, right = 1, left = 0, 
                                       hspace = 0, wspace = 0
                                       )
                         )
    axs = axs.flatten()

    for i in range(img_number):
        axs[i].imshow(img_list[i])
        axs[i].set_xticks([])
        axs[i].set_yticks([])

由于我们首先在此处创建子图,因此可以使用gridspec_kw参数(source)为grid_spec提供一些参数。 其中包括“top = 1,bottom = 0,right = 1,left = 0,hspace = 0,wspace = 0”参数,可防止图像之间的间距。要查看其他参数,请访问here

当上面设置figure_size时,我通常使用像(30,15)这样的图形大小。我将其概括了一下,并将其添加到了代码中。如果您愿意,可以在此处输入手动大小。


0
这是另一种简单的方法,使用ImageGrid类(改编自this answer)。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

nrow = 5
ncol = 3
fig = plt.figure(figsize=(4, 10))
grid = ImageGrid(fig, 
                 111, # as in plt.subplot(111)
                 nrows_ncols=(nrow,ncol),
                 axes_pad=0,
                 share_all=True,)

for row in grid.axes_column:
    for ax in row:
        im = np.random.rand(28,28)
        ax.imshow(im)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接