在一个数组中计算逆序对数

128
我正在设计一个算法,其功能如下:给定数组A[1... n],对于每个i < j,找到所有逆序对使得A[i] > A[j]。我使用归并排序并将数组A复制到数组B,然后比较这两个数组,但我很难看出如何使用它来找到逆序对的数量。任何提示或帮助将不胜感激。
38个回答

153

所以,这里提供了一个Java语言的O(n log n)解决方案。

long merge(int[] arr, int[] left, int[] right) {
    int i = 0, j = 0;
    long count = 0;
    while (i < left.length || j < right.length) {
        if (i == left.length) {
            arr[i+j] = right[j];
            j++;
        } else if (j == right.length) {
            arr[i+j] = left[i];
            i++;
        } else if (left[i] <= right[j]) {
            arr[i+j] = left[i];
            i++;                
        } else {
            arr[i+j] = right[j];
            count += left.length-i;
            j++;
        }
    }
    return count;
}

long invCount(int[] arr) {
    if (arr.length < 2)
        return 0;

    int m = (arr.length + 1) / 2;
    int left[] = Arrays.copyOfRange(arr, 0, m);
    int right[] = Arrays.copyOfRange(arr, m, arr.length);

    return invCount(left) + invCount(right) + merge(arr, left, right);
}

这几乎是标准的归并排序,整个关键在于合并函数中。 请注意,在排序时,算法会消除逆序对。 在合并时,算法计算出已消除逆序对的数量(可以说是已排序的数量)。

唯一一个消除逆序对的时刻是当算法从数组的右侧取出元素并将其合并到主数组时。 此操作所消除的逆序对数量等于尚未合并的左侧数组中元素的数量。

希望这已经足够解释了。


2
我尝试运行了一下,但是没有得到正确的答案。你需要在main函数中调用invCount(intArray)来开始计算吗?其中intArray是未排序的整数数组?我使用了一个包含许多整数的数组进行测试,但是得到了-1887062008作为我的答案。我做错了什么? - Nearpoint
4
+1,请参见C++11中类似的解决方案,其中包括一种通用的基于迭代器的解决方案以及使用5-25个元素序列的随机测试样例。享受吧! - WhozCraig
3
这不是一个解决方案。我尝试运行它,但得到了错误的结果。 - mirgee
2
抱歉问一个新手问题,但是将left.length - i添加到反转计数器中是什么意思?我认为只需要加1就可以了,因为你已经进入了左子数组元素比右子数组元素大的逻辑情况。有人能像我5岁时那样解释一下吗? - Alfredo Gallegos
3
考虑两个数组:[6,8]和[4,5]。当您看到6大于4时,将4放入“arr”中,但这不是一个逆序对。您找到了左侧数组中所有大于6的元素的逆序对。在我们的情况下,这也包括8。因此,“count”增加了2,等于“left.length-i”。 - ilya
显示剩余4条评论

89

我通过以下方法在O(n * log n)时间内找到了它。

  1. 合并排序数组A并创建一个副本数组B
  2. 取出A[1]并通过二分查找找到其在排序数组B中的位置。该元素的逆序对数将比它在B中的索引号少1,因为A的第一个元素后面出现的每个较小的数字都将是一个逆序对。

    2a. 将逆序对的数量累加到计数器变量num_inversions中。

    2b. 从数组A和相应的位置数组B中删除A[1]

  3. 重复步骤2直到A中没有更多元素。

这里是此算法的一个例子。原始数组A = (6, 9, 1, 14, 8, 12, 3, 2)

1:合并排序并复制到数组B

B = (1, 2, 3, 6, 8, 9, 12, 14)

2:取出A[1]并二分查找在数组B中找到它

A[1] = 6

B = (1, 2, 3, 6, 8, 9, 12, 14)

6在数组B的第4个位置上,因此有3个逆序对。我们之所以知道这一点,是因为6在数组A的第一个位置上,在此之后出现的任何较小值元素都将具有j>i的索引(因为在这种情况下i为1)。

2.b:从数组A和相应的位置数组B中删除A[1](删除的元素已加粗显示)。

A = (6, 9, 1, 14, 8, 12, 3, 2) = (9, 1, 14, 8, 12, 3, 2)

B = (1, 2, 3, 6, 8, 9, 12, 14) = (1, 2, 3, 8, 9, 12, 14)

3: 在新的A和B数组上重新执行第2步。

A[1] = 9

B = (1, 2, 3, 8, 9, 12, 14)

现在9在B数组中的第5个位置,因此有4个逆序对。我们知道这是因为9在A数组中的第一个位置,因此任何后来出现的更小的元素会有一个大于i的索引j(因为在这种情况下i还是1)。 从数组A和其对应的B中删除A [1](加粗的元素被删除)

A = (9, 1, 14, 8, 12, 3, 2) = (1, 14, 8, 12, 3, 2)

