【bzoj4044】[Cerc2014] Virus synthesis【回文自动机+倍增】

[Cerc2014] Virus synthesis

Description

你要用ATGC四个字母用两种操作拼出给定的串:
1.将其中一个字符放在已有串开头或者结尾
2.将已有串复制,然后reverse,再接在已有串的头部或者尾部
一开始已有串为空。求最少操作次数。
len<=100000

题解:
这道题我是乱搞的,时间复杂度O(n log n),正解应该是O(n)的,目前b站倒数第二。= =
首先可以得到最后的字符串一定是由一个回文串再添上一些字符得到的。于是我们只要得到每个回文串最要要多少次才能得到,然后取最优值就可以了。
就上回文自动机了。
可以让dp[i]表示i这个状态的回文串最少要几次才能得到。
可以得到:
len[i]为奇数,则它一定不能经过翻转得到。方程:dp[i]=dp[fa[i]]+2,代表i这个状态是在上一个状态的基础上在左右两边各添加一个相同的字符得到的。
len[i]为偶数,方程:dp[i]=min(dp[fa[i]]+1,dp[half]len[half]+len[i]/2+1) half是任意一个长度不超过i的长度的一半的i的回文后缀。
分别代表两种情况:xSSx,SxxS。只要取个min就可以了。
注意一下,当fa[i]时空串,即串长为2时,方程应该是dp[i]=2
只需要最后bfs一次dp就好啦!
求half我是乱搞的,用倍增跳fail数组,再维护一个f数组代表每个状态到根的路径上的dp[i]len[i]的最小值,就可以很方便地统计啦!细节详见代码。
注意,同一层应该先更新长度为奇数的那棵树的节点,因为长度为偶数的节点的fail可能会会跳到长度为奇数的那棵树的节点。事实上实现很简单,把1先入队,再把0入队就可以了。
UPD:桑森,在b站A了,在xsy上T了,时间卡得好紧,多个log都不行QAQ。
code:

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
int t,l;
char s[N];
queue<int> q;
struct Pam{
    int n,tot,tmp,last,now,s[N],len[N],dep[N],f[N],dp[N],fail[N][20],ch[N][4];
    int newnode(int l){
        memset(ch[tot],0,sizeof(ch[tot]));
        memset(fail[tot],0,sizeof(fail[tot]));
        len[tot]=l;
        return tot++;
    }
    void init(){
        tot=n=last=0;
        s[0]=-1;
        newnode(0);
        newnode(-1);
        f[0]=0;
        dp[0]=0;
        dp[1]=-1;
        fail[0][0]=1;
    }
    int getfail(int x){
        while(s[n-len[x]-1]!=s[n]){
            x=fail[x][0];
        }
        return x;
    }
    void insert(int x){
        if(x=='A'){
            x=0;
        }else if(x=='T'){
            x=1;
        }else if(x=='C'){
            x=2;
        }else{
            x=3;
        }
        s[++n]=x;
        tmp=getfail(last);
        if(!ch[tmp][x]){
            now=newnode(len[tmp]+2);
            fail[now][0]=ch[getfail(fail[tmp][0])][x];
            dep[now]=dep[tmp]+1;
            for(int i=1;(1<<i)<=dep[now];i++){
                fail[now][i]=fail[fail[now][i-1]][i-1];
            }
            ch[tmp][x]=now;
        }
        last=ch[tmp][x];
    }
    int solve(){
        int ans=l,u,v,x;
        q.push(1);
        q.push(0);
        while(!q.empty()){
            u=q.front();
            q.pop();
            for(int i=0;i<4;i++){
                v=ch[u][i];
                if(!v){
                    continue;
                }
                if(len[v]%2==0){
                    if(u){
                        dp[v]=dp[u]+1;  
                    }else{
                        dp[v]=dp[u]+2;
                    }
                    x=v;
                    for(int i=17;i>=0;i--){
                        if(len[fail[x][i]]*2>len[v]){
                            x=fail[x][i];
                        }
                    }
                    x=fail[x][0];
                    dp[v]=min(dp[v],f[x]+len[v]/2+1);
                }else{
                    dp[v]=dp[u]+2;
                }
                f[v]=min(f[u],dp[v]-len[v]);
                ans=min(ans,dp[v]+l-len[v]);
                q.push(v);
            }
        }
        return ans;
    }
}pam;
int main(){
    scanf("%d",&t);
    while(t--){
        scanf("%s",s+1);
        l=strlen(s+1);
        pam.init();
        for(int i=1;i<=l;i++){
            pam.insert(s[i]);
        }
        printf("%d\n",pam.solve());
    }
    return 0;
}

