bzoj 1112 poi 2008 砖块

这滞胀题调了两天了...

好愚蠢的错误啊...

其实这道题思维比较简单,就是利用treap进行维护(有人说线段树好写,表示treap真心很模板)

就是枚举所有长度为k的区间,查出中位数,计算代价即可。

(根据绝对值不等式的几何意义,中位数一定是最优解)

而维护长度为k的区间也很简单,就是首先把前k个扔到树上,然后每次把新来的插入,把最前面的一个删除即可

至于求中位数,简直就是基础操作嘛

关键在于...代价怎么算?

显然我们不能把所有数枚举出来挨个加减,这样会T飞的...

所以我们考虑直接在treap上维护,根据treap很重要的性质:左树<根<右树

那么我们对每个节点,维护一个子树权值和,这样就可以做到在查询中位数的同时查出小于中位数的数之和和大于中位数的数之和了

注意一个小细节,就是在查询的时候,要把重复出现的中位数分左右放到左右的和里,否则计算会有bug

剩下的就是模板了

不要像我一样,插点不修改树的大小,输出全是负数...

贴代码(巨丑)

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ls tree[rt].lson
#define rs tree[rt].rson
#define ll long long
using namespace std;
struct Treap
{
    int lson;
    int rson;
    int huge;
    int same;
    ll val;
    int rank;
    ll sum;
}tree[100005];
int a[100005];
int tot=0;
int n,k,mid;
ll s[100005];
int rot=0;
inline int read()
{
    int f=1,x=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void update(int rt)
{
    tree[rt].huge=tree[ls].huge+tree[rs].huge+tree[rt].same;
    tree[rt].sum=(ll)tree[ls].sum+(ll)tree[rs].sum+(ll)tree[rt].same*(ll)tree[rt].val;
}
void lturn(int &rt)
{
    int temp=rs;
    rs=tree[rs].lson;
    tree[temp].huge=tree[rt].huge;
    tree[temp].sum=tree[rt].sum;
    tree[temp].lson=rt;
    update(rt);
    rt=temp;
}
void rturn(int &rt)
{
    int temp=ls;
    ls=tree[ls].rson;
    tree[temp].huge=tree[rt].huge;
    tree[temp].rson=rt;
    tree[temp].sum=tree[rt].sum;
    update(rt);
    rt=temp;
}
void ins(int &rt,ll v)
{
    if(!rt)
    {
        rt=++tot;
        tree[rt].huge=1;
        tree[rt].same=1;
        tree[rt].val=v;
        tree[rt].rank=rand();
        tree[rt].sum=v;
        return;
    }
    tree[rt].sum+=v;
    tree[rt].huge++;
    if(tree[rt].val==v)
    {
        tree[rt].same++;
        return;
    }
    if(tree[rt].val>v)
    {
        ins(ls,v);
        if(tree[ls].rank<tree[rt].rank)
        {
            rturn(rt);
        }
    }else
    {
        ins(rs,v);
        if(tree[rs].rank<tree[rt].rank)
        {
            lturn(rt);
        }
    }
}
void del(int &rt,ll v)
{
    if(!rt)
    {
        return;
    }
    if(tree[rt].val==v)
    {
        if(tree[rt].same>1)
        {
            tree[rt].huge--;
            tree[rt].same--;
            tree[rt].sum-=(ll)v;
            return;
        }else if(ls*rs==0)
        {
            rt=ls+rs;
            return;
        }else
        {
            if(tree[ls].rank<tree[rs].rank)
            {
                rturn(rt);
                del(rt,v);
            }else
            {
                lturn(rt);
                del(rt,v);
            }
        }
        return;
    }
    tree[rt].huge--;
    tree[rt].sum-=v;
    if(tree[rt].val>v)
    {
        del(ls,v);
    }else
    {
        del(rs,v);
    }
    update(rt);
}
ll Lsum,Rsum;
int tt;
int query_num(int rt,int v)
{
    if(!rt)
    {
        return 0;
    }
    if(tree[ls].huge>=v)
    {
        Rsum+=(ll)tree[rs].sum+(ll)tree[rt].same*(ll)tree[rt].val;
        return query_num(ls,v);
    }else if(tree[ls].huge+tree[rt].same<v)
    {
        Lsum+=(ll)tree[ls].sum+(ll)tree[rt].val*(ll)tree[rt].same;
        return query_num(rs,v-tree[ls].huge-tree[rt].same);
    }else
    {
        Lsum+=(ll)tree[ls].sum+(ll)(v-tree[ls].huge-1)*(ll)tree[rt].val;
        Rsum+=(ll)tree[rs].sum+(ll)(tree[ls].huge+tree[rt].same-v)*(ll)tree[rt].val;
        return tree[rt].val;
    }
}
int main()
{
    n=read(),k=read();
    mid=(k+1)/2;
    for(int i=1;i<=n;i++)
    {
        a[i]=read();
    }
    for(int i=1;i<=k;i++)
    {
        ins(rot,a[i]);
    }
    int lret=1,rret=k;
    int v0=query_num(rot,mid);
    int ret=v0;
    ll co=(ll)(mid-1)*(ll)v0-Lsum+Rsum-(ll)(k-mid)*(ll)v0;
    for(int i=k+1;i<=n;i++)
    {
        int st=i-k+1;
        del(rot,a[st-1]);
        ins(rot,a[i]);
        tt=0,Lsum=0,Rsum=0;
        int v1=query_num(rot,mid);
        ll temp=(ll)(mid-1)*(ll)v1-Lsum+Rsum-(ll)(k-mid)*(ll)v1;
        if(co>temp)
        {
            co=temp;
            lret=st;
            rret=i;
            ret=v1;
        }
    }
    printf("%lld\n",co);
    for(int i=1;i<=n;i++)
    {
        if(i<lret||i>rret)
        {
            printf("%d\n",a[i]);
        }else
        {
            printf("%d\n",ret);
        }
    }
    return 0;
}

 

posted @ 2018-09-15 14:14  lleozhang  Views(134)  Comments(0Edit  收藏  举报
levels of contents