关于一类容斥原理设计 dp 状态的探讨

写在前面

为什么要写?因为自己学不明白希望日后能掌握。

大体思路大概是

  1. 设计一个容斥的方案,并使其贡献可以便于计算。
  2. 得出 dp 状态,然后优化以得出答案。

下列所有类似 \([l,r]\) 这样的都是离散的。

1.

\(n\) 个点,每个点有一个能选择的颜色 \(a_i\),左右相邻的点不能同色,求方案数。

如果我们使用容斥的思想,强制 \(k\) 段的颜色相同,这个限制下的方案数对答案的贡献的容斥系数就是 \((-1)^{n-k}\)。这应该是相邻颜色不同的方案数的一个非常平凡的trick。(但是我不会

可以设 \(f_i\) 表示统计到前 \(i\) 个点所容斥的答案和。枚举 \([j,i]\) 这一段强制颜色相等。

\[f_i=-\sum\limits_{j=1}^i f_{j-1}\min\limits_{j\le k\le i}a_i \]

这个东西可以用单调栈维护一下。

注意到这个东西可以拓展到环上。把 \(a_i\) 最小的位置轮换到最前面,然后你发现 \(f_i\) 其实就是强制了 \([i+1,n]\)\(1\) 的颜色相同的答案。全部加起来就好了。

    s[0] = f[0] = 1; int top = 0;
    ll sum = 0;
    fo(i, 1, c) {
        while(top && b[stc[top]] > b[i])
            sum = (sum + (ll)(s[stc[top] - 1] - (stc[top] == 1 ? 0 :
                  s[stc[top - 1] - 1]) + mod) * (b[i] - b[stc[top]] + mod)) % mod,
            --top;
        stc[++top] = i;
        sum = (sum + (ll)f[i - 1] * b[i]) % mod;
        f[i] = mod - sum;
        s[i] = (f[i] + s[i - 1]) % mod;
    }

2.

\(n\) 个点,一个区间可以覆盖 \([l_i,r_i]\) 这一段,每个区间有一个价值 \(v_i\) ,定义一种“覆盖”为每个点至少被一个区间所覆盖的方案,其价值为所有所选区间的价值积,求所有覆盖的价值之和。

考虑强制 \(k\) 个点不被覆盖,那么这种情况对答案的贡献的容斥系数就是 \((-1)^k\)。其贡献就是这些点之间的区间的乘积之和。

这样的话,设 \(f_i\) 表示 \(i\) 点被钦定,枚举 \(j\) 表示上一个钦定点,有

\[f_i=-\sum_{j=1}^{i-1}f_j \prod_{j<l_k\le r_k<i}(v_k+1) \]

这玩意可以线段树优化!考虑线段树的每一个位置记录的是它作为 \(j\) 造成的贡献,假设现在新加入一个区间 \(k\) ,它能使 \([0,l_k)\) 的位置的贡献发生变化,乘上 \((1+v_k)\)

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define ll long long
#define fo(i, a, b) for(int i = (a); i <= (b); ++i)
#define fd(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
inline void read(int &x) {
	x = 0; char ch = getchar();
	while(ch < '0' || ch > '9')	ch = getchar();
	while(ch >= '0' && ch <= '9')	x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
}
const int N = 2e5 + 10, mod = 1e9 + 7;
namespace Seg {
	#define ls t << 1
	#define rs ls | 1
	#define mid ((l + r) >> 1)
	int tr[N << 2], pro[N << 2];
	inline void mul(int t, int v) {tr[t] = (ll)tr[t] * v % mod, pro[t] = (ll)pro[t] * v % mod;}
	inline void push_down(int t) {
		if(pro[t] > 1) {
			mul(ls, pro[t]), mul(rs, pro[t]);
			pro[t] = 1;
		}
	}
	void build(int t, int l, int r) {
		pro[t] = 1;
		if(l == r)	return;
		build(ls, l, mid), build(rs, mid + 1, r);
	}
	void change(int t, int l, int r, int w, int v) {
		tr[t] = (tr[t] + v) % mod;
		if(l == r)	return ;
		push_down(t);
		w <= mid ? change(ls, l, mid, w, v) : change(rs, mid + 1, r, w, v);
	}
	void update(int t, int l, int r, int fl, int fr, int v) {
		if(fl <= l && r <= fr)	return mul(t, v);
		push_down(t);
		fl <= mid && (update(ls, l, mid, fl, fr, v), 1);
		fr > mid && (update(rs, mid + 1, r, fl, fr, v), 1);
		tr[t] = (tr[ls] + tr[rs]) % mod;
	}
	int query(int t, int l, int r, int fl, int fr) {
		if(fl <= l && r <= fr)	return tr[t];
		push_down(t);
		int ret = 0;
		fl <= mid && (ret = (ret + query(ls, l, mid, fl, fr)) % mod);
		fr > mid && (ret = (ret + query(rs, mid + 1, r, fl, fr)) % mod);
		return ret;
	}
}
struct Op {
	int l, r, v;
}p[N];
vector<int> q[N];
int n, m, f[N];
int main() {
	freopen("gugugu.in", "r", stdin);
	freopen("gugugu.out", "w", stdout);
	read(n), read(m);
	fo(i, 1, m)	read(p[i].l), read(p[i].r), read(p[i].v), q[p[i].r].push_back(i);
	Seg::build(1, 0, n);
	Seg::change(1, 0, n, 0, 1);
	fo(i, 1, n + 1) {
		Seg::change(1, 0, n, i, f[i] = mod - Seg::query(1, 0, n, 0, i - 1));
		for(auto k : q[i])
			Seg::update(1, 0, n, 0, p[k].l - 1, (p[k].v + 1) % mod);
	}
	printf("%d\n", mod - f[n + 1]);
	return 0;
}

To be continued..

posted @ 2021-10-21 16:29  Martin_MHT  阅读(89)  评论(0)    收藏  举报