Loading

【题解】P2487 [SDOI2011]拦截导弹

题意

P2487 [SDOI2011]拦截导弹

给定一个长度为 \(n\) 的序列,每个元素有其高度 \(h_i\) 和速度 \(v_i\)。对于两个元素 \(i, j\),当且仅当满足 \(i < j, h_i \leq h_j, v_i \leq v_j\) 时称元素 \(i\) 小于等于元素 \(j\)。已知原序列可能有多个最长不升子序列,现在随机选中一条最长不升子序列,对于每个元素,试求出其被选中的概率。

\(1 \leq n \leq 5 \times 10^4, 1 \leq h_i, v_i \leq 10^9\)

根据 hack,此题似乎没有极端数据。

思路

cdq 分治优化 dp。

令:

  1. 以第 \(i\) 个元素结尾的不升子序列的最大长度为 \(f1[i]\)

  2. 以第 \(i\) 个元素开头的最长不升子序列的最大长度为 \(f2[i]\)

  3. 以第 \(i\) 个元素结尾且长度为 \(f1[i]\) 的不升子序列个数为 \(g1[i]\)

  4. 以第 \(i\) 个元素开头且长度为 \(f2[i]\) 的不升子序列个数为 \(g2[i]\)

  5. 原序列最长不升子序列的个数为 \(k\)

容易发现第 \(i\) 个元素被选中的概率为 \(\frac{g1[i] \cdot g2[i]}{k}\)

朴素 dp 求 \(f1, g1\) 的复杂度是 \(O(n^2)\),考虑用 cdq 分治优化。

对于下标在 \([l, r]\) 内的状态,令 \(m = \lfloor \frac{l + r}{2} \rfloor\),考虑将它们之间的转移关系分为:

  1. \([l, m]\) 中的状态转移到 \([l, m]\) 中的状态。

  2. \([l, m]\) 中的状态转移到 \((m, r]\) 中的状态。

  3. \((m, r]\) 中的状态转移到 \((m, r]\) 中的状态。

注意此时应该 按上面列举的顺序转移。原因是用以更新其他状态前,当前状态应该已经完成转移。如果按照一部分 cdq 分治的写法先处理 3 再处理 2,就会导致 \((m, r]\) 中的状态转移错误。

于是考虑处理 2。以处理 \(f1, g1\) 为例。不妨按照普通三维偏序的思路处理,先依 \(h_i\) 排序,再用数据结构维护 \(v_j \in [v_i, m]\) 的状态,其中 \(m\) 是离散化后的值域上界。注意到这里需要查询满足 \(v_j \in [v_i, m]\)\(f1[j]\) 的最大值和其出现次数,因此考虑用线段树维护。

注意到先处理 2 再处理 3 时,对 \((m, r]\) 分治可能会出现左半部分状态的下标不严格小于右半部分的下标(对 \([l, r]\) 分治时排序了 \((m, r]\)),因此每次分治前还要把当前区间 按下标排序 一次。

记原序列的最长不升子序列长度为 \(l\),显然有 \(k = \sum\limits_{i = 1}^n g[i] \cdot [f[i] = l]\)

容易发现对于被原序列的至少一条最长不升子序列包含的元素 \(i\),必然有 \(f1[i] + f2[i] - 1 = l\)。因此考虑处理出 \(f2[i]\)。具体地,有:

\(x \geq y\),则 \(z - x + 1 \leq z - y + 1\)

因为 \(f2\) 的定义以 \(i\) 为开头,所以不妨将整个序列翻转。此时原本求最长不升子序列应改为求最长不降子序列。为方便起见,根据上面的式子,不妨将每个元素的 \(h_i, v_i\) 及其下标在对应值域中翻转,重新把问题转化为求最长不升子序列。上式中 \(z\) 为值域上界,\(x, y\) 为原本元素的值。

具体在代码中体现是这样的:

a[i].h = 1e9 - a[i].h + 1;
a[i].v = m - a[i].v + 1;
a[i].p = n - a[i].p + 1;

线段树可以遍历从根到每个访问过的结点的路径清空,也可以直接遍历整树,但注意遇到已经清空的结点要直接退出。

long long 存方案数可能会溢出,因此 \(g1, g2, k\) 等应用 double

时间复杂度 \(O(n \log^2 n)\)

代码

#include <cstdio>
#include <algorithm>
using namespace std;

const int maxn = 5e4 + 5;

struct item
{
    int f, p, h, v;
    double g;
} a[maxn], b[maxn];

struct node
{
    int l, r, val;
    double cnt;
} tree[maxn << 2];

