题解:ARC176D Swap Permutation
题意
给定一个长度为 \(n\) 的排列 \(p\),并执行以下操作 \(m\) 次:选择 \(1\leq i<j\leq n\),交换 \(p_i\) 和 \(p_j\)。
定义一个序列 \(p\) 的权值为 \(\sum_{i=1}^{n-1}|p_i-p_{i-1}|\)。求在 \(\binom{n}{2}^m\) 种可能的操作后,\(p\) 的价值之和。答案对 \(998244353\) 取模。
对于所有数据,\(2\leq n\leq 2\times 10^5\),\(1\leq m\leq 2\times 10^5\)。
题解
一顿乱想后发现序列的权值很难进行统计,所以考虑重新刻画权值。显然
据此,考虑枚举 \(k\in[1,n)\),建立 \(0/1\) 序列 \(v_i=[p_i>k]\)。注意到若 \(v\) 中存在相邻的 \((0,1)/(1,0)\),则说明对应位置上 \(\min(p_i,p_{i+1})<j\leq \max(p_i,p_{i+1})\),与权值的形式恰好一致!因此,问题转化为对于每个 \(0/1\) 序列,计算所有可能的交换操作后相邻的 \((0,1)/(1,0)\) 的数量总和。
我们选定两个相邻的位置 \(j\) 和 \(j+1\) 进行观察,则本质不同的 \(0/1\) 对只有 3 种:\((0,0),(1,1),(0,1)/(1,0)\),依次标号为 \([1,3]\)。考虑 DP。令 \(f_{i,1/2/3}\) 表示 \(i\) 次操作后这两个位置最终形成某种 \(0/1\) 对的方案数。转移很简单,并且显然可以矩阵快速幂优化,这里直接给出转移矩阵:
其中 \(c_0=k,c_1=n-k\) 分别表示 \(v\) 中 \(0,1\) 的个数。依据选定的 \(j\) 和 \(j+1\) 的初始情况,列出初始的 \(1\times 3\) 答案矩阵,再右乘 \(T^m\) 即可得到这两个位置的答案。于是我们得到了 \(O(n^2\log{m})\) 的算法。但既然本质不同的 \(0/1\) 对只有 3 种,那么直接维护它们的数量即可,无需枚举每对相邻的位置。时间复杂度 \(O(n\log{m})\)。
具体实现时,可以将 3 个 \(1\times 3\) 的初始矩阵合并为一个 \(3\times 3\) 矩阵整体转移。
代码
#include <iostream>
using namespace std;
#define lowbit(x) ((x) & -(x))
#define chk_min(x, v) (x) = min((x), (v))
#define chk_max(x, v) (x) = max((x), (v))
typedef long long ll;
typedef pair<int, int> pii;
const int N = 2e5 + 5, M = 2e5 + 5, MOD = 998244353;
int n, m, ans, p[N], pos[N];
int c0, c1, t1, t2, t3;
bool vis[N];
struct Matrix {
int r, c;
ll a[5][5];
void clear(ll v = 0) {
for (int i = 1; i <= r; ++i)
for (int j = 1; j <= c; ++j) a[i][j] = v;
}
Matrix operator*(const Matrix &x) const {
Matrix res = { r, x.c }; res.clear();
for (int i = 1; i <= r; ++i)
for (int j = 1; j <= x.c; ++j)
for (int k = 1; k <= c; ++k)
res.a[i][j] = (res.a[i][j] + a[i][k] * x.a[k][j] % MOD) % MOD;
return res;
}
} f, g, t;
Matrix qpowm(Matrix a, ll b) {
--b;
Matrix res = a;
for (; b; b >>= 1) {
if (b & 1) res = res * a;
a = a * a;
}
return res;
}
void change(int x, int d) {
if (!vis[x] && !vis[x + 1]) f.a[1][1] += d;
else if (vis[x] && vis[x + 1]) f.a[2][2] += d;
else f.a[3][3] += d;
}
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> p[i], pos[p[i]] = i;
f = t = { 3, 3 };
f.clear(); f.a[2][2] = n - 1;
for (int i = 1; i <= n; ++i) vis[i] = true;
for (int i = 1; i < n; ++i) {
if (pos[i] > 1) change(pos[i] - 1, -1);
if (pos[i] < n) change(pos[i], -1);
vis[pos[i]] = 0;
if (pos[i] > 1) change(pos[i] - 1, 1);
if (pos[i] < n) change(pos[i], 1);
c0 = i; c1 = n - i;
ll v = 1ll * (n - 2) * (n - 3) / 2;
t.a[1][1] = v + ((c0 - 2) << 1) + 1; t.a[1][2] = 0; t.a[1][3] = c1 << 1;
t.a[2][1] = 0; t.a[2][2] = v + ((c1 - 2) << 1) + 1; t.a[2][3] = c0 << 1;
t.a[3][1] = c0 - 1; t.a[3][2] = c1 - 1; t.a[3][3] = v + c0 - 1 + c1 - 1 + 1;
for (int i = 1; i <= 3; ++i)
for (int j = 1; j <= 3; ++j) t.a[i][j] %= MOD;
g = f * qpowm(t, m);
ans = (ans + (g.a[1][3] + g.a[2][3] + g.a[3][3]) % MOD) % MOD;
}
cout << ans;
return 0;
}

浙公网安备 33010602011771号