QOJ11111 + and × with a sugar / + , × 与糖 题解 [ 蓝 ] [ 线性 DP ] [ 结论题 ] [ 值域分治 ]

+ and × with a sugar / + , × 与糖:神奇结论题。

首先可以发现,大多数情况下直接用乘法肯定是更优的。于是我们可以先来研究什么情况下用加法可能是更优的。

假设当前某一个长度为 \(n\) 的段的乘积为 \(x\),且首尾两个元素都不为 \(1\)(因为首尾的 \(1\) 从这一段剥离一定更优),那么我们尝试构造出和最大的方案。因为在和一定的时候,两数差越小,乘积越大;所以在乘积一定的时候,两数差越大和越大。由此可以想到,对于一个长度为 \(n\) 的段,我们把头尾设为最大值,使得他们的乘积为 \(x\),然后中间全放 \(1\),就能取到和的最大值。

不难估计和的最大值的上界,首先全 \(1\) 段的和显然是 \(n-2\),而因为头尾乘积为 \(x\),所以头尾差最小肯定最优,都取 \(\sqrt{x}\) 即可。总和为 \(n-2 + 2\sqrt{x}\)。如果总和要比乘积小的话,要满足如下条件:

\[n-2 + 2\sqrt{x} \le x \]

随便代入几个值,可以发现如果 \(x = 2n\) 了,那么这个不等式对于 \(n \ge 0\) 的情况就是成立的,因此大于 \(2n\) 的就更加成立了。具体的可以参考下面的函数:

image

这当然不是一个很紧的界,只是这个性质已经足以将这题做出了。

对于一个段,我们可以先去掉首尾\(1\),然后对于中间的段分类讨论

  • 若中间段的乘积 \(>2n\) 了,那么可以直接将中间段相乘,然后与左右两段的 \(1\) 相加。时间复杂度 \(O(n)\)
  • 否则就说明中间段的乘积 \(\le 2n\),可以发现此时 \(>1\) 的数大概最多有 \(\log 2n\) 个,于是我们可以对这 \(\log\) 个数进行 DP,具体地,定义 $dp_{i} $ 表示考虑到第 \(i\)\(>1\) 的数时的最大答案,然后枚举这一段的左端点 \(j\) 进行转移即可。时间复杂度 \(O(\log^2 n)\)

总体时间复杂度 \(O(n + \log^2 n)\)

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N = 200005;
const ll mod = 1e9 + 7;
int n, lx, rx;
ll a[N], ans, mul, dp[105], pos[105], cnt, pre[105];
void sub()
{
    cnt = 0;
    memset(dp, 0, sizeof(dp));
    for(int i = lx + 1; i <= rx - 1; i++)
    {
        if(a[i] > 1)
        {
            pos[++cnt] = i;
            if(cnt == 1) pre[cnt] = 0;
            else pre[cnt] = i - 1 - pos[cnt - 1];
        }
    }
    for(int i = 1; i <= cnt; i++)
    {
        ll sufmul = 1;
        for(int j = i; j >= 1; j--)
        {
            sufmul *= a[pos[j]];
            dp[i] = max(dp[i], dp[j - 1] + pre[j] + sufmul);
        }
    }
    ans = (ans + dp[cnt]) % mod;
    cout << ans << "\n";
}
void solve()
{
    cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i];
    ans = 0;
    lx = 0, rx = n + 1;
    for(int i = 1; i <= n; i++)
    {
        if(a[i] == 1) ans++, lx = i;
        else break;
    }
    for(int i = n; i >= 1; i--)
    {
        if(a[i] == 1) ans++, rx = i;
        else break;
    }
    if(ans > n)
    {
        cout << n << '\n';
        return;
    }
    mul = 1;
    for(int i = lx + 1; i <= rx - 1; i++)
    {
        mul *= a[i];
        if(mul > 2 * n) break;
    }
    if(mul <= 2 * n)
    {
        sub();
        return;
    }
    mul = 1;
    for(int i = lx + 1; i <= rx - 1; i++) mul = (mul * a[i]) % mod;
    ans = (ans + mul) % mod;
    cout << ans << "\n";
}
int main()
{
    //freopen("sample.in","r",stdin);
    //freopen("sample.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while(t--) solve();
    return 0;
}
posted @ 2025-08-08 20:21  KS_Fszha  阅读(19)  评论(0)    收藏  举报