QOJ11111 + and × with a sugar / + , × 与糖 题解 [ 蓝 ] [ 线性 DP ] [ 结论题 ] [ 值域分治 ]
+ and × with a sugar / + , × 与糖:神奇结论题。
首先可以发现,大多数情况下直接用乘法肯定是更优的。于是我们可以先来研究什么情况下用加法可能是更优的。
假设当前某一个长度为 \(n\) 的段的乘积为 \(x\),且首尾两个元素都不为 \(1\)(因为首尾的 \(1\) 从这一段剥离一定更优),那么我们尝试构造出和最大的方案。因为在和一定的时候,两数差越小,乘积越大;所以在乘积一定的时候,两数差越大,和越大。由此可以想到,对于一个长度为 \(n\) 的段,我们把头尾设为最大值,使得他们的乘积为 \(x\),然后中间全放 \(1\),就能取到和的最大值。
不难估计和的最大值的上界,首先全 \(1\) 段的和显然是 \(n-2\),而因为头尾乘积为 \(x\),所以头尾差最小肯定最优,都取 \(\sqrt{x}\) 即可。总和为 \(n-2 + 2\sqrt{x}\)。如果总和要比乘积小的话,要满足如下条件:
\[n-2 + 2\sqrt{x} \le x
\]
随便代入几个值,可以发现如果 \(x = 2n\) 了,那么这个不等式对于 \(n \ge 0\) 的情况就是成立的,因此大于 \(2n\) 的就更加成立了。具体的可以参考下面的函数:

这当然不是一个很紧的界,只是这个性质已经足以将这题做出了。
对于一个段,我们可以先去掉首尾的 \(1\),然后对于中间的段分类讨论:
- 若中间段的乘积 \(>2n\) 了,那么可以直接将中间段相乘,然后与左右两段的 \(1\) 相加。时间复杂度 \(O(n)\)。
- 否则就说明中间段的乘积 \(\le 2n\),可以发现此时 \(>1\) 的数大概最多有 \(\log 2n\) 个,于是我们可以对这 \(\log\) 个数进行 DP,具体地,定义 $dp_{i} $ 表示考虑到第 \(i\) 个 \(>1\) 的数时的最大答案,然后枚举这一段的左端点 \(j\) 进行转移即可。时间复杂度 \(O(\log^2 n)\)。
总体时间复杂度 \(O(n + \log^2 n)\)。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N = 200005;
const ll mod = 1e9 + 7;
int n, lx, rx;
ll a[N], ans, mul, dp[105], pos[105], cnt, pre[105];
void sub()
{
cnt = 0;
memset(dp, 0, sizeof(dp));
for(int i = lx + 1; i <= rx - 1; i++)
{
if(a[i] > 1)
{
pos[++cnt] = i;
if(cnt == 1) pre[cnt] = 0;
else pre[cnt] = i - 1 - pos[cnt - 1];
}
}
for(int i = 1; i <= cnt; i++)
{
ll sufmul = 1;
for(int j = i; j >= 1; j--)
{
sufmul *= a[pos[j]];
dp[i] = max(dp[i], dp[j - 1] + pre[j] + sufmul);
}
}
ans = (ans + dp[cnt]) % mod;
cout << ans << "\n";
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
ans = 0;
lx = 0, rx = n + 1;
for(int i = 1; i <= n; i++)
{
if(a[i] == 1) ans++, lx = i;
else break;
}
for(int i = n; i >= 1; i--)
{
if(a[i] == 1) ans++, rx = i;
else break;
}
if(ans > n)
{
cout << n << '\n';
return;
}
mul = 1;
for(int i = lx + 1; i <= rx - 1; i++)
{
mul *= a[i];
if(mul > 2 * n) break;
}
if(mul <= 2 * n)
{
sub();
return;
}
mul = 1;
for(int i = lx + 1; i <= rx - 1; i++) mul = (mul * a[i]) % mod;
ans = (ans + mul) % mod;
cout << ans << "\n";
}
int main()
{
//freopen("sample.in","r",stdin);
//freopen("sample.out","w",stdout);
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while(t--) solve();
return 0;
}

浙公网安备 33010602011771号