Java中的多线程分段埃拉托斯特尼筛法

7
我正在尝试在Java中创建一个快速的素数生成器。通常认为,最快的方法是Eratosthenes分段筛法:https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes。可以进一步实现许多优化以使其更快。截至目前,我的实现在大约1.6秒内生成了10^9以下的50847534个质数,但我希望让它更快,并至少突破1秒的限制。为了增加获得良好回复的机会,我将包括算法和代码的演示。
作为一个TL;DR,我仍然希望将多线程编程纳入代码中。
为了回答这个问题,我想要区分'Eratosthenes筛法'的“分段”和“传统”筛法。传统筛法需要O(n)的空间,因此在输入范围(它的极限)上非常有限。而分段筛法只需要O(n^0.5)的空间,并且可以处理更大范围的输入(主要加速是使用缓存友好的分段,考虑特定计算机的L1&L2缓存大小)。最后,涉及我的问题的主要区别是传统筛法是顺序的,意味着只有在完成前面的步骤后才能继续进行。然而,分段筛法不是这样的。每个段都是独立的,并且针对素数筛选(小于n^0.5的素数)单独进行“处理”。这意味着理论上,一旦我有了素数筛选,我就可以在多台计算机之间分配工作,每个计算机处理不同的段。彼此的工作是独立的。假设(错误地)每个段需要相同的时间t来完成,且有k个段,则一个计算机需要总时间T = k * t,而k台计算机,每个处理不同的段,则需要总时间T = t来完成整个过程。(实际上,这是错误的,但为了简单起见,这是一个例子)。
这让我开始阅读有关多线程的文章——将工作分配给几个线程,每个线程处理较少量的工作,以更好地利用CPU。据我所知,传统筛法无法进行多线程操作,因为它是顺序的。每个线程都会依赖于前一个线程,使整个想法变得不可行。但是,分段筛法可能确实可以进行多线程处理(我认为)。
在直接进入我的问题之前,我认为介绍我的代码非常重要,因此我在此包含了我目前最快的分段筛法实现。我已经非常努力地工作了。花费了相当长的时间,慢慢地对其进行调整和优化。代码并不简单。我会说是相当复杂的。因此,我假设读者熟悉我正在介绍的概念,例如轮子分解、质数、分段等等。我已经包括了笔记,以使其更容易理解。
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;

public class primeGen {

    public static long x = (long)Math.pow(10, 9); //limit
    public static int sqrtx;
    public static boolean [] sievingPrimes; //the sieving primes, <= sqrtx

    public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes
    public static int [] gaps; //the gaps, according to the wheel. will enable skipping multiples of the wheel primes
    public static int nextp; // the first prime > wheel primes
    public static int l; // the amount of gaps in the wheel

    public static void main(String[] args)
    {
        long startTime = System.currentTimeMillis();

        preCalc();  // creating the sieving primes and calculating the list of gaps

        int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        }

        System.out.println(pi);

        long endTime = System.currentTimeMillis();
        System.out.println("Solution took "+(endTime - startTime) + " ms");
    }

    public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    }

    public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    }

    public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    }

}
目前,它在大约1.6秒内产生了10^9以下的50847534个质数。按我的标准来说,这非常令人印象深刻,但我希望让它更快,可能突破1秒的限制。即使如此,我认为它仍然可以更快。

整个程序基于轮筛法https://en.wikipedia.org/wiki/Wheel_factorization。我注意到使用包括所有小于19的质数的轮筛可以获得最快的结果。

public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes

这意味着跳过这些质数的倍数,从而缩小搜索范围。然后使用“preCalc”方法计算需要获取的数字之间的差距。如果我们在搜索范围内的数字之间进行这些跳跃,就可以跳过基本质数的倍数。
public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    } 

preCalc方法的末尾,调用simpleSieve方法,高效地筛选出之前提到的所有小于等于sqrtx的素数。这是一个简单的埃拉托色尼筛法,而不是分段筛法,但仍基于先前计算的轮筛法
 public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    } 

