高斯消元

高斯消元

高斯消元通常用于求解如下形式的 \(n\) 元线性方程组:

\[\begin{cases} a_{1, 1} x_1 + a_{1, 2} x_2 + \cdots + a_{1, n} x_n = b_1 \\ a_{2, 1} x_1 + a_{2, 2} x_2 + \cdots + a_{2, n} x_n = b_2 \\ \vdots \\ a_{m, 1} x_1 + a_{m, 2} x_2 + \cdots + a_{m, n} x_n = b_n \end{cases} \]

考虑将其表示为矩阵:

\[\left[ \begin{array}{cccc|c} a_{1, 1} & a_{1, 2} & \cdots & a_{1, n} & b_1 \\ a_{2, 1} & a_{2, 2} & \cdots & a_{2, n} & b_2 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ a_{m, 1} & a_{m, 2} & \cdots & a_{m, n} & b_m \\ \end{array} \right] \]

称之为增广矩阵,而没有右边系数的矩阵称为系数矩阵。

考虑把一个线性方程组的 \(n\) 个元的解写在一个列向量中:

\[X = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_n \end{bmatrix} \]

于是可以得到向量方程 \(AX = B\)

解线性方程组

P3389 【模板】高斯消元法

考虑在增广矩阵上操作,将其化为上三角矩阵,就可以从最后一行逆推上去:

\[\left[ \begin{array}{ccccc|c} a_{1, 1} & a_{1, 2} & \cdots & a_{1, n - 1} & a_{1, n} & b_1 \\ 0 & a_{2, 2} & \cdots & a_{2, n - 1} & a_{2, n} & b_2 \\ 0 & 0 & \cdots & a_{3, n - 1} & a_{3, n} & b_3 \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots \\ 0 & 0 & \cdots & 0 & a_{m, n} & b_m \\ \end{array} \right] \]

考虑枚举 \(1 \sim n\) 作为主元,然后用加减消元法将下方的行的该元消去即可。

每次选取系数最大的为主元可以有效降低精度误差。

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-9;
const int N = 1e2 + 7;

double g[N][N], ans[N];

int n;

inline bool Gauss() {
    for (int i = 1; i <= n; ++i) {
        int mxp = i;

        for (int j = i + 1; j <= n; ++j)
            if (fabs(g[j][i]) > fabs(g[mxp][i]))
                mxp = j;

        if (fabs(g[mxp][i]) < eps)
            return false;

        if (mxp != i)
            swap(g[mxp], g[i]);

        for (int j = i + 1; j <= n; ++j) {
            double div = g[j][i] / g[i][i];

            for (int k = i; k <= n + 1; ++k)
                g[j][k] -= div * g[i][k];
        }
    }

    for (int i = n; i; --i) {
        ans[i] = g[i][n + 1];

        for (int j = i + 1; j <= n; ++j)
            ans[i] -= g[i][j] * ans[j];

        ans[i] /= g[i][i];
    }

    return true;
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n + 1; ++j)
            scanf("%lf", g[i] + j);

    if (!Gauss())
        return puts("No Solution"), 0;

    for (int i = 1; i <= n; ++i)
        printf("%.2lf\n", ans[i]);

    return 0;
}

约旦消元法可以避免回带来求出答案,目标是把矩阵变为:

\[\left[ \begin{array}{cccc|c} a_{1, 1} & 0 & \cdots & 0 & b_1 \\ 0 & a_{2, 2} & \cdots & 0 & b_2 \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots \\ 0 & 0 & \cdots & a_{m, n} & b_m \\ \end{array} \right] \]

只要在高斯消元时同时消掉上面的行即可。

注意一下特殊情况:

  • 无解:当左边系数为 \(0\) 而右边系数不为 \(0\) 时即为无解。

  • 多解:当左边系数为 \(0\) 而右边系数也为 \(0\) 时即为多解。

inline bool Gauss() {
	for (int i = 1; i <= n; ++i) {
		int mxp = i;
		
		for (int j = i + 1; j <= n; ++j)
			if (fabs(g[j][i]) > fabs(g[mxp][i]))
				mxpos = j;
		
		if (fabs(g[mxp][i]) < eps)
			return false;
       	
        if (i != mxp)
            swap(g[i], g[mxp]);
		
		for (int j = 1; j <= n; ++j)
			if (j != i) {
				double div = g[j][i] / g[i][i];
				
				for (int k = i + 1; k <= n + 1; ++k)
					g[j][k] -= g[i][k] * div;
			}
	}
	
	for (int i = 1; i <= n; ++i)
		ans[i] = g[i][n + 1] / g[i][i];
	
	return true;
}

P3232 [HNOI2013] 游走

给定一张无向图,初始在 \(1\) ,每次等概率走一条出边,获得该边边权的分数,走到 \(n\) 时停止。

需要给每条边确定 \(1 \sim m\) 的边权,其中每条边边权互不相同,最小化期望得分。

\(n \leq 500\)

考虑统计点的期望经过次数:

\[f_u = \sum_{(u, v) \in E, v \neq n} \frac{f_v}{d_v} + [u = 1] \]

注意 \(f_n = 1\) ,不难高消求解,则 \((u, v)\) 边的期望经过次数即为 \(\frac{f_u}{d_u} + \frac{f_v}{d_v}\) ,注意这里 \(u\)\(v\) 恰为 \(n\) 时不能贡献。排序后贪心即可,时间复杂度 \(O(n^3)\)

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-9;
const int N = 5e2 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

