CF1709E 题解

似乎是启发式合并的板子题,别人开的题,于是我就找了下这个题。不过似乎就算是板子题我也想了 10min /流汗。

首先有 \(dis(x,y)=d_x\ xor\ d_y\ xor\ a_{lca(x,y)}\),那么这个东西不等于 0 等价于 \(d_x\ xor\ d_y=a_{lca(x,y)}\)。我们类似树形 dp 的从小往上考虑,对于一个点上的决策就是将子树的答案合并。

这里的 \(d_x\) 是根到 \(x\)\(xor\)

我们注意到一件事情,就是如果一个子树的根如果我操作了,那么这个子树你可以认为就是砍掉了,因为我完全可以把 \(a_{rt}\) 变成一个神秘的数,这样显然子树是内是没有任何贡献了,外面想要和子树内的点连成一条路径,必须经过 \(a_{rt}\),那么显然是也没有任何贡献了。(这里的贡献的意思你可以认为是对树是否合法的影响)。

接着我们考虑哪些子树要操作呢?如果子树本身合法,不操作一定是最优秀的,就算后面子树内的点和外面的连接路径后不合法,在上面的点操作也是更优秀的。同理,如果不合法,那么在此时操作也是最优秀的。因此我们就贪心的考虑这个问题,每次就需要判定子树内是否合法即可,于是对于每个点维护一个对应的 \(set\) 表示子树内的 \(d\)。于是 \(check\) 的时候就对于每个子树看一下会不会和前面的子树内连接导致不合法即可。这个东西用启发式合并就可以做到 \(log\) 了。

同时一点是,如果操作了一个点,它的 \(set\) 就要清空了。理由在上面,你可以认为砍掉了。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <cassert>
#include <utility>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <bitset>
#include <random>

#define rep(i,l,r) for (int i = (l); i <= (r); i ++ )
#define per(i,r,l) for (int i = (r); i >= (l); i -- )
#define debug(x) cout << #x << '=' << x << '\n'
#define all(vc) vc.begin(), vc.end()
#define SZ(x) ((int)(x).size())
#define lwb lower_bound
#define upr upper_bound
#define pb push_back
#define x0 __xx00
#define y0 __yy00
#define x1 __xx11
#define y1 __yy11
#define fi first
#define se second

using namespace std;

const int Base = 10, mod1 = (int)1e9 + 7, mod2 = (int)1e9 + 9, mod3 = 998244353;
constexpr int md = (int)1e9 + 7;


typedef unsigned long long ULL;
typedef long double LDB;
typedef long long LL;
typedef double DB;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;


template <typename T>
T inv(const T& x, const T& y) {
    assert(x != 0);
    T u = 0, v = 1, a = x, m = y, t;
    while (a != 0) {
        t = m / a;
        swap(a, m -= t * a);
        swap(u -= t * v, v);
    }
    assert(m == 1);
    return u;
}

template <typename T>
class Modular {
public:
    using Type = typename decay<decltype(T::value)>::type;

    constexpr Modular() : value() {}
    template <typename U> Modular(const U& x) { value = normalize(x); }

    template <typename U>
    static Type normalize(const U& x) {
        Type v = static_cast<Type>((-mod() <= x && x < mod()) ? x : x % mod());
        if (v < 0) v += mod();
        return v;
    }

    const Type& operator()() const { return value; }
    template <typename U> explicit operator U() const { return static_cast<U>(value); }
    constexpr static Type mod() { return T::value; }

    Modular& operator+=(const Modular& other) {
        if ((value += other.value) >= mod()) value -= mod();
        return *this;
    }
    Modular& operator-=(const Modular& other) {
        if ((value -= other.value) < 0) value += mod();
        return *this;
    }
    template <typename U> Modular& operator+=(const U& other) { return *this += Modular(other); }
    template <typename U> Modular& operator-=(const U& other) { return *this -= Modular(other); }
    Modular& operator++() { return *this += 1; }
    Modular& operator--() { return *this -= 1; }
    Modular operator++(int) {
        Modular result(*this);
        *this += 1;
        return result;
    }
    Modular operator--(int) {
        Modular result(*this);
        *this -= 1;
        return result;
    }
    Modular operator-() const { return Modular(-value); }

