CF1366E 两个数组
1 CF1366E 两个数组
2 题目描述
时间限制 \(2s\) | 空间限制 \(256M\)
给出两个数组 \(𝑎_1,𝑎_2,…,𝑎_𝑛\) 和 \(𝑏_1,𝑏_2,…,𝑏_𝑚\)。数组 \(b\) 是按照升序排好序。你需要把数组 \(a\) 分成 \(m\) 个连续子序列,使得对于每个 \(i\) 从 \(1\) 到 \(m\),第 \(i\) 个字串的最小值和 \(b_i\) 相等。注意每个元素严格的属于某个子序列,他们是这么生成的:\(a\) 的前几个元素构成第一个子序列,后面跟着的几个元素生成第二个子序列,以此类推。例如,如果 \(𝑎=[12,10,20,20,25,30]\), \(𝑏=[10,20,30]\),那么 \(a\) 的两个好的分割如下:
- \([12,10,20],[20,25],[30]\);
- \([12,10],[20,20,25],[30]\)。
计算字符串 \(a\) 分割的方法的个数,结果对 \(998244353\) 取模。
数据范围:\(1≤𝑛,𝑚≤2\times 10^5\), \(1≤𝑎_𝑖,b_i≤10^9\), \(𝑏_𝑖<\)𝑏\(_𝑖\)\(_+\)\(_1\)。
3 题解
我们先来思考答案为 \(0\) 的情况。如果 \(n < m\),那我们必然无法在 \(a\) 中找到最小值等于 \(b\) 中元素的 $m $ 段区间。如果我们在 \(b_1\) 这个值在 \(a\) 数组出现之前出现了一个比 \(b_1\) 还小的数,那么我们无论如何也得把这个数并到第一个区间内,而第一个区间的最小值也就无论如何都不是 \(b_1\) 了。如果在 \(b_i\) 和 \(b_{i+1}\) 之间存在一个比 \(b_i\) 小的数,那么我们必定无法将第 \(i\) 个区间和第 \(i+1\) 个区间的最小值都处理好:\(b_i < b_{i+1}\),且必定有一个区间包含这个较小的数。同时,如果在 \(b\) 中的某个数并不在 \(a\) 中出现过,我们肯定无法将某个区间的最小值变为这个 \(b\) 数组中的数。上述的这些情况答案都为 \(0\)。
我们考虑 \(b_i\) 和 \(b_{i+1}\) 之间的可能方案。我们可以通过将 \(b_{i+1}\) 或者在 \(b_i\) 与 \(b_{i+1}\) 中出现过的与 \(b_{i+1}\) 相邻的大于 \(b_{i+1}\) 的数分给第 \(i+1\) 个区间或者第 \(i\) 个区间来增加方案数。我们最后将得到的 \(b_i\) 与 \(b_{i+1}\) 之间的可能方案数乘上之前所有可能方案数之积,得到的就是到前 \(i+1\) 个数中可能的总方案数。
4 代码(空格警告):
#include <iostream>
using namespace std;
const int N = 2e5+10;
const int mod = 998244353;
#define int long long
int n, m, pos, ans, cnt, cnt2, flag;
int a[N], b[N], s[N], len[N];
signed main()
{
cin >> n >> m;
if (n < m)
{
cout << 0;
return 0;
}
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= m; i++) cin >> b[i];
for (int i = 1; i <= n; i++)
{
if (a[i] < b[1])
{
cout << 0;
return 0;
}
if (a[i] == b[1])
{
pos = i;
break;
}
}
cnt = 2;
if (pos == 0)
{
cout << 0;
return 0;
}
for (int i = pos+1; i <= n; i++)
{
if (a[i] == b[cnt])
{
if (a[i+1] != b[cnt]) cnt++;
}
else
{
if (a[i] < b[cnt-1])
{
cout << 0;
return 0;
}
}
}
if (cnt != m+1)
{
cout << 0;
return 0;
}
ans = 1;
cnt = 1;
for (int i = 1; i <= n; i++)
{
if (a[i] == b[cnt] && !flag)
{
len[cnt]++;
flag = 1;
s[cnt] = i;
}
else if (a[i] == b[cnt])
{
len[cnt]++;
}
else
{
flag = 0;
if (a[i] == b[cnt+1])
{
cnt++;
len[cnt]++;
flag = 1;
s[cnt] = i;
}
}
}
cnt = 2;
cnt2 = 0;
for (int i = pos+1; i <= n; i++)
{
if (i == s[cnt])
{
ans = (ans * (cnt2 + len[cnt]) % mod) % mod;
i += len[cnt] - 1;
cnt2 = 0;
cnt++;
}
else
{
if (a[i] > b[cnt])
{
cnt2++;
}
if (a[i] < b[cnt])
{
cnt2 = 0;
}
}
}
cout << ans;
return 0;
}
欢迎关注我的公众号


浙公网安备 33010602011771号