int n, m;
int tmp[maxn], f1[maxn], f2[maxn];
double g1[maxn], g2[maxn];

bool cmp1(item a, item b) { return a.h > b.h; }

bool cmp2(item a, item b) { return a.p < b.p; }

void build(int k, int l, int r)
{
    tree[k].l = l;
    tree[k].r = r;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
}

void update(int k, int p, int w, double v)
{
    if (tree[k].l == tree[k].r)
    {
        if (tree[k].val < w) tree[k].cnt = 0;
        tree[k].val = max(tree[k].val, w);
        if (tree[k].val == w) tree[k].cnt += v;
        return;
    }
    int mid = (tree[k].l + tree[k].r) >> 1;
    if (p <= mid) update(k << 1, p, w, v);
    else update(k << 1 | 1, p, w, v);
    tree[k].val = max(tree[k << 1].val, tree[k << 1 | 1].val);
    tree[k].cnt = (tree[k].val == tree[k << 1].val ? 1 : 0) * tree[k << 1].cnt + (tree[k].val == tree[k << 1 | 1].val ? 1 : 0) * tree[k << 1 | 1].cnt;
}

void clear(int k)
{
    if (!tree[k].val) return;
    tree[k].val = tree[k].cnt = 0;
    if (tree[k].l == tree[k].r) return;
    clear(k << 1);
    clear(k << 1 | 1);
}

pair<int, double> query(int k, int l, int r)
{
    if ((tree[k].l >= l) && (tree[k].r <= r)) return make_pair(tree[k].val, tree[k].cnt);
    int mid = (tree[k].l + tree[k].r) >> 1;
    pair<int, double> res, lres, rres;
    res = lres = rres = make_pair(0, 0);
    if (l <= mid) lres = query(k << 1, l, r);
    if (r > mid) rres = query(k << 1 | 1, l, r);
    if (lres.first > rres.first) return lres;
    else if (lres.first < rres.first) return rres;
    return make_pair(lres.first, lres.second + rres.second);
}

void cdq(int l, int r)
{
    if (l == r) return;
    int mid = (l + r) >> 1;
    sort(a + l, a + r + 1, cmp2);
    cdq(l, mid);
    clear(1);
    sort(a + l, a + mid + 1, cmp1);
    sort(a + mid + 1, a + r + 1, cmp1);
    int i = mid + 1, j = l;
    while (i <= r)
    {
        while ((j <= mid) && (a[j].h >= a[i].h)) update(1, a[j].v, a[j].f, a[j].g), j++;
        pair<int, double> qres = query(1, a[i].v, m);
        if (!qres.first)
        {
            i++;
            continue;
        }
        if (a[i].f < qres.first + 1) a[i].g = 0;
        a[i].f = max(a[i].f, qres.first + 1);
        a[i].g += (a[i].f == qres.first + 1 ? 1 : 0) * qres.second;
        i++;
    }
    cdq(mid + 1, r);
}

int main()
{
    int res = 0;
    double cnt = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        a[i].p = i;
        scanf("%d%d", &a[i].h, &a[i].v);
        tmp[i] = a[i].v;
    }
    sort(tmp + 1, tmp + n + 1);
    m = unique(tmp + 1, tmp + n + 1) - tmp - 1;
    build(1, 1, m);
    for (int i = 1; i <= n; i++)
    {
        a[i].v = lower_bound(tmp + 1, tmp + m + 1, a[i].v) - tmp;
        a[i].f = a[i].g = 1;
        b[i] = a[i];
    }
    cdq(1, n);
    for (int i = 1; i <= n; i++) res = max(res, a[i].f);
    for (int i = 1; i <= n; i++)
        if (a[i].f == res) cnt += a[i].g;
    for (int i = 1; i <= n; i++)
    {
        f1[a[i].p] = a[i].f, g1[a[i].p] = a[i].g;
        a[i] = b[n - i + 1];
        a[i].h = 1e9 - a[i].h + 1;
        a[i].v = m - a[i].v + 1;
        a[i].p = n - a[i].p + 1;
    }
    cdq(1, n);
    for (int i = 1; i <= n; i++) f2[n - a[i].p + 1] = a[i].f, g2[n - a[i].p + 1] = a[i].g;
    printf("%d\n", res);
    for (int i = 1; i <= n; i++)
        if (f1[i] + f2[i] - 1 == res) printf("%lf ", g1[i] * g2[i] / cnt);
        else printf("%lf ", 0.0);
    return 0;
}
posted @ 2022-07-11 16:57  kymru  阅读(107)  评论(0)    收藏  举报