如何使用延迟标记实现线段树?

5
我正在实现一个线段树,以便能够快速回答数组A中以下查询的问题:
  • 查询 i, j:范围 (i,j) 内所有元素的和
  • 更新 i, j, k:将 k 添加到范围 (i,j) 内的所有元素
这是我的实现代码:
typedef long long intt;

const int max_num=100000,max_tree=4*max_num;
intt A[max_num],ST[max_tree];

void initialize(int node, int be, int en) {
  if(be==en) {
    ST[node]=ST[be];
  } else {
    initialize(2*node+1,be,(be+en)/2);
    initialize(2*node+2,(be+en)/2+1,en);

    ST[node]=ST[2*node+1]+ST[2*node+2];
  }
}

void upg(int node, int be, int en, int i, intt k) {
  if(be>i || en<i || be>en) return;
  if(be==en) {
    ST[node]+=k;
    return;
  }
  upg(2*node+1, be, (be+en)/2, i, k);
  upg(2*node+2, (be+en)/2+1, en, i, k);
  ST[node] = ST[2*node+1]+ST[2*node+2];
}

intt query(int node, int be, int en, int i, int j) {
  if(be>j || en<i) return -1;
  if(be>=i && en<=j) return ST[node];

  intt q1=query(2*node+1, be, (be+en)/2, i, j);
  intt q2=query(2*node+2, (be+en)/2+1, en, i, j);

  if(q1==-1) return q2;
  else if(q2==-1) return q1;
  else return q1+q2;
}

查询函数非常快,其复杂度为O(lg N),其中N是j-i。更新函数在平均情况下也很快,但当j-i很大时,更新的复杂度为O(N lg N),这并不快。
我稍微搜索了一下这个主题,并发现如果我使用带有lazy propagation的线段树来实现,则查询和更新的复杂度都是O(lg N),这比O(N lg N)渐进更快。
我还找到了另一个问题的链接,其中有一个非常好的线段树实现,它使用指针:如何使用lazy propagation实现线段树?。因此,我的问题是:是否有一种更简单的方法来实现lazy propagation,而不使用指针,而是使用数组索引,并且没有segment_tree数据结构?

1
显然,另一个SO问题的第一个答案提供了一个Java实现的链接,它不使用指针。将其翻译成C++应该是可行的。 - didierc
@didierc:最新的,而不是第一个(如果我的答案包含Java实现的链接,我会感到惊讶的;)但是,这里是你提到的链接:http://isharemylearning.blogspot.in/2012/08/lazy-propagation-in-segment-tree.html - Zeta
@Zeta 由于某种原因,它首先出现在我的页面上,但你是正确的。 - didierc
抱歉,但你只是太懒了。 - Alexander
@Alexander,你到底是什么意思? - Rontogiannis Aristofanis
为什么你想要避免使用指针和结构体? - Cheers and hth. - Alf
2个回答

3
这是我对这种数据结构和一些模板戏法的尝试。
在所有这些混乱的底部,有两个平面数组的访问,其中一个包含总和树,另一个包含要向下传播的进位值树。在概念上,它们形成一个二叉树。
二叉树中节点的真实值是存储的总和树中的值加上节点下叶子数乘以从节点返回到根的所有进位树值的总和。
同时,树中每个节点的真实值等于其下叶子节点的真实值。
我编写了一个函数来处理进位和总和,因为它们访问相同的节点。有时读取会写入。因此,通过使用零的“增量”来调用它来获得总和。
所有模板戏法所做的就是计算每个树节点的偏移量以及左右子节点的位置。
虽然我使用了一个结构体,但该结构体是瞬态的——它只是一个包装器,具有指向数组偏移量的一些预计算值。我确实存储了指向数组开头的指针,但是在程序中,每个block_ptr都使用完全相同的root值。
为了调试,我有一些可怜的Assert()和Debug()宏,以及递归求和函数调用的跟踪无参函数(我用它来跟踪调用总数)。再次,为了避免全局状态而不必要地复杂。 :)
#include <memory>
#include <iostream>

// note that you need more than 2^30 space to fit this
enum {max_tier = 30};

typedef long long intt;

#define Assert(x) (!(x)?(std::cout << "ASSERT FAILED: (" << #x << ")\n"):(void*)0)
#define DEBUG(x) 

