如何使用线段树计算在区间[a,b]内小于给定常数的元素数量? (ACM)

3

我对线段树比较新,并希望通过一些线段树的练习来让自己忙起来。

这个问题更像是ACM类型的,有以下条件:

  • 有n个数字和m个操作,其中n,m<=10,000,每个操作可以是以下之一:
    • 将区间更新为减去一个数字x,每次的x可以不同
    • 查询一个区间,找到区间中小于等于0的数字数量

建立线段树和更新显然可以在O(nlog n) / O(log n)内完成。

但是我无法想出如何在O(log n)内进行查询,请问是否有人可以给我一些建议/提示?任何意见都将非常有帮助!谢谢!

TL;DR:

给定n个数字和2种操作:

  1. 将x添加到[a,b]中的所有元素中,其中x每次可以不同
  2. 查询[a,b]中元素数<C,其中C是已知的常数

如何使操作1和2都可以在O(log n)内完成?


这个问题来自哪里?是作业吗? - Codie CodeMonkey
请查看我在 http://stackoverflow.com/questions/18687589 的回答。 - Peter de Rivaz
Codie: 不,实际上我本科课程中从未学过线段树,只是当时从我的ACM团队培训中学到了它。所以是的,我现在正在工作,这不是我的家庭作业。Peter: 真的非常感谢,很抱歉我之前找不到这篇文章...它真的很有帮助,像往常一样,主要问题在于树节点的设计...我错过了“所有子节点中最小正值”的部分,因此无法完全使用惰性传播...再次感谢! - shole
这个问题似乎不适合此处,因为它涉及计算机科学并应该发布在CS.SE上。 - Saeed Amiri
2个回答

4

很棒的问题:)

我思考了一段时间,但仍然无法用线段树解决这个问题,但我尝试使用“桶方法”来解决这个问题。

我们可以将初始的n个数字分成B个桶,在每个桶中对数字进行排序,并维护每个桶中的总添加值。然后对于每个查询:

  • “添加”更新区间 [a,b] 为 c

    我们只需要重建最多两个桶,并将c添加到(b-a)/BUCKET_SIZE个桶中

  • “查询”查询区间 [a,b] <= c

    我们只需要扫描最多两个桶,每个值逐个快速通过(b-a)/BUCKET_SIZE个桶进行二进制搜索即可

每个查询应在O(N / BUCKET_SIZE * log(BUCKET_SIZE,2))内运行,这比暴力方法(O(N))更小。虽然它比O(logN)大,但在大多数情况下可能已经足够。

这里是测试代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <ctime>
#include <cassert>

using namespace std;

struct Query {
    //A a b c  add c in [a, b] of arr
    //Q a b c  Query number of i in [a, b] which arr[i] <= c
    char ty;
    int a, b, c;
    Query(char _ty, int _a, int _b, int _c):ty(_ty), a(_a), b(_b), c(_c){}
};

int n, m;
vector<int> arr;
vector<Query> queries;

vector<int> bruteforce() {
    vector<int> ret;
    vector<int> numbers = arr;
    for (int i = 0; i < m; i++) {
        Query q = queries[i];
        if (q.ty == 'A') {
            for (int i = q.a; i <= q.b; i++) {
                numbers[i] += q.c;
            }
            ret.push_back(-1);
        } else {
            int tmp = 0;
            for(int i = q.a; i <= q.b; i++) {
                tmp += numbers[i] <= q.c;
            }
            ret.push_back(tmp);
        }
    }
    return ret;
}

struct Bucket {
    vector<int> numbers;
    vector<int> numbers_sorted;
    int add;
    Bucket() {
        add = 0;
        numbers_sorted.clear();
        numbers.clear();
    }
    int query(int pos) {
        return numbers[pos] + add;
    }
    void add_pos(int pos, int val) {
        numbers[pos] += val;
    }
    void build() {
        numbers_sorted = numbers;
        sort(numbers_sorted.begin(), numbers_sorted.end());
    }
};

