AtCoder Beginner Contest 266 G,Ex
G
考虑先放G和B,此时共有\(C_{G+B}^{B}\)种方案。
然后选出\(k\)个G,在前面放上\(R\),共有\(C_{G}^{k}\)种方案。
最后我们放剩下的\(R-K\)个R,考虑目前哪些区间内部可以放一段连续的\(R\)。可以发现,单独G的后面,以及B的前后,RG的前后是可以放的,总共是\(B-k+1\)个区间内可以放\(R\)。那么此时就是一个经典的问题——给你\(N\)个桶和\(M\)个球,每个桶可以放任意多个球,球是无标号的,求总方案数。用隔板法发现答案就是\(C^{N-1}_{N+M-1}\)。
最后将这些乘起来就可以得到总方案数了。
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
const int maxn=3000005,mod=998244353;
int r,g,b,k;
int fac[maxn];
int qpow(int x,int y) {
if(y==0) return 1;
int ret=qpow(x,y>>1);
ret=1ll*ret*ret%mod;
if(y&1) ret=1ll*ret*x%mod;
return ret;
}
int C(int x,int y) {
return 1ll*fac[x]*qpow(fac[y],mod-2)%mod*qpow(fac[x-y],mod-2)%mod;
}
int D(int x,int y) {
return C(x+y-1,x-1);
}
int main() {
fac[0]=1; for(int i=1;i<maxn;i++) fac[i]=1ll*fac[i-1]*i%mod;
scanf("%d%d%d%d",&r,&g,&b,&k);
printf("%d\n",(int)(1ll*C(g+b,b)*C(g,k)%mod*D(k+b+1,r-k)%mod));
return 0;
}
H
我们发现很多坐标是没用的,准确的说,是只有出发点\((0,0)\)、每个snuke出现的点\((x_{i},y_{i})\)只在\(t_{i}\)时刻是有用的。
那么设\(dp(i)\)表示到了第\(i\)个点,且目前时间是\(t_{i}\),我们得到最大Size的和。
不妨设\(dp(0)\)表示在初始点,并且\(x_{0}=y_{0}=t_{0}=a_{0}=0\)。
那么\(dp(j)\)能转移到\(dp(i)\),当且仅当\(y_{i}\geq y_{j},y_{i}-y_{j}+|x_{i}-x_{j}|\leq t_{i}-t_{j}\)。
转移方程式就是\(dp(i)=\max_{y_{i}\geq y_{j},y_{i}-y_{j}+|x_{i}-x_{j}|\leq t_{i}-t_{j}}(dp(j))+a_{i}\)
考虑优化:发现\(y_{i}-y_{j}+|x_{i}-x_{j}|\leq t_{i}-t_{j}\)这个条件有绝对值符号,我们想要去掉。最简单的想法就是分类讨论。不过我们这里没用必要,因为\(x_{i}-x_{j}<0\)时,\(y_{i}-y_{j}+x_{i}-x_{j}\)没有\(y_{i}-y_{j}-x_{j}+x_{i}\)大。故我们可以将其拆成两个限制:\(y_{i}-y_{j}+x_{i}-x_{j}\leq t_{i}-t_{j}\)和\(y_{i}-y_{j}+x_{j}-x_{i}\leq t_{i}-t_{j}\)。
再简化,可以得到:\(t_{j}-x_{j}-y_{j}\leq t_{i}-x_{i}-y_{i}\),\(t_{j}+x_{j}-y_{j}\leq t_{i}+x_{i}-y_{i}\)。我们设\(t1_{i}=y_{i},t2_{i}=t_{i}-x_{i}-y_{i},t3_{i}=t_{i}+x_{i}-y_{i}\),那么原来的转移就能写成:
也就是一个三维偏序的转移,可以使用CDQ分治解决。
时间复杂度为\(O(n\log^2 n)\)。
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
const int maxn=100005;
int n,t[maxn],x[maxn],y[maxn],a[maxn];
ll ans,dp[maxn],val[maxn*4];
std::vector<std::array<int,4>> vec;
bool cmp(int id1,int id2) {
// return vec[id1][1]<vec[id2][1]; 这样写错3个点
// return vec[id1][1]==vec[id2][1]?vec[id1][2]<vec[id2][2]:vec[id1][1]<vec[id2][1]; 这样写错2个点
// return vec[id1][1]==vec[id2][1]?id1<id2:vec[id1][1]<vec[id2][1]; 这样写可以AC
return vec[id1][1]==vec[id2][1]?(vec[id1][2]==vec[id2][2]?id1<id2:vec[id1][2]<vec[id2][2]):vec[id1][1]<vec[id2][1];
//这样写也可以AC
}
#define lowbit(x) (x&-x)
void upd(int pos,ll num) {
for(int i=pos;i<maxn*4;i+=lowbit(i)) {
val[i]=std::max(val[i],num);
}
}
void updt(int pos) {
for(int i=pos;i<maxn*4;i+=lowbit(i)) {
val[i]=-1e18;
}
}
ll qry(int pos) {
ll ret=-1e18;
for(int i=pos;i;i-=lowbit(i)) {
ret=std::max(ret,val[i]);
}
return ret;
}
void divide(int l,int r) {
if(l==r) return;
int mid=l+r>>1;
divide(l,mid);
std::vector<int> ids;
for(int i=l;i<=r;i++) {
ids.push_back(i);
}
std::sort(ids.begin(),ids.end(),cmp);
for(auto id : ids) {
if(id<=mid) {
upd(vec[id][2]+1,dp[id]);
} else {
dp[id]=std::max(dp[id],qry(vec[id][2]+1)+vec[id][3]);
}
}
for(auto id : ids) {
if(id<=mid) {
updt(vec[id][2]+1);
}
}
divide(mid+1,r);
}
int main() {
scanf("%d",&n);
std::vector<int> nums;
for(int i=1;i<=n;i++) {
scanf("%d%d%d%d",&t[i],&x[i],&y[i],&a[i]);
if(x[i]+y[i]<=t[i]) {
nums.push_back(y[i]);
nums.push_back(t[i]-x[i]-y[i]);
nums.push_back(t[i]+x[i]-y[i]);
}
}
nums.push_back(0);
std::sort(nums.begin(),nums.end());
nums.erase(std::unique(nums.begin(),nums.end()),nums.end());
for(int i=1;i<=n;i++) {
if(x[i]+y[i]<=t[i]) {
int y_=std::lower_bound(nums.begin(),nums.end(),y[i])-nums.begin();
int t1=std::lower_bound(nums.begin(),nums.end(),t[i]-x[i]-y[i])-nums.begin();
int t2=std::lower_bound(nums.begin(),nums.end(),t[i]+x[i]-y[i])-nums.begin();
vec.push_back({y_,t1,t2,a[i]});
}
}
std::sort(vec.begin(),vec.end());
vec.insert(vec.begin(),{0,0,0,0});
for(int i=1;i<(int)vec.size();i++) dp[i]=-1e18;
for(int i=0;i<maxn*4;i++) val[i]=-1e18;
divide(0,(int)vec.size()-1);
for(int i=0;i<(int)vec.size();i++) ans=std::max(ans,dp[i]);
/* for(int i=1;i<(int)vec.size();i++) dp[i]=-1e18;
for(int i=1;i<(int)vec.size();i++) {
for(int j=0;j<i;j++) {
if(vec[j][0]<=vec[i][0]&&vec[j][1]<=vec[i][1]&&vec[j][2]<=vec[i][2]){
dp[i]=std::max(dp[i],dp[j]+vec[i][3]);
}
}
ans=std::max(ans,dp[i]);
}
//O(n^2)的解法
*/
printf("%lld\n",ans);
return 0;
}
浙公网安备 33010602011771号