大整数(BigInteger)的实现和性能

4

我在C++中编写了一个BigInteger类,应该能够对任何大小的数字执行操作。目前,我正在尝试通过比较现有算法并测试它们适用于哪些位数来实现非常快速的乘法方法,但我遇到了非常意外的结果。我尝试进行了20次500位数的乘法并计时。这是结果:

karatsuba:
  14.178 seconds

long multiplication:
  0.879 seconds

维基百科告诉我:

因此,对于足够大的n,Karatsuba算法将执行比长手乘法更少的移位和单个数字相加,即使其基本步骤使用比直接公式更多的加法和移位。然而,对于小的n值,额外的移位和加法操作可能使其比长手方法运行得更慢。正回报点取决于计算机平台和上下文。作为经验法则,当乘数长度大于320-640位时,Karatsuba通常更快。

由于我的数字至少有1500位长,这是很意外的,因为维基百科说karatsuba应该运行得更快。我认为我的问题可能在于我的加法算法,但我不知道如何使它更快,因为它已经以O(n)的速度运行。我将在下面发布我的代码,以便您可以检查我的实现。我会略去不相关的部分。
我还在想,也许我使用的结构不是最好的。我用小端表示每个数据段。例如,如果我有数字123456789101112存储在长度为3的数据段中,它将如下所示:

{112,101,789,456,123}

所以我现在问什么是实现BigInteger类的最佳结构和最佳方法?为什么karatsuba算法比长乘法慢?

这是我的代码:(对于长度我很抱歉)

using namespace std;

bool __longmult=true;
bool __karatsuba=false;

struct BigInt {
public:
    vector <int> digits;

    BigInt(const char * number) {
        //constructor is not relevant   
    }
    BigInt() {}

    void BigInt::operator = (BigInt a) {
        digits=a.digits;
    }

    friend BigInt operator + (BigInt,BigInt);
    friend BigInt operator * (BigInt,BigInt);

    friend ostream& operator << (ostream&,BigInt);
};

BigInt operator + (BigInt a,BigInt b) {
    if (b.digits.size()>a.digits.size()) {
        a.digits.swap(b.digits); //make sure a has more or equal amount of digits than b
    }
    int carry=0;

    for (unsigned int i=0;i<a.digits.size();i++) {
        int sum;
        if (i<b.digits.size()) {
            sum=b.digits[i]+a.digits[i]+carry;
        } else if (carry==1) {
            sum=a.digits[i]+carry;
        } else {
            break; // if carry is 0 and no more digits in b are left then we are done already
        }

        if (sum>=1000000000) {
            a.digits[i]=sum-1000000000;
            carry=1;
        } else {
            a.digits[i]=sum;
            carry=0;
        }
    }

    if (carry) {
        a.digits.push_back(1);
    }

    return a;
}

BigInt operator * (BigInt a,BigInt b) {
    if (__longmult) {
        BigInt res;
        for (unsigned int i=0;i<b.digits.size();i++) {
            BigInt temp;
            temp.digits.insert(temp.digits.end(),i,0); //shift to left for i 'digits'

            int carry=0;
            for (unsigned int j=0;j<a.digits.size();j++) {
                long long prod=b.digits[i];
                prod*=a.digits[j];
                prod+=carry;
                int t=prod%1000000000;
                temp.digits.push_back(t);
                carry=(prod-t)/1000000000;
            }
            if (carry>0) {
                temp.digits.push_back(carry);
            }
            res+=temp;
        }
        return res;
    } else if (__karatsuba) {
        BigInt res;
        BigInt a1,a0,b1,b0;
        assert(a.digits.size()>0 && b.digits.size()>0);
        while (a.digits.size()!=b.digits.size()) { //add zeroes for equal size
            if (a.digits.size()>b.digits.size()) {
                b.digits.push_back(0);
            } else {
                a.digits.push_back(0);
            }
        }

        if (a.digits.size()==1) {
            long long prod=a.digits[0];
            prod*=b.digits[0];

            res=prod;//conversion from long long to BigInt runs in constant time
            return res;

        } else {
            for (unsigned int i=0;i<a.digits.size();i++) {
                if (i<(a.digits.size()+(a.digits.size()&1))/2) { //split the number in 2 equal parts
                    a0.digits.push_back(a.digits[i]);
                    b0.digits.push_back(b.digits[i]);
                } else {
                    a1.digits.push_back(a.digits[i]);
                    b1.digits.push_back(b.digits[i]);
                }
            }
        }

        BigInt z2=a1*b1;
        BigInt z0=a0*b0;
        BigInt z1 = (a1 + a0)*(b1 + b0) - z2 - z0;

        if (z2==0 && z1==0) {
            res=z0;
        } else if (z2==0) {
            z1.digits.insert(z1.digits.begin(),a0.digits.size(),0);
            res=z1+z0;
        } else {
            z1.digits.insert(z1.digits.begin(),a0.digits.size(),0);
            z2.digits.insert(z2.digits.begin(),2*a0.digits.size(),0);
            res=z2+z1+z0;
        }

        return res;
    }
}

