计算第N个既是三角数又是平方数的数。

3

这道题目是练习比赛中的一道题:

计算第N个既是三角形数又是平方数的数,对10006699取模。(1 ≤ N ≤ 10^18) 最多有10^5组测试数据。

我发现可以使用递推关系式Ti = 6Ti-1 - Ti-2 + 2 来轻松计算,其中T0 = 0T1 = 1

我正在使用矩阵指数运算来实现每个测试用例大约O(log N)的性能,但很明显太慢了,因为有10^5个测试用例。事实上,即使在限制条件仅为(1 ≤ N ≤ 10^6)的情况下,这段代码也太慢了,我可以只做O(N)的预处理和O(1)的查询。

我应该改变解决问题的方法还是只优化代码的一些部分呢?

#include <ios>
#include <iostream>
#include <vector>
#define MOD 10006699

/*
Transformation Matrix:

 0 1 0   t[i]     t[i+1]
-1 6 1 * t[i+1] = t[i+2]
 0 0 1     2        2
*/

std::vector<std::vector<long long int> > multi(std::vector<std::vector<long long int> > a, std::vector<std::vector<long long int> > b)
{
    std::vector<std::vector<long long int> > c(3, std::vector<long long int>(3));
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            for (int k = 0; k < 3; k++)
            {
                c[i][j] += (a[i][k] * b[k][j]) % MOD;
                c[i][j] %= MOD;
            }
        }
    }
    return c;
}

std::vector<std::vector<long long int> > power(std::vector<std::vector<long long int> > vec, long long int p)
{
    if (p == 1) return vec;
    else if (p % 2 == 1) return multi(vec, power(vec, p-1));
    else
    {
        std::vector<std::vector<long long int> > x = power(vec, p/2);
        return multi(x, x);
    }
}

int main()
{
    std::ios_base::sync_with_stdio(false);
    long long int n;
    while (std::cin >> n)
    {
        if (n == 0) break;
        else
        {
            std::vector<std::vector<long long int> > trans;
            long long int ans;
            trans.resize(3);

            trans[0].push_back(0);  
            trans[0].push_back(1);
            trans[0].push_back(0);
            trans[1].push_back(-1);
            trans[1].push_back(6);
            trans[1].push_back(1);
            trans[2].push_back(0);
            trans[2].push_back(0);
            trans[2].push_back(1);

            trans = power(trans, n);

            ans = (trans[0][1]%MOD + (2*trans[0][2])%MOD)%MOD;

            if (ans < 0) ans += MOD;

            std::cout << ans << std::endl;
        }
    }
}

1
你能解释一下如何得出这个公式吗? - Pham Trung
1个回答

1

注意: 我删除了旧回答,这个更有用

看来你很难创造出比O(log N)更好渐进算法来解决这个问题。但是,你可以对当前的代码进行一些修改,在不改变渐进时间复杂度的情况下提高性能。

以下是一个修改版的代码,它产生相同的答案:

#include <ctime>
#include <ios>
#include <iostream>
#include <vector>
#define MOD 10006699

void power(std::vector<std::vector<long long int> >& vec, long long int p)
{
    if (p == 1)
        return;

    else if (p & 1)
    {
        std::vector<std::vector<long long int> > copy1 = vec;
        power(copy1, p-1);

        std::vector<std::vector<long long int> > copy2(3, std::vector<long long int>(3));
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++)
            {
                for (int k = 0; k < 3; k++)
                    copy2[i][j] += (vec[i][k] * copy1[k][j]) % MOD;
                copy2[i][j] %= MOD;
            }
        vec = copy2;

        return;
    }

    else
    {
        power(vec, p/2);

        std::vector<std::vector<long long int> > copy(3, std::vector<long long int>(3));
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++)
            {
                for (int k = 0; k < 3; k++)
                    copy[i][j] += (vec[i][k] * vec[k][j]) % MOD;
                copy[i][j] %= MOD;
            }
        vec = copy;

        return;
    }
}