B = (1, 2, 3, 8, 9, 12, 14) = (1, 2, 3, 8, 12, 14)

继续这样做,直到循环完成,我们将得到数组A的总逆序对数。

执行步骤1(归并排序)需要O(n * log n)的时间。 步骤2将执行n次,每次执行将执行一个二进制搜索,需要O(log n)的时间,总共需要O(n * log n)。因此,总运行时间为O(n * log n) + O(n * log n) = O(n * log n)。

感谢您的帮助。在纸上写出示例数组确实有助于可视化问题。


1
为什么要使用归并排序而不是快速排序? - Alcott
6
快速排序在列表已经排好序的情况下,每一轮都选择第一个元素作为枢轴,时间复杂度最坏可达O(n^2)。归并排序的最坏情况时间复杂度为O(n log n)。 - user482594
35
从标准数组中删除步骤会使你的算法变成O(n^2),因为需要移动数值。(这就是为什么插入排序是O(n^2)的原因)。 - Kyle Butt
从数组B的第一个元素开始计数,在数组A中排在它前面的元素也会得到相同的结果,只要你按照你在答案中描述的那样将它们排除即可。 - tutak
1
@el diablo,很棒的想法!但除了所有删除操作的O(n ^ 2)复杂度外,还有一个问题。二分查找不能搜索第一次出现的元素。但我们需要第一个元素。考虑一个数组[4,7,4]。您的方法将返回2个逆序对,而不是1个。更具体地说,在第一步中,二分查找在原始“4”的索引0处找到具有索引1的“4”,因此我们得到了错误的1(= 1-0)个逆序对。 - ilya
显示剩余2条评论

40

我想知道为什么还没有人提到二叉索引树。你可以使用它来维护排列元素值的前缀和。然后,你可以从右到左遍历,计算每个元素右边比它小的元素数量:

def count_inversions(a):
  res = 0
  counts = [0]*(len(a)+1)
  rank = { v : i+1 for i, v in enumerate(sorted(a)) }
  for x in reversed(a):
    i = rank[x] - 1
    while i:
      res += counts[i]
      i -= i & -i
    i = rank[x]
    while i <= len(a):
      counts[i] += 1
      i += i & -i
  return res

复杂度为O(n log n),而且常数非常低。


可能是最好的方法 :) - Nilutpal Borgohain
1
@NilutpalBorgohain 谢谢 :) 至少在O(n log n)的候选算法中,它似乎需要最少的代码。 - Niklas B.
2
谢谢。i -= i & -i 这行代码的意思是什么?类似地,i += i & -i 是什么意思? - Gerard Condon
1
@GerardCondon,那基本上就是BIT数据结构。可以在答案中找到解释它的链接。 - Niklas B.
1
我了解了Fenwick树。谢谢!我已经发布了一个答案,其中包括对这个问题的所有Python答案进行timeit比较,因此包括您的代码。您可能会对时间结果感兴趣。 - PM 2Ring
这是一个精简而又出色的实现和起点。然而,通过对数组进行sort排序来计算逆序对感觉有点像作弊;-) 当然,需要排序和存储映射意味着更多的速度复杂性。我发现我们可以通过根据源数组中的最大元素调整BIT的大小,并稍微改变实现方式来以空间复杂度为代价换取时间复杂度;请参见GeeksForGeeks上的这里。尽管如此,非常感谢您提供的宝贵指导,正如所说,这是一个非常有用的简明介绍! - underscore_d

29

在Python中

# O(n log n)

def count_inversion(lst):
    return merge_count_inversion(lst)[1]

def merge_count_inversion(lst):
    if len(lst) <= 1:
        return lst, 0
    middle = int( len(lst) / 2 )
    left, a = merge_count_inversion(lst[:middle])
    right, b = merge_count_inversion(lst[middle:])
    result, c = merge_count_split_inversion(left, right)
    return result, (a + b + c)

def merge_count_split_inversion(left, right):
    result = []
    count = 0
    i, j = 0, 0
    left_len = len(left)
    while i < left_len and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            count += left_len - i
            j += 1
    result += left[i:]
    result += right[j:]
    return result, count        


#test code
input_array_1 = []  #0
input_array_2 = [1] #0
input_array_3 = [1, 5]  #0
input_array_4 = [4, 1] #1
input_array_5 = [4, 1, 2, 3, 9] #3
input_array_6 = [4, 1, 3, 2, 9, 5]  #5
input_array_7 = [4, 1, 3, 2, 9, 1]  #8

print count_inversion(input_array_1)
print count_inversion(input_array_2)
print count_inversion(input_array_3)
print count_inversion(input_array_4)
print count_inversion(input_array_5)
print count_inversion(input_array_6)
print count_inversion(input_array_7)

