[做题记录-计数相关] [AGC023E] Inversions

题意

一个长度为\(n\)的排列数组, 每个位置有上限的限制, 求所有合法的排列的逆序数的和。

\(n \leq 10^5\)

题解

经典套路是先计数一下序列的个数然后考虑每一对数对答案的贡献。

考虑从大往小填数, 记\(b_i\)表示\(a_i \ge i\)的位置个数, 那么合法的排列个数为:

\[cnt = \prod_i b_i - (n - i) \]

然后现在对于序列上一对位置考虑。不妨设\(i < j\), \(p_i\)表示填好以后\(i\)位置上的数是什么。

\(a_i = a_j\)时, 显然\(p_i >p_j\)\(p_i < p_j\)的情况数相同, 贡献是\(\frac{cnt}{2}\)

\(a_i < a_j\)的时候, 讨论\(a_j\)的取值。当\(a_j \in [a_i + 1, a_j]\)的时候这里肯定是没有贡献的, 那么考虑强行让\(a_j = a_i\), 那么这样的话会发现\(b_i : which \ i \in [a_i + 1, a_j]\)会减少\(1\)。那么这里的贡献是:

\[\frac{cnt}{2} \prod _{k = a_i + 1}^{a_j}\frac{b_k - (n - k) - 1}{b_k - (n - k)} \]

\(a_i > a_j\)的时候, 不妨把\(i, j\)反过来, 讨论就变成了\(a_j < a_i\), 那么这里的顺序对就和上面情况中逆序对的情况是一样的, 用总方案减去出现顺序对的情况即可。

\[cnt - \frac{cnt}{2} \prod _{k = a_j + 1}^{a_i}\frac{b_k - (n - k) - 1}{b_k - (n - k)} \]

然后考虑快速计算, 以\(a_i < a_j\)为例, 考虑从小往大枚举\(a_j\), 维护一个全局的数据结构, 每次\(a_j\)增大的时候全局乘, 查询的时候查位置小于\(j\)的所有位置的权值和, 然后在\(j\)位置加入一个\(\frac{cnt}{2}\)

/*
	QiuQiu /qq
  ____    _           _                 __                
  / __ \  (_)         | |               / /                
 | |  | |  _   _   _  | |  _   _       / /    __ _    __ _ 
 | |  | | | | | | | | | | | | | |     / /    / _` |  / _` |
 | |__| | | | | |_| | | | | |_| |    / /    | (_| | | (_| |
  \___\_\ |_|  \__,_| |_|  \__, |   /_/      \__, |  \__, |
                            __/ |               | |     | |
                           |___/                |_|     |_|
*/

#include <bits/stdc++.h>

using namespace std;

class Input {
	#define MX 1000000
	private :
		char buf[MX], *p1 = buf, *p2 = buf;
		inline char gc() {
			if(p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MX, stdin);
			return p1 == p2 ? EOF : *(p1 ++);
		}
	public :
		Input() {
			#ifdef Open_File
				freopen("a.in", "r", stdin);
				freopen("a.out", "w", stdout);
			#endif
		}
		template <typename T>
		inline Input& operator >>(T &x) {
			x = 0; int f = 1; char a = gc();
			for(; ! isdigit(a); a = gc()) if(a == '-') f = -1;
			for(; isdigit(a); a = gc()) 
				x = x * 10 + a - '0';
			x *= f;
			return *this;
		}
		inline Input& operator >>(char &ch) {
			while(1) {
				ch = gc();
				if(ch != '\n' && ch != ' ') return *this;
			}
		}
		inline Input& operator >>(char *s) {
			int p = 0;
			while(1) {
				s[p] = gc();
				if(s[p] == '\n' || s[p] == ' ' || s[p] == EOF) break;
				p ++; 
			}
			s[p] = '\0';
			return *this;
		}
	#undef MX
} Fin;

class Output {
	#define MX 1000000
	private :
		char ouf[MX], *p1 = ouf, *p2 = ouf;
		char Of[105], *o1 = Of, *o2 = Of;
		void flush() { fwrite(ouf, 1, p2 - p1, stdout); p2 = p1; }
		inline void pc(char ch) {
			* (p2 ++) = ch;
			if(p2 == p1 + MX) flush();
		}
	public :
		template <typename T> 
		inline Output& operator << (T n) {
			if(n < 0) pc('-'), n = -n;
			if(n == 0) pc('0');
			while(n) *(o1 ++) = (n % 10) ^ 48, n /= 10;
			while(o1 != o2) pc(* (--o1));
			return *this; 
		}
		inline Output & operator << (char ch) {
			pc(ch); return *this; 
		}
		inline Output & operator <<(const char *ch) {
			const char *p = ch;
			while( *p != '\0' ) pc(* p ++);
			return * this;
		}
		~Output() { flush(); } 
	#undef MX
} Fout;

#define cin Fin
#define cout Fout
#define endl '\n'

using LL = long long;

inline int log2(unsigned int x);
inline int popcount(unsigned x);
inline int popcount(unsigned long long x);

template <int mod>
class Int {
	private :
		inline int Mod(int x) { return x + ((x >> 31) & mod); } 
		inline int power(int x, int k) {
			int res = 1;
			while(k) {
				if(k & 1) res = 1LL * x * res % mod;
				x = 1LL * x * x % mod; k >>= 1;
			}
			return res;
		}
	public :
		int v;
		Int(int _v = 0) : v(_v) {}
		operator int() { return v; }
		
