使用Python实现归并排序

47

我找不到任何可行的Python 3.3归并排序算法代码,所以我自己写了一个。有没有办法让它更快?它可以在大约0.3-0.5秒内对20,000个数字进行排序。

def msort(x):
    result = []
    if len(x) < 2:
        return x
    mid = int(len(x)/2)
    y = msort(x[:mid])
    z = msort(x[mid:])
    while (len(y) > 0) or (len(z) > 0):
        if len(y) > 0 and len(z) > 0:
            if y[0] > z[0]:
                result.append(z[0])
                z.pop(0)
            else:
                result.append(y[0])
                y.pop(0)
        elif len(z) > 0:
            for i in z:
                result.append(i)
                z.pop(0)
        else:
            for i in y:
                result.append(i)
                y.pop(0)
    return result

9
在迭代列表时,不应该使用 pop 操作来从列表中弹出元素,因为这会导致数组元素的不必要移动。在迭代列表时,应该避免修改列表本身。 - poke
1
此外,在普通的归并排序实现中,可能没有针对Python 3.3的特定内容,因此您可以在Google上搜索“python mergesort”,并使用任何您找到的实现,即使它是针对旧版本的。例如,这个网址:http://www.geekviewpoint.com/python/sorting/mergesort - Tamás
这个问题可能有点老了,但是使用归并排序合并结果数组会不会占用更多的内存呢?因为归并排序需要使用两倍于原数组的内存来进行排序,而我们又要在结果中再次生成该数组。 - siddhesh
32个回答

77

第一个改进是简化主循环中的三种情况:不是在一些序列中迭代,而是在两个序列都有元素的情况下进行迭代。当离开循环时,它们中的一个将为空,我们不知道是哪一个,但我们不关心:我们将它们附加到结果的末尾。

def msort2(x):
    if len(x) < 2:
        return x
    result = []          # moved!
    mid = int(len(x) / 2)
    y = msort2(x[:mid])
    z = msort2(x[mid:])
    while (len(y) > 0) and (len(z) > 0):
        if y[0] > z[0]:
            result.append(z[0])
            z.pop(0)
        else:
            result.append(y[0])
            y.pop(0)
    result += y
    result += z
    return result
第二个优化是避免 pop 弹出元素。相反,使用两个索引:
def msort3(x):
    if len(x) < 2:
        return x
    result = []
    mid = int(len(x) / 2)
    y = msort3(x[:mid])
    z = msort3(x[mid:])
    i = 0
    j = 0
    while i < len(y) and j < len(z):
        if y[i] > z[j]:
            result.append(z[j])
            j += 1
        else:
            result.append(y[i])
            i += 1
    result += y[i:]
    result += z[j:]
    return result

最后一个改进是使用非递归算法对短序列进行排序。在这种情况下,我使用内置的sorted函数,并在输入大小小于20时使用它:

def msort4(x):
    if len(x) < 20:
        return sorted(x)
    result = []
    mid = int(len(x) / 2)
    y = msort4(x[:mid])
    z = msort4(x[mid:])
    i = 0
    j = 0
    while i < len(y) and j < len(z):
        if y[i] > z[j]:
            result.append(z[j])
            j += 1
        else:
            result.append(y[i])
            i += 1
    result += y[i:]
    result += z[j:]
    return result

我对于随机列表中的 100000 个整数进行排序的测量结果如下:原始版本需要 2.46 秒,msort2 需要 2.33 秒,msort3 需要 0.60 秒,msort4 需要 0.40 秒。参考一下,使用 sorted 对整个列表进行排序只需要 0.03 秒。


60
使用 sorted() 感觉像是作弊。 - simonzack
我在Python 2.7.6中尝试了您的msort3方法,但是出现了以下错误-Traceback(最近的调用最后):File“mergesort.py”,第21行,在<module>中msort3([5,24,87,55,32,1,45]);File“mergesort.py”,第6行,在msort3中y = msort3(x [:mid])File“mergesort.py”,第10行,在msort3中i <len(y)and j <len(z)while i <len(y)and j <len(z):TypeError:类型为'NoneType'的对象没有len() - Abhishek Prakash
我在Python 3.4.0中尝试了相同的msort3方法,但出现了以下错误 - [24, 87] Traceback (most recent call last): File "mergesort.py", line 21, in <module> msort3([5,24, 87, 55, 32, 1, 45]); File "mergesort.py", line 6, in msort3 y = msort3(x[:mid]) File "mergesort.py", line 10, in msort3 while i < len(y) and j < len(z): TypeError: object of type 'NoneType' has no len() - Abhishek Prakash
@AbhishekPrakash:我无法在Python 2.7.5中重现错误。稍后会在另一台机器上尝试。return语句是否书写良好? - anumi
2
@AbhishekPrakash:我在Python 2.7.6和Python 3.4.0(Ubuntu 14.04)下运行了您的测试,没有出现任何问题。如果您使用的是print而不是return,则函数将返回None(因为没有找到返回值),并且会中断递归。 - anumi

