2025牛客暑期多校训练营1


E. ndless Ladders

题意:所有\(x^2 - y^2\)构成的正整数的集合里,\(|a^2 - b^2|\)是第几位。

打表发现规律,没出现的数中,除了前面\(1, 2, 4, 6\)外其它都是每次加\(4\)

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

void solve() {
	i64 a, b;
	std::cin >> a >> b;
	i64 d = std::abs(a * a - b * b);
	if (d <= 6) {
		if (d == 3) {
			std::cout << 1 << "\n";
		} else if (d == 5) {
			std::cout << 2 << "\n";
		}
	} else  {
		std::cout << d - ((d - 6) / 4 + 4) << "\n";
	} 
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

G. Symmetry Intervals

题意:给你一个字符串\(S\),每次询问一个字符串\(T\)\(S\)以第\(a\)个为起点,有多少区间和\(T\)相同。

直接枚举,记每个和\(T\)相同的区间的长度为\(len\),那么这个区间任意选两个端点有\(\frac{len(len+1)}{2}\)个子区间满足条件,也可以记录一个变量一边遍历一边加。

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

void solve() {
	int n, q;
	std::cin >> n >> q;
	std::string s;
	std::cin >> s;
	while (q -- ) {
		std::string t;
		std::cin >> t;
		int a;
		std::cin >> a;
		-- a;
		int m = t.size();
		i64 ans = 0;
		for (int i = 0; i < m; ++ i) {
			if (s[a + i] == t[i]) {
				int j = i;
				while (j < m && s[a + j] == t[j]) {
					++ j;
				}

				i64 len = j - i;
				ans += len * (len + 1) / 2;
				i = j - 1;
			}
		}
		std::cout << ans << "\n";
	}
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	// std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

H. Symmetry Intervals 2

赛后补题
题意:与\(G\)差不多,不过是让\(G\)的两个长度相同的子串比较,而且变成了\(01\)串。然后增加了反转操作,让一个区间取反。保证询问操作不超过\(2500\)次。

以前没怎么写过这种分块的题,也是好好研究了一下,发现其思想很暴力,不是很难理解。核心思想就是把数组分成若干长度相同的块,然后如果能实现块与块之间的合并,那么就可以一次跳一个块去操作。
回到这个题,因为是\(01\)串,我们可以把它按位存储。定一个位数\(B\),然后把\(S\)分成\(\lceil \frac{n}{B} \rceil\)个长度为\(B\)的二进制串,或者说是分块。显然不同的二进制串只有\(2^B\)个,那么我们可以用\(2^B\)预处理这些串,得到\(L[i], r[i], ans[i]\),分别表示二进制\(i\)左边连续的\(1\)的个数、右边连续\(1\)的个数,以及这个串里自己贡献的答案。
那么我们可以拼接两个串,记录前面串有\(cnt\)\(1\)连接当前串的左边,那么可以贡献\(L[i] \times cnt\)个区间,然后更新\(cnt\),如果\(L[i] == B\)则表示全是\(1\),可以与前面的接到一起:\(cnt = cnt + L[i]\),否则和前面的断开了,\(cnt = R[i]\)
那么对于反转操作,我们可以用一个\(tag\)数组差分一下,也就是如果反转\([l, r]\),可以得到\(l, r\)的二进制串的编号为\(\lfloor \frac{l}{B} \rfloor\)\(\lfloor \frac{r}{B} \rfloor\),因为这两个块可能并没有完全反转,所以需要先把这两个块反转的那一部分反转一下,也就是异或一个二进制数,然后就相当于给\([\lfloor \frac{l}{B} \rfloor + 1, \lfloor \frac{r}{B} \rfloor - 1]\)这个区间都异或上\(1\),直接差分记录一下。
对于询问操作,先把所有块跑一遍,用前面记录的差分数组判断要不要反转。然后一块一块的跳,记录答案就行了。

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

const int B = 16;
const int N = (1e6 + 5) / B + 10;

int A[N], tag[N];
int ans[1 << B], L[1 << B], R[1 << B];

int get(int l, int r) {
	return (1 << r) - 1 - ((1 << l) - 1);
}

int get_st(int id, int k, int len) {
	if (len == B) {
		if (k == 0) {
			return A[id];
		} else {
			return ((A[id] & get(k, B)) >> k) + ((A[id + 1] & get(0, k)) << (B - k));
		}
	} 

	int st = 0;
	for (int i = id * B + k, j = 0; j < len; ++ i, ++ j) {
		st |= (A[i / B] >> (i % B) & 1) << j;
	}

	return st;
}

void solve() {
	int n, q;
	std::cin >> n >> q;
	std::string s;
	std::cin >> s;
	for (int i = 0; i < n; ++ i) {
		int x = i / B, y = i % B;
		if (s[i] == '1') {
			A[x] |= 1 << y;
		}
	}

	for (int i = 0; i < 1 << B; ++ i) {
		for (int j = 0; j < B && (i >> j & 1); ++ j) {
			++ L[i];
		}

		for (int j = B - 1; j >= 0 && (i >> j & 1); -- j) {
			++ R[i];
		}

		for (int j = 0, x = 0; j < B; ++ j) {
			if (i >> j & 1) {
				++ x;
			} else {
				x = 0;
			}

			ans[i] += x;
		}
	}


	while (q -- ) {
		int op, l, r, a, b;
		std::cin >> op;
		if (op == 1) {
			std::cin >> l >> r;
			-- l, -- r;
			if (l / B == r / B) {
				A[l / B] ^= get(l % B, r % B + 1);
			} else {
				A[l / B] ^= get(l % B, B);
				A[r / B] ^= get(0, r % B + 1);
				tag[l / B + 1] ^= 1;
				tag[r / B] ^= 1;
			}
		} else {
			std::cin >> l >> a >> b;
			-- a, -- b;
			for (int i = 0, flag = 0; i <= n / B; ++ i) {
				flag ^= tag[i];
				tag[i] = 0;
				if (flag) {
					A[i] ^= get(0, B);
				}
			}

			i64 cnt = 0;
			i64 res = 0;
			int a1 = a / B, a2 = a % B, b1 = b / B, b2 = b % B;
			while (l) {
				int len = std::min(l, B);
				int st = get_st(a1, a2, len) ^ get_st(b1, b2, len) ^ get(0, len);
				res += ans[st];
				res += L[st] * cnt;
				if (L[st] == B) {
					cnt += L[st];
				} else {
					cnt = R[st];
				}

				++ a1, ++ b1;
				l -= len;
			}
			std::cout << res << "\n";
		}
	}
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	// std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

I. Iron Bars Cutting

赛后补题。
题意:一个数组,每次选一个位置\(k\)\([l, r]\)切割为\([l, k], [k + 1, r]\)两个部分,切割有不平衡度和代价,需要把每个位置单独切出来。一个合法的切割方案需要满足,每次切割的不平衡度小于等于上一次切割的不平衡度。现在求一开始把\([1, n]\)\(k\)处切割的合法方案的最小总代价。

方法一\((n^3logn)\)
可以记\(f[i][j][k]\)\([i, j]\)分成\([i, k], [k + 1, j]\)的最小代价,记\(pair\)类型\(的dp[i][j]\)数组存\([i, j]\)所有切割方案的不平衡值和总代价。那么对于\([i, k], [k + 1, j]\)这两个区间,我们就可以得到小于等于当前不平衡值的切割方案的最小代价。
具体实现是,我们按区间长度从小到大求,这也是区间\(dp\)的常规转移方式,然后对于每个\(f[i][j][k]\)去找左右区间的最小代价,每个\(dp[i][j]\)都是按不平衡值从小到大排序,并且最小代价取一个前缀\(min\)。因为如果当前方案可行,而还有一个不平衡值比它小且代价更小的,那么这个方案比它更优。那么就可以二分查找到这个值。
然后关于减少空间复杂度的方法,可以只开一个\(f[]\)数组作为临时数组,省略掉两位,然后枚举到\([1, n]\)这个区间时输出方案。

点击查看代码

#include <bits/stdc++.h>

using i64 = long long;

void solve() {
	int n;
	std::cin >> n;
	std::vector<i64> sum(n + 1);
	for (int i = 1; i <= n; ++ i) {
		std::cin >> sum[i];
		sum[i] += sum[i - 1];
	}

	const i64 inf = 1e18;

	std::vector dp(n + 1, std::vector<std::vector<std::pair<i64, i64>>>(n + 1));
	auto get = [&](int i, int j, i64 b) -> i64 {
		if (dp[i][j].empty() || dp[i][j][0].first > b) {
			return inf;
		}

		int l = 0, r = (int)dp[i][j].size() - 1;
		while (l < r) {
			int mid = l + r + 1 >> 1;
			if (dp[i][j][mid].first <= b) {
				l = mid;
			} else {
				r = mid - 1;
			}
		}

		return dp[i][j][l].second;
	};

	std::vector<i64> f(n + 1);
	for (int len = 1; len <= n; ++ len) {
		for (int i = 1; i + len - 1 <= n; ++ i) {
			int j = i + len - 1;
			if (len == 1) {
				dp[i][j].emplace_back(0, 0);
			} else {
				int lg = std::ceil(std::log2(sum[j] - sum[i - 1]));
				for (int k = i; k < j; ++ k) {
					i64 l1 = sum[k] - sum[i - 1], l2 = sum[j] - sum[k];
					i64 b = std::abs(l1 - l2);
					f[k] = get(i, k, b) + get(k + 1, j, b) + std::min(l1, l2) * lg;
					f[k] = std::min(f[k], inf);
					dp[i][j].emplace_back(b, f[k]);
				}

				std::ranges::sort(dp[i][j]);
				for (int k = 1; k < dp[i][j].size(); ++ k) {
					dp[i][j][k].second = std::min(dp[i][j][k].second, dp[i][j][k - 1].second);
				}

				if (i == 1 && j == n) {
					for (int k = 1; k < n; ++ k) {
						std::cout << (f[k] >= inf ? -1 : f[k]) << " \n"[k == n - 1];
					}
				}
			}
		}
	}
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

方法二\(O(n^3)\)
注意到如果固定区间,切割点从左往右移动,那么不平衡值是先降后增,因为左边的和不断变大,右边的不断变小。同理固定左端点和切割点,移动右端点,以及固定右端点和切割点,移动左端点。也是先降后增。
那么可以记\(L[i][j][k]\)为在\([i, j]\)\(k\)点切割后\([i, k]\)可以取得的最小代价。\(R[i][j][k]\)\([i, j]\)\(k\)点切割后\([k + 1, j]\)可以取得的最小代价。那么有\(f[i][j][k] = L[i][j][k] + R[i][j][k] + cost(i, j, k)\)
那么如果我们得到了\(f[i][j][k]\),则可以用它更新\(L[i][r][j], r\in [j+1, n]\),和\(R[l][j][i], l\in [1, i - 1]\)。具体实现就是按照上述结论,分别找到先降后增的分割点,然后分左右两边,按不平衡值从小到大移动。
然后需要注意的是,这题不能开三维数组,我们可以给每个\((l, r, m)\)编号,用一个大一维数组存,因为合法的\((l, r, m)\)肯定少于\(n^3\)个,可以计算出来,这样三元组的总数是\(\sum_{r=1}^{n} \sum_{m=1}^{r-1} \sum_{l=1}^{m} = \sum_{r=1}^{n} \frac{r(r-1)}{2} = \frac{n(n-1)(n-2)}{6}\)\(n=420\)的情况下开\(longlong\)数组大概占用\(200MB\),可以通过。至于编号,可以先按照\(r\)再按照\(m\),根据上面的式子,固定\(r\)时,所有\(r' < r\)的三元组有\(\sum_{r'=1}^{r - 1} \sum_{m=1}^{r' - 1} m = \frac{r(r-1)(r-2)}{6}\)。同理,在同一个\(r\)下,固定\(m\),小于它的有\(\frac{m(m-1)}{2}\)个,那么\((l, r, m)\)可以编号为\(\frac{r(r-1)(r-2)}{6} + \frac{m(m-1)}{2} + l\)。正式比赛中可以不用算这些,直接把数组开到极限,然后随便搞个不会重复最大值不会越界的编号方式就行了。
感觉这个写法细节多一点,而且代码也长,如果是区域赛遇到还是写方法一好了。

点击查看代码

#include <bits/stdc++.h>

using i64 = long long;

const i64 inf = 1e18;
const int N = 425, MAXL = N * N * N / 6;

i64 L[MAXL], R[MAXL];
i64 s[N], f[N];

i64 & get(i64 dp[], int l, int r, int m) {
	return dp[r * (r - 1) * (r - 2) / 6 + m * (m - 1) / 2 + l];
}

void solve() {
	int n;
	std::cin >> n;
	for (int i = 1; i <= n; ++ i) {
		std::cin >> s[i];
		s[i] += s[i - 1];
	}

	for (int i = 1; i <= n; ++ i) {
		for (int j = i + 1; j <= n; ++ j) {
			for (int k = i; k < j; ++ k) {
			 	get(L, i, j, k) = inf;
			 	get(R, i, j, k) = inf;
			 	if (i == k) {
			 		get(L, i, j, k) = 0;
			 	} 

			 	if (k + 1 == j) {
			 		get(R, i, j, k) = 0;
			 	}
			}
		}
	}

	auto sum = [&](int l, int r) -> i64 {
		return s[r] - s[l - 1];
	};

	for (int len = 2; len <= n; ++ len) {
		for (int i = 1; i + len - 1 <= n; ++ i) {
			int j = i + len - 1;
			i64 lg = std::ceil(std::log2(sum(i, j)));
			for (int k = i; k < j; ++ k) {
				i64 l1 = sum(i, k), l2 = sum(k + 1, j);
				i64 b = std::abs(l1 - l2);
				f[k] = get(L, i, j, k) + get(R, i, j, k) + std::min(l1, l2) * lg;
				f[k] = std::min(f[k], inf);
			}

			int p1 = i;
			while (p1 < j && sum(i, p1) < sum(p1 + 1, j)) {
				++ p1;
			}

			//L[i][r][j];
			int p2 = j + 1;
			while (p2 <= n && sum(i, j) > sum(j + 1, p2)) {
				++ p2;
			}

			i64 min = inf;
			for (int k = p2, x = p1 - 1, y = p1; k <= n; ++ k) {
				i64 cur = sum(j + 1, k) - sum(i, j);
				while (x >= i && sum(x + 1, j) - sum(i, x) <= cur) {
					min = std::min(min, f[x]);
					-- x;
				}

				while (y < j && sum(i, y) - sum(y + 1, j) <= cur) {
					min = std::min(min, f[y]);
					++ y;
				}

				get(L, i, k, j) = std::min(get(L, i, k, j), min);
			}

			min = inf;
			for (int k = p2 - 1, x = p1 - 1, y = p1; k > j; -- k) {
				i64 cur = sum(i, j) - sum(j + 1, k);
				while (x >= i && sum(x + 1, j) - sum(i, x) <= cur) {
					min = std::min(min, f[x]);
					-- x;
				}

				while (y < j && sum(i, y) - sum(y + 1, j) <= cur) {
					min = std::min(min, f[y]);
					++ y;
				}

				get(L, i, k, j) = std::min(get(L, i, k, j), min);
			}


			//R[l][j][i];
			p2 = i - 1;
			while (p2 > 0 && sum(p2, i - 1) < sum(i, j)) {
				-- p2;
			}

			min = inf;
			for (int k = p2 + 1, x = p1 - 1, y = p1; k < i; ++ k) {
				i64 cur = sum(i, j) - sum(k, i - 1);
				while (x >= i && sum(x + 1, j) - sum(i, x) <= cur) {
					min = std::min(min, f[x]);
					-- x;
				}

				while (y < j && sum(i, y) - sum(y + 1, j) <= cur) {
					min = std::min(min, f[y]);
					++ y;
				}

				get(R, k, j, i - 1) = std::min(get(R, k, j, i - 1), min);
			}

			min = inf;
			for (int k = p2, x = p1 - 1, y = p1; k > 0; -- k) {
				i64 cur = sum(k, i - 1) - sum(i, j);
				while (x >= i && sum(x + 1, j) - sum(i, x) <= cur) {
					min = std::min(min, f[x]);
					-- x;
				}

				while (y < j && sum(i, y) - sum(y + 1, j) <= cur) {
					min = std::min(min, f[y]);
					++ y;
				}

				get(R, k, j, i - 1) = std::min(get(R, k, j, i - 1), min);
			}

			if (i == 1 && j == n) {
				for (int k = 1; k < n; ++ k) {
					std::cout << (f[k] >= inf ? -1 : f[k]) << " \n"[k == n - 1];
				}
			}
		}
	}
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

K. Museum Acceptance

题意:一个图,每个点最多\(3\)条边,每条边有编号。一条边对于连接的两个点对应的编号可能不同。从一个点出发,一开始走\(1\)号边,后面会选择序号比前面大一的边走,如果没有则走\(1\)号边。求从每个点出发可以经过多少不同的边。

按边建图,那么有\(3n\)个点,发现执行点都恰好有一个出边一个入边,那么这些点组成了一个个环。把这些环找出来,用\(set\)给边去重就可以知道这个环里有多少边,然后一号边在这个环里的点答案就是边的数量。

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

void solve() {
	int n;
	std::cin >> n;
	std::vector p(n + 1, std::array<int, 4>{});
	std::map<int, std::map<int, int>> mp;
	for (int i = 1; i <= n; ++ i) {
		int k;
		std::cin >> k;
		p[i][0] = k;
		for (int j = 1; j <= k; ++ j) {
			std::cin >> p[i][j];
			mp[i][p[i][j]] = j;
		}
	}

	std::vector<std::vector<int>> adj(3 * n + 1);
	for (int i = 1; i <= n; ++ i) {
		for (int j = 1; j <= p[i][0]; ++ j) {
			int next = mp[p[i][j]][i] == p[p[i][j]][0] ? 1 : mp[p[i][j]][i] + 1;
			adj[(i - 1) * 3 + j].push_back((p[i][j] - 1) * 3 + next);
		}
	}

	std::vector<int> st(3 * n + 1), f(n + 1);
	for (int i = 1; i <= 3 * n; ++ i) {
		if (st[i]) {
			continue;
		}
		std::vector<int> a;
		auto dfs = [&](auto & self, int u) -> void {
			st[u] = 1;
			a.push_back(u);
			for (auto & v : adj[u]) {
				if (!st[v]) {
					self(self, v);
				}
			}
		};

		dfs(dfs, i);

		std::set<std::pair<int, int>> s;
		for (auto & u : a) {
			for (auto & v : adj[u]) {
				int x = (u - 1) / 3 + 1, y = (v - 1) / 3 + 1;
				s.insert({std::min(x, y), std::max(x, y)});
			}
		}

		for (auto & u : a) {
			if (u % 3 == 1) {
				f[(u - 1) / 3 + 1] = s.size();
			}
		}
	}

	for (int i = 1; i <= n; ++ i) {
		std::cout << f[i] << "\n";
	}
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	// std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}

L. Numb Numbers

题意:一个数组,单点修改。每次求有多少数至少有\(\lfloor \frac{n}{2} \rfloor\)个数比它大。

我是直接用动态开点线段树加线段树内二分维护的。具体的维护一个\([1, 10^{15}]\)的区间每个点有多少数就行。

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

const int N = 2e5 + 5;

#define ls(u) tr[u].lson
#define rs(u) tr[u].rson

struct Node {
	int lson, rson;
	int sum;
}tr[N << 5];

const i64 R = 1e15;

i64 a[N];
int root, idx;

void modify(int & u, i64 l, i64 r, i64 p, int add) {
	if (u == 0) {
		u = ++ idx;
	}

	tr[u].sum += add;
	if (l == r) {
		return;
	}

	i64 mid = l + r >> 1ll;
	if (p <= mid) {
		modify(ls(u), l, mid, p, add);
	} else {
		modify(rs(u), mid + 1, r, p, add);
	}
}

int query(int x) {
	int u = 1;
	int res = 0;
	i64 l = 1, r = R;
	while (u && x > 0) {
		// std::cout << l << " " << r << ":\n";
		// std::cout << tr[ls(u)].sum << " " << tr[rs(u)].sum << " " << x << "\n";
		if (tr[u].sum <= x) {
			res += tr[u].sum;
			break;
		}
		i64 mid = l + r >> 1ll;
		if (tr[ls(u)].sum >= x) {
			u = ls(u);
			r = mid;
		} else {
			x -= tr[ls(u)].sum;
			res += tr[ls(u)].sum;
			u = rs(u);
			l = mid + 1;
		}
	}

	return res;
}

void clear(int & u) {
	if (ls(u)) {
		clear(ls(u));
	}

	if (rs(u)) {
		clear(rs(u));
	}

	tr[u] = {0, 0, 0};
	u = 0;
}

void solve() {
	int n, q;
	std::cin >> n >> q;
	root = idx = 0;
	for (int i = 1; i <= n; ++ i) {
		std::cin >> a[i];
		modify(root, 1, R, a[i], 1);
	}

	while (q -- ) {
		int p, v;
		std::cin >> p >> v;
		modify(root, 1, R, a[p], -1);
		a[p] += v;
		modify(root, 1, R, a[p], 1);
		std::cout << query((n + 1) / 2) << "\n";
	}

	clear(root);
}

int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	int t = 1;
	std::cin >> t;
	while (t -- ) {
		solve();
	}
	return 0;
}	
posted @ 2025-07-15 21:12  maburb  阅读(671)  评论(0)    收藏  举报