动态前缀和

8
有没有一种数据结构能够在O(log n)时间内返回数组的前缀和[1]、更新元素以及向数组中插入/删除元素?
[1] "前缀和"是指从第一个元素到给定索引的所有元素的总和。
例如,给定非负整数数组8 1 10 7,前三个元素的前缀和为19(8 + 1 + 10)。将第一个元素更新为7,将3插入为第二个元素并删除第三个元素得到7 3 10 7。同样,前三个元素的前缀和为20。
对于前缀和和更新,可以使用Fenwick树。但我不知道如何在其中处理添加/删除操作以使其复杂度为O(log n)。
另一方面,还有一些二叉搜索树,比如红黑树,它们都可以在对数时间内处理更新/插入/删除。但我不知道如何保持给定的顺序并在O(log n)内执行前缀和。
2个回答

5

带有隐式键的treap可以在每个查询中以O(log n)时间执行所有这些操作。 隐式键的想法非常简单:我们不在节点中存储任何键,而是为所有节点维护子树的大小,并使用此信息在添加或删除元素时找到适当的位置。

以下是我的实现:

#include <iostream>
#include <memory>

struct Node {
  int priority;
  int val;
  long long sum;
  int size;
  std::shared_ptr<Node> left;
  std::shared_ptr<Node> right;

  Node(long val): 
    priority(rand()), val(val), sum(val), size(1), left(), right() {}
};

// Returns the size of a node owned by t if it is not empty and 0 otherwise.
int getSize(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->size;
}

// Returns the sum of a node owned by t if it is not empty and 0 otherwise.
long long getSum(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->sum;
}


// Updates a node owned by t if it is not empty.
void update(std::shared_ptr<Node> t) {
  if (t) {
    t->size = 1 + getSize(t->left) + getSize(t->right);
    t->sum = t->val + getSum(t->left) + getSum(t->right);
  }
}

// Merges the nodes owned by L and R and returns the result.
std::shared_ptr<Node> merge(std::shared_ptr<Node> L, 
    std::shared_ptr<Node> R) {
  if (!L || !R)
    return L ? L : R;
  if (L->priority > R->priority) {
    L->right = merge(L->right, R);
    update(L);
    return L;
  } else {
    R->left = merge(L, R->left);
    update(R);
    return R;
  }
}

// Splits a subtree rooted in t by pos. 
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> split(
    std::shared_ptr<Node> t,
    int pos, int add) {
  if (!t)
    return make_pair(std::shared_ptr<Node>(), std::shared_ptr<Node>());
  int cur = getSize(t->left) + add;
  std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> res;
  if (pos <= cur) {
    auto ret = split(t->left, pos, add);
    t->left = ret.second;
    res = make_pair(ret.first, t); 
  } else {
    auto ret = split(t->right, pos, cur + 1);
    t->right = ret.first;
    res = make_pair(t, ret.second); 
  }
  update(t);
  return res;
}

// Returns a prefix sum of [0 ... pos]
long long getPrefixSum(std::shared_ptr<Node>& root, int pos) {
  auto parts = split(root, pos + 1, 0);
  long long res = getSum(parts.first);
  root = merge(parts.first, parts.second);
  return res;
}

// Adds a new element at a position pos with a value newValue.
// Indices are zero-based.
void addElement(std::shared_ptr<Node>& root, int pos, int newValue) {
  auto parts = split(root, pos, 0);
  std::shared_ptr<Node> newNode = std::make_shared<Node>(newValue);
  auto temp = merge(parts.first, newNode);
  root = merge(temp, parts.second);
}

// Removes an element at the given position pos.
// Indices are zero-based.
void removeElement(std::shared_ptr<Node>& root, int pos) {
  auto parts1 = split(root, pos, 0);
  auto parts2 = split(parts1.second, 1, 0);
  root = merge(parts1.first, parts2.second);
}

