AtCoder Beginner Contest 257 Ex
这么难的题出在ABC真的好吗?
首先是一波推式子,假设我们选的骰子集合为\(D=\{i1,i2,...,ik\}\)(这里\(D\)储存的是骰子的下标),那么我们不难列出下面的式子:
考虑前半部分很烦,我们可以拆掉平方,即利用完全平方公式
可以得到:
我们发现,对于\(i\in D,1\leq j\leq 6\),\(A_{i,j}^2\)的贡献是\(6^{k-1}\)(相当于固定了这一位,其它\((k-1)\)个骰子可以在\(1\)到\(6\)中任意选择),而\(A_{i,j}A_{k,l}\)的贡献是\(6^{k-2}\)(相当于固定这\(2\)位,其它\((k-1)\)个骰子任意选择)
所以,我们就可以愉快的把最外面,也是最难枚举的\(\sum_{j1,j2,...,jk}\)给去除了!
将\(\frac{1}{6^k}\)乘进去:
我们发现:
(相当于两两组合,但是每一对\((i,j)\)要被计算两遍:在\((i,j)\)和\((j,i)\)时都会被计算,故要除以\(2\),并且还要减去自己和自己组合的情况)
那么我们设\(S_{i}=\sum_{j=1}^{6} A_{i,j}\),\(P_{i}=\sum_{j=1}^{6} A_{i,j}^2\),可得:
设\(T_{i}=6P_{i} - {S_{i}}^2 - 36C_{i}\)。
则:
\(S_{i},T_{i}\)可以在输入时就计算好。那么现在我们就将原问题转成如下的问题:
给你\(n\)个数对\((x_{i},y_{i})\),让你选出\(k\)个数对,使得\((\sum x_{i})^2+\sum y_{i}\)最大。
此时我们可以得到一个非常重要的结论:
若我们选择集合\(D\)最优,那么一定存在一个整数\(c\),满足对于任意\(i\in D,j\notin D\),\(cx_{i}+y_{i}\geq cx_{j}+y_{j}\)。
证明:
咕咕咕(马上补)
那么此时我们就可以枚举整数\(c\)从\(-\inf\)到\(\inf\),找到当前\(cx_{i}+y_{i}\)最大的\(k\)个,将\((\sum x)^2+\sum y\)更新到答案。
此时时间复杂度为\(O(n\log n\times \inf)\),虽然不能AC,但是已经不是一个指数级别的做法了。我们已经有了很大的进步。
其实还可以优化。我们利用类似离散化的思想:如果\(c\)时刻的\(cx_{i}+y_{i}\)的数列和\((c+1)\)时刻的不变,那么我们没有必要枚举\((c+1)\)时刻。所以,数列只会在从某对\((i,j)\)从\(cx_{i}+y_{i} < cx_{j}+y_{j}\)变成\(cx_{i}+y_{i} \geq cx_{j}+y_{j}\)时改变。即\(c=\lceil \frac{y_{j}-y_{i}}{x_{i}-x_{j}}\rceil\)时改变。那么我们把这\(O(n^2)\)个时刻存储下来,枚举\(c\)等于它们的时候,以及最初(\(c=-\inf\))时即可。时间复杂度为\(O(n^3\log n)\)。
继续优化!我们发现上面做法慢的原因在于每次枚举新的\(c\)时,总是要对数组重新排序。如果我们可以利用单调性维护排序数组呢?我们发现,当\(c=\lceil \frac{y_{j}-y_{i}}{x_{i}-x_{j}}\rceil\)时,我们只改变\(i,j\)的大小关系,也就是说对于之前的数组,我们只要交换\(i,j\)就可以得到现在的数组。故我们就可以先排序,之后每枚举一个\(c\),交换对应的下标即可,时间复杂度为\(O(n^2\log n)\)。
代码如下:
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
const int maxn=1005;
const ll inf=4e18;
int n,k;
int c[maxn],a[maxn][10],id[maxn];
ll s[maxn],t[maxn],p[maxn];
ll calc(int ind) {
int x=id[ind],y=id[ind+1];
if(s[x]>=s[y]) return inf;
return (t[x]-t[y]+s[y]-s[x]-1ll)/(s[y]-s[x]);
}
int main() {
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<=n;i++)
for(int j=1;j<=6;j++)
scanf("%d",&a[i][j]);
ll Ss=0,St=0,ans=-inf;
for(int i=1;i<=n;i++) {
for(int j=1;j<=6;j++) {
s[i]+=1ll*a[i][j];
p[i]+=1ll*a[i][j]*a[i][j];
}
t[i]=6ll*p[i]-s[i]*s[i]-36ll*c[i];
id[i]=i;
}
std::sort(id+1,id+n+1,[](int id1,int id2){return t[id1]>t[id2];});
for(int i=1;i<=k;i++) Ss+=s[id[i]],St+=t[id[i]];
ans=std::max(ans,Ss*Ss+St);
std::set<std::pair<long long,int>> S;
for(int i=1;i<n;i++) S.insert({calc(i),i});
while(!S.empty()) {
std::pair<long long,int> rec=*S.begin();
if(rec.first==inf) break;
S.erase({calc(rec.second),rec.second});
if(rec.second>=2) {
S.erase({calc(rec.second-1),rec.second-1});
}
if(rec.second<=n-2) {
S.erase({calc(rec.second+1),rec.second+1});
}
if(rec.second==k) {
Ss-=s[id[rec.second]]; St-=t[id[rec.second]];
}
std::swap(id[rec.second],id[rec.second+1]);
S.insert({calc(rec.second),rec.second});
if(rec.second>=2) {
S.insert({calc(rec.second-1),rec.second-1});
}
if(rec.second<=n-2) {
S.insert({calc(rec.second+1),rec.second+1});
}
if(rec.second==k) {
Ss+=s[id[rec.second]]; St+=t[id[rec.second]];
}
ans=std::max(ans,Ss*Ss+St);
}
//注意ans还有可能为负数
ans=(ans%998244353+998244353)%998244353;
printf("%lld\n",ans*859599304ll%998244353ll);
return 0;
}
浙公网安备 33010602011771号