【题解】火枪打怪

思路

求最小值,很有可能是二分,我们再看数据范围,就应该往这方面想。
其实这道题最坑的还是 check, 是真的不好想。
我们不难写出一个 \(\Theta(n^2)\)check

bool check(int x) {
    int sum = 0;
    for (int i = 1; i <= n; i++) b[i] = a[i];
    for (int i = n; i >= 1; i--) {
        if (b[i] < 0)
            continue;
        int h = ceil((b[i] + 1) * 1.0 / x);
        b[i] -= h * x;
        for (int j = 1; j < i; j++) b[j] -= max(0ll, x - (i - j) * (i - j)) * h;
        sum += h;
    }
    return sum <= m;
}

但是这必定超时,所以我们要想办法优化我们的 check
我们可以化简式子得到:
\(p - (i-j)^2=p-i^2+2*i*j-j^2\)
我们假设射击范围为 r,
所以溅射伤害为:
\(hp_i=\sum_{i + 1}^{i+r-1}p-(i^2 * r+\sum_{j=i+1}^{i+r-1}j^2)+2*i*\sum_{j=i + 1}^{i+r-1} j\)
所以我们只需维护 \(i^2\), \(j^2\)的区间和即可。
就有代码:

bool check(int p) {
    for (int i = 1; i <= n; i++)
		b[i] = 0;
	int r = sqrt(p) + 1, s = 0, v1 = 0, v2 = 0, sum = 0;
    for (int i = n; i >= 1; i--) {
        if (i + r <= n && b[i + r]) {
        	s -= b[i + r];
        	v1 -= (i + r) * (i + r) * b[i + r];
        	v2 -= (i + r) * b[i + r];
		}
        int hp = s * p - v1 - i * i * s + v2 * i * 2;
        if (a[i] >= hp) {
        	b[i] += (a[i] - hp) / p + 1;
        	s += b[i];
        	v1 += i * i * b[i];
        	v2 += i * b[i];
        	sum += b[i];
		}
		if (sum > m)
        	break;
	}
    return sum <= m;
}

其时间复杂度为 \(\Theta(n)\)

Code

#include <cstdio>
#include <cmath>
#include <iostream>
#define int long long
using namespace std;
const int MAXN = 5e5 + 5;
int n, m;
int a[MAXN];
int b[MAXN];
bool check(int p) {
    for (int i = 1; i <= n; i++)
		b[i] = 0;
	int r = sqrt(p) + 1, s = 0, v1 = 0, v2 = 0, sum = 0;
    for (int i = n; i >= 1; i--) {
        if (i + r <= n && b[i + r]) {
        	s -= b[i + r];
        	v1 -= (i + r) * (i + r) * b[i + r];
        	v2 -= (i + r) * b[i + r];
		}
        int hp = s * p - v1 - i * i * s + v2 * i * 2;
        if (a[i] >= hp) {
        	b[i] += (a[i] - hp) / p + 1;
        	s += b[i];
        	v1 += i * i * b[i];
        	v2 += i * b[i];
        	sum += b[i];
		}
		if (sum > m)
        	break;
	}
    return sum <= m;
}
signed main() {
    scanf("%lld %lld", &n, &m);
    for (int i = 1; i <= n; i++)
		scanf("%lld", &a[i]);
    int l = 1, r = 1e15;
    while (l < r) {
        int mid = (l + r) >> 1;
        if (check(mid))
            r = mid;
        else
            l = mid + 1;
    }
    printf("%lld", l);
    return 0;
}
posted @ 2022-07-27 21:26  zhou_ziyi  阅读(30)  评论(0)    收藏  举报