[bzoj4569][SCOI2016]萌萌哒-并查集+倍增

Brief Description

一个长度为n的大数,用S1S2S3...Sn表示,其中Si表示数的第i位,S1是数的最高位,告诉你一些限制条件,每个条
件表示为四个数,l1,r1,l2,r2,即两个长度相同的区间,表示子串Sl1Sl1+1Sl1+2...Sr1与Sl2Sl2+1Sl2+2...S
r2完全相同。比如n=6时,某限制条件l1=1,r1=3,l2=4,r2=6,那么123123,351351均满足条件,但是12012,13
1141不满足条件,前者数的长度不为6,后者第二位与第五位不同。问满足以上所有条件的数有多少个。

Algorithm Design

朴素的想法是使用并查集合并相同位置元素, 那么设集合的个数是\(cnt\), 那么答案就是
\[9*10^{cnt-1}\].
这样的话有许多重复计算, 观察数据范围, 我们发现我们大致需要一个\(O(mlogn)\)的算法, 然后, 首先想到了使用线段树, 但是线段树分割不灵活, 可以证明, 复杂度仍然是\(O(mn)\)的, 与朴素算法无异, 而且这个上界比较紧. 考虑使用类似ST表的算法, 令\(f[i][j]\)表示从i开始长度为\(2^j\)的区间, 那么每一个限制都可以拆成两个区间分别merge. 我们把这些f按照j的大小建树, 某个节点的两个孩子拼起来可以得到自身.这样的话可以证明复杂度是\(O(mlogn)\)的.
操作处理完了之后, 我们以区间长度为关键字从大到小依次考察每个区间, 如果他并不是一个独立的区间, 那么我们把对于他的限制下放到他的两个儿子那里去, 直到区间长度为1, 然后统计一下集合的个数就好了.
时间复杂度\(O(nlogn\alpha(n)+m\alpha (n))\)

Code

#include <cstdio>
const int mod = 1e9 + 7;
const int maxn = 1e5 + 1e2;
int n, m, cnt;
int f[maxn][18], log_2[maxn], lc[maxn * 18], rc[maxn * 18], fa[maxn * 18];
void init() {
  cnt = 0;
  for (int i = 1; i <= n; i++) {
    if ((1 << (cnt + 1)) <= i)
      cnt++;
    log_2[i] = cnt;
  }
  cnt = 0;
  for (int j = 0; (1 << j) <= n; j++) {
    for (int i = 1; i + (1 << (j)) - 1 <= n; i++) {
      f[i][j] = ++cnt;
      if (j) {
        lc[cnt] = f[i][j - 1];
        rc[cnt] = f[i + (1 << (j - 1))][j - 1];
      }
    }
  }
  for (int i = 1; i <= cnt; i++)
    fa[i] = i;
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
void merge(int a, int b) {
  int x = find(a), y = find(b);
  if (x != y)
    fa[x] = y;
}
int main() {
  scanf("%d %d", &n, &m);
  init();
  while (m--) {
    int s1, s2, e1, e2;
    scanf("%d %d %d %d", &s1, &e1, &s2, &e2);
    int len = log_2[e1 - s1 + 1];
    merge(f[s1][len], f[s2][len]);
    merge(f[e1 - (1 << len) + 1][len], f[e2 - (1 << len) + 1][len]);
  }
  int t;
  for (int i = cnt; i > n; i--) {
    if ((t = find(i)) != i) {
      merge(lc[i], lc[t]);
      merge(rc[i], rc[t]);
    }
  }
  t = 0;
  for (int i = 1; i <= n; i++) {
    t += ((find(i)) == i);
  }
  int ans = 9;
  for (int i = 1; i < t; i++)
    ans = (1ll * ans * 10) % mod;
  printf("%d\n", ans);
}

posted on 2017-03-22 21:14 蒟蒻konjac 阅读(...) 评论(...) 编辑 收藏

统计