19
我对这个得到了+13的答案感到困惑 - 我不是特别擅长Python,但它似乎与2年前提出的Java版本基本相同,唯一的区别是没有提供任何解释。在其他语言发布答案在我看来是有害的 - 可能有成千上万甚至更多的语言 - 我希望没有人会认为我们应该在一个问题中发布成千上万的答案 - [se]并不是为此而存在的。 - Bernhard Barker
2
@tennenrishin 好吧,也许不是成千上万。但我们应该在哪里划线呢?目前有十个答案已经给出了相同的方法,这大约占了43%的答案(不包括非答案) - 这占据了相当多的空间,因为这里还有其他几种方法。即使只有两个答案采用相同的方法,这也会不必要地稀释答案。而且我在之前的评论中对于这个特定的答案并不有用做出了相当不错的论点。 - Bernhard Barker
3
和你一样,我对Python不熟悉,更加熟悉Java。我认为这个解决方案比Java的解决方案不容易阅读。因此,同样地,对于某些人来说,反过来也可能同样成立。 - Museful
4
对于绝大多数用户来说,Python 就像伪代码一样。即使没有解释,我也觉得 Python 比 Java 更易读。如果这有助于某些读者,我认为没必要过于烦恼。 - Francisco Vargas
3
这个解决方案对于Python用户来说非常不错且易懂。人们想看看别人如何在Python中实现它。 - aerin
显示剩余5条评论

20
这篇文章的主要目的是比较各种Python版本的速度,但我也有一些自己的贡献。在CPython中实现的算法的相对执行速度可能与从简单分析和其他语言的经验所期望的不同。这是因为Python提供了许多强大的函数和方法,这些函数和方法在C中实现,可以对列表和其他集合进行操作,其速度接近于全编译语言,因此这些操作比使用Python代码手动实现的等效算法运行得快得多。利用这些工具的代码通常可以胜过理论上更优越的算法,这些算法尝试通过对集合的每个单独项进行Python操作来完成所有工作。当然,处理的数据量实际上也会对此产生影响。但对于适度数量的数据,使用以C速度运行的O(n²)算法的代码可以轻松地击败使用O(n log n)算法的代码,后者使用Python操作完成大部分工作。许多发布的答案都基于归并排序的算法。从理论上讲,这是一个好的方法,除非数组大小非常小。但是Python内置的TimSort(一种混合稳定排序算法,源自归并排序和插入排序)以C速度运行,而手工编写的归并排序无法与其竞争速度。

Niklas B发布的答案中,更加有趣的解决方案之一是使用内置排序来确定数组项的排名,并使用二进制索引树(又称Fenwick树)存储计算逆序对所需的累积和。在尝试理解这个数据结构和Niklas的算法的过程中,我写了几个自己的变体(如下所示)。但我也发现,对于中等大小的列表,实际上使用Python内置的sum函数比可爱的Fenwick树更快。

def count_inversions(a):
    total = 0
    counts = [0] * len(a)
    rank = {v: i for i, v in enumerate(sorted(a))}
    for u in reversed(a):
        i = rank[u]
        total += sum(counts[:i])
        counts[i] += 1
    return total

最终,当列表大小达到约500时,在那个for循环中调用sum的O(n²)方面就会显现出来,性能开始急剧下降。
归并排序不是唯一的O(nlogn)排序算法,还有其他几种可用于执行逆序计数。prasadvk的答案使用了二叉树排序,但他的代码似乎是用C++或其衍生语言编写的。因此,我添加了一个Python版本。我最初使用一个类来实现树节点,但发现字典要快得多。最终我使用了列表,它甚至更快,尽管这使得代码稍微难以阅读。
树排序的一个额外好处是它比归并排序更容易迭代实现。Python不会优化递归,并且它有一个递归深度限制(虽然如果你真的需要,可以增加它)。当然,Python函数调用相对较慢,因此在尝试优化速度时,尽可能避免函数调用是明智的。
另一个O(nlogn)排序算法是著名的基数排序。它的优点是不需要将键与其他键进行比较。缺点是最适合连续的整数序列,理想情况下是range(b**m)中的整数排列,其中b通常为2。在尝试阅读计算排列中“逆序对”的数量中提到的计数逆序对、离线正交范围计数和相关问题后,我添加了几个基于基数排序的版本。
要有效地使用基数排序计算长度为n的一般序列seq中的逆序对,我们可以创建一个由seq具有相同逆序对数量的range(n)的排列。通过TimSort,我们可以在(最坏情况下)O(nlogn)时间内完成这些操作。关键是通过对seq进行排序来排列其索引。通过一个小例子更容易解释这个技巧。
seq = [15, 14, 11, 12, 10, 13]
b = [t[::-1] for t in enumerate(seq)]
print(b)
b.sort()
print(b)

输出

