算术编码和解码算法Python

3

我正在开发一种自适应的算术编码和解码算法,在Python中已经实现了它,但对于某些字符串,我得到了正确的答案,而对于其他字符串,我得到了错误的答案。

程序启动时,提供一个参数来决定符号概率何时更改。例如,如果参数为10,则在传输/接收10个符号后,概率表根据迄今为止传输/接收的所有符号进行更改。因此,域分配也会更改。最初,我有一个均匀分布[a-z],概率为1/26。

这对于“heloworldheloworld”和许多其他情况都不起作用。

此外,我已经了解了下溢问题,但我该如何解决这个问题。

import sys
import random
import string


def encode(encode_str, N):
    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)

    low = 0
    high = float(1)/float(26)

    for key, value in sorted(cdf_range.iteritems()):
        cdf_range[key] = [low, high]
        low = high
        high += float(1)/float(26)

    for key, value in sorted(pdf.iteritems()):
        pdf[key] = float(1)/float(26)

    # for key, value in sorted(cdf_range.iteritems()):
    #   print key, value

    # for key, value in sorted(pdf.iteritems()):
    #   print key, value

    i = 26

    lower_bound = 0                                                                     # upper bound
    upper_bound = 1                                                                     # lower bound

    u = 0

    # go thru every symbol in the string
    for sym in encode_str:
        i += 1
        u += 1
        count[sym] += 1

        curr_range = upper_bound - lower_bound                                          # current range
        upper_bound = lower_bound + (curr_range * cdf_range[sym][1])                    # upper_bound
        lower_bound = lower_bound + (curr_range * cdf_range[sym][0])                    # lower bound

        # update cdf_range after N symbols have been read
        if (u == N):
            u = 0

            for key, value in sorted(pdf.iteritems()):
                pdf[key] = float(count[key])/float(i)

            low = 0
            for key, value in sorted(cdf_range.iteritems()):
                high = pdf[key] + low
                cdf_range[key] = [low, high]
                low = high

    return lower_bound

def decode(encoded, strlen, every):
    decoded_str = ""

    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)

    low = 0
    high = float(1)/float(26)

    for key, value in sorted(cdf_range.iteritems()):
        cdf_range[key] = [low, high]
        low = high
        high += float(1)/float(26)

    for key, value in sorted(pdf.iteritems()):
        pdf[key] = float(1)/float(26)


    lower_bound = 0                                                                     # upper bound
    upper_bound = 1                                                                     # lower bound

    k = 0

    while (strlen != len(decoded_str)):
        for key, value in sorted(pdf.iteritems()):

            curr_range = upper_bound - lower_bound                                      # current range
            upper_cand = lower_bound + (curr_range * cdf_range[key][1])                 # upper_bound
            lower_cand = lower_bound + (curr_range * cdf_range[key][0])                 # lower bound

            if (lower_cand <= encoded < upper_cand):
                k += 1
                decoded_str += key

                if (strlen == len(decoded_str)):
                    break

                upper_bound = upper_cand
                lower_bound = lower_cand

                count[key] += 1

                if (k == every):
                    k = 0
                    for key, value in sorted(pdf.iteritems()):
                        pdf[key] = float(count[key])/float(26+len(decoded_str))

                    low = 0
                    for key, value in sorted(cdf_range.iteritems()):
                        high = pdf[key] + low
                        cdf_range[key] = [low, high]
                        low = high

    print decoded_str

def main():
    count = 10
    encode_str = "yyyyuuuuyyyy"
    strlen = len(encode_str)
    every = 3
    encoded = encode(encode_str, every)
    decoded = decode(encoded, strlen, every)

if __name__ == '__main__':
    main()

所以它对于 heloworldheloworld 不起作用。它会抛出异常还是产生意料之外的输出?您能否澄清一下? - Konstantin
它没有抛出错误,我首先对“helloworldhelloworld”进行编码,然后使用算术编码算法,得到一个浮点数,我必须使用该浮点数解码才能获得“helloworldhelloworld”,但是对于这个字符串不起作用,但是对于其他字符串(如“yyyyhhhh”或“helloworld”)有效。没有语法错误,算法在某个地方出了问题,但我很难追踪它。 - Pete
2个回答

2
这是因为Python的float只有53位精度。您无法编码非常长的字符串。
您可能想使用decimal代替floats以获得任意精度。
import sys
import random
import string

import decimal
from decimal import Decimal

decimal.getcontext().prec=100

def encode(encode_str, N):
    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)

    low = 0
    high = Decimal(1)/Decimal(26)

    for key, value in sorted(cdf_range.iteritems()):
        cdf_range[key] = [low, high]
        low = high
        high += Decimal(1)/Decimal(26)

    for key, value in sorted(pdf.iteritems()):
        pdf[key] = Decimal(1)/Decimal(26)

    # for key, value in sorted(cdf_range.iteritems()):
    #   print key, value

    # for key, value in sorted(pdf.iteritems()):
    #   print key, value

    i = 26

    lower_bound = 0                                                                     # upper bound
    upper_bound = 1                                                                     # lower bound

    u = 0

    # go thru every symbol in the string
    for sym in encode_str:
        i += 1
        u += 1
        count[sym] += 1

        curr_range = upper_bound - lower_bound                                          # current range
        upper_bound = lower_bound + (curr_range * cdf_range[sym][1])                    # upper_bound
        lower_bound = lower_bound + (curr_range * cdf_range[sym][0])                    # lower bound

        # update cdf_range after N symbols have been read
        if (u == N):
            u = 0

            for key, value in sorted(pdf.iteritems()):
                pdf[key] = Decimal(count[key])/Decimal(i)

            low = 0
            for key, value in sorted(cdf_range.iteritems()):
                high = pdf[key] + low
                cdf_range[key] = [low, high]
                low = high

    return lower_bound

