Gym 100341C AVL Trees NTT

AVL Trees

题目连接:

(http://codeforces.com/gym/100341)

题意

avl树是每棵子树的左右子树高度之差小于等于1,给你节点个数和树高,问有多少种树

题解:

很轻松地我们写出dp:
dp[h][n]表示树高h+1,n个节点的答案
\(dp[h][n]=\sum_{i=0}^{n-1}dp[h-1][i]*(dp[h-1][n-1-i]+2*dp[h-2][n-1-i])\)
一眼看出这是n次ntt,最暴力3h次变换,优化一下2h次,前h/2由于有效的dp值很小,可以暴力算,那么就是h次;
假设我们学过信号与系统,我们知道\(DP[h]=DP[h-1]\times(DP[h-1]+2*DP[h-2])\times\delta[1]\),那么我们只要3次变换就可以完成;
我们观察dp[0],只有dp[0][0]=1,那么\(dp[0]=\delta[0],DP[0]=\{1,1,1,1,..\}\)
同理只有dp[1][1]=1,那么\(dp[1]=\delta[1],DP[1]=\{{g}^{0},{g}^{1},{g}^{2},..\}\)
前两次变换就可以直接赋值,我们观察最后一次反变换,由于我们只要求x[n]的值,我们不需要把整个序列变换回来
根据变换式$$x[n]=\frac{1}{N}\sum_{k=0}{N-1}X[k]*g$$
我们可以O(n)求出x[n]
那么这道题我们可以用O(hn)的算法完美解决,不用ntt变换

代码

//#include <bits/stdc++.h>
#include <stdio.h>
#include <iostream>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include <algorithm>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <stack>
#include <bitset>
#include <string>
#include <time.h>
using namespace std;
long double esp=1e-11;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define fi first
#define se second
#define all(a) (a).begin(),(a).end()
#define cle(a) while(!a.empty())a.pop()
#define mem(p,c) memset(p,c,sizeof(p))
#define mp(A, B) make_pair(A, B)
#define pb push_back
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
typedef long long int LL;
const long double PI = acos((long double)-1);
const LL INF=0x3f3f3f3fll;
const int MOD =1000000007ll;
const int maxn=100100;
const int NUM = 1<<17;
int  wn[NUM];
    // p                   | deg | g        长度为2^k,且N|(p-1),p=c*(1<<k)+1,g为原根,g^phi(p)=1 %p的最小g
    // 469762049             26    3
    // 998244353             23    3
    // 1004535809            21    3
    // 1107296257            24    10
    // 10000093151233        26    5
    // 1000000523862017      26    3
    // 1000000000949747713   26    2
LL mu(LL a,LL b,LL P)
{
    LL ans=1;
    while(b)
    {
        if(b&1)
            ans=ans*a%P;
        a=a*a%P;
        b>>=1;
    }
    return ans;
}
void GetWn(int G,int P,int len)
{
    wn[0] = 1, wn[1] = mu(G, (P - 1) / len, P);
	for(int i = 2; i < len; i++)
		wn[i] = 1LL * wn[i - 1] * wn[1] % P;
}
int dp[17][1<<17];
int main()
{
    //freopen("in.txt", "r", stdin);
    freopen("avl.in", "r", stdin);
    freopen("avl.out", "w", stdout);
    //::iterator iter;                  %I64d
    //for(int x=1;x<=n;x++)
    //for(int y=1;y<=n;y++)
    //scanf("%d",&a);
    //printf("%d\n",ans);
    int n,h;
    scanf("%d%d",&n,&h);
    if(n>=1<<(h+1))
	{
		printf("%d\n",0);
		return 0;
	}
	LL N=1ll<<(h+1),g=10,P=786433ll;
    GetWn(g,P,N);
    for(int x=0;x<N;x++)
		dp[0][x]=1,dp[1][x]=wn[x];
	for(int x=2;x<=h+1;x++)
		for(int y=0;y<N;y++)
			dp[x][y]=1ll*dp[x-1][y]*(dp[x-1][y]+2ll*dp[x-2][y])%P*wn[y]%P;
	int ans=0,w=mu(wn[n],P-2,P);
	int t=1;
	for(int x=0;x<N;x++)
		ans=(ans+1ll*t*dp[h+1][x])%P,t=1ll*t*w%P;
	ans=1ll*ans*mu(N,P-2,P)%P;
	printf("%d\n",ans);
    return 0;
}

posted @ 2016-08-24 20:53  femsub  阅读(361)  评论(0编辑  收藏  举报