USACO 2023 Feb. Problem 1. Hungry Cow 解法

USACO 2023 February Contest, Platinum
Problem 1. Hungry Cow

USACO 2023 の 3rd Contest の復習をする。

全体感想

問題名 略解 目標時間
Hungry Cow セグツリー 2時間
Problem Setting bit dp 20分
Watching Cowflix 木dp 3時間

解説を見てコンテスト4時間で3問満点をだすのは不可能だと思ったが、満点は複数人居る。

Hungry Cow 問題

usaco.org

牧場の牛は毎晩、エサの干し草があれば 1 単位の干し草を食べる、干し草がない場合は何もしない。
最初、牧場に干し草が届く予定はない。
U 個の予定の更新が与えられる。i 番目の予定は:

  •  d_i 日目の朝に  b_i 単位の干し草が届けられる

同じ日の予定がある場合は古い予定は消して、新しいもので上書きする。
牛が干し草を食べられる日にちの総和を出力せよ。(ただし総和は大きい可能性があるので 1000000007 で割った余りを出力せよ)

制約

 1 ≦ U ≦ 100000
 1 ≦ d_i ≦ 10 ^ {14}
 1 ≦ b_i ≦ 10 ^ 9

3
4 3
1 5
1 2

各予定を更新すると、

  • 4 + 5 + 6 = 15
  • 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 = 36
  • 1 + 2 + 4 + 5 + 6 = 18

解法

  1. クエリをノードとするようなセグツリーを作る。U個のアップデートをセグメントツリーに分解する
  2. 日にちの区間をノードとするセグツリーを作る。このセグツリーはロールバックに対応する必要がある

クエリのセグツリーをDFSしながら、日にちのセグツリーを更新・ロールバックする

クエリのセグツリー

同じ日の更新が複数あって、特に前回より減るような更新があると難しい。そのため各更新に対して上書きされるまでの期間を求めて、その範囲に対応するセグツリーのノードに加算する。
セグツリーをDFSすると、ツリーの子に移動するときにそのノードの更新を適用させ、親に戻るときには更新をロールバックする。

  • 更新の回数が logU 倍になる
  • 増える更新だけ対応すればよい
  • 更新のロールバックが必要になる (または永続化)

日にちのセグツリー

増える更新だけ対応するとき、重なる区間を併合する必要があるが、一つずつ併合するとロールバックの時に復活して計算量が悪くなるので、一回の更新を logD 時間で行いたい(D = max d + sum b)。
日にちの区間を完全に覆うならば、子ノードに再帰しないようなセグツリーを作る必要がある。
日にちの範囲が広いので座標圧縮する。干し草を食べられる連続した区間の左端は  d_i のいずれかと等しい。右端は必ずしも  d_i と一致するとは限らない。

  • リーフ:  d_i 以上  d_{i+1} 未満の区間に対応して、この区間の prefix m 日は連続して干し草を食べられる ( 0 ≦ m ≦ d_{i+1} - d_i )
  • ノード:区間の間が何日干し草を食べられるか、また、その日にちの総和

コード

#pragma GCC optimize ("O3")
#pragma GCC target ("sse4")
#pragma GCC optimize("unroll-loops")
#include<stdio.h>
#include<iostream>
#include<vector>
#include<algorithm>
#include<string>
#include<string.h>
#include<map>
#include<memory>

#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#else
#define NDEBUG
#define eprintf(...) do {} while (0)
#endif
#include<cassert>

using namespace std;

typedef long long LL;
typedef vector<int> VI;

#define REP(i,n) for(int i=0, i##_len=(n); i<i##_len; ++i)
#define EACH(i,c) for(__typeof((c).begin()) i=(c).begin(),i##_end=(c).end();i!=i##_end;++i)

template<class T> inline void amin(T &x, const T &y) { if (y<x) x=y; }
template<class T> inline void amax(T &x, const T &y) { if (x<y) x=y; }
#define rprintf(fmt, begin, end) do { const auto end_rp = (end); auto it_rp = (begin); for (bool sp_rp=0; it_rp!=end_rp; ++it_rp) { if (sp_rp) putchar(' '); else sp_rp = true; printf(fmt, *it_rp); } putchar('\n'); } while(0)

template<class T> void sort_unique(vector<T> &v) {
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
}

template<unsigned MOD_> struct ModInt {
    static constexpr unsigned MOD = MOD_;
    unsigned x;
    void undef() { x = (unsigned)-1; }
    bool isnan() const { return x == (unsigned)-1; }
    inline int geti() const { return (int)x; }
    ModInt() { x = 0; }
    ModInt(int y) { if (y<0 || (int)MOD<=y) y %= (int)MOD; if (y<0) y += MOD; x=y; }
    ModInt(unsigned y) { if (MOD<=y) x = y % MOD; else x = y; }
    ModInt(long long y) { if (y<0 || MOD<=y) y %= MOD; if (y<0) y += MOD; x=y; }
    ModInt(unsigned long long y) { if (MOD<=y) x = y % MOD; else x = y; }
    ModInt &operator+=(const ModInt y) { if ((x += y.x) >= MOD) x -= MOD; return *this; }
    ModInt &operator-=(const ModInt y) { if ((x -= y.x) & (1u<<31)) x += MOD; return *this; }
    ModInt &operator*=(const ModInt y) { x = (unsigned long long)x * y.x % MOD; return *this; }
    ModInt &operator/=(const ModInt y) { x = (unsigned long long)x * y.inv().x % MOD; return *this; }
    ModInt operator-() const { return (x ? MOD-x: 0); }

    ModInt inv() const { return pow(MOD-2); }
    ModInt pow(long long y) const {
	ModInt b = *this, r = 1;
	if (y < 0) { b = b.inv(); y = -y; }
	for (; y; y>>=1) {
	    if (y&1) r *= b;
	    b *= b;
	}
	return r;
    }

    friend ModInt operator+(ModInt x, const ModInt y) { return x += y; }
    friend ModInt operator-(ModInt x, const ModInt y) { return x -= y; }
    friend ModInt operator*(ModInt x, const ModInt y) { return x *= y; }
    friend ModInt operator/(ModInt x, const ModInt y) { return x *= y.inv(); }
    friend bool operator<(const ModInt x, const ModInt y) { return x.x < y.x; }
    friend bool operator==(const ModInt x, const ModInt y) { return x.x == y.x; }
    friend bool operator!=(const ModInt x, const ModInt y) { return x.x != y.x; }
};

