Loading

[CTSC2011] 字符串重排

是个美男题。其实不美。

首先这道题要最大化 \(\text{LCP}\) 的和,我们不难想到使用 Trie 去维护这些字符串,但是本人赛时瞪了 1h 才想到,我太菜了,我是怎么在不太会说着说忘记 Trie 的情况下想到的呢?

我观察样例发现在按字典序排序后能互相交换的其实不一定是某两位,而是区间,这两个区间满足的条件其实就是在 Trie 树上是兄弟关系。当我发现这些区间满足包含关系时想到了建树,而交换顺序其实就是任取 dfs 序,之后发现这棵树就是 Trie 树。

回归正题,我们转化问题之后该怎么做呢?首先显然我们需要从后往前去一个一个 check 合法,考虑一组附加条件的影响。

首先我们可以利用一个小方法去使串没有前缀关系,就是往每个串后加一个不在 \(a\)\(z\) 之间的字母即可。因为如果满足前缀关系需要限制会变多,而利用这个小 trick 可以去掉前缀关系,基本上是个万金油的小 trick。

然后我们分析发现最终对于一个结点 \(u\),我们会有三种限制:

  1. 子结点 \(v\) 必须是第一个遍历的。
  2. 子结点 \(v\) 必须是最后一个遍历的。
  3. 子节点 \(v\) 的下一个遍历的必须是子结点 \(w\)

我们考虑这三个限制的影响。前两个限制会影响路径中除了首尾和最近公共祖先的位置,而第三个只会影响最近公共祖先。这时我们虽然可以使用树链剖分,但是关于 Trie 我们有一个性质,我们缩二度点之后树高只有 \(\mathcal{O}(\sqrt {\sum s_i})\) 级别,所以我们缩完直接暴力枚举即可。

具体怎么维护呢?对于前两种性质看起来是好维护的,第三种我们考虑使用链表维护,总之是一些很暴力的东西,不要想复杂就好了。比如说我们可以对于第三种限制从 \(v\)\(w\) 连一条边,然后我们可以连出若干连通块,对于连通块内每个点我们可以暴力把链头链尾以及链的长度维护出来,然后在考虑需要满足什么限制即可,实现细节可以借鉴代码:

#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for (int i (l); i <= (r); ++ i)
#define rrp(i, l, r) for (int i (r); i >= (l); -- i)
#define eb emplace_back
using namespace std;
#define pii pair <int, int>
#define inf 100000000
#define ls (p << 1)
#define rs (ls | 1)
constexpr int N = 3e5 + 5, M = 1e6 + 5, P = 998244353;
typedef long long ll;
typedef unsigned long long ull;
inline int rd () {
  int x = 0, f = 1;
  char ch = getchar ();
  while (! isdigit (ch)) {
    if (ch == '-') f = -1;
    ch = getchar ();
  }
  while (isdigit (ch)) {
    x = (x << 1) + (x << 3) + (ch ^ 48);
    ch = getchar ();
  }
  return x * f;
}
int qpow (int x, int y) {
  int ret (1);
  for (; y; y >>= 1, x = x * x % P) if (y & 1) ret = ret * x % P;
  return ret;
}
vector <int> e[N];
int tr[N][27], tot = 1, deg[N], dep[N];
int n, m, id[N], pi[N], fa[N];
int sz[N], nxt[N], pre[N], st[N], ed[N], X[N], Y[N];
int hd[N], tl[N], mt[N];
char s[N];
vector <int> g[N];
class node {
  public:
  int u, v, l;
} ;
node LCA (int u, int v) {
  bool fl (0);
  if (dep[u] < dep[v]) swap (u, v), fl = 1;
  while (dep[u] > dep[v]) u = fa[u];
  while (fa[u] != fa[v]) u = fa[u], v = fa[v];
  return (node) {fl ? v : u, fl ? u : v, fa[u]};
}
vector <int> vec;
bool vis[N];
void dfs (int u) {
  if (mt[u]) vec.eb (mt[u]);
  vis[u] = 1;
  if (st[u]) {
    for (int w (st[u]); w; w = nxt[w]) dfs (w);
  }
  for (auto v : g[u]) {
    if (vis[v] || tl[v] == ed[u] || v != hd[v]) continue;
    for (int w (v); w; w = nxt[w]) dfs (w);
  }
  if (ed[u]) {
    for (int w (hd[ed[u]]); w; w = nxt[w]) dfs (w);
  }
}
int32_t main () {
  n = rd (), m = rd ();
  int ans (0);
  rep (i, 1, n) {
    scanf ("%s", s + 1);
    int len (strlen (s + 1)), p (1);
    s[++ len] = 'z' + 1;
    rep (j, 1, len) {
      int o (s[j] - 'a');
      if (! tr[p][o]) {
        tr[p][o] = ++ tot;
        if (++ deg[p] >= 2) ans += (j - 1) * (j - 1);
      }
      p = tr[p][o];
    } id[i] = p; mt[p] = i;
  }
  rrp (i, 1, tot) {
    if (deg[i] == 1 && i > 1) {
      rep (j, 0, 26) {
        if (tr[i][j]) pi[i] = pi[tr[i][j]];
      }
    } else {
      pi[i] = i; 
      rep (j, 0, 26) {
        if (tr[i][j]) g[i].eb (pi[tr[i][j]]), fa[pi[tr[i][j]]] = i;
      }
    }
    hd[i] = tl[i] = i; sz[i] = 1;
  }
  rep (i, 1, tot) for (auto u : g[i]) dep[u] = dep[i] + 1;
  rep (i, 1, m) X[i] = rd (), Y[i] = rd ();
  vector <int> A;
  rrp (i, 1, m) {
    int u (id[X[i]]), v (id[Y[i]]);
    node t (LCA (u, v));
    int l (t.l), p (t.u), q (t.v);
    bool chk (1);
    for (int x (u); fa[x] != l; x = fa[x]) {
      if (ed[fa[x]] && ed[fa[x]] != x) chk = 0;
      if (nxt[x]) chk = 0;
      if (st[fa[x]] == hd[x] && sz[x] < deg[fa[x]]) chk = 0;
    }
    for (int x (v); fa[x] != l; x = fa[x]) {
      if (st[fa[x]] && st[fa[x]] != x) chk = 0;
      if (pre[x]) chk = 0;
      if (ed[fa[x]] == tl[x] && sz[x] < deg[fa[x]]) chk = 0;
    }
    if (nxt[p] != q) {
      if (nxt[p] || pre[q]) chk = 0;
      if (ed[l] == p || st[l] == q) chk = 0;
      if (hd[p] == hd[q]) chk = 0;
      if (hd[p] == st[l] && tl[q] == ed[l] && sz[p] + sz[q] < deg[l]) chk = 0;
    }
    if (chk) {
      A.eb (i);
      for (int x (u); fa[x] != l; x = fa[x]) ed[fa[x]] = x;
      for (int x (v); fa[x] != l; x = fa[x]) st[fa[x]] = x;
      int ns (sz[p] + sz[q]);
      for (int x (hd[p]); x; x = nxt[x]) tl[x] = tl[q], sz[x] = ns;
      for (int x (tl[q]); x; x = pre[x]) hd[x] = hd[p], sz[x] = ns;
      nxt[p] = q, pre[q] = p;
    }
  }
  cout << ans << endl;
  sort (A.begin (), A.end ());
  printf ("%ld ", A.size ());
  assert (A.size () <= m);
  for (auto u : A) printf ("%lld ", u), assert (u <= m); puts ("");
  dfs (1);
  for (auto u : vec) printf ("%lld ", u);
}
posted @ 2025-09-30 15:08  lalaouye  阅读(12)  评论(0)    收藏  举报