    template <typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, int>::value, Modular>::type& operator*=(const Modular& rhs) {
#ifdef _WIN32
        uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
        uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
        asm(
            "divl %4; \n\t"
            : "=a"(d), "=d"(m)
            : "d"(xh), "a"(xl), "r"(mod()));
        value = m;
#else
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
#endif
        return *this;
    }
    template <typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, long long>::value, Modular>::type& operator*=(const Modular& rhs) {
        long long q = static_cast<long long>(static_cast<long double>(value) * rhs.value / mod());
        value = normalize(value * rhs.value - q * mod());
        return *this;
    }
    template <typename U = T>
    typename enable_if<!is_integral<typename Modular<U>::Type>::value, Modular>::type& operator*=(const Modular& rhs) {
        value = normalize(value * rhs.value);
        return *this;
    }

    Modular& operator/=(const Modular& other) { return *this *= Modular(inv(other.value, mod())); }

    friend const Type& abs(const Modular& x) { return x.value; }
    template <typename U> friend bool operator==(const Modular<U>& lhs, const Modular<U>& rhs);
    template <typename U> friend bool operator<(const Modular<U>& lhs, const Modular<U>& rhs);
    template <typename V, typename U> friend V& operator>>(V& stream, Modular<U>& number);

private:
    Type value;
};

template <typename T> bool operator==(const Modular<T>& lhs, const Modular<T>& rhs) { return lhs.value == rhs.value; }
template <typename T, typename U> bool operator==(const Modular<T>& lhs, U rhs) { return lhs == Modular<T>(rhs); }
template <typename T, typename U> bool operator==(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) == rhs; }

template <typename T> bool operator!=(const Modular<T>& lhs, const Modular<T>& rhs) { return !(lhs == rhs); }
template <typename T, typename U> bool operator!=(const Modular<T>& lhs, U rhs) { return !(lhs == rhs); }
template <typename T, typename U> bool operator!=(U lhs, const Modular<T>& rhs) { return !(lhs == rhs); }

template <typename T> bool operator<(const Modular<T>& lhs, const Modular<T>& rhs) { return lhs.value < rhs.value; }

template <typename T> Modular<T> operator+(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) += rhs; }
template <typename T, typename U> Modular<T> operator+(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) += rhs; }
template <typename T, typename U> Modular<T> operator+(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) += rhs; }

template <typename T> Modular<T> operator-(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) -= rhs; }
template <typename T, typename U> Modular<T> operator-(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) -= rhs; }
template <typename T, typename U> Modular<T> operator-(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) -= rhs; }

template <typename T> Modular<T> operator*(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) *= rhs; }
template <typename T, typename U> Modular<T> operator*(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) *= rhs; }
template <typename T, typename U> Modular<T> operator*(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) *= rhs; }

template <typename T> Modular<T> operator/(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) /= rhs; }
template <typename T, typename U> Modular<T> operator/(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) /= rhs; }
template <typename T, typename U> Modular<T> operator/(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) /= rhs; }

template <typename T, typename U>
Modular<T> qpow(const Modular<T>& a, const U& b) {
    assert(b >= 0);
    Modular<T> x = a, res = 1;
    for (T p = b; p; x *= x, p >>= 1)
        if (p & 1) res *= x;
    return res;
}

template <typename T> bool IsZero(const Modular<T>& number) { return number() == 0; }
template <typename T> string to_string(const Modular<T>& number) { return to_string(number()); }

template <typename U, typename T> U& operator<<(U& stream, const Modular<T>& number) { return stream << number(); }

template <typename U, typename T>
U& operator>>(U& stream, Modular<T>& number) {
    typename common_type<typename Modular<T>::Type, long long>::type x;
    stream >> x;
    number.value = Modular<T>::normalize(x);
    return stream;
}

// using ModType = int;
// struct VarMod { static ModType value; };
// ModType VarMod::value;
// ModType& md = VarMod::value;// for mod can change
// using Mint = Modular<VarMod>;

using Mint = Modular<std::integral_constant<decay<decltype(md)>::type, md>>;

struct Fact {
    vector<Mint> fact, factinv;
    const int n;
    Fact(const int& _n) : n(_n), fact(_n + 1, Mint(1)), factinv(_n + 1) {
        for (int i = 1; i <= n; ++i) fact[i] = fact[i - 1] * i;
        factinv[n] = inv(fact[n](), md);
        for (int i = n; i; --i) factinv[i - 1] = factinv[i] * i;
    }
    Mint C(const int& n, const int& k) {
        if (n < 0 || k < 0 || n < k) return 0;
        return fact[n] * factinv[k] * factinv[n - k];
    }
    Mint A(const int& n, const int& k) {
        if (n < 0 || k < 0 || n < k) return 0;
        return fact[n] * factinv[n - k];
    }
};