struct Edge {
    double w;
    int u, v;
} e[N * N];

double a[N][N], f[N];
int deg[N];

int n, m;

inline void Gauss() {
    for (int i = 1; i <= n; ++i) {
        if (fabs(a[i][i]) < eps) {
            for (int j = i + 1; j <= n; ++j)
                if (fabs(a[j][i]) >= eps) {
                    swap(a[j], a[i]);
                    break;
                }
        }

        for (int j = 1; j <= n; ++j) {
            if (j == i)
                continue;

            double div = a[j][i] / a[i][i];

            for (int k = i + 1; k <= n + 1; ++k)
                a[j][k] -= a[i][k] * div;
        }
    }

    for (int i = 1; i <= n; ++i)
        f[i] = a[i][n + 1] / a[i][i];
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= m; ++i) {
        scanf("%d%d", &e[i].u, &e[i].v);
        G.insert(e[i].u, e[i].v), G.insert(e[i].v, e[i].u);
        ++deg[e[i].u], ++deg[e[i].v];
    }

    for (int u = 1; u < n; ++u) {
        a[u][u] = 1, a[u][n + 1] = (u == 1);

        for (int v : G.e[u])
            if (v != n)
                a[u][v] -= (double)1 / deg[v];
    }

    a[n][n] = a[n][n + 1] = 1, Gauss();

    for (int i = 1; i <= m; ++i)
        e[i].w = (e[i].u == n ? 0 : f[e[i].u] / deg[e[i].u]) + (e[i].v == n ? 0 : f[e[i].v] / deg[e[i].v]);

    sort(e + 1, e + m + 1, [](const Edge &a, const Edge &b) {
        return a.w > b.w;
    });

    double ans = 0;

    for (int i = 1; i <= m; ++i)
        ans += e[i].w * i;

    printf("%.3lf", ans);
    return 0;
}

解 01 异或方程组

01 异或方程组是指形如

\[\begin{cases} a_{1, 1} x_1 \oplus a_{2, 2} x_2 \oplus \cdots \oplus a_{1, n} x_n = b_1 \\ a_{2, 1} x_1 \oplus a_{2, 2} x_2 \oplus \cdots \oplus a_{2, n} x_n = b_2 \\ \cdots \\ a_{n, 1} x_1 \oplus a_{n, 2} x_2 \oplus \cdots \oplus a_{n, n} x_n = b_n \end{cases} \]

的方程组,其中 \(a_{i, j}, b_i \in \{ 0, 1 \}\)

由于 \(\oplus\) 符合交换律与结合律,故可以按照高斯消元法逐步消元求解。在消元的时候应使用异或消元而非加减消元,且无需配系数(因为系数均为 \(0\)\(1\) )。

可以用 bitset 优化到 \(O(\frac{n^3}{\omega})\)

inline bool Gauss() {
	for (int i = 1; i <= n; ++i) {
		int mxp = i;

		while (mxpos <= n && !g[mxpos].test(i))
			++mxp;

		if (mxpos > n)
			return false;
		
        if (i != mxp)
			swap(g[i], g[mxp]);

		for (int j = 1; j <= n; ++j)
			if (j != i && g[j].test(i))
				g[j] ^= g[i];
	}

	return true;
}

P3429 [POI 2005] DWA-Two Parties

给出一张无向图,需要给每个点染上黑色或白色。定义一个点合法当且仅当其所有同色邻居数量为偶数,最大化合法点的数量,并构造方案。

\(n \leq 200\)\(m \leq \frac{n(n - 1)}{2}\)

可以证明答案始终为 \(n\) ,证明就是求解如下异或方程组:

\[\bigoplus_{(u, v) \in G} (x_u \oplus x_v) = \deg_u \bmod 2 \]

这个方程始终有解,否则若出现能将其消为左边没有元而右边为 \(1\) 的情况,其说明存在奇数个奇度数的点。而每个奇度数的点要选奇数个邻居在集合中,偶度数点要选偶数个邻居在集合中,从而其导出子图的度数为奇数,矛盾。

#include <bits/stdc++.h>
using namespace std;
const int N = 2e2 + 7;

bitset<N> g[N];

int n;

inline void Gauss() {
    for (int i = 1; i <= n; ++i) {
        if (!g[i].test(i)) {
            for (int j = i + 1; j <= n; ++j)
                if (g[j].test(i)) {
                    swap(g[i], g[j]);
                    break;
                }
        }

        for (int j = 1; j <= n; ++j)
            if (j != i && g[j].test(i))
                g[j] ^= g[i];
    }
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i) {
        int d;
        scanf("%d", &d);

        if (d & 1)
            g[i].set(i), g[i].set(n + 1);

        while (d--) {
            int x;
            scanf("%d", &x);
            g[i].set(x);
        }
    }

    Gauss();
    int cnt = 0;

    for (int i = 1; i <= n; ++i)
        if (g[i].test(n + 1))
            ++cnt;

    printf("%d\n", cnt);

    for (int i = 1; i <= n; ++i)
        if (g[i].test(n + 1))
            printf("%d ", i);
    
    return 0;
}

QOJ10520. 矩阵除法

给定 \(n \times m\) 的 01 矩阵 \(A\)\(n \times p\) 的 01 矩阵 \(C\) ,求一个 \(m \times p\) 的 01 矩阵 \(B\) 满足 \(A \times B = C\) ,其中 \(C_{i, j} = \oplus_{k = 1}^m A_{i, k} \and B_{k, j}\)

