给定M个长度为N的序列,从每个序列中任取一个数求和,可以构成NM个和。求其中最小的N个和。

先考虑M=2的一般情形,从2个序列中任取一个数构成的N2个和中求出前N小的和。

设两个数列A,B,把他们都排序。可以知道最小的一定是A[0]+B[0],次小和是min(A[0]+B[1],A[1]+B[0]),假设次小和是A[1]+B[0],那么第三小就是min(A[0]+B[1],A[1]+B[1],A[2]+B[0])

也就是说,当确定了A[i]+B[j]是第k小和后,A[i]+B[j+1]A[i+1]+B[j]就加入了第k+1小和的竞争者中。

需要注意的是,A[1]+B[2]和A[2]+B[1]都能产生A[2]+B[2]这个答案。为避免重复,可以规定在把j+1加入备选答案,以后只能增加j,不能再增加i。以避免和I+1相重复。

(i+1,j)=>(i+2,j) or (i+1,j+1),

(i,j+1)=>(i,j+2)

那么需要建立一个小根堆,堆中存储(i,j,last)这个三元祖,其中last表示上一次移动的是不是j,堆的比较以A[I]+B[J]作为权值。

起初只有(i,j,last),然后把(i,j+1,true)插入堆,如果last是false,再把(i+1,j,false)插入堆。

重复N次,每次取出堆顶节点的权值一起构成前N小和。算法复杂度是O(NlogN)

根据数学归纳法,可以知道先把前2行构成的前N小和作为新的序列,再与第三个序列构成新的前N小和。最终就能得到M个序列构成的前N小和。整个复杂度是O(MNlogN)。

 

代码如下:

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <queue>
#include <algorithm>
using namespace std;

int a[2333][2333];
int m,n;
int ans[2333],anss[2333],p;
struct node 
{
    int i,j;
    int x,y;
    bool last;
    node(){}
    node(int a,int b,int _x,int _y,bool c):x(_x),y(_y),i(a),j(b),last(c){}
    bool operator < (const node &b)const {
        return x+y>b.x+b.y;
    }
};
priority_queue<node,vector<node> > que;
int main() {
    int t;scanf("%d",&t);
    while(t--){
        scanf("%d%d",&m,&n);
        for(int i=0;i<m;++i)
            for(int j=0;j<n;++j)
                scanf("%d",&a[i][j]);
        
        for(int i=0;i<m;++i)
            sort(a[i],a[i]+n);
        for(int i=0;i<n;++i)
            ans[i]=a[0][i];

        for(int i=1;i<m;++i){
            while(que.size())que.pop();
            que.push(node(0,0,ans[0],a[1][0],false));
            for(int j=0;j<n;++j){
                node now=que.top();
                que.pop();
                anss[p++]=ans[now.i]+a[i][now.j];
                //printf("%d-%d=>%d\n",now.i,now.j,anss[p-1]);
                if(now.last==false)
                    que.push(node(now.i+1,now.j,ans[now.i+1],a[i][now.j],false));
                que.push(node(now.i,now.j+1,ans[now.i],a[i][now.j+1],true));
            }
            for(int i=0;i<n;++i)
                ans[i]=anss[i];
            
            p=0;
        }
        printf("%d",ans[0]);
        for(int i=1;i<n;++i)
            printf(" %d",ans[i]);
        puts("");
    }
    return 0;
}

 

posted on 2018-04-28 15:33  chagin  阅读(127)  评论(0编辑  收藏  举报