[LOJ3109][TJOI2019]甲苯先生的线段树:DP

分析

首先,请允许我 orz HN队长zsy。链接

我们发现树上的链有两种类,一类是直上直下的,一类不是直上直下的(废话)。并且,如果我们确定了左侧和右侧的链的长度和整条链上所有节点的编号之和,那么这个链的深度最浅的的节点的编号也是可以唯一地确定的。(也有可能不存在这样的节点,判掉就好)

以第二类链为例,我们可以枚举左侧链和右侧链的长度,令深度最浅的节点的编号为\(x\),那么我们发现这条链的编号之和的下界可以写成\(kx+b\)的形式。于是我们可以求出\(x=\lfloor\frac{sum-b}{k}\rfloor\),然后类似数位DP那样决策左侧链和右侧链的每个位置分别向左儿子还是右儿子走就好了,这个过程可以通过记忆化搜索实现。

时间复杂度不会算。

代码

#include <bits/stdc++.h>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int)a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline LL read(){
    LL x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int d,c;
LL a,b,n;
std::map<LL,LL> mp1[55],mp2[55][55];

inline LL solve1(LL x,LL y){
    LL ret=0;
    while(x!=y){
        if(x<y)std::swap(x,y);
        ret+=x;x>>=1;
    }
    ret+=x;
    return ret;
}

LL dfs1(int x,LL w){
    if(w<0)return 0;
    if(w>(1ll<<(x+1))-2-x)return 0;
    if(!x)return w==0;
    if(mp1[x].find(w)!=mp1[x].end())return mp1[x][w];
    return mp1[x][w]=dfs1(x-1,w)+dfs1(x-1,w-(1ll<<x)+1);
}

LL dfs2(int l,int r,LL w){
    if(l>r)std::swap(l,r);
    if(w<0)return 0;
    if(w>(1ll<<(l+1))+(1ll<<(r+1))-4-l-r)return 0;
    if(!r)return w==0;
    if(mp2[l][r].find(w)!=mp2[l][r].end())return mp2[l][r][w];
    return mp2[l][r][w]=dfs2(l,r-1,w)+dfs2(l,r-1,w-(1ll<<r)+1);
}

int main(){
    int T=read();
    while(T--){
        d=read(),a=read(),b=read(),c=read(),n=(1<<d)-1;
        LL len=solve1(a,b);
        if(c==1){printf("%lld\n",len);continue;}
        LL ans=0;
        rin(i,0,d-1){
            LL k=(1ll<<(i+1))-1,x=len/k;
            if(k<=len&&(int)log2((long double)x)+1+i<=d){
                ans+=dfs1(i,len%k);
            }
        }
        rin(l,1,d-1)rin(r,1,d-1){
            LL k=(1ll<<(l+1))+(1ll<<(r+1))-3,b=(1ll<<r)-1,x=(len-b)/k;
            if(k+b<=len&&(int)log2((long double)x)+1+std::max(l,r)<=d){
                ans+=dfs2(l-1,r-1,(len-b)%k);
            }
        }
        printf("%lld\n",ans-1);
    }
    return 0;
}

posted on 2019-05-08 07:48  ErkkiErkko  阅读(...)  评论(... 编辑 收藏

统计