\(n \leq 1000\)

固定第 \(i\) 列,问题转化为求解 \(n\)\(m\) 元的方程组,其中第 \(k\) 个方程为 \(\oplus_{j = 1}^m a_{k, j} \and b_{j, i} = c_{k, i}\) 。直接高斯消元,时间复杂度 \(O(\frac{p n^2 m}{\omega})\) ,无法接受。

注意到系数矩阵是固定的,只有常数项在变,也就是说消元的过程是一样的,考虑同时求解所有方程,将 \(n\) 个常数项都放进 bitset 即可。

时间复杂度 \(O(\frac{n^2 (m + p)}{\omega})\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e3 + 7;

bitset<N> a[N], b[N], c[N];

int n, m, p;

signed main() {
    scanf("%d%d%d", &n, &m, &p);

    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= m; ++j) {
            int x;
            scanf("%d", &x);
            a[i].set(j, x);
        }

    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= p; ++j) {
            int x;
            scanf("%d", &x);
            c[i].set(j, x);
        }

    for (int i = 1, t = 1; i <= m && t <= n; ++i) {
        if (!a[t][i]) {
            for (int j = t + 1; j <= n; ++j)
                if (a[j][i]) {
                    swap(a[t], a[j]), swap(c[t], c[j]);
                    break;
                }

            if (!a[t][i])
                continue;
        }

        for (int j = 1; j <= n; ++j)
            if (j != t && a[j][i])
                a[j] ^= a[t], c[j] ^= c[t];

        ++t;
    }

    for (int i = 1; i <= n; ++i) {
        int pos = 0;

        for (int j = 1; j <= m; ++j)
            if (a[i][j]) {
                pos = j;
                break;
            }

        if (!pos && c[i].any())
            return puts("No"), 0;
        else if (pos)
            b[pos] = c[i];
    }

    puts("Yes");

    for (int i = 1; i <= m; ++i) {
        for (int j = 1; j <= p; ++j)
            printf("%d ", (int)b[i][j]);

        puts("");
    }

    return 0;
}

求解行列式

P7112 【模板】行列式求值

考虑极端情况,当一个矩阵任意一个位置出现 \(0\) ,其直接没有贡献了。

考虑利用初等变换消元,将矩阵消成一个上三角矩阵,此时矩阵的行列式即为对角线元素的乘积。

注意取模意义下某些数可能不存在逆元,此时需要辗转相减。

inline int Gauss(int n) {
	int res = 1;

	for (int i = 1; i <= n; ++i)
		for (int j = i + 1; j <= n; ++j) {
			while (g[i][i]) {
				int div = g[j][i] / g[i][i];

				for (int k = i; k <= n; ++k)
					g[j][k] = dec(g[j][k], 1ll * g[i][k] * div % Mod);

				swap(g[i], g[j]), res = Mod - res;
			}
			
			swap(g[i], g[j]), res = Mod - res;
		}
	
	for (int i = 1; i <= n; ++i)
		res = 1ll * res * g[i][i] % Mod;

	return res;
}

band-matrix

band-matrix 形如下图:

空白部分都为 \(0\) ,橙色部分有值,这样中间就形成了一个宽度为 \(d\) 的带。

可以发现任意一个 \(i\) 满足从 \((i, i)\) 向右或向下拓展都有不超过 \(d - 1\) 个非零数字,即很多位置根本不需要消。

具体地,假设现在要消第 \(i\) 列,那么从第 \(i\) 行开始往下枚举 \(d - 1\) 行,每行往右消 \(d\) 个数字即可,最后仍能得到一个上三角矩阵。

与普通高斯消元有点不一样的地方在于当主元为 \(0\) 的时候的处理方法。在 band-matrix 中,若直接交换行会破坏 band-matrix 。注意到每次交换完后交换的行右边最多多出 \(d\) 个数,于是每次往右消元 \(2d\) 个数即可。

时间复杂度 \(O(nd^2)\)

inline bool Gauss(int n, int d) {
    for (int i = 1; i <= n; ++i) {
        if (fabs(g[i][i]) < eps) {
            for (int j = i + 1; j <= min(i + d, n); ++j)
                if (fabs(g[j][i]) >= eps) {
                    swap(g[i], g[j]);
                    break;
                }

            if (fabs(g[i][i]) < eps)
                return false;
        }
            
        for (int j = i + 1; j <= min(i + d, n); ++j) {
            double div = g[j][i] / g[i][i];
            
            for (int k = i; k <= min(i + 2 * d, n); ++k)
                g[j][k] -= div * g[i][k];
            
            g[j][n + 1] -= div * g[i][n + 1];
        }
    }
    
    for (int i = n; i; --i) {
        ans[i] = g[i][n + 1];
        
        for (int j = i + 1; j <= min(i + 2 * d, n); ++j)
            ans[i] -= g[i][j] * ans[j];
        
        ans[i] /= g[i][i];
    }
    
    return true;
}

还有一种解决主元为 \(0\) 的方法,普通高消是交换行,这里只要交换列就可以保持 band-matrix 的性质了,时间复杂度也是 \(O(nd^2)\)

注意这里要记录一下当前主元是哪一列的,这样回带才是对的。

