【BZOJ2244】拦截导弹(SDOI2011)-DP+CDQ分治

测试地址:拦截导弹
做法:本题需要用到DP+CDQ分治。
很容易想到,先求出最长不上升子序列的数量,再对每个点进行判断,如果该点可能在最长不上升子序列中,就用包含它的最长不上升子序列数量除以方案总数得到概率。那么我们要求的就是以某枚导弹开头或者结尾的最长不上升子序列长度及数量。下面先讨论从前往后的方向。
f(i)为最后选第i枚导弹能得到的最长不上升序列长度,很容易写出状态转移方程:
f(i)=max{f(j)+1|j<i,hjhi,vjvi}
我们可以把下标,高度,速度作为一个点的三个信息,这显然就是一个三维偏序问题。
按照口诀:一维排序,二维分治,三维数据结构,得到如下算法:
首先,题目的输入已经按照时间顺序排好了,所以第一步完成。
接下来,CDQ分治,先递归处理左半,处理左半对右半的影响,再递归处理右半。
那么难点就在于处理左半对右半的影响。注意到,左半区间的时间一定比右半区间的时间小了,因此不用考虑时间,而高度的限制可以用两个指针来处理(即右边指针先移动,然后左边指针一直移动到不合法位置,顺便插入新信息),于是就只剩下速度的限制了。我们用树状数组维护速度在一个后缀区间上的导弹的f最大值以及对应的方案数(至于如何用树状数组维护后缀……把两个函数的循环变量变化方向取反就好了,这个我是参照ZJOI2017-树状数组的写法写的),就能求出左半区间对右半区间状态转移的影响了。
至于方向相反的求法,只要把下标,高度,速度都取反,然后按照上面的方法再做一遍即可。注意高度和速度要离散化。那么我们就可以计算最后的答案了,对于某枚导弹,它可能在最长不上升子序列中当且仅当以它开头的最长不上升子序列长度,加上以它结尾的最长不上升子序列长度,再减去1(因为它自己重复计算了),这个数和答案相等。而包含它的最长不上升子序列数就是以它开头的最长不上升子序列数,乘上以它结尾的最长不上升子序列数。这样我们就能轻易地计算概率了。于是我们就以O(nlog2n)的时间复杂度解决了这个问题。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
int n,maxh,maxv;
int mx[50010]={0},ff[50010]={0},gf[50010]={0},ans=0,changed[50010],top;
double sum[50010]={0},fg[50010]={0},gg[50010]={0},tot=0.0;
struct Q
{
    int id,h,v;
}q[50010];
bool cmpid(Q a,Q b) {return a.id<b.id;}
bool cmph(Q a,Q b) {return a.h<b.h;}
bool cmpv(Q a,Q b) {return a.v<b.v;}

int lowbit(int x)
{
    return x&(-x);
}

void insert(int pos,int f,double g)
{
    changed[++top]=pos;
    for(int i=pos;i;i-=lowbit(i))
    {
        if (f>mx[i]) mx[i]=f,sum[i]=g;
        else if (f==mx[i]) sum[i]+=g;
    }
}

void query(int pos,int &f,double &g)
{
    f=0,g=0.0;
    for(int i=pos;i<=maxv;i+=lowbit(i))
    {
        if (mx[i]>f) f=mx[i],g=sum[i];
        else if (mx[i]==f) g+=sum[i];
    }
}

void clear(int pos)
{
    for(int i=pos;i;i-=lowbit(i))
        mx[i]=sum[i]=0;
}

void solve(int l,int r,int *f,double *g)
{
    if (l==r)
    {
        f[l]++;
        if (f[l]==1) g[l]=1.0;
        return;
    }

    int mid=(l+r)>>1;
    solve(l,mid,f,g);

    sort(q+l,q+mid+1,cmph);
    sort(q+mid+1,q+r+1,cmph);
    int now=mid;
    top=0;
    for(int i=r;i>mid;i--)
    {
        while(now>=l&&q[now].h>=q[i].h)
        {
            insert(q[now].v,f[q[now].id],g[q[now].id]);
            now--;
        }
        int nxtf;
        double nxtg;
        query(q[i].v,nxtf,nxtg);
        if (nxtf>f[q[i].id]) f[q[i].id]=nxtf,g[q[i].id]=nxtg;
        else if (nxtf==f[q[i].id]) g[q[i].id]+=nxtg;
    }
    for(int i=1;i<=top;i++)
        clear(changed[i]);

    sort(q+l,q+r+1,cmpid);
    solve(mid+1,r,f,g);
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&q[i].h,&q[i].v);
        q[i].id=i;
    }

    sort(q+1,q+n+1,cmph);
    maxh=0;
    for(int i=1;i<=n;i++)
    {
        if (i==1||q[i].h!=q[i-1].h) q[i-1].h=maxh,maxh++;
        else q[i-1].h=maxh;
    }
    q[n].h=maxh;

    sort(q+1,q+n+1,cmpv);
    maxv=0;
    for(int i=1;i<=n;i++)
    {
        if (i==1||q[i].v!=q[i-1].v) q[i-1].v=maxv,maxv++;
        else q[i-1].v=maxv;
    }
    q[n].v=maxv;

    sort(q+1,q+n+1,cmpid);
    solve(1,n,ff,fg);
    for(int i=1;i<=n;i++)
        ans=max(ans,ff[i]);
    for(int i=1;i<=n;i++)
        if (ff[i]==ans) tot+=fg[i];

    for(int i=1;i<=n;i++)
    {
        q[i].id=n-q[i].id+1;
        q[i].h=maxh-q[i].h+1;
        q[i].v=maxv-q[i].v+1;
    }
    sort(q+1,q+n+1,cmpid);
    solve(1,n,gf,gg);

    printf("%d\n",ans);
    for(int i=1;i<n-i+1;i++)
        swap(gf[i],gf[n-i+1]),swap(gg[i],gg[n-i+1]);
    for(int i=1;i<=n;i++)
    {
        if (ff[i]+gf[i]-1==ans)
            printf("%.6lf ",fg[i]*gg[i]/tot);
        else printf("0.000000 ");
    }

    return 0;
}
posted @ 2018-05-21 21:35  Maxwei_wzj  阅读(109)  评论(0编辑  收藏  举报