CF618F Double Knapsack

给定长度为 \(n\),值域为 \([1,n]\) 的整数序列 \(A\)\(B\)。你需要找到 \(A\) 的非空子序列 \([a_{p_1},a_{p_2},\dots,a_{p_x}]\)\(B\) 的非空子序列 \([b_{q_1},b_{q_2},\dots,b_{q_y}]\),使得 \(\sum\limits_{i=1}^{x} a_{p_i}=\sum\limits_{i=1}^{y} b_{q_i}\)。若无解返回 \(-1\)

\(1 \leq n \leq 10^6\)

首先我们证明如下命题:

一定存在 \(A\) 的非空子段 \([a_{l_a},a_{l_a+1},\dots,a_{r_a}]\)\(B\) 的非空子段 \([b_{l_b},b_{l_b+1},\dots,b_{r_b}]\),使得 \(\sum\limits_{i=l_a}^{r_a} a_i=\sum\limits_{i=l_b}^{r_b} b_i\)

考虑 \(A\)\(B\) 的前缀和数组 \(SA\)\(SB\)。容易发现 \(sa_0=sb_0\)。不妨令 \(sa_n \leq sb_n\),此时对于每个 \(i \in [0,n]\),必然存在 \(c_i\),使得 \(c_i\) 是满足 \(sa_i \geq sb_{c_i}\) 的最大下标。容易发现 \(c_i \in [0,n]\)。此时一定有 \(sa_i < sb_{c_i+1} = sb_{c_i} + b_{c_{i}+1}\),有 \(sa_i-sb_{c_i}<b_{c_i+1}\),又下标 \(i\)\(n+1\) 个,而 \(sa_i-sb_{c_i} \geq 0\)\(sa_i-sb_{c_i} < b_{c_i+1} \leq n\) 说明了不同的 \(sa_{i}-sb_{c_i}\) 只有 \(n\) 个,必然会有两个下标 \(i<j\) 满足 \(sa_i-sb_{c_i}=sa_j-sb_{c_j}\),即 \(sa_j-sa_i=sb_{c_j}-sb_{c_i}\),取 \(A\) 的子段 \([i+1,j]\)\(B\) 的子段 \([c_i+1,c_j]\) 即可。

使用双指针维护 \(c_i\),时间复杂度 \(O(n)\)

#include<iostream>
#include<cstdio>
using namespace std;
int a[1000010],b[1000010],c[1000010],lst[1000010];
long long sa[1000010],sb[1000010];
int main(){
	int n;
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
		sa[i]=sa[i-1]+a[i];
	}
	for(int i=1;i<=n;i++){
		scanf("%d",&b[i]);
		sb[i]=sb[i-1]+b[i];
	}
	bool rev=false;
	if(sa[n]>sb[n]){
		rev=true;
		for(int i=1;i<=n;i++){
			swap(a[i],b[i]);
			swap(sa[i],sb[i]);
		}
	}
	for(int i=0;i<n;i++){
		lst[i]=-1;
	}
	lst[0]=0;
	for(int i=1;i<=n;i++){
		c[i]=c[i-1];
		while(c[i]<n  &&  sb[c[i]+1]<=sa[i]){
			c[i]++;
		}
		int val=sa[i]-sb[c[i]];
		if(lst[val]!=-1){
			int la=lst[val]+1,ra=i;
			int lb=c[lst[val]]+1,rb=c[i];
			if(!rev){
				printf("%d\n",ra-la+1);
				for(int j=la;j<=ra;j++){
					printf("%d ",j);
				}
				printf("\n");
				printf("%d\n",rb-lb+1);
				for(int j=lb;j<=rb;j++){
					printf("%d ",j);
				}
				printf("\n");
			}
			else{
				printf("%d\n",rb-lb+1);
				for(int j=lb;j<=rb;j++){
					printf("%d ",j);
				}
				printf("\n");
				printf("%d\n",ra-la+1);
				for(int j=la;j<=ra;j++){
					printf("%d ",j);
				}
				printf("\n");
			}
			return 0;
		}
		lst[val]=i;
	}
	return 0;
}
posted @ 2025-12-11 14:57  Oken喵~  阅读(3)  评论(0)    收藏  举报