inline bool Gauss(int n, int d) {
    iota(id + 1, id + 1 + n, 1);
    
    for (int i = 1; i <= n; ++i) {
        if (g[i][i] < eps) {
            for (int j = i + 1; j <= min(n, i + d); ++j)
                if (g[i][j] >= eps) {
                    for (int k = 1; k <= n; ++k)
                        swap(g[k][i], g[k][j]);
                    
                    swap(id[i], id[j]);
                    break;
                }

            if (g[i][i] < eps)
                return false;
        }
        
        for (int j = i + 1; j <= n; ++j) {
            double div = g[j][i] / g[i][i];
            
            for (int k = i; k <= min(n, i + d); ++k)
                g[j][k] -= g[i][k] * div;
            
            g[j][n + 1] -= g[i][n + 1] * div;
        }
    }
    
    for (int i = n; i; --i) {
        ans[id[i]] = g[i][n + 1];
        
        for (int j = i + 1; j <= min(n, i + d); ++j)
            ans[id[i]] -= ans[id[j]] * g[i][j];
        
        ans[id[i]] /= g[i][i];
    }

    return true;
}

CF24D Broken robot

\(n\)\(m\) 列的矩阵,现在在 \((x,y)\),每次等概率向左、右、下走或原地不动,但不能走出去,求走到最后一行期望的步数。

\(n, m \leq 10^3\)

\(f_{i, j}\) 表示机器人在 \((i, j)\) 时走到最后一行的期望步数,\(m = 1\) 时有(省略第二维):

\[f_i = 1 + \dfrac{1}{2} (f_{i + 1} + f_i) \]

即:

\[f_i = f_{i + 1} + 2 \]

\(m > 1\) 时有:

\[f_{i, j} = 1 + \begin{cases} \dfrac{1}{3} (f_{i, j} + f_{i + 1, j} + f_{i, j + 1}), & j = 1 \\ \dfrac{1}{4} (f_{i, j} + f_{i + 1, j} + f_{i, j - 1} + f_{i, j + 1}), & 1 < j < m \\ \dfrac{1}{3} (f_{i, j} + f_{i + 1, j} + f_{i, j - 1}), & j = m \end{cases} \]

注意到这是一个 \(d = 2\) 的 band-matrix ,直接使用 band-matrix 消元即可,时间复杂度 \(O(nmd^2)\)

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-12;
const int N = 1e3 + 7;

double g[N][N], ans[N][N];

int n, m, x, y;

inline bool Gauss(int n, int d, double *ans) {
    for (int i = 1; i <= n; ++i) {
        if (fabs(g[i][i]) < eps) {
            for (int j = i + 1; j <= min(i + d, n); ++j)
                if (fabs(g[j][i]) >= eps) {
                    swap(g[i], g[j]);
                    break;
                }

            if (fabs(g[i][i]) < eps)
                return false;
        }
            
        for (int j = i + 1; j <= min(i + d, n); ++j) {
            double div = g[j][i] / g[i][i];
            
            for (int k = i; k <= min(i + 2 * d, n); ++k)
                g[j][k] -= div * g[i][k];
            
            g[j][n + 1] -= div * g[i][n + 1];
        }
    }
    
    for (int i = n; i; --i) {
        ans[i] = g[i][n + 1];
        
        for (int j = i + 1; j <= min(i + 2 * d, n); ++j)
            ans[i] -= g[i][j] * ans[j];
        
        ans[i] /= g[i][i];
    }
    
    return true;
}

signed main() {
    scanf("%d%d%d%d", &n, &m, &x, &y);
    
    if (m == 1)
        return printf("%.5lf", 2.0 * (n - x)), 0;
    
    for (int i = n - 1; i; --i) {
        g[1][1] = 2, g[1][2] = -1, g[1][m + 1] = 3 + ans[i + 1][1];
        
        for (int j = 2; j < m; ++j)
            g[j][j] = 3, g[j][j - 1] = g[j][j + 1] = -1, g[j][m + 1] = 4 + ans[i + 1][j];
        
        g[m][m] = 2, g[m][m - 1] = -1, g[m][m + 1] = 3 + ans[i + 1][m];
        Gauss(m, 1, ans[i]);
    }
    
    printf("%.5lf", ans[x][y]);
    return 0;
}

CF963E Circles of Waiting

二维平面上有一个点,一开始在 \((0, 0)\) 。每秒钟其都会随机移动,记当前坐标为 \((x, y)\) ,下一秒:

  • \((x - 1, y)\) 的概率是 \(p_1\)
  • \((x, y - 1)\) 的概率是 \(p_2\)
  • \((x + 1, y)\) 的概率是 \(p_3\)
  • \((x, y + 1)\) 的概率是 \(p_4\)

保证 \(p_1 + p_2 + p_3 + p_4 = 1\) ,求该点距离原点 \(> R\) 的期望步数

\(0 \leq R \leq 50\)

把所有满足 \(i^2 + j^2 \leq R^2\) 的点依次编号,显然有 \(O(R^2)\) 个点。

