最长递增子序列的数量

17

我正在练习算法,其中一个任务是计算给定 0<n≤10^6 个数字的所有最长递增子序列的数量。 解决方案 O(n^2) 不可行。

我已经实现了查找LIS及其长度的算法(LIS算法),但该算法会将数字转换为尽可能小的数字。因此,不能确定具有先前数字(较大数字)的子序列是否能够达到最长长度,否则我可以只计算这些转换次数。

有什么办法在约 O(nlogn) 的时间复杂度内解决吗?我知道应该使用动态规划来解决。

我已经实现了一种解决方案,并且它运行良好,但它需要两个嵌套循环 (i在1..n) x (j在1..i-1)。所以我认为它是O(n^2),然而它太慢了。

我甚至尝试将那些数字从数组移动到二叉树中(因为在每个 i 迭代中,我都会查找i-1..1元素中所有小于 number[i]的数字),但速度更慢。

示例测试:

1 3 2 2 4
result: 3 (1,3,4 | 1,2,4 | 1,2,4)

3 2 1
result: 3 (1 | 2 | 3)

16 5 8 6 1 10 5 2 15 3 2 4 1
result: 3 (5,8,10,15 | 5,6,10,15 | 1,2,3,4)

无法在O(nlogn)的时间复杂度内解决它。动态规划意味着您将存储先前的子序列,并在从主序列中获取下一个数字时,迭代遍历子序列列表,检查是否可以将此下一个数字添加到每个子序列中。 - Alex Salauyou
在构建这样的子序列时,您应该为每个存储最后一个项目和长度。如果您通过 O(1) 访问它们,则此算法将在 O(n*m) 中运行,其中 m 是存储的子序列数。在主序列降序的最坏情况下(如第二个示例中),n = m。 - Alex Salauyou
@Salau,你确定吗?没有办法通过某种方式仅计算这些子序列的数量吗?我不一定需要重建它们,我只需要它们的数量。这个LIS算法的时间复杂度是O(nlogn),但如果你先思考这个问题,你不会认为它是可能的,所以也许我的问题类似? :-) - Wojciech Kulik
@Salauyou:谢谢,它不一定要是O(n*logn)。我写了应该是关于什么的,因为我不确定这种情况下最好的复杂度是什么。等待您的帖子,谢谢。 - Wojciech Kulik
毕竟,我成功地实现了O(nlognlogn)的复杂度和O(n)的额外内存保证。稍后我会解释算法,因为在平板电脑上打字很困难。 - Alex Salauyou
显示剩余3条评论
5个回答

24

查找所有最长递增子序列的数量

下面是改进后的LIS算法的完整Java代码,它不仅可以发现最长递增子序列的长度,还可以发现这种长度的子序列的数量。我喜欢使用泛型,以允许不仅仅是整数,而是任何可比较的类型。

@Test
public void testLisNumberAndLength() {

    List<Integer> input = Arrays.asList(16, 5, 8, 6, 1, 10, 5, 2, 15, 3, 2, 4, 1);
    int[] result = lisNumberAndlength(input);
    System.out.println(String.format(
            "This sequence has %s longest increasing subsequenses of length %s", 
            result[0], result[1]
            ));
}


/**
 * Body of improved LIS algorithm
 */
public <T extends Comparable<T>> int[] lisNumberAndLength(List<T> input) {

    if (input.size() == 0) 
        return new int[] {0, 0};

    List<List<Sub<T>>> subs = new ArrayList<>();
    List<Sub<T>> tails = new ArrayList<>();

    for (T e : input) {
        int pos = search(tails, new Sub<>(e, 0), false);      // row for a new sub to be placed
        int sum = 1;
        if (pos > 0) {
            List<Sub<T>> pRow = subs.get(pos - 1);            // previous row
            int index = search(pRow, new Sub<T>(e, 0), true); // index of most left element that <= e
            if (pRow.get(index).value.compareTo(e) < 0) {
                index--;
            } 
            sum = pRow.get(pRow.size() - 1).sum;              // sum of tail element in previous row
            if (index >= 0) {
                sum -= pRow.get(index).sum;
            }
        }

        if (pos >= subs.size()) {                             // add a new row
            List<Sub<T>> row = new ArrayList<>();
            row.add(new Sub<>(e, sum));
            subs.add(row);
            tails.add(new Sub<>(e, 0));

        } else {                                              // add sub to existing row
            List<Sub<T>> row = subs.get(pos);
            Sub<T> tail = row.get(row.size() - 1); 
            if (tail.value.equals(e)) {
                tail.sum += sum;
            } else {
                row.add(new Sub<>(e, tail.sum + sum));
                tails.set(pos, new Sub<>(e, 0));
            }
        }
    }

    List<Sub<T>> lastRow = subs.get(subs.size() - 1);
    Sub<T> last = lastRow.get(lastRow.size() - 1);
    return new int[]{last.sum, subs.size()};
}



