「6月雅礼集训 2017 Day2」A

【题目大意】

给出一棵树,求有多少对点(u,v)满足其路径上不存在两个点a,b满足(a,b)=1

n<=10^5

【题解】

考虑找出所有不符合的点对,共有n*ln(n)对,他们要么是祖先->儿子边,要么是不是。

考虑祖先->儿子边,那么一个点在祖先以上,一个点在儿子以下的点对全部无法访问。

考虑另外一种边,就是LCA不是两个端点的,这就比较好统计了,两个点在这两棵子树的点对无法访问。

考虑用DFS序,这样子树就是连续的一段(祖先以上是连续两段)

然后就是一个二维覆盖问题,用扫描线+线段树即可解决。

复杂度O(nln(n)logn)

注意。。扫描线数组要开到 4 * n * ln(n) 不然。。会奇怪的WA/RE。。。

# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;

# define RG register
# define ST static

const int M = 2e5 + 10, N = 1e5 + 10, Max = 8 * M;
const int mod = 998244353;

int n, head[N], nxt[M], to[M], tot = 0;
inline void add(int u, int v) {
    ++tot; nxt[tot] = head[u]; head[u] = tot; to[tot] = v;
}
inline void adde(int u, int v) {
    add(u, v), add(v, u);
}

int in[N], out[N], DFN = 0;
int dep[N], fa[N][19];
inline void dfs(int x, int fat = 0) {
    in[x] = ++DFN; dep[x] = dep[fat] + 1;
    fa[x][0] = fat;
    for (int i=1; i<=18; ++i) fa[x][i] = fa[fa[x][i-1]][i-1];
    for (int i=head[x]; i; i=nxt[i]) {
        if(to[i] == fat) continue;
        dfs(to[i], x);
    }
    out[x] = DFN;
}

inline int lca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for (int i=18; ~i; --i)
        if((dep[u] - dep[v]) & (1<<i)) u = fa[u][i];
    if(u == v) return u;
    for (int i=18; ~i; --i)
        if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}

inline int jump(int u, int anc) {
    for (int i=18; ~i; --i)
        if(dep[fa[u][i]] > dep[anc]) u = fa[u][i];
    return u;
}

struct pa {
    int x, yl, yr, d;
    pa() {}
    pa(int x, int yl, int yr, int d) : x(x), yl(yl), yr(yr), d(d) {}
    friend bool operator < (pa a, pa b) {
        return a.x < b.x;
    }
}p[Max * 4]; int pn = 0;

inline void ADD(int xl, int xr, int yl, int yr) {
    p[++pn] = pa(xl, yl, yr, 1);
    p[++pn] = pa(xr+1, yl, yr, -1); 
}

inline void doit(int x, int y) {
    int par = lca(x, y);
//    if(par == -1) cout << x << ' ' << y << endl;
    if(dep[x] > dep[y]) swap(x, y);
    if(x == par) {
        int pars = jump(y, par);
        ADD(1, in[pars] - 1, in[y], out[y]);
        ADD(in[y], out[y], out[pars] + 1, n);
    } else {
        if(in[x] > in[y]) swap(x, y);
        ADD(in[x], out[x], in[y], out[y]);
    }
}


struct SMT {
    int w[Max], tag[Max];
    # define ls (x<<1)
    # define rs (x<<1|1)
    inline void set() {
        memset(w, 0, sizeof w);
        memset(tag, 0, sizeof tag);
    }
    inline int gs(int x, int l, int r) {
        if(tag[x]) return r-l+1;
        else return w[x];
    }
    inline void edt(int x, int l, int r, int L, int R, int d) {
        if(L > R) return ;
        if(L <= l && r <= R) {tag[x] += d; return ;}
        int mid = l+r>>1;
        if(L <= mid) edt(ls, l, mid, L, R, d);
        if(R > mid) edt(rs, mid+1, r, L, R, d);
        w[x] = gs(ls, l, mid) + gs(rs, mid+1, r);
    }
    inline int sum(int x, int l, int r, int L, int R) {
        if(L > R) return 0;
        if(tag[x]) return min(R, r) - max(L, l) + 1;
        if(L <= l && r <= R) return gs(x, l, r);
        int mid = l+r>>1, ret = 0;
        if(L <= mid) ret += sum(ls, l, mid, L, R);
        if(R > mid) ret += sum(rs, mid+1, r, L, R);
        return ret;
    }
    # undef ls
    # undef rs
}T;

int main() {
//    freopen("A.in", "r", stdin);
//    freopen("A.out", "w", stdout);
    cin >> n;
    for (int i=1, u, v; i<n; ++i) {
        scanf("%d%d", &u, &v);
        adde(u, v);
    }
    dfs(1, 0);
    for (int i=1; i<=n; ++i)
        for (int j=i+i; j<=n; j+=i) doit(i, j);
    
    sort(p+1, p+pn+1); T.set();
    ll ans = (ll)n * (n-1) / 2;
    for (int i=1, j=1; i<=n; ++i) {
        while(j<=pn && p[j].x == i) T.edt(1, 1, n, p[j].yl, p[j].yr, p[j].d), ++j;
        ans -= T.sum(1, 1, n, i+1, n);
    }
    cout << ans;
    return 0;
}
View Code

 

posted @ 2017-06-18 16:57  Galaxies  阅读(305)  评论(0编辑  收藏  举报