A[1... n]
,对于每个i < j
,找到所有逆序对使得A[i] > A[j]
。我使用归并排序并将数组A复制到数组B,然后比较这两个数组,但我很难看出如何使用它来找到逆序对的数量。任何提示或帮助将不胜感激。A[1... n]
,对于每个i < j
,找到所有逆序对使得A[i] > A[j]
。我使用归并排序并将数组A复制到数组B,然后比较这两个数组,但我很难看出如何使用它来找到逆序对的数量。任何提示或帮助将不胜感激。所以,这里提供了一个Java语言的O(n log n)解决方案。
long merge(int[] arr, int[] left, int[] right) {
int i = 0, j = 0;
long count = 0;
while (i < left.length || j < right.length) {
if (i == left.length) {
arr[i+j] = right[j];
j++;
} else if (j == right.length) {
arr[i+j] = left[i];
i++;
} else if (left[i] <= right[j]) {
arr[i+j] = left[i];
i++;
} else {
arr[i+j] = right[j];
count += left.length-i;
j++;
}
}
return count;
}
long invCount(int[] arr) {
if (arr.length < 2)
return 0;
int m = (arr.length + 1) / 2;
int left[] = Arrays.copyOfRange(arr, 0, m);
int right[] = Arrays.copyOfRange(arr, m, arr.length);
return invCount(left) + invCount(right) + merge(arr, left, right);
}
这几乎是标准的归并排序,整个关键在于合并函数中。 请注意,在排序时,算法会消除逆序对。 在合并时,算法计算出已消除逆序对的数量(可以说是已排序的数量)。
唯一一个消除逆序对的时刻是当算法从数组的右侧取出元素并将其合并到主数组时。 此操作所消除的逆序对数量等于尚未合并的左侧数组中元素的数量。
希望这已经足够解释了。
我通过以下方法在O(n * log n)时间内找到了它。
取出A[1]并通过二分查找找到其在排序数组B中的位置。该元素的逆序对数将比它在B中的索引号少1,因为A的第一个元素后面出现的每个较小的数字都将是一个逆序对。
2a. 将逆序对的数量累加到计数器变量num_inversions中。
2b. 从数组A和相应的位置数组B中删除A[1]
这里是此算法的一个例子。原始数组A = (6, 9, 1, 14, 8, 12, 3, 2)
1:合并排序并复制到数组B
B = (1, 2, 3, 6, 8, 9, 12, 14)
2:取出A[1]并二分查找在数组B中找到它
A[1] = 6
B = (1, 2, 3, 6, 8, 9, 12, 14)
6在数组B的第4个位置上,因此有3个逆序对。我们之所以知道这一点,是因为6在数组A的第一个位置上,在此之后出现的任何较小值元素都将具有j>i的索引(因为在这种情况下i为1)。
2.b:从数组A和相应的位置数组B中删除A[1](删除的元素已加粗显示)。
A = (6, 9, 1, 14, 8, 12, 3, 2) = (9, 1, 14, 8, 12, 3, 2)
B = (1, 2, 3, 6, 8, 9, 12, 14) = (1, 2, 3, 8, 9, 12, 14)
3: 在新的A和B数组上重新执行第2步。
A[1] = 9
B = (1, 2, 3, 8, 9, 12, 14)
现在9在B数组中的第5个位置,因此有4个逆序对。我们知道这是因为9在A数组中的第一个位置,因此任何后来出现的更小的元素会有一个大于i的索引j(因为在这种情况下i还是1)。 从数组A和其对应的B中删除A [1](加粗的元素被删除)
A = (9, 1, 14, 8, 12, 3, 2) = (1, 14, 8, 12, 3, 2)
B = (1, 2, 3, 8, 9, 12, 14) = (1, 2, 3, 8, 12, 14)
继续这样做,直到循环完成,我们将得到数组A的总逆序对数。
执行步骤1(归并排序)需要O(n * log n)的时间。 步骤2将执行n次,每次执行将执行一个二进制搜索,需要O(log n)的时间,总共需要O(n * log n)。因此,总运行时间为O(n * log n) + O(n * log n) = O(n * log n)。
感谢您的帮助。在纸上写出示例数组确实有助于可视化问题。
我想知道为什么还没有人提到二叉索引树。你可以使用它来维护排列元素值的前缀和。然后,你可以从右到左遍历,计算每个元素右边比它小的元素数量:
def count_inversions(a):
res = 0
counts = [0]*(len(a)+1)
rank = { v : i+1 for i, v in enumerate(sorted(a)) }
for x in reversed(a):
i = rank[x] - 1
while i:
res += counts[i]
i -= i & -i
i = rank[x]
while i <= len(a):
counts[i] += 1
i += i & -i
return res
复杂度为O(n log n),而且常数非常低。
i -= i & -i
这行代码的意思是什么?类似地,i += i & -i
是什么意思? - Gerard Condontimeit
比较,因此包括您的代码。您可能会对时间结果感兴趣。 - PM 2Ringsort
排序来计算逆序对感觉有点像作弊;-) 当然,需要排序和存储映射意味着更多的速度复杂性。我发现我们可以通过根据源数组中的最大元素调整BIT的大小,并稍微改变实现方式来以空间复杂度为代价换取时间复杂度;请参见GeeksForGeeks上的这里。尽管如此,非常感谢您提供的宝贵指导,正如所说,这是一个非常有用的简明介绍! - underscore_d在Python中
# O(n log n)
def count_inversion(lst):
return merge_count_inversion(lst)[1]
def merge_count_inversion(lst):
if len(lst) <= 1:
return lst, 0
middle = int( len(lst) / 2 )
left, a = merge_count_inversion(lst[:middle])
right, b = merge_count_inversion(lst[middle:])
result, c = merge_count_split_inversion(left, right)
return result, (a + b + c)
def merge_count_split_inversion(left, right):
result = []
count = 0
i, j = 0, 0
left_len = len(left)
while i < left_len and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
count += left_len - i
j += 1
result += left[i:]
result += right[j:]
return result, count
#test code
input_array_1 = [] #0
input_array_2 = [1] #0
input_array_3 = [1, 5] #0
input_array_4 = [4, 1] #1
input_array_5 = [4, 1, 2, 3, 9] #3
input_array_6 = [4, 1, 3, 2, 9, 5] #5
input_array_7 = [4, 1, 3, 2, 9, 1] #8
print count_inversion(input_array_1)
print count_inversion(input_array_2)
print count_inversion(input_array_3)
print count_inversion(input_array_4)
print count_inversion(input_array_5)
print count_inversion(input_array_6)
print count_inversion(input_array_7)
在Niklas B发布的答案中,更加有趣的解决方案之一是使用内置排序来确定数组项的排名,并使用二进制索引树(又称Fenwick树)存储计算逆序对所需的累积和。在尝试理解这个数据结构和Niklas的算法的过程中,我写了几个自己的变体(如下所示)。但我也发现,对于中等大小的列表,实际上使用Python内置的sum
函数比可爱的Fenwick树更快。
def count_inversions(a):
total = 0
counts = [0] * len(a)
rank = {v: i for i, v in enumerate(sorted(a))}
for u in reversed(a):
i = rank[u]
total += sum(counts[:i])
counts[i] += 1
return total
range(b**m)
中的整数排列,其中b
通常为2。在尝试阅读计算排列中“逆序对”的数量中提到的计数逆序对、离线正交范围计数和相关问题后,我添加了几个基于基数排序的版本。seq
中的逆序对,我们可以创建一个由seq
具有相同逆序对数量的range(n)
的排列。通过TimSort,我们可以在(最坏情况下)O(nlogn)时间内完成这些操作。关键是通过对seq
进行排序来排列其索引。通过一个小例子更容易解释这个技巧。seq = [15, 14, 11, 12, 10, 13]
b = [t[::-1] for t in enumerate(seq)]
print(b)
b.sort()
print(b)
输出
[(15, 0), (14, 1), (11, 2), (12, 3), (10, 4), (13, 5)]
[(10, 4), (11, 2), (12, 3), (13, 5), (14, 1), (15, 0)]
seq
的(值,索引)对进行排序,我们已经通过相同数量的交换来重新排列了seq
的索引,以将其从排序顺序恢复到原始顺序。我们可以通过使用适当的键函数对range(n)
进行排序来创建该排列:print(sorted(range(len(seq)), key=lambda k: seq[k]))
输出
[4, 2, 3, 5, 1, 0]
sorted(range(len(seq)), key=seq.__getitem__)
timeit
测试,以及我自己写的一些算法:几个暴力O(n²)版本、一些基于Niklas B算法的变体,当然还有一个基于归并排序的(我没有参考现有答案而写的)。它还包括我的基于列表的树排序代码,大致派生自prasadvk的代码,并且基于基数排序的各种函数,其中一些使用与归并排序方法类似的策略,一些使用sum
或Fenwick树。timeit
调用都会返回一个包含3个结果的向量,我将其排序。在这里要关注的主要值是最小值,其他值仅表示该最小值的可靠性如何,正如timeit
模块文档中的注释所讨论的那样。然而,对于非常大的列表大小,基于二分法的算法无法与真正的O(nlogn)算法竞争。
#!/usr/bin/env python3
''' Test speeds of various ways of counting inversions in a list
The inversion count is a measure of how sorted an array is.
A pair of items in a are inverted if i < j but a[j] > a[i]
See https://dev59.com/RXRC5IYBdhLWcg3wUvQS
This program contains code by the following authors:
mkso
Niklas B
B. M.
Tim Babych
python
Zhe Hu
prasadvk
noman pouigt
PM 2Ring
Timing and verification code by PM 2Ring
Collated 2017.12.16
Updated 2017.12.21
'''
from timeit import Timer
from random import seed, randrange
from bisect import bisect, insort_left
seed('A random seed string')
# Merge sort version by mkso
def count_inversion_mkso(lst):
return merge_count_inversion(lst)[1]
def merge_count_inversion(lst):
if len(lst) <= 1:
return lst, 0
middle = len(lst) // 2
left, a = merge_count_inversion(lst[:middle])
right, b = merge_count_inversion(lst[middle:])
result, c = merge_count_split_inversion(left, right)
return result, (a + b + c)
def merge_count_split_inversion(left, right):
result = []
count = 0
i, j = 0, 0
left_len = len(left)
while i < left_len and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
count += left_len - i
j += 1
result += left[i:]
result += right[j:]
return result, count
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Using a Binary Indexed Tree, aka a Fenwick tree, by Niklas B.
def count_inversions_NiklasB(a):
res = 0
counts = [0] * (len(a) + 1)
rank = {v: i for i, v in enumerate(sorted(a), 1)}
for x in reversed(a):
i = rank[x] - 1
while i:
res += counts[i]
i -= i & -i
i = rank[x]
while i <= len(a):
counts[i] += 1
i += i & -i
return res
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by B.M
# Modified by PM 2Ring to deal with the global counter
bm_count = 0
def merge_count_BM(seq):
global bm_count
bm_count = 0
sort_bm(seq)
return bm_count
def merge_bm(l1,l2):
global bm_count
l = []
while l1 and l2:
if l1[-1] <= l2[-1]:
l.append(l2.pop())
else:
l.append(l1.pop())
bm_count += len(l2)
l.reverse()
return l1 + l2 + l
def sort_bm(l):
t = len(l) // 2
return merge_bm(sort_bm(l[:t]), sort_bm(l[t:])) if t > 0 else l
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Bisection based method by Tim Babych
def solution_TimBabych(A):
sorted_left = []
res = 0
for i in range(1, len(A)):
insort_left(sorted_left, A[i-1])
# i is also the length of sorted_left
res += (i - bisect(sorted_left, A[i]))
return res
# Slightly faster, except for very small lists
def solutionE_TimBabych(A):
res = 0
sorted_left = []
for i, u in enumerate(A):
# i is also the length of sorted_left
res += (i - bisect(sorted_left, u))
insort_left(sorted_left, u)
return res
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Bisection based method by "python"
def solution_python(A):
B = list(A)
B.sort()
inversion_count = 0
for i in range(len(A)):
j = binarySearch_python(B, A[i])
while B[j] == B[j - 1]:
if j < 1:
break
j -= 1
inversion_count += j
B.pop(j)
return inversion_count
def binarySearch_python(alist, item):
first = 0
last = len(alist) - 1
found = False
while first <= last and not found:
midpoint = (first + last) // 2
if alist[midpoint] == item:
return midpoint
else:
if item < alist[midpoint]:
last = midpoint - 1
else:
first = midpoint + 1
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by Zhe Hu
def inv_cnt_ZheHu(a):
_, count = inv_cnt(a.copy())
return count
def inv_cnt(a):
n = len(a)
if n==1:
return a, 0
left = a[0:n//2] # should be smaller
left, cnt1 = inv_cnt(left)
right = a[n//2:] # should be larger
right, cnt2 = inv_cnt(right)
cnt = 0
i_left = i_right = i_a = 0
while i_a < n:
if (i_right>=len(right)) or (i_left < len(left)
and left[i_left] <= right[i_right]):
a[i_a] = left[i_left]
i_left += 1
else:
a[i_a] = right[i_right]
i_right += 1
if i_left < len(left):
cnt += len(left) - i_left
i_a += 1
return (a, cnt1 + cnt2 + cnt)
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# Merge sort version by noman pouigt
# From https://stackoverflow.com/q/47830098
def reversePairs_nomanpouigt(nums):
def merge(left, right):
if not left or not right:
return (0, left + right)
#if everything in left is less than right
if left[len(left)-1] < right[0]:
return (0, left + right)
else:
left_idx, right_idx, count = 0, 0, 0
merged_output = []
# check for condition before we merge it
while left_idx < len(left) and right_idx < len(right):
#if left[left_idx] > 2 * right[right_idx]:
if left[left_idx] > right[right_idx]:
count += len(left) - left_idx
right_idx += 1
else:
left_idx += 1
#merging the sorted list
left_idx, right_idx = 0, 0
while left_idx < len(left) and right_idx < len(right):
if left[left_idx] > right[right_idx]:
merged_output += [right[right_idx]]
right_idx += 1
else:
merged_output += [left[left_idx]]
left_idx += 1
if left_idx == len(left):
merged_output += right[right_idx:]
else:
merged_output += left[left_idx:]
return (count, merged_output)
def partition(nums):
count = 0
if len(nums) == 1 or not nums:
return (0, nums)
pivot = len(nums)//2
left_count, l = partition(nums[:pivot])
right_count, r = partition(nums[pivot:])
temp_count, temp_list = merge(l, r)
return (temp_count + left_count + right_count, temp_list)
return partition(nums)[0]
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# PM 2Ring
def merge_PM2R(seq):
seq, count = merge_sort_count_PM2R(seq)
return count
def merge_sort_count_PM2R(seq):
mid = len(seq) // 2
if mid == 0:
return seq, 0
left, left_total = merge_sort_count_PM2R(seq[:mid])
right, right_total = merge_sort_count_PM2R(seq[mid:])
total = left_total + right_total
result = []
i = j = 0
left_len, right_len = len(left), len(right)
while i < left_len and j < right_len:
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
total += left_len - i
result.extend(left[i:])
result.extend(right[j:])
return result, total
def rank_sum_PM2R(a):
total = 0
counts = [0] * len(a)
rank = {v: i for i, v in enumerate(sorted(a))}
for u in reversed(a):
i = rank[u]
total += sum(counts[:i])
counts[i] += 1
return total
# Fenwick tree functions adapted from C code on Wikipedia
def fen_sum(tree, i):
''' Return the sum of the first i elements, 0 through i-1 '''
total = 0
while i:
total += tree[i-1]
i -= i & -i
return total
def fen_add(tree, delta, i):
''' Add delta to element i and thus
to fen_sum(tree, j) for all j > i
'''
size = len(tree)
while i < size:
tree[i] += delta
i += (i+1) & -(i+1)
def fenwick_PM2R(a):
total = 0
counts = [0] * len(a)
rank = {v: i for i, v in enumerate(sorted(a))}
for u in reversed(a):
i = rank[u]
total += fen_sum(counts, i)
fen_add(counts, 1, i)
return total
def fenwick_inline_PM2R(a):
total = 0
size = len(a)
counts = [0] * size
rank = {v: i for i, v in enumerate(sorted(a))}
for u in reversed(a):
i = rank[u]
j = i + 1
while i:
total += counts[i]
i -= i & -i
while j < size:
counts[j] += 1
j += j & -j
return total
def bruteforce_loops_PM2R(a):
total = 0
for i in range(1, len(a)):
u = a[i]
for j in range(i):
if a[j] > u:
total += 1
return total
def bruteforce_sum_PM2R(a):
return sum(1 for i in range(1, len(a)) for j in range(i) if a[j] > a[i])
# Using binary tree counting, derived from C++ code (?) by prasadvk
# https://dev59.com/RXRC5IYBdhLWcg3wUvQS#16056139
def ltree_count_PM2R(a):
total, root = 0, None
for u in a:
# Store data in a list-based tree structure
# [data, count, left_child, right_child]
p = [u, 0, None, None]
if root is None:
root = p
continue
q = root
while True:
if p[0] < q[0]:
total += 1 + q[1]
child = 2
else:
q[1] += 1
child = 3
if q[child]:
q = q[child]
else:
q[child] = p
break
return total
# Counting based on radix sort, recursive version
def radix_partition_rec(a, L):
if len(a) < 2:
return 0
if len(a) == 2:
return a[1] < a[0]
left, right = [], []
count = 0
for u in a:
if u & L:
right.append(u)
else:
count += len(right)
left.append(u)
L >>= 1
if L:
count += radix_partition_rec(left, L) + radix_partition_rec(right, L)
return count
# The following functions determine swaps using a permutation of
# range(len(a)) that has the same inversion count as `a`. We can create
# this permutation with `sorted(range(len(a)), key=lambda k: a[k])`
# but `sorted(range(len(a)), key=a.__getitem__)` is a little faster.
# Counting based on radix sort, iterative version
def radix_partition_iter(seq, L):
count = 0
parts = [seq]
while L and parts:
newparts = []
for a in parts:
if len(a) < 2:
continue
if len(a) == 2:
count += a[1] < a[0]
continue
left, right = [], []
for u in a:
if u & L:
right.append(u)
else:
count += len(right)
left.append(u)
if left:
newparts.append(left)
if right:
newparts.append(right)
parts = newparts
L >>= 1
return count
def perm_radixR_PM2R(a):
size = len(a)
b = sorted(range(size), key=a.__getitem__)
n = size.bit_length() - 1
return radix_partition_rec(b, 1 << n)
def perm_radixI_PM2R(a):
size = len(a)
b = sorted(range(size), key=a.__getitem__)
n = size.bit_length() - 1
return radix_partition_iter(b, 1 << n)
# Plain sum of the counts of the permutation
def perm_sum_PM2R(a):
total = 0
size = len(a)
counts = [0] * size
for i in reversed(sorted(range(size), key=a.__getitem__)):
total += sum(counts[:i])
counts[i] = 1
return total
# Fenwick sum of the counts of the permutation
def perm_fenwick_PM2R(a):
total = 0
size = len(a)
counts = [0] * size
for i in reversed(sorted(range(size), key=a.__getitem__)):
j = i + 1
while i:
total += counts[i]
i -= i & -i
while j < size:
counts[j] += 1
j += j & -j
return total
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# All the inversion-counting functions
funcs = (
solution_TimBabych,
solutionE_TimBabych,
solution_python,
count_inversion_mkso,
count_inversions_NiklasB,
merge_count_BM,
inv_cnt_ZheHu,
reversePairs_nomanpouigt,
fenwick_PM2R,
fenwick_inline_PM2R,
merge_PM2R,
rank_sum_PM2R,
bruteforce_loops_PM2R,
bruteforce_sum_PM2R,
ltree_count_PM2R,
perm_radixR_PM2R,
perm_radixI_PM2R,
perm_sum_PM2R,
perm_fenwick_PM2R,
)
def time_test(seq, loops, verify=False):
orig = seq
timings = []
for func in funcs:
seq = orig.copy()
value = func(seq) if verify else None
t = Timer(lambda: func(seq))
result = sorted(t.repeat(3, loops))
timings.append((result, func.__name__, value))
assert seq==orig, 'Sequence altered by {}!'.format(func.__name__)
first = timings[0][-1]
timings.sort()
for result, name, value in timings:
result = ', '.join([format(u, '.5f') for u in result])
print('{:24} : {}'.format(name, result))
if verify:
# Check that all results are identical
bad = ['%s: %d' % (name, value)
for _, name, value in timings if value != first]
if bad:
print('ERROR. Value: {}, bad: {}'.format(first, ', '.join(bad)))
else:
print('Value: {}'.format(first))
print()
#Run the tests
size, loops = 5, 1 << 12
verify = True
for _ in range(7):
hi = size // 2
print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
seq = [randrange(hi) for _ in range(size)]
time_test(seq, loops, verify)
loops >>= 1
size <<= 1
#size, loops = 640, 8
#verify = False
#for _ in range(5):
#hi = size // 2
#print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
#seq = [randrange(hi) for _ in range(size)]
#time_test(seq, loops, verify)
#size <<= 1
#size, loops = 163840, 4
#verify = False
#for _ in range(3):
#hi = size // 2
#print('Size = {}, hi = {}, {} loops'.format(size, hi, loops))
#seq = [randrange(hi) for _ in range(size)]
#time_test(seq, loops, verify)
#size <<= 1
bisect
是C语言吗?我非常确定它是Python。 - Stefan Pochmann我曾经有一个类似的问题作业。我的限制是它必须具有O(nlogn)的效率。
我使用了你提出的使用归并排序的想法,因为它已经具有正确的效率。我只是在合并函数中插入了一些代码,基本上是这样的: 每当将右侧数组中的数字添加到输出数组时,我就将剩余左侧数组中的数字数量添加到逆序总数中。
现在我已经足够思考了,这对我来说非常有意义。你计算有多少次在任何数字之前出现更大的数字。
希望对你有所帮助。
def merge(l1,l2):
l = []
# global count
while l1 and l2:
if l1[-1] <= l2[-1]:
l.append(l2.pop())
else:
l.append(l1.pop())
# count += len(l2)
l.reverse()
return l1 + l2 + l
def sort(l):
t = len(l) // 2
return merge(sort(l[:t]), sort(l[t:])) if t > 0 else l
count=0
print(sort([5,1,2,4,9,3]), count)
# [1, 2, 3, 4, 5, 9] 6
编辑1
使用稳定版本的快速排序也可以完成同样的任务,这被认为略微更快:
def part(l):
pivot=l[-1]
small,big = [],[]
count = big_count = 0
for x in l:
if x <= pivot:
small.append(x)
count += big_count
else:
big.append(x)
big_count += 1
return count,small,big
def quick_count(l):
if len(l)<2 : return 0
count,small,big = part(l)
small.pop()
return count + quick_count(small) + quick_count(big)
def count_inversions(a):
n = a.size
counts = np.arange(n) & -np.arange(n) # The BIT
ags = a.argsort(kind='mergesort')
return BIT(ags,counts,n)
关于高效的BIT算法,其中的numba部分:
@numba.njit
def BIT(ags,counts,n):
res = 0
for x in ags :
i = x
while i:
res += counts[i]
i -= i & -i
i = x+1
while i < n:
counts[i] -= 1
i += i & -i
return res
请看这个链接:http://www.cs.jhu.edu/~xfliu/600.363_F03/hw_solution/solution1.pdf
希望它能给你正确的答案。
这里有一个使用二叉树变体的可能解决方案。它向每个树节点添加了一个称为 rightSubTreeSize 的字段。按照它们在数组中出现的顺序将数字插入二叉树中。如果数字进入节点的lhs,则该元素的逆序对数将是(1 + rightSubTreeSize)。因为所有这些元素都大于当前元素,并且它们应该先出现在数组中。如果元素进入节点的rhs,则只需增加其 rightSubTreeSize。以下是代码。
Node {
int data;
Node* left, *right;
int rightSubTreeSize;
Node(int data) {
rightSubTreeSize = 0;
}
};
Node* root = null;
int totCnt = 0;
for(i = 0; i < n; ++i) {
Node* p = new Node(a[i]);
if(root == null) {
root = p;
continue;
}
Node* q = root;
int curCnt = 0;
while(q) {
if(p->data <= q->data) {
curCnt += 1 + q->rightSubTreeSize;
if(q->left) {
q = q->left;
} else {
q->left = p;
break;
}
} else {
q->rightSubTreeSize++;
if(q->right) {
q = q->right;
} else {
q->right = p;
break;
}
}
}
totCnt += curCnt;
}
return totCnt;
if(p->data < q->data)
,否则重复项无法正确处理。并且在循环顶部没有必要测试 q
,无条件的 while
循环可以正常工作。另外,您忘记提到这是什么语言了。 :) 而且您的函数似乎丢失了头行。 - PM 2Ring
left.length - i
添加到反转计数器中是什么意思?我认为只需要加1就可以了,因为你已经进入了左子数组元素比右子数组元素大的逻辑情况。有人能像我5岁时那样解释一下吗? - Alfredo Gallegos