0313

今天是来到南京外培的第一天,打了模拟赛,挂大分……

T1

从特殊情况入手:\(f(m)\) 怎么求?

首先 \(f(m)\) 的上限是长度不小于 \(m\) 的区间的数量,如果答案能达到上限就好了。怎么样才能达到上界?那就必须要每一个长度为 \(m\) 的区间都是 \(0 \sim m-1\) 的排列,如果 \(s_i\) 都为 \(0\) 我们全部按照 \(0\sim m-1\) 依次周期排列就行,但是有一些 \(s_i\)\(1\) 怎么办?可以把这些 \(s_i\)\(1\) 的数放在每个周期的最前面,然后最后一个周期只有 \(s_i=1\) 的,比如 \(1320132013\),于是 \(f(m)\) 的答案一定能取到上限。

对于一般的 \(i\),我们把所有数分为两类 —— 小于 \(i\) 的和大于等于 \(i\) 的,我们先把小于 \(i\) 的按上述的方式放好,然后考虑把大于等于 \(i\) 的一个一个插入序列。我们把小于 \(i\) 的数写作 \(0\),大于 \(i\) 的写作 \(1\)。考虑新增一个 \(1\) 的贡献,和这个 \(1\) 之间有超过 \(i-1\)\(0\) 的位置可以作为合法区间的另一个端点。可以推理得出——我们一定会把 \(1\) 分为若干组,每组插入开头/结尾/一段 \(0\) 周期的缝隙之中。证明:对于不在同一组的两个 \(1\) 以它们为端点的区间是合法的,如果两个组之间的距离不到一个周期,那么还不如把它们合并到一起,为了能塞下尽量多的组,所以每个一个周期塞一个组。

然后我们要考虑如何把这些 \(1\) 分组,考虑每个组的贡献,对于两个端点都是 \(1\) 的区间,合法区间数量是 \(C_{cnt1}^2 -\sum C_{sz_i}^2\),然后对于一个端点是 \(0\),一个端点是 \(1\) 的区间,合法的数量是 \((sz_1+sz_p) \times (tot_1-k+1) + \sum_{i=2}^p sz_i \times (tot1-2k+2)\)。乍一看这是很难计算的,但是我们可以把每一组插入每一个数的贡献差分出来:对于第一组和最后一组,插入第 \(i\) 个数的贡献是 \(tot_1-k+2-i\) ,对于中间的组,插入第 \(i\) 个数的贡献是 \(tot_1-2k+3-i\) 然后我们显然可以每次贪心取最大的贡献。所以就得到结论:先在两边的组轮流插,每个组插了 \(k-1\) 个以后所有区间的贡献都一致了,然后对着所有组轮流插。

得到策略以后,直接计算即可,注意实现细节!!!

然后不知道为什么,我赛时读入 \(n\) 的时候取了个模!!!!!然后当 \(n>=998244353\) 会寄,然后挂了 45 分!!!!!你肉眼检查了个寂寞!!!!不要只检查代码的核心部分啊!!!!!!!!!!!!!!!

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
inline int qpow(int a,int b){
	int c=1;
	for(;b;b>>=1,a=mul(a,a))if(b&1)c=mul(c,a);
	return c;
}
int m,l,r,x;
char s[10000005];
int main(){
	scanf("%d%d%d%d",&m,&l,&r,&x);
	scanf("%s",s);
	long long n=1ll*m*x;
	for(int i=0;i<m;i++)n+=s[i]=='1';
	int res=1ll*n*(n+1)/2%mod;
	int ans=0;if(l==0)ans=1ll*n*(n+1)/2%mod;
	ll n1=0,n2=n;int p=x+1;
	for(int i=1,pw=233;i<=r;i++,pw=233ll*pw%mod){
		n1+=x+(s[i-1]=='1'),n2-=x+(s[i-1]=='1');
        int val2=0,m1=n1%mod,m2=n2%mod;
		int val1=(1ll*m1*(m1+1)>>1)%mod;
		p=min(p,x+(s[i-1]=='1'));
		if(i<l)continue;
		val1=add(val1,mod-mul(i-1,m1+1));
		val1=add(val1,(1ll*i*(i-1)>>1)%mod);
		val2=1ll*m2*(m2+1)/2%mod;
		if(n2<=i-1<<1){
			val2=add(val2,mul(m2,add(m1,mod-(i-1))));
			int x=n2>>1,y=n2-x;
			val2=add(val2,mod-(1ll*x*(x+1)>>1)%mod);
			val2=add(val2,mod-(1ll*y*(y+1)>>1)%mod);
		}else{
			val2=add(val2,add(mod-1ll*i*(i-1)%mod,mul(i-1<<1,add(m1,mod-(i-1)))));
			ll nn2=n2-(i-1<<1);
			val2=add(val2,mul(nn2%mod,add(m1,mod-(i-1<<1))));
			ll x=nn2/(p+1),y=nn2-x*(p+1);
			val2=add(val2,mod-mul((1ll*x*(x+1)>>1)%mod,p+1-y));
			val2=add(val2,mod-mul((1ll*(x+1)*(x+2)>>1)%mod,y));
		}
		int val=mul(add(val1,val2),pw);
		if(i>=l)ans^=val;
	}
	cout<<ans<<endl;
	return 0;
}