/**
 * Implementation of binary search in a sorted list
 */
public <T> int search(List<? extends Comparable<T>> a, T v, boolean reversed) {

    if (a.size() == 0)
        return 0;

    int sign = reversed ? -1 : 1;
    int right = a.size() - 1;

    Comparable<T> vRight = a.get(right);
    if (vRight.compareTo(v) * sign < 0)
        return right + 1;

    int left = 0;
    int pos = 0;
    Comparable<T> vPos;
    Comparable<T> vLeft = a.get(left);

    for(;;) {
        if (right - left <= 1) {
            if (vRight.compareTo(v) * sign >= 0 && vLeft.compareTo(v) * sign < 0) 
                return right;
            else 
                return left;
        }
        pos = (left + right) >>> 1;
        vPos = a.get(pos);
        if (vPos.equals(v)) {
            return pos;
        } else if (vPos.compareTo(v) * sign > 0) {
            right = pos;
            vRight = vPos;
        } else {
            left = pos;
            vLeft = vPos;
        }
    } 
}



/**
 * Class for 'sub' pairs
 */
public static class Sub<T extends Comparable<T>> implements Comparable<Sub<T>> {

    T value;
    int sum;

    public Sub(T value, int sum) { 
        this.value = value; 
        this.sum = sum; 
    }

    @Override public String toString() {
        return String.format("(%s, %s)", value, sum); 
    }

    @Override public int compareTo(Sub<T> another) { 
        return this.value.compareTo(another.value); 
    }
}

解释

由于我的解释似乎比较长,我将把初始序列称为“seq”,并将任何其子序列称为“sub”。因此,任务是计算可以从seq中获得的最长递增子序列的数量。

正如我之前提到的,想法是保持在先前步骤中获得的所有可能的最长子序列的计数。因此,让我们创建一个带有编号的行列表,其中每行的编号等于存储在该行中的子序列的长度。并且让我们将子序列存储为数字对(v, c),其中“v”是结束元素的值,“c”是以“v”结尾的给定长度的子序列的数量。例如:

1: (16, 1) // that means that so far we have 1 sub of length 1 which ends by 16.

我们将逐步构建这样的列表,按顺序从初始序列中获取元素。在每个步骤中,我们将尝试将此元素添加到可以添加到最长子序列,并记录更改。

构建列表

让我们使用您示例中的序列构建列表,因为它具有所有可能的选项:
 16 5 8 6 1 10 5 2 15 3 2 4 1

首先,取出元素16。由于我们的列表目前为空,因此我们只需将一个键值对放入其中:

1: (16, 1) <= one sub that ends by 16

接下来是5。它不能添加到以16结尾的子集中,因此它将创建一个长度为1的新子集。我们创建一对(5,1)并将其放入第1行:

1: (16, 1)(5, 1)

下一个元素是8。它无法创建长度为2的子数组[16, 8],但可以创建子数组[5, 8]。这就是算法的应用场景。首先,我们从后往前迭代列表行,查看最后一对的“值”。如果我们的元素大于所有行中所有最后一个元素的值,则我们可以将其添加到现有的子数组中,将其长度增加一。因此,值为8将创建列表的新行,因为它大于目前为止列表中所有最后一个元素的值(即> 5):

1: (16, 1)(5, 1) 
2: (8, ?)   <=== need to resolve how many longest subs ending by 8 can be obtained

第8个元素可以继续5,但不能继续16。因此我们需要搜索前一行,从其末尾开始,计算“值”小于8的成对“计数”的总和:

(16, 1)(5, 1)^  // sum = 0
(16, 1)^(5, 1)  // sum = 1
^(16, 1)(5, 1)  // value 16 >= 8: stop. count = sum = 1, so write 1 in pair next to 8

1: (16, 1)(5, 1)
2: (8, 1)  <=== so far we have 1 sub of length 2 which ends by 8.

为什么我们不把值8存储到长度为1的子串(第一行)中?因为我们需要最大可能长度的子串,并且8可以延续一些先前的子串。因此,每个大于8的下一个数字也将继续这样的子串,因此没有必要将8保留为长度小于它可用的子串。
接下来是6。通过行中的最后一个“值”反向搜索:
1: (16, 1)(5, 1)  <=== 5 < 6, go next
2: (8, 1)

