找出所有能被某个特定数整除的掩码的子集。

4
所有4位数1101的子集是00000001010001011000100111001011。该掩码的所有可被2整除的子集是0000010010001100
给定一个64位掩码M和一个64位整数P,如何迭代所有可被P整除的M的子集?
要迭代位掩码的子集,可以这样做:
uint64_t superset = ...;
uint64_t subset = 0;
do {
    print(subset);
    subset = (subset - superset) & superset;
} while (subset != 0);

如果M~0,我可以从0开始,并持续加上P以遍历所有P的倍数。如果P是2的幂,则可以使用M &= ~(P - 1)来截断永远不会设置的位。
但如果没有上述约束条件,我是否有比朴素地检查每个子集是否可被P整除更好的方法?这种朴素算法平均需要O(P)次操作来获得下一个可被P整除的子集。我能做得比O(P)更好吗?

3
这个问题是怎么出现的?是一个学术问题还是挑战性练习?还是源于某个实际问题?一个位集既被视为通过位掩码表示集合和子集,又被视为某个整数的倍数,这是不寻常的。了解它的起因可能会激发直觉或关于解决方案的思考。 - Eric Postpischil
3
这个问题是如何产生的呢?它是一个学术问题还是挑战性练习?或者它是由某个实际问题引起的?一个比特集合既可以通过位掩码表示为集合和子集,又可以表示为某个整数的倍数,这是不寻常的。了解它的产生方式可能会激发直觉或对解决方案的思考。 - Eric Postpischil
3
这个问题是如何产生的呢?是一个学术问题还是一个挑战性的练习?还是源于某个实际问题?一个比特集合既可以通过位掩码表示为集合和子集,又可以表示为某个整数的倍数,这是不寻常的。了解问题的产生方式可能会激发直觉或对解决方案的思考。 - undefined
1
1011如何是1101的子集? - גלעד ברקן
1
1011如何是1101的子集? - גלעד ברקן
显示剩余4条评论
1个回答

1

一个并行算法

对于某些输入来说,检查因子的倍数比检查掩码的子集要高效得多,而对于其他输入则相反。例如,当M0xFFFFFFFFFFFFFFFFP0x4000000000000000时,检查P的三个倍数几乎是瞬间完成的,但即使你每秒能计算和检查十亿个M的子集,枚举所有子集也需要三十年的时间。只找到大于或等于P的子集的优化仅能将时间缩短到四年。

然而,有一个很强的理由来枚举和检查P的倍数,而不是M的子集:并行性。我想强调一下,在其他地方对这段代码的错误评论:OP中的算法是固有的顺序执行,因为每个subset的值都使用前一个subset的值。它不能运行,直到所有较低的子集已经计算完毕。它不能被向量化以使用AVX寄存器或类似的东西。你不能将四个值加载到一个AVX2寄存器中,并在其上运行SIMD指令,因为你需要计算第一个值来初始化第二个元素,第二个值来初始化第三个元素,所有三个值来初始化最后一个元素,然后你又回到了一次只计算一个值。它也不能在不同的CPU核心上的工作线程之间进行分割,这与前面的情况不同。(接受的答案可以修改以实现后者,但没有进行彻底重构就无法实现前者。)你不能将工作负载划分为0到63的子集、64到127的子集等,并让不同的线程并行处理每个子集,因为在你知道第63个子集是什么之前,你不能开始处理第64个子集,而要知道第63个子集,你需要知道第62个子集,依此类推。
如果你从这里没有学到其他东西,我强烈建议你启用全面优化,并亲自在Godbolt上尝试此代码,看看它是如何编译成顺序代码的。如果你熟悉OpenMP,请尝试添加#pragma omp simd#pramga omp parallel指令,看看会发生什么。问题不在于编译器,而是算法本身是顺序的。但至少通过观察真实的编译器行为,你应该能够确信2023年的编译器无法像这样对代码进行向量化。
供参考,以下是Clang 16对find的处理结果:
Find:                                   # @Find
        push    r15
        push    r14
        push    r12
        push    rbx
        push    rax
        mov     rbx, rdi
        cmp     rdi, rsi
        jne     .LBB1_1
.LBB1_6:
        lea     rdi, [rip + .L.str]
        mov     rsi, rbx
        xor     eax, eax
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        jmp     printf@PLT                      # TAILCALL
.LBB1_1:
        mov     r14, rdx
        mov     r15, rsi
        jmp     .LBB1_2
.LBB1_5:                                #   in Loop: Header=BB1_2 Depth=1
        imul    r12, r14
        add     r15, r12
        cmp     r15, rbx
        je      .LBB1_6
.LBB1_2:                                # =>This Inner Loop Header: Depth=1
        cmp     r15, rbx
        ja      .LBB1_7
        mov     rax, r15
        xor     rax, rbx
        blsi    r12, rax
        test    r12, rbx
        je      .LBB1_5
        mov     rdi, rbx
        sub     rdi, r12
        mov     rsi, r15
        mov     rdx, r14
        call    Find
        jmp     .LBB1_5
.LBB1_7:
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        ret

列举并检查倍数而不是子集
除了具有更多的并行性外,这种方法在速度上还有几个优点:
- 找到后继元素,即给定pi*p,将其应用于一个包含四个元素的向量,可以通过强制减少为单个加法操作。 - 测试因子是否为子集只需要一个与操作,而测试子集是否为因子则需要一个%操作,大多数CPU没有作为本地指令,并且即使存在时也始终是最慢的ALU操作。
因此,使用多线程和SIMD进行加速的代码版本如下:
#include <assert.h>
#include <omp.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>


