潜在的O(n)解法:最长递增子序列

22
我试图用递归(动态规划)回答这个问题。 http://en.wikipedia.org/wiki/Longest_increasing_subsequence 从这篇文章和SO周围,我意识到最有效的现有解决方案是O(nlgn)。我的解决方案是O(N),我找不到它失败的情况。我包括我使用的单元测试用例。
import static org.junit.Assert.assertEquals;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.junit.Test;

public class LongestIncreasingSubseq {

    public static void main(String[] args) {
        int[] arr = {0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, 1};
        getLongestSubSeq(arr);
    }

    public static List<Integer> getLongestSubSeq(int[] arr) {
        List<Integer> indices = longestRecursive(arr, 0, arr.length-1);
        List<Integer> result = new ArrayList<>();
        for (Integer i : indices) {
            result.add(arr[i]);
        }

        System.out.println(result.toString());
        return result;
    }

    private static List<Integer> longestRecursive(int[] arr, int start, int end) {
        if (start == end) {
            List<Integer> singleton = new ArrayList<>();
            singleton.add(start);
            return singleton;
        }

        List<Integer> bestRightSubsequence = longestRecursive(arr, start+1, end); //recursive call down the array to the next start index
        if (bestRightSubsequence.size() == 1 && arr[start] > arr[bestRightSubsequence.get(0)]) {
            bestRightSubsequence.set(0, start); //larger end allows more possibilities ahead
        } else if (arr[start] < arr[bestRightSubsequence.get(0)]) {
            bestRightSubsequence.add(0, start); //add to head
        } else if (bestRightSubsequence.size() > 1 && arr[start] < arr[bestRightSubsequence.get(1)]) {
            //larger than head, but still smaller than 2nd, so replace to allow more possibilities ahead
            bestRightSubsequence.set(0, start); 
        }

        return bestRightSubsequence;
    }

    @Test
    public void test() {
        int[] arr1 = {0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, 1};
        int[] arr2 = {7, 0, 9, 2, 8, 4, 1};
        int[] arr3 = {9, 11, 2, 13, 7, 15};
        int[] arr4 = {10, 22, 9, 33, 21, 50, 41, 60, 80};
        int[] arr5 = {1, 2, 9, 4, 7, 3, 11, 8, 14, 6};
        assertEquals(getLongestSubSeq(arr1), Arrays.asList(0, 4, 6, 9, 11, 15));
        assertEquals(getLongestSubSeq(arr2), Arrays.asList(0, 2, 8));
        assertEquals(getLongestSubSeq(arr3), Arrays.asList(9, 11, 13, 15));
        assertEquals(getLongestSubSeq(arr4), Arrays.asList(10, 22, 33, 50, 60, 80));
        assertEquals(getLongestSubSeq(arr5), Arrays.asList(1, 2, 4, 7, 11, 14));
    }

}

费用严格为O(n),因为有关系式T(n) = T(n-1) + O(1) => T(n) = O(n)。
是否有人可以找到这种情况失败或者有什么错误?非常感谢。
更新: 感谢大家指出我以前实现中的错误。最终代码如下,通过了所有以前无法通过的测试用例。
思路是列出(计算)所有可能的递增子序列(每个都从索引i开始,i从0到N.length-1),并选择最长的子序列。我使用记忆化技术(使用哈希表),避免重新计算已经计算过的子序列 - 因此对于每个起始索引,我们只计算所有递增子序列一次。
然而,我不确定在这种情况下如何正式推导时间复杂度 - 如果有人能解释一下,我将不胜感激。非常感谢。
import static org.junit.Assert.assertEquals;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;

public class LongestIncreasingSubsequence {

    public static List<Integer> getLongestSubSeq(int[] arr) {
        List<Integer> longest = new ArrayList<>();
        for (int i = 0; i < arr.length; i++) {
            List<Integer> candidate = longestSubseqStartsWith(arr, i);
            if (longest.size() < candidate.size()) {
                longest = candidate;
            }
        }

        List<Integer> result = new ArrayList<>();
        for (Integer i : longest) {
            result.add(arr[i]);
        }

        System.out.println(result.toString());
        cache = new HashMap<>(); //new cache otherwise collision in next use - because object is static
        return result;
    }

    private static Map<Integer, List<Integer>> cache = new HashMap<>();
    private static List<Integer> longestSubseqStartsWith(int[] arr, int startIndex) {
        if (cache.containsKey(startIndex)) { //check if already computed
            //must always return a clone otherwise object sharing messes things up
            return new ArrayList<>(cache.get(startIndex)); 
        }

        if (startIndex == arr.length-1) {
            List<Integer> singleton = new ArrayList<>();
            singleton.add(startIndex);
            return singleton;
        }

        List<Integer> longest = new ArrayList<>();
        for (int i = startIndex + 1; i < arr.length; i++) {
            if (arr[startIndex] < arr[i]) {
                List<Integer> longestOnRight = longestSubseqStartsWith(arr, i);
                if (longestOnRight.size() > longest.size()) {
                    longest = longestOnRight;
                }
            }
        }

        longest.add(0, startIndex);
        List<Integer> cloneOfLongest = new ArrayList<>(longest);
        //must always cache a clone otherwise object sharing messes things up
        cache.put(startIndex, cloneOfLongest); //remember this subsequence
        return longest;
    }