int main() {
  std::shared_ptr<Node> root;
  int n;
  std::cin >> n;
  for (int i = 0; i < n; i++) {
    std::string s;
    std::cin >> s;
    if (s == "add") {
      int pos, val;
      std::cin >> pos >> val;
      addElement(root, pos, val);
    } else if (s == "remove") {
      int pos;
      std::cin >> pos;
      removeElement(root, pos);
    } else {
      int pos;
      std::cin >> pos;
      std::cout << getPrefixSum(root, pos) << std::endl;
    }
  }
  return 0;
}

抱歉,我完全不明白:treap具有随机排序,那么如何更新“第一个元素”?“找到适当的位置” - 怎么做? - Ecir Hana
1
@EcirHana 一个Treap并不具有随机排序。 - kraskevich
1
@EcirHana 优先级仅用于平衡,不影响元素的顺序。 - kraskevich
我现在才意识到我和你有相同的想法(使用隐式键)。但我不确定为什么你选择了Treap而不是AVL或RB树。两者都可以使用相同的想法。 - Juan Lopes
不错!我从未想过一种方法,即回答每个查询都涉及修改数据结构。您为什么要给split()一个add参数,而不是在第二个递归调用中只调用split(t->right, pos - cur - 1) - j_random_hacker
显示剩余2条评论

3
一个想法:修改AVL树。通过索引执行添加和删除。每个节点保留其子树的计数和总和,以便所有操作都可以O(log n)完成。
使用“add_node”、“update_node”和“prefix_sum”实现了概念证明:
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.left_height = 0
        self.right_height = 0
        self.left_count = 1
        self.left_sum = value
        self.right_count = 0
        self.right_sum = 0

    def set_value(self, value):
        self.value = value
        self.left_sum = self.left.left_sum + self.left.right_sum+self.value if self.left else self.value

    def set_left(self, node):
        self.left = node
        self.left_height = max(node.left_height, node.right_height)+1 if node else 0
        self.left_count = node.left_count + node.right_count+1 if node else 1
        self.left_sum = node.left_sum + node.right_sum+self.value if node else self.value

    def set_right(self, node):
        self.right = node
        self.right_height = max(node.left_height, node.right_height)+1 if node else 0
        self.right_count = node.left_count + node.right_count if node else 0
        self.right_sum = node.left_sum + node.right_sum if node else 0

    def rotate_left(self):
        b = self.right
        self.set_right(b.left)
        b.set_left(self)
        return b

    def rotate_right(self):
        a = self.left
        self.set_left(a.right)
        a.set_right(self)
        return a

    def factor(self):
        return self.right_height - self.left_height

def add_node(root, index, node):
    if root is None: return node

    if index < root.left_count:
        root.set_left(add_node(root.left, index, node))
        if root.factor() < -1:
            if root.left.factor() > 0:
                root.set_left(root.left.rotate_left())
            return root.rotate_right()
    else:
        root.set_right(add_node(root.right, index-root.left_count, node))
        if root.factor() > 1:
            if root.right.factor() < 0:
                root.set_right(root.right.rotate_right())
            return root.rotate_left()

    return root

def update_node(root, index, value):
    if root is None: return root

    if index+1 < root.left_count:
        root.set_left(update_node(root.left, index, value))
    elif index+1 > root.left_count:
        root.set_right(update_node(root.right, index - root.left_count, value))
    else:
        root.set_value(value)

    return root


def prefix_sum(root, index):
    if root is None: return 0

    if index+1 < root.left_count:
        return prefix_sum(root.left, index)
    else:
        return root.left_sum + prefix_sum(root.right, index-root.left_count)


import random
tree = None
tree = add_node(tree, 0, Node(10))
tree = add_node(tree, 1, Node(40))
tree = add_node(tree, 1, Node(20))
tree = add_node(tree, 2, Node(70))

tree = update_node(tree, 2, 30)

print prefix_sum(tree, 0)
print prefix_sum(tree, 1)
print prefix_sum(tree, 2)
print prefix_sum(tree, 3)
print prefix_sum(tree, 4)

太棒了,谢谢!但是如果我添加一个元素会发生什么?难道不需要重新索引所有剩余的节点吗?这是O(n)的…… - Ecir Hana
不,元素索引将隐式地保留(通过每个子树上的节点计数)。 - Juan Lopes

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接