题解 I. Three Body "蔚来杯"2022牛客暑期多校训练营4

传送门


【大意】

给定 \(K\) 维数组 \(S,T\) ,其中每个元素都是不超过 \(K\) 的正整数。求有多少个位置,使得 \(T\)\(T_{0, 0}\) 元素对齐该位置后,整个 \(T\) 数组的值都不超过 \(S\) 数组对应位置的值


【分析】

我们令 \(g_{v, x_1, x_2, \cdots, x_K}=[T_{x_1, x_2, \cdots, x_K}=v]\) ,表示 \(T\) 数组的对应位置是否为 \(v\)

\(f_{v, x_1, x_2, \cdots, x_K}=[S_{x_1, x_2, \cdots, x_K}<v]\) ,表示 \(S\) 数组的对应位置是否严格小于 \(v\)

那么,对于 \(\displaystyle h_{v, x_1, x_2, \cdots, x_K}=\sum_{i_1-j_1=x_1}\sum_{i_2-j_2=x_2}\cdots\sum_{i_K-j_K=x_K}f_{v, i_1, i_2, \cdots, i_K}\cdot g_{v, j_1, j_2, \cdots, j_K}\) ,表示位置 \((x_1, x_2, \cdots, x_K)\) 处放置 \(T\) 数组的 \(T_{0, 0}\) 元素后,因为 \(T\) 数组中大小为 \(v\) 的元素,产生的冲突次数

我们对 \(h_{v, x_1, x_2, \cdots, x_K}\)\(v=1\)\(v=K\) 进行求和,则 \(\displaystyle h_{x_1, x_2, \cdots, x_K}=\sum_{v=1}^Kh_{v, x_1, x_2, \cdots, x_K}\) ,表示位置 \((x_1, x_2, \cdots, x_K)\) 处放置 \(T\) 数组的 \(T_{0, 0}\) 元素后产生的冲突次数

显然只有冲突为 \(0\) 的位置是可以摆放的,故答案为 \(h\) 中,\(0\) 出现的次数


现在的问题化为如何求解 \(h_{v, x_1, x_2, \cdots, x_K}\)

