POJ3417 LCA+树dp
http://poj.org/problem?id=3417
题意:先给出一棵无根树,然后下面再给出m条边,把这m条边连上,然后每次你能毁掉两条边,规定一条是树边,一条是新边,问有多少种方案能使树断裂。
我们考虑加上每一条新边的情况,当一条新边加上之后,原本的树就会成环,环上除了所有的树边要断的话必然要砍掉这条新边才可行。
每一条新边成的环就是u - lca(u,v) - v,对每一条边的覆盖次数++
考虑所有的树边,被覆盖 == 0的时候,意味着单独砍掉这条树边即可,其他随便选一个新边就是一种方案,贡献值 += M;
被覆盖 == 1的时候,意味着砍掉这条树边必须砍掉另一条与他匹配的新边,贡献值 ++
被覆盖 >= 2的时候,这条树边被砍掉是没有意义的,因为不可能同时砍掉两条以上的新边
下面的问题就变成了如何求每一条边的被覆盖次数,我们只要对dp[lca] -= 2,dp[u]++,dp[v]++从根节点向下推,到叶子节点之后回溯,更新dp值即可
这就变成了一个喜闻乐见的树dp、
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Scl(x) scanf("%lld",&x); #define Pri(x) printf("%d\n", x) #define Prl(x) printf("%lld\n",x); #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; const double eps = 1e-9; const int maxn = 1e5 + 10; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; int N,M,tmp,K; int head[maxn],tot,cnt; bool vis[maxn]; int F[maxn * 2],P[maxn],rmq[maxn * 2]; struct Edge{ int to,next; }edge[maxn * 2]; int dp[maxn]; LL sum; struct ST{ int dp[maxn * 2][20]; int mm[maxn * 2]; void init(int n){ mm[0] = -1; for(int i = 1; i <= n ; i ++){ mm[i] = ((i & (i - 1)) == 0)?mm[i - 1] + 1:mm[i - 1]; dp[i][0] = i; } for(int j = 1; j <= mm[n]; j ++){ for(int i = 1; i + (1 << j) - 1 <= n ; i ++){ dp[i][j] = rmq[dp[i][j - 1]] < rmq[dp[i + (1 << (j - 1))][j - 1]]?dp[i][j - 1]:dp[i + (1 << (j - 1))][j - 1]; } } } int query(int a,int b){ if(a > b) swap(a,b); int k = mm[b - a + 1]; return rmq[dp[a][k]] <= rmq[dp[b - (1 << k) + 1][k]]?dp[a][k]:dp[b - (1 << k) + 1][k]; } }st; void init(){ Mem(head,-1); tot = 0; } void add(int u,int v){ edge[tot].next = head[u]; edge[tot].to = v; head[u] = tot++; } void dfs(int u,int pre,int dep){ F[++cnt] = u; rmq[cnt] = dep; P[u] = cnt; for(int i = head[u]; ~i; i = edge[i].next){ int v = edge[i].to; if(v == pre ) continue; dfs(v,u,dep + 1); F[++cnt] = u; rmq[cnt] = dep; } } void LCA_init(int root){ cnt = 0; dfs(root,root,0); st.init(2 * N - 1); } int lca(int u,int v){ return F[st.query(P[u],P[v])]; } int dfs2(int x,int last){ for(int i = head[x]; ~i ; i = edge[i].next){ int to = edge[i].to; if(to == last) continue; dfs2(to,x); dp[x] += dp[to]; if(dp[to] == 1) sum++; else if(!dp[to]) sum += M; } return dp[x]; } int main() { Sca2(N,M); init(); For(i,1,N - 1){ int u,v; Sca2(u,v); add(u,v); add(v,u); } LCA_init(1); For(i,1,M){ int u,v; Sca2(u,v); dp[u]++; dp[v]++; dp[lca(u,v)] -= 2; } dfs2(1,-1); Prl(sum); #ifdef VSCode system("pause"); #endif return 0; }