QOJ 59. Determinant of A+Bz 题解

QOJ # 59. Determinant of A+Bz 题解

Determinant of A+Bz - Problem - QOJ.ac

起因是有人在联考里考了搬了一个需要用到这个的题,本来原题是 \(\mathcal{O}(n^4)\) 就可以过的,但是丧心病狂的出题人把它加强到了 \(\mathcal{O}(n^3)\)

题意

给定一个 \(n\times n\) 的矩阵,每个位置是一个 \(Ax+B\) 的多项式。你需要求出这个矩阵的行列式,输出这个 \(n\) 次多项式。对 \(9998244353\) 取模。

\(n\le 500,0\le A_{i,j},B_{i,j} < 998244353\)

海森堡算法

海森堡可以在 \(\mathcal{O}(n^3)\) 的时间复杂度内求解一个矩阵的特征多项式,其中矩阵 \(A\) 的特征多项式为 \(det(I_nx-A)\)

首先一个暴力做法是用拉格朗日插值求解,即把 \(x\) 带入 \(0\sim n\) 然后求解行列式,复杂度为 \(\mathcal{O}(n^4)\)。我们发现复杂度的瓶颈为每次求解行列式,所以可以思考将矩阵变为一个比较好看的形式,使得每次求行列式可以快速计算。

初等变换

一个矩阵 \(A\) 的初等变换共有三种:

  • 交换第 \(s\) 行和第 \(t\) 行,矩阵表示为将 \(I_n\) 的第 \(s\) 行和第 \(t\) 行互换后的矩阵,记为 \(P_{s,t}\)\(P_{s,t}A\) 表示交换 \(A\) 的第 \(s\) 行和第 \(t\) 行,\(AP_{s,t}\) 表示交换 \(A\) 的第 \(s\) 列和第 \(t\) 列。
  • 将第 \(s\) 行乘一个数 \(c\),矩阵表示为将 \(I_n\)\((s,s)\) 位置设为 \(c\) 的矩阵,记为 \(D_s(c)\)\(D_s(c)A\) 表示将 \(A\) 的第 \(s\) 行乘 \(c\)\(AD_s(c)\) 表示将 \(A\) 的第 \(i\) 列乘 \(c\)
  • 将第 \(s\) 行乘一个数 \(k\) 加到第 \(t\) 行上,矩阵表示为将 \(I_n\)\((t,s)\) 位置设为 \(k\) 的矩阵,记为 \(T_{s,t}(k)\)\(T_{s,t}(k)A\) 表示将 \(A\) 的第 \(s\) 行乘 \(k\) 加到第 \(t\) 行上。\(AT_{s,t}(k)\) 表示将 \(A\) 的第 \(t\) 列乘 \(k\) 加到第 \(s\) 列上。

三个矩阵都是可逆矩阵,有 \(P^{-1}_{s,t} = P_{s,t},D^{-1}_s(c) = D_s(c^{-1}),T^{-1}_{s,t}(k) = T_{s,t}(-k)\)

相似变换

对于 \(n\times n\) 的矩阵 \(A\)\(B\),当存在可逆矩阵 \(P\) 满足:

\[B = P^{-1}AP \]

则称为 \(A\)\(B\) 相似,记变换 \(A\leftrightarrow P^{-1}AP\) 为相似变换。

命题:相似变换不改变矩阵的行列式和特征多项式,即 \(A\)\(P^{-1}AP\) 有相同的行列式和特征多项式。

证明:首先 \(A\)\(P^{-1}AP\) 的行列式相同是好证的,考虑证明特征多项式相同:

\[\begin{aligned} det(I_nx-P^{-1}AP) &= det(I_nxP^{-1}P-P^{-1}AP) \\ &= det(P^{-1}I_nxP-P^{-1}AP) \\ &= det(P^{-1}(I_nx-A)P) \\ &= det(I_nx-A)det(P^{-1})det(P) \\ &= det(I_nx-A) \end{aligned} \]

\(P\) 写成初等矩阵的积,设 \(P = T_1T_2\ldots T_n\),那么:

\[P^{-1}AP = T_n^{-1}\ldots T_1^{-1}AT_1\ldots T_n \]

于是相似变换相当于每次对 \(A\) 做一个初等行变换,再用这个初等矩阵的逆矩阵对 \(A\) 做一个列变换。三个初等变换分别为:

  • 先交换 \(A\) 的第 \(s\) 行和第 \(t\) 行,再交换 \(A\) 的第 \(s\) 列和第 \(t\) 列。
  • 先将 \(A\) 的第 \(s\) 行乘上 \(c\),再将 \(A\) 的第 \(t\) 列除以 \(c\)
  • 先将 \(A\) 的第 \(s\) 行乘 \(k\) 加到第 \(t\) 行上,再将 \(A\) 的第 \(t\) 列乘 \(-k\) 加到第 \(s\) 列上。

上海森堡矩阵

跟上海没有关系

因为相似变换不改变特征多项式,所以我们可以尝试用对 \(A\) 做相似变换把 \(A\) 消成比较好看的形式。

