// 下标从0开始 n为2的幂次
void FWT(int a[] ,int n){
for (int d = 1 ; d < n ; d <<= 1){
for (int m = d << 1 ,i = 0;i < n ; i+=m){
for (int j = 0 ; j < d ; j++){
int x = a[i+j],y = a[i+j+d];
//xor;
a[i+j] = (x+y) % mod,a[i+j+d] = (x-y+mod)%mod;
//and
//a[i+j]=x+y;
//or
//a[i+j+d]=x+y;
}
}
}
}
void UFWT(int a[],int n){
for (int d = 1 ; d < n ; d<<=1){
for (int m = d <<1, i = 0; i < n; i+=m){
for (int j = 0 ; j < d ; j++){
int x = a[i+j],y = a[i+j+d];
//xor
a[i+j] = 1LL*(x+y)*rev2%mod,a[i+j+d] = (1LL*(x-y)*rev2%mod + mod) % mod;
//and
//a[i+j] = x-y;
//or
//a[i+j+d] = y-x;
}
}
}
}
void solve(int a[],int b[],int n){
FWT(a,n);
FWT(b,n);
for (int i = 0 ; i<n ; i++) a[i]=1LL*a[i]*b[i]%mod;
UFWT(a,n);
}

HDU 5909 这种树DP很常见，如果都是异或值可以这么优化，还有的需要类似FFT的优化。

 1 #include <bits/stdc++.h>
2 const int mod = 1e9+7;
3 int rev2 = (mod+1)/2;
4 const double ex = 1e-10;
5 #define inf 0x3f3f3f3f
6 using namespace std;
7 // 下标从0开始 n为2的幂次
8 void FWT(int a[] ,int n){
9     for (int d = 1 ; d < n ; d <<= 1){
10         for (int m = d << 1 ,i = 0;i < n ; i+=m){
11             for (int j = 0 ; j < d ; j++){
12                 int x = a[i+j],y = a[i+j+d];
13                 //xor;
14                 a[i+j] = (x+y) % mod,a[i+j+d] = (x-y+mod)%mod;
15                 //and
16                 //a[i+j]=x+y;
17                 //or
18                 //a[i+j+d]=x+y;
19             }
20         }
21     }
22 }
23 void UFWT(int a[],int n){
24     for (int d = 1 ; d < n ; d<<=1){
25         for (int m = d <<1, i = 0; i < n; i+=m){
26             for (int j = 0 ; j < d ; j++){
27                 int x = a[i+j],y = a[i+j+d];
28                 //xor
29                 a[i+j] = 1LL*(x+y)*rev2%mod,a[i+j+d] = (1LL*(x-y)*rev2%mod + mod) % mod;
30                 //and
31                 //a[i+j] = x-y;
32                 //or
33                 //a[i+j+d] = y-x;
34             }
35         }
36     }
37 }
38 void solve(int a[],int b[],int n){
39     FWT(a,n);
40     FWT(b,n);
41     for (int i = 0 ; i<n ; i++) a[i]=1LL*a[i]*b[i]%mod;
42     UFWT(a,n);
43 }
44 int a[1300];
45 vector<int> E[1300];
46 int dp[1300][1300];
47 int ans[1300];
48 int tmp[1300];
49 int m;
50 void dfs(int u,int fa){
51     dp[u][a[u]] = 1;
52     for (int i = 0 ; i<E[u].size(); i++){
53         int to = E[u][i];
54         if (to == fa) continue;
55         dfs(to,u);
56         for (int j = 0;j<m; j++){
57             tmp[j] = dp[u][j];
58         }
59         solve(tmp,dp[to],m);
60         for (int j = 0 ; j  < m ;j++){
61             dp[u][j]=(dp[u][j] + tmp[j]) % mod;
62         }
63     }
64     for (int  i = 0 ; i < m; i++){
65         ans[i] = (ans[i] + dp[u][i]) % mod;
66     }
67 }
68 int main()
69 {
70     int t;
71     scanf("%d",&t);
72     while (t--){
73         int n;
74         scanf("%d%d",&n,&m);
75         for (int i = 1; i<=n;i++) E[i].clear();
76         for (int i = 1; i<=n ;i++){
77             scanf("%d",&a[i]);
78         }
79         for (int i=1; i<n ; i++){
80             int u,v;
81             scanf("%d%d",&u,&v);
82             E[u].push_back(v);
83             E[v].push_back(u);
84         }
85         memset(dp,0,sizeof(dp));
86         memset(ans,0,sizeof(ans));
87         dfs(1,0);
88         for (int i = 0; i<m; i++){
89             printf("%d%c",ans[i],i+1==m?'\n':' ');
90         }
91     }
92     return 0;
93 }
View Code

posted @ 2017-10-14 18:31 HITLJR 阅读(...) 评论(...) 编辑 收藏