T2

好像是孙心去年讲过的题,但是我当年就没听懂,现在更是忘完了!

倒着想,问题变成了我们从 \((a,b,c)\) 出发,求不走出边界的长度为 \(d\) 的路径的方案数。三个维度互相不影响,我们可以对三个维度分别算出走 \(i\) 步的路径的方案 \(X_i,Y_i,Z_i\),然后把它们的 egf 卷起来,就得到了最后的答案。

那么 \(X,Y,Z\) 怎么算捏?

如果不考虑边界的限制,那么 \(X_i\) 就是一段组合数的和,这是可以 \(O(n)\) 直接算的。如果只有下边界的限制,我们考虑穿过下边界的路径,我们在它第一次到达 \(0\) 的时候将后面的路径沿 \(y=0\) 对称,这样就可以让穿过下边界的路径和终点为 \([-n,-1]\) 的路径一一对应。只考虑上边界也同理,但是有的路径会被重复计算,所以要容斥。容斥的结果是终点在 \([k(n+1)+1,k(n+1)+n]\) 的路径的容斥系数是 \((-1)^k\)。然后我们求 \(X_i\),就是对于容斥出来的每一段,算出一行组合数的一段区间和,在 \(i\) 变化的时候每段只会变化 \(O(1)\)。然后就可以直接维护,复杂度 \(O(d \frac n d)\)。再搭配上 \(O(dn)\) 的暴力就可以做到 \(O(d \sqrt d)\)

除此以外,我们还可以分治 ntt。将路径走到 \(0,n-1\) 视为关键点,我们钦定一些时刻必须在关键点来容斥,然后可以分治 ntt,这样做思维难度要小一些,但是代码更难写且跑的更慢。复杂度 \(O(n \log^2 n)\)

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
inline int qpow(int a,int b){
    int c=1;
    for(;b;b>>=1,a=mul(a,a))if(b&1)c=mul(c,a);
    return c;
}
const int N=1<<20|5,inf=1e9;
int D;
void ntt(int *a,int n,int op){
    static int pw[N],rev[N];
    for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1){
        int g=qpow(3,mod/i/2);
        pw[0]=1;
        for(int j=1;j<i;j++)pw[j]=mul(pw[j-1],g);
        for(int j=0;j<n;j+=i<<1)
            for(int k=0;k<i;k++){
                int x=a[j+k],y=mul(a[i+j+k],pw[k]);
                a[j+k]=add(x,y),a[i+j+k]=add(x,mod-y);
            }
    }
    if(op==-1){
        int inv=qpow(n,mod-2);
        for(int i=0;i<n;i++)a[i]=mul(a[i],inv);
        reverse(a+1,a+n);
    }
}
int f[N],f1[N];
int A[N],B[N],C[N];
void init1(int *A,int n,int a){
    for(int i=0;i<=n+1;i++)f[i]=0,f1[i]=0;
    f[a]=1,A[0]=1;
    for(int i=1;i<=D;i++){
        for(int j=1;j<=n;j++)f1[j]=add(f[j-1],f[j+1]);
        long long s=0;
        for(int j=1;j<=n;j++)f[j]=f1[j],s+=f[j],f1[j]=0;
        A[i]=s%mod;
    }
}
int fac[N],ifac[N];
inline int getC(int a,int b){
    if(a<b||b<0)return 0;
    return mul(fac[a],mul(ifac[b],ifac[a-b]));
}
int l0[N],r0[N],l[N],r[N],val[N];
void init(int *A,int n,int a){
    if(n<=1300){
        init1(A,n,a);
        return;
    }
    for(int i=0;i<=205;i++)l0[i]=inf,r0[i]=-inf,l[i]=inf;
    for(int i=a-D;i<=D+a;i++){
        if(i%(n+1)==0)continue;
        int bl;
        if(i>0)bl=(i-1)/(n+1);
        else bl=-i/(n+1)+101;
        l0[bl]=min(l0[bl],i);
        r0[bl]=max(r0[bl],i);
    }
    for(int i=0;i<=D;i++){
        for(int j=0;j<=205;j++){
            if(l0[j]==inf)continue;
            int l1=max(0,l0[j]-a+i+1>>1),r1=min(i,r0[j]-a+i>>1);
            if(l1>r1)continue;
            if(l[j]==inf){
                l[j]=l1,r[j]=r1,val[j]=0;
                for(int k=l1;k<=r1;k++)val[j]=add(val[j],getC(i,k));
            }else{
                val[j]=add(val[j],val[j]),val[j]=add(val[j],add(getC(i-1,l[j]-1),mod-getC(i-1,r[j])));
                while(l[j]<l1)val[j]=add(val[j],mod-getC(i,l[j])),l[j]++;
                while(r[j]<r1)r[j]++,val[j]=add(val[j],getC(i,r[j]));
            }
            if(j&1)A[i]=add(A[i],mod-val[j]);
            else A[i]=add(A[i],val[j]);
        }
    }
}
int main(){
    freopen("walk.in","r",stdin);
    freopen("walk.out","w",stdout);
    int n,m,k,a,b,c;
    cin>>D>>n>>m>>k>>a>>b>>c;
    fac[0]=1;
    for(int i=1;i<=D;i++)fac[i]=mul(fac[i-1],i);
    ifac[D]=qpow(fac[D],mod-2);
    for(int i=D;i>=1;i--)ifac[i-1]=mul(ifac[i],i);
    init(A,n,a);
    init(B,m,b);
    init(C,k,c);
    for(int i=0;i<=D;i++)A[i]=mul(A[i],ifac[i]);
    for(int i=0;i<=D;i++)B[i]=mul(B[i],ifac[i]);
    for(int i=0;i<=D;i++)C[i]=mul(C[i],ifac[i]);
    int o=1<<19;
    ntt(A,o,1);ntt(B,o,1);ntt(C,o,1);
    for(int i=0;i<o;i++)A[i]=mul(mul(A[i],B[i]),C[i]);
    ntt(A,o,-1);
    cout<<mul(A[D],fac[D])<<endl;
    return 0;
}

