ABC329G Delivery on Tree
statement
有一个 \(n\) 个节点的二叉树和 \(m\) 个球,球初始在 \(s_i\),目标是 \(t_i\)。
你现在要以 \(1\) 为根遍历这棵树,每到一个点(包括到达和回溯),可以选择拿起该点的若干个球(如果有的话),或者选择你手中的若干个球并将其放下。
你需要保证每时每刻你手中的球数不超过 \(k\),求遍历方案数,对 \(998244353\) 取模。
\(n\le 10^4,m\le 2\times10^5,k\le 10^3\)。
solution
对于球 \(i\)(暂且先考虑 \(s_i,t_i\) 不互为祖孙关系的点),我们在离开 \(s_i\) 的时候带上这个球,在到达 \(t_i\) 时放下这个球,设 \(lca\) 为 \(s_i,t_i\) 的最近公共祖先,那么球经过的路线可以拆分为 \(s_i\rightarrow lca,lca\rightarrow t_i\)。对于每一条向上经过的边 \((u,v)\)(钦定 \(dep_u>dep_v\)),我们肯定会带走 \(u\) 子树内某些球。对于每一条向下经过的边,我们肯定会带走 \(u\) 子树外的某些球。对于每条边,分别求出向上、向下的贡献 \(a_i,b_i\),那么到该点时,手上的球数的最小值就是 \(\max(a_i,b_i)\),令 \(c_i=\max(a_i,b_i)\)。如果存在 \(c_i>k\),那么无解,输出 0 即可。
接下来考虑 dp。设 \(f_{x,i}\) 表示 \(x\) 的子树内,手上的球的最大数量为 \(i\) 的方案数。
转移考虑讨论一下该点的儿子数量:
- 若为 \(0\)(即是叶子节点),\(f_{x,c_x}=1\)。
- 若为 \(1\),\(f_{x,\max(i,c_x)}=\sum f_{u,i}\)。
- 若为 \(2\):
- 考虑先走哪一边的子树。设两个儿子分别为 \(u,v\),再设先走 \(u\) 给 \(v\) 子树的贡献为 \(d1\),先走 \(v\) 子树的贡献为 \(d2\),\(f_{x,\max(i+d2,j+d1,\textit{\textbf{c}})}=\sum_{i,j}f_{u,i}\times f_{v,j}\),这里 \(c\) 就是 \(c_x\)。
不处理 \(c_i\) 会 WA 掉 9 个点
- 考虑先走哪一边的子树。设两个儿子分别为 \(u,v\),再设先走 \(u\) 给 \(v\) 子树的贡献为 \(d1\),先走 \(v\) 子树的贡献为 \(d2\),\(f_{x,\max(i+d2,j+d1,\textit{\textbf{c}})}=\sum_{i,j}f_{u,i}\times f_{v,j}\),这里 \(c\) 就是 \(c_x\)。
补充:对于每个 \(s_i,t_i\) 不互为祖孙关系的球,设 \(s_i\) 在 \(lca\) 的儿子 \(u\) 的子树内,\(t_i\) 在 \(lca\) 的儿子 \(v\) 的子树内,则将 \(s_i\rightarrow u\) 的链的 \(d1\) 加 1,将 \(t_i\rightarrow v\) 的链的 \(d2\) 加 1。
时间复杂度 \(O(nk^2)\),瓶颈在于二叉节点的转移,过不掉。
优化:令 \(i+d2\ge j+d1\),枚举 \(i\),则能转移到 \(f_x\) 上的 \(j\) 一定是一段前缀,前缀和优化掉即可。另一种情况反过来转移就行。注意另一种情况不要取等号,要不然会算重。
时间复杂度 \(O(nk)\)。
附了一些注释和优化前的 dp 式子。
点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#define x1 xx1
#define y1 yy1
#define IOS ios::sync_with_stdio(false)
#define ITIE cin.tie(0);
#define OTIE cout.tie(0);
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P__ puts("")
#define PU puts("--------------------")
#define popc __builtin_popcount
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
#define gc getchar
#define pc putchar
#define pb emplace_back
#define rep(a,b,c) for(int a=b;a<=c;a++)
#define per(a,b,c) for(int a=b;a>=c;a--)
#define reprange(a,b,c,d) for(int a=b;a<=c;a+=d)
#define perrange(a,b,c,d) for(int a=b;a>=c;a-=d)
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
#define int long long
//#define int __int128
using namespace std;
bool greating(int x,int y){return x>y;}
bool greatingll(long long x,long long y){return x>y;}
bool smallingll(long long x,long long y){return x<y;}
inline int rd(){
int x=0,f=1;int ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}return x*f;
}
inline void write(int x,char ch='\0'){
if(x<0){x=-x;putchar('-');}
int y=0;char z[40];
while(x||!y){z[y++]=x%10+48;x/=10;}
while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=1e4+5,maxm=2e5+5,maxk=2e3+5,inf=0x3f3f3f3f,mod=998244353;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,m,k;
vector<int>G[maxn];
int st[maxm],ed[maxm];
namespace LowestCommonAncestor{
int lg[maxn],f[maxn][20],dep[maxn];
void dfs(int x,int y){
dep[x]=dep[y]+1,f[x][0]=y;
rep(i,1,lg[dep[x]])f[x][i]=f[f[x][i-1]][i-1];
for(int u:G[x])dfs(u,x);
}
int LCA(int x,int y){
if(dep[x]<dep[y])swap(x,y);
while(dep[x]!=dep[y])x=f[x][lg[dep[x]-dep[y]]-1];
if(x==y)return x;
per(i,15,0)if(f[x][i]^f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
void init(){
rep(i,1,n){
lg[i]=lg[i-1];
if((1<<lg[i])==i)lg[i]++;
}
dfs(1,0);
}
}
//树上差分求的贡献
int a[maxn],b[maxn],c[maxn],d[maxn];
void dfs_1(int x,int y){
for(int u:G[x])dfs_1(u,x),a[x]+=a[u],b[x]+=b[u],c[x]+=c[u],d[x]+=d[u];
}
int f[maxn][maxk];
int g[2][maxk];
int vis[maxn];//限定顺序
/*
0:无限制 -1:先左后右 1:先右后左
*/
void add(int &x,int y){x+=y;if(x>mod)x-=mod;}
int get(int x,int y){return y<0?0:g[x][y];}
void dp(int x){
if(G[x].size()==0){
f[x][a[x]]=1;
}else if(G[x].size()==1){
int u=G[x][0];dp(u);
rep(i,0,k)add(f[x][max(i,a[x])],f[u][i]);
}else{
int u=G[x][0],v=G[x][1];dp(u),dp(v);
g[0][0]=f[u][0];
rep(i,1,k)g[0][i]=(g[0][i-1]+f[u][i])%mod;
g[1][0]=f[v][0];
rep(i,1,k)g[1][i]=(g[1][i-1]+f[v][i])%mod;
if(vis[x]!=1){
// rep(i,0,k)rep(j,0,k){
// add(f[x][max(i+d[v],j+c[u])],f[u][i]*f[v][j]%mod);
// }
rep(i,0,k){
add(f[x][max(a[x],i+d[v])],f[u][i]*get(1,i+d[v]-c[u])%mod);
}
rep(j,0,k){
add(f[x][max(a[x],j+c[u])],f[v][j]*get(0,j+c[u]-d[v]-1)%mod);
}
}
if(vis[x]!=-1){
// rep(i,0,k)rep(j,0,k){
// add(f[x][max(i+d[u],j+c[v])],f[v][i]*f[u][j]%mod);
// }
rep(i,0,k){
add(f[x][max(i+d[u],a[x])],f[v][i]*get(0,i+d[u]-c[v])%mod);
}
rep(j,0,k){
add(f[x][max(j+c[v],a[x])],f[u][j]*get(1,j+c[v]-d[u]-1)%mod);
}
}
}
}
void solve_the_problem(){
n=rd(),m=rd(),k=rd();
rep(i,2,n){int x=rd();G[x].pb(i);}
rep(i,1,m)st[i]=rd(),ed[i]=rd();
LowestCommonAncestor::init();
rep(i,1,m){
int lca=LowestCommonAncestor::LCA(st[i],ed[i]);
a[st[i]]++,a[lca]--,b[ed[i]]++,b[lca]--;
}
rep(i,1,m){
int lca=LowestCommonAncestor::LCA(st[i],ed[i]);
int belx=LowestCommonAncestor::LCA(G[lca][0],st[i])==lca?1:0;
int bely=LowestCommonAncestor::LCA(G[lca][0],ed[i])==lca?1:0;
if(st[i]^lca)c[st[i]]++,c[G[lca][belx]]--;
if(ed[i]^lca)d[ed[i]]++,d[G[lca][bely]]--;
}
dfs_1(1,0);
rep(i,1,n){
a[i]=max(a[i],b[i]);
if(a[i]>k){
puts("0");return;
}
}
rep(i,1,m){
int lca=LowestCommonAncestor::LCA(st[i],ed[i]);
int belx=LowestCommonAncestor::LCA(G[lca][0],st[i])==lca?1:0;
int bely=LowestCommonAncestor::LCA(G[lca][0],ed[i])==lca?1:0;
if(st[i]==lca||ed[i]==lca)continue;
if(vis[lca]&&vis[lca]!=belx-bely){
puts("0");return;
}
vis[lca]=belx-bely;
}
dp(1);
int ans=0;
rep(i,0,k)ans=(ans+f[1][i])%mod;
write(ans);
}
bool Med;
signed main(){
// freopen(".in","r",stdin);freopen(".out","w",stdout);
// fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
int _=1;while(_--)solve_the_problem();
}
/*
*/

浙公网安备 33010602011771号