hdu 5593 ZYB's Tree 树形dp

ZYB's Tree

Time Limit: 3000/1500 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)


Problem Description
ZYB has a tree with N nodes,now he wants you to solve the numbers of nodes distanced no more than K for each node.
the distance between two nodes(x,y) is defined the number of edges on their shortest path in the tree.

To save the time of reading and printing,we use the following way:

For reading:we have two numbers A and B,let fai be the father of node i,fa1=0,fai=(Ai+B)%(i1)+1 for i[2,N] .

For printing:let ansi be the answer of node i,you only need to print the xor sum of all ansi.
 

 

Input
In the first line there is the number of testcases T.

For each teatcase:

In the first line there are four numbers N,K,A,B

1T5,1N500000,1K10,1A,B1000000
 

 

Output
For T lines,each line print the ans.

Please open the stack by yourself.

N100000 are only for two tests finally.
 

 

Sample Input
1 3 1 1 1
 

 

Sample Output
3
 

 

Source
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<iostream>
#include<cstdio>
#include<cmath>
#include<string>
#include<queue>
#include<algorithm>
#include<stack>
#include<cstring>
#include<vector>
#include<list>
#include<set>
#include<map>
#include<bitset>
#include<time.h>
using namespace std;
#define LL long long
#define pi (4*atan(1.0))
#define eps 1e-4
#define bug(x)  cout<<"bug"<<x<<endl;
const int N=5e5+10,M=1e6+10,inf=1e9+7,MOD=1e9+7;
const LL INF=1e18+10,mod=1e9+7;

vector<int>edge[N];
int n,a,b,k;;
int dp[N][12][2];
void dfs(int u)
{
    dp[u][0][0]=1;
    for(int i=0;i<edge[u].size();i++)
    {
        int v=edge[u][i];
        dfs(v);
        for(int j=1;j<=k;j++)
        dp[u][j][0]+=dp[v][j-1][0];
    }

}
void dfs2(int u)
{
    for(int i=0;i<edge[u].size();i++)
    {
        int v=edge[u][i];
        dp[v][0][1]=dp[v][0][0];
        for(int i=1;i<=k;i++)
        {
            if(i<2)dp[v][i][1]=dp[u][i-1][1]+dp[v][i][0];
            else dp[v][i][1]=dp[u][i-1][1]+dp[v][i][0]-dp[v][i-2][0];
        }
        dfs2(v);
    }
}
int main()
{
    int T,cas=1;
    scanf("%d",&T);
    while(T--)
    {
        memset(dp,0,sizeof(dp));
        scanf("%d%d%d%d",&n,&k,&a,&b);
        for(int i=1;i<=n;i++)
            edge[i].clear();
        for(int i=2;i<=n;i++)
        {
            int x=(1LL*a*i+b)%(i-1)+1;
            edge[x].push_back(i);
        }
        dfs(1);
        for(int i=0;i<=k;i++)dp[1][i][1]=dp[1][i][0];
        dfs2(1);
        int ans=0;
        for(int i=1;i<=n;i++)
        {
            int sum=0;
            for(int j=0;j<=k;j++)
                sum+=dp[i][j][1];
            ans^=sum;
        }
        printf("%d\n",ans);
    }
    return 0;
}

 

 

posted @ 2017-10-08 15:25  jhz033  阅读(239)  评论(0编辑  收藏  举报