一、题目描述:
给定 $n$ 和 $k$,表示有 $n$ 个点,其中有 $k$ 个点是关键点,这 $k$ 个点随机分布。
给出 $n$ 个点的连接方式,保证构成一棵树,求有期望多少个点使得这个点到 $k$ 个关键点的距离之和最小,答案对 $1e9+7$ 取模。
数据范围:$1\leq n\leq 2e5,1\leq k\leq min(n,3)$。
二、解题思路:
$k=1:很明显答案是一,就是这个关键点本身,直接输出即可。$
$k=2:两个关键点之间的路径长度$+1$就是关键点数量,统计路径被经过的次数即可。$
$k=3:枚举每一个点,假设当前枚举到 $u$,考虑计算点 $u$ 是答案点的方案。$
$Situation 1:点 $ $u$ $是关键点$
那么这 $3$ 个关键点必然构成一条链,且点 $u$ 在中间,暴力计算即可。
$Situation 2:点$ $u$ $不是关键点$
那么 $3$ 个关键点必然不在一条链上,且在 $u$ 的不同子树中(假设 $u$ 是根),需要前缀和统计计算。
那么此题已经解决,时间复杂度 $O(n)$。
其实还有许多细节,但我不想再提了。
三、完整代码:
1 #include<iostream> 2 #define N 200010 3 #define lim 200000 4 #define to edge[i].v 5 #define ll long long 6 #define M 1000000007 7 using namespace std; 8 ll n,k,u1,v1,ans; 9 ll s[N],jc[N],inv[N],sum[N]; 10 struct EDGE{ 11 ll v,nxt; 12 }edge[N*2]; 13 ll head[N],cnt; 14 void add(ll u,ll v) 15 { 16 edge[++cnt].v=v; 17 edge[cnt].nxt=head[u]; 18 head[u]=cnt; 19 } 20 ll ksm(ll base,ll p) 21 { 22 ll res=1; 23 while(p) 24 { 25 if(p&1) res*=base,res%=M; 26 base*=base,base%=M,p>>=1; 27 } 28 return res; 29 } 30 void pre_work() 31 { 32 jc[0]=1; 33 for(ll i=1;i<=lim;i++) 34 jc[i]=jc[i-1]*i%M; 35 inv[lim]=ksm(jc[lim],M-2); 36 for(ll i=lim;i>=1;i--) 37 inv[i-1]=inv[i]*i%M; 38 } 39 void update(ll u,ll val) 40 { 41 ans+=val*(n-1)*(n-1-val)%M; 42 ans-=ksm(val,2)*(n-1-val)%M; 43 ans-=val*(sum[u]-ksm(val,2))%M; 44 ans=(ans%M+M)%M; 45 } 46 void dfs1(ll u,ll ff) 47 { 48 for(ll i=head[u];i!=-1;i=edge[i].nxt) 49 if(to!=ff) 50 { 51 dfs1(to,u), s[u]+=s[to]; 52 (sum[u]+=ksm(s[to],2))%=M; 53 } 54 s[u]++,sum[u]+=ksm(n-s[u],2)%M; 55 } 56 void dfs2(ll u,ll ff) 57 { 58 for(ll i=head[u];i!=-1;i=edge[i].nxt) 59 if(to!=ff) 60 (ans+=s[to]*(n-s[to])*2)%=M,dfs2(to,u); 61 } 62 void dfs3(ll u,ll ff) 63 { 64 update(u,n-s[u]),(ans+=3*(n-s[u])*(s[u]-1))%=M; 65 for(ll i=head[u];i!=-1;i=edge[i].nxt) 66 if(to!=ff) 67 { 68 update(u,s[to]);dfs3(to,u); 69 (ans+=3*s[to]*(n-s[to]-1))%=M; 70 } 71 } 72 void baoli() 73 { 74 dfs1(1,0);dfs2(1,0); 75 cout<<(ans+n*n-n)%M*ksm((n*n-n)%M,M-2)%M<<'\n'; 76 } 77 void rwork() 78 { 79 dfs1(1,0);dfs3(1,0); 80 cout<<ans*ksm((n*n-n)%M*(n-2)%M,M-2)%M<<'\n'; 81 } 82 int main() 83 { 84 cin>>n>>k; 85 pre_work(); 86 for(ll i=1;i<=n;i++) 87 head[i]=-1; 88 for(ll i=1;i<n;i++) 89 { 90 cin>>u1>>v1; 91 add(u1,v1); 92 add(v1,u1); 93 } 94 if(k==1) 95 { 96 cout<<1<<'\n'; 97 return 0; 98 } 99 if(k==2) baoli(); 100 if(k==3) rwork(); 101 return 0; 102 }
2023.5.19 upt:
1 #include<iostream> 2 #define N 200010 3 #define to edge[i].v 4 #define ll long long 5 #define M 1000000007 6 using namespace std; 7 ll n,k,u1,v1,ans,s[N]; 8 struct EDGE{ 9 ll v,nxt; 10 }edge[N*2]; 11 ll head[N],cnt; 12 void add(ll u,ll v) 13 { 14 edge[++cnt].v=v; 15 edge[cnt].nxt=head[u]; 16 head[u]=cnt; 17 } 18 ll ksm(ll base,ll p) 19 { 20 ll res=1; 21 while(p) 22 { 23 if(p&1) res*=base,res%=M; 24 base*=base,base%=M,p>>=1; 25 } 26 return res; 27 } 28 void dfs(ll u,ll ff) 29 { 30 s[u]=1; 31 for(ll i=head[u];i!=-1;i=edge[i].nxt) 32 if(to!=ff) 33 { 34 dfs(to,u),s[u]+=s[to]; 35 (ans+=s[to]*(n-s[to])*2)%=M; 36 } 37 } 38 int main() 39 { 40 ios::sync_with_stdio(false); 41 cin.tie(0);cout.tie(0); 42 cin>>n>>k; 43 for(ll i=1;i<=n;i++) 44 head[i]=-1; 45 for(ll i=1;i<n;i++) 46 { 47 cin>>u1>>v1; 48 add(u1,v1); 49 add(v1,u1); 50 } 51 if(k&1) 52 { 53 cout<<1<<'\n'; 54 return 0; 55 } 56 dfs(1,0);ll t=n*n-n; 57 cout<<(ans+t)%M*ksm(t%M,M-2)%M<<'\n'; 58 return 0; 59 }
四、写题心得:
这个题是自己想出来的,很不错。可是很烦的一点就是比赛之后 $C$ 题居然 $Main$ $Test$ $Run$ $Time$ $Error$ 了,导致从 $1400$ 名掉到 $3500$ 名,还掉分了!
而且这个 $D1$ 我的一个同学写的比较简单,比我的代码短多了,还得继续加油啊!拜拜!
2023.5.19 upt:
我说那个同学代码怎么那么短。
哈哈哈,笑死我了,我居然没看出来 $k=3$ 时答案必然为 $1$。我真的哭死。
比赛时白白浪费一堆时间想 $k=3$ 的情况,还没写出来,真的笑死我了$qwq$!
浙公网安备 33010602011771号