29

来自麻省理工学院课程的代码。(带有通用合作者)

import operator


def merge(left, right, compare):
    result = []
    i, j = 0, 0
    while i < len(left) and j < len(right):
        if compare(left[i], right[j]):
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
    while i < len(left):
        result.append(left[i])
        i += 1
    while j < len(right):
        result.append(right[j])
        j += 1
    return result


def mergeSort(L, compare=operator.lt):
    if len(L) < 2:
        return L[:]
    else:
        middle = int(len(L) / 2)
        left = mergeSort(L[:middle], compare)
        right = mergeSort(L[middle:], compare)
        return merge(left, right, compare)

1
当我们跳出第一个while循环后,我们可以执行以下操作: 如果len(left) == i: result.extend(right[j:]) 否则: result.extend(left[i:]) - Kishan Mehta

21
def merge_sort(x):

    if len(x) < 2:return x

    result,mid = [],int(len(x)/2)

    y = merge_sort(x[:mid])
    z = merge_sort(x[mid:])

    while (len(y) > 0) and (len(z) > 0):
            if y[0] > z[0]:result.append(z.pop(0))   
            else:result.append(y.pop(0))

    result.extend(y+z)
    return result

你正在创建一个新列表而不是修改原始列表...这不是一个好主意! - NoobEditor
非常简约的方法,但使用extend()无法展示合并的概念/算法……我的意思是,没有合并算法实现的归并排序算法是什么! - grepit

16
您可以在归并排序的顶层调用中初始化整个结果列表:
result = [0]*len(x)   # replace 0 with a suitable default element if necessary. 
                      # or just copy x (result = x[:])

然后对于递归调用,您可以使用一个辅助函数,将索引而不是子列表传递给它。底层调用从x中读取值并直接写入result

这样,您就可以避免所有的popappend操作,从而提高性能。


13

采用我的实现

def merge_sort(sequence):
    """
    Sequence of numbers is taken as input, and is split into two halves, following which they are recursively sorted.
    """
    if len(sequence) < 2:
        return sequence

    mid = len(sequence) // 2     # note: 7//2 = 3, whereas 7/2 = 3.5

    left_sequence = merge_sort(sequence[:mid])
    right_sequence = merge_sort(sequence[mid:])

    return merge(left_sequence, right_sequence)

def merge(left, right):
    """
    Traverse both sorted sub-arrays (left and right), and populate the result array
    """
    result = []
    i = j = 0
    while i < len(left) and j < len(right):
        if left[i] < right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
    result += left[i:]
    result += right[j:]

    return result

# Print the sorted list.
print(merge_sort([5, 2, 6, 8, 5, 8, 1]))

返回错误:切片索引必须是整数、无或具有 index 方法 - Claudiu Creanga
1
在Python 2.7.5上运行良好。 - Dimitri W
这是Tim Roughgarden的《算法点亮》一书的实现。 - user9652688
如何将值按顺序保存,而不是创建一个名为“result”的新列表? - user

7

如前所述,l.pop(0) 是一个 O(len(l)) 的操作,必须避免使用,上述的msort函数是O(n**2)的。如果效率很重要,索引更好但也有成本。for x in l 更快,但对于归并排序来说不易实现:可以在这里使用iter。最后,检查i < len(l) 会被测试两次,因为在访问元素时再次进行了测试:异常机制(try except)更好,并且可以提高30%的性能。

def msort(l):
    if len(l)>1:
        t=len(l)//2
        it1=iter(msort(l[:t]));x1=next(it1)
        it2=iter(msort(l[t:]));x2=next(it2)
        l=[]
        try:
            while True:
                if x1<=x2: l.append(x1);x1=next(it1)
                else     : l.append(x2);x2=next(it2)
        except:
            if x1<=x2: l.append(x2);l.extend(it2)
            else:      l.append(x1);l.extend(it1)
    return l

6
循环类似这样的代码可能可以加速:
for i in z:
    result.append(i)
    z.pop(0)

相反,只需执行以下操作:
result.extend(z)

请注意,无需清理z的内容,因为您不会使用它。

5

这是一个更长的计算逆序对并符合sorted接口的算法。很容易将其修改为就地排序对象的方法。

import operator

