2022 RoboCom 世界机器人开发者大赛-本科组(国赛)RC-u5 养老社区

题目大意是给定一棵树,每个节点有个权值。
之后在树上找到3个点,使他们两两之间距离相等,并且两两之间权值不同。
问三元组的个数

  • 首先因为是一棵树,所以我们可以通过广搜得到两两节点之间的距离。时间复杂度是O(\(n^2\)

接下来很容易想到直接找到两个权值不同的点(a, b),然后得出a b之间的距离d,然后再找到与a, b距离为d的点c,并且c的权值和a,b的权值都不同

后续的操作可以用bitset解决
B[a][d] 代表距离a节点长度为d的点的集合。
所以 B[a][d] & B[b][d] 即可得到距离a,b节点长度都为d的点的集合。
然后我们再定义一个bitset T,T[i]表示权值为i的所有点集合。
w[a]代表a点上的权值。
最后 (B[a][d] & B[b][d] & (~T[w[a]]) & (~T[w[b]]))代表的就是距离a,b节点长度为d,并且权值与a,b权值不同的点的集合。

答案即是上面公式所有代表的集合中点的个数之和。

很遗憾的是并不能开bitset<N>B[N][N]这么大的空间。

因此我们需要进行优化,将bitset降维

  • 第一点,将所有长度为d的二元组(a,b)存入vector中。
  • 第二点,将bitset降维后,长度这一维消失,B[a]即代表到a节点长度为d的点的集合,我们枚举vector中的边,修改bitset中的值。
  • 第三点,当a,b点的权值不同时,将集合(B[a] & B[b] & (~T[w[a]]) & (~T[w[b]]))点的个数加入到答案之中
  • 最后我们算出来的答案是需要除三,很明显就是某个三元组的三条边都枚举了,时间复杂度O(\(\frac{N}{w}n^2\))
AC代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <queue>
using namespace std;
typedef pair<int, int> PII;
typedef long long ll;
const int N = 2020, M = 2 * N;
int h[N], e[M], ne[M], idx;
int d[N][N], w[N];
int n;
bitset<N> T[N];
vector<PII> s[N];
void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void bfs(int x) {
    queue<int> que;
    d[x][x] = 0;
    que.push(x);
    while (!que.empty()) {
        int u = que.front(); que.pop();
        for (int i = h[u]; ~i; i = ne[i]) {
            int j = e[i];
            if (d[x][j] > d[x][u] + 1) {
                d[x][j] = d[x][u] + 1;
                que.push(j);
            }
        }
    }
}
int main() {
    memset(h, -1, sizeof(h));
    memset(d, 0x3f, sizeof(d));
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b), add(b, a);
    }
    for (int i = 1; i <= n; i++) T[i] = ~T[i], bfs(i);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &w[i]);
        T[w[i]].set(i, false);
    }
    for (int i = 1; i <= n; i++) {
        for (int j = i + 1; j <= n; j++) {
            int c = d[i][j];
            s[c].push_back({i, j});
        }
    }
    ll res = 0;
    for (int i = 1; i <= n; i++) {
        bitset<N> B[N];
        for (int j = 0; j < s[i].size(); j++) {
            int a = s[i][j].first, b = s[i][j].second;
            B[a].set(b, true);
            B[b].set(a, true);
        }
        for (int j = 0; j < s[i].size(); j++) {
            int a = s[i][j].first, b = s[i][j].second;
            if (w[a] == w[b]) continue;
            res += (B[a] & B[b] & T[w[a]] & T[w[b]]).count();
        }
    }
    printf("%lld\n", res / 3);
    return 0;
}
posted @ 2022-08-10 15:52  什么都不会的娃娃  阅读(317)  评论(0编辑  收藏  举报