\(f_{i, j}\) 表示 \((i, j)\) 走出圆的期望步数,\((i, j)\) 能转移到 \((i, j - 1), (i - 1, j), (i, j + 1), (i + 1, j)\) 。因为是依次编号,所以建出来的矩阵带宽 \(\leq 2R + 1\) 。用 band-matrix 即可做到 \(O(R^4)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int R = 5e1 + 7, N = 8e3 + 7;

vector<pair<int, int> > Pos;

int id[R << 1][R << 1], g[N][N], ans[N];

int r, p1, p2, p3, p4;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int mi(int a, int b) {
    int res = 1;
    
    for (; b; b >>= 1, a = 1ll * a * a % Mod)
        if (b & 1)
            res = 1ll * res * a % Mod;
    
    return res;
}

inline int getid(int x, int y) {
    return id[x + R][y + R];
}

inline bool Gauss(int n, int d) {
    for (int i = 1; i <= n; ++i) {
        if (!g[i][i]) {
            for (int j = i + 1; j <= min(i + d, n); ++j)
                if (g[j][i]) {
                    swap(g[i], g[j]);
                    break;
                }

            if (!g[i][i])
                return false;
        }
        
        int inv = mi(g[i][i], Mod - 2);
            
        for (int j = i + 1; j <= min(i + d, n); ++j) {
            int div = 1ll * g[j][i] * inv % Mod;
            
            for (int k = i; k <= min(i + 2 * d, n); ++k)
                g[j][k] = dec(g[j][k], 1ll * div * g[i][k] % Mod);
            
            g[j][n + 1] = dec(g[j][n + 1], 1ll * div * g[i][n + 1] % Mod);
        }
    }
    
    for (int i = n; i; --i) {
        ans[i] = g[i][n + 1];
        
        for (int j = i + 1; j <= min(i + 2 * d, n); ++j)
            ans[i] = dec(ans[i], 1ll * g[i][j] * ans[j] % Mod);
        
        ans[i] = 1ll * ans[i] * mi(g[i][i], Mod - 2) % Mod;
    }

    return true;
}

signed main() {
    scanf("%d%d%d%d%d", &r, &p1, &p2, &p3, &p4);
    int all = mi(add(add(p1, p2), add(p3, p4)), Mod - 2);
    p1 = 1ll * p1 * all % Mod, p2 = 1ll * p2 * all % Mod;
    p3 = 1ll * p3 * all % Mod, p4 = 1ll * p4 * all % Mod;
    
    for (int i = -r; i <= r; ++i)
        for (int j = -r; j <= r; ++j)
            if (i * i + j * j <= r * r)
                Pos.emplace_back(i, j), id[i + R][j + R] = Pos.size();
    
    int n = Pos.size(), band = 0;
    
    for (auto it : Pos) {
        int x = it.first, y = it.second, id = getid(x, y);
        g[id][id] = g[id][n + 1] = 1;
        
        if (getid(x - 1, y)) {
            g[id][getid(x - 1, y)] = Mod - p1;
            band = max(band, abs(id - getid(x - 1, y)));
        }
        
        if (getid(x, y - 1)) {
            g[id][getid(x, y - 1)] = Mod - p2;
            band = max(band, abs(id - getid(x, y - 1)));
        }
        
        if (getid(x + 1, y)) {
            g[id][getid(x + 1, y)] = Mod - p3;
            band = max(band, abs(id - getid(x + 1, y)));
        }
        
        if (getid(x, y + 1)) {
            g[id][getid(x, y + 1)] = Mod - p4;
            band = max(band, abs(id - getid(x, y + 1)));
        }
    }
    
    Gauss(n, band);
    printf("%d", ans[getid(0, 0)]);
    return 0;
}

P4457 [BJOI2018] 治疗之雨

你有 \(p\) 滴血量,血量上限为 \(n\) 。每轮操作如下:

  • 先以 \(\frac{1}{m + 1}\) 的概率增加 \(1\) 滴血,满血时则概率为 \(0\)
  • \(k\) 次判定,每次以 \(\frac{1}{m + 1}\) 的概率减少一滴血,死了则概率为 \(0\)

求期望几轮死亡。

\(n \leq 1500\)\(m, k \leq 10^9\)

\(f_i\) 表示血量为 \(i\) 时期望多少轮结束,则:

\[f_i = 1 + \sum_{j = 0}^{\min(i, k)} (\frac{m}{m + 1} f_{i - j} + \frac{1}{m + 1} f_{i - j + 1}) \times \binom{k}{j} (\frac{1}{m + 1})^j (\frac{m}{m + 1})^{k - j} \]

注意一下 \(i = n\) 时的情况即可。

观察到 \(f_i\) 只与 \(f_{0 \sim i + 1}\) 有关,所以矩阵应该是一个类似下三角矩阵的东西。

因为高斯消元的时候是拿自己这行去减下面的,所以每一行中只有 \(2\) 个系数要去和下面的相减。

时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 1.5e3 + 7;

int g[N][N], inv[N], d[N], id[N], ans[N];

int n, p, m, k;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int mi(int a, int b) {
    int res = 1;
    
    for (; b; b >>= 1, a = 1ll * a * a % Mod)
        if (b & 1)
            res = 1ll * res * a % Mod;
    
    return res;
}

inline bool Gauss(int n, int d) {
    iota(id + 1, id + 1 + n, 1);
    
    for (int i = 1; i <= n; ++i) {
        if (!g[i][i]) {
            for (int j = i + 1; j <= min(n, i + d); ++j)
                if (g[i][j]) {
                    for (int k = 1; k <= n; ++k)
                        swap(g[k][i], g[k][j]);
                    
                    swap(id[i], id[j]);
                    break;
                }

            if (!g[i][i])
                return false;
        }
        
        int Inv = mi(g[i][i], Mod - 2);
        
        for (int j = i + 1; j <= n; ++j) {
            int div = 1ll * g[j][i] * Inv % Mod;
            
            for (int k = i; k <= min(n, i + d); ++k)
                g[j][k] = dec(g[j][k], 1ll * g[i][k] * div % Mod);
            
            g[j][n + 1] = dec(g[j][n + 1], 1ll * g[i][n + 1] * div % Mod);
        }
    }
    
    for (int i = n; i; --i) {
        ans[id[i]] = g[i][n + 1];
        
        for (int j = i + 1; j <= min(n, i + d); ++j)
            ans[id[i]] = dec(ans[id[i]], 1ll * ans[id[j]] * g[i][j] % Mod);
        
        ans[id[i]] = 1ll * ans[id[i]] * mi(g[i][i], Mod - 2) % Mod;
    }

    return true;
}

signed main() {
    inv[0] = inv[1] = 1;
    
    for (int i = 2; i < N; ++i) 
        inv[i] = 1ll * (Mod - Mod / i) * inv[Mod % i] % Mod;
    
    int T;
    scanf("%d", &T);
    
    while (T--) {
        scanf("%d%d%d%d", &n, &p, &m, &k);
        
        if (!k) {
            puts("-1");
            continue;
        } else if (!m) {
            if (k == 1)
                puts("-1");
            else
                printf("%d\n", (p + k - 2 - (p == n)) / (k - 1));
            
            continue;
        }
        
        d[0] = mi(1ll * m * mi(m + 1, Mod - 2) % Mod, k);
        
        for (int i = 1, invm = mi(m, Mod - 2); i <= min(n, k); ++i)
            d[i] = 1ll * d[i - 1] * invm % Mod * (k - i + 1) % Mod * inv[i] % Mod;
        
        memset(g, 0, sizeof(g));
        
        for (int i = 1, div = mi(m + 1, Mod - 2); i < n; ++i) {
            g[i][i] = g[i][n + 1] = 1;
            
            for (int j = 0; j <= min(i, k); ++j) {
                g[i][i - j] = dec(g[i][i - j], 1ll * dec(1, div) * d[j] % Mod);
                g[i][i - j + 1] = dec(g[i][i - j + 1], 1ll * div * d[j] % Mod);
            }
        }

        g[n][n + 1] = g[n][n] = 1;

        for (int i = 0; i <= min(n, k); ++i)
            g[n][n - i] = dec(g[n][n - i], d[i]);
        
        Gauss(n, 1);
        printf("%d\n", ans[p]);
    }
    
    return 0;
}

P6899 [ICPC2014 WF] Pachinko

有一个宽度为 \(w\) 高度为 \(h\) 的方格纸, $ w \times h$ 的格子中,有一些是空的,有一些是洞,有一些是障碍物。从第一行的空的格子中随机选一个放置一个球,向上下左右移动的概率比为 \(p_u : p_d : p_l : p_r\) ,不能移动到有障碍物的格子上。对于每个洞,输出落入该洞的概率。

\(2 \leq 20, h \leq 10^4, p_u + p_d + p_l + p_r = 100\)

和上面比较类似。

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-12;
const int dx[] = {-1, 1, 0, 0};
const int dy[] = {0, 0, -1, 1};
const int N = 1e4 + 7, M = 2e1 + 7;

double g[N * M][M << 1], ans[N * M];
int id[N][M];
char str[N][M];

int n, m, p[4], tot;

inline bool check(int x, int y) {
	return x && x <= n && y && y <= m;
}

inline void Gauss(int n, int d) {
	for (int i = 1; i <= n; ++i) {
		if (fabs(g[i][i]) < eps)
			continue;
		
		for (int j = i + 1; j <= min(i + d, n); ++j)
			g[i][j] /= g[i][i];

		ans[i] /= g[i][i], g[i][i] = 1;

		for (int j = i + 1; j <= min(i + d, n); ++j) {
			double div = g[j][i] / g[i][i];

			for (int k = i; k <= min(i + d, n); ++k)
				g[j][k] -= g[i][k] * div;

			ans[j] -= ans[i] * div;
		}
	}

	for (int i = n; i; --i)
		for (int j = i + 1; j <= min(i + d, n); ++j)
			ans[i] -= g[i][j] * ans[j];
}

signed main() {
	scanf("%d%d", &m, &n);

	for (int i = 0; i < 4; ++i)
		scanf("%d", p + i);

	for (int i = 1; i <= n; ++i) {
		scanf("%s", str[i] + 1);

		for (int j = 1; j <= m; ++j)
			if (str[i][j] != 'X')
				id[i][j] = ++tot;
	}

	int sum = m - count(id[1] + 1, id[1] + 1 + m, 0);

	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= m; ++j) {
			if (!id[i][j])
				continue;

			if (i == 1)
				ans[id[i][j]] = 1.0 / sum;

			if (str[i][j] == 'T')
				continue;

			g[id[i][j]][id[i][j]] = 1;
			int base = 100;

			for (int k = 0; k < 4; ++k) {
				int x = i + dx[k], y = j + dy[k];

				if (!check(x, y) || !id[x][y])
					base -= p[k];
			}

			for (int k = 0; k < 4; ++k) {
				int x = i + dx[k], y = j + dy[k];

				if (check(x, y) && id[x][y])
					g[id[x][y]][id[i][j]] = -1.0 * p[k] / base;
			}
		}

	Gauss(tot, m);

	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= m; ++j)
			if (str[i][j] == 'T')
				printf("%.9lf\n", ans[id[i][j]]);

	return 0;
}