typedef uint_fast32_t word;

/* Sets each element results[i], where i <= mask/factor, to true if factor*i
 * is a subset of the mask, false otherwise.  The results array MUST have at
 * least (mask/factor + 1U) elements.  The capacity of results in elements is
 * required and checked, just in case.
 *
 * Returns a pointer to the results.
 */
static bool* check_multiples( const word mask,
                              const word factor,
                              const size_t n,
                              bool results[n] )
{
    const word end = mask/factor;
    const word complement = ~mask;
    assert(&results);
    assert(n > end);

    #pragma omp parallel for simd schedule(static)
    for (word i = 0; i <= end; ++i) {
        results[i] = (factor*i & complement) == 0;
    }

    return results;
}

/* Replace these with non-constants so that the compiler actually
 * actually instantiates the function:
 */
/*
#define MASK 0xA0A0UL
#define FACTOR 0x50UL
#define NRESULTS (MASK/FACTOR + 1U)
 */
extern const word MASK, FACTOR;
#define NRESULTS 1024UL

int main(void)
{
    bool are_subsets[NRESULTS] = {false};
    (void)check_multiples(MASK, FACTOR, NRESULTS, are_subsets);

    for (word i = 0; i < NRESULTS; ++i) {
        if (are_subsets[i]) {
            const unsigned long long multiple = (unsigned long long)FACTOR*i;
            printf("%llx ", multiple);
            assert((multiple & MASK) == multiple && (multiple & ~MASK) == 0U);
        }
    }

    return EXIT_SUCCESS;
}

check_multiples的内部循环编译成在ICX 2022上

.LBB1_5:                                # =>This Inner Loop Header: Depth=1
        vpmullq         ymm15, ymm1, ymm0
        vpmullq         ymm16, ymm2, ymm0
        vpmullq         ymm17, ymm3, ymm0
        vpmullq         ymm18, ymm4, ymm0
        vpmullq         ymm19, ymm5, ymm0
        vpmullq         ymm20, ymm6, ymm0
        vpmullq         ymm21, ymm7, ymm0
        vpmullq         ymm22, ymm8, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm21, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm20, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm16, ymm9
        vptestnmq       k3, ymm15, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi], ymm15
        vpaddq          ymm15, ymm11, ymm7
        vpaddq          ymm16, ymm6, ymm11
        vpaddq          ymm17, ymm5, ymm11
        vpaddq          ymm18, ymm4, ymm11
        vpaddq          ymm19, ymm3, ymm11
        vpaddq          ymm20, ymm2, ymm11
        vpaddq          ymm21, ymm1, ymm11
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm11
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 32], ymm15
        vpaddq          ymm15, ymm12, ymm7
        vpaddq          ymm16, ymm6, ymm12
        vpaddq          ymm17, ymm5, ymm12
        vpaddq          ymm18, ymm4, ymm12
        vpaddq          ymm19, ymm3, ymm12
        vpaddq          ymm20, ymm2, ymm12
        vpaddq          ymm21, ymm1, ymm12
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm12
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 64], ymm15
        vpaddq          ymm15, ymm13, ymm7
        vpaddq          ymm16, ymm6, ymm13
        vpaddq          ymm17, ymm5, ymm13
        vpaddq          ymm18, ymm4, ymm13
        vpaddq          ymm19, ymm3, ymm13
        vpaddq          ymm20, ymm2, ymm13
        vpaddq          ymm21, ymm1, ymm13
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm13
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 96], ymm15
        vpaddq          ymm8, ymm8, ymm14
        vpaddq          ymm7, ymm14, ymm7
        vpaddq          ymm6, ymm14, ymm6
        vpaddq          ymm5, ymm14, ymm5
        vpaddq          ymm4, ymm14, ymm4
        vpaddq          ymm3, ymm14, ymm3
        vpaddq          ymm2, ymm14, ymm2
        vpaddq          ymm1, ymm14, ymm1
        sub             rsi, -128
        add             rdi, -4
        jne             .LBB1_5

我鼓励你在这个编译器中尝试使用相同的设置对算法进行变化,并观察结果。如果你认为在这些子集上生成向量化代码是可能的,并且和之前一样好,那么你应该多练习。

一个可能的改进

要检查的候选者数量可能非常大,但一种限制方法是同时计算P的乘法逆元,并在更好的情况下使用它。

每个P的值可分解为2ⁱ · Q,其中Q是奇数。由于Q和2⁶⁴互质,Q将具有模数乘法逆元Q',其乘积QQ'= 1 (mod 2⁶⁴)。你可以使用扩展欧几里德算法找到它(但不是我最初提出的方法)。

这对于优化算法非常有用,因为对于许多值的P,Q' < P。如果m是一个解,m = nP,其中n是整数。将两边乘以Q',得到Q'Pm = 2ⁱ · m = Q'n。这意味着我们可以列举出Q'或P的倍数(稍微增加一些逻辑以确保它们具有足够的尾零位)。请注意,由于Q'是奇数,不需要检查所有Q'的倍数;例如,如果前面的常数是4,则只需检查4·_Q'_的乘积。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接