1: (16, 1)(5, 1)
2: (8, 1 )  <=== 8 >= 6, so 6 should be put here

找到了可以容纳6人的房间,需要计算人数:

take previous line
(16, 1)(5, 1)^  // sum = 0
(16, 1)^(5, 1)  // 5 < 6: sum = 1
^(16, 1)(5, 1)  // 16 >= 6: stop, write count = sum = 1

1: (16, 1)(5, 1)
2: (8, 1)(6, 1) 

处理完1之后:

1: (16, 1)(5, 1)(1, 1) <===
2: (8, 1)(6, 1)

处理完 10 后:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)
3: (10, 2) <=== count is 2 because both "values" 8 and 6 from previous row are less than 10, so we summarized their "counts": 1 + 1

处理完5后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1) <===
3: (10, 2)

处理完2后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1) <===
3: (10, 2)

在处理完15之后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)
4: (15, 2) <===

处理完3后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)(3, 1) <===
4: (15, 2)  

处理完2后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2) <===
3: (10, 2)(3, 1) 
4: (15, 2)  

如果在按最后一个元素搜索行时发现相等的元素,则根据上一行重新计算其“计数”,并添加到现有的“计数”中。

处理完4之后:

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2)  
3: (10, 2)(3, 1) 
4: (15, 2)(4, 1) <===

处理1后:

1: (16, 1)(5, 1)(1, 2) <===
2: (8, 1)(6, 1)(5, 1)(2, 2)  
3: (10, 2)(3, 1) 
4: (15, 2)(4, 1)  

在处理所有初始序列后,我们得到了什么?看着最后一行,我们可以看到有3个最长的子序列,每个子序列由4个元素组成:2个以15结尾,1个以4结尾。
那复杂度呢?
在每次迭代中,当从初始序列中取下一个元素时,我们进行2个循环:第一个是在迭代行以找到下一个元素的位置,第二个是在汇总前一行的计数时。因此,对于每个元素,我们最多进行n次迭代(最坏情况:如果初始序列按递增顺序排列,我们将得到n行,每行有1对;如果序列按降序排列,则会获得具有n个元素的1行列表)。顺便说一句,O(n²)的复杂度不是我们想要的。
首先,很明显,在每个中间状态下,行都按其最后一个“值”的递增顺序排序。因此,可以执行二进制搜索而不是暴力循环,其复杂度为O(log n)。
其次,我们不需要每次通过循环遍历行元素来汇总子序列的“计数”。我们可以在添加新对到行时的过程中进行汇总,例如:
1: (16, 1)(5, 2) <=== instead of 1, put 1 + "count" of previous element in the row

因此,第二个数字将显示不是以给定值结尾的最长子序列的计数,而是以“值”对中大于或等于“值”的任何元素结尾的所有最长子序列的摘要计数。
因此,“counts”将被替换为“sums”。我们不再迭代前一行中的元素,而是执行二进制搜索(这是可能的,因为任何行中的对始终按其“值”排序),并将新对的“sum”作为上一行中最后一个元素的“sum”减去在上一行中找到的位置左侧的元素的“sum”,再加上当前行中前一个元素的“sum”。
因此,在处理4时:
1: (16, 1)(5, 2)(1, 3)
2: (8, 1)(6, 2)(5, 3)(2, 5) 
3: (10, 2)(3, 3) 
4: (15, 2) <=== room for (4, ?)

search in row 3 by "values" < 4:
3: (10, 2)^(3, 3) 

4将与(3-2+2)配对: (“上一行的最后一对”的总和) - (“上一行在找到位置左侧的一对”的总和) + (“当前行中前一对”的总和):

4: (15, 2)(4, 3)

在这种情况下,最长子串的最终计数是列表的最后一行的最后一对的“sum”,即3,而不是3 + 2。
因此,在进行二分查找时,对行搜索和总和搜索,我们将得到O(n * log n)复杂度。
至于内存消耗,在处理完整个数组后,我们获得最大n对,因此在使用动态数组的情况下,内存消耗为O(n)。此外,当使用动态数组或集合时,需要一些额外的时间来分配和调整它们的大小,但是大多数操作都在O(1)时间内完成,因为我们在过程中不进行任何排序和重新排列。因此,复杂性估计似乎是最终的。

