树形DP+树状数组 HDU 5877 Weak Pair

 1 //树形DP+树状数组 HDU 5877  Weak Pair
 2 // 思路:用树状数组每次加k/a[i],每个节点ans+=Sum(a[i]) 表示每次加大于等于a[i]的值
 3 // 这道题要离散化
 4 
 5 #include <bits/stdc++.h>
 6 using namespace std;
 7 #define LL long long
 8 typedef pair<int,int> pii;
 9 const double inf = 123456789012345.0;
10 const LL MOD =100000000LL;
11 const int N = 2e5+10;
12 const int maxx = 200010; 
13 #define clc(a,b) memset(a,b,sizeof(a))
14 const double eps = 1e-7;
15 void fre() {freopen("in.txt","r",stdin);}
16 void freout() {freopen("out.txt","w",stdout);}
17 inline int read() {int x=0,f=1;char ch=getchar();while(ch>'9'||ch<'0') {if(ch=='-') f=-1; ch=getchar();}while(ch>='0'&&ch<='9') {x=x*10+ch-'0';ch=getchar();}return x*f;}
18 
19 map<LL,LL> ma;
20 LL a[N];
21 LL c[N],b[N];
22 LL in[N];
23 vector<LL> g[N];
24 LL lowbit(LL x){ return x&(-x);}
25 LL add(LL x,int t){
26     while(x>0){
27        c[x]+=t;
28        x-=lowbit(x);
29     }
30 }
31 LL Sum(LL x){
32     LL sum=0;
33     while(x<maxx){
34         sum+=c[x];
35         x+=lowbit(x);
36     }
37     return sum;
38 }
39 
40 LL ans=0;
41 LL n,k;
42 void dfs(LL rt){
43      for(LL i=0;i<(int)g[rt].size();i++){
44          LL v=g[rt][i];
45          ans+=Sum(ma[a[v]]);
46          if(a[v]==0) add(maxx,1);
47          else add(ma[k/a[v]],1);
48          dfs(v);
49          if(a[v]==0) add(maxx,-1);
50          else add(ma[k/a[v]],-1);
51      }
52 }
53 int main(){
54     int T;
55     scanf("%d",&T);
56     while(T--){
57         ma.clear();
58         memset(c,0,sizeof(c));
59         scanf("%I64d%I64d",&n,&k);
60         for(int i=1;i<=n;i++){
61             scanf("%I64d",&a[i]);
62             b[i*2-2]=a[i];
63             if(a[i]!=0) b[i*2-1]=k/a[i];
64             g[i].clear();
65             in[i]=0;
66         }
67         sort(b,b+2*n);
68         int K=unique(b,b+2*n)-b;
69         int cxt=0;
70         for(int i=0;i<K;i++){
71             ma[b[i]]=++cxt;
72         }
73         for(LL i=0;i<n-1;i++){
74             LL u,v;
75             scanf("%I64d%I64d",&u,&v);
76             g[u].push_back(v);
77             in[v]++;
78         }
79         LL rt;
80         for(LL i=1;i<=n;i++){
81             if(in[i]==0){
82                 rt=i;
83                 break;
84             }
85         }
86         ans=0;
87         if(a[rt]==0) add(maxx,1);
88         else add(ma[k/a[rt]],1);
89         dfs(rt);
90         printf("%I64d\n",ans);
91     }
92     return 0;
93 }

 

posted @ 2016-09-11 12:15  yyblues  阅读(243)  评论(0编辑  收藏  举报