qoj2571 Aidana and Pita
题意
给出 \(n\) 个数 \(a_i\),你需要把这些数分成三组,使得每组之和的极差最小,给出方案。
\(3\le n\le 25,1\le a_i\le 10^7\)。
思路
\(n\le 25\),考虑折半搜索。
发现 \(3^{13}\approx 10^6\),所以复杂度是对的。
于是问题就变为给出两个三元组集合 \((a_i,b_i,c_i),(a'_j,b'_j,c'_j)\),要求从两组集合中分别选出一个三元组使得 \((a_i+a'_j,b_i+b'_j,c_i+c'_j)\) 的极差最小。
容易发现只需要统计所有 \(a_i+a'_j\ge b_i+b'_j\ge c_i+c'_j\) 的集合,此时的极差为 \(a_i-c_i+a'_j-c'_j\)。
令 \(x_i=a_i-b_i,y_i=b_i-c_i,x'_j=a'_j-b'_j,y'_j=b'_j-c'_j\),题目转换为给出两个二元组集合 \((x_i,y_i),(x'_j,y'_j)\),从两组集合中分别选出一个二元组满足 \(x_i\ge -x'_j,y_i\ge -y'_j\),且 \((x_i+y_i+x'_j+y'_j)\) 最小。
剩下的就简单了。将 \(x'_j,y'_j\) 取反,按照 \(x\) 为关键字将 \((x_i,y_i)\) 排序并遍历,找出最大的 \(x'_j+y'_j\) 满足 \(x_i\ge x'_j,y_i\ge y'_j\) 即可,可以离散化后使用树状数组(不离散化应该也可以),此时的最小值即为 \(x_i+y_i-x'_j-y'_j\)。
代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct pi{
int x,y,id;
};
int n,a[30],tc[30],c[2200005],ci[2200005],to[2200005],mn=INT_MAX,i1,i2;
vector<pi> p1,p2;
vector<int> lsh,ans;
void add(int x,int v,int id){
while(x<=lsh.size()+1){
if(v>c[x])
c[x]=v,ci[x]=id;
x+=(x&-x);
}
}
pair<int,int> qu(int x){
int res=INT_MIN,id=0;
while(x){
if(c[x]>res)
res=c[x],id=ci[x];
x-=(x&-x);
}
return {res,id};
}
bool cmp(pi x,pi y){
return x.x<y.x;
}
void dfs(vector<pi> &p,int l,int r,pi nw,int x){
if(x==r+1){
p.push_back(nw);
return;
}
dfs(p,l,r,{nw.x+a[x],nw.y,nw.id},x+1);
dfs(p,l,r,{nw.x-a[x],nw.y+a[x],nw.id+tc[x-l]},x+1);
dfs(p,l,r,{nw.x,nw.y-a[x],nw.id+2*tc[x-l]},x+1);
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(nullptr),cout.tie(nullptr);
memset(c,-0x3f,sizeof(c));
tc[0]=1;
for(int i=1;i<=13;i++)
tc[i]=tc[i-1]*3;
cin>>n;
for(int i=1;i<=n;i++)
cin>>a[i];
dfs(p1,1,n/2,{0,0,0},1);
dfs(p2,n/2+1,n,{0,0,0},n/2+1);
for(pi &v:p2)
v.x*=-1,v.y*=-1;
for(pi v:p1)
lsh.push_back(v.y);
for(pi v:p2)
lsh.push_back(v.y);
sort(lsh.begin(),lsh.end());
int len=unique(lsh.begin(),lsh.end())-lsh.begin();
for(pi &v:p1){
int x=lower_bound(lsh.begin(),lsh.begin()+len,v.y)-lsh.begin()+1;
to[x]=v.y,v.y=x;
}
for(pi &v:p2){
int x=lower_bound(lsh.begin(),lsh.begin()+len,v.y)-lsh.begin()+1;
to[x]=v.y,v.y=x;
}
sort(p1.begin(),p1.end(),cmp);
sort(p2.begin(),p2.end(),cmp);
int l=0;
for(pi v:p1){
while(l<p2.size()&&p2[l].x<=v.x)
add(p2[l].y,p2[l].x+to[p2[l].y],p2[l].id),l++;
pair<int,int> tmp=qu(v.y);
int res=v.x+to[v.y]-tmp.first;
if(res<mn)
mn=res,i1=v.id,i2=tmp.second;
}
for(int i=1;i<=n;i++){
if(i<=n/2) cout<<i1%3+1<<" ",i1/=3;
else cout<<i2%3+1<<" ",i2/=3;
}
return 0;
}

浙公网安备 33010602011771号