#317. 01背包
题面
分析
分治背包模板题,只需对一边的01背包做前缀和,对于每个询问时间复杂度为\(O(m)\)
#include<bits/stdc++.h>
#define ls(p) p<<1
#define rs(p) p<<1|1
#define ll long long
const int P=998244353;
using namespace std;
const int N=1e5+5,M=505,K=2e4+5;;
struct A{
int i,x,y,k;
};
vector<A>V[N];
struct B {
int sum,num;
}t[M],ans[N],f[K][M],g[K][M];
int n,a[N],b[N];
void dfs(int p,int l,int r) {
if(l>r) return;
if(l==r) {
for(auto v:V[p]) {
ans[v.i]=(v.k>=b[l])?(B){a[l],1}:(B){0,0};
}
return;
}
int mid=l+r>>1;
for(int i=1;i<=500;i++) {
t[i]=(B){0,0};
}
t[0]=(B){0,1};
for(int i=mid;i>=l;i--) {
for(int j=500;j>=b[i];j--) {
if(t[j].sum<t[j-b[i]].sum+a[i]) {
t[j].sum=t[j-b[i]].sum+a[i];
t[j].num=t[j-b[i]].num;
} else if(t[j].sum==t[j-b[i]].sum+a[i]) {
t[j].num=(t[j].num+t[j-b[i]].num)%P;
}
}
f[i][0]=t[0];
for(int j=1;j<=500;j++) {
f[i][j]=t[j];
if(f[i][j].sum==f[i][j-1].sum) f[i][j].num=(f[i][j].num+f[i][j-1].num)%P;
else if(f[i][j].sum<f[i][j-1].sum) {
f[i][j].sum=f[i][j-1].sum;
f[i][j].num=t[j].num;
}
}
}
for(int i=1;i<=500;i++) {
t[i]=(B){0,0};
}
t[0]=(B){0,1};
for(int i=mid+1;i<=r;i++) {
for(int j=500;j>=b[i];j--) {
if(t[j].sum<t[j-b[i]].sum+a[i]) {
t[j].sum=t[j-b[i]].sum+a[i];
t[j].num=t[j-b[i]].num;
} else if(t[j].sum==t[j-b[i]].sum+a[i]) {
t[j].num=(t[j].num+t[j-b[i]].num)%P;
}
}
for(int j=0;j<=500;j++) {
g[i][j]=t[j];
}
}
for(auto v:V[p]) {
int x=v.x,y=v.y,m=v.k;
ans[v.i]=f[x][m];
for(int i=1;i<=m;i++) {
if(ans[v.i].sum<f[x][m-i].sum+g[y][i].sum) {
ans[v.i].sum=f[x][m-i].sum+g[y][i].sum;
ans[v.i].num=(ll)f[x][m-i].num*g[y][i].num%P;
} else if(ans[v.i].sum==f[x][m-i].sum+g[y][i].sum) {
ans[v.i].num=((ll)ans[v.i].num+(ll)f[x][m-i].num*g[y][i].num)%P;
}
}
if(ans[v.i].sum==0) ans[v.i].num=0;
}
dfs(ls(p),l,mid),dfs(rs(p),mid+1,r);
}
int main() {
freopen("knapsack.in","r",stdin);
freopen("knapsack.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n;i++) {
scanf("%d%d",&a[i],&b[i]);
}
int q; scanf("%d",&q);
for(int i=1;i<=q;i++) {
int x,y,k; scanf("%d%d%d",&x,&y,&k);
int l=1,r=n,p=1;
while(l<=r) {
if(l==r) {
V[p].push_back((A){i,x,y,k});
break;
}
int mid=l+r>>1;
if(y<=mid) p=ls(p),r=mid;
else if(x>mid) p=rs(p),l=mid+1;
else {
V[p].push_back((A){i,x,y,k});
break;
}
}
}
dfs(1,1,n);
for(int i=1;i<=q;i++) {
printf("%d %d\n",ans[i].sum,ans[i].num);
}
return 0;
}