题解:P10104 [GDKOI2023 提高组] 异或图
题意:给定一张 \(n\) 个点 \(m\) 条边的无向图和一个长度为 \(n\) 的数组 \(a_1, a_2, \cdots , a_n\) 以及一个整数 \(C\),你需要求出有多少个长度为 \(n\) 的数组 \(b\) 满足:
- \(0 ≤ b_i ≤ a_i,\forall 1 ≤ i ≤ n\)。
- 对于每条边 \((u, v)\),\(b_u \ne b_v\)。
- \(b_1 ⊕ b_2 ⊕ \cdots ⊕ b_n = C\),其中 \(\oplus\) 代表异或。
做法:
首先先要会做 \(m=0\),那么我们可以在后面加一个元素 \(a_{n+1}=C\) 算出异或和为 \(0\) 方案数,然后再改成 \(a_{n+1}=C-1\) 算出方案数,做差即是原题要求的。
怎么求异或和为 \(0\) 的方案数?其实是 这个题,可以见我在这个题的 题解。
然后考虑 \(m\not = 0\),看到有互异这个限制,很自然地直接容斥,考虑每次加入一个相等的连通块,那么容斥系数就应该是导出子图的 \(\sum\limits_{E'\in E}(-1)^{|E'|}\),观察到这个柿子在 \(|E'| \ge 1\) 的时候为 \(0\),所以只需要独立集即可,总的系数就是枚举一个包含 \(1\) 的连通块划分出去然后去计算即可。
然后考虑,如果划分的是一个偶数大小的连通块,那么这个连通块怎么划都可以,方案数为连通块最小值 \(+1\) 种方式;如果是奇数大小的,那么就需要跑我们开始的那个东西,全部乘起来的到贡献,可以做到贝尔数的复杂度。
但是我们考虑可以直接在加入连通块的时候记录最小值,这样可以做到 \(4^n\),但是还是不够快。
我们这里做一个简单的优化,我们把所有数按 \(a\) 排个序,同时集合里对于 \(\le i\) 的位,我们只记录他是否是作为最小值加入奇连通块的,而大于的就记录是否已经被加入集合。具体的,在转移的时候我们考虑 \(i\) 是否在我们枚举的集合中,如果在,那么就意味着他是通过以前被加入的,那么贡献时直接把这一位变成 \(0\) 即可;然后考虑他是最小值,如果是偶连通块,直接乘上贡献,并且加入 \(>i\) 的点,\(i\) 这一位保持 \(0\),奇连通块则变为 \(1\),可以见代码实现。这样我们就可以很方便地在最后直接拉出来哪些是需要的,复杂度是 \(O(3^nn)\)。
最后再对跑出来的那些有用点集去跑 \(m=0\) 的做法就可以了,总复杂度 \(O(3^nn)\)。
因为 \(a\) 可以到达 \(10^{18}\),记得取模。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5 + 5, inf = 2e9, mod = 998244353;
int n, m, c, val[maxn], coef[maxn], a[maxn];
int dp[20][2][2];
int cal(vector<int> v) {
int ans = 0, s = 0;
for (int i = 1; i < v.size(); i++)
s ^= v[i];
ans += (!s); s = 0;
for (int d = 60; d >= 0; d--) {
memset(dp, 0, sizeof(dp));
dp[0][0][0] = 1;
s = 0;
int all = (1ll << d) - 1;
for (int i = 1; i < v.size(); i++) {
s ^= (v[i] >> d + 1);
int t = ((v[i] >> d) & 1);
if(!t) {
for (int x = 0; x <= 1; x++)
for (int y = 0; y <= 1; y++)
dp[i][x][y] = dp[i - 1][x][y] * (((v[i] & all) + 1) % mod) % mod;
}
else {
dp[i][0][0] = dp[i - 1][1][0] * (((v[i] & all) + 1) % mod) % mod;
dp[i][1][0] = dp[i - 1][0][0] * (((v[i] & all) + 1) % mod) % mod;
dp[i][0][1] = (dp[i - 1][1][1] * (((v[i] & all) + 1) % mod) % mod + dp[i - 1][0][1] * ((all + 1) % mod) % mod + dp[i - 1][0][0]) % mod;
dp[i][1][1] = (dp[i - 1][0][1] * (((v[i] & all) + 1) % mod) % mod + dp[i - 1][1][1] * ((all + 1) % mod) % mod + dp[i - 1][1][0]) % mod;
}
}
if(!s)
ans = (ans + dp[v.size() - 1][0][1]) % mod;
}
return ans;
}
int solve(vector<int> v) {
v.push_back(c);
int ans = cal(v);
if(c == 0)
return ans;
v[v.size() - 1]--;
// cout << ans << " " << cal(v) << "Adsf" << endl;
return (ans - cal(v) + mod) % mod;
}
int dpt[2][maxn], id[maxn], p[maxn];
bool cmp(int x, int y) {
return a[x] < a[y];
}
int lowbit(int x) {
return x & (-x);
}
int cnt[maxn];
signed main() {
cin >> n >> m >> c;
for (int i = 1; i <= n; i++)
cin >> a[i], p[i] = i;
sort(p + 1, p + n + 1, cmp);
sort(a + 1, a + n + 1);
for (int i = 1; i <= n; i++)
id[p[i]] = i;
for (int i = 1; i <= m; i++) {
int x, y; cin >> x >> y;
x = id[x], y = id[y];
int s = (1 << x - 1) | (1 << y - 1);
for (int t = 0; t < (1 << n); t++)
if((s & t) == s)
val[t]++;
}
for (int s = 0; s < (1 << n); s++)
val[s] = !val[s], cnt[s] = cnt[s >> 1] + (s & 1);
dpt[0][0] = 1;
int cur = 0;
for (int s = 0; s < (1 << n); s++) {
coef[s] = val[s];
int lb = lowbit(s);
for (int t = (s - 1) & s; t; t = (t - 1) & s) {
if((t & lb) == lb)
coef[s] = (coef[s] - coef[t] * val[s ^ t] % mod + mod) % mod;
}
// cout << s << " " << coef[s] << endl;
}
for (int i = 1; i <= n; i++) {
cur ^= 1, memset(dpt[cur], 0, sizeof(dpt[cur]));
for (int s = 0; s < (1 << n); s++) {
if(!dpt[cur ^ 1][s])
continue;
if((s >> i - 1) & 1) {
dpt[cur][s ^ (1 << i - 1)] = (dpt[cur][s ^ (1 << i - 1)] + dpt[cur ^ 1][s]) % mod;
continue;
}
int res = (((1 << n) - 1) ^ ((1 << i) - 1));
// cout << res << endl;
res = res ^ (s & res);
for (int t = res; ; t = (t - 1) & res) {
int u = (t | (1 << i - 1));
if(cnt[u] & 1)
dpt[cur][s | u] = (dpt[cur][s | u] + dpt[cur ^ 1][s] * coef[u] % mod) % mod;
else
dpt[cur][s | t] = (dpt[cur][s | t] + dpt[cur ^ 1][s] * coef[u] % mod * ((a[i] + 1) % mod) % mod) % mod;
if(!t)
break;
}
// cout << dpt[cur][9] << " " << dpt[cur ^ 1][s] << " " << s << " " << res << " " << (s | res) << "Adsf" << endl;
}
// for (int s = 0; s < (1 << n); s++)
// cout << dpt[cur][s] << " " << i << " " << s << endl;
}
int ans = 0;
for (int i = 0; i < (1 << n); i++) {
vector<int> nw; nw.push_back(0);
for (int j = 1; j <= n; j++)
if((i >> j - 1) & 1)
nw.push_back(a[j]);
// cout << i << " " << solve(nw) << " " << dpt[cur][i] << endl;
ans = (ans + solve(nw) * dpt[cur][i] % mod) % mod;
}
cout << ans << endl;
return 0;
}

浙公网安备 33010602011771号