【题解】P2487 [SDOI2011]拦截导弹
题意
给定一个长度为 \(n\) 的序列,每个元素有其高度 \(h_i\) 和速度 \(v_i\)。对于两个元素 \(i, j\),当且仅当满足 \(i < j, h_i \leq h_j, v_i \leq v_j\) 时称元素 \(i\) 小于等于元素 \(j\)。已知原序列可能有多个最长不升子序列,现在随机选中一条最长不升子序列,对于每个元素,试求出其被选中的概率。
\(1 \leq n \leq 5 \times 10^4, 1 \leq h_i, v_i \leq 10^9\)
根据 hack,此题似乎没有极端数据。
思路
cdq 分治优化 dp。
令:
-
以第 \(i\) 个元素结尾的不升子序列的最大长度为 \(f1[i]\)
-
以第 \(i\) 个元素开头的最长不升子序列的最大长度为 \(f2[i]\)
-
以第 \(i\) 个元素结尾且长度为 \(f1[i]\) 的不升子序列个数为 \(g1[i]\)
-
以第 \(i\) 个元素开头且长度为 \(f2[i]\) 的不升子序列个数为 \(g2[i]\)
-
原序列最长不升子序列的个数为 \(k\)
容易发现第 \(i\) 个元素被选中的概率为 \(\frac{g1[i] \cdot g2[i]}{k}\)
朴素 dp 求 \(f1, g1\) 的复杂度是 \(O(n^2)\),考虑用 cdq 分治优化。
对于下标在 \([l, r]\) 内的状态,令 \(m = \lfloor \frac{l + r}{2} \rfloor\),考虑将它们之间的转移关系分为:
-
从 \([l, m]\) 中的状态转移到 \([l, m]\) 中的状态。
-
从 \([l, m]\) 中的状态转移到 \((m, r]\) 中的状态。
-
从 \((m, r]\) 中的状态转移到 \((m, r]\) 中的状态。
注意此时应该 按上面列举的顺序转移。原因是用以更新其他状态前,当前状态应该已经完成转移。如果按照一部分 cdq 分治的写法先处理 3 再处理 2,就会导致 \((m, r]\) 中的状态转移错误。
于是考虑处理 2。以处理 \(f1, g1\) 为例。不妨按照普通三维偏序的思路处理,先依 \(h_i\) 排序,再用数据结构维护 \(v_j \in [v_i, m]\) 的状态,其中 \(m\) 是离散化后的值域上界。注意到这里需要查询满足 \(v_j \in [v_i, m]\) 的 \(f1[j]\) 的最大值和其出现次数,因此考虑用线段树维护。
注意到先处理 2 再处理 3 时,对 \((m, r]\) 分治可能会出现左半部分状态的下标不严格小于右半部分的下标(对 \([l, r]\) 分治时排序了 \((m, r]\)),因此每次分治前还要把当前区间 按下标排序 一次。
记原序列的最长不升子序列长度为 \(l\),显然有 \(k = \sum\limits_{i = 1}^n g[i] \cdot [f[i] = l]\)
容易发现对于被原序列的至少一条最长不升子序列包含的元素 \(i\),必然有 \(f1[i] + f2[i] - 1 = l\)。因此考虑处理出 \(f2[i]\)。具体地,有:
\(x \geq y\),则 \(z - x + 1 \leq z - y + 1\)
因为 \(f2\) 的定义以 \(i\) 为开头,所以不妨将整个序列翻转。此时原本求最长不升子序列应改为求最长不降子序列。为方便起见,根据上面的式子,不妨将每个元素的 \(h_i, v_i\) 及其下标在对应值域中翻转,重新把问题转化为求最长不升子序列。上式中 \(z\) 为值域上界,\(x, y\) 为原本元素的值。
具体在代码中体现是这样的:
a[i].h = 1e9 - a[i].h + 1;
a[i].v = m - a[i].v + 1;
a[i].p = n - a[i].p + 1;
线段树可以遍历从根到每个访问过的结点的路径清空,也可以直接遍历整树,但注意遇到已经清空的结点要直接退出。
用 long long 存方案数可能会溢出,因此 \(g1, g2, k\) 等应用 double
时间复杂度 \(O(n \log^2 n)\)
代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 5e4 + 5;
struct item
{
int f, p, h, v;
double g;
} a[maxn], b[maxn];
struct node
{
int l, r, val;
double cnt;
} tree[maxn << 2];
int n, m;
int tmp[maxn], f1[maxn], f2[maxn];
double g1[maxn], g2[maxn];
bool cmp1(item a, item b) { return a.h > b.h; }
bool cmp2(item a, item b) { return a.p < b.p; }
void build(int k, int l, int r)
{
tree[k].l = l;
tree[k].r = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
}
void update(int k, int p, int w, double v)
{
if (tree[k].l == tree[k].r)
{
if (tree[k].val < w) tree[k].cnt = 0;
tree[k].val = max(tree[k].val, w);
if (tree[k].val == w) tree[k].cnt += v;
return;
}
int mid = (tree[k].l + tree[k].r) >> 1;
if (p <= mid) update(k << 1, p, w, v);
else update(k << 1 | 1, p, w, v);
tree[k].val = max(tree[k << 1].val, tree[k << 1 | 1].val);
tree[k].cnt = (tree[k].val == tree[k << 1].val ? 1 : 0) * tree[k << 1].cnt + (tree[k].val == tree[k << 1 | 1].val ? 1 : 0) * tree[k << 1 | 1].cnt;
}
void clear(int k)
{
if (!tree[k].val) return;
tree[k].val = tree[k].cnt = 0;
if (tree[k].l == tree[k].r) return;
clear(k << 1);
clear(k << 1 | 1);
}
pair<int, double> query(int k, int l, int r)
{
if ((tree[k].l >= l) && (tree[k].r <= r)) return make_pair(tree[k].val, tree[k].cnt);
int mid = (tree[k].l + tree[k].r) >> 1;
pair<int, double> res, lres, rres;
res = lres = rres = make_pair(0, 0);
if (l <= mid) lres = query(k << 1, l, r);
if (r > mid) rres = query(k << 1 | 1, l, r);
if (lres.first > rres.first) return lres;
else if (lres.first < rres.first) return rres;
return make_pair(lres.first, lres.second + rres.second);
}
void cdq(int l, int r)
{
if (l == r) return;
int mid = (l + r) >> 1;
sort(a + l, a + r + 1, cmp2);
cdq(l, mid);
clear(1);
sort(a + l, a + mid + 1, cmp1);
sort(a + mid + 1, a + r + 1, cmp1);
int i = mid + 1, j = l;
while (i <= r)
{
while ((j <= mid) && (a[j].h >= a[i].h)) update(1, a[j].v, a[j].f, a[j].g), j++;
pair<int, double> qres = query(1, a[i].v, m);
if (!qres.first)
{
i++;
continue;
}
if (a[i].f < qres.first + 1) a[i].g = 0;
a[i].f = max(a[i].f, qres.first + 1);
a[i].g += (a[i].f == qres.first + 1 ? 1 : 0) * qres.second;
i++;
}
cdq(mid + 1, r);
}
int main()
{
int res = 0;
double cnt = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
a[i].p = i;
scanf("%d%d", &a[i].h, &a[i].v);
tmp[i] = a[i].v;
}
sort(tmp + 1, tmp + n + 1);
m = unique(tmp + 1, tmp + n + 1) - tmp - 1;
build(1, 1, m);
for (int i = 1; i <= n; i++)
{
a[i].v = lower_bound(tmp + 1, tmp + m + 1, a[i].v) - tmp;
a[i].f = a[i].g = 1;
b[i] = a[i];
}
cdq(1, n);
for (int i = 1; i <= n; i++) res = max(res, a[i].f);
for (int i = 1; i <= n; i++)
if (a[i].f == res) cnt += a[i].g;
for (int i = 1; i <= n; i++)
{
f1[a[i].p] = a[i].f, g1[a[i].p] = a[i].g;
a[i] = b[n - i + 1];
a[i].h = 1e9 - a[i].h + 1;
a[i].v = m - a[i].v + 1;
a[i].p = n - a[i].p + 1;
}
cdq(1, n);
for (int i = 1; i <= n; i++) f2[n - a[i].p + 1] = a[i].f, g2[n - a[i].p + 1] = a[i].g;
printf("%d\n", res);
for (int i = 1; i <= n; i++)
if (f1[i] + f2[i] - 1 == res) printf("%lf ", g1[i] * g2[i] / cnt);
else printf("%lf ", 0.0);
return 0;
}

浙公网安备 33010602011771号