在C++中实现高精度数的模幂运算

17

最近我一直在实现Miller-Rabin素性测试。我将它限制在32位所有数字的范围内,因为这是一个只是为了熟悉C++的有趣项目,而且我不想有一段时间内使用64位的任何东西。一个额外的好处是,该算法对于所有32位数都是确定性的,所以我可以明确测试哪些证人以显著提高效率。

对于较小的数字,该算法工作得非常出色。然而,该过程的一部分依赖于模指数运算,即(num ^ pow) % mod。例如:

3 ^ 2 % 5 = 
9 % 5 = 
4

这是我一直在使用的模幂运算代码:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned test;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    }

    return test;

}

正如你可能已经猜到的那样,当参数都是异常大的数字时就会出现问题。例如,如果我想要测试数字673109是否为素数,我将在某个时候需要找到:

(2 ^ 168277) % 673109

现在2 ^ 168277是一个异常大的数字,在过程中它会发生溢出,导致评估不正确。

另一方面,像

4000111222 ^ 3 % 1608

这样的参数也由于同样的原因而计算不正确。

有人有关于模幂运算的建议,以防止这种溢出并/或操纵它产生正确的结果吗?(在我看来,溢出只是模运算的另一种形式,即num % (UINT_MAX+1))

6个回答

8
平方取幂算法在模指数运算中仍然有效。你的问题不是2 ^ 168277是一个异常大的数字,而是你的其中一个中间结果是一个相当大的数字(比2 ^ 32还要大),因为673109比2 ^ 16还要大。
所以我认为以下方法可以解决问题。我可能会漏掉一些细节,但基本思想是正确的,这就是“真实”加密代码如何进行大模指数运算的方式(虽然不使用32位和64位数字,而是使用永远不必扩展到大于2 * log(模数)的大数):
  • 首先使用平方取幂算法。
  • 在64位无符号整数中执行实际平方操作。
  • 在每个步骤中对673109进行模运算,以回到32位范围内,就像你所做的那样。
显然,如果你的C ++实现没有64位整数,那么有点棘手,尽管你总是可以伪造一个。
在这里有第22页的示例:http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf,尽管它使用非常小的数字(小于2^16),因此可能不会展示任何你不知道的内容。
您的另一个示例4000111222 ^ 3 % 1608将在您当前的代码中起作用,如果您在开始之前只是将4000111222减少到1608模数。1608足够小,您可以安全地将任意两个模1608的数字相乘在32位int中。

谢谢,伙计,这就解决了问题。只是出于好奇,你知道有没有不需要使用更大内存的方法吗?我相信它们会很有用。 - Axel Magnuson
据我所知没有这样的方法。您需要将两个数字相乘,直到达到673108,然后对673109取模。显然,您可以将其分解并使用较小的“数字”进行长乘法,例如2^10。但是,一旦您在软件中实现了乘法和除法,最好就为将两个32位值相乘以得到64位结果的特殊情况实现它,然后进行除法以提取32位余数。可能有一些硬核优化只做您需要的最少工作,但我不知道它们,而在C++中伪造一个64位int也不是那么难。 - Steve Jessop

7

最近我用C++写了一些关于RSA的东西,但有点凌乱。

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
    digits.push_back(0);
    negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) {
        n = b.digits.size();
    }
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) {
            a_digit = a.digits[i];
        }
        if (i < (int)b.digits.size()) {
            b_digit = b.digits[i];
        }
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    }
    if (sum_n_carry != 0) {
        putCarryInfront(c, sum_n_carry);
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) {
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
        } else {
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        }
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) {
            //std::cout << "-1" << std::endl;
            return -1;
        } else if (a_digit > b_digit) {
            //std::cout << "+1" << std::endl;
            return +1;
        }
    }
    //std::cout << "0" << std::endl;
    return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    }
    if (mult_n_carry != 0) {
        putCarryInfront(c, mult_n_carry);
    }
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) {
        b.digits[i] = 0;
    }
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i + times] = a.digits[i];
    }
}

void BigInteger::shiftRight(BigInteger& a) {
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) {
            if ((a.digits[i+1] & 0x1) != 0) {
                a.digits[i] |= 0x8000;
            }
        }
    }
    //std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    }
    if (lastBit) {
        a.digits.push_back(1);
    }
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i] = a.digits[i];
    }
    a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) {
        shiftLeft(e);
        shiftLeft(f);
    }
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) {
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) {
            shiftRight(e);
            shiftRight(f);
        }
    }
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
    digits = other.digits;
    negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') {
        make_negative = true;
        ++c;
    }
    while (*c != 0) {
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    }
    negative = make_negative;
}