template<size_t tier, size_t count=0>
struct block_ptr
{
  enum {array_size = 1+block_ptr<tier-1>::array_size * 2};
  enum {range_size = block_ptr<tier-1>::range_size * 2};
  intt* root;
  size_t offset;
  size_t logical_offset;
  explicit block_ptr( intt* start, size_t index, size_t logical_loc=0 ):root(start),offset(index), logical_offset(logical_loc) {}
  intt& operator()() const
  {
    return root[offset];
  }
  block_ptr<tier-1> left() const
  {
    return block_ptr<tier-1>(root, offset+1, logical_offset);
  }
  block_ptr<tier-1> right() const
  {
    return block_ptr<tier-1>(root, offset+1+block_ptr<tier-1>::array_size, logical_offset+block_ptr<tier-1>::range_size);
  }
  enum {is_leaf=false};
};

template<>
struct block_ptr<0>
{
  enum {array_size = 1};
  enum {range_size = 1};
  enum {is_leaf=true};
  intt* root;
  size_t offset;
  size_t logical_offset;

  explicit block_ptr( intt* start, size_t index, size_t logical_loc=0 ):root(start),offset(index), logical_offset(logical_loc)
  {}
  intt& operator()() const
  {
    return root[offset];
  }
  // exists only to make some of the below code easier:
  block_ptr<0> left() const { Assert(false); return *this; }
  block_ptr<0> right() const { Assert(false); return *this; }
};


template<size_t tier>
void propogate_carry( block_ptr<tier> values, block_ptr<tier> carry )
{
  if (carry() != 0)
  {
    values() += carry() * block_ptr<tier>::range_size;
    if (!block_ptr<tier>::is_leaf)
    {
      carry.left()() += carry();
      carry.right()() += carry();
    }
    carry() = 0;
  }
}

// sums the values from begin to end, but not including end!
// ie, the half-open interval [begin, end) in the tree
// if increase is non-zero, increases those values by that much
// before returning it
template<size_t tier, typename trace>
intt query_or_modify( block_ptr<tier> values, block_ptr<tier> carry, int begin, int end, int increase=0, trace const& tr = [](){} )
{
  tr();
  DEBUG(
  std::cout << begin << " " << end << " " << increase << "\n";
  if (increase)
  {
    std::cout << "Increasing " << end-begin << " elements by " << increase << " starting at " << begin+values.offset << "\n";
  }
  else
  {
    std::cout << "Totaling " << end-begin << " elements starting at " << begin+values.logical_offset << "\n";
  }
  )
  if (end <= begin)
    return 0;
  size_t mid = block_ptr<tier>::range_size / 2;
  DEBUG( std::cout << "[" << values.logical_offset << ";" << values.logical_offset+mid << ";" << values.logical_offset+block_ptr<tier>::range_size << "]\n"; )
  // exatch math first:
  bool bExact = (begin == 0 && end >= block_ptr<tier>::range_size);
  if (block_ptr<tier>::is_leaf)
  {
    Assert(bExact);
  }
  bExact = bExact || block_ptr<tier>::is_leaf; // leaves are always exact
  if (bExact)
  {
    carry()+=increase;
    intt retval =  (values()+carry()*block_ptr<tier>::range_size);
    DEBUG( std::cout << "Exact sum is " << retval << "\n"; )
    return retval;
  }
  // we don't have an exact match.  Apply the carry and pass it down to children:
  propogate_carry(values, carry);
  values() += increase * end-begin;
  // Now delegate to children:
  if (begin >= mid)
  {
    DEBUG( std::cout << "Right:"; )
    intt retval = query_or_modify( values.right(), carry.right(), begin-mid, end-mid, increase, tr );
    DEBUG( std::cout << "Right sum is " << retval << "\n"; )
    return retval;
  }
  else if (end <= mid)
  {
    DEBUG( std::cout << "Left:"; )
    intt retval = query_or_modify( values.left(), carry.left(), begin, end, increase, tr );
    DEBUG( std::cout << "Left sum is " << retval << "\n"; )
    return retval;
  }
  else
  {
    DEBUG( std::cout << "Left:"; )
    intt left = query_or_modify( values.left(), carry.left(), begin, mid, increase, tr );
    DEBUG( std::cout << "Right:"; )
    intt right = query_or_modify( values.right(), carry.right(), 0, end-mid, increase, tr );
    DEBUG( std::cout << "Right sum is " << left << " and left sum is " << right << "\n"; )
    return left+right;
  }
}