但是很遗憾,简单的高斯消元无法通过相似变换将 \(A\) 消成上三角矩阵,因为当你用 \((i,i)\) 去消第 \(j\) 行时,第 \(j\) 列又会对第 \(i\) 列有贡献,所以不太好做。于是考虑可以将 \(A\) 消成上海森堡矩阵。

上海森堡矩阵的定义为,除主对角线及以上和主对角线下面一格的位置,其余位置都为 \(0\),即所有满足 \(i > j+1\)\(A_{i,j} = 0\)。这样在做消元的时候就是用 \((i+1,i)\) 去消第 \(j\) 行,那么逆矩阵就是用第 \(j\) 列对第 \(i+1\) 列有贡献,而第 \(i+1\) 列在下一步才会枚举到,没有影响。于是就可以在 \(\mathcal{O}(n^3)\) 的复杂度内将一个矩阵 \(A\) 通过相似变换消成上海森堡矩阵。


\(A\) 消成上海森堡矩阵后,现在的问题就是如何求一个上海森堡矩阵的特征多项式。还是用拉差来计算,每次就是要快速计算一个上海森堡矩阵的行列式(下面这一部分和网上不太相同,没有用一些比较高深的东西)。

我们考虑普通求行列式的算法,每次用 \((i,i)\) 去消第 \(i+1\sim n\) 行第 \(i\) 列的值。而因为这是上海森堡矩阵,所以只有 \(A_{i+1,i}\) 可能有值,其余 \(j>i+1\)\(j\) 都满足 \(A_{j,i} = 0\),所以每次消的时候只用将第 \(i+1\) 行消掉即可,复杂度为 \(\mathcal{O}(n^2)\)

总的时间复杂度为 \(\mathcal{O}(n^3)\),于是我们就将特征多项式的求法由 \(\mathcal{O}(n^4)\) 优化为了 \(\mathcal{O}(n^3)\)

这一部分的代码:

ll det(int n)
{
    ll ans = 1;
    for(int i = 1;i <= n;i++)
    {
        if(!f[i][i])
            if(i+1 <= n&&f[i+1][i])swap(f[i],f[i+1]),ans = -ans;
            else return 0;
        (ans *= f[i][i]) %= mod;
        ll s = f[i+1][i]*inv(f[i][i])%mod;
        for(int j = i;j <= n;j++)
            (f[i+1][j] -= f[i][j]*s) %= mod;
    }
    return (ans+mod)%mod;
}
inline int calc(int n,ll x)
{
    for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)
        f[i][j] = ((i==j)*x+b[i][j])%mod;
    return det(n);
}
void Hessenberg(int n)
{
    for(int i = 2;i <= n;i++)
    {
        if(!b[i][i-1])
        {
            int p = 0;
            for(int j = i;j <= n;j++)
                if(b[j][i-1]){p = j;break;}
            if(!p)continue;
            for(int k = 1;k <= n;k++)swap(b[i][k],b[p][k]);
            for(int k = 1;k <= n;k++)swap(b[k][i],b[k][p]);
        }
        for(int j = i+1;j <= n;j++)
        {
            ll s = b[j][i-1]*inv(b[i][i-1])%mod;
            for(int k = 1;k <= n;k++)(b[j][k] -= b[i][k]*s) %= mod;
            for(int k = 1;k <= n;k++)(b[k][i] += b[k][j]*s) %= mod;
        }
    }
    for(int i = 0;i <= n;i++)
    {
        ll s = 1;
        for(int j = 0;j <= n;j++)g[j] = !j;
        for(int j = 0;j <= n;j++)if(i != j)
        {
            s = s*(i-j+mod)%mod;
            for(int k = n;~k;k--)
                g[k] = ((k?g[k-1]:0)+g[k]*(mod-j))%mod;
        }
        s = inv(s)*calc(n,i)%mod;
        for(int j = 0;j <= n;j++)(F[j] += g[j]*s) %= mod;
    }
}

题解

如果我们能把 \(det(Ax+B)\) 的形式转化成 \(det(I_nx+C)\) 的形式,就可以用上面的海森堡算法了。(注意 \(C\) 不取反也行)

首先第一想法是乘一个 \(A^{-1}\),变成 \(\frac{det(A^{-1}(Ax+B))}{det(A^{-1})} = \frac{det(I_nx+A^{-1}B)}{det(A^{-1})}\),但是如果 \(A\) 不是满秩的就无法使用这个做法。

我们考虑直接将 \(A,B\) 两个矩阵放在一起,然后尝试将 \(A\) 消成 \(I_n\)。假设到了第 \(i\) 行,先用前 \(i-1\) 列将第 \(i\) 列的前 \(i-1\) 行都消成 \(0\),然后找这一列的主元。如果此时 \(A\) 的第 \(i\) 列全是 \(0\),那么就会有问题。而因为 \(A\) 的第 \(i\) 列全是 \(0\),所以原矩阵中这一列的元素都不含 \(x\),我们可以考虑给这一列的所有元素都乘上 \(x\),最后再将行列式除以 \(x\)

