习题:冒泡排序(逆序对&权值线段树)

题目

传送门

思路

不好想到

定义t数组,其中\(t_i\)表示在\(1\le j<i\)中有多少个大于\(a_i\)\(a_j\)

那么逆序对就是\(\sum_{i=1}^n{t_i}\)

考虑一操作

交换\(a_i\)\(a_{i+1}\)

因为只是交换这两个数

所以除了这两个数之外的\(t_j\)不会发生任何变化

而维护\(t_i\)\(t_{i+1}\)只需要考虑\(a_i\)\(a_{i+1}\)的大小就可以了

再考虑二操作

注意到一个性质

一轮冒泡排序会让\(t_i=max(t_i-1,0)\)

其理由是在任何一轮的冒泡排序之中,我们我们最多只会让\(1\le j <i\)的一个\(a_j\)跑到\(a_i\)的右边

而如果要让\(a_j\)跑到\(a_i\)右边,那么一定有\(a_j>a_i\),所以此时\(t_i\)会减1

那k轮呢?

应当有\(t_i=max(t_i-k,0)\)

看上去有一个max很难维护

但其实用权值线段树就好了

最后的答案就是询问值域\(k+1\rightarrow n\)的和减去k乘上值域\(k+1\rightarrow n\)有多少个\(t_i\)

总时间复杂度为\(O(n*log_n)\)

代码

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
struct node
{
    int l;
    int r;
    int tot;
    long long s;
}tre[1000005];
int n,m;
int a[200005];
int b[200005];
long long bit[200005];
int lowbit(int x)
{
    return x&(-x);
}
void update(int x,int val)
{
    for(int i=x;i<=n;i+=lowbit(i))
        bit[i]+=val;
}
long long solve_s(int x)
{
    long long ret=0;
    for(int i=x;i>0;i-=lowbit(i))
        ret+=bit[i];
    return ret;
}   
void build(int l,int r,int k)
{
    //cout<<k<<'\n';
    tre[k].l=l;
    tre[k].r=r;
    tre[k].s=0;
    if(l==r)
        return;
    int mid=(l+r)>>1;
    build(l,mid,k<<1);
    build(mid+1,r,k<<1|1);
}
void add(int a,int f,int k)
{
    if(tre[k].l>a||a>tre[k].r)
        return;
    if(tre[k].l==tre[k].r)
    {
        tre[k].tot+=f;
        tre[k].s=tre[k].s+tre[k].l*f;
        return;
    }
    add(a,f,k<<1);
    add(a,f,k<<1|1);
    tre[k].tot=tre[k<<1].tot+tre[k<<1|1].tot;
    tre[k].s=tre[k<<1].s+tre[k<<1|1].s;
}
long long ask(int l,int r,int k)
{
    if(l>r)
        return 0;
    if(l>tre[k].r||tre[k].l>r)
        return 0;
    if(l<=tre[k].l&&tre[k].r<=r)    
        return tre[k].s-1ll*tre[k].tot*(l-1);
    return ask(l,r,k<<1)+ask(l,r,k<<1|1);
}
int main()
{
    //freopen("bubble.in","r",stdin);
    //freopen("bubble.out","w",stdout); 
    scanf("%d %d",&n,&m);
    build(0,n,1);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        b[i]=i-1-solve_s(a[i]);
        update(a[i],1);
        add(b[i],1,1);
    }
    for(int i=1;i<=m;i++)
    {
        int t;
        long long c;
        scanf("%d %lld",&t,&c);
        if(t==1)
        {
            if(a[c]>a[c+1])
            {
                add(b[c],-1,1);
                add(b[c+1],-1,1);
                int x=b[c];
                int y=b[c+1];
                b[c]=y-1;
                b[c+1]=x;
                add(b[c],1,1);
                add(b[c+1],1,1);
                swap(a[c],a[c+1]);
            }
            else if(a[c]<a[c+1])
            {
                add(b[c],-1,1);
                add(b[c+1],-1,1);
                int x=b[c];
                int y=b[c+1];
                b[c]=y;
                b[c+1]=x+1;
                add(b[c],1,1);
                add(b[c+1],1,1);
                swap(a[c],a[c+1]);
            }
        }
        else
        {
            printf("%lld\n",ask(c+1,n,1));
        }
    }
    return 0;
}


posted @ 2020-03-07 22:57  loney_s  阅读(333)  评论(0)    收藏  举报