Loading

P4689 [Ynoi2016] 这是我自己的发明

description

给你一棵树,要求支持 \(m\) 个操作,每种操作可以换根或者查询两个结点的子树内的相同的数的数对个数。

\(1 \le n \le 10^5, 1 \le m \le 5 \times 10^5\)

solution

首先我们发现换根是假的,因为如果换的根在 \(x\) 的子树外面,那么对 \(x\) 没有影响,如果在 \(x\) 里面,那么就是在 DFS 序上挖开了一段区间(注意不是根所表示的 DFS 序,而是根所在的 \(x\) 的儿子的子树的儿子的 DFS 序区间)。

然后就变成了两个区间内相同的数的对数,这个是可以差分拆贡献然后莫队维护前缀的两个末尾 \(l, r\),注意不是区间,扫一遍算一下贡献影响就可以了。

然后注意不要开 long long,ios 神助攻。

code

#include <bits/stdc++.h>

using namespace std;

#define int long long
#define fir first
#define sec second
#define mkp make_pair 
#define pb push_back
#define lep( i, l, r ) for ( int i = ( l ); i <= ( r ); ++ i )
#define rep( i, r, l ) for ( int i = ( r ); i >= ( l ); -- i )

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair < int, int > pii;

char _c; bool _f; template < class type > inline void read ( type &x ) {
	_f = 0, x = 0;
	while ( _c = getchar (), !isdigit ( _c ) ) if ( _c == '-' ) _f = 1;
	while ( isdigit ( _c ) ) x = x * 10 + _c - '0', _c = getchar (); if ( _f ) { x = -x; }
}

template < class type > inline void chkmin ( type &x, type y ) { x = ( x <= y ? x : y ); }
template < class type > inline void chkmax ( type &x, type y ) { x = ( x >= y ? x : y ); }

const int N = 1000005;
const int B = 100;

int n, m, l, r, ans, cnt, idx, rt;
int a[N], tong1[N], tong2[N], res[N], dfn[N], siz[N], b[N], c[N], vis[N], f[N][20], dep[N];

int head[N], tot;

struct Graph {
  int to, next;
} edges[N << 1];

void add ( int u, int v ) {
  tot ++;
  edges[tot].to = v;
  edges[tot].next = head[u];
  head[u] = tot;
}

struct Node {
  int l, r, id, op;
} q[N << 2];

void addl ( int x ) {
  ans += tong2[a[x]];
  tong1[a[x]] ++;
}

void dell ( int x ) {
  ans -= tong2[a[x]];
  tong1[a[x]] --;
}

void addr ( int x ) {
  ans += tong1[a[x]];
  tong2[a[x]] ++;
}

void delr ( int x ) {
  ans -= tong1[a[x]];
  tong2[a[x]] --;
}

void help ( int l1, int r1, int l2, int r2, int i ) {
  // cout << l1 << " " << r1 << " " << l2 << " " << r2 << '\n';
  if ( l1 > r1 || l2 > r2 ) {
    return ;
  }
  q[++ cnt] = { r1, r2, i, 1 };
  q[++ cnt] = { l1 - 1, r2, i, -1 };
  q[++ cnt] = { r1, l2 - 1, i, -1 };
  q[++ cnt] = { l1 - 1, l2 - 1, i, 1 };
}

void dfs ( int x, int fa ) {
  dfn[x] = ++ idx;
  siz[x] = 1;
  f[x][0] = fa;
  dep[x] = dep[fa] + 1;
  for ( int i = 1; i <= 19; i ++ ) {
    f[x][i] = f[f[x][i - 1]][i - 1];
  }
  for ( int i = head[x]; i; i = edges[i].next ) { 
    if ( edges[i].to != fa ) {
      dfs ( edges[i].to, x );
      siz[x] += siz[edges[i].to];
    }
  }
}

pii calc ( int y, int x ) {
  for ( int i = 19; i >= 0; i -- ) {
    if ( dep[f[y][i]] > dep[x] ) {
      y = f[y][i];
    }
  }
  return { dfn[y], dfn[y] + siz[y] - 1 };
}

