2025牛客寒假算法基础集训营2补题笔记

题目难度大致顺序为:\(A、B、F、G、J、K、D、H、M、E、C、I、L\)

\(easy\)\(A、B、F、G、J、K、D\)

\(mid\)\(H、M、E、C\)

\(hard\)\(I、L\)

这场前期打的挺顺的,直到进入 \(mid\) 难度,C题卡了一个多钟,中途代码错了,但我误以为思路错了导致多花半小时debug,然后卡在H题,补完感觉就是个结论题,赛时不知道就只能靠猜或打表找规律。后面又补了M,感觉比H、C要简单,当然H知道结论就秒了,但不知道的我为什么要死犟在H题啊!!!(而且M题过题人很少,让我误以为很难很难)E题补了2小时,主要我推了一个复杂的公式,看完题解我只能说我还是菜。。。

A.一起奏响历史之音!

题意

给定7个整数表示音节序列,问是否是仅由五声音调组成,五声语调:1、2、3、5、6。

思路

简单的输入输出题,按题意模拟判断一下就行。

代码

点击查看代码
void solve()
{
  for (int i = 1; i <= 7; i ++) {
  	int x;
  	cin >> x;
  	if (x == 4 || x == 7) return void(cout << "NO");
  }
  cout << "YES";
}

B.能去你家蹭口饭吃吗

题意

给定一组数组,询问最大一个整数比一半的数组元素更小。

思路

二分的模板题

但可以用排序做,将数组升序排序,然后输出中位数再减一。

代码

点击查看代码
void solve()
{
  cin >> n;
  for (int i = 0; i < n; i ++) cin >> a[i];
  sort(a, a + n);
  
  cout << a[n >> 1] - 1;
}

F.一起找神秘的数!

题意

给一个区间 \([l, r]\),问有多少对不同的整数 \(x,y\) 满足 \(x + y = (x \ or \ y) + (x \ and \ y) + (x \ xor \ y)\)

思路

这里涉及到数学性质:\(x + y = (x \ or \ y) + (x \ and \ y)\)

为什么呢?
因为 \(x \ or \ y\) 可以得到两个整数相加去掉有进位后的数, \(x \ and \ y\) 得到两个整数相加有进位的数。
如:

\[11 + 5 = 1011 + 101 \\ 11 \ | \ 5:1111 \\ 11 \ \& \ 5:0001 \]

所以要满足 \(x + y = (x \ or \ y) + (x \ and \ y) + (x \ xor \ y)\),就必须 \(x \ xor \ y = 0\),此时只有 \(x = y\)才可行。

代码

点击查看代码
void solve()
{
  ll l, r;
  cin >> l >> r;
  
  cout << r - l + 1 << '\n';
}

G.一起铸最好的剑!

题意

给两个整数 \(n, m\),每次可以让 \(m\)\(m\) 倍,问多少次可以 \(m\) 最接近 \(n\)

思路

暴力模拟。

特判 \(m = 1\) 的情况,其他情况让 \(m\) 翻倍到一个数最接近 \(n\) 就行。

这题我愚蠢的没有想到m大于n的情况,wa3发。

代码

点击查看代码
void solve()
{
  cin >> n >> m;
  int ans = 1;
  ll x = m;
  if (m == 1 || m >= n) return void(cout << 1 << '\n');
  while (m * x <= n) {
  	ans ++;
  	m *= x;
  }
  if (m == n) cout << ans << '\n';
    else {
        ll d = n - m;
        m *= x;
        ll dd = m - n;
        if (d <= dd) cout << ans << '\n';
        else cout << ans + 1 << '\n';
    }
}

J.数据时间?

题意

统计某年某月内三个时段下各有多少人登录。

思路

按题意模拟就好,区分时段可以将时间转化为秒。

代码

点击查看代码
#include <iostream>
#include <map>

using namespace std;

int n, h, m;
map<pair<string, int>, bool> st;
int ans[5];