树上高消

树上 DP 遇到需要高斯消元的情况时,有时叶子和其父亲的 DP 值呈一次函数关系。

那么类似数学归纳法地向上递推,可以发现所有点的值和根的值都是一次函数关系。

此时会发现所有值之和为定值,或者某些点(如根、叶子)的值容易求,那么就可以递推优化到 \(O(n)\)

P5643 [PKUWC2018] 随机游走

给定一棵有根树,\(q\) 次询问从根走到 \(S\) 中的所有点至少一次的期望步数。

\(n \leq 18\)\(q \leq 5000\)

考虑 Min-Max 容斥,求经过 \(S\) 里所有元素的期望时间,即到达 \(S\) 中最后一个点的期望步数( \(\max\) ),那么可以转化为枚举 \(S\) 的子集 \(T\) ,求到达 \(T\) 中第一个元素的期望时间( \(\min\) )。

\(f_{u, S}\) 表示 \(u\) 第一次走到 \(S\) 中的点的期望步数,\(d_u\) 表示 \(u\) 的度数,则:

\[f_{u, S} = \frac{f_{fa_u, S} + \sum_{v \in son(u)} f_{v, S}}{d_u} + 1 \quad (x \not \in S) \\ f_{u, S} = 0 \quad (x \in S) \]