void Solve () {
  cin >> n >> m;
  for ( int i = 1; i <= n; i ++ ) {
    cin >> c[i];
    b[i] = c[i];
  }
  for ( int i = 1; i < n; i ++ ) {
    int u, v;
    cin >> u >> v;
    add ( u, v ), add ( v, u );
  }
  sort ( b + 1, b + 1 + n );
  int len = unique ( b + 1, b + 1 + n ) - b - 1;
  for ( int i = 1; i <= n; i ++ ) {
    c[i] = lower_bound ( b + 1, b + 1 + len, c[i] ) - b;
  }
  dfs ( 1, 0 );
  rt = 1;
  for ( int i = 1; i <= n; i ++ ) {
    a[dfn[i]] = c[i];
  }
  // cout << "a: ";
  // for ( int i = 1; i <= n; i ++ ) {
  //   cout << a[i] << " ";
  // }
  // cout << '\n';
  // cout << "dfn: ";
  // for ( int i = 1; i <= n; i ++ ) {
  //   cout << dfn[i] << " ";
  // }
  // cout << '\n';
  for ( int i = 1; i <= m; i ++ ) {
    int op;
    cin >> op;
    if ( op == 1 ) {
      cin >> rt;
    }
    else {
      vis[i] = 1;
      int l1, r1, l2, r2, x, y;
      cin >> x >> y;
      if ( rt == x && rt == y ) {
        // cout << "cy1\n";
        help ( 1, n, 1, n, i );
      }
      else if ( rt == x && dfn[y] <= dfn[rt] && dfn[rt] <= dfn[y] + siz[y] - 1 ) {
        // cout << "cy2\n";
        pii tmp = calc ( rt, y );
        help ( 1, n, 1, tmp.first - 1, i );
        help ( 1, n, tmp.second + 1, n, i );
      }
      else if ( rt == x && ( dfn[rt] < dfn[y] || dfn[rt] > dfn[y] + siz[y] - 1 ) ) {
        // cout << "cy3\n";
        help ( 1, n, dfn[y], dfn[y] + siz[y] - 1, i );
      }
      else if ( dfn[x] <= dfn[rt] && dfn[rt] <= dfn[x] + siz[x] - 1 && rt == y ) {
        // cout << "cy4\n";
        pii tmp = calc ( rt, x );
        help ( 1, tmp.first - 1, 1, n, i );
        help ( tmp.second + 1, n, 1, n, i );
      }
      else if ( dfn[x] <= dfn[rt] && dfn[rt] <= dfn[x] + siz[x] - 1 && dfn[y] <= dfn[rt] && dfn[rt] <= dfn[y] + siz[y] - 1 ) {
        // cout << "cy5\n";
        pii tmp1 = calc ( rt, x ), tmp2 = calc ( rt, y );
        help ( 1, tmp1.first - 1, 1, tmp2.first - 1, i );
        help ( 1, tmp1.first - 1, tmp2.second + 1, n, i );
        help ( tmp1.second + 1, n, 1, tmp2.first - 1, i );
        help ( tmp1.second + 1, n, tmp2.second + 1, n, i );
      }
      else if ( dfn[x] <= dfn[rt] && dfn[rt] <= dfn[x] + siz[x] - 1 && ( dfn[rt] < dfn[y] || dfn[rt] > dfn[y] + siz[y] - 1 ) ) {
        // cout << "cy6\n";
        pii tmp = calc ( rt, x );
        help ( 1, tmp.first - 1, dfn[y], dfn[y] + siz[y] - 1, i );
        help ( tmp.second + 1, n, dfn[y], dfn[y] + siz[y] - 1, i );
      }
      else if ( ( dfn[rt] < dfn[x] || dfn[rt] > dfn[x] + siz[x] - 1 ) && rt == y ) {
        // cout << "cy7\n";
        help ( dfn[x], dfn[x] + siz[x] - 1, 1, n, i );
      }
      else if ( ( dfn[rt] < dfn[x] || dfn[rt] > dfn[x] + siz[x] - 1 ) && dfn[y] <= dfn[rt] && dfn[rt] <= dfn[y] + siz[y] - 1 ) {
        // cout << "cy8\n";
        pii tmp = calc ( rt, y );
        help ( dfn[x], dfn[x] + siz[x] - 1, 1, tmp.first - 1, i );
        help ( dfn[x], dfn[x] + siz[x] - 1, tmp.second + 1, n, i );
      }
      else if ( ( dfn[rt] < dfn[x] || dfn[rt] > dfn[x] + siz[x] - 1 ) && ( dfn[rt] < dfn[y] || dfn[rt] > dfn[y] + siz[y] - 1 ) ) {
        // cout << "cy9\n";
        help ( dfn[x], dfn[x] + siz[x] - 1, dfn[y], dfn[y] + siz[y] - 1, i );
      }
    }
  }
  sort ( q + 1, q + 1 + cnt, [] ( Node x, Node y ) { return x.l / B == y.l / B ? x.r < y.r : x.l < y.l; } );
  for ( int i = 1; i <= cnt; i ++ ) {
    while ( l < q[i].l ) {
      addl ( ++ l );
    }
    while ( l > q[i].l ) {
      dell ( l -- );
    }
    while ( r < q[i].r ) {
      addr ( ++ r );
    }
    while ( r > q[i].r ) {
      delr ( r -- );
    }
    res[q[i].id] += ans * q[i].op;
  }
  for ( int i = 1; i <= m; i ++ ) {
    if ( vis[i] ) {
      cout << res[i] << '\n';
    }
  }
}

signed main () {
#ifdef judge
  freopen ( "Code.in", "r", stdin );
  freopen ( "Code.out", "w", stdout );
  freopen ( "Code.err", "w", stderr );
#endif
  Solve ();
  return 0;
}
posted @ 2024-09-12 15:56  Alexande  阅读(15)  评论(0)    收藏  举报