51nod1601完全图MST计数

我们可以根据二进位划分集合,从高到底,越来越

首先,最高位不同的点集合中,必须存在一条边,所以,可以用trie树来处理该操作并统计该边的数目。

然后对于两个不同的集合,递归重复这样的操作,就会得到两个集合的MST。

附代码:

#include "iostream"
#include "cstdio"
#include "algorithm"
#include "cmath"
#include "queue"
#include "cstring"
#define LL long long 
#define fo(i ,j ,k) for(int i=j; i<=k; i++)
#define pa pair<int ,int>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn = 1e5+5;
const int mod = 1e9+7;
const long long mo = 1e9+7;
int n ,a[maxn] ,s[maxn] ,t[maxn] ,cnt;
long long sum ,anscnt = 1;
struct node
{
    int nxt[2],cnt;
}tr[maxn*31];

void init()
{
    fo(i ,0, cnt)
    {
        tr[i].nxt[1] = tr[i].nxt[0] = tr[i].cnt = 0;
    }
    cnt = 0;
}
template<typename T>
void read(T &x)
{
    x = 0;
    char c = getchar();
    while(!isdigit(c))
        c = getchar();
    while(isdigit(c))
    {
        x = (x<<1) + (x<<3) + c -'0';
        c = getchar() ;
    }
    return ;
}

long long power(int x ,int y)
{
    long long res = 1;
    while(y)
    {
            if(y&1) res=1LL*res*x%mo;
        y>>=1;
        x = 1LL*x*x%mo;
    }    
    return res;
}

void insert(int x)
{
    int p = 0;
    for(int i=30; i>=0; i--)
    {
        int y = (x>>i)&1;
        if(!tr[p].nxt[y])
        {
            tr[p].nxt[y] = ++cnt;
            p = cnt;
        }
        else p = tr[p].nxt[y];
     }
     tr[p].cnt++;
}

inline pa find(int x)
{
    int p = 0;
    int ans = 0;
    for(int i=30; i>=0; i--)
    {
        int y = (x>>i)&1;
        if(tr[p].nxt[y])
            p = tr[p].nxt[y],ans|=(y<<i);
        else p = tr[p].nxt[y^1],ans|=((y^1)<<i);
    }
    return make_pair(ans^x ,tr[p].cnt);
}
void solve(int l ,int r , int dep)
{
   if(l>=r) return;
    if(dep<0){
        if(r-l>=1)
            anscnt=1LL*anscnt*power(r-l+1,r-l-1)%mod;
        return;
    }
      int  cnt1 = 0,cnt2 = 0;
    for(int i=l;i<=r;i++)
        if((a[i]>>dep)&1) s[cnt1++]=a[i];
        else t[cnt2++]=a[i];
    for(int i=0;i<cnt1;i++) a[l+i]=s[i];
    for(int i=0;i<cnt2;i++) a[l+cnt1+i]=t[i];

    init();pa tmp;int ans=inf,cnt=0;
    for(int i=0;i<cnt2;i++) insert(t[i]);
    for(int i=0;i<cnt1;i++){
        tmp=find(s[i]);
        if(tmp.first<ans)
            ans=tmp.first,cnt=tmp.second;
        else if(tmp.first==ans)
            cnt+=tmp.second;
    }
     if(sum!=inf&&cnt) sum+=ans,anscnt=1LL*cnt*anscnt%mo;
    solve(l,l+cnt1-1,dep-1);solve(l+cnt1,r,dep-1);
}

int main(void)
{
    read(n);
    fo(i ,1 ,n)
    {
        read(a[i]);
    }
    solve(1 ,n ,30);
    printf("%lld\n%d\n",sum,anscnt);
    return 0;
}

 

posted @ 2017-12-23 14:25  Mnirvana  阅读(298)  评论(0编辑  收藏  举报