template <typename tp> inline void tomax(tp &a, tp v) { if (v > a) a = v; }
template <typename tp> inline void tomin(tp &a, tp v) { if (v < a) a = v; }
template <typename tp> inline void read(tp &n) {
    tp x = 0, f = 1;
    char ch = getchar();
    while (ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
    while (ch>='0'&&ch<='9') { x = (x<<1)+(x<<3)+(ch^48); ch=getchar(); }
    n = x*f;
}
template<typename tp> inline void print(tp x) {
    if(x<0) putchar('-'), x = -x;
    if(x>9) print(x/10); putchar(x%10+'0');
}
template<typename tp> inline tp qmi(tp a, tp b, tp p) {
    tp res = 1;
    while (b) {
        if (b&1) res = 1ll*res*a%p;
        a = 1ll*a*a%p; b >>= 1;
    }
    return res;
}
template<typename tp> inline tp gcd(tp a, tp b) { return !b? a: gcd(b, a % b); }
PLL exgcd(LL a, LL b) {
    if(!b) return {1,0};
    PLL tmp = exgcd(b,a%b);
    return {tmp.se, tmp.fi - (a/b)*tmp.se};
}
typedef class BigInteger: public vector<LL> {
    public: 
        using vector<LL>:: vector;
        void shrink() { while(size()>1u&&!back())pop_back(); }
        friend BigInteger operator + (BigInteger a, BigInteger b) { int n=max(a.size(),b.size())+1;a.resize(n,0);b.resize(n,0);rep(j,0,n-1)if((a[j]+=b[j])>=Base){a[j]-=Base;a[j+1]+=1;}a.shrink();return a; }
        friend BigInteger operator / (BigInteger a, int b) { int n=a.size(),p=0;per(j,n-1,0){p=p*Base+a[j];a[j]=p/b;p%=b;}a.shrink();return a; }
        friend BigInteger operator * (BigInteger a, BigInteger b) { int n=a.size(),m=b.size();BigInteger c(n+m,0);rep(i,0,n-1)for(int j=0,s=i;j<m;j++,s++){c[s]+=a[i]*b[j];c[s+1]+=c[s]/Base;c[s]%=Base;}for(int i=1;i<n+m;i++){c[i]+=c[i-1]/Base;c[i-1]%=Base;}while(c.size()>1u&&!c.back())c.pop_back();return c; }
        friend istream& operator >> (istream& is, BigInteger& a) { string s;cin>>s;int n=s.size();per(j,n-1,0)a.pb(s[j]-'0');return is; }
        friend ostream& operator << (ostream& os, BigInteger& a) { int n=a.size();per(j,n-1,0)print(a[j]);return os; }
}BigInteger;

const int N = 2e5 + 5;
int n, a[N], ans, dis[N];
vector<int> g[N];
set<int> st[N];
void init(int x, int fa) {
    for (auto to: g[x]) {
        if (to == fa) continue;
        dis[to] = dis[x] ^ a[to];
        init(to, x);
    }
}
void dfs(int x, int fa) {
    bool tag = 0;
    st[x].insert(dis[x]);
    for (auto to: g[x]) {
        if (to == fa) continue;
        dfs(to, x);
        if (SZ(st[x]) < SZ(st[to]))
            swap(st[x], st[to]);
        for (auto v: st[to])
            if (st[x].count(a[x] ^ v)) tag = 1;
        for (auto v: st[to])
            st[x].insert(v);
    }
    ans += tag;
    if (tag) st[x].clear();
}


void solve() {
    cin >> n;
    rep(i, 1, n) cin >> a[i];
    rep(i, 2, n) {
        int u, v; cin >> u >> v;
        g[u].pb(v), g[v].pb(u);
    }
    dis[1] = a[1];
    init(1, -1);
    dfs(1, -1);
    cout << ans << '\n';
}


int main() {

    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);

    /*
    
    */



    int T = 1;
    // cin >> T;
    
    while (T--) solve();



    return 0;
}
posted @ 2025-08-03 17:53  v1ne0qrs  阅读(9)  评论(0)    收藏  举报