哪种神经网络可以处理可变的输入和输出尺寸?

19

我正在尝试使用这篇论文描述的方法https://arxiv.org/abs/1712.01815 来让算法学习一个新游戏。

唯一的问题是该方法并不能直接适用于我要学习的游戏,因为该游戏没有固定的棋盘大小。因此,当前输入张量的维度为m*n*11,其中m和n是游戏棋盘的维度,每次玩游戏时都可以变化。因此,首先我需要一个神经网络能够利用这种可变的输入尺寸。

输出的大小也取决于棋盘大小,因为它有一个矢量,其中包含棋盘上每个可能的移动条目,所以如果棋盘大小增加,则输出向量将更大。

我已经了解了递归和循环神经网络,但它们似乎都与自然语言处理有关,我不确定如何将其应用到我的问题上。

欢迎提供任何能够处理我的情况的神经网络架构的想法。


2
对于不同的输入大小,这里有一些好的答案:https://stats.stackexchange.com/questions/388859/is-it-possible-to-give-variable-sized-images-as-input-to-convolutioal-neural-net以及这里:https://ai.stackexchange.com/questions/2008/how-can-neural-networks-deal-with-varying-input-sizes但对于不同的输出大小,我仍然感到困惑。 - Charlie Parker
高度理论化的完全卷积网络。例如,YOLOv3能够处理不同尺寸的图像(在网络不崩溃的情况下)。 - viceriel
3个回答

11
你需要的是指针网络(Pointer Networks) (https://arxiv.org/abs/1506.03134)。
以下是关于指针网络的引言:
指针网络是一种新的神经结构,可以学习指向输入序列中某个位置的指针。这是新的,因为现有技术需要拥有固定数量的目标类,而这通常并不适用-请考虑旅行商问题,其中类的数量等于输入的数量。另一个例子是对可变大小的序列进行排序。 -https://finbarr.ca/pointer-networks/ 它是一种基于注意力机制的模型。
基本上,指针网络用于预测输入的指针,这意味着您的输出层实际上不是固定的,而是可变的。
我使用它们的一个用例是将原始文本翻译成SQL查询语句。
例如:
  • 输入:“HOW MANY CARS WERE SOLD IN US IN 1983”
  • 输出:SELECT COUNT(Car_id) FROM Car_table WHERE (Country='US' AND Year=='1983')

这样的原始文本的问题在于,它只针对特定的表格(在此情况下是具有围绕汽车销售的一组变量的汽车表格)才有意义。这意味着问题不能成为唯一的输入。因此,实际进入指针网络的输入是以下内容的组合:
  1. 查询
  2. 表格的元数据(列名称)
  3. 所有分类列的标记词汇表
  4. SQL语法中的关键字(SELECT、WHERE等)
所有这些都被追加在一起。
然后,输出层简单地指向输入的特定索引。它会指向元数据中的Country和Year(从分类列的词汇表中的标记中),它会指向输入的SELECT、WHERE等部分中的US和1983(从SQL语法组件中),并且还会指向元数据中的SELECT、WHERE等部分。这些索引的顺序则用作计算图的输出,并使用存在于WIKISQL数据集中的训练数据进行优化。
您的情况非常相似,您需要将输入、游戏的元数据以及作为输出的所需信息作为追加索引传递。然后,指针网络仅从输入中进行选择(指向它们)。

0
全卷积神经网络能够实现此功能。卷积层的参数是卷积核,卷积核不太关心输入大小(是的,与步长、填充输入和核心大小有关的限制是存在的)。
典型用例是一些conv层接着maxpooling,重复几次,直到滤波器被压平并连接到密集层。密集层存在问题,因为它期望固定大小的输入。如果有另一个conv2层,则输出将是适当大小的另一个特征映射。
这样的网络示例可以是YOLOv3。例如,如果您使用416x416x3图像进行馈送,则输出可能为13x13x滤波器数量(我知道YOLOv3具有更多的输出层,但由于简单起见,我仅讨论其中一个)。如果您使用256x256x3图像进行馈送,则输出将成为特征映射6x6x滤波器数量。
所以,网络不会崩溃并产生结果。结果会好吗?我不知道,也许是,也许不是。我从未以这种方式使用它,我总是重新调整图像大小或重新训练网络。

你测试过了吗?我认为TF/keras/...会出现张量大小不匹配的错误(类似这样的问题)。你是不是指填充(padding)? - Babak.Abad

0

你需要回到一个固定的输入/输出问题。

当应用于图像/时间序列时,解决此问题的常见方法是使用滑动窗口进行缩小。也许这可以应用于你的游戏。


我可以使用一个4x4的游戏板来训练网络,然后为每个4x4的棋盘块进行单独的预测。问题在于如何合并结果,以及如何补偿相对于考虑整个棋盘时丢失的信息,因为远离的部分往往会互相影响。 - Damcios
平均数,多数投票,自定义规则... 我不知道你的游戏 - mxdbld

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接