考虑到求解式子类似减法卷积的形式,我们直接定义 \(g[v][x_1][x_2]\cdots[x_K]\to g'[x_1n_2n_3\cdots n_K+x_2n_3\cdots n_K+\cdots x_{K-1}n_K+x_K]\)

同理定义 \(f'\)\(h'\) ,并对之前未定义的位置置 \(0\)

于是,\(\displaystyle h'_x=\sum_{i-j=x}f'_ig'_j\) ,其中 \(x=x_1n_2n_3\cdots n_K+x_2n_3\cdots n_K+\cdots x_{K-1}n_K+x_K, i=i_1n_2n_3\cdots n_K+i_2n_3\cdots n_K+\cdots i_{K-1}n_K+i_K, j=j_1n_2n_3\cdots n_K+j_2n_3\cdots n_K+\cdots j_{K-1}n_K+j_K\)

而对应过来,有 \(i-j=(i_1-j_1)n_2n_3\cdots n_K+(i_2-j_2)n_3\cdots n_K+\cdots+(i_{K-1}-j_{K-1})n_K+(i_K-j_K)=x\) ,故 \(i_t-j_t=x_t\) 在原对应位置均合法

然而,唯一的问题出在边界上,需要特判一下,在该位置上的值,加上 \(T\) 数组的大小后,是否会超过这一维度的边界

由于超过边界的值也是被置 \(0\) 的,故可能计算出的答案也是 \(0\) ,但该位置因为过于靠近边界,答案是不能被统计的


【代码】

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define sz(a) (int)a.size()
#define de(a) cout << #a << " = " << a << endl
#define dd(a) cout << #a << " = " << a << " "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define lc(x) ((x)<<1)
#define rc(x) ((x)<<1|1)
#define rsz(a, x) (a.resize(x))
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;

const int P=998244353;
inline int kpow(int a, int x, int p=P) { int ans=1; for(;x;x>>=1, a=(ll)a*a%p) if(x&1) ans=(ll)ans*a%p; return ans; }
inline int exgcd(int a, int b, int &x, int &y) {
	static int g;
	return b?(exgcd(b, a%b, y, x), y-=a/b*x, g):(x=1, y=0, g=a);
}
inline int inv(int a, int p=P) {
	static int x, y;
	return exgcd(a, p, x, y)==1?(x<0?x+p:x):(-1);
}
const int LimBit=19;
const int M=1<<LimBit<<1;
namespace Poly{
	const int G=3;
	struct vir {
		int v;
		vir(int v_=0):v(v_>=P?v_-P:v_) {}
		inline vir operator + (const vir &x) const { return vir(v+x.v); }
		inline vir operator - (const vir &x) const { return vir(v+P-x.v); }
		inline vir operator * (const vir &x) const { return vir((ll)v*x.v%P); }
		
		inline vir operator - () const { return vir(P-v); }
		inline vir operator ! () const { return vir(inv(v)); }
		inline operator int() const { return v; }
	};
	struct poly : public vector<vir> {
		inline friend ostream& operator << (ostream& out, const poly &p) {
			if(!p.empty()) out<<(int)p[0];
			for(int i=1; i<sz(p); ++i) out<<" "<<(int)p[i];
			return out;
		}
	};
	
	int N, N_, Stk[M], curStk, rev[M];
	vir invN, Inv[M], w[2][M];
	inline void init() {
		N_=-1;
		curStk=0;
		Inv[1]=1;
		for(int i=2; i<M; ++i)
			Inv[i]=-vir(P/i)*Inv[P%i];
	}
	
	void work(){
		if(N_==N) return ;
		N_=N;
		int d = __builtin_ctz(N);
		vir x(kpow(G, (P-1)/N)), y=!x;
		w[0][0] = w[1][0] = 1;
		for (int i = 1; i < N; ++i) {
			rev[i] = (rev[i>>1] >> 1) | ((i&1) << (d-1));
			w[0][i]=x*w[0][i-1], w[1][i]=y*w[1][i-1];
		}
		invN=!vir(N);
	}
	
	inline void FFT(vir a[M],int f){
		static auto make = [=](vir w, vir &a, vir &b) { w=w*a; a=b-w; b=b+w; };
		for(int i=0;i<N;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
		for(int i=1;i<N;i<<=1)
			for(int j=0,t=N/(i<<1);j<N;j+=i<<1)
				for(int k=0,l=0;k<i;k++,l+=t)
					make(w[f][l], a[j+k+i], a[j+k]);
		if(f) for(int i=0;i<N;++i) a[i]=a[i]*invN;
	}
	
	vir p1[M], p0[M];
	inline void get_mul(poly &a, poly &b, int na, int nb) {//3*FFT
		for(N=1;N<na+nb-1;N<<=1);
		for(int i=0; i<na; ++i) p1[i]=(int)a[i]; for(int i=na;i<N;++i) p1[i]=0;
		for(int i=0; i<nb; ++i) p0[i]=(int)b[i]; for(int i=nb;i<N;++i) p0[i]=0;
		work(); FFT(p1,0); FFT(p0,0);
		for(int i=0;i<N;++i) p1[i]=p1[i]*p0[i];
		FFT(p1,1);
		rsz(a, na+nb-1); for(int i=0; i<sz(a); ++i) a[i]=p1[i];
	}
	inline void get_mulT(poly &a, poly &b, poly&c, int na, int nb, int n) {//c=a*r(b)
		c=b;
		reverse(all(c));
		get_mul(c, a, nb, na);
		for(int i=0, j=nb-1; i<n; ++i, ++j)
			c[i]=c[j];
		rsz(c, n);
	}
}
using Poly::poly;

const int MAXN=3e5+10;
int k, n[7], val[MAXN], ship[MAXN], m[6];
poly f, g, h;
inline int alw(int pos) {
	if(h[pos].v)
		return 0;
	for(int i=1; i<=k; ++i) {
		if(pos+(m[i]-1)*n[i+1]>=n[i])
			return 0;
		pos%=n[i+1];
	}
	return 1;
}
inline void work() {
	rsz(h, n[1]);
	for(int t=1; t<=k; ++t) {
		rsz(f, n[1]); rsz(g, n[1]);
		for(int i=0; i<n[1]; ++i) {
			f[i]=(val[i]<t);
			g[i]=(ship[i]==t);
		}
		Poly::get_mulT(f, g, g, n[1], n[1], n[1]);
		for(int i=0; i<n[1]; ++i)
			h[i]=h[i]+g[i];
	}
	int res=0;
	for(int i=0; i<n[1]; ++i)
		if(alw(i)) ++res;
	cout<<res;
}
void draw(int m[], int k, int pos, int x) {
	if(pos>k) {
		ship[x]=k;
		return ;
	}
	for(int i=0; i<m[pos]; ++i)
		draw(m, k, pos+1, x+i*n[pos+1]);
}
inline void init() {
	Poly::init();
	cin>>k;
	for(int i=1; i<=k; ++i) cin>>n[i];
	n[k+1]=1;
	for(int i=k; ~i; --i) n[i]*=n[i+1];
	for(int i=0; i<n[1]; ++i) val[i]=k;
	
	int c; cin>>c;
	for(int i=1; i<=c; ++i) {
		int x=0, y;
		for(int j=1; j<=k; ++j)
			cin>>y, x+=y*n[j+1];
		cin>>y;
		val[x]=y;
	}
	
	for(int i=1; i<=k; ++i) cin>>m[i];
	draw(m, k, 1, 0);
	cin>>c;
	for(int i=1; i<=c; ++i) {
		int x=0, y;
		for(int j=1; j<=k; ++j)
			cin>>y, x+=y*n[j+1];
		cin>>y;
		ship[x]=y;
	}
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	init();
	work();
	cout.flush();
	return 0;
}
posted @ 2022-08-03 10:34  JustinRochester  阅读(30)  评论(0编辑  收藏  举报