在Divakar的回答的基础上,如果原始数组A的范围比B更广,那么就会出现“索引超出范围”的错误。请参见:
A = np.array([1,1,2,3,3,3,4,5,6,7,8,8,10,12,14])
B = np.array([1,2,8])
A[B[np.searchsorted(B,A)] != A]
>> IndexError: index 3 is out of bounds for axis 0 with size 3
这将发生是因为
np.searchsorted
会将A中的10、12和14元素分配给B中的索引3(即B中插入位置的最后一个位置)。因此,在
B[np.searchsorted(B,A)]
中会出现IndexError。
为了避免这种情况,一种可能的方法是:
def subset_sorted_array(A,B):
Aa = A[np.where(A <= np.max(B))]
Bb = (B[np.searchsorted(B,Aa)] != Aa)
Bb = np.pad(Bb,(0,A.shape[0]-Aa.shape[0]), method='constant', constant_values=True)
return A[Bb]
它的工作方式如下:
Aa = A[np.where(A <= np.max(B))]
Bb = (B[np.searchsorted(B,Aa)] != Aa)
Bb = np.pad(Bb,(0,A.shape[0]-Aa.shape[0]), method='constant', constant_values=True)
A[Bb]
>> array([ 3, 3, 3, 4, 5, 6, 7, 10, 12, 14])
请注意,这也适用于字符串数组和其他类型之间的比较(对于所有定义了比较运算符“<=”的类型)。
numpy.setdiff1d
不正是你所需要的吗? - der_herr_g