题解:AT_arc188_d [ARC188D] Mirror and Order
posted on 2025-02-19 12:56:34 | under | source
闲话:larsr 说评分虚高。
先判定 \(a,b\) 是否合法,不妨令 \(a\) 升序。考虑排序后的序列,将 \(2i,2i+1\) 视为一组,首先 \(a\) 必须在不同组才能合法。\(s_i,t_i\) 回文,相当于确定了每个串的开头结尾数字。
记 \(s_i,t_i\) 中间元素为 \(c_i\)。对于同一组,让其中 \(a\) 对应的 \(c_x\) 向 \(b\) 对应的 \(c_y\) 连有向边。这样的好处是得到若干环(边的起点、终点序列均为排列)。
现在考虑边权是什么。易发现不可能 \(c_x=c_y\),因为这种情况当且仅当 \(x=y\) 也即同一组元素互为回文,那么违背题目“不能有相同字符串”的条件。所以当 \(a\) 在奇数位置上时连小于号、反之大于号。
考虑一个环是否合法。首先边权全都一样必然不行,反之发现必然有构造方案。具体来说,将环通过 \(>\) 号断开,得到全 \(<\) 号段也就是上升序列,那么要让每个上升序列的结尾大于下一个的开头,发现只需让其“环环相扣”即可(单个元素放置在上方)。建议手玩下。
综上,一组 \(a,b\) 合法当且仅当不存在环满足全是奇数位置或全是偶数位置。感觉是可以容斥的亚子。
回到原题,已经建出了一些边。若成环,就扔掉,当然如果非法直接输出零。对于链(包括孤立点),将其分组,记总共有 \(m\) 条链,\(x\) 条全偶链、\(y\) 条全奇链。
容斥,对于一个非法环,可钦定其系数为 \(-1\)。那么我们枚举全偶链中有 \(i\) 条链被钦定用来组成非法环,全奇链有 \(j\) 条。记 \(f_i\) 为 \(i\) 个元素组成若干环,一个环系数为 \(-1\),系数之积的和。
答案即为:
可以卷积做到 \(O(n\log n)\)。
最后考虑 \(f\) 怎么递推,不妨枚举 \(i\) 元素构成的环的大小,则有:
拆开组合数,可以前缀和做到 \(O(n)\)。
总复杂度 \(O(n\log n)\)。偷懒写了 \(O(n^2)\)。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pir pair<int, int>
const int N = 3e3 + 5, mod = 998244353;
int n, bel[N << 1], fa[N], siz[N], c[N], pw, m, x, y, ans, jc[N], jcinv[N], f[N], cir[N];
pir a[N];
inline int qstp(int a, int k) {int res = 1; for(; k; a = a * a % mod, k >>= 1) if(k & 1) res = res * a % mod; return res;}
inline int A(int n, int m) {return n < m ? 0 : jc[n] * jcinv[n - m] % mod;}
inline int C(int n, int m) {return n < m ? 0 : jc[n] * jcinv[m] % mod * jcinv[n - m] % mod;}
inline int find(int u) {return fa[u] == u ? u : fa[u] = find(fa[u]);}
inline void mer(int x, int y){
int fx = find(x), fy = find(y);
if(fx ^ fy){
siz[fy] += siz[fx], c[fy] |= c[fx];
fa[fx] = fy;
}
else cir[fx] |= 1;
}
signed main(){
jc[0] = jcinv[0] = 1;
for(int i = 1; i < N; ++i) jcinv[i] = qstp(jc[i] = jc[i - 1] * i % mod, mod - 2);
f[0] = 1;
for(int i = 1, s = 0; i < N; ++i){
s = (s + f[i - 1] * jcinv[i - 1] % mod) % mod;
f[i] = (mod - 1) * jc[i - 1] % mod * s % mod;
}
cin >> n;
for(int i = 1; i <= n; ++i) scanf("%lld", &a[i].first);
for(int i = 1; i <= n; ++i) scanf("%lld", &a[i].second);
sort(a + 1, a + 1 + n);
for(int i = 1; i <= n; ++i){
if((a[i].first + 1) / 2 == (a[i - 1].first + 1) / 2) {puts("0"); return 0;}
if(a[i].second != -1) bel[a[i].second] = i;
fa[i] = i, siz[i] = 1;
c[i] = (1 << (a[i].first & 1));
}
for(int i = 1; i <= n; ++i){
int bro = ((a[i].first & 1) ? (a[i].first + 1) : (a[i].first - 1));
if(bel[bro]) mer(i, bel[bro]);
}
for(int i = 1; i <= n; ++i)
if(fa[i] == i){
if(cir[i]){
if(c[i] != 3) {puts("0"); return 0;}
continue;
}
++m;
if(c[i] == 1) ++x;
if(c[i] == 2) ++y;
}
for(int i = 0; i <= x; ++i)
for(int j = 0; j <= y; ++j){
int res = f[i] * C(x, i) % mod * f[j] % mod * C(y, j) % mod * jc[m - i - j] % mod;
ans = (ans + res) % mod;
}
cout << ans;
return 0;
}

浙公网安备 33010602011771号