考虑将每个点的值都写作 \(f_{u, S} = k_u \times f_{fa_u, S} + b_u\) 的形式,记:

\[K_u = \sum_{v \in son(u)} k_v, B_u = \sum_{v \in son(u)} b_v \]

则得到:

\[f_{u, S} = \dfrac{1}{d_u - K_u} \times f_{fa_u, S} + \dfrac{d_u + B_u}{d_u - K_u} \]

即:

\[k_u = \dfrac{1}{d_u - K_u}, b_u = \dfrac{d_u + B_u}{d_u - K_u} \]

答案即为 \(\sum_{T \subseteq S} (-1)^{|T| + 1} f_{r, T}\) ,不难用高维前缀和预处理后 \(O(1)\) 查询。时间复杂度 \(O(n 2^n + q)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 19;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int f[1 << N], g[1 << N], deg[N];

int n, q, r;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int mi(int a, int b) {
    int res = 1;
    
    for (; b; b >>= 1, a = 1ll * a * a % Mod)
        if (b & 1)
            res = 1ll * res * a % Mod;
    
    return res;
}

inline int sgn(int n) {
    return n & 1 ? Mod - 1 : 1;
}

struct Node {
    int k, b;

    inline Node(int _b = 0) : k(0), b(_b) {}

    inline Node(int _k, int _b) : k(_k), b(_b) {}

    inline Node operator + (const Node &rhs) const {
        return Node(add(k, rhs.k), add(b, rhs.b));
    }
} nd[N];

void dfs(int u, int f, int state) {
    Node res = 0;

    for (int v : G.e[u])
        if (v != f)
            dfs(v, u, state), res = res + nd[v];

    if (~state >> u & 1)
        nd[u].k = mi(dec(deg[u], res.k), Mod - 2), nd[u].b = 1ll * add(deg[u], res.b) * nd[u].k % Mod;
    else
        nd[u] = 0;
}

signed main() {
    scanf("%d%d%d", &n, &q, &r), --r;

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        ++deg[--u], ++deg[--v];
        G.insert(u, v), G.insert(v, u);
    }

    for (int i = 1; i < (1 << n); ++i)
        dfs(r, -1, i), f[i] = 1ll * nd[r].b * sgn(__builtin_parity(i) ^ 1) % Mod;

    for (int i = 0; i < n; ++i)
        for (int j = 0; j < (1 << n); ++j)
            if (j >> i & 1)
                f[j] = add(f[j], f[j ^ (1 << i)]);

    while (q--) {
        int k, state = 0;
        scanf("%d", &k);

        while (k--) {
            int x;
            scanf("%d", &x);
            state |= 1 << (x - 1);
        }

        printf("%d\n", f[state]);
    }

    return 0;
}

P11736 [集训队互测 2015] 胡策的小树

有一棵树,点权 \(a_{1 \sim n}\) 构成了一个 \(0 \sim n - 1\) 的排列,并且满足 \(a_1 = 0\)

初始时,每个节点各有一只猴子,每一秒第 \(i\) 个点上的猴子会有 \(p(i)\) 的概率跳到父亲,\(\frac{1 - p(i)}{siz_i}\) 的概率跳到子树内的任意点,其中 \(p(i) = \begin{cases} 0 & i = 1 \\ \frac{a_i}{n} & 2 \leq i \leq n \end{cases}\)

初始时将所有点的权值变为 \((a_i + x) \bmod n\) ,其中 \(x\) 为自行选定任意非负整数。

记第 \(i\) 秒成功跳到父亲的猴子数量为 \(g_i\) ,幸福指数被定义为 \(g_{0 \sim +\infty}\) 的平均数。选取适当的 \(x\) ,求幸福指数期望的最大值。

\(n \leq 5 \times 10^5\) ,节点 \(i\) 的父亲从 \(1 \sim i - 1\) 中等概率选取

先考虑初始操作 \(x = 0\) 的情况。可以发现每只猴子对答案的贡献是独立的,进一步发现每只猴子的初始位置是不重要的。因为时间无穷,因此一只猴子一定会跳到根,并可以忽略跳到根之前的贡献。下面讨论从根开始跳的答案。