void solve()
{
  cin >> n >> h >> m;
  while (n --)
  {
  	string id;
  	int nian, yue, ri, shi, fen, miao;
  	cin >> id;
  	scanf("%d-%d-%d %d:%d:%d", &nian, &yue, &ri, &shi, &fen, &miao);
  	
  	if (nian != h || yue != m) continue;
  	
  	int time = shi * 3600 + fen * 60 + miao, x = 0;
  	if (time >= 25200 && time <= 32400) x = 1;
    else if (time >= 64800 && time <= 72000) x = 1;
	else if (time >= 39600 && time <= 46800) x = 2;
	else if (time >= 79200 || time <= 3600) x = 3;
  	
  	if (st.count({id, x})) continue;
      st[{id, x}] = 1;
      ans[x] ++;
  }
  
  for (int i = 1; i <= 3; i ++) cout << ans[i] << ' ';
}

int main()
{
  int t = 1;
  while (t --) solve();
  return 0;
}

K.可以分开吗?

题意

给一个 \(n \times m\) 的0-1矩阵,问最少的 \(1\) 连通块相邻 \(0\) 的个数。

思路

遇见一个\(1\) 连通块就\(bfs\) 搜索一下就行,感觉没有好说的,很板的搜索题吧。

代码

点击查看代码
#include <iostream>
#include <queue>
#include <map>

#define PII         pair<int, int> 
#define fi          first
#define se          second

using namespace std;

const int INF = 0x3f3f3f3f;
const int dx[] = {0, 0, 1, -1};
const int dy[] = {-1, 1, 0, 0};
const int M = 1000 + 10;

int n, m;
string g[M];
bool st[M][M];

int bfs(int x, int y)
{
	queue<PII> q;
	q.push({x, y});
	st[x][y] = 1;
	int res = 0;
	map<PII, bool> vis;
	
	while (!q.empty())
	{
		auto t = q.front();
		q.pop();
		
		for (int i = 0; i < 4; i ++) {
			int sx = t.fi + dx[i], sy = t.se + dy[i];
			if (sx < 0 || sx >= n || sy < 0 || sy >= m) continue;
			if (g[sx][sy] == '0' && !vis.count({sx, sy})) {
				res ++;
				vis[{sx, sy}] = 1;
			}
		}
		
		for (int i = 0; i < 4; i ++)
		{
			int xx = t.fi + dx[i], yy = t.se + dy[i];
			if (xx < 0 || xx >= n || yy < 0 || yy >= m) continue;
			if (g[xx][yy] == '0') continue;
			if (st[xx][yy]) continue;
			
			st[xx][yy] = 1;
			q.push({xx, yy});
		}
	}
	return res;
}

void solve()
{
  cin >> n >> m;
  for (int i = 0; i < n; i ++) cin >> g[i];
  
  int ans = INF;
  for (int i = 0; i < n; i ++) 
  	for (int j = 0; j < m; j ++) 
  		if (g[i][j] == '1' && !st[i][j]) ans = min(ans, bfs(i, j));
  
  cout << ans;
}

int main()
{
  ios::sync_with_stdio(false);
  cin.tie(0), cout.tie(0);

  int t = 1;
  while (t --) solve();
  return 0;
}

D.字符串里串

题意

给定长度为 \(n\) 的字符串,询问最大长度 \(k\) 使长度为 \(k\) 的非空子串和长度为 \(k\) 的不连续且非空的子序列相等。

思路

首先,如果存在更大的长度满足要求,那么较小的长度也会存在,这样可以存在一个结果区间左部分是正确的右部分是错误,就满足了二分性。

所以,我们只需要二分一个最大长度满足要求就可。

最后一个点是如何判断非空子串和不连续且非空的子序列相等,可以发现如果非空子串的首位或末尾存在一个可替换的情况就可构造一个不连续且非空的子序列相等。

代码

点击查看代码
#include <iostream>
#include <algorithm>
#include <map>
#include <vector>