bool BigInteger::isOdd() const {
    return (digits[0] & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
    BigInteger result;
    if (negative) {
        if (other.negative) {
            result.negative = true;
            addWithoutSign(result, *this, other);
        } else {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = false;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = true;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        }
    } else {
        if (other.negative) {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = true;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = false;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        } else {
            result.negative = false;
            addWithoutSign(result, *this, other);
        }
    }
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) {
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    }
    if (negative != other.negative) {
        result.negative = true;
    } else {
        result.negative = false;
    }
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
    if (negative) {
        if (other.negative) {
            return cmpWithoutSign(*this, other) < 0;
        } else {
            return false;
        }
    } else {
        if (other.negative) {
            return true;
        } else {
            return cmpWithoutSign(*this, other) > 0;
        }
    }
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) {
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) {
            *this *= base;
            *this %= modulus;
        }
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    }
    return *this;
}

std::string BigInteger::toString() const {
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) {
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    }
    while (!s.empty()) {
        os << s.top();
        s.pop();
    }
    /*
    for (int i = digits.size()-1; i >= 0; --i) {
        os << digits[i];
        if (i != 0) {
            os << ",";
        }
    }
    */
    return os.str();

并提供一个使用示例。

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

它也很快,并且具有无限数量的数字。


谢谢分享!一个问题,digit是std::vector<unsigned short>吗? - darius
是的,但在底层使用的是65536进制而不是10进制。 - clinux

4

两件事情:

  • 你是否使用了适当的数据类型?换句话说,UINT_MAX是否允许你将673109作为参数?

不,它不允许,因为在某个时刻你有一个代码行num = 2^16,这会导致溢出。使用更大的数据类型来保存此中间值。

  • 如何在每个可能的溢出机会上进行取模操作,例如:

    test = ((test % mod) * (num % mod)) % mod;

编辑:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    }

    return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;
}

1
    package playTime;

    public class play {

        public static long count = 0; 
        public static long binSlots = 10; 
        public static long y = 645; 
        public static long finalValue = 1; 
        public static long x = 11; 

        public static void main(String[] args){

            int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};  

            x = BME(x, count, binArray); 

            System.out.print("\nfinal value:"+finalValue);

        }

        public static long BME(long x, long count, int[] binArray){

            if(count == binSlots){
                return finalValue; 
            }

            if(binArray[(int) count] == 1){
                finalValue = finalValue*x%y; 
            }

            x = (x*x)%y; 
            System.out.print("Array("+binArray[(int) count]+") "
                            +"x("+x+")" +" finalVal("+              finalValue + ")\n");

            count++; 


            return BME(x, count,binArray); 
        }

    }

这是我很快用Java编写的代码。我使用的示例是11^644mod 645 = 1。我们知道645的二进制是1010000100。我有点作弊,硬编码了变量,但它可以正常工作。 - ShowLove
输出为 Array(0) x(121) finalVal(1) Array(0) x(451) finalVal(1) Array(1) x(226) finalVal(451) Array(0) x(121) finalVal(451) Array(0) x(451) finalVal(451) Array(0) x(226) finalVal(451) Array(0) x(121) finalVal(451) Array(1) x(451) finalVal(391) Array(0) x(226) finalVal(391) Array(1) x(121) finalVal(1)最终值:1 - ShowLove

0

LL 是代表 long long int

LL power_mod(LL a, LL k) {
    if (k == 0)
        return 1;
    LL temp = power(a, k/2);
    LL res;

    res = ( ( temp % P ) * (temp % P) ) % P;
    if (k % 2 == 1)
        res = ((a % P) * (res % P)) % P;
    return res;
}

使用上述递归函数来查找数字的模指数。这不会导致溢出,因为它以自下而上的方式进行计算。

对于以下示例测试运行: a = 2k = 168277,输出结果为518358,这是正确的,并且该函数在O(log(k))时间内运行;


-1
你可以使用以下恒等式:
(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)
尝试直接使用它,并逐步改进。
    if (pow & 1)
        test = ((test % mod) * (num % mod)) % mod;
    num = ((num % mod) * (num % mod)) % mod;

2
谢谢你们两位的建议,但由于算法的特性,测试和数字将始终小于模数,因此:{(test%mod)= test}和{(num%mod)= test},因此该身份证无法帮助我,因为即使num和test小于mod时,函数也会失败。此外,无符号整数允许我将673109作为参数。 UINT_MAX = 4 294 967 295适用于我的计算机。 - Axel Magnuson

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接