哇,太棒了。这是一堂漫长的讲座 :-)。明天我会仔细阅读并标记您的回答。再次感谢您。 - Wojciech Kulik
1
非常完美的答案,解释得非常清楚,正是我所需要的!我的算法做了类似的事情,但你以聪明的方式存储它,使得可以执行二分查找并显著降低复杂度。感谢您的时间、帮助和出色的答案。 - Wojciech Kulik
我已经实现并通过了测试 :-),但是你的算法有一个错误。如果我们找到等于的元素,我们必须将其保留为新元素(而不仅仅是增加值)。为什么呢?看一下这个链接:http://pastebin.com/qyAxT2yh总之,干得好 :-)。 - Wojciech Kulik
是的,我明白了!当我们遇到相等的元素时,我们应该根据前一行重新计算它的数量,并将其添加到已有的数量中。然后在你的例子中,81将与3配对,而不是2。非常感谢您的注意! - Alex Salauyou
如果是“求和”而不是“计数”,我们应该先提取同一行中前一个元素的总和并减去现有元素的总和来获取现有元素的“计数”,并将其记住。然后从该行中删除具有相等“值”的现有项。接下来像平常一样计算新项,并将记住的计数添加到其“总和”中。 - Alex Salauyou
谢谢,我使用了带有计数和二分搜索的版本,这对我的目的已经足够了 :-)。效果很好。 - Wojciech Kulik

0

以上逻辑的 Cpp 实现:

#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define pob pop_back
#define pll pair<ll, ll>
#define pii pair<int, int>
#define ll long long
#define ull unsigned long long
#define fori(a,b) for(i=a;i<b;i++)
#define forj(a,b) for(j=a;j<b;j++)
#define fork(a,b) for(k=a;k<b;k++)
#define forl(a,b) for(l=a;l<b;l++)
#define forir(a,b) for(i=a;i>=b;i--)
#define forjr(a,b) for(j=a;j>=b;j--)
#define mod 1000000007
#define boost std::ios::sync_with_stdio(false)

struct comp_pair_int_rev
{
    bool operator()(const pair<int,int> &a, const int & b)
    {
        return (a.first > b);
    }
    bool operator()(const int & a,const pair<int,int> &b)
    {
        return (a > b.first);
    }
};

struct comp_pair_int
{
    bool operator()(const pair<int,int> &a, const int & b)
    {
        return (a.first < b);
    }
    bool operator()(const int & a,const pair<int,int> &b)
    {
        return (a < b.first);
    }
};

int main()
{
    int n,i,mx=0,p,q,r,t;
    cin>>n;

    int a[n];
    vector<vector<pii > > v(100005);
    vector<pii > v1(100005);

    fori(0,n)
    cin>>a[i];

    v[1].pb({a[0], 1} );
    v1[1]= {a[0], 1};

    mx=1;
    fori(1,n)
    {
        if(a[i]<=v1[1].first)
        {
            r=v1[1].second;

            if(v1[1].first==a[i])
                v[1].pob();

            v1[1]= {a[i], r+1};
            v[1].pb({a[i], r+1});
        }
        else if(a[i]>v1[mx].first)
        {
            q=upper_bound(v[mx].begin(), v[mx].end(), a[i], comp_pair_int_rev() )-v[mx].begin();
            if(q==0)
            {
                r=v1[mx].second;
            }
            else
            {
                r=v1[mx].second-v[mx][q-1].second;
            }

            v1[++mx]= {a[i], r};
            v[mx].pb({a[i], r});
        }
        else if(a[i]==v1[mx].first)
        {
            q=upper_bound(v[mx-1].begin(), v[mx-1].end(), a[i], comp_pair_int_rev() )-v[mx-1].begin();
            if(q==0)
            {
                r=v1[mx-1].second;
            }
            else
            {
                r=v1[mx-1].second-v[mx-1][q-1].second;
            }
            p=v1[mx].second;
            v1[mx]= {a[i], p+r};

            v[mx].pob();
            v[mx].pb({a[i], p+r});


        }
        else
        {
            p=lower_bound(v1.begin()+1, v1.begin()+mx+1, a[i], comp_pair_int() )-v1.begin();
            t=v1[p].second;

            if(v1[p].first==a[i])
            {

                v[p].pob();
            }

            q=upper_bound(v[p-1].begin(), v[p-1].end(), a[i], comp_pair_int_rev() )-v[p-1].begin();
            if(q==0)
            {
                r=v1[p-1].second;
            }
            else
            {
                r=v1[p-1].second-v[p-1][q-1].second;
            }

            v1[p]= {a[i], t+r};
            v[p].pb({a[i], t+r});

        }


    }

    cout<<v1[mx].second;

    return 0;
}