    @Test
    public void test() {
        int[] arr1 = {0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, 1};
        int[] arr2 = {7, 0, 9, 2, 8, 4, 1};
        int[] arr3 = {9, 11, 2, 13, 7, 15};
        int[] arr4 = {10, 22, 9, 33, 21, 50, 41, 60, 80};
        int[] arr5 = {1, 2, 9, 4, 7, 3, 11, 8, 14, 6};
        int[] arr6 = {0,0,0,0,0,0,1,1,1,1,2,3,0,0,0,1,1,0,1,1,0,1,0,3};
        int[] arr7 = {0,1,2,0,1,3};
        int[] arr8 = {0,1,2,3,4,5,1,3,8};
        assertEquals(getLongestSubSeq(arr1), Arrays.asList(0, 4, 6, 9, 13, 15));
        assertEquals(getLongestSubSeq(arr2), Arrays.asList(0, 2, 8));
        assertEquals(getLongestSubSeq(arr3), Arrays.asList(9, 11, 13, 15));
        assertEquals(getLongestSubSeq(arr4), Arrays.asList(10, 22, 33, 50, 60, 80));
        assertEquals(getLongestSubSeq(arr5), Arrays.asList(1, 2, 4, 7, 11, 14));
        assertEquals(getLongestSubSeq(arr6), Arrays.asList(0,1,2,3));
        assertEquals(getLongestSubSeq(arr7), Arrays.asList(0,1,2,3));
        assertEquals(getLongestSubSeq(arr8), Arrays.asList(0, 1, 2, 3, 4, 5, 8));
    }

    public static void main(String[] args) {
        int[] arr1 = {7, 0, 9, 2, 8, 4, 1};
        System.out.println(getLongestSubSeq(arr1));
    }

}

16
为什么会有踩和关闭投票?这个问题表达得很清楚,展示了一定的研究努力,并且十分有用。 - kba
7
使用模糊测试/随机测试可能会更容易找到反例。实现一个已知正确的算法,生成随机序列,并比较两种实现的结果。 - user395760
@kba 没有点赞也没有踩,但它看起来像是典型的找出我的代码中的错误,这些通常会被相当大幅度地踩(并且不是很有用)(我无法看到这些和这个之间有显着的区别,但投票结果却非常不同)。此外,一般来说,伪代码/高级描述而不是/除了实际代码使得更容易看出发生了什么。 - Bernhard Barker
看了你提供的维基百科页面,它引用了一篇关于“Ω(nlogn)”下界的论文。虽然我没有亲自验证过,但我敢打赌“Ω(nlogn)”确实是通常计算模型的下界。引文链接指向这里,你可以下载PDF文件:http://www.sciencedirect.com/science/article/pii/0012365X7590103X - rliu
有两种构建软件设计的方式:一种方法是使它非常简单,以至于没有明显的缺陷;另一种方法是使其非常复杂,以至于没有明显的缺陷。第一种方法要困难得多。 - Paul Hankin
5个回答

15

您的程序在这个测试用例中失败了

int[] arr5 = {0,0,0,0,0,0,1,1,1,1,2,3,0,0,0,1,1,0,1,1,0,1,0,3};

你的结果是[0, 1, 3],难道不应该是[0,1,2,3]吗?


5
好的,它甚至对于 {0,1,2,0,1,3} 也会失败。 - Niklas B.

5

刚才我试了一下你的算法,使用了以下测试案例:

 @Test
    public void test() {

      int[] arr1 = {0,1,2,3,4,5,1,3,8};
      assertEquals(getLongestSubSeq(arr1), Arrays.asList(0, 1, 2, 3, 4, 5, 8));
    }

根据您的评论进行了编辑:尝试失败,输出为{1、3、8}。


您的测试用例应该返回 Arrays.asList(0, 1, 2, 3, 4, 5, 8) - PoweredByRice

1

很抱歉给你带来不好的消息,但实际上这是O(n2)。我不确定你是否想要更正式的分析,但这是我的分析:

consider the case when the input is sorted in descending order
  (longestRecursive is never executed recursively, and the cache has no effect)

getLongestSubSeq iterates over the entire input -> 1:n
  each iteration calls longestRecursive
  longestRecursive compares arr[startIndex] < arr[i] for startIndex+1:n -> i - 1

因此,比较 arr[startIndex] < arr[i] 恰好发生 sum(i - 1, 1, n) = n * (n - 1) / 2 次,这显然是 O(n2)。你可以通过发送升序排序的输入来强制使用最大缓存。在这种情况下,getLongestSubSeq 将调用 longestRecursive n 次;其中第一个将触发 n - 1 个递归调用,每个调用都会导致缓存未命中并运行 i - 1 次比较 arr[startIndex] < arr[i],因为在递归开始展开之前不会将任何内容放入缓存中。比较次数与我们绕过缓存的示例完全相同。实际上,比较次数始终相同;在输入中引入反转仅使代码将递归交换为迭代。

1

这是一个O(n^2)算法,因为有两个循环。第二个循环隐藏在一个方法调用中。

这是第一个循环:for (int i = 0; i < arr.length; i++)。在这个循环内部,你调用了longestSubseqStartsWith(arr, i);。看一下longestSubseqStartWith的实现,我们可以看到for (int i = startIndex + 1; i < arr.length; i++)


-2
这是我在Python3.x中的潜在O(N)解决方案:
l = list(map(int,input().split()))
t = []
t2 = []
m = 0
for i in l:
    if(len(t)!=0):
        if(t[-1]<=i):
            if(t[-1]!=1):
                 t.append(i)
        else:
            if(len(t)>m):
                t2 = t
                m = len(t)
            t = [i]
    else:
        t.append(i)
print(t2,len(t2))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接