[ZJOI2019]语言——树上差分+线段树合并

题面

  LOJ#3046

解析

   题意还是很好懂的,问题也很容易转化为求每个点能到的点的个数之和,最后除以$2$即可

  考虑任意一点i能到的点的个数。这些点所组成的点集等于所有包含节点$i$的链的点集的并集。

  需要哪些信息才能维护出这个点集?

  由于每条链都包含了节点$i$,因此这个点集会组成一个连通块(暂且这么叫吧),这个连通块显然可以通过确定其边界上的点确定下来,那么问题变为如何维护出连通块的大小

  这就是在考察$dfs$序的应用了,先以$dfs$序建立线段树。借助于虚树的思想,我们维护出每一个区间内选择了点到$1$号节点的根缀的并集大小,记为$siz$。假设现在我们已知线段树中一个区间的左区间信息与右区间信息,如何得出这个区间的信息?直接把左右区间的$siz$相加显然不对,有重复的部分,重复部分的大小是多少?结合$dfs$序,可以发现重合部分就是左区间选择的点中$dfs$序最大的点$mx$与右区间选择的点中$dfs$序最小的点$mn$的$lca$的深度,即$dep[lca(mx, mn)]$,左右区间的$siz$相加再减去这个就是这个就是这个区间的$siz$,而最后的连通块的大小就是整个序列的$siz$减去$dep[lca(mx, mn)]$,$mx$是整个点集中$dfs$序最大的点,$mn$是$dfs$序最小的点。

  为了快速求$lca$可以预处理欧拉序与$RMQ$,在$O(1)$内查询$lca$,信息更新就是$log$的时间复杂度

  我们需要一个连通块的边界,这个边界显然是由所有链的边界组成的,因此在树上差分就可维护出边界

  这是对于一个点的情况。对于所有的点,因为使用了树上差分,父亲节点需要从儿子节点继承信息,于是就需要线段树合并

  总时间复杂度是$O(NlogN)$的

 代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100005;

inline int read()
{
    int ret, f=1;
    char c;
    while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1;
    ret = c-'0';
    while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0';
    return ret*f;
}

int n, m, root[maxn];
ll ans;

int head[maxn], tot;
struct edge{
    int nxt, to;
}e[maxn<<1];

void Addedge(int x, int y)
{
    e[++tot] = (edge){head[x], y};
    head[x] = tot;
}

int f[maxn];
int cur, lg[maxn<<1], st[20][maxn<<1], arr[maxn<<1], dfn[maxn], dep[maxn];

void dfs1(int x, int fa)
{
    arr[++cur] = x;
    dfn[x] = cur;
    for(int i = head[x]; i; i = e[i].nxt)
    {
        int id = e[i].to;
        if(id == fa)    continue;
        dep[id] = dep[x] + 1;    
        f[id] = x;
        dfs1(id, x);
        arr[++cur] = x;
    }
}

void RMQ()
{
    lg[0] = -1;
    for(int i = 1; i <= cur; ++i)
    {
        lg[i] = lg[i>>1] + 1;
        st[0][i] = arr[i];
    }
    for(int j = 1; j <= lg[cur]; ++j)
        for(int i = 1; i + (1 << j) - 1 <= cur; ++i)
        {
            int u = st[j-1][i], v = st[j-1][i + (1 << (j - 1))];
            st[j][i] = (dep[u] < dep[v]? u: v);
        }
}

int Get_lca(int x, int y)
{
    if(!x || !y)    return 0;    
    x = dfn[x]; y = dfn[y];
    if(x > y)    swap(x, y);
    int u = st[lg[y-x+1]][x], v = st[lg[y-x+1]][y-(1 << lg[y-x+1])+1];
    return dep[u] < dep[v]? u: v;
}

int ndnum, tim;
struct seg_tree{
    int ls, rs, siz, mx, mn, num;
}tr[maxn*64];

void update(int x)
{
    int lson = tr[x].ls, rson = tr[x].rs;
    tr[x].siz = tr[lson].siz + tr[rson].siz - dep[Get_lca(tr[lson].mx, tr[rson].mn)];
    tr[x].mx = (tr[rson].mx? tr[rson].mx: tr[lson].mx);
    tr[x].mn = (tr[lson].mn? tr[lson].mn: tr[rson].mn); 
}

void Modify(int x, int L, int R, int p, int w)
{
    if(L == R)
    {
        tr[x].num += w;
        tr[x].siz = (tr[x].num? dep[p]: 0);
        tr[x].mx = tr[x].mn = (tr[x].num? p: 0);
        return ;
    }
    int mid = (L + R) >> 1;
    if(dfn[p] <= mid)
    {
        if(!tr[x].ls)    tr[x].ls = ++ ndnum;
        Modify(tr[x].ls, L, mid, p, w);
    }
    else
    {
        if(!tr[x].rs)    tr[x].rs = ++ ndnum;
        Modify(tr[x].rs, mid + 1, R, p, w);
    }
    update(x);
}

int Merg(int x, int y, int L, int R)
{
    if(!x || !y)    return x + y;
    if(L == R)
    {
        tr[x].num += tr[y].num;
        tr[x].siz = (tr[x].num? dep[arr[L]]: 0);
        tr[x].mx = tr[x].mn = (tr[x].num? arr[L]: 0);
        return x;
    }
    int mid = (L + R) >> 1;
    tr[x].ls = Merg(tr[x].ls, tr[y].ls, L, mid);
    tr[x].rs = Merg(tr[x].rs, tr[y].rs, mid + 1, R);
    update(x);
    return x;
}

void dfs2(int x)
{
    for(int i = head[x]; i; i = e[i].nxt)
    {
        int id = e[i].to;
        if(id == f[x])    continue;
        dfs2(id);
        root[x] = Merg(root[x], root[id], 1, cur);
    }    
    ans += (ll)tr[root[x]].siz - dep[Get_lca(tr[root[x]].mx, tr[root[x]].mn)];
}
int main()
{
    n = read(); m = read();
    for(int i = 1; i < n; ++i)
    {
        int u = read(), v = read();
        Addedge(u, v);
        Addedge(v, u);
    }    
    dfs1(1, 0);
    RMQ();
    for(int i = 1; i <= n; ++i)
        root[i] = ++ ndnum;
    while(m --)
    {
        int u = read(), v = read(), lca = Get_lca(u, v);
        Modify(root[u], 1, cur, u, 1);
        Modify(root[u], 1, cur, v, 1);
        Modify(root[v], 1, cur, u, 1);
        Modify(root[v], 1, cur, v, 1);
        Modify(root[lca], 1, cur, u, -1);
        Modify(root[lca], 1, cur, v, -1);
        if(lca != 1)
        {
            Modify(root[f[lca]], 1, cur, u, -1);
            Modify(root[f[lca]], 1, cur, v, -1);
        }
    }
    dfs2(1);
    printf("%lld\n", ans >> 1);
    return 0;
}
View Code

 

posted @ 2019-10-12 23:53  Mr_Joker  阅读(160)  评论(0编辑  收藏  举报