将一列乘上 \(x\) 相当于将 \(A\) 的第 \(i\) 列赋值成 \(B\) 的第 \(i\) 列然后将 \(B\) 的第 \(i\) 列清空,此时再用前 \(i-1\) 列将第 \(i\) 列的前 \(i-1\) 行都消掉,如果第 \(i\) 列有非零元素就可以找到主元了,而如果这一列仍然全部为 \(0\),我们就再进行上述操作(消 \(A\)\(i\) 列前 \(i-1\) 行时 \(B\) 也会跟着操作,所以可能会产生新的非零元素)。记录一个变量表示最后要除多少次 \(x\),如果中途发现除的次数 \(>n\) 了,那么最终答案就是 \(0\)

这样子做下去,要么是发现答案为 \(0\),要么是将 \(A\) 消成了单位矩阵,此时再做海森堡算法即可。整个复杂度依然为 \(\mathcal{O}(n^3)\)

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ll long long
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
using namespace std;
const int N = 505,mod = 998244353;
int n,c;
ll a[N][N],b[N][N],f[N][N],g[N],F[N],sum;
inline ll inv(ll x){return x<0?inv((x+mod)%mod):x==1?1:(mod-mod/x)*inv(mod%x)%mod;}
ll det(int n)
{
    ll ans = 1;
    for(int i = 1;i <= n;i++)
    {
        if(!f[i][i])
            if(i+1 <= n&&f[i+1][i])swap(f[i],f[i+1]),ans = -ans;
            else return 0;
        (ans *= f[i][i]) %= mod;
        ll s = f[i+1][i]*inv(f[i][i])%mod;
        for(int j = i;j <= n;j++)
            (f[i+1][j] -= f[i][j]*s) %= mod;
    }
    return (ans+mod)%mod;
}
inline int calc(int n,ll x)
{
    for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)
        f[i][j] = ((i==j)*x+b[i][j])%mod;
    return det(n);
}
void Hessenberg(int n)
{
    for(int i = 2;i <= n;i++)
    {
        if(!b[i][i-1])
        {
            int p = 0;
            for(int j = i;j <= n;j++)
                if(b[j][i-1]){p = j;break;}
            if(!p)continue;
            for(int k = 1;k <= n;k++)swap(b[i][k],b[p][k]);
            for(int k = 1;k <= n;k++)swap(b[k][i],b[k][p]);
        }
        for(int j = i+1;j <= n;j++)
        {
            ll s = b[j][i-1]*inv(b[i][i-1])%mod;
            for(int k = 1;k <= n;k++)(b[j][k] -= b[i][k]*s) %= mod;
            for(int k = 1;k <= n;k++)(b[k][i] += b[k][j]*s) %= mod;
        }
    }
    for(int i = 0;i <= n;i++)
    {
        ll s = 1;
        for(int j = 0;j <= n;j++)g[j] = !j;
        for(int j = 0;j <= n;j++)if(i != j)
        {
            s = s*(i-j+mod)%mod;
            for(int k = n;~k;k--)
                g[k] = ((k?g[k-1]:0)+g[k]*(mod-j))%mod;
        }
        s = inv(s)*calc(n,i)%mod;
        for(int j = 0;j <= n;j++)(F[j] += g[j]*s) %= mod;
    }
}
char buf[1<<21],*p1,*p2;
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();ll sum = 1;
    for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)b[i][j] = rd();
    for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)a[i][j] = rd();
    for(int i = 1;i <= n;i++)
    {
        int p = 0;
        while(!p)
        {
            for(int j = i;j <= n;j++)
                if(a[j][i]){p = j;break;}
            if(p)break;
            for(int j = 1;j < i;j++)if(a[j][i])
            {
                ll s = a[j][i]*inv(a[j][j])%mod;
                for(int k = 1;k <= n;k++)
                    (a[k][i] -= a[k][j]*s) %= mod,(b[k][i] -= b[k][j]*s) %= mod;
            }
            c++;for(int j = 1;j <= n;j++)a[j][i] = b[j][i],b[j][i] = 0;
            if(c > n){for(int i = 0;i <= n;i++)printf("%d%c",0," \n"[i==n]);return 0;}
        }
        swap(a[i],a[p]);swap(b[i],b[p]);
        if(i != p)sum = -sum;
        ll s = inv(a[i][i]);(sum *= a[i][i]) %= mod;
        for(int j = 1;j <= n;j++)
            (a[i][j] *= s) %= mod,(b[i][j] *= s) %= mod;
        for(int j = 1;j <= n;j++)if(i != j)
        {
            ll s = a[j][i]*inv(a[i][i])%mod;
            for(int k = 1;k <= n;k++)
                (a[j][k] -= a[i][k]*s) %= mod,(b[j][k] -= b[i][k]*s) %= mod;
        }
    }
    Hessenberg(n);
    for(int i = 0;i <= n;i++)F[i] = i<=n-c?F[i+c]*(sum+mod)%mod:0;
    for(int i = 0;i <= n;i++)printf("%lld%c",F[i]," \n"[i==n]);
    return 0;
}

一些例题

posted @ 2025-06-12 11:09  max0810  阅读(89)  评论(1)    收藏  举报