【ICPC2024 上海F】Fast Bogosort

分治NTT

定义排序算法:

  • 若当前序列有序,结束;
  • 将当前序列尽可能多的序列,使得每个区间的坐标范围与值域相同;
  • 对于每个分出的区间,如果 \(l<r\),则执行 \(\text{shuffle(l, r)}\)

给一个排列,问期望的 \(\text{shuffle}\) 次数,使得区间变得有序。对 \(998244353\) 取模。

Solution

当前序列按照下标分成 \(\cup_{i=1}^m [l_i, r_i]\)

\(f(n)\) 为长度为 \(n\) 的随机排列的期望 \(\text{shuffle}\) 次数。

\[Ans = m + \sum_{i=1}^m f(r_i - l_i + 1) \]

\(g(n)\) 表示长度为 \(n\) 且不可被划分的排列的方案数,则

\[g(1) = 1, g(n) = n! - \sum_{i=1}^{n-1}g(i)(n-i)! \]

\(p(n,l)\) 表示长度为 \(n\) 的随机排列,第一个满足条件的区间长度为 \(l\) 的概率。

\[p(n, l) = \frac{g(l)(n-l)!}{n!} \]

\[f(0)=f(1)=0\\ \begin{aligned} f(n) &=1 - \frac{1}{n} + \sum_{i=1}^n p(n, i)\left[ f(i) + f(n-i)\right] \\ &=1 - \frac1n + \frac{1}{n!}\sum_{i=1}^n g(i)f(i)(n-i)! + g(i)f(n-i)(n-i)! \\ &= \frac{n!}{n!-g(n)} (1 - \frac1n + \frac1{n!}\sum_{i=1}^{n-1} g(i)f(i)(n-i)! + g(i)f(n-i)(n-i)!) \end{aligned} \]

\(f,g\) 分别进行分治 NTT 即可。