[(15, 0), (14, 1), (11, 2), (12, 3), (10, 4), (13, 5)]
[(10, 4), (11, 2), (12, 3), (13, 5), (14, 1), (15, 0)]

通过对seq的(值,索引)对进行排序,我们已经通过相同数量的交换来重新排列了seq的索引,以将其从排序顺序恢复到原始顺序。我们可以通过使用适当的键函数对range(n)进行排序来创建该排列:
print(sorted(range(len(seq)), key=lambda k: seq[k]))

输出

[4, 2, 3, 5, 1, 0]

我们可以通过使用`seq`的`. __getitem__`方法来避免那个 `lambda` 函数:
sorted(range(len(seq)), key=seq.__getitem__)

这只是略微更快,但我们正在寻求所有可以获得的速度提升。;)
以下代码对这个页面上所有现有的Python算法进行了timeit测试,以及我自己写的一些算法:几个暴力O(n²)版本、一些基于Niklas B算法的变体,当然还有一个基于归并排序的(我没有参考现有答案而写的)。它还包括我的基于列表的树排序代码,大致派生自prasadvk的代码,并且基于基数排序的各种函数,其中一些使用与归并排序方法类似的策略,一些使用sum或Fenwick树。
该程序测量每个函数在一系列随机整数列表上的执行时间;它还可以验证每个函数给出的结果是否与其他函数相同,并且不会修改输入列表。
每个timeit调用都会返回一个包含3个结果的向量,我将其排序。在这里要关注的主要值是最小值,其他值仅表示该最小值的可靠性如何,正如timeit模块文档中的注释所讨论的那样。
不幸的是,该程序的输出太大,无法包含在此答案中,因此我将其发布在自己的社区维基答案中。
输出是在我的古老的32位单核2GHz机器上运行Python 3.6.0的旧Debian衍生版上进行的3次运行。你的结果可能会有所不同。在测试期间,我关闭了Web浏览器并断开了路由器以最小化对CPU的影响。
第一次运行测试所有函数,列表大小从5到320,循环大小从4096到64(随着列表大小加倍,循环大小减半)。用于构建每个列表的随机池的大小是列表本身的一半,因此我们可能会得到很多重复项。某些计数反演算法对重复项更敏感。
第二次运行使用更大的列表:640到10240,并且循环大小为8。为节省时间,它从测试中删除了几个最慢的函数。在这些大小上,我的暴力O(n²)函数太慢了,如前所述,我的使用sum的代码在小到中等列表上表现得非常好,但在大列表上无法跟上。
最后一次运行涵盖列表大小从20480到655360,固定循环大小为4,并使用8个最快速的函数。对于列表大小在40,000以下的情况下,Tim Babych的代码是明显的赢家。做得好,Tim!Niklas B的代码也是一个良好的全能表演者,尽管在较小的列表上被击败。 "Python"的二分代码也表现不错,尽管对于具有大量重复项的巨大列表似乎稍微慢了一些,可能是由于它使用的线性while循环来跨越重复项。

然而,对于非常大的列表大小,基于二分法的算法无法与真正的O(nlogn)算法竞争。

#!/usr/bin/env python3

''' Test speeds of various ways of counting inversions in a list

    The inversion count is a measure of how sorted an array is.
    A pair of items in a are inverted if i < j but a[j] > a[i]

    See https://dev59.com/RXRC5IYBdhLWcg3wUvQS

    This program contains code by the following authors:
    mkso
    Niklas B
    B. M.
    Tim Babych
    python
    Zhe Hu
    prasadvk
    noman pouigt
    PM 2Ring

    Timing and verification code by PM 2Ring
    Collated 2017.12.16
    Updated 2017.12.21
'''

from timeit import Timer
from random import seed, randrange
from bisect import bisect, insort_left

seed('A random seed string')

# Merge sort version by mkso
def count_inversion_mkso(lst):
    return merge_count_inversion(lst)[1]

def merge_count_inversion(lst):
    if len(lst) <= 1:
        return lst, 0
    middle = len(lst) // 2
    left, a = merge_count_inversion(lst[:middle])
    right, b = merge_count_inversion(lst[middle:])
    result, c = merge_count_split_inversion(left, right)
    return result, (a + b + c)

def merge_count_split_inversion(left, right):
    result = []
    count = 0
    i, j = 0, 0
    left_len = len(left)
    while i < left_len and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            count += left_len - i
            j += 1
    result += left[i:]
    result += right[j:]
    return result, count

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Using a Binary Indexed Tree, aka a Fenwick tree, by Niklas B.
def count_inversions_NiklasB(a):
    res = 0
    counts = [0] * (len(a) + 1)
    rank = {v: i for i, v in enumerate(sorted(a), 1)}
    for x in reversed(a):
        i = rank[x] - 1
        while i:
            res += counts[i]
            i -= i & -i
        i = rank[x]
        while i <= len(a):
            counts[i] += 1
            i += i & -i
    return res

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by B.M
# Modified by PM 2Ring to deal with the global counter
bm_count = 0

