HDU 5735 - Born Slippy

题意:

  一棵 n 个节点的根树,i 节点权重 wi

  对每一个节点s,找到这样一个长 m 的标号序列 v :

    1. vi是vi-1 的祖先

    2. f[s] = w[vi] + ∑(i=2, m) (w[vi] opt w[vi-1]) 最大

 

  要求输出:S = ∑(i=1, n) (i * f[i])  (mod 1e9 + 7)

  

  opt给出,为 & , ^ , | 的任意一种

  数据范围: 2<= n <= 2^16 , 2 <= wi <= 2^16


分析:

    普通转移方程: DP[i] = max(DP[j] + w[i] opt w[j] ), j 为 i 的祖先。

  折半:

    将权值的高低八位拆分: w[i] = (ai<<8) + bi

    所以转移方程变式: DP[i] = max(DP[j] + (ai opt aj)*(1<<8) + (bi opt bj) ), j 为 i 的祖先。

      再拆  tmp[ajbi] = max(DP[j] + (bi opt bj) )  bi 为 j 点的某个子孙节点的低八位

        DP[i] = max( tmp[ajbi]+ ( (ai opt aj)<<8) ) )  aj 为 i 点的某个祖父节点的高八位 

  

  枚举:

    辅助数组 F[a][b] 表示某点权值(不管哪个点)低八位为 b 时的所有权值高八位为 a 的祖先 j 中 DP[j] + (bi opt bj) 的最大值

    即   F[a][b] = max (DP[j] + (bi opt bj) ) , w[j]>>8 = a .

 

    对于每个 i ,  w[i] = (ai<<8) + bi , 枚举 F[x][bi] , DP[i] = max(F[x][bi] + (x opt ai) ) , 0 <= x <= 2^8 - 1 且 x为祖先高八位.

    之后再更新 F[ai][x] = max (DP[i] + (x opt bi) ) ,0 <= x <= 2^8 - 1 .

 

  用DFS就可以套在树上。

  注意要还原

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <vector>
 5 using namespace std;
 6 #define LL long long
 7 const int MOD = 1e9+7;
 8 const int MAXN = 1<<16+5;
 9 LL f[1<<8+1][1<<8+1],w[MAXN],tmp[MAXN][1<<8+1];//tmp:回溯 
10 LL ans;
11 int n,fa[1<<8+1];//fa:是否为祖先 
12 char op[5];
13 vector<int> g[MAXN];
14 LL opt(LL a,LL b)
15 {
16     if (op[0]=='A') return a&b;
17     else if (op[0]=='X') return a^b;
18     else return a|b;
19 }
20 void DFS(int x)
21 {
22     int a = w[x] >> 8,b = w[x] & 255;
23     LL DPx=0;
24     for (int i = 0;i <= 255; i++) tmp[x][b] = f[a][b];
25     for (int i = 0;i <= 255; i++) if(fa[i]) DPx = max(DPx, f[i][b] + ( opt(a,i)<<8 ) );
26     fa[a]++;
27     ans = (ans + x * ( w[x] + DPx ) ) % MOD;
28     for (int i=0;i<g[x].size();i++) DFS(g[x][i]);
29     for (int i = 0;i <= 255; i++) f[a][b] = tmp[x][b];//回溯 
30 }
31 int main()
32 {
33     int t;
34     scanf("%d",&t);
35     while (t--)
36     {
37         scanf("%d%s",&n,op);
38         for (int i=1;i<=n;i++) g[i].clear();
39         for (int i=1;i<=n;i++) scanf("%lld",&w[i]);
40         memset(f,0,sizeof(f));
41         memset(fa,0,sizeof(fa)); 
42         for (int i=2;i<=n;i++)
43         {
44             int x;
45             scanf("%d",&x);
46             g[x].push_back(i);
47         }
48         DFS(1);
49         printf("%lld\n",ans);
50     } 
51 }

 

posted @ 2016-08-08 17:03  nicetomeetu  阅读(140)  评论(0编辑  收藏  举报