T3

比较麻烦的大讨论 dp 乱搞题。

首先考虑一个比较直观的 dp:\(f_{0/1,i,j}\) 表示 \(t_i\) 时刻,人/分身在 \(x_i\) 另一个在 \(j\)。状态数是 \(O(nx)\) 的,然后讨论一下转移,每一个位置转移到的是一段区间,于是可以差分来转移,于是我们就得到了一个 \(O(nx)\) 的 dp。结果发现所有部分分都没有限制 x 的范围

然后考虑优化,讨论一下所有的转移,发现有效的转移不是很多,并且操作大多是比较简单的区间赋值,感觉可以优化。

所以我们来详细地讨论一下所有的转移:

\(T=t_i-t_{i-1},D=x_i-x_{i-1}\)

  1. 先考虑从 \(f_{0,i-1}\) 而来的转移,如果 \(D\le T\),那么我们可以从 \(x_{i-1}\) 直接走到 \(x_i\),所以 \(f_{0,i-1,j}\) 可以转移到 \(f_{0,i,j}\)。除此之外,在从 \(x_{i-1}\) 走到 \(x_i\) 的过程中,我们可以放下分身,具体的,我们可以在区间 \(\big[\min(x_i,x_{i-1})-\lfloor \frac {T-D} 2\rfloor,\max(x_i,x_{i-1})+\lfloor \frac {T-D} 2\rfloor\big]\) 放下分身,于是我们要把这一段区间的 \(f_{0,i+1}\) 赋值为 \(1\)。还有一种走的方法是我们先走到 \(x_i\),并在 \(x_i\) 放下分身,然后我们继续走。这样我们转移到的是 \(\big [x_i-(T-D),x_i+(T-D) \big ]\),把 \(f_{1,i}\) 的这个区间赋值为 \(1\)

  2. 再来考虑从 \(f_{1,i-1}\) 而来的转移,还是和上述一样的,大体上有两种走法:一种是从 \(j\) 走到 \(x_i\),并在路上放下分身,但是 \(j\) 很多,我们不能一一转移,但是可以发现,所有 \(j\) 转移到的区间都包含了 \(x_i\),即都是相交的,于是这些区间的并集就是最左和最右的区间的并,最左的区间就是不小于 \(x_i-T\) 的最小的 \(j\),最右的是不大于 \(x_i+T\) 的最大的 \(j\),把他们抓出来对 \(f_{0,i}\) 区间赋值即可。还有一种是从 \(j\) 走到 \(x_i\) 放下分身,然后继续走,最后可以走到的区间是 \(\big [x_i-(T-D),x_i+(T-D)\big ]\),我们要区间尽量大,就要让 \(D\) 尽量小,所以要查找离 \(x_i\) 最近的 \(j\),然后对 \(f_{1,i}\) 区间赋值。最后还有一种情况,如果 \(x_{i-1}=x_i\) 那么分身可以不动,人继续走,那么每个 \(j\) 可以转移到 \([j-T,j+T]\) 这段区间。

