【树形dp】Distance in Tree
A tree is a connected graph that doesn't contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u,v) are considered to be the same pair.
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
5 2
1 2
2 3
3 4
2 5
4
5 3
1 2
2 3
3 4
4 5
2
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
题目大意:树上有N个点,问多少对不同点对(u,v)最短路为K?
试题分析:设dp[N][K]代表从i走j步能到达多少点。
初始化:dp[i][0]=1;//它不走可以到它自己
转移一步:dp[i][j]=sum(dp[i->son][j-1]);
统计答案分两步,一步是从i走K步能到达的点:dp[i][K]
一步是以i为最近公共祖先的点对:dp[i->son][t-1]*(dp[i][K-t]-dp[i->son][K-t-1]);
因为u,v v,u算一对,所以ans最后加上tmp/2;
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<algorithm>
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const int MAXN=100001;
const int INF=999999;
int N,K;
long long dp[50001][501];
vector<int> vec[50001];
long long ans;
void dfs(int x,int fa){
dp[x][0]=1;
for(int i=0;i<vec[x].size();i++){
if(vec[x][i]==fa) continue;
dfs(vec[x][i],x);
}
for(int i=0;i<vec[x].size();i++){
if(vec[x][i]==fa) continue;
for(int j=1;j<=K;j++) dp[x][j]+=dp[vec[x][i]][j-1];
}
ans+=dp[x][K]; long long tmp=0;
for(int i=0;i<vec[x].size();i++){
if(vec[x][i]!=fa)
for(int j=1;j<K;j++) tmp+=(dp[vec[x][i]][j-1]*(dp[x][K-j]-dp[vec[x][i]][K-j-1]));
}
ans+=(tmp/2);
return ;
}
int main(){
N=read(),K=read();
for(int i=1;i<N;i++){
int u=read(),v=read();
vec[u].push_back(v);
vec[v].push_back(u);
}
dfs(1,-1);
printf("%d\n",ans);
}

浙公网安备 33010602011771号