[BZOJ3037] 创世纪 题解

基环内向树上 dp,不过在这里提供给一种非典型做法。

考虑将环上的每一条边都断开,这样就会形成多棵树,先在这些树上进行树形 \(dp\)。设 \(dp_{i,0/1}\) 表示不选/选 \(i\) 时,\(i\) 子树内的最大选点数。明显方程为:

\[\begin{cases}dp_{u,0}=\sum\limits_{v\in uson}\max(dp_{v,0},dp_{v,1})\\ \\dp_{u,1}=[\sum\limits_{v\in uson}[dp_{v,0}\ge dp_{v,1}]>0]?dp_{u,0}:dp_{u,0}-\min\limits_{v\in uson}(dp_{v,1}-dp_{v,0})\end{cases} \]

接下来,我们开始在环上找答案。考虑断环为链。设 \(f_{i,0/1,0/1}\) 表示在环上的第 \(i\) 个点,选不选,第 \(cnt\) 个点选不选,\(lp_i\) 表示环上第 \(i\) 个点的编号。则转移方程为:

\[\begin{cases} f_{i,0,0/1}=\max(f_{i-1,0,0/1},f_{i-1,1,0/1})+dp_{lp_i,0}\\ f_{i,1,0/1}=\max(f_{i-1,0,0/1}+dp_{lp_i,0}+1,f_{i-1,1,0/1}+dp_{lp_i,1}) \end{cases} \]

时间复杂度瓶颈为并查集(不知道并查集干什么用的,详见上一道题我写的题解),时间复杂度 \(O(n\log n)\)

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e6+5;
int n,m,h[N],to[N],nxt[N];
int cnt,lp[N],v[N],fh[N],ans;
int dp[N][2],f[N][2][2],t[N];
void add(int x,int y){
    to[++m]=y;
    nxt[m]=h[x];
    h[x]=m;
}void init(){
    for(int i=1;i<=n;i++)
        fh[i]=i;
}int find
(int x){
    if(fh[x]==x) return x;
    return fh[x]=find(fh[x]);
}void unite(int x,int y){
    x=find(x);
    y=find(y);
    if(x==y) return;
    fh[y]=x;
}void dp_(int x){
    int f=1,mn=1e9;
    for(int i=h[x];i;i=nxt[i]){
        if(v[to[i]]) continue;
        int y=to[i];dp_(y);
        dp[x][0]+=max(dp[y][0],dp[y][1]);
        if(dp[y][0]>=dp[y][1]) f=0;
        else mn=min(mn,dp[y][1]-dp[y][0]);
    }dp[x][1]=dp[x][0]+1;
    if(f) dp[x][1]-=mn;
}void solve(int rt){
    int x=rt,y=t[rt];cnt=0;
    lp[++cnt]=y;v[y]=1;
    while(y!=x){
        lp[++cnt]=t[y];
        v[t[y]]=1;y=t[y];
    }for(int i=1;i<=cnt;i++) dp_(lp[i]);
    f[0][0][1]=f[0][1][0]=-1e9;
    for(int i=1;i<=cnt;i++){
        f[i][0][0]=max(f[i-1][0][0],f[i-1][1][0])+dp[lp[i]][0];
        f[i][1][0]=max(f[i-1][0][0]+dp[lp[i]][0]+1,f[i-1][1][0]+dp[lp[i]][1]);
        f[i][0][1]=max(f[i-1][0][1],f[i-1][1][1])+dp[lp[i]][0];
        f[i][1][1]=max(f[i-1][0][1]+dp[lp[i]][0]+1,f[i-1][1][1]+dp[lp[i]][1]);
    }ans+=max(f[cnt][0][0],f[cnt][1][1]);
}int main(){
    scanf("%d",&n);init();
    for(int i=1;i<=n;i++){
        cin>>t[i];
        add(t[i],i);
        unite(t[i],i);
    }for(int i=1;i<=n;i++)
        if(find(i)==i) solve(i);
    printf("%d",ans);
    return 0;
}//Kaká
posted @ 2024-04-20 10:20  长安一片月_22  阅读(2)  评论(0编辑  收藏  举报