以下是翻译的结果:

这里提供一些辅助类以便更轻松地创建指定大小的线段树。但需要注意的是,你只需要一个正确大小的数组,并且可以通过指向元素0的指针构造一个block_ptr,就可以开始使用了。

template<size_t tier>
struct segment_tree
{
  typedef block_ptr<tier> full_block_ptr;
  intt block[full_block_ptr::range_size];
  full_block_ptr root() { return full_block_ptr(&block[0],0); }
  void init()
  {
    std::fill_n( &block[0], size_t(full_block_ptr::range_size), 0 );
  }
};

template<size_t entries, size_t starting=0>
struct required_tier
{
  enum{ tier =
    block_ptr<starting>::array_size >= entries
    ?starting
    :required_tier<entries, starting+1>::tier
  };
  enum{ error =
    block_ptr<starting>::array_size >= entries
    ?false
    :required_tier<entries, starting+1>::error
  };
};

// max 2^30, to limit template generation.
template<size_t entries>
struct required_tier<entries, size_t(max_tier)>
{
  enum{ tier = 0 };
  enum{ error = true };
};

// really, these just exist to create an array of the correct size
typedef required_tier< 1000000 > how_big;

enum {tier = how_big::tier};


int main()
{
  segment_tree<tier> values;
  segment_tree<tier> increments;
  Assert(!how_big::error); // can be a static assert -- fails if the enum of max tier is too small for the number of entries you want
  values.init();
  increments.init();
  auto value_root = values.root();
  auto carry_root = increments.root();

  size_t count = 0;
  auto tracer = [&count](){count++;};
  intt zero = query_or_modify( value_root, carry_root, 0, 100000, 0, tracer );
  std::cout << "zero is " << zero << " in " << count << " steps\n";
  count = 0;
  Assert( zero == 0 );
  intt test2 = query_or_modify( value_root, carry_root, 0, 100, 10, tracer ); // increase everything from 0 to 100 by 10
  Assert(test2 == 1000);
  std::cout << "test2 is " << test2 << " in " << count << " steps \n";
  count = 0;
  intt test3 = query_or_modify( value_root, carry_root, 1, 1000, 0, tracer );
  Assert(test3 == 990);
  std::cout << "test3 is " << test3 << " in " << count << " steps\n";
  count = 0;
  intt test4 = query_or_modify( value_root, carry_root, 50, 5000, 87, tracer );
  Assert(test4 == 10*(100-50) + 87*(5000-50) );
  std::cout << "test4 is " << test4 << " in " << count << " steps\n";
  count = 0;
}

虽然这不是你想要的答案,但它可能会让某人更容易地编写它。而且,写这个让我觉得很有趣。希望能有所帮助!

这段代码已在Ideone.com上使用C++0x编译器进行了测试和编译。


-1

惰性传播是指只在必要时进行更新。这是一种技术,可以使范围更新的渐进时间复杂度为O(logN)(N在这里是范围)。

假设您要更新范围[0,15],然后更新节点[0,15]并在节点中设置一个标志,表示将更新其子节点(在未使用标志的情况下使用哨兵值)。

可能的压力测试案例:

0 1 100000

0 1 100000

0 1 100000 ...重复Q次(其中Q = 99999),第100000个查询将是

1 1 100000

在这种情况下,大多数实现都会在最后一次简单查询之前翻转100000个硬币99999次,并超时。

使用惰性传播,您只需要翻转节点[0,100000] 99999次并设置/取消设置其子节点需要更新的标志即可。当实际查询本身被问及时,您开始遍历其子节点并开始翻转它们,将标志向下推送并取消父节点的标志。

哦,确保您使用适当的I/O例程(如果是C ++,请使用scanf和printf而不是cin和cout)。希望这让您了解了什么是惰性传播。更多信息:http://www.spoj.pl/forum/viewtopic.php?f=27&t=8296


我没有问懒惰传播是什么,而是如何在C++中实现它,不使用结构体或指针。 - Rontogiannis Aristofanis
дЄЇдїАдєИдљ†иѓі scanf еТМ printf жѓФ cin жИЦ cout жЫіе•љпЉЯ - AJMansfield

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接