如何获取前一层的形状并将其传递给下一层?

5

我希望将传递给输入层的形状为(None,)的输入数据作为模板,并在for循环中使用它来完成某些任务。

以下是我的代码实现部分:

lst_rfrm = []
Inpt_lyr = keras.Input(shape = (None,))
for k in range(tm_stp):
  F = keras.layers.Lambda(lambda x, i, j: x[:, None, j : j + i])
  F.arguments = {'i' : sub_len, 'j' : k}
  tmp_rfrm = F(Inpt_lyr)
  lst_rfrm.append(tmp_rfrm)
cnctnt_lyr = keras.layers.merge.Concatenate(axis = 1)(lst_rfrm)
#defining other layers ...

由于输入形状为(None,),我不知道该在for循环的range()中传递什么参数(我用"tm_stp"来描述)。在这种情况下,我如何获取输入层(传递到输入层的数据)的形状?非常感谢您的任何帮助。


3
你应该尝试以"张量"的方式处理,而不是迭代。 - Daniel Möller
也许可以在预处理中使用生成器来实现。 - Daniel Möller
你的意思是在这种情况下无法获取输入形状吗?我已经搜索了很多关于张量的方法,但不幸的是没有找到任何方法。 - lordof4towers
非常感谢任何帮助。 - lordof4towers
在TensorFlow中,可以通过tf.shape(layer_name).eval()来获取层的形状(在输入形状为(None,)的动态形状的情况下),但我不知道如何在Keras中实现... - lordof4towers
1个回答

1

你可以尝试不同类型的循环。看起来你正在尝试滑动窗口,对吧?你不知道要运行的“长度”,但你知道窗口的大小以及要删除多少边界...所以....

Slicing differently

该函数根据以下原则获取切片:
windowSize = sub_len
def getWindows(x):
    borderCut = windowSize - 1 #lost length in the length dimension

    leftCut = range(windowSize) #start of sequence
    rightCut = [i - borderCut for i in leftCut] #end of sequence - negative
    rightCut[-1] = None #because it can't be zero for slicing

    croppedSequences = K.stack([x[:, l: r] for l,r in zip(leftCut, rightCut)], axis=-1)
    return croppedSequences

运行测试:
from keras.layers import *
from keras.models import Model
import keras.backend as K
import numpy as np

windowSize = 3
batchSize = 5

randomLength = np.random.randint(5,10)
inputData = np.arange(randomLength * batchSize).reshape((batchSize, randomLength))

def getWindows(x):
    borderCut = windowSize - 1

    leftCut = range(windowSize)
    rightCut = [i - borderCut for i in leftCut]
    rightCut[-1] = None

    croppedSequences = K.stack([x[:, l: r] for l,r in zip(leftCut, rightCut)], axis=-1)
    return croppedSequences

inputs = Input((None,))
outputs = Lambda(getWindows)(inputs)
model = Model(inputs, outputs)

preds = model.predict(inputData)

for i, (inData, pred) in enumerate(zip(inputData, preds)):
    print('sample: ', i)
    print('input sequence: ', inData)
    print('output sequence: \n', pred, '\n\n')

谢谢你,Daniel。它起作用了。我非常感激你的帮助。 - lordof4towers

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接