根据下标获取行的均值的快速方法

3

我有一份数据,可以通过以下方式进行模拟:

N = 10^6;%10^8;
K = 10^4;%10^6; 

subs = randi([1 K],N,1);
M = [randn(N,5) subs];
M(M<-1.2) = nan;

换句话说,它是一个矩阵,其中最后一行是下标。 现在我想要为每个下标计算 nanmean()。同时,我还想保存每个下标的行数。这里有一个 'dummy' 代码:

uniqueSubs = unique(M(:,6));
avM = nan(numel(uniqueSubs),6);
for iSub = 1:numel(uniqueSubs)
    tmpM = M(M(:,6)==uniqueSubs(iSub),1:5);
    avM(iSub,:) = [nanmean(tmpM,1) size(tmpM,1)];
end

问题在于它太慢了。我希望它能处理 N = 10^8K = 10^6(请参见这些变量的定义中的注释部分)。
有什么方法可以更快地找到数据的平均值吗?
2个回答

6
这似乎是使用 findgroupssplitapply 的完美工作。
% Find groups in the final column
G = findgroups(M(:,6));
% function to apply per group
fcn = @(group) [mean(group, 1, 'omitnan'), size(group, 1)];
% Use splitapply to apply fcn to each group in M(:,1:5)
result = splitapply(fcn, M(:, 1:5), G);
% Check
assert(isequaln(result, avM));

1
注意:我使用了 mean 函数的 'omitnan' 版本,因为它存在于基础 MATLAB 中并且执行完全相同的操作。 - Edric
比我的循环更易读,但慢了大约15%;这是因为findgroups和/或splitapply在底层进行了检查,而我在循环中省略了吗? - Adriaan
1
说实话,我不确定额外的开销从哪里来。findgroupssplitapply都是MATLAB代码,在我的机器上,两者都没有明显的浪费。我认为findgroupssplitapply更易读,并且(现在我已经学会了它们)我认为更多的人应该利用它们。 - Edric
我同意你的观点。差别并不大,不会成为问题,而且函数名称非常描述性,使得它们比使用自定义变量名的for循环更易读。 - Adriaan

5
M = sortrows(M,6); % sort the data per subscript
IDX = diff(M(:,6)); % find where the subscript changes
tmp = find(IDX);
tmp = [0 ;tmp;size(M,1)]; % add start and end of data
for iSub= 2:numel(tmp)
    % Calculate the mean over just a single subscript, store in iSub-1
    avM2(iSub-1,:) = [nanmean(M(tmp(iSub-1)+1:tmp(iSub),1:5),1) tmp(iSub)-tmp(iSub-1)];tmp(iSub-1)];
end

这比您在我的计算机上的原始代码快60倍。加速主要来自对数据进行预排序,然后找到所有下标更改的位置。这样,每次不必遍历整个数组以找到正确的下标,而是仅在每次迭代中检查必要的内容。因此,您可以在约100行上计算平均值,而无需首先检查在该迭代中是否需要每行1,000,000行。

因此:在原始代码中,您检查numel(uniqueSubs),在这种情况下为10,000,看所有N,在这里为1,000,000,数字是否属于某个类别,这导致了10^12次检查。建议的代码对行进行排序(排序是NlogN,因此在此处为6,000,000),然后循环一次完整的数组而不进行额外的检查。


为了完成,这里是原始代码和我的版本,两者是相同的:

N = 10^6;%10^8;
K = 10^4;%10^6; 

subs = randi([1 K],N,1);
M = [randn(N,5) subs];
M(M<-1.2) = nan;

uniqueSubs = unique(M(:,6));
%% zlon's original code
avM = nan(numel(uniqueSubs),7); % add the subscript for comparison later
tic
uniqueSubs = unique(M(:,6));
for iSub = 1:numel(uniqueSubs)
    tmpM = M(M(:,6)==uniqueSubs(iSub),1:5);
    avM(iSub,:) = [nanmean(tmpM,1) size(tmpM,1) uniqueSubs(iSub)];
end
toc
%%%%% End of zlon's code
avM = sortrows(avM,7); % Sort for comparison

%% Start of Adriaan's code
avM2 = nan(numel(uniqueSubs),6);
tic
M = sortrows(M,6);
IDX = diff(M(:,6));
tmp = find(IDX);
tmp = [0 ;tmp;size(M,1)];
for iSub = 2:numel(tmp)
    avM2(iSub-1,:) = [nanmean(M(tmp(iSub-1)+1:tmp(iSub),1:5),1) tmp(iSub)-tmp(iSub-1)];
end
toc %tic/toc should not be used for accurate timing, this is just for order of magnitude
%%%% End of Adriaan's code

all(avM(:,1:6) == avM2) % Do the comparison
% End of script

% Output
Elapsed time is 58.561347 seconds.
Elapsed time is 0.843124 seconds. % ~70 times faster

ans =

  1×6 logical array

   1   1   1   1   1   1 % i.e. the matrices are equal to one another

这是因为你的代码输出超出了理解范围!首先,iSub没有声明,但我猜它应该是ii。它返回一个numel(tmp)*6的矩阵!这些行和列应该代表什么?请详细说明! - Rahul
1
@Rahul 谢谢你发现了这个错别字,我已经更正了以匹配 OP 的代码。avM 的声明是我从原始代码中复制的,关于列的含义,你应该问他,我只是提高了他的代码速度。前五列给出了某个下标元素的平均值,第六列给出了具有相应下标的元素数量。如果您还想知道它是哪个下标,那么它就是简单地索引为 1:numel(rows),因此行号就是下标。如果您需要更多信息,请告诉我。 - Adriaan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接