2025/5/1 省队集训|容斥
当然,以下是优化排版和数学公式的版本,便于理解与阅读:
[FJOI2017] 矩阵填数
题目链接:https://www.luogu.com.cn/problem/P3813
题目描述
给定一个 \(h \times w\) 的矩阵(编号从 \(1 \sim h\) 行、\(1 \sim w\) 列),每个格子填入 \(1 \sim m\) 中的某个数。
有 \(n\) 个限制,每个限制是一个子矩阵及其最大值 \(v\),要求最终填数方案中该子矩阵的最大值恰为 \(v\)。
问一共有多少种填数方案,满足所有限制条件。
解题思路
一、总体策略
- 每个格子合法值的范围是:\[[1, \min(\text{包含该格子的所有限制中的 } v_i, m)] \]
- 不同值域的点之间互不影响,可以分开统计后再相乘。
二、求每个值域的方案数(容斥原理)
对于一个值域为 \(k\) 的区域集合 \(S_k\),总填法是:
\[k^{|S_k|}
\]
但这会多算一些不合法的情况(即某些限制矩阵中最大值不是 \(k\) 而更小)。
假设第 \(i\) 个限制子矩阵的最大值为 \(k\),它的点集记为 \(T_{k,i}\)。
对这些矩阵执行容斥:
- 先减去:第1个矩阵不满足最大值为 \(k\) 的填法数:\[(k-1)^{|T_{k,1}|} \cdot k^{|S_k| - |T_{k,1}|} \]
- 再加上两个不满足的交集……
- 直到所有集合组合处理完。
最终所有满足值域为 \(k\) 的合法方案数,乘起来即可。
代码
#include<bits/stdc++.h>
using namespace std;
constexpr int maxn = 15, maxm = 1050;
#define int long long
constexpr int mod = 1e9+7;
namespace MATH {
inline int pow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
}
int n, m, h, w;
int s[maxm], u[maxm];
int siz[maxm];
#define y1 s_e_v_e
struct Matrix {
int x, y, x1, y1;
int v;
void read() {
cin >> x >> y >> x1 >> y1 >> v;
}
bool isclean() const {
return (x > x1) || (y > y1);
}
int getS() const {
if (isclean()) return 0;
return (x1 - x + 1) * (y1 - y + 1);
}
void operator &= (const Matrix& a) {
x = max(x, a.x); y = max(y, a.y);
x1 = min(x1, a.x1); y1 = min(y1, a.y1);
}
bool operator < (const Matrix& X) const {
return v < X.v;
}
} r[maxn];
signed main() {
for (int i = 1; i < 1024; i++) {
siz[i] = siz[i >> 1] + (i & 1);
}
ios::sync_with_stdio(0);
cin.tie(0);
int T;
cin >> T;
while (T-- > 0) {
cin >> h >> w >> m >> n;
for (int i = 0; i < n; i++) {
r[i].read();
}
sort(r, r + n);
int N = (1 << n) - 1;
for (int i = 1; i <= N; i++) { //求并集
Matrix t = {1, 1, h, w, 0};
for (int p = i, j = 0; p; p >>= 1, j++) {
if (p & 1) {
t &= r[j];
if (t.isclean()) {
break;
}
}
}
s[i] = t.getS();
}
for (int i = 1; i <= N; i++) {//求交集
u[i] = 0;
for (int j = i; j; j = (j - 1) & i) {
if (siz[j] & 1) u[i] += s[j]; //容斥
else u[i] -= s[j];
}
}
int ns = 0, ls = 0; //目前的,上一次的
int res = 1;
for (int i = 0; i < n; i++) {
ns |= (1 << i);
if (i + 1 < n && r[i].v == r[i + 1].v) continue;
int total = u[ns | ls] - u[ls];
int st = total;
int ret = MATH::pow(r[i].v, total);
for (int k = ns; k; k = (k - 1) & ns) {
int part = u[k | ls] - u[ls]; //这个就是T
int del = MATH::pow(r[i].v - 1, part) * MATH::pow(r[i].v, st - part) % mod;
if (siz[k] % 2) ret = (ret + mod - del) % mod; //容斥
else ret = (ret + del) % mod;
}
res = res * ret % mod;
ls |= ns;
ns = 0;
}
cout << res * MATH::pow(m, h * w - u[N]) % mod << '\n';
for (int i = 0; i <= N; i++) {
u[i] = 0;
}
}
return 0;
}
浙公网安备 33010602011771号