\(p_i = p(i), q_i = \frac{1 - p_i}{siz_i}\) ,设猴子处于 \(u\) 的概率为 \(f_u\) ,则:

\[f_u = \left( \sum_{w \in anc(u) \cup \{ u \}} q_w f_w \right) + \left( \sum_{v \in son(u)} p_v f_v \right) \\ \sum f_u = 1 \]

暴力高消可以做到 \(O(n^3)\) ,但是没有利用树上高消的性质。

枚举祖先并不好树上高消,记 \(F_u = \sum_{w \in anc(u)} q_w f_w\) ,则 \(f_u = \frac{1}{q_u} (F_u - F_{fa_u})\) ,带入得到:

\[\frac{1}{q_u} (F_u - F_{fa_u}) = F_u + \sum_{v \in son(u)} \frac{p_v}{q_v} (F_v - F_u) \\ (\frac{1}{q_u} - 1 + \sum_{v \in son(u)} \frac{p_v}{q_v}) F_u = \frac{1}{q_u} F_{fa_u} + \sum_{v \in son(u)} \frac{p_v}{q_v} F_v \]

记:

\[\begin{align} A_u &= \frac{1}{q_u} - 1 + \sum_{v \in son(u)} \frac{p_v}{q_v} \\ B_u &= \frac{1}{q_u} \\ C_u &= \frac{p_u}{q_u} \end{align} \]

则:

\[A_u F_u = B_u F_{fa_u} + \sum_{v \in son(u)} C_v F_v \]

考虑设 \(F_u = k_u F_{fa_u} + b_u\) ,则:

\[\begin{align} A_u (k_u F_{fa_u} + b_u) &= B_u F_{fa_u} + \sum_{v \in son(u)} C_v (k_v F_u + b_v) \\ (A_u - \sum_{v \in son(u)} C_v k_v) F_u &= B_u F_{fa_u} + \sum_{v \in son(u)} C_v b_v \end{align} \]

于是可以 \(O(n)\) 高消求解 \(x = 0\) 时幸福指数的期望 \(\sum p_i f_i\)

考虑 \(x \neq 0\) 的情况,此时一定存在一个点 \(i\) 满足 \(p(i) = 0\) ,也就是说跳到 \(i\) 子树内后就跳不出来了。于是对于每个点的子树做一遍树上高消即可。

时间复杂度 \(O(\sum siz_i)\) ,因为树的形态随机,因此时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long double ld;
using namespace std;
const int N = 5e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

struct Node {
    ld k, b;

    inline Node() {}

    inline Node(ld _b) : b(_b) {}

    inline Node(ld _k, ld _b) : k(_k), b(_b) {}

    inline Node operator + (Node rhs) {
        return Node(k + rhs.k, b + rhs.b);
    }

    inline Node operator - (Node rhs) {
        return Node(k - rhs.k, b - rhs.b);
    }

    inline Node operator * (ld x) {
        return Node(k * x, b * x);
    }

    inline Node operator / (ld x) {
        return Node(k / x, b / x);
    }
} nd[N], F[N];

ld p[N], q[N], A[N], B[N], C[N], f[N];
int fa[N], a[N], siz[N], in[N], out[N], id[N];

int n, dfstime;

void dfs(int u) {
    siz[u] = 1, id[in[u] = ++dfstime] = u;

    for (int v : G.e[u])
        dfs(v), siz[u] += siz[v];

    out[u] = dfstime;
}

inline ld solve(int r) {
    for (int i = in[r]; i <= out[r]; ++i) {
        int u = id[i];
        p[u] = (ld)((a[u] + n - a[r]) % n) / n;
        q[u] = (1 - p[u]) / siz[u];
        A[u] = 1 / q[u] - 1, B[u] = 1 / q[u], C[u] = p[u] / q[u];

        if (u != r)
            A[fa[u]] += p[u] / q[u];
    }

    for (int i = out[r]; i > in[r]; --i) {
        int u = id[i];
        nd[u] = Node(B[u], 0);
        ld div = A[u];

        for (int v : G.e[u])
            nd[u].b += C[v] * nd[v].b, div -= C[v] * nd[v].k;

        nd[u] = nd[u] / div;
    }

    F[r] = Node(1, 0);
    Node all = F[r] / q[r];

    for (int i = in[r] + 1; i <= out[r]; ++i) {
        int u = id[i];
        F[u] = F[fa[u]] * nd[u].k + nd[u].b;
        all = all + (F[u] - F[fa[u]]) / q[u];
    }

    f[r] = (1 - all.b) / all.k / q[r];

    for (int i = in[r] + 1; i <= out[r]; ++i) {
        int u = id[i];
        Node now = (F[u] - F[fa[u]]) / q[u];
        f[u] = now.k * f[r] * q[r] + now.b;
    }

    ld res = 0;

    for (int i = in[r]; i <= out[r]; ++i) {
        int u = id[i];
        res += f[u] * p[u];
    }

    return res;
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", fa + i), G.insert(fa[i], i);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    dfs(1);
    ld ans = 0;

    for (int i = 1; i <= n; ++i)
        ans = max(ans, solve(i));

    printf("%.9LF", ans);
    return 0;
}
posted @ 2024-07-22 19:48  wshcl  阅读(48)  评论(0)    收藏  举报