Codeforces 1086F - Forest Fires(插值/找性质+容斥)

Codeforces 题面传送门 & 洛谷题面传送门

这里介绍两种方法,一种是基于拉格朗日插值的最优可做到 \(n^3\log n\) 的做法,另一种是基于容斥的最优可做到 \(n^2\log n\) 的做法。

首先先推一波式子,直接计算每个点着火的时刻显然没啥性质可言,因此我们不妨将贡献摊到每个时刻上,我们考虑对每个时刻 \(T\in[0,t]\) 计算着火格子数并累计求和,不难发现这样对于一个着火时刻为 \(t_0\) 的格子,其会产生 \(t+1-t_0\) 的贡献。因此我们只用拿 \((t+1)·t\text{时刻着火的格子数}\) 再减去之前求得的和式即可算出答案,前者直接矩形面积交即可,于是瓶颈在于计算后者。

一个很显然的想法是,注意到这个值域很大,但是只要矩形之间两两是否相交的情况不变,矩形面积并就是关于时间的多项式,因此我们考虑像 2022 年省选 D1T2 那样,把 \([1,t]\) 分成若干个区间使得每个区间的面积并都是关于时间的多项式,这样我们对于每个区间插三个值(因为多项式次数显然只有 \(2\))然后累加求前缀和即可。注意到区间个数是 \(\mathcal O(n^2)\) 级别的,再加上矩形面积并的复杂度,总复杂度 \(n^3\log n\),已经足够通过 CF 的数据。

接下来考虑复杂度更优的做法,之前的插值做法好像没啥优化的余地,于是我们不妨另辟蹊径。我们考虑求矩形并的经典做法:求矩形交,然后容斥。于是我们考虑这样计算 \([0,t]\) 所有时刻着火格子数之和:枚举一个子集 \(S\),然后求 \([0,t]\) 时刻对应矩形交大小之和,乘上容斥系数 \((-1)^{|S|+1}\)。这样看似复杂度升高了,变成了 \(2^nn\),但是仔细分析可以发现,如果一个点集中存在某个点被其他点包围了,那么中间这个点就是可有可无的,此时矩形的贡献就变成了 \(0\),有点类似于求矩阵行列式的正负相消的思想。于是我们称不存在一个点被包围的集合是“好的”,我们不妨枚举这些“好的”集合的左右边界,并用 multiset 存储中间那些点的纵坐标,可以发现,只有左右边界对应的点在 multiset 中相邻才有可能产生贡献,因此“好的”集合个数是 \(\mathcal O(n^2)\) 的,再加上 multiset 的复杂度,可得最终复杂度 \(n^2\log n\)

