在numpy数组中找到平衡点

3

考虑以下数组:

a = np.array([1,2,3,4,3,2,1])

我希望能够获取将数组均匀分割的元素,即在该元素之前数组的总和等于该元素之后数组的总和。在这种情况下,第4个元素a [3]可以将数组均匀分割。有没有更快的(numpy)方法来实现呢?还是我必须遍历所有元素?
期望的函数:
f(a) =  3
5个回答

4
我会选择类似这样的东西:

我会选择类似这样的东西:

def equib(a):
    c = a.cumsum()
    return np.argmin(np.abs(c-(c[-1]/2)))

首先,我们建立 a 的累积和。累加和意味着 c[i] = sum(a[:i])。然后,我们找出值与总权重之差的绝对值最小的位置。 更新:@DSM注意到我的第一个版本有一点偏移,因此这里提供另一个版本:
def equib(a):
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    return np.argmin(np.abs(c1-c2))

1
几乎可以肯定,你应该对减法使用np.abs(否则你会得到第一个元素作为最小值)。 - agomcas
更新后的版本在 OP 的示例中仍然返回 2,但实际上应该返回 3,因为它无法区分参数中的第一个 2 和第二个左括号(我的意思是 argmin 中的参数)。 - DSM
@chw21:你的方法至少快了两倍,如果不是因为那个“区分”的问题。 - ahmohamed
2
我建议使用 searchsorted 来在 c 中搜索 c[-1] / 2 - user2357112
1
合并 chw21 和 user2357112:np.searchsorted(a.cumsum(), a.sum()/2.) - agomcas
显示剩余7条评论

4

如果所有的输入值都是非负数,那么最有效的方法之一可能是构建一个累加和数组,并在其中进行二分查找,以找到两侧和为总和一半的位置。然而,很容易出现二分查找错误的情况。在试图处理所有边缘情况时,我得到了以下测试结果:

class SplitpointTest(unittest.TestCase):
    def testFloatRounding(self):
        # Due to rounding error, the cumulative sums for these inputs are
        # [1.1, 3.3000000000000003, 3.3000000000000003, 5.5, 6.6]
        # and [0.1, 0.7999999999999999, 0.7999999999999999, 1.5, 1.6]
        # Note that under default settings, numpy won't display
        # enough precision to see that.
        self.assertEquals(2, splitpoint([1.1, 2.2, 1e-20, 2.2, 1.1]))
        self.assertEquals(2, splitpoint([0.1, 0.7, 1e-20, 0.7, 0.1]))

    def testIntRounding(self):
        self.assertEquals(1, splitpoint([1, 1, 1]))
    def testIntPrecision(self):
        self.assertEquals(2, splitpoint([2**60, 1, 1, 1, 2**60]))
    def testIntMax(self):
        self.assertEquals(
            2,
            splitpoint(numpy.array([40, 23, 1, 63], dtype=numpy.int8))
        )

    def testIntZeros(self):
        self.assertEquals(
            4,
            splitpoint(numpy.array([0, 1, 0, 2, 0, 2, 0, 1], dtype=int))
        )
    def testFloatZeros(self):
        self.assertEquals(
            4,
            splitpoint(numpy.array([0, 1, 0, 2, 0, 2, 0, 1], dtype=float))
        )

在决定放弃之前,我尝试了以下几个版本:

def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c, c[-1]/2)
    # Fails on [1, 1, 1]

