【ICPC2024 上海F】Fast Bogosort
分治NTT
定义排序算法:
- 若当前序列有序,结束;
- 将当前序列尽可能多的序列,使得每个区间的坐标范围与值域相同;
- 对于每个分出的区间,如果 \(l<r\),则执行 \(\text{shuffle(l, r)}\)。
给一个排列,问期望的 \(\text{shuffle}\) 次数,使得区间变得有序。对 \(998244353\) 取模。
Solution
当前序列按照下标分成 \(\cup_{i=1}^m [l_i, r_i]\)。
设 \(f(n)\) 为长度为 \(n\) 的随机排列的期望 \(\text{shuffle}\) 次数。
\[Ans = m + \sum_{i=1}^m f(r_i - l_i + 1)
\]
设 \(g(n)\) 表示长度为 \(n\) 且不可被划分的排列的方案数,则
\[g(1) = 1, g(n) = n! - \sum_{i=1}^{n-1}g(i)(n-i)!
\]
设 \(p(n,l)\) 表示长度为 \(n\) 的随机排列,第一个满足条件的区间长度为 \(l\) 的概率。
\[p(n, l) = \frac{g(l)(n-l)!}{n!}
\]
\[f(0)=f(1)=0\\
\begin{aligned}
f(n) &=1 - \frac{1}{n} + \sum_{i=1}^n p(n, i)\left[ f(i) + f(n-i)\right] \\
&=1 - \frac1n + \frac{1}{n!}\sum_{i=1}^n g(i)f(i)(n-i)! + g(i)f(n-i)(n-i)! \\
&= \frac{n!}{n!-g(n)} (1 - \frac1n + \frac1{n!}\sum_{i=1}^{n-1} g(i)f(i)(n-i)! + g(i)f(n-i)(n-i)!)
\end{aligned}
\]
对 \(f,g\) 分别进行分治 NTT 即可。
#include "bits/stdc++.h"
using namespace std;
using ui=unsigned; using db=long double; using ll=long long; using ull=unsigned long long; using lll=__int128;
using pii=pair<int,int>; using pll=pair<ll,ll>;
template<class T1, class T2> istream &operator>>(istream &cin, pair<T1, T2> &a) { return cin>>a.first>>a.second; }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index==sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index < sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { is>>std::get<Index>(t); tuple_read<Index+1>(is, t); }
template <typename... Ts>std::istream &operator>>(std::istream &is, std::tuple<Ts...> &t) { tuple_read(is, t); return is; }
template<class T1> istream &operator>>(istream &cin, vector<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1> istream &operator>>(istream &cin, valarray<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1, class T2> bool cmin(T1 &x, const T2 &y) { if (y<x) { x=y; return 1; } return 0; }
template<class T1, class T2> bool cmax(T1 &x, const T2 &y) { if (x<y) { x=y; return 1; } return 0; }
istream &operator>>(istream &cin, lll &x) { x=0; static string s; cin>>s; for (char c:s) x=x*10+(c-'0'); return cin; }
ostream &operator<<(ostream &cout, lll x) { static char s[60]; int tp=1; s[0]='0'+(x%10); while (x/=10) s[tp++]='0'+(x%10); while (tp--) cout<<s[tp]; return cout; }
#if !defined(ONLINE_JUDGE)
#include "my_header/IO.h"
#include "my_header/defs.h"
#else
#define dbg(...) ;
#define dbgx(...) ;
#define dbg1(x) ;
#define dbg2(x) ;
#define dbg3(x) ;
#define DEBUG(msg) ;
#define REGISTER_OUTPUT_NAME(Type, ...) ;
#define REGISTER_OUTPUT(Type, ...) ;
#endif
#define all(x) (x).begin(),(x).end()
#define print(...) cout<<format(__VA_ARGS__)
#define err(...) cerr<<format(__VA_ARGS__)
const int mod1 = 998244353, mod2 = 1e9+7;
#define MOD1
#ifdef MOD1
const int p = mod1; int fpow(ll x,ll y=mod1-2,int m=mod1){int r=1;for(;y;y>>=1,x=(ll)x*x%m)if(y&1)r=(ll)r*x%m;return r;}
# else
const int p = mod2; int fpow(ll x,ll y=mod2-2,int m=mod2){int r=1;for(;y;y>>=1,x=(ll)x*x%m)if(y&1)r=(ll)r*x%m;return r;}
# endif
#define BINOM_ // Notice value of LIM !!
#ifdef BINOM_
const int LIM = 1e6+5;
namespace BINOM
{
int fac[LIM], inv[LIM], Inv[LIM];
void init()
{
fac[0] = fac[1] = inv[0] = inv[1] = Inv[0] = 1;
for (int i=2; i<LIM; ++i) fac[i] = (ll)fac[i-1]*i%p, inv[i] = p-(ll)p/i*inv[p%i]%p;
for (int i=1; i<LIM; ++i) Inv[i] = (ll)Inv[i-1]*inv[i]%p;
}
int C(int x, int y)
{
if (x<0||y<0||y>x) return 0;
return (ll)fac[x]*Inv[y]%p*Inv[x-y]%p;
}
int _=(init(), 0);
};
using BINOM::C; using BINOM::fac; using BINOM::inv; using BINOM::Inv;
#endif
const int N = 1<<20, mod = 998244353;
struct NTT
{
ll re[N], w[2][N], t0[N], t1[N];
int NTT_init(int n)
{
int len=1, bit=0;
while (len <= n) len<<=1,++bit;
for (int i=1; i<len; ++i) re[i] = (re[i>>1] >> 1) | ((i&1) << (bit-1));
w[0][0] = w[1][0] = 1;
w[0][1] = fpow(3, (mod-1)/len);
w[1][1] = fpow(w[0][1]);
for(int o=0; o<2; ++o)
for(int i=2; i<len; ++i)
w[o][i] = w[o][i-1] * w[o][1] % mod;
return len;
}
void ntt(ll* a, int n, int o=0)
{
for (int i=1; i<n; ++i) if (i < re[i]) swap(a[i], a[re[i]]);
int R;
for (int k=1; k<n; k<<=1)
for (int i=0, t=k<<1, st=n/t; i<n; i+=t)
for (int j=0, nw=0; j<k; ++j, nw+=st)
{
R = a[i+j+k] * w[o][nw] % mod;
a[i+j+k] = (a[i+j] - R + mod) % mod;
a[i+j] = (a[i+j] + R) % mod;
}
if(o)
{
ll t = fpow(n);
for (int i=0; i<n; ++i) a[i] = a[i] * t % mod;
}
}
vector<int> solve(vector<int>& a, vector<int>& b, int n) // return a*b%(x^n)
{
vector <int> res(n);
int len = NTT_init(n);
memset(t0, 0, sizeof(ll)*len);
for (int i=0; i<(int)a.size() && i<n; ++i) t0[i] = a[i];
memset(t1, 0, sizeof(ll)*len);
for (int i=0; i<(int)b.size() && i<n; ++i) t1[i] = b[i];
ntt(t0, len);
ntt(t1, len);
for (int i=0; i<len; ++i) t0[i] = t0[i] * t1[i] % mod;
ntt(t0, len, 1);
for (int i=0; i<n; ++i) res[i] = t0[i];
return res;
}
} ntt;
int f[N], g[N];
void cdqg(int l, int r)
{
if (l == r)
{
if (l != 1) g[l] = (fac[l] + p - g[l]) % p;
return ;
}
int mid = (l+r)>>1;
cdqg(l, mid);
vector<int> a(mid-l+1), b(r-l+1);
// g
for (int i=l; i<=mid; ++i) a[i-l] = g[i];
for (int i=1; i<=r-l; ++i) b[i] = fac[i];
auto G=ntt.solve(a, b, r-l+1);
for (int i=mid+1; i<=r; ++i) g[i] = (g[i] + G[i-l]) % p;
cdqg(mid+1, r);
}
void cdqf(int l, int r)
{
if (l == r)
{
if (l != 1) f[l] = (((ll)1+p-inv[l])+(ll)Inv[l]*f[l])%p*fac[l]%p*fpow(fac[l]+p-(ll)g[l])%p;
return ;
}
int mid = (l+r)>>1;
cdqf(l, mid);
vector<int> a(mid-l+1), b(r-l+1);
// f1
for (int i=l; i<=mid; ++i) a[i-l] = (ll)g[i] * f[i] % p;
for (int i=1; i<=r-l; ++i) b[i] = fac[i];
auto F1=ntt.solve(a, b, r-l+1);
for (int i=mid+1; i<=r; ++i) f[i] = (f[i] + F1[i-l]) % p;
// f2
for (int i=l; i<=mid; ++i) a[i-l] = ((ll)fac[i] * f[i]) % p;
for (int i=1; i<=r-l; ++i) b[i] = g[i];
auto F2 = ntt.solve(a, b, r-l+1);
for (int i=mid+1; i<=r; ++i) f[i] = (f[i] + F2[i-l]) % p;
cdqf(mid+1, r);
}
void init(int n)
{
f[1] = 0; g[1] = 1;
cdqg(1, n);
cdqf(1, n);
}
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
cout<<fixed<<setprecision(15);
init(100000);
int n; cin >> n;
vector<int> a(n);
cin >> a;
ll ans = 0;
for (int i=0; i<n; ++i)
{
int l=i, r=i, mi=a[i], ma=a[i];
while (r+1 < n && (mi != l+1 || r+1 != ma))
{
++r;
mi = min(mi, a[r]);
ma = max(ma, a[r]);
}
i=r;
ans = (ans + (r-l+1!=1) + f[r-l+1]) % p;
}
cout << ans << "\n";
return 0;
}
致虚极,守静笃,万物并作,吾以观其复