int main() {
    clock_t start, end;

    BigInt a("984561231354629875468546546534125215534125215634987498548489456125215421563498749854848945612385663498749854848945612521542156349874985484894561238561698774565123165221393856169877456512316552156349874985484894561238561698774565123165221392213935215634987498548489456123856169877456512316522139521563498749854848945612385616987745651231652213949651465123151354686324848945612385616987745651231652213949651465123151354684132319321005482265341252156349874985484894561252154215634987498548489456123856264596162131");
    BigInt b("453412521563498749853412521563498749854848945612521542156349874985484894561238565484894561252154215634987498548489456123856848945612385616935462987546854521563498749854848945653412521563498749854848945612521542156349874985484894561238561238754579785616987745651231652213965465341235215634987495215634987498548489456123856169877456512316522139854848774565123165223546298754685465465341235215634987498548354629875468546546534123521563498749854844139496514651231513546298754685465465341235215634987498548435468");

    __longmult=false;
    __karatsuba=true;

    start=clock();
    for (int i=0;i<20;i++) {
        a*b;
    }
    end=clock();
    printf("\nTook %f seconds\n", (double)(end-start)/CLOCKS_PER_SEC);

    __longmult=true;
    __karatsuba=false;

    start=clock();
    for (int i=0;i<20;i++) {
        a*b;
    }
    end=clock();
    printf("\nTook %f seconds\n", (double)(end-start)/CLOCKS_PER_SEC);

    return 0;
}

看起来你把小于10^9的数字当作了一个位数。这么做会让Karatsuba算法与朴素乘法越来越相似。建议你把位数定为小于10的数字。另外,你确定进位的最大值只有1吗? - Abhishek Bansal
这正是我之前所做的。结果:16秒内完成了500位数的乘法... 我认为使用<10^9作为数字很方便,因为这些值在硬件中只需几个时钟就可以相乘,而总是有比软件更好的硬件实现。 是的,最坏情况是999 999 999 + 999 999 999,这将给出1 999 999 998,这将给出进位值1。 - gelatine1
你能展示你的整个代码,以便其他人可以尝试吗? - Abhishek Bansal
此外,根据您的系统,您可能会遇到int溢出的问题。 - Abhishek Bansal
你应该使用性能分析器来找出时间花费在哪里。 - Oliver Charlesworth
完整代码在此处可用:(654行) http://codepad.org/dzvPbqIV - gelatine1
1个回答

3
  1. You use std::vector

    for your digits make sure there are no unnecessary reallocations in it. So allocate space before operation to avoid it. Also I do not use it so I do not know the array range checking slowdowns.

    Check if you do not shift it !!! which is O(N) ... i.e. insert to first position...

  2. Optimize your implementation

    here you can find mine implementation optimized an unoptimized for comparison

    x=0.98765588997654321000000009876... | 98*32 bits...
    mul1[ 363.472 ms ]... O(N^2) classic multiplication
    mul2[ 349.384 ms ]... O(3*(N^log2(3))) optimized karatsuba multiplication
    mul3[ 9345.127 ms]... O(3*(N^log2(3))) unoptimized karatsuba multiplication 
    

    mine implementation threshold for Karatsuba is around 3100bits ... ~ 944 digits!!! The more optimized the code the lover threshold is.


    Try to remove unnecessary data from function operands

    //BigInt operator + (BigInt a,BigInt b)
    BigInt operator + (const BigInt &a,const BigInt &b)
    

    this is way you will not create another copy of a,b on heap in every + call also even faster is this:

    mul(BigInt &ab,const BigInt &a,const BigInt &b) // ab = a*b
    
  3. Schönhage-Strassen multiplication

    this one is FFT or NTT based. Mine threshold for it is big ... ~ 49700bits ... ~ 15000digits so if you do not plan to use such big numbers then forget about it. Implementation is also in the link above.


    here is mine NTT implementation (optimized as much as I could)

  4. Summary

    Does not matter if you use little or big endian but you should code your operations in a way that they do not use insert operations.


    You use decadic base for digits that is slow because you need to use division and modulo operations. If you choose base as power of 2 then just bit operations are enough and also it removes many if statements from code which are slowing all the most. If you need the base as power of 10 then use biggest you can in some cases this reduce the div,mod to few subtractions

    2^32 = 4 294 967 296 ... int = +/- 2147483648
    base = 1 000 000 000
    
    //x%=base
    while (x>=base) x-=base;
    

    max number of cycles is 2^32/base or 2^31/base on some platform is this faster then modulo and also the bigger the base the less operations you need but beware the overflows !!!


所以,您建议使用 vector <bool> digits 代替 vector <int> digits; 并将其用作二进制表示? - gelatine1
@gelatine1 不,你仍然可以使用vector<int>,但基数不是1000000000而是0x100000000,唯一需要添加的是十进制和十六进制字符串之间的转换,就像这里所示:https://dev59.com/BXE85IYBdhLWcg3wr1yx#18231860。我更喜欢十六进制表示,因为它在打印和赋值操作时要快得多(没有除法,只有位运算)。当然,对于大数字,这种转换比对它们本身的操作要慢,所以请注意你正在测量的时间。 - Spektre

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接