def decode(encoded, strlen, every):
    decoded_str = ""

    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)

    low = 0
    high = Decimal(1)/Decimal(26)

    for key, value in sorted(cdf_range.iteritems()):
        cdf_range[key] = [low, high]
        low = high
        high += Decimal(1)/Decimal(26)

    for key, value in sorted(pdf.iteritems()):
        pdf[key] = Decimal(1)/Decimal(26)


    lower_bound = 0                                                                     # upper bound
    upper_bound = 1                                                                     # lower bound

    k = 0

    while (strlen != len(decoded_str)):
        for key, value in sorted(pdf.iteritems()):

            curr_range = upper_bound - lower_bound                                      # current range
            upper_cand = lower_bound + (curr_range * cdf_range[key][1])                 # upper_bound
            lower_cand = lower_bound + (curr_range * cdf_range[key][0])                 # lower bound

            if (lower_cand <= encoded < upper_cand):
                k += 1
                decoded_str += key

                if (strlen == len(decoded_str)):
                    break

                upper_bound = upper_cand
                lower_bound = lower_cand

                count[key] += 1

                if (k == every):
                    k = 0
                    for key, value in sorted(pdf.iteritems()):
                        pdf[key] = Decimal(count[key])/Decimal(26+len(decoded_str))

                    low = 0
                    for key, value in sorted(cdf_range.iteritems()):
                        high = pdf[key] + low
                        cdf_range[key] = [low, high]
                        low = high

    print decoded_str

def main():
    count = 10
    encode_str = "heloworldheloworld"
    strlen = len(encode_str)
    every = 3
    encoded = encode(encode_str, every)
    decoded = decode(encoded, strlen, every)

if __name__ == '__main__':
    main()

1
错误出现在大约12个字符长度的字符串处。这接近Python使用的双精度,可能导致您的问题。
我使用了具有任意精度的BigFloat库进行了快速测试,并获得了正确的答案:
import sys
import random
import string
from bigfloat import *

factor = BigFloat(1)/BigFloat(26)

def encode(encode_str, N):
    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)

    with precision(200) + RoundTowardZero:
        low = 0
        high = factor

        for key, value in sorted(cdf_range.iteritems()):
            cdf_range[key] = [low, high]
            low = high
            high += factor

        for key, value in sorted(pdf.iteritems()):
            pdf[key] = factor

        # for key, value in sorted(cdf_range.iteritems()):
        #   print key, value

        # for key, value in sorted(pdf.iteritems()):
        #   print key, value

        i = 26

        lower_bound = 0                         # upper bound
        upper_bound = 1                         # lower bound

        u = 0

        # go thru every symbol in the string
        for sym in encode_str:
            i += 1
            u += 1
            count[sym] += 1

            curr_range = upper_bound - lower_bound                                          # current range
            upper_bound = lower_bound + (curr_range * cdf_range[sym][1])                    # upper_bound
            lower_bound = lower_bound + (curr_range * cdf_range[sym][0])                    # lower bound

            # update cdf_range after N symbols have been read
            if (u == N):
                u = 0

                for key, value in sorted(pdf.iteritems()):
                    pdf[key] = BigFloat(count[key])/BigFloat(i)

                low = 0
                for key, value in sorted(cdf_range.iteritems()):
                    high = pdf[key] + low
                    cdf_range[key] = [low, high]
                    low = high

    return lower_bound

def decode(encoded, strlen, every):
    decoded_str = ""

    count = dict.fromkeys(string.ascii_lowercase, 1)                                        # probability table
    cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
    pdf = dict.fromkeys(string.ascii_lowercase, 0)


    with precision(200) + RoundTowardZero:
        low = 0
        high = factor

        for key, value in sorted(cdf_range.iteritems()):
            cdf_range[key] = [low, high]
            low = high
            high += factor

        for key, value in sorted(pdf.iteritems()):
            pdf[key] = factor


        lower_bound = BigFloat(0)                           # upper bound
        upper_bound = BigFloat(1)                           # lower bound

        k = 0

        while (strlen != len(decoded_str)):
            for key, value in sorted(pdf.iteritems()):

                curr_range = upper_bound - lower_bound                                      # current range
                upper_cand = lower_bound + (curr_range * cdf_range[key][1])                 # upper_bound
                lower_cand = lower_bound + (curr_range * cdf_range[key][0])                 # lower bound

                if (lower_cand <= encoded < upper_cand):
                    k += 1
                    decoded_str += key

                    if (strlen == len(decoded_str)):
                        break

                    upper_bound = upper_cand
                    lower_bound = lower_cand

                    count[key] += 1

                    if (k == every):
                        k = 0
                        for key, value in sorted(pdf.iteritems()):
                            pdf[key] = BigFloat(count[key])/BigFloat(26+len(decoded_str))

                        low = 0
                        for key, value in sorted(cdf_range.iteritems()):
                            high = pdf[key] + low
                            cdf_range[key] = [low, high]
                            low = high

        print decoded_str

def main():
    count = 10
    encode_str = "heloworldheloworld"
    strlen = len(encode_str)
    every = 3
    encoded = encode(encode_str, every)
    decoded = decode(encoded, strlen, every)

if __name__ == '__main__':
    main()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接