FJWC2019
FJWC 2019
Day 1
全连
有 \(n\) 个数,第 \(i\) 个的覆盖区间为 \((i - t_i, i + t_i)\) ,价值为 \(a_i\) 。求覆盖区间互不重叠的情况下可以选出的数的最大价值和。
\(n \leq 10^6\)
有一个显然的 \(O(n^2)\) DP:令 \(f_i\) 表示以 \(i\) 结尾的最高分数,枚举上一个数进行转移。
注意到 \(j\) 能转移到 \(i\) 当且仅当 \(j + t_j \leq i\) 且 \(j \leq i - t_i\) 。用树状数组维护当前能产生贡献的所有状态,对于状态 \(i\) ,它只有可能对 \([i + t_i, n]\) 产生贡献,只要在时刻 \(i + t_i\) 将其加入树状数组即可。
剩下就是简单的前缀最大值查询,时间复杂度 \(O(n \log n)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;
vector<int> upd[N];
ll f[N];
int t[N], a[N];
int n;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace BIT {
ll c[N];
inline void update(int x, ll k) {
for (; x <= n; x += x & -x)
c[x] = max(c[x], k);
}
inline ll query(int x) {
ll res = 0;
for (; x > 0; x -= x & -x)
res = max(res, c[x]);
return res;
}
} // namespace BIT
signed main() {
freopen("fc.in", "r", stdin);
freopen("fc.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i)
t[i] = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
for (int i = 1; i <= n; ++i) {
for (int it : upd[i])
BIT::update(it, f[it]);
f[i] = BIT::query(i - t[i]) + 1ll * a[i] * t[i];
if (i + t[i] <= n)
upd[i + t[i]].emplace_back(i);
}
printf("%lld", *max_element(f + 1, f + 1 + n));
return 0;
}
原样输出
给定 \(n\) 个字符串,求将每个串头尾各删去一些字符后(可以删完或不删)拼接在一起的串的种类的数量,并按字典序大小输出。
保证输入文件大小不超过 1MB,保证输出文件大小不超过 200MB,字符集为 \(\{ A, C, G, T \}\) 。
可以发现这题的性质和子序列相似。回顾枚举子序列的方法,枚举下一个字符 \(c\),然后跳到离当前位置最近的字符 \(c\) 继续搜索,显然这样能够不重不漏地遍历所有子序列。
对于这题,考虑在失配时跳到下一个包含字符 \(c\) 的串的对应位置。具体地,对于每个串建立 SAM,然后从后往前考虑每个串,维护 \(lst_c\) 表示最近的包含字符 \(c\) 的串的 SAM 中 \(c\) 代表的状态。枚举当前 SAM 上的所有状态,如果该状态不存在字符 \(c\) 的转移边,只需要将这条边指向 \(lst_c\) 即可。
这样就得到了一个能匹配所有合法结果的自动机,显然它是一个 DAG,求方案数只要按拓扑逆序 DP 即可,时间复杂度线性。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 3e6 + 7, S = 4;
const char chr[S] = {'A', 'C', 'G', 'T'};
struct Graph {
vector<pair<int, int> > e[N];
inline void insert(int u, int v, int w) {
e[u].emplace_back(v, w);
}
} G;
int ch[N][S], lst[N][S];
int fa[N], len[N], f[N];
char str[N];
bool vis[N];
int n, testid, tot;
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int getid(char c) {
for (int i = 0; i < S; ++i)
if (c == chr[i])
return i;
}
struct SAM {
int be, ed, las;
inline void extend(int c) {
int p = las, np = las = ++tot;
len[np] = len[p] + 1;
for (; p && !ch[p][c]; p = fa[p])
ch[p][c] = np;
if (!p)
fa[np] = be;
else {
int q = ch[p][c];
if (len[q] == len[p] + 1)
fa[np] = q;
else {
int nq = ++tot;
fa[nq] = fa[q], len[nq] = len[p] + 1;
memcpy(ch[nq], ch[q], sizeof(int) * S);
fa[q] = fa[np] = nq;
for (; p && ch[p][c] == q; p = fa[p])
ch[p][c] = nq;
}
}
}
} sam[N];
int dfs1(int u) {
if (vis[u])
return f[u];
vis[u] = true, f[u] = 1;
for (auto it : G.e[u])
f[u] = add(f[u], dfs1(it.first));
return f[u];
}
void dfs2(int u, int pos) {
puts(str + 1);
for (auto it : G.e[u]) {
int v = it.first, w = it.second;
str[pos] = chr[w], dfs2(v, pos + 1);
}
str[pos] = '\0';
}
signed main() {
freopen("copy.in", "r", stdin);
freopen("copy.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%s", str + 1);
sam[i].be = sam[i].las = ++tot;
for (int j = 1, len = strlen(str + 1); j <= len; ++j)
sam[i].extend(getid(str[j]));
memcpy(lst[i], ch[sam[i].be], sizeof(int) * S);
sam[i].ed = tot;
}
scanf("%d", &testid);
vector<int> nxt(4);
for (int i = n; i; --i) {
for (int j = sam[i].be; j <= sam[i].ed; ++j)
for (int k = 0; k < S; ++k) {
if (!ch[j][k] && nxt[k])
G.insert(j, nxt[k], k);
else if (ch[j][k])
G.insert(j, ch[j][k], k);
}
for (int j = 0; j < S; ++j)
if (lst[i][j])
nxt[j] = lst[i][j];
}
if (testid == 1) {
memset(str, '\0', sizeof(str));
dfs2(1, 1);
}
printf("%d", dfs1(1));
return 0;
}
不同的缩写
给定 \(n\) 个串,令每个串仅保留其一个非空子序列,求使得最后保留串互不相同时最长保留串的最短长度,并构造方案。
\(n, |S| \leq 300\)
考虑二分最大长度 \(ans\),然后对每个串 DFS/BFS 找出所有长度不超过 \(ans\) 的子序列,转化成匹配问题,Dinic 即可。
注意到若一个串有至少 \(n\) 个不同的子序列,则无论其他串保留什么子序列,它都能匹配上,因此对于每个串只要至多搜出 \(n\) 个子序列。
这样建立的网络流图点数和边数都是 \(O(n^2)\) 的,时间复杂度 \(O(n^3 \log n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e2 + 7, M = 1e6 + 7;
struct SA {
int nxt[N][27];
char str[N];
int len;
inline void build() {
len = strlen(str + 1);
memset(nxt[len + 1], -1, sizeof(nxt[len + 1]));
for (int i = len; ~i; --i) {
memcpy(nxt[i], nxt[i + 1], sizeof(nxt[i]));
nxt[i][str[i + 1] - 'a'] = i + 1;
}
}
inline bool query(char *p) {
int ptr = 0;
for (int i = 1, lenp = strlen(p + 1); i <= lenp && ~ptr; ++i)
ptr = nxt[ptr][p[i] - 'a'];
return ~ptr;
}
} sa[N];
struct Graph {
struct Edge {
int nxt, v;
} e[M];
int head[N];
int tot;
inline void clear() {
memset(head, 0, sizeof(head));
tot = 0;
}
inline void insert(int u, int v) {
e[++tot] = (Edge) {head[u], v}, head[u] = tot;
}
} G;
map<int, string> mp;
map<string, int> idx;
string answer[N], nowstr;
int obj[M], vis[M], num[N];
int n, tot;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
void dfs(int u, int id, const int limit) {
if (!idx.count(nowstr))
mp[idx[nowstr] = ++tot] = nowstr;
G.insert(id, idx[nowstr]), ++num[id];
if (num[id] == n || nowstr.size() >= limit) {
nowstr.pop_back();
return;
}
for (int i = 0; i < 26; ++i)
if (~sa[id].nxt[u][i]) {
nowstr.push_back('a' + i);
dfs(sa[id].nxt[u][i], id, limit);
if (num[id] == n) {
nowstr.pop_back();
return;
}
}
nowstr.pop_back();
}
inline bool Hungary(int u, int tag) {
for (int i = G.head[u]; i; i = G.e[i].nxt) {
int v = G.e[i].v;
if (vis[v] != tag) {
vis[v] = tag;
if (!obj[v] || Hungary(obj[v], tag)) {
obj[v] = u;
return true;
}
}
}
return false;
}
inline bool check(int lambda) {
memset(num, 0, sizeof(num));
G.clear();
for (int i = 1; i <= n; ++i)
for (int j = 0; j < 26; ++j)
if (~sa[i].nxt[0][j] && num[i] < n) {
nowstr.push_back('a' + j);
dfs(sa[i].nxt[0][j], i, lambda);
}
memset(obj, 0, sizeof(obj));
memset(vis, 0, sizeof(vis));
int sum = 0;
for (int i = 1; i <= n; ++i)
sum += Hungary(i, i);
return sum == n;
}
signed main() {
freopen("diff.in", "r", stdin);
freopen("diff.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i) {
scanf("%s", sa[i].str + 1);
sa[i].build();
}
int l = 1, r = n, ans = -1;
while (l <= r) {
int mid = (l + r) >> 1;
if (check(mid))
ans = mid, r = mid - 1;
else
l = mid + 1;
}
if (ans == -1)
return puts("-1"), 0;
printf("%d\n", ans);
check(ans);
for (int i = 1; i <= tot; ++i)
if (obj[i])
answer[obj[i]] = mp[i];
for (int i = 1; i <= n; ++i)
puts(answer[i].c_str());
return 0;
}
Day 2
直径
构造一棵 \(n\) 个点,并且有 \(k \leq 5 \times 10^6\) 个直径的树,需要满足 \(n \leq 5000\) 且边权 \(w \in [0, 10^5]\) 。
考虑这样构造:根节点上接三条边权为 \(1\) 的边,三个点底下分别接接 \(a - 1, b - 1, c - 1\) 条边权为 \(0\) 的边。
这样构造的树直径条数为 \(ab + ac + bc\),只要枚举 \(a, b\) 就能算出是否存在对应的 \(c\) 。
显然对于 \(k \le 5 \times 10^6\) 大概率存在这样的解,打表可以说明一定存在,需要注意判断 \(c = 0\) 的情况。
#include <bits/stdc++.h>
using namespace std;
int k;
signed main() {
freopen("diameter.in", "r", stdin);
freopen("diameter.out", "w", stdout);
scanf("%d", &k);
if (k == 1)
return puts("2\n1 2 0"), 0;
else if (k == 2)
return puts("4\n1 2 3\n2 3 1\n2 4 1"), 0;
else if (k == 4)
return puts("7\n1 2 1\n1 3 1\n2 4 1\n2 5 1\n3 6 1\n3 7 1"), 0;
for (int a = 1; a <= 3e2; ++a)
for (int b = a; a + b <= 5e3 && a * b <= k; ++b)
if (!((k - a * b) % (a + b))) {
int c = (k - a * b) / (a + b);
if (a + b + c + 1 > 5e3)
continue;
if (c) {
printf("%d\n1 2 1\n1 3 1\n1 4 1\n", a + b + c + 1);
int tot = 4;
for (int i = 1; i < a; ++i)
printf("2 %d 0\n", ++tot);
for (int i = 1; i < b; ++i)
printf("3 %d 0\n", ++tot);
for (int i = 1; i < c; ++i)
printf("4 %d 0\n", ++tot);
} else {
printf("%d\n1 2 1\n1 3 1\n", a + b + 1);
int tot = 3;
for (int i = 1; i < a; ++i)
printf("2 %d 0\n", ++tot);
for (int i = 1; i < b; ++i)
printf("3 %d 0\n", ++tot);
}
return 0;
}
return 0;
}
定价
给定 \(n, m, q\) ,初始有一个全为 \(2^m - 1\) 的数组 \(a_{1 \sim n}\) ,\(q\) 次操作:
1 r c:\(a_r \gets a_r \oplus 2^{m - c}\) 。2:求数组 \(b_{1 \sim n}\) 满足 \(0 < b_1 < b_2 < \cdots < b_n < 2^m\) 且 \(a_i \and b_i = 0\) ,输出 \(\sum_{i = 1}^n b_i\) 的最小值 \(\bmod 10^9 + 7\) 。\(n \leq 1000\) ,\(m \leq 10^9\) ,\(q \leq 5 \times 10^5\) ,\(2\) 操作不超过 \(1000\) 次
因为查询比较少,所以可以每次查询都做一次。
先将限制看为 \(0\) 和 \(1\) ,\(0\) 代表只能填 \(0\) ,\(1\) 代表都可以。考虑从高位到低位枚举,分四类情况:
- 若上一个选了 \(1\) ,这一位为 \(1\) ,则这一位必选 \(1\) 。
- 若上一个选了 \(0\) ,这一位为 \(1\) ,则可以将这一位选上,更新当前套装的答案,但再往低位枚举时就不用将这一位选上。
- 若上一个选了 \(1\) ,这一位为 \(0\) ,则我们直接退出当前套装的枚举。
- 若上一个选了 \(0\) ,这一位为 \(0\) ,则我们不用管它。
此时我们看看当前套装的答案有没有被更新过,如果没有,就说明没有一种方案满足要求,直接输出 \(-1\) 就好了,时间复杂度 \(O(nm)\) 。
注意到只有为 \(1\) 的位才有贡献,考虑用 set 存每一位的限制,枚举时双指针即可。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 1e3 + 7;
set<int> st[N];
int n, m, q;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
inline int mi(int a, int b) {
int res = 1;
for (; b; b >>= 1, a = 1ll * a * a % Mod)
if (b & 1)
res = 1ll * res * a % Mod;
return res;
}
inline int solve() {
if (st[1].empty())
return -1;
set<int> lst;
int ans = mi(2, *st[1].begin());
lst.emplace(*st[1].begin());
for (int i = 2; i <= n; ++i) {
if (st[i].empty())
return -1;
auto it2 = prev(lst.end()), it1 = st[i].upper_bound(*it2);
if (it1 == st[i].end())
--it1;
set<int> res, now;
for (;;) {
if (*it2 == *it1) { // 上一个选了1 and 这一位为1
now.emplace(*it1);
if (it1 == st[i].begin())
break;
--it1;
if (it2 == lst.begin()) {
res = now, res.emplace(*st[i].begin());
break;
}
--it2;
} else if (*it2 > *it1) // 上一个选了1 and 这一位为0
break;
else { // 上一个选了0 and 这一位为1
res = now, res.emplace(*it1);
if (it1 == st[i].begin())
break;
--it1;
}
}
if (res.empty())
return -1;
for (auto it1 = res.begin(); it1 != res.end(); ++it1)
ans = add(ans, mi(2, *it1));
lst = res;
}
return ans;
}
signed main() {
freopen("price.in", "r", stdin);
freopen("price.out", "w", stdout);
n = read(), m = read(), q = read();
while (q--) {
if (read() == 1) {
int r = read(), c = m - read();
if (st[r].find(c) == st[r].end())
st[r].emplace(c);
else
st[r].erase(c);
} else
printf("%d\n", solve());
}
return 0;
}
排序
inline void FastSort(int *a, int n) { ll cnt = 0; for (int i = 1; i <= n; ++i) for (int j = i + 1; j <= n; ++j) { if (a[j] < a[i]) swap(a[i], a[j]); ++cnt; } }对给定的 \(1,2 \ldots n\) 的排列 \(a\) 执行算法 \(\text{FastSort}\),问当 \(cnt\) 刚刚变成输入中给定的值时的 \(a\) 序列。
\(n \leq 10^6\)
考虑模拟出 \(i\) 循环了 \(k\) 次后的序列,剩下的 \(O(n)\) 次操作暴力模拟即可。
显然 \(1 \sim k\) 的最终位置已经确定了,考虑依次确定 \(k + 1 \sim n\) 每个数的位置。对于数字 \(i \in [k + 1, n]\) ,后面只要有一个比它小的数,它就会往后移动。执行了 \(k\) 轮,只要存 \(1 \sim i - 1\) 中前 \(k\) 大的位置,加入 \(i\) 后将其移动到最后即可。
时间复杂度 \(O(n \log n)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;
int a[N], pos[N];
ll cnt;
int n;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
signed main() {
freopen("sort.in", "r", stdin);
freopen("sort.out", "w", stdout);
n = read(), cnt = read<ll>();
for (int i = 1; i <= n; ++i)
pos[a[i] = read()] = i;
int k = 0;
while (cnt >= n - k - 1)
cnt -= n - ++k;
if (k) {
priority_queue<int> q;
for (int i = 1; i <= k; ++i)
a[i] = i, q.emplace(pos[i]);
for (int i = k + 1; i <= n; ++i)
q.emplace(pos[i]), a[q.top()] = i, q.pop();
}
for (int i = 1; i <= cnt; ++i)
if (a[k + i + 1] < a[k + 1])
swap(a[k + i + 1], a[k + 1]);
for (int i = 1; i <= n; ++i)
printf("%d ", a[i]);
return 0;
}
Day 3
签到题
给出 \(a_{1 \sim n}\) ,\(q\) 次询问,每次会将一个位置的值修改。需要在 \(q\) 次修改之前以及每次修改后(询问 \(q + 1\) 次)求一个最小的 \(x\) 使得每个 \(a_i\) 异或 \(x\) 后 \(a\) 升序,或判断无解。
\(n \leq 10^6\)
显然 \(a \ \text{xor} \ x \ge b \ \text{xor} \ x\) 相当于二进制最高的满足 \(a\) 和 \(b\) 不同的位必须是 \(0/1\) 。数组中每相邻两个数都会产生一个这样的限制,对每一位进行统计即可。实现精细一点可以做到 \(O(n)\)。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 7, B = 31;
int a[N], cnt[B][2];
int n, q, num, ans;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline void insert(int x, int y) {
if (x == y)
return;
int h = __lg(x ^ y);
if (!cnt[h][y >> h & 1] && cnt[h][x >> h & 1])
++num;
++cnt[h][y >> h & 1];
if (cnt[h][1] == 1 && (y >> h & 1))
ans ^= 1 << h;
}
inline void remove(int x, int y) {
if (x == y)
return;
int h = __lg(x ^ y);
if (cnt[h][1] == 1 && (y >> h & 1))
ans ^= 1 << h;
--cnt[h][y >> h & 1];
if (!cnt[h][y >> h & 1] && cnt[h][x >> h & 1])
--num;
}
signed main() {
freopen("sort.in", "r", stdin);
freopen("sort.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
for (int i = 2; i <= n; ++i)
insert(a[i], a[i - 1]);
printf("%d\n", num ? -1 : ans);
q = read();
while (q--) {
int x = read(), k = read();
if (x > 1)
remove(a[x], a[x - 1]), insert(k, a[x - 1]);
if (x < n)
remove(a[x + 1], a[x]), insert(a[x + 1], k);
a[x] = k;
printf("%d\n", num ? -1 : ans);
}
return 0;
}
送分题
有 \(2n\) 个人和两个单间,一个男女通用,一个女性专用。每次:
- 若队首为女性,则优先尝试进入女性专用单间,否则尝试进入男女通用单间。
- 若队首为男性,则尝试进入男女通用单间。若只有女性专用单间为空,则最前面的女生进入。
假设每个人进房间后需要过 \(1\) 单位时间后才出来,忽略在队列中移动的时间。
调整队列顺序使得在 \(n\) 个单位时间内所有人都从房间出来,求一种合法的调整方案使得每个人不满值的最大值最小,一个人的不满值定义为调整队列前在他之后且调整队列后在他之前的人数。
\(n \leq 10^{18}\) ,输入方式为给出 \(m \leq 10^5\) ,接下来 \(m\) 行,每行给出一个字符串和一个数字 \(k\) ,将这个串重复 \(k\) 次插入队列末尾。
因为 \(2n\) 个人 \(n\) 秒结束,所以每秒必须两人进。
显然任意时刻均需满足队列中剩下的男性比剩下的女性少,那么从后往前数排第 \(i\) 个的女性必须在从后往前数第 \(2i\) 个及后面。
又因为要让调整的距离尽量小,那么第 \(i\) 个女性如果位置太靠前,将她移到从后往前数第 \(2i\) 个会最优。
显然这样只有女性会产生不满值,那么对所有女性的不满值取最大值即可,时间复杂度 \(O(n)\) 。
对于所有数据,从后往前考虑每个输入的字符串,并且从后往前考虑每一次重复。设当前一段后缀有 \(m_1\) 个男性与 \(m_2\) 个女性,则贡献为 \(m_1 - m_2\) ,于是对于每个字符串判断重复一次还是 \(k\) 次能取到最大值即可。
注意算总贡献的时候要 \(-1\) ,因为如果恰好有一个男生在后面,多出来的那个男生可以跟着他之前的女生走。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int M = 1e5 + 7, S = 2e5 + 7;
ll a[M], mx[M];
char str[S];
ll n;
int m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
signed main() {
freopen("queue.in", "r", stdin);
freopen("queue.out", "w", stdout);
n = read<ll>(), m = read();
ll MaleSum = 0;
for (int i = 1; i <= m; ++i) {
scanf("%s", str + 1);
a[i] = read<ll>();
int male = 0, female = 0;
mx[i] = -inf;
for (int j = strlen(str + 1); j; --j) {
if (str[j] == 'M')
++male;
else
++female;
mx[i] = max(mx[i],(ll) male - female);
}
if (male > female)
mx[i] += (a[i] - 1) * (male - female);
MaleSum += male * a[i], a[i] *= male - female;
}
if (MaleSum > n)
return puts("-1"), 0;
for (int i = m - 1; i; --i)
a[i] += a[i + 1];
ll ans = 0;
for (int i = m; i; --i)
ans = max(ans, a[i + 1] + mx[i] - 1);
printf("%lld", ans);
return 0;
}
简单题
source:LOJ2731. 「JOISC 2016 Day 1」棋盘游戏
有一个 \(3\times n\) 的棋盘,开始时棋盘有一些格子上已经摆上了棋子(记为
o),剩下的格子都是空的(记为x)。每次可以选择一个空的格子摆上棋子,这个格子必须满足以下两个条件之一:
- 这个格子上下两格都有棋子。
- 这个格子左右两格都有棋子。
求有多少种不同的摆满棋盘的摆放顺序。
\(n \leq 2000\)
先考虑无解情况:四个角有 x ,或第一行有两个相邻的 x ,或第三行有两个相邻的 x 。
若一个格子能填,则要么上下比它早填,要么左右比它早填。考虑给每个格子一个数,表示填入棋子的时间,o 的是 \(0\) ,x 的是 \([1, m]\) ,其中 \(m\) 为空格个数。
然后考虑 DP 求每个 x 的连通块的答案,显然它们的顺序互不干扰,最后用组合数合并即可。
设 \(f_{i, j, 0/1}\) 表示在 \((2, i)\) 填 \(j\) ,\((2, i)\) 是否比 \((2, i + 1)\) 先填(填的数更小),把这个连通块到第 \(i\) 列的所有 \(x\) 填一个排列,都填合法的方案数。不难发现只有 \(f_{i, j, 1}\) 对答案有贡献。由题可得每个 x 要么比左右晚要么比上下晚。
根据 \(f_{i, j}\) 和 \(f_{i - 1, j}\) 的第三维可得 \(i - 1, i, i + 1\) 的时序关系,具体转移:
- 对于该列三个
x的情况:- 左在中前,中在右前,上下在中前:\(f_{i, j, 1} \gets (j - 1) \times (j - 2) \times \sum_{k = 1}^{j - 3} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下在中前:\(f_{i, j, 0} \gets (j - 1) \times (j - 2) \times \sum_{k = 1}^{j - 3} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下一个在中后、一个在中前:\(f_{i, j, 0} \gets 2 \times (j - 1) \times (siz_i - j) \times \sum_{k = 1}^{j - 2} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下在中后:\(f_{i, j, 0} \gets (siz_i - j) \times (siz_i - j - 1) \times \sum_{k = 1}^{j - 1} f_{i - 1, k, 1}\) 。
- 左在中后,中在右前,上下在中前:\(f_{i, j, 1} \gets (j - 1) \times (j - 2) \times \sum_{k = j - 2}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
- 左在中后,中在右后,上下在中前:\(f_{i, j, 0} \gets (j - 1) \times (j - 2) \times \sum_{k = j - 2}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
- 对于该列两个
x的情况:- 左在中前,中在右前,上下在中前:\(f_{i, j, 1} \gets (j - 1) \times \sum_{k = 1}^{j - 2} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下在中前: \(f_{i, j, 0} \gets (j - 1) \times \sum_{k = 1}^{j - 2} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下一个在中后、一个在中前:\(f_{i, j, 0} \gets (siz_i - j) \times \sum_{k = 1}^{j - 1} f_{i - 1, k, 1}\) 。
- 左在中后,中在右前,上下在中前:\(f_{i, j, 1} \gets (j - 1) \times \sum_{k = j - 1}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
- 左在中后,中在右后,上下在中前:\(f_{i, j, 0} \gets (j - 1) \times \sum_{k = j - 1}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
- 对于该列一个
x的情况:- 左在中前,中在右前,上下在中前:\(f_{i, j, 1} \gets \sum_{k = 1}^{j - 2} f_{i - 1, k, 1}\) 。
- 左在中前,中在右后,上下在中前: \(f_{i, j, 0} \gets \sum_{k = 1}^{j - 2} f_{i - 1, k, 1}\) 。
- 左在中后,中在右前,上下在中前:\(f_{i, j, 1} \gets \sum_{k = j}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
- 左在中后,中在右后,上下在中前:\(f_{i, j, 0} \gets \sum_{k = j}^{siz_{i - 1}} f_{i - 1, k, 0}\) 。
前缀和优化即可做到 \(O(n^2)\) 。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 2e3 + 7;
int f[N][N << 1][2], s[N][N << 1][2];
int fac[N * 3], inv[N * 3], invfac[N * 3], siz[N];
char str[5][N];
int n, m;
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
inline void prework(int n) {
fac[0] = fac[1] = 1;
inv[0] = inv[1] = 1;
invfac[0] = invfac[1] = 1;
for (int i = 2; i <= n; ++i) {
fac[i] = 1ll * fac[i - 1] * i % Mod;
inv[i] = 1ll * (Mod - Mod / i) * inv[Mod % i] % Mod;
invfac[i] = 1ll * invfac[i - 1] * inv[i] % Mod;
}
}
inline bool check() {
for (int i = 1; i < n; ++i) {
if (str[1][i] == 'x' && str[1][i + 1] == 'x')
return false;
else if (str[3][i] == 'x' && str[3][i + 1] == 'x')
return false;
}
return str[1][1] != 'x' && str[1][n] != 'x' && str[3][1] != 'x' && str[3][n] != 'x';
}
signed main() {
scanf("%d", &n);
for (int i = 1; i <= 3; ++i)
scanf("%s", str[i] + 1), m += count(str[i] + 1, str[i] + n + 1, 'x');
prework(n * 3);
if (!check())
return puts("0"), 0;
int ans = 1;
for (int i = 1; i <= n; ++i) {
if (str[2][i] == 'o')
continue;
int cnt = (str[1][i] == 'x') + 1 + (str[3][i] == 'x');
siz[i] = siz[i - 1] + cnt;
if (i == 1 || str[2][i - 1] == 'o') {
f[i - 1][0][1] = 1;
for (int j = 0; j <= m; ++j)
s[i - 1][j][1] = 1;
}
if (cnt == 3) {
for (int j = 1; j <= siz[i]; ++j) {
if (j >= 3) {
f[i][j][0] = add(f[i][j][0], 1ll * (j - 1) * (j - 2) % Mod * s[i - 1][j - 3][1] % Mod);
f[i][j][1] = add(f[i][j][1], 1ll * (j - 1) * (j - 2) % Mod * s[i - 1][j - 3][1] % Mod);
f[i][j][0] = add(f[i][j][0], 1ll * (j - 1) * (j - 2) % Mod *
dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 3][0]) % Mod);
f[i][j][1] = add(f[i][j][1], 1ll * (j - 1) * (j - 2) % Mod *
dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 3][0]) % Mod);
}
f[i][j][0] = add(f[i][j][0], 1ll * (siz[i] - j) * (siz[i] - j - 1) % Mod *
s[i - 1][j - 1][1] % Mod);
if (j >= 2)
f[i][j][0] = add(f[i][j][0], 2ll * (j - 1) * (siz[i] - j) % Mod *
s[i - 1][j - 2][1] % Mod);
}
} else if (cnt == 2) {
for (int j = 1; j <= siz[i]; ++j) {
if (j >= 2) {
f[i][j][0] = add(f[i][j][0], 1ll * (j - 1) * s[i - 1][j - 2][1] % Mod);
f[i][j][1] = add(f[i][j][1], 1ll * (j - 1) * s[i - 1][j - 2][1] % Mod);
f[i][j][0] = add(f[i][j][0], 1ll * (j - 1) *
dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 2][0]) % Mod);
f[i][j][1] = add(f[i][j][1], 1ll * (j - 1) *
dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 2][0]) % Mod);
}
f[i][j][0] = add(f[i][j][0], 1ll * (siz[i] - j) * s[i - 1][j - 1][1] % Mod);
}
} else {
for (int j = 1; j <= siz[i]; ++j) {
f[i][j][0] = add(f[i][j][0], s[i - 1][j - 1][1]);
f[i][j][1] = add(f[i][j][1], s[i - 1][j - 1][1]);
f[i][j][0] = add(f[i][j][0], dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 1][0]));
f[i][j][1] = add(f[i][j][1], dec(s[i - 1][siz[i - 1]][0], s[i - 1][j - 1][0]));
}
}
for (int j = 1; j <= m; ++j) {
s[i][j][0] = add(s[i][j - 1][0], f[i][j][0]);
s[i][j][1] = add(s[i][j - 1][1], f[i][j][1]);
}
if (i == n || str[2][i + 1] == 'o')
ans = 1ll * ans * s[i][siz[i]][0] % Mod * invfac[siz[i]] % Mod;
}
printf("%d", 1ll * ans * fac[m] % Mod);
return 0;
}
Day 4
循环流
有一个 \(n\) 个点的循环流(每个点入度等于出度),每条边的流量只有 \(1\) 或 \(2\) ,可能有重边,没有自环。
显然,由于它是一个流网络,它是一个弱连通图(将边视为无向边后为连通图)。
求是否存在一个 \(n\) 个点、\(a\) 条流量为 \(1\) 的边、\(b\) 条流量为 \(2\) 的边的循环流。
\(n, a, b \leq 50\)
首先判掉 \(a + b < n\) 的情况,此时一定没有环,无解。当 \(a + b = n\) 时,显然只能有一种边,即 \(a = n, b = 0\) 或 \(a = 0, b = n\) 。
接下来考虑 \(a + b > n\) 的情况,首先用 \(1\) 边构造一个 \(x\) 个点的环,\(2\) 边构造一个 \(y\) 个点的环,并满足 \(x \leq a, y \leq b, x + y = n + 1\) ,然后两个环各选一个点,让他们合并成一个,即用这个点连接两个环。这样就构造出了 \(n\) 个点、\(n + 1\) 条边的合法循环流。
接下来考虑剩下的边。先考虑 \(1\) 边构成的环,\(2\) 边同理。选择两个环上相邻的点 \(u \to v\) ,断开它们之间的边,然后再找环上的一个点 \(z\) ,连 \(u \to z \to v\) 。需要特判一些情况:
- \(a = 1\) :显然无解。
- \(b = 1\) :先连 \(1\) 的边,然后分裂至合法,然后随便拿出一条边 \(u \to v\) 流量改为 \(2\) ,再连一条 \(v \to u\) 流量为 \(1\) 的边。
- \(n = 2\) :首先要满足 \(a + b \geq 2\)
- 如果 \(b\) 是奇数,\(a\) 必须是非 \(0\) 偶数。
- 如果 \(b\) 是偶数,\(a\) 必须是偶数。
#include <bits/stdc++.h>
using namespace std;
signed main() {
freopen("flow.in", "r", stdin);
freopen("flow.out", "w", stdout);
int testid, T, n, a, b;
scanf("%d%d", &testid, &T);
while (T--) {
scanf("%d%d%d", &n, &a, &b);
if (n == 2)
puts((!a && b && (~b & 1)) || (a && (~a & 1)) ? "1" : "0");
else
puts((!a && b >= n) || (!b && a >= n) || (a != 1 && a + b > n) ? "1" : "0");
}
return 0;
}
整除分块
对于正整数 \(n\),定义无限数列 \(a_{n, i} = \lfloor \dfrac{n}{i} \rfloor\) ,设 \(f(n)\) 表示数列 \(a_n\) 中最小的没有出现过的自然数。
\(T\) 次询问,每次给出 \(l, r\) ,求 \(\sum_{k = l}^r f(k) \bmod 998244353\) 。
\(T \leq 65536\) ,\(1 \leq l \leq r \leq 10^{36}\)
把所有 \(\lfloor \sqrt i \rfloor\) 相同的 \(f(i)-\sqrt i\) 写在同一行,会发现每行都是递减的两段,然后将这两段分开就是这样:

然后答案分两边计算,对于左半边(右半边类似),令 \(k=f(i)-\lfloor \sqrt i \rfloor\),考虑每个 \(k\) 出现次数,将答案分为以下几个部分:
- .由于是 \(f(i) - \lfloor \sqrt{i} \rfloor\),要把 \(-\sqrt{i}\) 加回来。
- 数字 \(k\) 组成的一个腰长为 \(2i\) 的等腰直角三角形。
- 数字 \(k\) 组成的一个边长为 \(2i\) 的平行四边形。
- 最后一行可能会有空缺,单独拎出来计算。
可能需要用到的公式:
具体实现:对于左侧前 \(n\) 行:设出现的最大值为 \(k\),和为:
对于右侧前 \(n\) 行,和为:
对于左侧第 \(n\) 行,设和为 \(f(n)\),则:
则左侧第 \(n\) 行前 \(z\) 个数和为 \(f(n) - f(n - z)\) 。对于右侧第 \(n\) 行,设和为 \(g(n)\) ,则:
则右侧第 \(n\) 行前 \(z\) 个数和为 \(g(n) - g(n - z)\) 。
对于 \(\sum_{i = 1} ^ n \sqrt{i}\):
#include <bits/stdc++.h>
typedef __int128 s128;
using namespace std;
const s128 Mod = 998244353, inv2 = 499122177, inv6 = 166374059, inv30 = 432572553;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline s128 S1(s128 x) {
return (x + 1) * x % Mod * inv2 % Mod;
}
inline s128 S2(s128 x) {
return x * (x + 1) % Mod * (2 * x + 1) % Mod * inv6 % Mod;
}
inline s128 S3(int x) {
return S1(x) * S1(x) % Mod;
}
inline s128 S4(s128 x) {
return x * (x + 1) % Mod * (2 * x + 1) % Mod * (3 * x * x + 3 * x - 1) * inv30 % Mod;
}
inline s128 BinarySearch1(s128 limit) {
s128 l = 0, r = 1e18, ans = 0;
while (l <= r) {
s128 mid = (l + r) >> 1;
if (mid * (mid + 1) + 1 > limit)
ans = mid, r = mid - 1;
else
l = mid + 1;
}
return ans;
}
inline s128 BinarySearch2(s128 limit) {
s128 l = 0, r = 1e18, ans = 0;
while (l <= r) {
s128 mid = (l + r) >> 1;
if (mid * mid + 1 > limit)
ans = mid, r = mid - 1;
else
l = mid + 1;
}
return ans;
}
inline s128 Sqrt(s128 n) {
if (n <= 1e18)
return sqrt((long long) n);
s128 l = 1e9, r = 1e18, ans = 0;
while (l <= r) {
s128 mid = (l + r) >> 1;
if (mid * mid <= n)
ans = mid, l = mid + 1;
else
r = mid - 1;
}
return ans;
}
inline s128 query1(s128 n) {
s128 lim = BinarySearch1(n);
s128 ans = (n * (n + 1) / 2) % Mod * lim % Mod;
ans = (ans - (n * 2 + 1) % Mod * ((S1(lim - 1) + S2(lim - 1)) % Mod * inv2 % Mod % Mod) % Mod) % Mod;
ans = (ans + S3(lim - 1)) % Mod;
ans = (ans + (S4(lim - 1) + S2(lim - 1)) % Mod * inv2 % Mod % Mod) % Mod;
return ans;
}
inline s128 query2(s128 n) {
s128 lim = Sqrt(n);
s128 ans = (S1(n + 1) - 1) % Mod;
ans = (ans + lim % Mod * ((n + 1) % Mod) % Mod) % Mod;
ans = (ans - S2(lim)) % Mod;
ans = (ans + n * (n + 1) % Mod * inv2 % Mod * lim % Mod) % Mod;
ans = (ans - n % Mod * S2(lim) % Mod) % Mod;
ans = (ans + (S4(lim) - S2(lim)) % Mod * inv2 % Mod % Mod) % Mod;
return ans;
}
inline s128 query3(s128 n) {
return (2 * S2(n) % Mod + S1(n)) % Mod;
}
inline s128 query4(s128 h, s128 n) {
s128 lim = BinarySearch2(n);
s128 ans = lim * n % Mod;
ans = (ans - S2(lim - 1)) % Mod;
return ans;
}
inline s128 query5(s128 h, s128 n) {
s128 lim = BinarySearch1(n);
s128 ans = lim * (n + 1) % Mod;
if (lim) {
ans = (ans - lim) % Mod;
ans = (ans - S1(lim - 1)) % Mod;
ans = (ans - S2(lim - 1)) % Mod;
}
return ans;
}
inline s128 query(s128 n) {
if (!n)
return 0;
s128 sq = Sqrt(n);
s128 ans = (query1(sq) + query2(sq) + query3(sq)) % Mod;
n = (sq + 1) * (sq + 1) - 1 - n;
ans = (ans - sq * n % Mod) % Mod;
if (n <= sq + 1)
ans = (ans - query4(sq, n)) % Mod;
else
ans = (ans - query4(sq, sq + 1) - query5(sq, n - sq - 1)) % Mod;
return ans;
}
signed main() {
freopen("mex.in", "r", stdin);
freopen("mex.out", "w", stdout);
int testid = read(), T = read();
while (T--) {
s128 l = read<s128>(), r = read<s128>();
printf("%d\n", (int)(((query(r) - query(l - 1)) % Mod + Mod) % Mod));
}
return 0;
}
森林
定义对一棵树做一次变换的含义为:当以 \(1\) 号节点为根时,交换两个互相不为祖先的点的子树。一棵树的权值为对它进行至多一次变换能得到的最大直径长度。
初始时只有一个节点 \(1\),有 \(n-1\) 个操作,第 \(i\) 次操作会给出一个整数 \(x\),表示新加入第 \(i+1\) 号点,并与第 \(x\) 号点连一条边。每次操作后输出当前的树的权值。
\(n \leq 2 \times 10^5\) ,强制在线
显然要找一个这样的东西,使得最长:

可以通过一次变换将这三段连成一条直径。定义第三叉为 \(\min(OA, OB, OC)\) 。
先给出做法:维护直径长度 \(dist\) 与其端点 \(u, v\) ,以及第三叉长度 \(mx\) 。对于每次加入叶子 \(w\) :
- 加入 \(w\) 后直径变长,则第三叉长度不变。
- 加入 \(w\) 后直径不变:则用 \(w\) 到 \((u, v)\) 的距离更新 \(mx\) 。
对于该做法的正确性,只要证明三点:
-
新的直径一定是 \((u, w), (u, v), (w, v)\) 之一。
分两类讨论。若新的直径没有原来长,显然成立。否则设其为 \((w, x)\) ,那么 \((fa_w, x)\) 一定是原来直径之一。由于 \((u, v)\) 也是原来直径之一,且树上所有直径相交与同一点,设其为 \(rt\) ,则 \(dist(u, rt) = dist(v, rt) = dist(fa_w, rt) = dist(x, rt)\) ,则 \((u, w) = (v, w) = (x, w)\) ,与假设矛盾。
-
除去第三叉(不包括中心点)一定是一条直径。
首先如果选直径必定最优,那么选任意一条直径都可以,因为树上所有直径交于一点。
分两类讨论:
- 有多条直径:选三条半径组成答案。
- 只有一条直径:假设有一个更优的答案 \(OF\) 连在非直径 \(DE\) 上。如果 \(DE\) 与直径不交,那么吧 \(DE\) 拿下来给直径更优,否则不如把 \(OF\) 变为第三叉,\(DE\) 伸长为直径肯定更优,与假设矛盾。
-
如果直径变大,第三叉不变,否则第三叉要么不变要么变大。
后半句是废话,只有 \(w\) 到 \((u, v)\) 的距离可能更新第三叉。
对前半句,如果新的直径是 \((u, w)\) (\((v, w)\) 同理),那么用 \(v\) 到 \((u, w)\) 做第三叉,而原本的第三叉是 \(fa_w\) 到 \((u, v)\) ,并且此时必有 \(fa_w\) 到 \((u, v)\) 和 \(v\) 到 \((u, w)\) 一样长。
于是维护一下 LCA 即可,时间复杂度 \(O(n \log n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 7, LOGN = 23;
int fa[N][LOGN];
int dep[N];
int testid, n, X = 1, Y = 1, Z = 1;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline int LCA(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
if (h & 1)
x = fa[x][i];
if (x == y)
return x;
for (int i = LOGN - 1; ~i; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
inline int dist(int x, int y) {
return dep[x] + dep[y] - dep[LCA(x, y)] * 2;
}
inline int getdis(int u) {
int lca = LCA(X, Y);
return LCA(u, lca) == lca ? min(dist(LCA(X, u), u), dist(LCA(Y, u), u)) : dist(u, lca);
}
inline void update(int u) {
if (dist(X, u) < dist(Y, u))
swap(X, Y);
if (dist(X, u) > dist(X, Y))
swap(u, Y);
if (getdis(Z) < getdis(u))
swap(u, Z);
}
signed main() {
freopen("forest.in", "r", stdin);
freopen("forest.out", "w", stdout);
testid = read(), n = read();
dep[1] = 1;
for (int u = 2, lstans = 0; u <= n; ++u) {
dep[u] = dep[fa[u][0] = read() ^ lstans] + 1;
for (int i = 1; i < LOGN; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
update(u);
printf("%d\n", lstans = dist(X, Y) + max(0, getdis(Z) - 1));
}
return 0;
}
Day 5
最短路
给定无向图,走过每条边都需要花费 \(1\) 秒。请选择至多 \(k\) 个点(不能选择点 \(0\) 或点 \(n - 1\) ),令经过这些点也需要花费 \(1\) 秒,使得从点 \(0\) 走到点 \(n-1\) 的最短时间最大,求出这个最大值。
\(n \leq 100\)
考虑二分答案。判断答案 \(mid\) 是否可行时建立 \(mid\) 层二分图,每层都拆点,则经过同层的边时就代表选择这个点,经过相邻层的边时就代表走这条边。原图的边容量设为 \(+ \infty\) ,拆点之间的边( \(0, n - 1\) 除外)容量设为 \(1\) 。此时需要选择的点数就是最小割,判断其是否不超过 \(k\) 即可。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e2 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
int s[N << 1][N], t[N << 1][N];
int n, m, k;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace Dinic {
const int N = 1e5 + 7, M = 1e7 + 7;
struct Edge {
int nxt, v, f;
} e[M];
int head[N], cur[N], dep[N];
bool vis[N];
int n, S, T, tot, maxflow;
inline void prework(int _n, int _S, int _T) {
n = _n, S = _S, T = _T, tot = 1;
memset(head + 1, 0, sizeof(int) * n);
}
inline void insert(int u, int v, int f) {
e[++tot] = (Edge) {head[u], v, f}, head[u] = tot;
e[++tot] = (Edge) {head[v], u, 0}, head[v] = tot;
}
inline bool bfs() {
memcpy(cur + 1, head + 1, sizeof(int) * n);
memset(dep + 1, 0, sizeof(int) * n);
memset(vis + 1, false, sizeof(bool) * n);
queue<int> q;
dep[S] = 1, q.emplace(S);
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v, f = e[i].f;
if (f && !dep[v])
dep[v] = dep[u] + 1, q.emplace(v);
}
}
return dep[T];
}
int dfs(int u, int flow) {
if (u == T)
return flow;
vis[u] = true;
int outflow = 0;
for (int &i = cur[u]; i; i = e[i].nxt) {
int v = e[i].v, f = e[i].f;
if (f && dep[v] == dep[u] + 1 && !vis[v]) {
int res = dfs(v, min(f, flow - outflow));
e[i].f -= res, e[i ^ 1].f += res, outflow += res;
if (outflow == flow)
break;
}
}
if (outflow == flow)
vis[u] = false;
return outflow;
}
inline int solve(const bool flag = true) {
if (flag)
maxflow = 0;
while (bfs())
maxflow += dfs(S, inf);
return maxflow;
}
} // namespace Dinic
inline bool check(int lambda) {
int tot = 0;
for (int i = 0; i <= lambda; ++i)
for (int u = 1; u <= n; ++u)
s[i][u] = ++tot, t[i][u] = ++tot;
Dinic::prework(tot, t[0][1], t[lambda][n]);
for (int i = 0; i <= lambda; ++i) {
for (int u = 1; u <= n; ++u)
Dinic::insert(s[i][u], t[i][u], u != n);
if (i < lambda) {
for (int u = 1; u <= n; ++u) {
Dinic::insert(s[i][u], t[i + 1][u], inf);
for (int v : G.e[u])
Dinic::insert(t[i][u], s[i + 1][v], inf);
Dinic::insert(s[i][n], s[i + 1][n], inf);
}
}
}
Dinic::solve();
return Dinic::maxflow <= k;
}
signed main() {
freopen("min.in", "r", stdin);
freopen("min.out", "w", stdout);
n = read(), m = read(), k = min(read(), n - 2);
for (int i = 1; i <= m; ++i) {
int u = read() + 1, v = read() + 1;
G.insert(u, v), G.insert(v, u);
}
int l = 1, r = 2 * n, mid, ans = -1;
while (l <= r) {
mid = (l + r) >> 1;
if (check(mid))
l = mid + 1, ans = mid;
else
r = mid - 1;
}
printf("%d", ans);
return 0;
}
子图
给定一张无向图,求从所有边中随机选出 \(k\) 条边的导出子图连通的方案数。
\(n \leq 10^5\) ,\(m \leq 2 \times 10^5\) ,\(k \in \{ 3, 4 \}\)
先考虑 \(k = 3\) 的情况:
- 三元环:每个三元环会算三次。
- 链:枚举中间边,算两边的贡献,但两边的出点重合会退化为三元环。
- 菊花图:枚举根算。
再考虑 \(k = 4\) 的情况:
- 一条链:枚举中点算。
- \((1, 1, 2)\) 的菊花:枚举 \(2\) 那条链和菊花中心的边,然后算这条边两边是挂一个还是两个的贡献。
- \((1, 1, 1, 1)\) 的菊花:直接组合数算即可。
- 三元环挂点:求出每个点所在三元环个数然后枚举点算贡献。
- 四元环:直接求即可。
需要注意退化的情况,细节很多。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7, inv2 = 500000004, inv4 = 250000002, inv6 = 166666668;
const int N = 1e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G, nG;
int deg[N], tag[N], cnt3[N], cnt[N];
int n, m, k;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline bool cmp(const int &u, const int &v) {
return deg[u] > deg[v] || (deg[u] == deg[v] && u < v);
}
inline int FindCircle3() {
for (int u = 1; u <= n; ++u)
for (int v : G.e[u])
if (cmp(u, v))
nG.insert(u, v);
int sum = 0;
for (int i = 1; i <= n; ++i) {
for (int j : nG.e[i])
tag[j] = i;
for (int j : nG.e[i])
for (int k : nG.e[j])
if (tag[k] == i) {
++cnt3[i], ++cnt3[j], ++cnt3[k];
++sum;
if (sum == Mod)
sum = 0;
}
}
return sum;
}
inline int FindCircle4() {
memset(tag, 0, sizeof(tag));
int sum = 0;
for (int u = 1; u <= n; ++u)
for (int v : G.e[u])
if (cmp(u, v))
for (int w : G.e[v])
if (cmp(u, w) && w != u) {
if (tag[w] != u)
tag[w] = u, cnt[w] = 1;
else
++cnt[w];
sum = (sum + cnt[w] - 1) % Mod;
}
return sum;
}
signed main() {
freopen("subgraph.in", "r", stdin);
freopen("subgraph.out", "w", stdout);
n = read(), m = read(), k = read();
for (int i = 1; i <= m; ++i) {
int u = read(), v = read();
G.insert(u, v), G.insert(v, u);
++deg[u], ++deg[v];
}
if (k == 3) {
int ans = 0;
for (int u = 1; u <= n; ++u)
for (int v : G.e[u])
ans = (ans + 1ll * (deg[u] - 1) * (deg[v] - 1) % Mod) % Mod;
ans = 1ll * ans * inv2 % Mod;
for (int i = 1; i <= n; ++i)
ans = (ans + 1ll * deg[i] * (deg[i] - 1) % Mod * (deg[i] - 2) % Mod * inv6 % Mod) % Mod;
printf("%d", ((ans - 2ll * FindCircle3() % Mod) % Mod + Mod) % Mod);
} else {
int ans = ((-3ll * FindCircle3() % Mod - 3ll * FindCircle4() % Mod) % Mod + Mod) % Mod;
for (int u = 1; u <= n; ++u) {
ans = (((ans + 1ll * deg[u] * (deg[u] - 1) % Mod * (deg[u] - 2) % Mod * (deg[u] - 3) % Mod * inv4 % Mod * inv6 % Mod) % Mod
- 3ll * cnt3[u] * (deg[u] - 2) % Mod) % Mod + Mod) % Mod;
for (int v : G.e[u])
ans = (ans + 1ll * (deg[u] - 1) * (deg[u] - 2) % Mod * inv2 % Mod * (deg[v] - 1) % Mod) % Mod;
int sum = 0;
for (int v : G.e[u]) {
ans = (ans + 1ll * (deg[v] - 1) * sum % Mod) % Mod;
sum = (sum + deg[v] - 1) % Mod;
}
}
printf("%d", ans);
}
return 0;
}
吃
给出一张无向连通图,每次随机选择一个还未删除的点 \(v\),然后访问所有与 \(v\) 连通的点,然后删除点 \(v\) ,直至所有点都被删除。求期望访问次数。
\(n \leq 10^5\) ,\(m \in \{n - 1, n \}\)
先考虑树的情况。显然这个期望就是每种情况的概率和,又因为期望具有线性性,所以可以考虑两两点对对答案的贡献。对于有序点对 \((x, y)\) ,设树上路径经过 \(c\) 个节点,对答案贡献为 \(\dfrac{1}{c}\) 。
若在访问 \(x\) 时 \(y\) 能够被访问,则 \(x \to y\) 的这条链还没有断。也就是在所有删点方案中,只要 \(x\) 在 \(x \to y\) 这条链除了 \(x\) 之外的任意一个点被删除之前删除,就会产生贡献。这个概率为 \(\dfrac{1}{c}\) 。
所以我们只需要对每个 \(c\) 求出长度为 \(c\) 的路径条数,点分治 + NTT合并即可。
接下来考虑基环树,仍然是考虑有序点对 \((x, y)\) 的贡献。
若 \((x, y)\) 在同一子树内,则与树的计算方法相同。
否则记 \(a, b\) 为 \(x, y\) 所在树的根把环分为两侧后两侧分别的长度, \(c = dep_x + dep_y\) ,则删除 \(x\) 时 \(x, y\) 仍连通的概率等于 \(x\) 在 \(a + c\) 个点中最先删除的概率或在 \(b + c\) 个点中最先删除的概率(分别对应经过环的两侧),此时贡献为 \(\dfrac{1}{a + c} + \dfrac{1}{b + c} - \dfrac{1}{a + b + c}\) 。随便选一条边破环为链,用分治配合 NTT 合并即可。
总时间复杂度 \(O(n \log^2 n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353, rt = 3, invrt = 332748118;
const int N = 1e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
int cnt[N];
int n, m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
inline int mi(int a, int b) {
int res = 1;
for (; b; b >>= 1, a = 1ll * a * a % Mod)
if (b & 1)
res = 1ll * res * a % Mod;
return res;
}
namespace Poly {
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n))
#define clr(f, n) memset(f, 0, sizeof(int) * (n))
const int S = 2e6 + 7;
int rev[S], inv[S], iG[S][2];
inline void calrev(int n) {
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
}
inline int calc(int n) {
int len = 1;
while (len <= n)
len <<= 1;
return calrev(len), len;
}
inline void NTT(int *f, int n, int op) {
for (int i = 0; i < n; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
for (int k = 1; k < n; k <<= 1) {
int tG = mi(op == 1 ? rt : invrt, (Mod - 1) / (k << 1));
for (int i = 0; i < n; i += k << 1) {
int buf = 1;
for (int j = 0; j < k; ++j) {
int fl = f[i + j], fr = 1ll * buf * f[i + j + k] % Mod;
f[i + j] = add(fl, fr), f[i + j + k] = dec(fl, fr);
buf = 1ll * buf * tG % Mod;
}
}
}
if (op == -1) {
int invn = mi(n, Mod - 2);
for (int i = 0; i < n; ++i)
f[i] = 1ll * f[i] * invn % Mod;
}
}
inline void Times(int *f, int *g, int len) {
NTT(f, len, 1), NTT(g, len, 1);
for (int i = 0; i < len; ++i)
f[i] = 1ll * f[i] * g[i] % Mod;
NTT(f, len, -1);
}
inline void Mul(int *f, int n, int *g, int m, int *res) {
static int a[S], b[S];
int len = calc(n + m - 1);
cpy(a, f, n), clr(a + n, len - n);
cpy(b, g, m), clr(b + m, len - m);
Times(a, b, len);
cpy(res, a, n + m - 1);
}
#undef cpy
#undef clr
} // namespace Poly
namespace Method1 {
int siz[N], mxsiz[N], cur[N], buc[N], sum[N];
bool vis[N];
int root, mxdep;
int getsiz(int u, int f) {
siz[u] = 1;
for (int v : G.e[u])
if (!vis[v] && v != f)
siz[u] += getsiz(v, u);
return siz[u];
}
void getroot(int u, int f, const int Siz) {
siz[u] = 1, mxsiz[u] = 0;
for (int v : G.e[u])
if (!vis[v] && v != f)
getroot(v, u, Siz), siz[u] += siz[v], mxsiz[u] = max(mxsiz[u], siz[v]);
mxsiz[u] = max(mxsiz[u], Siz - siz[u]);
if (!root || mxsiz[u] < mxsiz[root])
root = u;
}
void dfs(int u, int f, int d) {
++cur[d], mxdep = max(mxdep, d);
for (int v : G.e[u])
if (!vis[v] && v != f)
dfs(v, u, d + 1);
}
void calc(int u) {
++cnt[1], buc[1] = 1;
int len = 1;
for (int v : G.e[u]) {
if (vis[v])
continue;
mxdep = 0, dfs(v, u, 1);
Poly::Mul(buc + 1, len, cur + 1, mxdep, sum + 2);
for (int i = 2; i <= len + mxdep; ++i)
cnt[i] = add(cnt[i], 2ll * sum[i] % Mod);
len = max(len, mxdep + 1);
for (int i = 1; i <= mxdep; ++i)
buc[i + 1] += cur[i], cur[i] = 0;
}
memset(buc + 1, 0, sizeof(int) * len);
}
void build(int u) {
vis[u] = true, calc(u);
for (int v : G.e[u])
if (!vis[v])
root = 0, getroot(v, u, getsiz(v, u)), build(root);
}
inline void solve() {
root = 0, getroot(1, 0, n), build(root);
}
} // namespace Method1
namespace Method2 {
int sta[N], circle[N], ldep[N], rdep[N], sum[N];
bool vis[N], incir[N];
int top, len, lmxdep, rmxdep;
bool findcircle(int u, int f) {
vis[u] = true, sta[++top] = u;
for (int v : G.e[u]) {
if (v == f)
continue;
if (vis[v]) {
while (sta[top] != v)
incir[sta[top]] = true, circle[++len] = sta[top--];
incir[v] = true, circle[++len] = v;
return true;
}
if (findcircle(v, u))
return true;
}
--top;
return false;
}
inline void dfs(int u, int f, int d, bool op) {
if (op)
++rdep[d], rmxdep = max(rmxdep, d);
else
++ldep[d], lmxdep = max(lmxdep, d);
for (int v : G.e[u])
if (v != f && !incir[v])
dfs(v, u, d + 1, op);
}
void divide(int l, int r) {
if (l == r)
return;
int mid = (l + r) >> 1;
lmxdep = rmxdep = 0;
for (int i = l; i <= mid; ++i)
dfs(circle[i], 0, i, false);
for (int i = mid + 1; i <= r; ++i)
dfs(circle[i], 0, len - i + 1, true);
int lbe = l, rbe = len - r + 1;
Poly::Mul(ldep + lbe, lmxdep - lbe + 1, rdep + rbe, rmxdep - rbe + 1, sum + lbe + rbe);
for (int i = lbe + rbe; i <= lmxdep + rmxdep; ++i)
cnt[i] = add(cnt[i], 2ll * sum[i] % Mod);
memset(ldep + lbe, 0, sizeof(int) * (lmxdep - lbe + 1));
memset(rdep + rbe, 0, sizeof(int) * (rmxdep - rbe + 1));
lmxdep = rmxdep = 0;
for (int i = l; i <= mid; ++i)
dfs(circle[i], 0, 0, false);
for (int i = mid + 1; i <= r; ++i)
dfs(circle[i], 0, 0, true);
Poly::Mul(ldep, lmxdep + 1, rdep, rmxdep + 1, sum);
for (int i = 0; i <= lmxdep + rmxdep; ++i)
cnt[len + i] = dec(cnt[len + i], 2ll * sum[i] % Mod);
memset(ldep, 0, sizeof(int) * (lmxdep + 1));
memset(rdep, 0, sizeof(int) * (rmxdep + 1));
divide(l, mid), divide(mid + 1, r);
}
inline void solve() {
findcircle(1, 0);
int x = circle[1], y = circle[len];
G.e[x].erase(find(G.e[x].begin(), G.e[x].end(), y));
G.e[y].erase(find(G.e[y].begin(), G.e[y].end(), x));
Method1::build(1);
divide(1, len);
}
} // namespace Method2
signed main() {
freopen("eat.in", "r", stdin);
freopen("eat.out", "w", stdout);
n = read(), m = read();
for (int i = 1; i <= m; ++i) {
int u = read(), v = read();
G.insert(u, v), G.insert(v, u);
}
if (m == n - 1)
Method1::solve();
else
Method2::solve();
int ans = 0;
for (int i = 1; i <= n; ++i)
ans = add(ans, 1ll * cnt[i] * mi(i, Mod - 2) % Mod);
printf("%d", ans);
return 0;
}
Day 6
堆
求有多少种不同的 \(n\) 个节点的二叉堆满足所有节点权值为 \(1 \sim n\) 的排列。
\(n \leq 10^9\)
考虑合法二叉堆的条件:每个点 \(u\) 的权值都是 \(u\) 子树中最小的。设 \(siz_u\) 表示 \(u\) 子树的大小,那么有 \(\dfrac{1}{siz_u}\) 的概率,满足 \(u\) 的权值都是 \(u\) 子树中最小的,因此答案就是 \(\dfrac{n!}{\prod_{i = 1}^n siz_i}\) ,问题转化为求 \(\prod_{i = 1}^n siz_i\) 。
设树高为 \(h\),那么只有第 \(h\) 层的节点不满,去掉第 \(h\) 层的节点之后,必定是一棵满二叉树。
将第 \(h\) 层的节点分成两个部分,在根的左子树和在根的右子树。如果第二部分中没有节点,那么根的右子树是棵满二叉树;如果第二部分中有节点,那么根的左子树肯定是满二叉树。于是根的两棵子树必有其一是满二叉树。又因为对于二叉堆中的每个节点 ,它的子树都有二叉堆的性质,所以每个节点都满足:它的两个子树,必有其一是满二叉树。
记 \(g_j\) 表示 \(j\) 层满二叉树的 \(\prod siz_i\) ,\(f_u\) 表示以 \(u\) 为根的子树中的 \(\prod siz_i\) 。
对于 \(g\) ,显然有 \(g_j = g_{j - 1}^2 \times (2^j - 1)\) 。
对于 \(f\) ,设 \(u\) 的左右子树分别为 \(x, y\) ,\(u\) 子树的树高为 \(p\) 。可以通过计算 \(u\) 子树中最后一层的节点数是否过半,来判断哪个子树是满二叉树。
-
若 \(x\) 子树是满二叉树,则 \(f_u = g_{p - 1} \times f_y \times siz_u\) ,\(siz_u = 2^{p - 1} + siz_y\) 。
-
若 \(y\) 子树是满二叉树,则 \(f_u = g_{p - 2} \times f_x \times siz_u\) ,\(siz_u = 2^{p - 2} + siz_x\) 。
对于计算 \(n!\) ,分块打表即可。
#include <bits/stdc++.h>
using namespace std;
const int fac[] = { 1, 682498929, 491101308, 76479948, 723816384, 67347853, 27368307, 625544428,
199888908, 888050723, 927880474, 281863274, 661224977, 623534362, 970055531, 261384175,
195888993, 66404266, 547665832, 109838563, 933245637, 724691727, 368925948, 268838846,
136026497, 112390913, 135498044, 217544623, 419363534, 500780548, 668123525, 128487469,
30977140, 522049725, 309058615, 386027524, 189239124, 148528617, 940567523, 917084264,
429277690, 996164327, 358655417, 568392357, 780072518, 462639908, 275105629, 909210595,
99199382, 703397904, 733333339, 97830135, 608823837, 256141983, 141827977, 696628828,
637939935, 811575797, 848924691, 131772368, 724464507, 272814771, 326159309, 456152084,
903466878, 92255682, 769795511, 373745190, 606241871, 825871994, 957939114, 435887178,
852304035, 663307737, 375297772, 217598709, 624148346, 671734977, 624500515, 748510389,
203191898, 423951674, 629786193, 672850561, 814362881, 823845496, 116667533, 256473217,
627655552, 245795606, 586445753, 172114298, 193781724, 778983779, 83868974, 315103615,
965785236, 492741665, 377329025, 847549272, 698611116 };
const int Mod = 1e9 + 7;
const int N = 31, B = 1e7;
int g[N];
int n;
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
inline int mi(int a, int b) {
int res = 1;
for (; b; b >>= 1, a = 1ll * a * a % Mod)
if (b & 1)
res = 1ll * res * a % Mod;
return res;
}
inline int factorial(int n) {
int res = fac[n / B];
for (int i = (n / B) * B + 1; i <= n; ++i)
res = 1ll * res * i % Mod;
return res;
}
int dfs(int siz) {
if (siz <= 1)
return 1;
int res = siz, x = __lg(siz + 1);
siz -= (1 << x) - 1;
if (siz < (1 << x - 1))
res = 1ll * res * g[x - 1] % Mod * dfs((1 << x - 1) + siz - 1) % Mod;
else
res = 1ll * res * g[x] % Mod * dfs(siz - 1) % Mod;
return res;
}
signed main() {
freopen("heap.in", "r", stdin);
freopen("heap.out", "w", stdout);
g[0] = 1;
for (int i = 1; i < N; ++i)
g[i] = 1ll * g[i - 1] * g[i - 1] % Mod * ((1 << i) - 1) % Mod;
scanf("%d", &n);
printf("%d", 1ll * factorial(n) * mi(dfs(n), Mod - 2) % Mod);
return 0;
}
密文
有一串长度为 \(n\) 的密文,密文的每一位都可以用一个非负整数来描述,并且每一位都有一个权值 \(a_i\) 。
可以进行任意多次操作,每次操作可以选择连续一段密文,花费选择的所有位上权值的异或和的代价获得这段密文每一位的异或和。
求至少需要花费多少代价才能将密文的每一位都破解出来。
\(n \leq 10^5\)
记 \(s_i\) 为前 \(i\) 个数的异或和,把数列看作 \(0 \sim n\) 共 \(n + 1\) 个点。如果获得了 \([i,j]\) 的异或和,那么需要花费 \(s_i \oplus s_{j - 1}\) 的代价,放到图上就是 \(i \to j - 1\) 连权值为 \(s_i \oplus s_{j - 1}\) 的双向边。
由于 \(s_0\) 已知,而且显然按照这样建图,如果 \(x \to y\) 有路径,那么 \(s_x\) 和 \(s_y\) 可以互推。
目标是让所有的 \(s_i\) 已知,即让 \(1 \sim n\) 都有到 \(0\) 的路径,那么就是让所有点连通,即这张图的 MST 就是答案。
然后就是 CF888G Xor-MST 了。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 2e5 + 7;
int a[N];
int n;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace Trie {
int ch[N << 5][2];
int L[N << 5], R[N << 5];
int tot = 1;
inline void insert(int k, int id) {
int u = 1;
for (int i = 31; ~i; --i) {
L[u] = (L[u] ? L[u] : id), R[u] = id;
int idx = k >> i & 1;
if (!ch[u][idx])
ch[u][idx] = ++tot;
u = ch[u][idx];
}
L[u] = (L[u] ? L[u] : id), R[u] = id;
}
inline int query(int u, int d, int k) {
int res = 0;
for (int i = d; ~i; --i) {
if (ch[u][k >> i & 1])
u = ch[u][k >> i & 1];
else if (ch[u][~k >> i & 1])
u = ch[u][~k >> i & 1], res |= 1 << i;
else
return 0;
}
return res;
}
ll dfs(int u, int d) {
if (d == -1)
return 0;
else if (ch[u][0] && ch[u][1]) {
int ans = inf;
if (R[ch[u][0]] - L[ch[u][0]] + 1 <= R[ch[u][1]] - L[ch[u][1]] + 1) {
for (int i = L[ch[u][0]]; i <= R[ch[u][0]]; ++i)
ans = min(ans, query(ch[u][1], d - 1, a[i]) | (1 << d));
} else {
for (int i = L[ch[u][1]]; i <= R[ch[u][1]]; ++i)
ans = min(ans, query(ch[u][0], d - 1, a[i]) | (1 << d));
}
return ans + dfs(ch[u][0], d - 1) + dfs(ch[u][1], d - 1);
} else if (ch[u][0])
return dfs(ch[u][0], d - 1);
else if (ch[u][1])
return dfs(ch[u][1], d - 1);
else
return 0;
}
} // namespace Trie
signed main() {
freopen("secret.in", "r", stdin);
freopen("secret.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i)
a[i + 1] = a[i] ^ read();
sort(a + 1, a + n + 2);
for (int i = 1; i <= n + 1; ++i)
Trie::insert(a[i], i);
printf("%lld", Trie::dfs(1, 31));
return 0;
}
树
给出一棵 \(n\) 个点的树 \(A\) 和一棵 \(m\) 个点的树 \(B\),求 \(A\) 有多少个不同的连通子图与 \(B\) 同构。
\(n \leq 2000\) ,\(m \leq 12\)
首先判断同构要先给两棵树定一个根,那么枚举 \(B\) 的根,再枚举所有 \(A\) 中的点与之对应。
下记 \(A_u\) 表示 \(A\) 中的节点 \(u\) ,\(rt_A\) 为 \(A\) 的根,\(B_u, rt_B\) 同理。设 \(f_{u, i}\) 表示 \(A_u\) 为根的子图与 \(B_i\) 为根的子树匹配的个数。
具体地,先确定 \(rt_B\) ,然后把 \(A\) dfs 一遍。回溯的时候转移:
- 先枚举 \(A_u\) 对应的 \(B_k\) ,\(B_k\) 可以是任意一个 \(B\) 中的点。
- 记一个状压数组 \(g_{t, s}\) ,\(s\) 的第 \(i\) 位表示 \(B_k\) 的第 \(i\) 个子树是否出现过,\(g_{t, s}\) 为 \(A_u\) 以及前 \(t\) 个子树生成的子图,与 \(s\) 中的子树匹配方案数,边界 \(g_{0, 0} = 1\) ,即 \(A_u\) 和 \(B_k\) 都是叶子的情况。
- 枚举 \(A_v\) ,\(v\) 是 \(u\) 的儿子(也就是枚举 \(t\) 了),枚举 \(B_j\) 和它对应,再枚举 \(s\) ,找到它第一个没出现的第 \(x\) 个子树,转移到 \(g_{t + 1, s + 2^{x - 1}}\) 。
- \(f_{u, k} = g_{S_{A_u}, 2^{S_{B_j}} - 1}\) ,其中 \(S(x)\) 表示 \(x\) 的子树个数。
但是这样会重复,将 \(f_{u, i}\) 定义为 \(A_u\) 和 \(B\) 的第 \(j\) 种子树对应的方案数即可。
判断同构可以用树哈希。
#include <bits/stdc++.h>
typedef unsigned long long ull;
using namespace std;
const int Mod = 1e9 + 7;
const int N = 2e3 + 7, M = 13;
struct Graph {
vector<vector<int> > e;
inline Graph(const int n) {
e.resize(n + 1);
}
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} A(N), B(M);
set<ull> st;
ull h[M];
int f[N][M];
int fa[M], rk[M];
int n, m, root, len;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
void dfs1(int u, int f) {
fa[u] = f, h[u] = 1;
for (int v : B.e[u])
if (v != f)
dfs1(v, u), h[u] += h[v];
h[u] *= h[u];
}
void dfs2(int u, int father) {
for (int v : A.e[u])
if (v != father)
dfs2(v, u);
for (int i = 1; i <= len; ++i) {
f[u][i] = 0;
vector<int> num(len + 1);
int r = rk[i];
for (int v : B.e[r])
if (v != fa[r])
++num[h[v]];
for (int j = 1; j <= len; ++j)
num[j] += num[j - 1];
int S = 1 << (B.e[r].size() - (r != root));
vector<int> g(S);
g[0] = 1;
for (int v : A.e[u]) {
if (v == father)
continue;
vector<int> h = g;
for (int j = 1; j <= len; ++j)
if (num[j] > num[j - 1] && f[v][j])
for (int s = 0; s < S; ++s)
for (int k = num[j - 1]; k < num[j]; ++k)
if (~s >> k & 1) {
g[s | (1 << k)] = add(g[s | (1 << k)], 1ll * h[s] * f[v][j] % Mod);
break;
}
}
f[u][i] = g[S - 1];
}
}
inline int solve(int rt) {
dfs1(root = rt, 0);
if (st.find(h[root]) != st.end())
return 0;
st.insert(h[root]);
vector<ull> vec;
for (int i = 1; i <= m; ++i)
vec.emplace_back(h[i]);
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
len = vec.size();
for (int i = 1; i <= m; ++i)
rk[h[i] = lower_bound(vec.begin(), vec.end(), h[i]) - vec.begin() + 1] = i;
dfs2(1, 0);
int res = 0;
for (int i = 1; i <= n; ++i)
res = add(res, f[i][h[root]]);
return res;
}
signed main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
n = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
A.insert(u, v), A.insert(v, u);
}
m = read();
for (int i = 1; i < m; ++i) {
int u = read(), v = read();
B.insert(u, v), B.insert(v, u);
}
int ans = 0;
for (int i = 1; i <= m; ++i)
ans = add(ans, solve(i));
printf("%d", ans);
return 0;
}

浙公网安备 33010602011771号