def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c, c[-1]/2.0)
    # Fails on [2**60, 1, 1, 1, 2**60]

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c, c[-1]/2.0)
    elif c.dtype.kind in ('i', 'u'):
        # Integer input.
        return numpy.searchsorted(c, (c[-1]+1)//2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c, c[-1]/2.0)
    # Fails on numpy.array([63, 1, 63], dtype=int8)

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c, c[-1]/2.0)
    elif c.dtype.kind in ('i', 'u'):
        # Integer input.
        return numpy.searchsorted(c, c[-1]//2 + c[-1]%2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c, c[-1]/2.0)
    # Still fails the floating-point rounding and zeros tests.

如果一直尝试,可能会使其工作,但不值得。基于chw21的第二种解决方案,即通过显式最小化左右两侧和的绝对差异来实现,更容易理解并更普遍适用。加上a = numpy.asarray(a),它通过了所有以上测试用例以及以下测试用例,这些测试用例扩展了算法预计要处理的输入类型:

class SplitpointGeneralizedTest(unittest.TestCase):
    def testNegatives(self):
        self.assertEquals(2, splitpoint([-1, 5, 2, 4]))
    def testComplex(self):
        self.assertEquals(2, splitpoint([1+1j, -5+2j, 43, -4+3j]))
    def testObjectDtype(self):
        from fractions import Fraction
        from decimal import Decimal
        self.assertEquals(2, splitpoint(map(Fraction, [1.5, 2.5, 3.5, 4])))
        self.assertEquals(2, splitpoint(map(Decimal, [1.5, 2.5, 3.5, 4])))

除非明确发现速度过慢,否则我会选择chw21的第二个解决方案。在我测试过的稍作修改后,如下:

def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    return np.argmin(np.abs(c1-c2))

我唯一能看出的缺陷是,如果输入具有无符号数据类型并且没有确切分割输入的索引,由于np.abs(c1-c2)对于无符号数据类型不起作用,因此此算法可能不会返回最接近分割输入的索引。从未指定算法在没有分割索引的情况下应该执行什么操作,因此这种行为是可接受的,但可能值得在注释中提到np.abs(c1-c2)和无符号数据类型。如果我们想要最接近分割输入的索引,则可以以一些额外的运行时间为代价获得它:

def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    if a.dtype.kind == 'u':
        # np.abs(c1-c2) doesn't work on unsigned ints
        absdiffs = np.where(c1>c2, c1-c2, c2-c1)
    else:
        # c1>c2 doesn't work on complex input.
        # We also use this case for other dtypes, since it's
        # probably faster.
        absdiffs = np.abs(c1-c2)
    return np.argmin(absdiffs)

当然,这里是针对这种行为的测试,修改后的表单通过了测试,未修改的表单失败了:
class SplitpointUnsignedTest(unittest.TestCase):
    def testBestApproximation(self):
        self.assertEquals(1, splitpoint(numpy.array([5, 5, 4, 5], dtype=numpy.uint32)))

仅仅因为你在一个回答中付出了如此大量的工作,我就必须点赞你 ;) - chw21
这是非常棒的工作。我为你感到敬佩。我会接受这个答案,因为它是对这个问题最完整的描述。 - ahmohamed

2
你可以从以下代码中找到平衡点。
    a = [1,3,5,2,2]
b = equilibirum(a)
n = len(a)
first_sum = 0
last_sum = 0
if n==1:
    print (1)
    return 0
for i in range(n):
    first_sum=first_sum+a[i]
    for j in range(i+2,n):
        last_sum=last_sum+a[j]
    if first_sum ==last_sum:
        s=i+2
        print (s)
        return 0
    last_sum=0

1

好的,这是我得到的内容,但我不确定这是否是最快的方法:

def eq(a)
    c = np.cumsum(a)
    return sum(c <= c[-1]/2)

0
你的数组中有两个“3”。看起来你应该返回所需元素的索引,而不是它的值,这样就知道哪一个“3”才是实际的平衡点了。在这个特定的例子中,它应该是:
f(a) = 4 # (because a[4] = 3)

改为:

f(a) = 3

如果是这种情况,那么这里是我关于如何定义一个适当的函数的建议:
def equi(arr):
  length = len(arr)
  if length == 0: return -1
  if length == 1: return 0

  i = 1
  j = 0 

  # starting sum1 (the 'left' sum)
  sum1 = 0

  # starting sum2 (the 'right' sum)
  sum2 = sum(arr[1:])

  while j < length:
      if sum1 == sum2: return j
      if j == length-1:
        sum2 = 0
      else:
        sum1 += arr[j]
        sum2 -= arr[j+1]
      j += 1
  if sum1 != 0: return -1

顺便说一下,这是我在Stack Overflow上的第一次贡献,我是一个编程初学者。如果我的解决方案不好,请随意评论!


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接