#define pk          push_back
#define si(x)       int(x.size())

using namespace std;

int n;
string s;
map<char, vector<int>> mp;

bool check(int k)
{
	for (int i = 0; i < n - k + 1; i ++) 
		if (mp[s[i + k - 1]].back() > i + k - 1) return true;
		else if (mp[s[i]][0] < i) return true;
	return false;
}

void solve()
{
  cin >> n >> s;
  for (int i = 0; i < si(s); i ++) mp[s[i]].pk(i);
  
  int l = 2, r = n - 1;
  while (l < r)
  {
  	int mid = (l + r + 1) >> 1;
  	if (check(mid)) l = mid;
  	else r = mid - 1;
  }
  if (check(l)) cout << l;
  else cout << 0;
}

int main()
{
  ios::sync_with_stdio(false);
  cin.tie(0), cout.tie(0);

  int t = 1;
  while (t --) solve();
  return 0;
}

H.一起画很大的圆!

题意

在二维平面上给一个矩形,求矩形边界上的三个点可以画出一个最大的圆。

思路

首先,在矩形的一条边上是无法选择三个点画圆的。

接着,是找一个可以评定的标准,我想到无论是哪个圆,必然圆内有一条弦所在直线可以经过对角线。那么就以对角线所在直线为标准,在这条直线上截取线段为弦,并以一个角为固定点,找到另两个整数点所画的圆是否是最大的。

正确答案的数学结论是当三点共线时,可以看作是画出了一个半径为无穷大的圆。所以,当给定的三个点越接近共线,绘制出的圆也就越大。

代码

点击查看代码
void solve()
{
  cin >> a >> b >> c >> d;
  
  int ba = b - a;
  int dc = d - c;
    
    cout << a << ' ' << c << '\n';
    if (ba > dc) {
        cout << b << ' ' << c + 1 << '\n';
        cout << a + 1 << ' ' << c << '\n';
        
    }
    else {
        cout << a << ' ' << c + 1 << '\n';
        cout << a + 1 << ' ' << d << '\n';
    }
}

M.那是我们的影子

题意

有一个 \(3\)\(n\) 列的异形数独,满足以下规则:

  • 每一个单元格都需要填入 \(1\)\(9\) 之间的整数;
  • 任意一个 \(3×3\) 的子矩阵中都不包含重复的数字;

现在已经填入部分数字,求一共可以构造多少个合法方案。

思路

思维,组合数。

因为任意一个 \(3×3\) 子矩阵都不包含重复的数字,这个限制就表明在第 \(i\) 列出现的数字只能在 \(i\) + (3的倍数)列上出现,不然就会在某个 \(3×3\) 子矩阵中重复,所以每一列的数字都是固定,只是在这一列的顺序是不固定的。

有了上面的结论,我们思考如何保证初始状态的合法性及每列数字是什么。

  • \(set\) 数组记录对应列上的数字,如果某一列数字个数超过了3个,就是不合法的;
  • \(set\) 数组记录每个数字所在的列,如果某个数字在两个不同的列,就是不合法的;
  • 检查某一列中数字是否重复出现;

用以上方法排除掉不合法状态,同时也存下每列数字,考虑解决填充方案。

  • 首先确定有多少个数字是未确定填充在哪一列上的,记为 \(U\)
  • 然后,通过确定前两列的数字就可以确定 \(3×3\) 子矩阵的分布。
    • 假设第一列需填入 \(x_1\) 个数字,那么就是从 \(U\) 中任选 \(x_1\) 个数字,即 \(C_{U}^{x_i}\)
    • 假设第二列需填入 \(x_2\) 个数字,那么就是从 \(U - x_1\) 中任选 \(x_2\) 个数字,即 \(C_{U-x_1}^{x_2}\)
  • 确定 \(3×3\) 子矩阵每一列数字的组合后,就通过每一列上 \(?\) 个数确定当前列的排列方案数。

代码

