树上启发式合并


树上启发式合并

解决树上统计问题, O ( n ∗ l o g ( n ) ) O(n*log(n)) O(nlog(n)),可以结合线段树等数据结构维护深度上的信息

博客


入门题

const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;

ll n, m, u, v, mx, sum;
vector<int> mp[maxn];
int sz[maxn], son[maxn];
void dfs(int u, int fa)
{
    sz[u] = 1;
    for (auto v : mp[u]) {
        if (v == fa)
            continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[son[u]] < sz[v])
            son[u] = v;
    }
}
ll col[maxn], cnt[maxn], ans[maxn], flag;
void count(int u, int fa)
{
    cnt[col[u]]++;
    if (cnt[col[u]] > mx) {
        sum = col[u];
        mx = cnt[col[u]];
    } else if (cnt[col[u]] == mx)
        sum += col[u];
    for (auto v : mp[u]) {
        if (v == fa || v == flag)
            continue;
        count(v, u);
    }
}
void clear(int u, int fa)
{
    cnt[col[u]]--;
    for (auto v : mp[u]) {
        if (v == fa || v == flag)
            continue;
        clear(v, u);
    }
}
void df(int u, int fa, bool kp)
{
    for (auto v : mp[u]) {
        if (v == fa || v == son[u])
            continue;
        df(v, u, 0);
    }
    if (son[u]) {
        df(son[u], u, 1);
        flag = son[u];
    }
    count(u, fa);
    flag = 0, ans[u] = sum;
    if (!kp) {
        clear(u, fa);
        sum = mx = 0;
    }
}

int main()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> col[i];
    for (int i = 1; i < n; i++) {
        cin >> u >> v;
        mp[u].pb(v), mp[v].pb(u);
    }
    dfs(1, 0);
    df(1, 0, 1);
    for (int i = 1; i <= n; i++)
        cout << ans[i] << " ";
    return 0;
}

结合线段树维护深度上的最值

#define mid ((tr[k].l + tr[k].r) >> 1)
#define ls k << 1
#define rs k << 1 | 1

const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;

int n, t, flag, mxdep;
int sz[maxn], son[maxn], dep[maxn], val[maxn];
ll sum[maxn], ans[maxn];
vector<int> mp[maxn];
void dfs(int u, int dp)
{
    sz[u] = 1, dep[u] = dp, mxdep = max(mxdep, dp);
    sum[dep[u]] += val[u];
    for (auto v : mp[u]) {
        dfs(v, dp + 1);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]])
            son[u] = v;
    }
}
struct node {
    ll l, r, w;
} tr[maxn << 2];
void build(int k, int l, int r)
{
    tr[k].l = l, tr[k].r = r;
    if (l == r) {
        tr[k].w = sum[l];
        return;
    }
    build(k << 1, l, mid), build(k << 1 | 1, mid + 1, r);
    tr[k].w = max(tr[k << 1].w, tr[k << 1 | 1].w);
}
void change(int k, int pl, int x)
{
    if (tr[k].l == tr[k].r && tr[k].l == pl) {
        tr[k].w -= x;
        return;
    }
    if (pl <= mid)
        change(k << 1, pl, x);
    else
        change(k << 1 | 1, pl, x);
    tr[k].w = max(tr[k << 1].w, tr[k << 1 | 1].w);
}
ll query()
{
    return tr[1].w;
}
void init()
{
    mxdep = 0;
    for (int i = 1; i <= n; i++) {
        if (mp[i].size())
            mp[i].clear();
        sz[i] = son[i] = dep[i] = sum[i] = 0;
    }
    for (int i = 0; i <= n * 4; i++)
        tr[i].l = tr[i].r = tr[i].w = 0;
}
void count(int u)
{
    change(1, dep[u], val[u]);
    for (auto v : mp[u])
        if (v != flag)
            count(v);
}
void clear(int u)
{
    change(1, dep[u], -val[u]);
    for (auto v : mp[u])
        if (v != flag)
            clear(v);
}
void df(int u, int kep)
{
    for (auto v : mp[u]) {
        if (v == son[u])
            continue;
        df(v, 0);
    }
    if (son[u]) {
        df(son[u], 1);
        flag = son[u];
    }
    count(u);
    ans[u] = query();
    if (!kep) {
        flag = 0;
        clear(u);
    }
}

int main()
{
    ioss;
    cin >> t;
    while (t--) {
        cin >> n;
        init();
        for (int i = 1; i <= n; i++)
            cin >> val[i];
        for (int i = 2, v; i <= n; i++) {
            cin >> v;
            mp[v].pb(i);
        }
        dfs(1, 1);
        build(1, 1, mxdep);
        df(1, 1);
        for (int i = 2; i <= n; i++)
            cout << ans[i] << "\n";
    }
    return 0;
}


按层维护26个字母的数量

const int maxn = 5e5 + 7;
const int mod = 1e9 + 7;

int n, q, t, flag, mxdep;
int sz[maxn], son[maxn], dep[maxn], ques[2][maxn];
char val[maxn];
vector<int> mp[maxn], que[maxn];
int cnt[maxn][40];
map<pii, int> ans;
void dfs(int u, int dp)
{
    sz[u] = 1, dep[u] = dp, mxdep = max(mxdep, dp);
    for (auto v : mp[u]) {
        dfs(v, dp + 1);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]])
            son[u] = v;
    }
}
void count(int u)
{
    cnt[dep[u]][val[u] - 'a']++;
    for (auto v : mp[u])
        if (v != flag)
            count(v);
}
void clear(int u)
{
    cnt[dep[u]][val[u] - 'a']--;
    for (auto v : mp[u])
        if (v != flag)
            clear(v);
}
bool check(int dp)
{
    int res = 0;
    for (int i = 0; i <= 26; i++)
        res += cnt[dp][i] % 2;
    return res <= 1;
}
void df(int u, int kep)
{
    for (auto v : mp[u]) {
        if (v == son[u])
            continue;
        df(v, 0);
    }
    if (son[u]) {
        df(son[u], 1);
        flag = son[u];
    }
    count(u);
    for (auto d : que[u])
        ans[{ u, d }] = check(d);
    if (!kep) {
        flag = 0;
        clear(u);
    }
}
int main()
{
    ioss;
    cin >> n >> q;
    for (int i = 2, v; i <= n; i++) {
        cin >> v;
        mp[v].pb(i);
    }
    cin >> val + 1;
    dfs(1, 1);
    for (int i = 1, a, b; i <= q; i++) {
        cin >> a >> b;
        ques[0][i] = a, ques[1][i] = b;
        que[a].pb(b);
    }
    df(1, 1);
    for (int i = 1; i <= q; i++) {
        if (ans[{ ques[0][i], ques[1][i] }])
            cout << "Yes" << endl;
        else
            cout << "No" << endl;
    }
    return 0;
}
posted @ 2020-11-29 20:18  naymi  阅读(31)  评论(0)    收藏  举报