#include "bits/stdc++.h"
using namespace std;
using ui=unsigned; using db=long double; using ll=long long; using ull=unsigned long long; using lll=__int128;
using pii=pair<int,int>; using pll=pair<ll,ll>;
template<class T1, class T2> istream &operator>>(istream &cin, pair<T1, T2> &a) { return cin>>a.first>>a.second; }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index==sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index < sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { is>>std::get<Index>(t); tuple_read<Index+1>(is, t); }
template <typename... Ts>std::istream &operator>>(std::istream &is, std::tuple<Ts...> &t) { tuple_read(is, t); return is; }
template<class T1> istream &operator>>(istream &cin, vector<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1> istream &operator>>(istream &cin, valarray<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1, class T2> bool cmin(T1 &x, const T2 &y) { if (y<x) { x=y; return 1; } return 0; }
template<class T1, class T2> bool cmax(T1 &x, const T2 &y) { if (x<y) { x=y; return 1; } return 0; }
istream &operator>>(istream &cin, lll &x) { x=0; static string s; cin>>s; for (char c:s) x=x*10+(c-'0'); return cin; }
ostream &operator<<(ostream &cout, lll x) { static char s[60]; int tp=1; s[0]='0'+(x%10); while (x/=10) s[tp++]='0'+(x%10); while (tp--) cout<<s[tp]; return cout; }
#if !defined(ONLINE_JUDGE)
#include "my_header/IO.h"
#include "my_header/defs.h"
#else
#define dbg(...) ;
#define dbgx(...) ;
#define dbg1(x) ;
#define dbg2(x) ;
#define dbg3(x) ;
#define DEBUG(msg) ;
#define REGISTER_OUTPUT_NAME(Type, ...) ;
#define REGISTER_OUTPUT(Type, ...) ;
#endif
#define all(x) (x).begin(),(x).end()
#define print(...) cout<<format(__VA_ARGS__)
#define err(...) cerr<<format(__VA_ARGS__)
const int mod1 = 998244353, mod2 = 1e9+7;
#define MOD1
#ifdef MOD1
const int p = mod1; int fpow(ll x,ll y=mod1-2,int m=mod1){int r=1;for(;y;y>>=1,x=(ll)x*x%m)if(y&1)r=(ll)r*x%m;return r;}
# else
const int p = mod2; int fpow(ll x,ll y=mod2-2,int m=mod2){int r=1;for(;y;y>>=1,x=(ll)x*x%m)if(y&1)r=(ll)r*x%m;return r;}
# endif
#define BINOM_ // Notice value of LIM !!
#ifdef BINOM_
const int LIM = 1e6+5;
namespace BINOM
{
int fac[LIM], inv[LIM], Inv[LIM];
void init()
{
    fac[0] = fac[1] = inv[0] = inv[1] = Inv[0] = 1;
    for (int i=2; i<LIM; ++i) fac[i] = (ll)fac[i-1]*i%p, inv[i] = p-(ll)p/i*inv[p%i]%p;
    for (int i=1; i<LIM; ++i) Inv[i] = (ll)Inv[i-1]*inv[i]%p;
}
int C(int x, int y)
{
    if (x<0||y<0||y>x) return 0;
    return (ll)fac[x]*Inv[y]%p*Inv[x-y]%p;
}
int _=(init(), 0);
};
using BINOM::C; using BINOM::fac; using BINOM::inv; using BINOM::Inv;
#endif

const int N = 1<<20, mod = 998244353;
struct NTT
{
	ll re[N], w[2][N], t0[N], t1[N];
	int NTT_init(int n)
	{
		int len=1, bit=0;
		while (len <= n) len<<=1,++bit;
		for (int i=1; i<len; ++i) re[i] = (re[i>>1] >> 1) | ((i&1) << (bit-1));
		w[0][0] = w[1][0] = 1;
		w[0][1] = fpow(3, (mod-1)/len);
		w[1][1] = fpow(w[0][1]);
		for(int o=0; o<2; ++o)
			for(int i=2; i<len; ++i)
				w[o][i] = w[o][i-1] * w[o][1] % mod;
		return len;
	}
	void ntt(ll* a, int n, int o=0)
	{
		for (int i=1; i<n; ++i) if (i < re[i]) swap(a[i], a[re[i]]);
		int R;
		for (int k=1; k<n; k<<=1)
			for (int i=0, t=k<<1, st=n/t; i<n; i+=t)
				for (int j=0, nw=0; j<k; ++j, nw+=st)
				{
					R = a[i+j+k] * w[o][nw] % mod;
					a[i+j+k] = (a[i+j] - R + mod) % mod;
					a[i+j] = (a[i+j] + R) % mod;
				}
		if(o)
        {
			ll t = fpow(n);
			for (int i=0; i<n; ++i) a[i] = a[i] * t % mod;
		}
	}

	vector<int> solve(vector<int>& a, vector<int>& b, int n)	// return a*b%(x^n)
	{
		vector <int> res(n);
		int len = NTT_init(n);
		memset(t0, 0, sizeof(ll)*len);
		for (int i=0; i<(int)a.size() && i<n; ++i) t0[i] = a[i];
		memset(t1, 0, sizeof(ll)*len);
		for (int i=0; i<(int)b.size() && i<n; ++i) t1[i] = b[i];
		ntt(t0, len);
		ntt(t1, len);
		for (int i=0; i<len; ++i) t0[i] = t0[i] * t1[i] % mod;
		ntt(t0, len, 1);
		for (int i=0; i<n; ++i) res[i] = t0[i];
		return res;
	}
} ntt;

int f[N], g[N];

void cdqg(int l, int r)
{
    if (l == r) 
    {
        if (l != 1)  g[l] = (fac[l] + p - g[l]) % p;
        return ;
    }
    int mid = (l+r)>>1;
    cdqg(l, mid);
    vector<int> a(mid-l+1), b(r-l+1);
    // g
    for (int i=l; i<=mid; ++i) a[i-l] = g[i];
    for (int i=1; i<=r-l; ++i) b[i] = fac[i];
    auto G=ntt.solve(a, b, r-l+1);
    for (int i=mid+1; i<=r; ++i) g[i] = (g[i] + G[i-l]) % p;
    cdqg(mid+1, r);
}

void cdqf(int l, int r)
{
    if (l == r) 
    {
        if (l != 1) f[l] = (((ll)1+p-inv[l])+(ll)Inv[l]*f[l])%p*fac[l]%p*fpow(fac[l]+p-(ll)g[l])%p;
        return ;
    }
    int mid = (l+r)>>1;
    cdqf(l, mid);
    vector<int> a(mid-l+1), b(r-l+1);
    // f1
    for (int i=l; i<=mid; ++i) a[i-l] = (ll)g[i] * f[i] % p;
    for (int i=1; i<=r-l; ++i) b[i] = fac[i];
    auto F1=ntt.solve(a, b, r-l+1);
    for (int i=mid+1; i<=r; ++i) f[i] = (f[i] + F1[i-l]) % p;
    // f2
    for (int i=l; i<=mid; ++i) a[i-l] = ((ll)fac[i] * f[i]) % p;
    for (int i=1; i<=r-l; ++i) b[i] = g[i];
    auto F2 = ntt.solve(a, b, r-l+1);
    for (int i=mid+1; i<=r; ++i) f[i] = (f[i] + F2[i-l]) % p;
    cdqf(mid+1, r);
}

void init(int n)
{
    f[1] = 0; g[1] = 1;
    cdqg(1, n);
    cdqf(1, n);
}

int main()
{
	ios::sync_with_stdio(0); cin.tie(0);
	cout<<fixed<<setprecision(15);
    init(100000);
    int n; cin >> n;
    vector<int> a(n);
    cin >> a;
    ll ans = 0;
    for (int i=0; i<n; ++i)
    {
        int l=i, r=i, mi=a[i], ma=a[i];
        while (r+1 < n && (mi != l+1 || r+1 != ma)) 
        {
            ++r;
            mi = min(mi, a[r]);
            ma = max(ma, a[r]);
        }
        i=r;
        ans = (ans + (r-l+1!=1) + f[r-l+1]) % p;
    }
    cout << ans << "\n";
    return 0;
}
posted @ 2025-03-10 21:34  PaperCloud  阅读(75)  评论(0)    收藏  举报