Python,元组索引必须是整数而不是元组?

9

所以,我不太确定这里发生了什么,但由于某种原因,Python向我抛出了这个错误。作为参考,它是我为了好玩而构建的一个小型神经网络的一部分,但它使用了许多np.array等,所以有很多矩阵被抛出,所以我认为它正在创建某种数据类型冲突。也许有人可以帮我解决这个问题,因为我已经盯着这个错误看了太久,却无法修复它。

#cross-entropy error
#y is a vector of size N and output is an Nx3 array
def CalculateError(self, output, y): 

    #calculate total error against the vector y for the neurons where output = 1 (the rest are 0)
    totalError = 0
    for i in range(0,len(y)):
       totalError += -np.log(output[i, int(y[i])]) #error is thrown here

    #now account for regularizer
    totalError+=(self.regLambda/self.inputDim) * (np.sum(np.square(self.W1))+np.sum(np.square(self.W2)))     

    error=totalError/len(y) #divide ny N
    return error

编辑:以下是返回输出的函数,以便您知道它来自哪里。y是直接从文本文件中获取的长度为150的向量。在y的每个索引处,它包含索引1、2或3中的一个:

#forward propogation algorithm takes a matrix "X" of size 150 x 3
def ForProp(self, X):            
        #signal vector for hidden layer
        #tanh activation function
        S1 = X.dot(self.W1) + self.b1
        Z1 = np.tanh(S1)

        #vector for the final output layer
        S2 = Z1.dot(self.W2)+ self.b2
        #softmax for output layer activation
        expScores = np.exp(S2)
        output = expScores/(np.sum(expScores, axis=1, keepdims=True))
        return output,Z1

2
看起来 output 并不像你想的那样是一个 Nx4 的数组。 - user2357112
2
请包含完整的回溯信息。 - Christian Dean
1
你如何保证 y[i] 在范围 [0,3] 内?这似乎是你的问题。这要么是一个明显的错误,要么是需要在架构中修复的临时解决方案。 - Harrichael
1
你需要如何处理ForProp的结果呢?它返回一个元组,很可能你正在传递产生的元组,其中包含outputZ1。我猜想你有类似这样的代码:output = ForProp( ... ),但实际上应该是 output, Z1 = ForProp( ... ) - lejlot
1个回答

19

你的 output 变量不是一个 N x 4 的矩阵,至少从 Python 类型 的角度来看不是。它是一个元组,只能通过单个数字进行索引,而你试图通过一个元组(两个用逗号分隔的数字)进行索引,这只适用于 numpy 矩阵。打印你的输出,找出问题是否仅仅是类型问题(那么只需转换为 np.array),或者是否传递了完全不同的内容(然后修复产生 output 的任何问题)。

发生的示例:

import numpy as np
output = ((1,2,3,5), (1,2,1,1))

print output[1, 2] # your error
print output[(1, 2)] # your error as well - these are equivalent calls

print output[1][2] # ok
print np.array(output)[1, 2] # ok
print np.array(output)[(1, 2)] # ok
print np.array(output)[1][2] # ok

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接