		inline Int operator =(Int x) { return Int(v = x.v); }
		inline Int operator =(int x) { return Int(v = x); }
		inline Int operator *(Int x) { return Int(1LL * v * x.v % mod); }
		inline Int operator *(int x) { return Int(1LL * v * x % mod); }
		inline Int operator +(Int x) { return Int( Mod(v + x.v - mod) ); }
		inline Int operator +(int x) { return Int( Mod(v + x - mod) ); }
		inline Int operator -(Int x) { return Int( Mod(v - x.v) ); }
		inline Int operator -(int x) { return Int( Mod(v - x) ); }
		inline Int operator ~() { return Int(power(v, mod - 2)); }
		inline Int operator +=(Int x) { return Int(v = Mod(v + x.v - mod)); }
		inline Int operator +=(int x) { return Int(v = Mod(v + x - mod)); }
		inline Int operator -=(Int x) { return Int(v = Mod(v - x.v)); }
		inline Int operator -=(int x) { return Int(v = Mod(v - x)); }
		inline Int operator *=(Int x) { return Int(v = 1LL * v * x.v % mod); }
		inline Int operator *=(int x) { return Int(v = 1LL * v * x % mod); }
		inline Int operator /=(Int x) { return Int(v = v / x.v); }
		inline Int operator /=(int x) { return Int(v = v / x); }
		inline Int operator ^(int k) { return Int(power(v, k)); }
} ;

using mint = Int<(int) (1e9 + 7)>;

const int N = 2e5 + 10;
const mint inv2 = ~ mint(2);

int n;
int a[N], b[N];
mint cnt;

struct Node {
	Node *ls, *rs;
	mint cj, tg;
	int l, r;
	Node() {}
	Node(int _l, int _r) : l(_l), r(_r), cj(0), tg(1), ls(NULL), rs(NULL) {}
	void upd() {
		cj = ls -> cj + rs -> cj;
	}
	void downcj(mint v) { cj *= v; tg *= v; }
	void pushdown() {
		if(tg != 1) {
			ls -> downcj(tg);
			rs -> downcj(tg);
			tg = 1;
		}
	}
	void modify(int pos, mint v) {
		if(l == r) { cj = v; return ; }
		pushdown();
		int mid = (l + r) >> 1;
		if(pos <= mid) ls -> modify(pos, v);
		else rs -> modify(pos, v);
		upd();
	}
	mint qry(int L, int R) {
		if(L <= l && r <= R) return cj;
		pushdown();
		int mid = (l + r) >> 1;
		mint res = 0;
		if(L <= mid) res += ls -> qry(L, R);
		if(R > mid) res += rs -> qry(L, R);
		return res;
	}
	void mul(mint v) { downcj(v); return ; }
} ;

Node *root;

Node *build(int l, int r) {
	Node * x = new Node(l, r);
	if(l == r) return x;
	int mid = (l + r) >> 1;
	x -> ls = build(l, mid);
	x -> rs = build(mid + 1, r);
	return x;
}

using pii = pair<int, int>;

int c[N];
#define lowbit(x) (x & -x)
void upd(int x, int y) {
	for(; x <= n; x += lowbit(x)) c[x] += y;
}
int qry(int x) {
	int ans = 0;
	for(; x; x -= lowbit(x)) ans += c[x];
	return ans;
}

int main() {
	cin >> n;
	for(int i = 1; i <= n; i ++) cin >> a[i];
	for(int i = 1; i <= n; i ++) b[a[i]] ++;
	for(int i = n; i >= 1; i --) b[i] += b[i + 1];
	cnt = 1;
	for(int i = 1; i <= n; i ++) cnt = cnt * (b[i] - (n - i));
	root = build(1, n);
	mint ans = 0;
	static vector<int> lim[N];
	for(int i = 1; i <= n; i ++) lim[a[i]].push_back(i);
	for(int i = 1; i <= n; i ++) {
		mint value = b[i] - (n - i) - 1;
		value = value * (~ (value + 1));
		root -> mul(value);
		for(int j : lim[i]) ans += root -> qry(1, j);
		for(int j : lim[i]) root -> modify(j, cnt * inv2);
		mint t = lim[i].size();
		ans += t * (t - 1) * inv2 * cnt * inv2;
	}
	//cout << ans << endl;
	for(int i = n; i >= 1; i --) {
		ans += cnt * qry(a[i] - 1);
		upd(a[i], 1); 
	}
//	cerr << ans << endl;
	root = build(1, n);
	for(int i = 1; i <= n; i ++) {
		mint value = b[i] - (n - i) - 1;
		value = value * (~ (value + 1));
		root -> mul(value);
		for(int j : lim[i]) ans -= root -> qry(j, n);
		for(int j : lim[i]) root -> modify(j, cnt * inv2);
		//mint t = lim[i].size();
		//ans -= t * (t - 1) * inv2 * cnt * inv2;
	}
	cout << ans << endl;
	return 0;
}

inline int log2(unsigned int x) { return __builtin_ffs(x); }
inline int popcount(unsigned int x) { return __builtin_popcount(x); }
inline int popcount(unsigned long long x) { return __builtin_popcountl(x); }
posted @ 2021-09-17 14:34  HN-wrp  阅读(33)  评论(0编辑  收藏  举报