bzoj2208: [Jsoi2010]连通数

tarjan,状态压缩。

首先直接暴力可过。

第一步tarjan缩强联通分量,图变成一个dag。跑一个拓扑排序。

然后倒序用一个f[i]二进制数组表示i能到达的点。

因为2000个点已知数据类型放不下,用一个bitset。

然后答案就是sum(size[u]*size[v]) f[u][v]=1,u能到v。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<bitset>
using namespace std;
const int maxn = 2000 + 10;
const int maxm = 4000000 + 10;

int G[maxn],V[maxm],Next[maxm],Eid;
int low[maxn],dfn[maxn],color[maxn],size[maxn],vis[maxn],s[maxn],sp,vid,cid;
int q[maxn],l,r;
bitset<maxn> f[maxn];
int n,res;

void Addedge(int a,int b) {
    V[Eid]=b; Next[Eid]=G[a]; G[a]=Eid++;    
}

void build() {
    memset(G,-1,sizeof(G));
    scanf("%d",&n);
    for(int i=1,t;i<=n;i++)
    for(int j=1;j<=n;j++) {
        scanf("%1d",&t);
        if(t) Addedge(i,j);
    }
}

void tarjan(int u) {
    dfn[u]=low[u]=++vid;
    s[++sp]=u; vis[u]=1;
    
    for(int i=G[u];~i;i=Next[i]) {
        if(vis[V[i]]==0) {
            tarjan(V[i]);
            low[u]=min(low[u],low[V[i]]);    
        }
        else if(vis[V[i]]==1) 
            low[u]=min(low[u],dfn[V[i]]);    
    }
    
    if(low[u]==dfn[u]) {
        ++cid;
        do {
            color[s[sp]]=cid;
            size[cid]++;        
            vis[s[sp]]=2;
        }while(s[sp--]!=u);
    }
}

int g[maxn],v[maxm],next[maxm],in[maxn],eid;


void addedge(int a,int b) {
    v[eid]=b; next[eid]=g[a]; g[a]=eid++;
    in[b]++;    
}

void predo() {
    for(int i=1;i<=n;i++) if(!vis[i]) tarjan(i);    
    
    memset(g,-1,sizeof(g));
    for(int u=1;u<=n;u++) 
        for(int i=G[u];~i;i=Next[i]) 
            if(color[u]!=color[V[i]]) 
                addedge(color[u],color[V[i]]);
}

void toposort() {
    l=r=1;
    for(int i=1;i<=cid;i++) if(!in[i]) q[r++]=i;
    
    while(l<r) {
        int u=q[l++];
        for(int i=g[u];~i;i=next[i]) if(!--in[v[i]]) 
            q[r++]=v[i];    
    } 
}

void solve() {
    toposort();
    for(int i=1;i<=cid;i++) f[i][i]=1;
    
    for(int x=r-1,u;x;x--) {
        u=q[x];
        for(int i=g[u];~i;i=next[i]) 
            f[u]|=f[v[i]];
    }
    
    for(int i=1;i<=cid;i++)
    for(int j=1;j<=cid;j++)
        if(f[i][j]) 
            res+=size[i]*size[j];
    printf("%d\n",res);
}

int main() {
    build();
    predo();
    solve();
    return 0;
}
posted @ 2016-06-17 12:11  invoid  阅读(161)  评论(0编辑  收藏  举报