UPD:参考了一下网上的题解,发现可以证明,只要从half[fa[i]]开始跳就可以了,于是就修改了一下我的代码。现在的做法是,倍增优化跳fail,但还是非常慢。代码:

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
int t,l,lg2[N];
char s[N];
queue<int> q;
struct Pam{
    int n,tot,tmp,last,now,s[N],len[N],dep[N],f[N],dp[N],half[N],fail[N][20],ch[N][4];
    int newnode(int l){
        memset(ch[tot],0,sizeof(ch[tot]));
        memset(fail[tot],0,sizeof(fail[tot]));
        len[tot]=l;
        return tot++;
    }
    void init(){
        tot=n=last=0;
        s[0]=-1;
        newnode(0);
        newnode(-1);
        f[0]=0;
        dp[0]=0;
        dp[1]=-1;
        fail[0][0]=1;
    }
    int getfail(int x){
        while(s[n-len[x]-1]!=s[n]){
            x=fail[x][0];
        }
        return x;
    }
    void insert(int x){
        if(x=='A'){
            x=0;
        }else if(x=='T'){
            x=1;
        }else if(x=='C'){
            x=2;
        }else{
            x=3;
        }
        s[++n]=x;
        tmp=getfail(last);
        if(!ch[tmp][x]){
            now=newnode(len[tmp]+2);
            fail[now][0]=ch[getfail(fail[tmp][0])][x];
            dep[now]=dep[tmp]+1;
            for(int i=1;(1<<i)<=dep[now];i++){
                fail[now][i]=fail[fail[now][i-1]][i-1];
            }
            if(len[now]<=2){
                half[now]=fail[now][0];
            }else{
                int f=half[tmp];
                for(int i=lg2[dep[f]];i>=0;i--){
                    if((len[fail[f][i]]+2)*2>len[now]){
                        f=fail[f][i];
                    }
                }
                if((len[f]+2)*2>len[now]){
                    f=fail[f][0];
                }
                while(s[n-len[f]-1]!=s[n]){
                    f=fail[f][0];
                }
                half[now]=ch[f][x];
            }
            ch[tmp][x]=now;
        }
        last=ch[tmp][x];
    }
    int solve(){
        int ans=l,u,v;
        q.push(1);
        q.push(0);
        while(!q.empty()){
            u=q.front();
            q.pop();
            for(int i=0;i<4;i++){
                v=ch[u][i];
                if(!v){
                    continue;
                }
                if(len[v]%2==0){
                    if(u){
                        dp[v]=dp[u]+1;  
                    }else{
                        dp[v]=dp[u]+2;
                    }
                    dp[v]=min(dp[v],f[half[v]]+len[v]/2+1);
                }else{
                    dp[v]=dp[u]+2;
                }
                f[v]=min(f[u],dp[v]-len[v]);
                ans=min(ans,dp[v]+l-len[v]);
                q.push(v);
            }
        }
        return ans;
    }
}pam;
int main(){
    for(int i=2;i<=100000;i++){
        lg2[i]=lg2[i/2]+1;
    }
    scanf("%d",&t);
    while(t--){
        scanf("%s",s+1);
        l=strlen(s+1);
        pam.init();
        for(int i=1;i<=l;i++){
            pam.insert(s[i]);
        }
        printf("%d\n",pam.solve());
    }
    return 0;
}

一气之下,改成直接跳fail了。快了不少。代码:

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
int t,l;
char s[N];
queue<int> q;
struct Pam{
    int n,tot,tmp,last,now,s[N],len[N],dep[N],f[N],dp[N],half[N],fail[N],ch[N][4];
    int newnode(int l){
        memset(ch[tot],0,sizeof(ch[tot]));
        len[tot]=l;
        return tot++;
    }
    void init(){
        tot=n=last=0;
        s[0]=-1;
        newnode(0);
        newnode(-1);
        f[0]=0;
        dp[0]=0;
        dp[1]=-1;
        fail[0]=1;
    }
    int getfail(int x){
        while(s[n-len[x]-1]!=s[n]){
            x=fail[x];
        }
        return x;
    }
    void insert(int x){
        if(x=='A'){
            x=0;
        }else if(x=='T'){
            x=1;
        }else if(x=='C'){
            x=2;
        }else{
            x=3;
        }
        s[++n]=x;
        tmp=getfail(last);
        if(!ch[tmp][x]){
            now=newnode(len[tmp]+2);
            fail[now]=ch[getfail(fail[tmp])][x];
            dep[now]=dep[tmp]+1;
            if(len[now]<=2){
                half[now]=fail[now];
            }else{
                int f=half[tmp];
                while(s[n-len[f]-1]!=s[n]||(len[f]+2)*2>len[now]){
                    f=fail[f];
                }
                half[now]=ch[f][x];
            }
            ch[tmp][x]=now;
        }
        last=ch[tmp][x];
    }
    int solve(){
        int ans=l,u,v;
        q.push(1);
        q.push(0);
        while(!q.empty()){
            u=q.front();
            q.pop();
            for(int i=0;i<4;i++){
                v=ch[u][i];
                if(!v){
                    continue;
                }
                if(len[v]%2==0){
                    if(u){
                        dp[v]=dp[u]+1;  
                    }else{
                        dp[v]=dp[u]+2;
                    }
                    dp[v]=min(dp[v],f[half[v]]+len[v]/2+1);
                }else{
                    dp[v]=dp[u]+2;
                }
                f[v]=min(f[u],dp[v]-len[v]);
                ans=min(ans,dp[v]+l-len[v]);
                q.push(v);
            }
        }
        return ans;
    }
}pam;
int main(){
    scanf("%d",&t);
    while(t--){
        scanf("%s",s+1);
        l=strlen(s+1);
        pam.init();
        for(int i=1;i<=l;i++){
            pam.insert(s[i]);
        }
        printf("%d\n",pam.solve());
    }
    return 0;
}
posted @ 2018-01-31 20:14  ez_2016gdgzoi471  阅读(161)  评论(0编辑  收藏  举报