vector<int> bucket_count(int bucket_size) {
    vector<int> ret;

    vector<Bucket> buckets;
    buckets.resize(int(n / bucket_size) + 5);
    for (int i = 0; i < n; i++) {
        buckets[i / bucket_size].numbers.push_back(arr[i]);
    }

    for (int i = 0; i <= n / bucket_size; i++) {
        buckets[i].build();
    }

    for (int i = 0; i < m; i++) {
        Query q = queries[i];
        char ty = q.ty;
        int a, b, c;
        a = q.a, b = q.b, c = q.c;
        if (ty == 'A') {
            set<int> affect_buckets;
            while (a < b && a % bucket_size != 0) buckets[a/ bucket_size].add_pos(a % bucket_size, c), affect_buckets.insert(a/bucket_size), a++;
            while (a < b && b % bucket_size != 0) buckets[b/ bucket_size].add_pos(b % bucket_size, c), affect_buckets.insert(b/bucket_size), b--;
            while (a < b) {
                buckets[a/bucket_size].add += c;
                a += bucket_size;
            }
            buckets[a/bucket_size].add_pos(a % bucket_size, c), affect_buckets.insert(a / bucket_size);
            for (set<int>::iterator it = affect_buckets.begin(); it != affect_buckets.end(); it++) {
                int id = *it;
                buckets[id].build();
            }
            ret.push_back(-1);
        } else {
            int tmp = 0;
            while (a < b && a % bucket_size != 0) tmp += (buckets[a/ bucket_size].query(a % bucket_size) <=c), a++;
            while (a < b && b % bucket_size != 0) tmp += (buckets[b/ bucket_size].query(b % bucket_size) <=c), b--;
            while (a < b) {
                int pos = a / bucket_size;
                tmp += upper_bound(buckets[pos].numbers_sorted.begin(), buckets[pos].numbers_sorted.end(), c - buckets[pos].add) - buckets[pos].numbers_sorted.begin();
                a += bucket_size;
            }
            tmp += (buckets[a / bucket_size].query(a % bucket_size) <= c);
            ret.push_back(tmp);
        }
    }

    return ret;
}

void process(int cas) {

    clock_t begin_t=clock();

    vector<int> bf_ans = bruteforce();
    clock_t  bf_end_t =clock();
    double bf_sec = ((1.0 * bf_end_t - begin_t)) / CLOCKS_PER_SEC;

    //bucket_size is important
    int bucket_size = 200;
    vector<int> ans = bucket_count(bucket_size);

    clock_t  bucket_end_t =clock();
    double bucket_sec = ((1.0 * bucket_end_t - bf_end_t)) / CLOCKS_PER_SEC;

    bool correct = true;
    for (int i = 0; i < ans.size(); i++) {
        if (ans[i] != bf_ans[i]) {
            cout << "query " << i + 1 << " bf = " << bf_ans[i] << " bucket  = " << ans[i] << "  bucket size = " <<  bucket_size << " " << n << " " << m <<  endl;
            correct = false;
        }
    }
    printf("Case #%d:%s bf_sec = %.9lf, bucket_sec = %.9lf\n", cas, correct ? "YES":"NO", bf_sec, bucket_sec);
}

void read() {
    cin >> n >> m;
    arr.clear();
    for (int i = 0; i < n; i++) {
        int val;
        cin >> val;
        arr.push_back(val);
    }
    queries.clear();
    for (int i = 0; i < m; i++) {
        char ty;
        int a, b, c;
        // a, b, c in [0, n - 1], a <= b
        cin >> ty >> a >> b >> c;
        queries.push_back(Query(ty, a, b, c));
    }
}

void run(int cas) {
    read();
    process(cas);
}

int main() {
    freopen("bucket.in", "r", stdin);
    //freopen("bucket.out", "w", stdout);
    int T;
    scanf("%d", &T);
    for (int cas  = 1; cas <= T; cas++) {
        run(cas);
    }
    return 0;
}

这里是数据生成代码:
#coding=utf8

import random
import math

def gen_buckets(f):
    t = random.randint(10, 20)
    print >> f, t
    nlimit = 100000
    mlimit = 10000
    limit = 100000
    for i in xrange(t):
        n = random.randint(1, nlimit)
        m = random.randint(1, mlimit)
        print >> f, n, m

        for i in xrange(n):
            val = random.randint(1, limit)
            print >> f, val ,
        print >> f
        for i in xrange(m):
            ty = random.randint(1, 2)
            a = random.randint(0, n - 1)
            b = random.randint(a, n - 1)
            #a = 0
            #b = n - 1
            c = random.randint(-limit, limit)
            print >> f, 'A' if ty == 1 else 'Q', a, b, c


f = open("bucket.in", "w")
gen_buckets(f)

0
尝试使用树状数组(Binary Index Trees,BIT)代替分段树。这里是 tutorial 的链接。

谢谢你的帮助,Dennis。实际上,在我设计解决问题的算法时,BIT是我的第一次尝试,但由于某些原因,我知道BIT不能在O(nlogn)的时间内解决这个问题,只有线段树才能胜任... - shole

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接