const int MAXN = 2000;
const int MOD = 998244353;
const int INV2 = MOD + 1 >> 1;
const int INV6 = (MOD + 1) / 6;
int n, t, x[MAXN + 5], y[MAXN + 5], res = 0;
namespace CalcUnion {
	int sum[MAXN * 2 + 5][MAXN * 2 + 5];
	int kx[MAXN * 2 + 5], ky[MAXN * 2 + 5], ux[MAXN * 2 + 5], uy[MAXN * 2 + 5], nx, ny, cx, cy;
	int calc() {
		cx = cy = nx = ny = 0; kx[0] = ky[0] = -1e9;
		for (int i = 1; i <= n; i++) {
			kx[++cx] = x[i] - t; kx[++cx] = x[i] + t + 1;
			ky[++cy] = y[i] - t; ky[++cy] = y[i] + t + 1;
		}
		sort(kx + 1, kx + cx + 1); sort(ky + 1, ky + cy + 1);
		for (int i = 1; i <= cx; i++) if (kx[i] != kx[i - 1]) ux[++nx] = kx[i];
		for (int i = 1; i <= cy; i++) if (ky[i] != ky[i - 1]) uy[++ny] = ky[i];
		for (int i = 1; i <= nx; i++) for (int j = 1; j <= ny; j++) sum[i][j] = 0;
		for (int i = 1; i <= n; i++) {
			int lx = lower_bound(ux + 1, ux + nx + 1, x[i] - t) - ux;
			int rx = lower_bound(ux + 1, ux + nx + 1, x[i] + t + 1) - ux;
			int ly = lower_bound(uy + 1, uy + ny + 1, y[i] - t) - uy;
			int ry = lower_bound(uy + 1, uy + ny + 1, y[i] + t + 1) - uy;
			sum[lx][ly]++; sum[lx][ry]--; sum[rx][ly]--; sum[rx][ry]++;
		}
		int ret = 0;
		for (int i = 1; i <= nx; i++) for (int j = 1; j <= ny; j++)
			sum[i][j] += sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1];
		for (int i = 1; i <= nx; i++) for (int j = 1; j <= ny; j++) {
			if (sum[i][j]) ret = (ret + 1ll * (ux[i + 1] - ux[i]) * (uy[j + 1] - uy[j])) % MOD;
		} 		
		return ret;
	}
}
int getsum2(int x) {return 1ll * x * (x + 1) % MOD * (2 * x + 1) % MOD * INV6 % MOD;}
int getsum2(int l, int r) {return (getsum2(r) - getsum2(l - 1) + MOD) % MOD;}
int getsum1(int x) {return 1ll * x * (x + 1) % MOD * INV2 % MOD;}
int getsum1(int l, int r) {return (getsum1(r) - getsum1(l - 1) + MOD) % MOD;}
namespace CalcInter {
	int calc(vector<int> vec) {
		for (int x : vec) if (!x) return 0;
		int mnx = 1e9, mxx = -1e9, mny = 1e9, mxy = -1e9;
		for (int id : vec) {
			chkmin(mnx, x[id]); chkmax(mxx, x[id]);
			chkmin(mny, y[id]); chkmax(mxy, y[id]);
		}
		int lim = max(mxx - mnx + 1, mxy - mny + 1) / 2;
		int V1 = mxx - mnx - 1, V2 = mxy - mny - 1;
		if (lim <= t) {
			// \sum\limits_{i=lim}^t(2i-V1)(2i-V2)
			int sum = (4ll * getsum2(lim, t) % MOD -
					2ll * (V1 + V2) * getsum1(lim, t) % MOD +
					1ll * V1 * V2 % MOD * (t - lim + 1) % MOD +
					MOD) % MOD;
			return sum;
		} else return 0;
	}
}
int ord[MAXN + 5];
void solve() {
	scanf("%d%d", &n, &t);
	for (int i = 1; i <= n; i++) scanf("%d%d", &x[i], &y[i]);
	res = 1ll * CalcUnion :: calc() * (t + 1) % MOD;
	for (int i = 1; i <= n; i++) res = (res - CalcInter :: calc({i}) + MOD) % MOD;
	for (int i = 1; i <= n; i++) ord[i] = i;
	sort(ord + 1, ord + n + 1, [&](int a, int b) {
		if (x[a] != x[b]) return x[a] < x[b];
		if (y[a] != y[b]) return y[a] < y[b];
		return a < b;
	});
	for (int i = 1; i <= n; i++) {
		set<pii> st; st.insert(mp(-1e9, 0)); st.insert(mp(1e9, 0));
		int id = ord[i]; st.insert(mp(y[id], id));
		for (int j = i + 1; j <= n; j++) {
			st.insert(mp(y[ord[j]], ord[j]));
			auto it1 = st.find(mp(y[id], id));
			auto it2 = st.find(mp(y[ord[j]], ord[j]));
			if (it1 == next(it2)) {
				res = (res + CalcInter :: calc({id, ord[j]})) % MOD;
				res = (res - CalcInter :: calc({id, ord[j], prev(it2) -> se}) + MOD) % MOD;
				res = (res - CalcInter :: calc({id, ord[j], next(it1) -> se}) + MOD) % MOD;
				res = (res + CalcInter :: calc({id, ord[j], next(it1) -> se, prev(it2) -> se})) % MOD;
			}
			if (it2 == next(it1)) {
				res = (res + CalcInter :: calc({id, ord[j]})) % MOD;
				res = (res - CalcInter :: calc({id, ord[j], prev(it1) -> se}) + MOD) % MOD;
				res = (res - CalcInter :: calc({id, ord[j], next(it2) -> se}) + MOD) % MOD;
				res = (res + CalcInter :: calc({id, ord[j], next(it2) -> se, prev(it1) -> se})) % MOD;
			}
		}
	}
	printf("%d\n", res);
}
int main() {
	freopen("fire.in", "r", stdin);
	freopen("fire.out", "w", stdout);
	int qu; scanf("%d", &qu);
	while (qu--) solve();	
	return 0;
}
posted @ 2022-06-06 17:14  tzc_wk  阅读(81)  评论(0)    收藏  举报