def merge_count_BM(seq):
    global bm_count
    bm_count = 0
    sort_bm(seq)
    return bm_count

def merge_bm(l1,l2):
    global bm_count
    l = []
    while l1 and l2:
        if l1[-1] <= l2[-1]:
            l.append(l2.pop())
        else:
            l.append(l1.pop())
            bm_count += len(l2)
    l.reverse()
    return l1 + l2 + l

def sort_bm(l):
    t = len(l) // 2
    return merge_bm(sort_bm(l[:t]), sort_bm(l[t:])) if t > 0 else l

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Bisection based method by Tim Babych
def solution_TimBabych(A):
    sorted_left = []
    res = 0
    for i in range(1, len(A)):
        insort_left(sorted_left, A[i-1])
        # i is also the length of sorted_left
        res += (i - bisect(sorted_left, A[i]))
    return res

# Slightly faster, except for very small lists
def solutionE_TimBabych(A):
    res = 0
    sorted_left = []
    for i, u in enumerate(A):
        # i is also the length of sorted_left
        res += (i - bisect(sorted_left, u))
        insort_left(sorted_left, u)
    return res

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Bisection based method by "python"
def solution_python(A):
    B = list(A)
    B.sort()
    inversion_count = 0
    for i in range(len(A)):
        j = binarySearch_python(B, A[i])
        while B[j] == B[j - 1]:
            if j < 1:
                break
            j -= 1
        inversion_count += j
        B.pop(j)
    return inversion_count

def binarySearch_python(alist, item):
    first = 0
    last = len(alist) - 1
    found = False
    while first <= last and not found:
        midpoint = (first + last) // 2
        if alist[midpoint] == item:
            return midpoint
        else:
            if item < alist[midpoint]:
                last = midpoint - 1
            else:
                first = midpoint + 1

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by Zhe Hu
def inv_cnt_ZheHu(a):
    _, count = inv_cnt(a.copy())
    return count