0
耐心排序也是O(N*logN),但比基于二分搜索的方法更短、更简单。
static int[] input = {4, 5, 2, 8, 9, 3, 6, 2, 7, 8, 6, 6, 7, 7, 3, 6};

/**
 * Every time a value is tested it either adds to the length of LIS (by calling decs.add() with it), or reduces the remaining smaller cards that must be found before LIS consists of smaller cards. This way all inputs/cards contribute in one way or another (except if they're equal to the biggest number in the sequence; if want't to include in sequence, replace 'card <= decs.get(decIndex)' with 'card < decs.get(decIndex)'. If they're bigger than all decs, they add to the length of LIS (which is something we want), while if they're smaller than a dec, they replace it. We want this, because the smaller the biggest dec is, the smaller input we need before we can add onto LIS.
 *
 * If we run into a decreasing sequence the input from this sequence will replace each other (because they'll always replace the leftmost dec). Thus this algorithm won't wrongfully register e.g. {2, 1, 3} as {2, 3}, but rather {2} -> {1} -> {1, 3}.
 *
 * WARNING: This can only be used to find length, not actual sequence, seeing how parts of the sequence will be replaced by smaller numbers trying to make their sequence dominate
 *
 * Due to bigger decs being added to the end/right of 'decs' and the leftmost decs always being the first to be replaced with smaller decs, the further a dec is to the right (the bigger it's index), the bigger it must be. Thus, by always replacing the leftmost decs, we don't run the risk of replacing the biggest number in a sequence (the number which determines if more cards can be added to that sequence) before a sequence with the same length but smaller numbers (thus currently equally good, due to length, and potentially better, due to less needed to increase length) has been found.
 */
static void patienceFindLISLength() {
    ArrayList<Integer> decs = new ArrayList<>();
    inputLoop: for (Integer card : input) {
        for (int decIndex = 0; decIndex < decs.size(); decIndex++) {
            if (card <= decs.get(decIndex)) {
                decs.set(decIndex, card);
                continue inputLoop;
            }
        }
        decs.add(card);
    }
    System.out.println(decs.size());
}

这是n的平方,不是吗?最坏情况单调递增。 - Dylan Madisetti

0
Sasha Salauyou的回答很好,但我不清楚为什么。
sum -= pRow.get(index).sum;

这是基于相同思路的我的代码

import java.math.BigDecimal;
import java.util.*;

class lisCount {
  static BigDecimal lisCount(int[] a) {
    class Container {
      Integer    v;
      BigDecimal count;

      Container(Integer v) {
        this.v = v;
      }
    }
    List<List<Container>> lisIdxSeq = new ArrayList<List<Container>>();
    int lisLen, lastIdx;
    List<Container> lisSeqL;
    Container lisEle;
    BigDecimal count;
    int pre;
    for (int i = 0; i < a.length; i++){
      pre = -1;
      count = new BigDecimal(1);
      lisLen = lisIdxSeq.size();
      lastIdx = lisLen - 1;
      lisEle = new Container(i);
      if(lisLen == 0 || a[i] > a[lisIdxSeq.get(lastIdx).get(0).v]){
        // lis len increased
        lisSeqL = new ArrayList<Container>();
        lisSeqL.add(lisEle);
        lisIdxSeq.add(lisSeqL);
        pre = lastIdx;
      }else{
        int h = lastIdx;
        int l = 0;

        while(l < h){
          int m = (l + h) / 2;
          if(a[lisIdxSeq.get(m).get(0).v] < a[i]) l = m + 1;
          else h = m;
        }

        List<Container> lisSeqC = lisIdxSeq.get(l);
        if(a[i] <= a[lisSeqC.get(0).v]){
          int hi = lisSeqC.size() - 1;
          int lo = 0;
          while(hi < lo){
            int mi = (hi + lo) / 2;
            if(a[lisSeqC.get(mi).v] < a[i]) lo = mi + 1;
            else hi = mi;
          }
          lisSeqC.add(lo, lisEle);
          pre = l - 1;
        }
      }
      if(pre >= 0){
        Iterator<Container> it = lisIdxSeq.get(pre).iterator();
        count = new BigDecimal(0);
        while(it.hasNext()){
          Container nt = it.next();
          if(a[nt.v] < a[i]){
            count = count.add(nt.count);
          }else break;
        }
      }
      lisEle.count = count;
    }

    BigDecimal rst = new BigDecimal(0);
    Iterator<Container> i = lisIdxSeq.get(lisIdxSeq.size() - 1).iterator();
    while(i.hasNext()){
      rst = rst.add(i.next().count);
    }
    return rst;
  }