class MergeSorted:

    def __init__(self):
        self.inversions = 0

    def __call__(self, l, key=None, reverse=False):

        self.inversions = 0

        if key is None:
            self.key = lambda x: x
        else:
            self.key = key

        if reverse:
            self.compare = operator.gt
        else:
            self.compare = operator.lt

        dest = list(l)
        working = [0] * len(l)
        self.inversions = self._merge_sort(dest, working, 0, len(dest))
        return dest

    def _merge_sort(self, dest, working, low, high):
        if low < high - 1:
            mid = (low + high) // 2
            x = self._merge_sort(dest, working, low, mid)
            y = self._merge_sort(dest, working, mid, high)
            z = self._merge(dest, working, low, mid, high)
            return (x + y + z)
        else:
            return 0

    def _merge(self, dest, working, low, mid, high):
        i = 0
        j = 0
        inversions = 0

        while (low + i < mid) and (mid + j < high):
            if self.compare(self.key(dest[low + i]), self.key(dest[mid + j])):
                working[low + i + j] = dest[low + i]
                i += 1
            else:
                working[low + i + j] = dest[mid + j]
                inversions += (mid - (low + i))
                j += 1

        while low + i < mid:
            working[low + i + j] = dest[low + i]
            i += 1

        while mid + j < high:
            working[low + i + j] = dest[mid + j]
            j += 1

        for k in range(low, high):
            dest[k] = working[k]

        return inversions


msorted = MergeSorted()

用途

>>> l = [5, 2, 3, 1, 4]
>>> s = msorted(l)
>>> s
[1, 2, 3, 4, 5]
>>> msorted.inversions
6

>>> l = ['e', 'b', 'c', 'a', 'd']
>>> d = {'a': 10,
...      'b': 4,
...      'c': 2,
...      'd': 5,
...      'e': 9}
>>> key = lambda x: d[x]
>>> s = msorted(l, key=key)
>>> s
['c', 'b', 'd', 'e', 'a']
>>> msorted.inversions
5

>>> l = [5, 2, 3, 1, 4]
>>> s = msorted(l, reverse=True)
>>> s
[5, 4, 3, 2, 1]
>>> msorted.inversions
4

>>> l = ['e', 'b', 'c', 'a', 'd']
>>> d = {'a': 10,
...      'b': 4,
...      'c': 2,
...      'd': 5,
...      'e': 9}
>>> key = lambda x: d[x]
>>> s = msorted(l, key=key, reverse=True)
>>> s
['a', 'e', 'd', 'b', 'c']
>>> msorted.inversions
5

3
这里是CLRS实现的代码:

CLRS

def merge(arr, p, q, r):
    n1 = q - p + 1
    n2 = r - q
    right, left = [], []
    for i in range(n1):
        left.append(arr[p + i])
    for j in range(n2):
        right.append(arr[q + j + 1])
    left.append(float('inf'))
    right.append(float('inf'))
    i = j = 0
    for k in range(p, r + 1):
        if left[i] <= right[j]:
            arr[k] = left[i]
            i += 1
        else:
            arr[k] = right[j]
            j += 1


def merge_sort(arr, p, r):
    if p < r:
        q = (p + r) // 2
        merge_sort(arr, p, q)
        merge_sort(arr, q + 1, r)
        merge(arr, p, q, r)


if __name__ == '__main__':
    test = [5, 2, 4, 7, 1, 3, 2, 6]
    merge_sort(test, 0, len(test) - 1)
    print test

结果:

[1, 2, 2, 3, 4, 5, 6, 7]

使用left.append(float('inf'))right.append(float('inf'))的原因是什么?是否有其他替代方法? - user9652688

3
许多人已经正确回答了这个问题,这只是另一种解决方案(虽然我的解决方案与Max Montana非常相似),但我在实现上有一些不同:
让我们在进入代码之前先回顾一下一般思路:
- 将列表分成两个大致相等的部分。 - 对左半部分进行排序。 - 对右半部分进行排序。 - 将两个排序后的部分合并为一个排序后的列表。
以下是代码(使用python 3.7测试):
def merge(left,right):
    result=[] 
    i,j=0,0
    while i<len(left) and j<len(right):
        if left[i] < right[j]:
            result.append(left[i])
            i+=1
        else:
            result.append(right[j])
            j+=1
    result.extend(left[i:]) # since we want to add each element and not the object list
    result.extend(right[j:])
    return result

def merge_sort(data):
    if len(data)==1:
        return data
    middle=len(data)//2
    left_data=merge_sort(data[:middle])
    right_data=merge_sort(data[middle:])
    return merge(left_data,right_data)


data=[100,5,200,3,100,4,8,9] 
print(merge_sort(data))

我想知道while循环块是否会使您的解决方案不稳定,如果i == j:将j附加到结果中,[1, 2, 3],[1, 8, 9],如果我没有弄错,结果将从右侧列表附加。 - Vitaliy Terziev

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接