ABC306F 题解
题目链接
题目大意
对于 \(S_1 \cap S_2 = \emptyset\),
定义长度为 \(|S_1|+|S_2|\) 的序列 \(A\),为 \(S_1\cup S_2\) 排序后的结果。
定义二元函数 \(f(S_1,S_2)=\sum\limits_{1\leq i\leq |S_1|+|S_2|} i\times[A_i\in S_1]\) 。
给定 \(n\) 个大小为 \(m\) 的正整数集合 \(S\),保证给的这 \(nm\) 个数互不相同,求 \(\sum\limits_{1\leq i<j\leq n} f(S_i,S_j)\)
\(n\leq 1\times 10^4\)
\(m\leq 1\times 10^2\)
题目分析
注意到答案中要求 \(i<j\),
考虑对于计算每个集合中的每个数的对答案的贡献。
可以发现,元素 \(s\) 在 \(f(S_i,S_j)\) 的贡献就是 \(S_1\cup S_2\) 中小于 \(s\) 的数的个数+1。
所以对于共 \(n\times m\) 个元素,计算它们在每种合法的 \(f(S_i,S_j) (1\leq i<j\leq n)\) 的贡献即可。
具体来说,将所有元素从小到大排序并保存该元素属于第几个集合,从前到后依次扫描这些元素。
设第 \(i\) 个元素属于第 \(p\) 个集合,设前 \(i-1\) 个元素中属于集合 \(k\) 的元素个数有 \(cnt_k\)个,那么第 \(i\) 个元素对答案的贡献就是 \(\sum\limits_{p<j \leq n}{cnt_j+cnt_p+1}\)。
这个可以拆成关于 \(\sum\limits_{p<j\leq n} cnt_j\) 和 \(cnt_p\) 的形式。
这两个都可以用树状数组或线段树维护,支持区间求和单点加即可。
就做完了。
时间复杂度 \(\mathcal{O}(nm\log nm)\)。 (瓶颈在排序qaq
场上直接拿线段树写了,树状数组常数更小。
参考代码
#include<bits/stdc++.h>
using namespace std;
template<typename T>
void read(T &x){
	x=0;
	int sgn=0;
	char c=getchar();
	while(!isdigit(c)) sgn|=(c=='-'),c=getchar();
	while(isdigit(c)) x=x*10-'0'+c,c=getchar();
	if(sgn) x=-x;
}
const int N=1000010;
struct segment{
	struct node{
		int l,r;
		long long sum;
	};
	
	node tr[N<<2];
	
	#define ls u<<1
	#define rs u<<1|1
	
	void build(int u,int l,int r){
		tr[u].l=l,tr[u].r=r,tr[u].sum=r-l+1; //init 1,1,1,....
		if(l==r) return ;
		int mid=(l+r)>>1;
		build(ls,l,mid),build(rs,mid+1,r);
	}
	
	void add(int u,int pos){
		int l=tr[u].l,r=tr[u].r;
		if(l==r){
			tr[u].sum++;
			return ;
		}
		int mid=(l+r)>>1;
		if(pos<=mid) add(ls,pos);
		else add(rs,pos);
		tr[u].sum=tr[ls].sum+tr[rs].sum;
	}
	
	long long query(int u,int L,int R){
		int l=tr[u].l,r=tr[u].r;
		if(L<=l&&r<=R) return tr[u].sum;
		int mid=(l+r)>>1;
		long long ret=0;
		if(L<=mid) ret+=query(ls,L,R);
		if(mid<R) ret+=query(rs,L,R);
		return ret;
	}
	
	#undef ls
	#undef rs
}sgt;
int n,m,idx;
pair<int,int> p[N];
long long ans=0;
int main(){
	
	read(n),read(m);
	for(int i=1; i<=n; i++){
		for(int j=1,a; j<=m; j++){
			read(a);
			p[++idx]={a,i};
		}
	}
	
	sort(p+1,p+1+n*m);
	int t=n*m;
	sgt.build(1,1,n);
	for(int i=1; i<=t; i++){
		ans+=sgt.query(1,p[i].second+1,n)+((sgt.query(1,p[i].second,p[i].second)-1)*(n-p[i].second));
//		cout<<ans<<' ';
		sgt.add(1,p[i].second);
	}//cout<<endl;
	
	cout<<ans;
	
	return 0;
}

 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号