int main()
{
    std::ios_base::sync_with_stdio(false);
    long long int n;
    while (std::cin >> n)
    {
        std::clock_t start = std::clock();
        if (n == 0) break;

        std::vector<std::vector<long long int> > trans;
        long long int ans;
        trans.resize(3);

        trans[0].push_back(0);  
        trans[0].push_back(1);
        trans[0].push_back(0);
        trans[1].push_back(-1);
        trans[1].push_back(6);
        trans[1].push_back(1);
        trans[2].push_back(0);
        trans[2].push_back(0);
        trans[2].push_back(1);

        power(trans, n);

        ans = (trans[0][1]%MOD + (2*trans[0][2])%MOD)%MOD;
        if (ans < 0) ans += MOD;
        std::cout << "Answer: " << ans << std::endl;

        std::cout << "Time: " << (std::clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << std::endl;
    }
}

主要的区别如下:
  • c[i][j] %= MOD; 的代码移动到 k 循环之外
  • 通过引用传递向量
  • 减少函数调用次数

如果我在您的 main 的 while 循环中放置相同的计时代码,就像我在我的代码中一样,将您的文件命名为 "before.cpp",将我的文件命名为 "after.cpp",并使用完全优化运行每个文件 10 次,则这是我的结果:

Alexanders-MBP:Desktop alexandersimes$ g++ before.cpp -O3 -o before
Alexanders-MBP:Desktop alexandersimes$ ./before 
1000000000000000000
Answer: 6635296
Time: 0.708 ms
1000000000000000000
Answer: 6635296
Time: 0.542 ms
1000000000000000000
Answer: 6635296
Time: 0.688 ms
1000000000000000000
Answer: 6635296
Time: 0.634 ms
1000000000000000000
Answer: 6635296
Time: 0.626 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.632 ms
1000000000000000000
Answer: 6635296
Time: 0.695 ms

Alexanders-MBP:Desktop alexandersimes$ g++ after.cpp -O3 -o after
Alexanders-MBP:Desktop alexandersimes$ ./after 
1000000000000000000
Answer: 6635296
Time: 0.283 ms
1000000000000000000
Answer: 6635296
Time: 0.287 ms
1000000000000000000
Answer: 6635296
Time: 0.27 ms
1000000000000000000
Answer: 6635296
Time: 0.27 ms
1000000000000000000
Answer: 6635296
Time: 0.266 ms
1000000000000000000
Answer: 6635296
Time: 0.265 ms
1000000000000000000
Answer: 6635296
Time: 0.266 ms
1000000000000000000
Answer: 6635296
Time: 0.267 ms
1000000000000000000
Answer: 6635296
Time: 0.21 ms
1000000000000000000
Answer: 6635296
Time: 0.208 ms

除非他找到一种O(1)的方法,否则我怀疑他的渐近性能是否比O(log N)更好。然而,我将更改while循环计时以包括O(1)操作,因为它们对于小N是有影响的。(编辑:这句话是针对已删除评论的人说的) - asimes
我无法编辑我的评论,因此将其删除。顺便说一句,这个建议不会在时间复杂度上有任何改进,并且可能会导致溢出问题,因为它在没有应用模数的情况下不断添加到“c”中。 - Pham Trung
原帖中提到 MOD10006699,这个值远未超出 long long int 的范围。在取模之前,(a[i][k] * b[k][j]) 可能会溢出,但是进行三次加法运算不会导致 long long int 溢出。 - asimes
将幂的方法从递归改为迭代,并避免创建新的“向量”,可能会稍微提高性能。虽然有27个加法,但是没错,这不会导致溢出。 - Pham Trung
@PhamTrung,你是对的,避免创建新向量确实有很大的区别,我完全改变了我的代码,并将其放在上面的答案中。此外,虽然有27个加法,但模数每3次迭代发生一次,这就是为什么我在溢出的情况下将其表述为3个加法的原因。 - asimes

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接