最后,我们来到算法的核心。我们从枚举所有小于等于<= sqrtx的质数开始,使用以下调用:

 long pi = pisqrtx();`

使用了以下方法:
public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    } 

然后,在初始化跟踪素数枚举的变量pi之后,我们执行所提到的分段,从第一个大于> sqrtx的质数开始枚举:
 int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        } 

我也将其作为一个注释包含在内,但也会解释一下。因为片段大小相对较小,我们不会在一个片段内遍历整个间隙列表,并检查它-这是多余的。(假设我们使用一个19-wheel)。但在程序的更广泛范围概述中,我们将利用整个间隙数组,所以变量u必须跟随它,而不是意外地超过它:

 while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            } 

使用更高的限制将最终生成更大的片段,这可能会导致需要检查即使在片段内我们也不会超过间隙列表。或者微调轮质数基数可能对程序产生影响。转换为位筛可以大大提高段限制。
  • 作为重要的附注,我知道高效的分段是考虑到L1&L2缓存大小的。我使用32,768 * 8 = 262,144 = 2^18的段大小获得最快的结果。我不确定我的计算机缓存大小是多少,但我认为它不可能那么大,因为我看到大多数缓存大小都小于等于32,768。尽管如此,在我的计算机上这产生了最快的运行时间,所以这就是选择的段大小。
  • 正如我之前提到的,我仍在努力大幅改进。根据我的介绍,我相信使用4个线程(对应4个核心),多线程可以将速度提高4倍。每个线程仍然使用分段筛法的思想,但处理不同的部分。将n分成4个相等的部分-线程,每个线程依次对其负责的n/4个元素进行分段,使用上述程序。我的问题是如何做到这一点?阅读有关多线程和示例的资料,不幸的是,没有给我带来任何关于如何有效实现它的见解。对我来说,与其背后的逻辑相反,线程似乎是顺序运行而不是同时运行。这就是为什么我将其从代码中排除以使其更易读。我真的很希望在此特定代码中提供一个代码示例,但一个好的解释和参考也许能够达到目的。
此外,我希望听到更多加速此程序的方法,如果您有任何想法,我很乐意倾听!真的希望使它非常快速和高效。谢谢!

@9000 1. 是的,每个段可以独立处理。段的处理完全取决于筛选素数,这些素数是预先计算的,并且在整个程序中都是相同的。每个段的更新不会对其他任何内容产生影响。2. 这正是我要问的,如何做到这一点。我的计算机有4个核心,因此有4个线程。在这个特定的程序中,我该如何创建这些线程以便同时运行?3. 是的,我明白了。为了简化,我假设它是线性的,并且每个段都以恒定的时间处理。显然,这些陈述是错误的。 - MC From Scratch
仅供比较,我的简单非分段C++筛法在ideone上花费3.74秒筛选到10^9。我认为单体筛法同样可以进行多线程处理——对于序列中的每个质数,在核心区域标记它,并生成一个独立的工作线程来完成任务。这样你就有了一个工作池,可以与核心工作线程并行无序地工作。只需要为它们提供一个工作窃取调度器即可。 - Will Ness
@WillNess 这很有道理。我的最快的非分段(简单)变体在 10^9 上需要 3.98 秒。但它也使用了 wheel。一般来说,C++ 在涉及计算时更快。至于多线程简单筛法,我很想看到它的实践,听起来很有趣。在简单筛法中,选择下一个“质数”的选择取决于前一个质数的所有倍数是否已被标记。因此,初始化新线程可能会选择一个实际上不是质数的“质数” - 它只是尚未被标记。 - MC From Scratch
或者也许不需要,只要下一个检测到的素数标记向前迈出几步,直到检测到它后面的第一个空洞 - 间隙 - 就是下一个素数。一旦找到了下一个素数P,我们可以立即从P^2开始标记,但我们必须等待上一个素数标记达到sqrt(N)。所以在sqrt(N)以下这是一个微妙的舞蹈,但在sqrt(N)以上则很简单。 - Will Ness
是的,那样更有意义,听起来可行。@WillNess - MC From Scratch
显示剩余2条评论
3个回答

1

像这样的示例应该可以帮助您入门。

解决方案概述:

  • 定义一个数据结构(“任务”),它包含一个特定的段;您可以将所有不可变的共享数据放入其中,以获得额外的整洁度。如果足够小心,您可以向所有任务传递一个公共的可变数组,以及段限制,并仅更新这些限制内的数组部分。这更容易出错,但可以简化结果合并步骤(据我所知;可能会有所不同)。
  • 定义一个数据结构(“结果”),用于存储任务计算的结果。即使您只是更新了共享的结果结构,您可能仍需要指示已更新该结构的哪个部分。
  • 创建一个可运行对象,接受一个任务,运行计算,并将结果放入给定的结果队列中。
  • 为任务创建一个阻塞输入队列和一个结果队列。
  • 使用接近机器核心数量的线程数创建ThreadPoolExecutor。
  • 将所有任务提交到线程池执行程序。它们将被安排在池中的线程上运行,并将其结果放入输出队列中,不一定按顺序。
  • 等待线程池中的所有任务完成。
  • 清空输出队列并将部分结果合并为最终结果。

通过在单独的任务中读取输出队列并将结果合并,或者根据连接步骤涉及的工作量更新可变共享输出结构,在synchronized下可能会实现额外的加速(也可能不会)。

希望这可以帮助到您。


当然有帮助,谢谢。我会仔细阅读并尝试使用它。 - MC From Scratch

1

是的,我确实熟悉他的工作。我的程序的一些部分,比如分割和轮因子分解,也存在于他的工作中。然而,我必须诚实地说,他算法中的一些关键概念超出了我的理解范围。其中大多数不是数学上的问题,而是计算上的问题。例如,“桶筛法”——他所说的在列表中存储素数倍数的方法。我不确定他具体指的是什么,也不知道如何实现它。我确实打算去了解,但现在还没有到那个时候。 - MC From Scratch
@MCFromScratch 分段筛法被描述为具有O(sqrt(N))的核心和分段大小。但另一种看法是,没有N,筛子是无限的,我们筛选连续质数平方之间(距离不小于)的“段”。当核心质数足够大时,任何“固定大小”的段都会开始错过一些来自这些大质数的命中,复杂度将恶化(我们对每个质数进行计算其在段内的起始倍数'm',但它可能根本不存在)。桶式筛法将这样的'm'保存到后面使用。 - Will Ness
@WillNess 这实际上是一个非常好而简单的解释。这对于哪些输入有用?它是如何实现的?我猜它需要预先计算那些倍数的表格,这样我们就不会在算法内部生成一个不会在特定段中使用的倍数。 - MC From Scratch
“节省时间”——只有在某些质数变得比段大小更大的情况下才会发生。当大小足够且没有缺失时,不会浪费任何努力,也不需要桶。我们可以增加段大小,但是到某个时候它将无法再适合缓存中了。 - Will Ness
@WillNess 真的吗?这很奇怪。您还有哪些其他优化方法?您的算法中还包括什么?例如,您知道如何切换到位筛分割而不是直接布尔分割筛法吗? - MC From Scratch
显示剩余5条评论

0

你对速度有多感兴趣?你会考虑使用C++吗?

$ time ../c_code/segmented_bit_sieve 1000000000
50847534 primes found.

real    0m0.875s
user    0m0.813s
sys     0m0.016s
$ time ../c_code/segmented_bit_isprime 1000000000
50847534 primes found.

real    0m0.816s
user    0m0.797s
sys     0m0.000s

(在我新买的i5笔记本电脑上)

第一种方法是由@Kim Walisch使用的奇数质数候选位数组。

https://github.com/kimwalisch/primesieve/wiki/Segmented-sieve-of-Eratosthenes

第二个是我的修改版Kim的代码,其中IsPrime[]也被实现为位数组,读起来稍微不太清晰,尽管对于大N而言,由于减少了内存占用,速度略微更快。
我会仔细阅读您的帖子,因为我对质数和性能感兴趣,无论使用什么语言。我希望这不会太偏离主题或过早。但我注意到我已经超出了您的性能目标。

我的当前版本比这里发布的那个更快。它在10^9的情况下运行大约需要1.05秒。我故意不使用C++,因为我想用Java编程——这是我用于大多数目的的主要语言,这就是为什么我需要这种特定语言的代码。我知道C++在这方面优于Java,但还是有建议可以让代码运行更快,非常感谢。 - MC From Scratch

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接