  public static void main(String[] args) {
    System.out.println(lisCount(new int[] { 1, 3, 2, 2, 4 }));
    System.out.println(lisCount(new int[] { 3, 2, 1 }));
    System.out.println(lisCount(new int[] { 16, 5, 8, 6, 1, 10, 5, 2, 15, 3, 2, 4, 1 }));
  }
}

0

虽然我完全同意Alex的观点,但使用线段树非常容易实现。 以下是使用线段树在NlogN中查找LIS长度的逻辑。 https://www.quora.com/What-is-the-approach-to-find-the-length-of-the-strictly-increasing-longest-subsequence 这里有一种方法可以找到LIS的数量,但时间复杂度为N^2。 https://codeforces.com/blog/entry/48677

我们使用线段树(如此处所示)来优化此处给出的方法。 以下是逻辑:

首先按升序对数组进行排序(同时保留原始顺序),用零初始化线段树,对于给定范围,线段树应查询两个内容(使用pair):        a. 第一个的最大值。        b. 对应于max-first的第二个的总和。 遍历排序后的数组。     让j成为当前元素的原始索引,然后我们查询(0-j-1)并更新j-th元素(如果查询结果为0,0,则将其更新为(1,1))。

这是我的C++代码:

#include<bits/stdc++.h>
#define tr(container, it) for(typeof(container.begin()) it = container.begin(); it != container.end(); it++)
#define ll          long long
#define pb          push_back
#define endl        '\n'
#define pii         pair<ll int,ll int>
#define vi          vector<ll int>
#define all(a)      (a).begin(),(a).end()
#define F           first
#define S           second
#define sz(x)       (ll int)x.size()
#define hell        1000000007
#define rep(i,a,b)  for(ll int i=a;i<b;i++)
#define lbnd        lower_bound
#define ubnd        upper_bound
#define bs          binary_search
#define mp          make_pair
using namespace std;

#define N  100005

ll max(ll a , ll b)

{
    if( a > b) return a ;
    else return
         b;
}
ll n,l,r;
vector< pii > seg(4*N);

pii query(ll cur,ll st,ll end,ll l,ll r)
{
    if(l<=st&&r>=end)
    return seg[cur];
    if(r<st||l>end)
    return mp(0,0);                           /*  2-change here  */
    ll mid=(st+end)>>1;
    pii ans1=query(2*cur,st,mid,l,r);
    pii ans2=query(2*cur+1,mid+1,end,l,r);
    if(ans1.F>ans2.F)
        return ans1;
    if(ans2.F>ans1.F)
        return ans2;

    return make_pair(ans1.F,ans2.S+ans1.S);                 /*  3-change here  */
}
void update(ll cur,ll st,ll end,ll pos,ll upd1, ll upd2)
{
    if(st==end)
    {
        // a[pos]=upd;                  /*  4-change here  */
        seg[cur].F=upd1;    
        seg[cur].S=upd2;            /*  5-change here  */
        return;
    }
    ll mid=(st+end)>>1;
    if(st<=pos&&pos<=mid)
        update(2*cur,st,mid,pos,upd1,upd2);
    else
        update(2*cur+1,mid+1,end,pos,upd1,upd2);
    seg[cur].F=max(seg[2*cur].F,seg[2*cur+1].F);


    if(seg[2*cur].F==seg[2*cur+1].F)
        seg[cur].S = seg[2*cur].S+seg[2*cur+1].S;
    else
    {
        if(seg[2*cur].F>seg[2*cur+1].F)
            seg[cur].S = seg[2*cur].S;
        else
            seg[cur].S = seg[2*cur+1].S;
        /*  6-change here  */
    }
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int TESTS=1;
//  cin>>TESTS;
    while(TESTS--)
    {
        int n ;
        cin >> n;
        vector< pii > arr(n);
        rep(i,0,n)
        {
            cin >> arr[i].F;
            arr[i].S = -i;
        }

        sort(all(arr));
        update(1,0,n-1,-arr[0].S,1,1);
        rep(i,1,n)
        {
            pii x = query(1,0,n-1,-1,-arr[i].S - 1 );
            update(1,0,n-1,-arr[i].S,x.F+1,max(x.S,1));

        }

        cout<<seg[1].S;//answer



    }
    return 0;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接