constexpr LL MOD = 1000000007;

using Mint = ModInt<MOD>;
const Mint inv2 = Mint(1) / 2;

Mint f(LL start, LL len) {
    return (Mint(start) * 2 + len - 1) * len * inv2;
}


int U;
LL D[100011];
LL B[100011];

int SIZE;
vector<pair<LL, LL> > DB[1<<19];

void STORE(int left, int right, LL d, LL b) {
    left += SIZE;
    right += SIZE;
    for (; left<right; left/=2, right/=2) {
	if (left & 1) { DB[left++].emplace_back(d, b); }
	if (right & 1) { DB[--right].emplace_back(d, b); }
    }
}

struct Node {
    Node *left, *right;
    LL start;
    LL len;
    LL num;
    Mint sum;

    void clear() {
	left = right = nullptr;
	start = len = num = 0;
	sum = 0;
    }
} nodes[5000000];
int nodei;

Node* new_node() {
    assert(nodei < 5000000);
    nodes[nodei].clear();
    return nodes + nodei++;
}

Node* new_node(Node *left, Node *right) {
    Node *x = new_node();
    x->left = left;
    x->right = right;
    x->start = left->start;
    x->len = left->len + right->len;
    x->num = left->num + right->num;
    x->sum = left->sum + right->sum;
    return x;
}

vector<LL> Ds;

Node* build(int l, int r) {
    if (l == r) {
	return nullptr;
    } else if (l + 1 == r) {
	Node *x = new_node();
	x->start = Ds[l];
	x->len = Ds[r] - Ds[l];
	return x;
    } else {
	Node *x = new_node();
	x->start = Ds[l];
	x->len = Ds[r] - Ds[l];
	x->left = build(l, (l+r)/2);
	x->right = build((l+r)/2, r);
	assert(x->start == x->left->start);
	assert(x->len = x->left->len + x->right->len);
	return x;
    }
}

pair<Node*, LL> add_rec(Node *tree, LL start, LL num) {
    if (!tree || num == 0 || tree->start + tree->len <= start || tree->len == tree->num) {
	return make_pair(tree, num);
    } else if (start <= tree->start && tree->len <= num + tree->num) {
	Node *x = new_node();
	x->start = tree->start;
	x->len = tree->len;
	x->num = x->len;
	x->sum = f(tree->start, x->num);
	return make_pair(x, num + tree->num - x->num);
    } else if (!tree->left && !tree->right) {
	assert(tree->len >= tree->num + num);
	Node *x = new_node();
	x->start = tree->start;
	x->len = tree->len;
	x->num = tree->num + num;
	x->sum = f(tree->start, x->num);
	return make_pair(x, 0LL);
    } else {
	auto L = add_rec(tree->left, start, num);
	auto R = add_rec(tree->right, start, L.second);
	Node *x = new_node(L.first, R.first);
	return make_pair(x, R.second);
    }
}


Node* add(Node *tree, LL d, LL b) {
    auto p = add_rec(tree, d, b);
    assert(p.second == 0);
    return p.first;
}

void DFS(Node *tree, int k) {
    int backup = nodei;
    EACH (e, DB[k]) {
	tree = add(tree, e->first, e->second);
    }
    if (SIZE <= k) {
	if (k-SIZE < U) {
	    printf("%d\n", tree->sum.geti());
	}
    } else {
	DFS(tree, k*2);
	DFS(tree, k*2+1);
    }
    nodei = backup;
}

void MAIN() {
    scanf("%d", &U);

    REP (i, U) {
	scanf("%lld%lld", D+i, B+i);
    }

    Ds.reserve(U+1);
    Ds.assign(D, D+U);
    Ds.push_back(1LL<<50);
    sort_unique(Ds);


    SIZE = 1;
    while (SIZE < U) SIZE += SIZE;

    vector<pair<LL, int> > V;

    REP (i, U) {
	V.emplace_back(D[i], i);
    }
    sort(V.begin(), V.end());

    for (int i=0, j=0; i<U; i=j) {
	while (j < U && V[i].first == V[j].first) {
	    j++;
	}

	int right = U;
	for (int k=j-1; k>=i; k--) {
	    int left = V[k].second;
	    STORE(left, right, D[left], B[left]);
	    right = left;
	}
    }

    Node *tree = build(0, Ds.size()-1u);
    DFS(tree, 1);
}

int main() {
    int TC = 1;
//    scanf("%d", &TC);
    REP (tc, TC) MAIN();
    return 0;
}