在Matlab 3D数组中查找重复的2D数组

4

worldStates是一个Matlab MxNxL三维数组(张量),包含MxN网格的L种二进制值状态。

ps是与不同状态相关的概率的长度为L的列表。

函数[worldStates,ps] = StateMerge(worldStates,ps)应该删除重复的世界状态并将合并状态的概率相加到剩余的单个状态中。 重复状态是具有完全相同的二进制值配置的状态。

这是此功能的当前实现:

function [worldStates, ps] = StateMerge(worldStates, ps)
    
    M = containers.Map;
    
    for i = 1:length(ps)
        s = worldStates(:,:,i); 
        s = mat2str(s);
        if isKey(M, s)
            M(s) = M(s) + ps(i);
        else
            M(s) = ps(i);
        end
    end
    
    stringStates = keys(M);
    n = length(stringStates);
    
    sz = size(worldStates);
    worldStates = zeros([sz(1:2), n]);
    ps = zeros(1, 1, n);
    
    for i = 1:n
        worldStates(:,:,i) = eval(stringStates{i});
        ps(i) = M(stringStates{i});
    end
end 

它使用Map来能够在O(L)时间内删除重复项,使用状态作为键和概率作为值。由于Matlab映射不允许将一般数据结构用作键,因此将状态转换为字符串表示形式以用作键,然后使用eval函数将其转换回数组。

事实证明,对于我需要处理多个状态(数量约为10^6)多次的需求(10^3),此代码速度太慢了。问题在于将矩阵转换为字符串需要大量时间,并且与状态大小缩放得不好。下面是一个小型25x25状态的示例:

enter image description here

如何更有效地创建键?除了使用Map之外,是否有其他解决方案可以产生更好的结果?

编辑:按请求的可运行代码。此示例使合并变得不太可能:

worldStates = double(rand(25,25, 1000) > 0.5);

weights = rand(1,1, 1000);
ps = weights./sum(weights);

[worldStates, ps] = StateMerge(worldStates, ps);

在这个例子中会有很多合并操作:

worldStates = double(rand(25,25) > 0.5) .* ones(1,1,1000);
worldStates(1:2,1:2,:) = rand(2,2,1000) > 0.5;

weights = rand(1,1, 1000);
ps = weights./sum(weights);

[worldStates, ps] = StateMerge(worldStates, ps);

4
摒弃eval()函数,它会禁用JIT(即时编译),降低代码执行速度,并伴随着一系列问题。详见我的这个回答及其引用。请加上一个minimal, complete, and verifiable example (mcve),也就是可供我们运行的代码?对于纯数值矩阵使用eval和字符串并不高效。 - Adriaan
@Adriaan,根据您的要求,我添加了可运行的示例代码。我知道eval被认为是邪恶的,但不知道如何有效地对矩阵进行编码以便于映射并快速解析它。如果您知道如何做到这一点,那正是我正在寻找的。 - Emil Jansson
1
因为它以微妙的方式显示出来:通过禁用整个函数的JIT(即时编译)。因此,包含eval的行可能很快,但其他函数可能会受到缺乏JIT的严重影响。 - Adriaan
2
我对字典、映射或哈希知识了解不多。话虽如此,我觉得你的问题似乎是要在一个三维数值矩阵中找到重复的页面(第三维)。这可以通过保持数字状态来更轻松、更快速地完成。 - Adriaan
1
基本上,all(A(:,:,1)==A(:,:,2),'all') 将比较两个矩阵是否完全相等(适用于二进制矩阵),然后通过简单的循环遍历所有可能性即可。可能可以通过利用 'dim' 参数或者更智能的检查来使其更加智能化(例如,~any() 有可能在不相等时提前终止,而不是像 all() 那样一直执行到最后)。 - Adriaan
显示剩余7条评论
1个回答

6
使用unique函数提取唯一(合并)状态,使用accumarray函数求和合并状态的概率。请注意,像您的解决方案一样,此解决方案不保留原始状态的顺序。如@Wolfie在评论中建议的,您可以使用带有“stable”选项的unique函数来保留状态的顺序:
function [worldStates, ps] = StateMerge(worldStates, ps)
    [M, N, L] = size (worldStates);
    worldStates1 = reshape(worldStates, M*N, L).';
    [~, uc, ui] = unique(worldStates1, 'rows');
    ps = accumarray(ui, ps(:));
    worldStates = worldStates (:, :, uc);
end

该解决方案运行良好,但有一个警告,即我在形式上收到了ps,其形式为[1,1,L],导致在accumarray中出现异常。不过这很容易通过reshape来解决,而且我在我的问题中并没有完全清楚地表达出来。 - Emil Jansson
1
@EmilJansson 我编辑了答案,重新调整了 ps - rahnema1
2
参考测试结果显示,该函数在 L = 1000 的情况下比原来快了10倍,而且对于更大的L,其正面效果似乎更加明显。 - Emil Jansson
1
accumarray 万岁 :-) - Luis Mendo
1
你能否使用带有“stable”参数的“unique”来保留一些顺序? - Wolfie
@Wolfie 谢谢!我已经将你的建议添加到答案中了。 - rahnema1

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接