点击查看代码
#include <iostream>
#include <vector>
#include <set>

using namespace std;

typedef long long ll;

const int mod = 1e9 + 7;
const int N = 1e5 + 10;

ll qmi(ll a, ll b)
{
    ll res = 1;
    while (b)
    {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

ll fact[N], infact[N];

void init()
{
    fact[0] = infact[0] = 1;
    for (int i = 1; i < N; i ++) fact[i] = fact[i - 1] * i % mod;
    infact[N - 1] = qmi(fact[N - 1], mod - 2);
    for (int i = N - 2; i; i --) infact[i] = infact[i + 1] * (i + 1) % mod;
}

ll C(int a, int b)
{
    if (b < 0 || a < b) return 0;
    return fact[a] * infact[b] % mod * infact[a - b] % mod;
}

int n;
string s[5];

void solve()
{
    cin >> n;
    for (int i = 1; i <= 3; i ++) {
        cin >> s[i];
        s[i] = ' ' + s[i];
    }
    
    set<int> v[3], num[10];
    for (int i = 1; i <= 3; i ++) {
        for (int j = 1; j <= n; j ++)
            if (s[i][j] != '?') {
                v[j % 3].insert(s[i][j] - '0');
                num[s[i][j] - '0'].insert(j % 3);
                
                // 当前列的数字个数超过了3个
                if (v[j % 3].size() > 3) return void(cout << 0 << '\n');
                // 这个数字所处的列个数超过1个
                if (num[s[i][j] - '0'].size() > 1) return void(cout << 0 << '\n');
            }
    }
        
    vector<ll> f(n + 1);
    for (int i = 1; i <= n; i ++) {
        vector<bool> vis(10, false);
        for (int j = 1; j <= 3; j ++)
            if (s[j][i] != '?') {
                // 如果数字重复出现
                if (vis[s[j][i] - '0']) return void(cout << 0 << '\n');
                vis[s[j][i] - '0'] = 1;
            }  else f[i] ++;
    }
    
    int used = 0;
    for (int i = 0; i < 3; i ++) used += 3 - v[i].size();
    
    ll ans = C(used, 3 - v[0].size()) * C(used - 3 + v[0].size(), 3 - v[1].size()) % mod;
    for (int i = 1; i <= n; i ++) ans = ans * fact[f[i]] % mod;
    cout << ans << '\n';
}

int main()
{
    init();
    
    int t;
    cin >> t;
    while (t --) solve();
    return 0;
}

C.字符串外串

题意

\(D\) 的扩展。

询问是否存在长度为 \(n\) 的字符串,使得字符串中有最大长度为 \(m\) 的非空子串和不连续且非空的子序列相等。

思路

思维,构造。

首先明确当 \(n = m\) 的时候,是不存在的,因为子序列要求不连续,就意味着 \(m\) 最大是 \(n - 1\)

然后,从最简单的开始,思考构造最大长度为 \(m = 1\) 时的非空子串和不连续且非空的子序列相等。
\(n = 2\) 时,\(aa\)
\(n = 3\) 时,\(aba\)
\(n = 4\) 时,\(abca\)
\(n = 27\) 时,\(abcdefghijklmnopqrstuvwxyza\)
从上面构造过程可以发现,除了目标 \(a\) 外,其他字母出现次数超过一次都会失败,以\(n = 4\)举例\(b\)出现两次即\(abba\),可以构造出字符串\(ab\)符合构造规则,但长度超过\(m = 1\),同理到\(n = 27\)时。所以当\(n = 28\)及以上时,不可避免的会让某个字母出现次数超过一次导致失败。
那么是所有的\(n > 27\)都不可能还是\(n > m + 26\)都不可能呢。

继续思考构造最大长度为 \(m = 2\) 时的非空子串和不连续且非空的子序列相等。
\(n = 3\) 时,\(aab\)
\(n = 4\) 时,\(abab\)
\(n = 5\) 时,\(abcab\)
\(n = 27\) 时,\(abcdefghijklmnopqrstuvwxyab\)
\(n = 28\) 时,\(abcdefghijklmnopqrstuvwxyzab\)
由此可以排除所有的\(n > 27\)都不可能这一猜测。
同时,可以发现,固定字符串\(ab\)为目标串,字母\(a和b\)最多再出现一次(可以自己手动构造一下),其他字母最多出现一次,换句话说,对于字符串\(ab\),每个字母的另一个出现位置必须最多相隔25个字母且除该字符串外的字母最多出现一次,而且字符串的首尾两字母中必须有一个字母,存在另一个相同字母,符合构造非连续子序列。

明显的,在 \(m < 27\) 的情况都符合上面找到的规律,设 \(abcdefghijklmnopqrstuvwxyz = S\)
\(m = 27\)时,\(n = 53\)能构造的就是\(SSa\)
\(m = 53\)时,\(n = 79\)能构造的就是\(SSSa\)

通过以上构造过程可以找到构造规律

  • \(n > m + 26\) 时不存在。
  • 构造长度为\(m\)的字符串按26个字母顺序构造
  • \(n - m\)的部分按26个字母顺序拼接在长度为\(m\)的字符串前。

(解释的不太清楚,要手动模拟构造一遍理解。)

代码

点击查看代码
#include <iostream>

using namespace std;

int n, m;
string tmp = "abcdefghijklmnopqrstuvwxyz";

void solve()
{
  cin >> n >> m;
  
  if (m == n) return void(cout << "NO" << '\n');
  if (n > m + 26) return void(cout << "NO" << '\n');
  
  cout << "YES" << '\n';
  string ans = "", c = "";
  
  int x = m / 26;
  while (x --) ans += tmp;
  int y = m % 26;
  for (char i = 'a'; y > 0; i ++, y --) ans += i;
  int d = n - m;
  for (char i = 'a'; d > 0; i ++, d --) c += i;
  ans = c + ans;
  
  cout << ans << '\n';
}

int main()
{
  ios::sync_with_stdio(false);
  cin.tie(0), cout.tie(0);

  int t = 1;
  cin >> t;
  while (t --) solve();

  return 0;
}

E.一起走很长的路!

题意

给一个长度为 \(n\) 的序列 \(a\),询问 \(q\) 次,每次询问区间 \([l,r]\),最少需要调整多少次可以使得 \(\forall i \in [l, r],\sum_{k=l}^{i - 1}a_k \ge a_i\),每次调整可以让 \(a_i\) 加一或减一。

思路

线段树、st表、数列分块

首先,当 \(l = r\) 时不需要调整。

然后,从区间的方向思考,假设已知区间A的结果和区间B的结果,询问区间AB,即涉及两个区间的合并问题。
当区间合并时,可以发现对于左边的区间即区间A内的元素是没有影响的,但对于区间B,则是多了区间A的元素和区间A调整的值,具体为:
设区间调整值为 \(d\) ,区间和为 \(S\),每个区间边界为 \(l, r\)
则区间B的调整值会改变为:\(d_b = \max({d_b - S_a, 0, l_b - S_a})\),即合并后区间调整值要减去区间A的总和,区间B的左端点元素对于区间A的总和的差
那么对于整个区间AB的影响:
\(S_{ab} = S_a + S_b\) , \(d_{ab} = \max(d_a, d_b)\)

把区间合并的问题解决后,用线段树维护每个区间的值就很简单了,套模板就行。

代码

点击查看代码
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N = 2e5 + 10;

int n, q;
ll a[N], s[N], c[N];
struct node {
    int l, r;
    ll sum;
    ll d;
} tr[N * 4];

ll func(node& l, node& r)
{
    return max({r.d - l.sum, l.d, a[r.l] - l.sum});
}

void pushup(int u)
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    tr[u].d = func(tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r) tr[u] = {l, r, a[l], 0ll};
    else {
        tr[u] = {l, r};
        int mid = (l + r) >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

node query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r) return tr[u];
    int mid = (tr[u].l + tr[u].r) >> 1;
    node R, res = {-1, -1, 0, 0};
    if (mid >= l) res = query(u << 1, l, r);
    if (mid < r) {
        R = query(u << 1 | 1, l, r);
        if (res.l != -1) {
            res.d = func(res, R);
            res.r = R.r;
            res.sum += R.sum;
        }
        else res = R;
    }
    return res;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    
    cin >> n >> q;
    for (int i = 1; i <= n; i ++) cin >> a[i];
    
    build(1, 1, n);
    
    while (q --)
    {
        int l, r;
        cin >> l >> r;
        cout << query(1, l, r).d << '\n';
    }
    return 0;
}

ST表同理,但要提前做前缀和,方便求区间和。

点击查看代码
#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long ll;

const int N = 2e5 + 10, M = 18;

int n, q;
ll a[N], s[N];
ll f[M][N];

ll S(int l, int r)
{
    return s[r] - s[l - 1];
}

void ST()
{
    for (int i = 0; i < M; i ++)
        for (int j = 1; j + (1 << i) - 1 <= n; j ++) 
            if (i) {
                ll l = f[i - 1][j], r = f[i - 1][j + (1 << i - 1)];
                ll lsum = S(j, j + (1 << i - 1) - 1);
                f[i][j] = max({l, a[j + (1 << i - 1)] - lsum, r - lsum});
            }
}

ll query(int l, int r)
{
    int len = r - l + 1;
    int k = log(len) / log(2);
    int R = l + (1 << k) - 1;
    ll lsum = S(l, R);
    if (R == r) return f[k][l];
    else {
        ll res = 0;
        ll x = query(R + 1, r);
        res = max({res, a[R + 1] - lsum, f[k][l], x - lsum});
        return res;
    }
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    
    cin >> n >> q;
    for (int i = 1; i <= n; i ++) {
        cin >> a[i];
        s[i] = s[i - 1] + a[i];
    }
    
    ST();
    while (q --)
    {
        int l, r;
        cin >> l >> r;
        cout << query(l, r) << '\n';
    }
    
    return 0;
}

通过上面的分析,其实可以发现,答案其实就是 \(\max_{i=l+1}^{r}(0, a_i - \sum_{j = l}^{i - 1}a_j)\)
所以我们可以先预处理\(a_i - \sum_{k=1}^{i - 1}a_k\),然后找到区间中最大的值后再加上\(\sum_{k=1}^{l-1}a_k\)即是答案。

这样的话,可以考虑用数列分块来做,因为前面已经写过线段树和ST表了

点击查看代码
#include <iostream>
#include <cmath>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N = 2e5 + 10;

int n, q, len;
ll a[N], id[N], s[N], sc[N], b[N];

ll query(int l, int r)
{
    if (l == r) return 0;
    int sd = id[l + 1], ed = id[r];
    ll res = sc[l + 1];
    if (sd == ed) {
        for (int i = l + 1; i <= r; i ++) res = max(res, sc[i]);
    } else {
        for (int i = l + 1; id[i] == sd; i ++) res = max(res, sc[i]);
        for (int i = sd + 1; i < ed; i ++) res = max(res, b[i]);
        for (int i = r; id[i] == ed; i --) res = max(res, sc[i]);
    }
    return max(0ll, res + s[l - 1]);
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    
    cin >> n >> q;
    len = sqrt(n);
    for  (int i = 1; i <= n; i ++) {
        b[i] = -1e18;
        cin >> a[i];
        id[i] = (i - 1) / len + 1;
        s[i] = s[i - 1] + a[i];
        sc[i] = a[i] - s[i - 1];
        b[id[i]] = max(b[id[i]], sc[i]);
    }
    
    while (q --)
    {
        int l, r;
        cin >> l >> r;
        cout << query(l, r) << '\n';
    }
    return 0;
}

I.一起看很美的日落!

题意

给一个由 \(n\) 个节点构成的树,第 \(i\) 个节点权值为 \(a_i\)
定义一个连通块的权值为连通块内两两节点的异或值的总和。
求这棵树内所有连通块的权值之和。

思路

树形dp,二进制优化

明显的,暴力求连通块内两两节点异或值再求和一定会超时,思考优化。从异或运算的角度出发,转为二进制后,某一位上只有 \(0 \oplus 1\) 时才对该位存在贡献。

树上求权值之和一般涉及到树形dp。

定义 \(dp[u]\):以 \(u\) 为根的子树中所有连通块的权值之和;
定义 \(f[u][i][0]\):以 \(u\) 为根的子树中第 \(i\) 为0的状态下的贡献;
定义 \(f[u][i][1]\):以 \(u\) 为根的子树中第 \(i\) 为1的状态下的贡献;
定义 \(s[u]\):以 \(u\) 为根的子树中包含 \(u\) 节点的连通块个数。

思考状态转移,即两棵子树合并的影响

  • 首先,原先的权值和 \(dp[u]\) 将会增加 \(s[v]\) 倍,因为合并了子树v中 \(s[v]\) 个连通块;
  • 其次,子树v的权值和 \(dp[v]\) 将会增加 \(s[u]\) 倍,原理同上;
  • 然后思考,两棵子树节点异或产生的权值:以 \(u\) 为根的子树中第 \(i\) 为0的贡献乘上以 \(v\) 为根的子树中第 \(i\) 为1的贡献,以及以 \(u\) 为根的子树中第 \(i\) 为1的的贡献乘上以 \(v\) 为根的子树中第 \(i\) 为0的贡献。
  • 将以上求和,即为 \(dp[u]\) 状态转移$。

f的状态转移类似:

  • \(f[u][i][0] += f[u][i][0] * s[v] + f[v][i][0] * s[u]\)
  • \(f[u][i][1] += f[u][i][1] * s[v] + f[v][i][1] * s[u]\)

\(s[u] += s[u] * s[v]\)

将每个节点的dp值求和后,再乘2(因为两两节点异或)即为答案。

代码

点击查看代码
#include <iostream>
#include <vector>

using namespace std;

typedef long long ll;

const int mod = 1e9 + 7;
const int N = 1e5 +  10;
const int M = 30;

int n;
int a[N];
vector<int> g[N];
ll dp[N], f[N][M][2], s[N], ans;

void dfs(int u, int fa)
{
    for (int i = 0; i < M; i ++) 
        if (a[u] >> i & 1) f[u][i][1] ++;
        else  f[u][i][0] ++;
    
    s[u] = 1;
    for (auto v : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        dp[u] = (dp[u] + dp[u] * s[v] % mod + dp[v] * s[u] % mod) % mod;
        for (int i = 0; i < M; i ++) {
            dp[u] = (dp[u] + (f[u][i][1] * f[v][i][0] % mod + f[u][i][0] * f[v][i][1] % mod) % mod * (1ll << i) % mod) % mod;
            f[u][i][0] = (f[u][i][0] + f[u][i][0] * s[v] % mod + f[v][i][0] * s[u] % mod) % mod;
            f[u][i][1] = (f[u][i][1] + f[u][i][1] * s[v] % mod + f[v][i][1] * s[u] % mod) % mod;
        }
        s[u] = (s[u] * s[v] % mod + s[u]) % mod;
    }
    ans = (ans + dp[u]) % mod;
}

int main()
{
    cin >> n;
    for (int i = 1; i <= n; i ++) cin >> a[i];
    for (int i = 1; i < n; i ++) {
        int u, v;
        cin >> u >> v;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    
    dfs(1, 0);
    cout << ans * 2 % mod;
    
    return 0;
}

posted @ 2025-02-17 15:54  Natural_TLP  阅读(33)  评论(0)    收藏  举报