列完了所有转移,发现除了最后的操作,其它都是简单区间赋值,而最后一种操作其实相当于将每一个 true 连续段的左右端点延长 \(T\)。于是我们考虑使用 map 维护连续段,然后对连续段的扩展打上 tag,然后再用一个 set 记录相邻两个连续段之间的距离,以便在打 tag 的时候合并连续段。但是因为这个题的至于只有 \(1e9\),所以每次我们暴力地将所有的连续段延长,就 AC 了!好像卡不掉!

#include<bits/stdc++.h>
using namespace std;
const int BS=1<<20|5;
char buf[BS],*P1,*P2;
inline char gc(){
    if(P1==P2)P2=(P1=buf)+fread(buf,1,BS,stdin);
    return P1==P2?EOF:*(P1++);
}
inline int in(){
    int x=0,f=1;char ch=gc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=gc();
    return x*f;
}
const int N=1e6+5,inf=1e9;
int n,pt[N],px[N];
map<int,int> f0,f1,f2;
void insert(map<int,int> &f,int l,int r){
    l=max(l,-inf),r=min(r,inf);
    if(!f.count(r))f[r]=l;
    else if(f[r]<=l)return;
    auto it=f.find(r),it1=it;
    bool flag=0;
    while(it!=f.begin()&&(it1=prev(it))->first>=l){
        l=min(l,it1->second),f.erase(it1);
    }
    it1=next(it);
    if(it1!=f.end()&&it1->second<=r)it1->second=min(it1->second,l),f.erase(it);
    else it->second=l;
}
int findl(map<int,int> &f,int p){
    auto it=f.lower_bound(p);
    if(it!=f.end()&&it->second<=p)return p;
    if(it!=f.begin())return prev(it)->first;
    return inf<<1|1;
}
int findr(map<int,int> &f,int p){
    auto it=f.lower_bound(p);
    if(it!=f.end())return max(p,it->second);
    return inf<<1|1;
}
int main(){
    n=in();
    for(int i=1;i<=n;i++)pt[i]=in(),px[i]=in();
    f0[0]=0,f1[0]=0;
    for(int i=1;i<=n;i++){
        int d=abs(px[i]-px[i-1]),t=pt[i]-pt[i-1];
        int l1=-inf,r1=-inf;
        if(findl(f0,px[i])==px[i])insert(f2,px[i-1]-t,px[i-1]+t);
        if(t<d)f0.clear();
        else if(f0.size()){
            insert(f2,px[i]-(t-d),px[i]+(t-d));
            int l=min(px[i],px[i-1]),r=max(px[i],px[i-1]);
            insert(f0,l-(t-d>>1),r+(t-d>>1));
        }
        if(px[i]==px[i-1]){
            for(auto p:f1){
                int l=max(-inf,p.second-t),r=min(inf,p.first+t);
                insert(f2,l,r);
            }
        }
        int lp=findr(f1,px[i]-t),rp=findl(f1,px[i]+t);
        bool flag=0;
        if(lp>=px[i]-t&&lp<=px[i]+t){
            int d=abs(px[i]-lp),l=min(px[i],lp),r=max(px[i],lp);
            insert(f0,l-(t-d>>1),r+(t-d>>1));
            flag=1;
        }
        if(rp>=px[i]-t&&rp<=px[i]+t){
            int d=abs(px[i]-rp),l=min(px[i],rp),r=max(px[i],rp);
            insert(f0,l-(t-d>>1),r+(t-d>>1));
            flag=1;
        }
        if(flag)insert(f0,px[i-1],px[i-1]);
        lp=findl(f1,px[i]),rp=findr(f1,px[i]);
        d=min(abs(px[i]-lp),abs(px[i]-rp));
        if(d<=t)insert(f2,px[i]-(t-d),px[i]+(t-d));
        swap(f1,f2),f2.clear();
        if(!f1.size()&&!f0.size()){puts("NO");return 0;}
    }
    puts("YES");
    cerr<<clock()<<endl;
    return 0;
}
posted @ 2023-03-13 16:47  蒟蒻_william555  阅读(131)  评论(0)    收藏  举报