def inv_cnt(a):
    n = len(a)
    if n==1:
        return a, 0
    left = a[0:n//2] # should be smaller
    left, cnt1 = inv_cnt(left)
    right = a[n//2:] # should be larger
    right, cnt2 = inv_cnt(right)

    cnt = 0
    i_left = i_right = i_a = 0
    while i_a < n:
        if (i_right>=len(right)) or (i_left < len(left)
            and left[i_left] <= right[i_right]):
            a[i_a] = left[i_left]
            i_left += 1
        else:
            a[i_a] = right[i_right]
            i_right += 1
            if i_left < len(left):
                cnt += len(left) - i_left
        i_a += 1
    return (a, cnt1 + cnt2 + cnt)

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by noman pouigt
# From https://stackoverflow.com/q/47830098
def reversePairs_nomanpouigt(nums):
    def merge(left, right):
        if not left or not right:
            return (0, left + right)
        #if everything in left is less than right
        if left[len(left)-1] < right[0]:
            return (0, left + right)
        else:
            left_idx, right_idx, count = 0, 0, 0
            merged_output = []

            # check for condition before we merge it
            while left_idx < len(left) and right_idx < len(right):
                #if left[left_idx] > 2 * right[right_idx]:
                if left[left_idx] > right[right_idx]:
                    count += len(left) - left_idx
                    right_idx += 1
                else:
                    left_idx += 1

            #merging the sorted list
            left_idx, right_idx = 0, 0
            while left_idx < len(left) and right_idx < len(right):
                if left[left_idx] > right[right_idx]:
                    merged_output += [right[right_idx]]
                    right_idx += 1
                else:
                    merged_output += [left[left_idx]]
                    left_idx += 1
            if left_idx == len(left):
                merged_output += right[right_idx:]
            else:
                merged_output += left[left_idx:]
        return (count, merged_output)

    def partition(nums):
        count = 0
        if len(nums) == 1 or not nums:
            return (0, nums)
        pivot = len(nums)//2
        left_count, l = partition(nums[:pivot])
        right_count, r = partition(nums[pivot:])
        temp_count, temp_list = merge(l, r)
        return (temp_count + left_count + right_count, temp_list)
    return partition(nums)[0]

# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# PM 2Ring
def merge_PM2R(seq):
    seq, count = merge_sort_count_PM2R(seq)
    return count

def merge_sort_count_PM2R(seq):
    mid = len(seq) // 2
    if mid == 0:
        return seq, 0
    left, left_total = merge_sort_count_PM2R(seq[:mid])
    right, right_total = merge_sort_count_PM2R(seq[mid:])
    total = left_total + right_total
    result = []
    i = j = 0
    left_len, right_len = len(left), len(right)
    while i < left_len and j < right_len:
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
            total += left_len - i
    result.extend(left[i:])
    result.extend(right[j:])
    return result, total

def rank_sum_PM2R(a):
    total = 0
    counts = [0] * len(a)
    rank = {v: i for i, v in enumerate(sorted(a))}
    for u in reversed(a):
        i = rank[u]
        total += sum(counts[:i])
        counts[i] += 1
    return total

# Fenwick tree functions adapted from C code on Wikipedia
def fen_sum(tree, i):
    ''' Return the sum of the first i elements, 0 through i-1 '''
    total = 0
    while i:
        total += tree[i-1]
        i -= i & -i
    return total

def fen_add(tree, delta, i):
    ''' Add delta to element i and thus 
        to fen_sum(tree, j) for all j > i 
    '''
    size = len(tree)
    while i < size:
        tree[i] += delta
        i += (i+1) & -(i+1)

def fenwick_PM2R(a):
    total = 0
    counts = [0] * len(a)
    rank = {v: i for i, v in enumerate(sorted(a))}
    for u in reversed(a):
        i = rank[u]
        total += fen_sum(counts, i)
        fen_add(counts, 1, i)
    return total

def fenwick_inline_PM2R(a):
    total = 0
    size = len(a)
    counts = [0] * size
    rank = {v: i for i, v in enumerate(sorted(a))}
    for u in reversed(a):
        i = rank[u]
        j = i + 1
        while i:
            total += counts[i]
            i -= i & -i
        while j < size:
            counts[j] += 1
            j += j & -j
    return total

def bruteforce_loops_PM2R(a):
    total = 0
    for i in range(1, len(a)):
        u = a[i]
        for j in range(i):
            if a[j] > u:
                total += 1
    return total

def bruteforce_sum_PM2R(a):
    return sum(1 for i in range(1, len(a)) for j in range(i) if a[j] > a[i])

# Using binary tree counting, derived from C++ code (?) by prasadvk
# https://dev59.com/RXRC5IYBdhLWcg3wUvQS#16056139
def ltree_count_PM2R(a):
    total, root = 0, None
    for u in a:
        # Store data in a list-based tree structure
        # [data, count, left_child, right_child]
        p = [u, 0, None, None]
        if root is None:
            root = p
            continue
        q = root
        while True:
            if p[0] < q[0]:
                total += 1 + q[1]
                child = 2
            else:
                q[1] += 1
                child = 3
            if q[child]:
                q = q[child]
            else:
                q[child] = p
                break
    return total

# Counting based on radix sort, recursive version
def radix_partition_rec(a, L):
    if len(a) < 2:
        return 0
    if len(a) == 2:
        return a[1] < a[0]
    left, right = [], []
    count = 0
    for u in a:
        if u & L:
            right.append(u)
        else:
            count += len(right)
            left.append(u)
    L >>= 1
    if L:
        count += radix_partition_rec(left, L) + radix_partition_rec(right, L)
    return count

# The following functions determine swaps using a permutation of 
# range(len(a)) that has the same inversion count as `a`. We can create
# this permutation with `sorted(range(len(a)), key=lambda k: a[k])`
# but `sorted(range(len(a)), key=a.__getitem__)` is a little faster.

# Counting based on radix sort, iterative version
def radix_partition_iter(seq, L):
    count = 0
    parts = [seq]
    while L and parts:
        newparts = []
        for a in parts:
            if len(a) < 2:
                continue
            if len(a) == 2:
                count += a[1] < a[0]
                continue
            left, right = [], []
            for u in a:
                if u & L:
                    right.append(u)
                else:
                    count += len(right)
                    left.append(u)
            if left:
                newparts.append(left)
            if right:
                newparts.append(right)
        parts = newparts
        L >>= 1
    return count

def perm_radixR_PM2R(a):
    size = len(a)
    b = sorted(range(size), key=a.__getitem__)
    n = size.bit_length() - 1
    return radix_partition_rec(b, 1 << n)

def perm_radixI_PM2R(a):
    size = len(a)
    b = sorted(range(size), key=a.__getitem__)
    n = size.bit_length() - 1
    return radix_partition_iter(b, 1 << n)

# Plain sum of the counts of the permutation
def perm_sum_PM2R(a):
    total = 0
    size = len(a)
    counts = [0] * size
    for i in reversed(sorted(range(size), key=a.__getitem__)):
        total += sum(counts[:i])
        counts[i] = 1
    return total

# Fenwick sum of the counts of the permutation
def perm_fenwick_PM2R(a):
    total = 0
    size = len(a)
    counts = [0] * size
    for i in reversed(sorted(range(size), key=a.__getitem__)):
        j = i + 1
        while i:
            total += counts[i]
            i -= i & -i
        while j < size:
            counts[j] += 1
            j += j & -j
    return total

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# All the inversion-counting functions
funcs = (
    solution_TimBabych,
    solutionE_TimBabych,
    solution_python,
    count_inversion_mkso,
    count_inversions_NiklasB,
    merge_count_BM,
    inv_cnt_ZheHu,
    reversePairs_nomanpouigt,
    fenwick_PM2R,
    fenwick_inline_PM2R,
    merge_PM2R,
    rank_sum_PM2R,
    bruteforce_loops_PM2R,
    bruteforce_sum_PM2R,
    ltree_count_PM2R,
    perm_radixR_PM2R,
    perm_radixI_PM2R,
    perm_sum_PM2R,
    perm_fenwick_PM2R,
)

def time_test(seq, loops, verify=False):
    orig = seq
    timings = []
    for func in funcs:
        seq = orig.copy()
        value = func(seq) if verify else None
        t = Timer(lambda: func(seq))
        result = sorted(t.repeat(3, loops))
        timings.append((result, func.__name__, value))
        assert seq==orig, 'Sequence altered by {}!'.format(func.__name__)
    first = timings[0][-1]
    timings.sort()
    for result, name, value in timings:
        result = ', '.join([format(u, '.5f') for u in result])
        print('{:24} : {}'.format(name, result))

    if verify:
        # Check that all results are identical
        bad = ['%s: %d' % (name, value)
            for _, name, value in timings if value != first]
        if bad:
            print('ERROR. Value: {}, bad: {}'.format(first, ', '.join(bad)))
        else:
            print('Value: {}'.format(first))
    print()

#Run the tests
size, loops = 5, 1 << 12
verify = True
for _ in range(7):
    hi = size // 2
    print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
    seq = [randrange(hi) for _ in range(size)]
    time_test(seq, loops, verify)
    loops >>= 1
    size <<= 1

#size, loops = 640, 8
#verify = False
#for _ in range(5):
    #hi = size // 2
    #print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
    #seq = [randrange(hi) for _ in range(size)]
    #time_test(seq, loops, verify)
    #size <<= 1

#size, loops = 163840, 4
#verify = False
#for _ in range(3):
    #hi = size // 2
    #print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
    #seq = [randrange(hi) for _ in range(size)]
    #time_test(seq, loops, verify)
    #size <<= 1

请参阅这里的输出结果


谢谢,这很有趣 :) 清楚地展示了使用C模块——二分法的好处。 - Tim Babych
问题在于获胜者使用(理论上)二次算法。对于大小约为100,000的问题,它将被其它解法超越。我编辑了我的帖子,提供了一个几乎线性且速度快的Python解决方案。 - B. M.
@B.M. 当然可以,但是Tim的二分法在处理45000左右的数据时非常好用。我还有几个解决方案,明天或者后天会添加在这里。 - PM 2Ring
@TimBabych你是在说bisect是C语言吗?我非常确定它是Python。 - Stefan Pochmann
1
Python的bisect模块是用C编写的,参见 https://github.com/python/cpython/blob/master/Modules/_bisectmodule.c https://github.com/python/cpython/blob/master/Lib/bisect.py#L84 - Tim Babych
我在我的帖子中添加了一个基于快速排序的竞争性纯Python解决方案。 - B. M.

18

我曾经有一个类似的问题作业。我的限制是它必须具有O(nlogn)的效率。

我使用了你提出的使用归并排序的想法,因为它已经具有正确的效率。我只是在合并函数中插入了一些代码,基本上是这样的: 每当将右侧数组中的数字添加到输出数组时,我就将剩余左侧数组中的数字数量添加到逆序总数中。

现在我已经足够思考了,这对我来说非常有意义。你计算有多少次在任何数字之前出现更大的数字。

希望对你有所帮助。


7
我支持你的答案,与归并排序的根本区别在于合并函数,在将第二个右数组的元素复制到输出数组时,需要将逆序对计数器增加剩余的第一个左数组中的元素数量。 - Alex.Salnikov

11
在归并排序的合并过程中,可以通过分析找到逆序对的数量: merge process 当将第二个数组中的元素(例如这个例子中的9)复制到合并数组时,它相对于其他元素保持不变。当将第一个数组中的元素(这里是5)复制到合并数组时,它与所有留在第二个数组中的元素都是逆序的(与3和4共有2个逆序对)。因此,对归并排序进行一些小的修改即可在O(n ln n)的时间内解决问题。 例如,只需取消下面Python代码中的两行#注释即可进行计数。
def merge(l1,l2):
    l = []
    # global count
    while l1 and l2:
        if l1[-1] <= l2[-1]:
            l.append(l2.pop())
        else:
            l.append(l1.pop())
            # count += len(l2)
    l.reverse()
    return l1 + l2 + l

def sort(l): 
    t = len(l) // 2
    return merge(sort(l[:t]), sort(l[t:])) if t > 0 else l

count=0
print(sort([5,1,2,4,9,3]), count)
# [1, 2, 3, 4, 5, 9] 6

编辑1

使用稳定版本的快速排序也可以完成同样的任务,这被认为略微更快:

def part(l):
    pivot=l[-1]
    small,big = [],[]
    count = big_count = 0
    for x in l:
        if x <= pivot:
            small.append(x)
            count += big_count
        else:
            big.append(x)
            big_count += 1
    return count,small,big

def quick_count(l):
    if len(l)<2 : return 0
    count,small,big = part(l)
    small.pop()
    return count + quick_count(small) + quick_count(big)

选择最后一个元素作为枢轴,逆序对可以很好地计数,并且执行时间比上面的合并方法快40%。 编辑2 为了在Python中提高性能,可以使用numpy和numba版本:
首先是numpy部分,它使用argsort O(n ln n):
def count_inversions(a):
    n = a.size
    counts = np.arange(n) & -np.arange(n)  # The BIT
    ags = a.argsort(kind='mergesort')    
    return  BIT(ags,counts,n)

关于高效的BIT算法,其中的numba部分:

@numba.njit
def BIT(ags,counts,n):
    res = 0        
    for x in ags :
        i = x
        while i:
            res += counts[i]
            i -= i & -i
        i = x+1
        while i < n:
            counts[i] -= 1
            i += i & -i
    return  res  

我已经发布了一个答案,它使用timeit比较了所有Python回答这个问题的代码,所以包括你的代码。你可能会对时间结果感兴趣。 - PM 2Ring
这篇文章没有性能问题... 我会在一段时间内尝试。可以使用Numpy numba吗? - B. M.
我从未使用过Numba,但我有一点使用Numpy的经验,并考虑自己添加一个Numpy版本,但我决定将测试限制在仅使用标准库的解决方案上。但我想看看Numpy解决方案的性能如何比较有趣。我怀疑它在小列表上不会更快。 - PM 2Ring
100倍的加速真是令人印象深刻!但我无法运行它,因为我没有安装Numba。正如我之前所说,将其包含在我的“timeit”集合中是不公平的。 - PM 2Ring

8
请注意,Geoffrey Irving的答案是错误的。
一个数组中逆序对的数量是排序该数组所需移动元素总距离的一半。因此,可以通过对数组进行排序,维护得到的排列p[i],然后计算abs(p[i]-i)/2的和来计算它。这需要O(n log n)时间,这是最优的。
另一种方法在http://mathworld.wolfram.com/PermutationInversion.html给出。这种方法等价于max(0, p[i]-i)的和,这与abs(p[i]-i])/2的和相等,因为元素向左移动的总距离等于元素向右移动的总距离。
以{3, 2, 1}序列为例。有三个逆序对:(3, 2), (3, 1), (2, 1),所以逆序对数为3。然而,根据引用的方法,答案应该是2。

正确答案可以通过计算最少需要多少相邻交换来找到。请参阅讨论:https://dev59.com/U2Ei5IYBdhLWcg3wwecN - Isaac Turner

5

4

这里有一个使用二叉树变体的可能解决方案。它向每个树节点添加了一个称为 rightSubTreeSize 的字段。按照它们在数组中出现的顺序将数字插入二叉树中。如果数字进入节点的lhs,则该元素的逆序对数将是(1 + rightSubTreeSize)。因为所有这些元素都大于当前元素,并且它们应该先出现在数组中。如果元素进入节点的rhs,则只需增加其 rightSubTreeSize。以下是代码。

Node { 
    int data;
    Node* left, *right;
    int rightSubTreeSize;

    Node(int data) { 
        rightSubTreeSize = 0;
    }   
};

Node* root = null;
int totCnt = 0;
for(i = 0; i < n; ++i) { 
    Node* p = new Node(a[i]);
    if(root == null) { 
        root = p;
        continue;
    } 

    Node* q = root;
    int curCnt = 0;
    while(q) { 
        if(p->data <= q->data) { 
            curCnt += 1 + q->rightSubTreeSize;
            if(q->left) { 
                q = q->left;
            } else { 
                q->left = p;
                break;
            }
        } else { 
            q->rightSubTreeSize++;
            if(q->right) { 
                q = q->right;
            } else { 
                q->right = p;
                break;
            }
        }
    }

    totCnt += curCnt;
  }
  return totCnt;

这是一种有趣的方法,看起来非常快。然而,比较需要是 if(p->data < q->data),否则重复项无法正确处理。并且在循环顶部没有必要测试 q,无条件的 while 循环可以正常工作。另外,您忘记提到这是什么语言了。 :) 而且您的函数似乎丢失了头行。 - PM 2Ring
我刚刚根据你的树算法添加了一个Python版本到我的答案中。当然,相对于其他Python版本,它并不像完全编译的版本那样快,但表现还是相当不错的。 - PM 2Ring

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接