2025/8/26 矩阵部分测试
T1
最大子矩阵和裸题,唐题,暴力都有\(80pts\)。注意取小的来转移。
时间复杂度\(\mathcal{O}(\min(n,m)^2*\max(n,m))\)。
#include<bits/stdc++.h>
using namespace std;
int main(){
int n,m;
cin>>n>>m;
vector<vector<int>>mat(n,vector<int>(m));
for(int i=0;i<n;i++){
for(int j=0;j<m;j++){
cin>>mat[i][j];
}
}
int ans=INT_MIN;
if(n<=m){
vector<vector<int>>pref_col(n+1,vector<int>(m, 0));
for(int j=0;j<m;j++){
for(int i=1;i<=n;i++){
pref_col[i][j]=pref_col[i-1][j]+mat[i-1][j];
}
}
for(int i=1;i<=n;i++){
for(int j=i;j<=n;j++){
int cur=0;
int best=INT_MIN;
for(int k=0;k<m;k++){
int num=pref_col[j][k]-pref_col[i - 1][k];
cur=max(num,cur+num);
best=max(best,cur);
}
ans=max(ans,best);
}
}
}else{
vector<vector<int>>pref_row(n,vector<int>(m+1,0));
for(int i=0;i<n;i++){
for(int j=1;j<=m;j++){
pref_row[i][j]=pref_row[i][j-1]+mat[i][j-1];
}
}
for(int l=1;l<=m;l++){
for(int r=l;r<=m;r++){
int cur=0;
int best=INT_MIN;
for(int i=0;i<n;i++){
int num=pref_row[i][r]-pref_row[i][l-1];
cur=max(num,cur+num);
best=max(best,cur);
}
ans=max(ans,best);
}
}
}
cout<<ans;
return 0;
}
T2

简洁。。。当时唐了,只想到\(\left( {\begin{array}{c:c}
\begin{matrix}
A
\end{matrix}&
\begin{matrix}
I
\end{matrix}
\end{array}} \right)
\)高斯消元成为\(\left( {\begin{array}{c:c}
\begin{matrix}
I
\end{matrix}&
\begin{matrix}
A^{-1}
\end{matrix}
\end{array}} \right)
\)了,再用\(A^{-1}*B\)算出\(X\),但是只要对\(\left( {\begin{array}{c:c}
\begin{matrix}
A
\end{matrix}&
\begin{matrix}
B
\end{matrix}
\end{array}} \right)
\)消元不就完了??虽然复杂度均为\(\mathcal{O}(n^3)\),但有卡常,卡在\(4s\)。
自己写的代码。
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll mod=998244353;
ll qp(ll a,ll x){
ll ans=1;
while(x){
if(x&1){
ans*=a;
ans%=mod;
}
a*=a;
a%=mod;
x>>=1;
}
return ans%mod;
}
signed main(){
// freopen("monica.in","r",stdin);
// freopen("monica.out","w",stdout);
ll n,m,r,w;
scanf("%lld%lld",&n,&m);
vector<vector<ll>>b=vector<vector<ll>>(n+1,vector<ll>(2*n+1,0));
for(ll i=1;i<=n;i++){
for(ll j=1;j<=m;j++){
scanf("%lld",&b[i][j]),b[i][j]%=mod;
}
}
scanf("%lld%lld",&n,&r);
vector<vector<ll>>a=vector<vector<ll>>(n+1,vector<ll>(n+r+1,0));
for(ll i=1;i<=n;i++){
for(ll j=1;j<=m;j++){
a[i][j]=b[i][j];
}
}
for(ll i=1;i<=n;i++){
for(ll j=n+1;j<=n+r;j++){
scanf("%lld",&a[i][j]);
a[i][j]%=mod;
}
}
w=n+r;
for(ll i=1;i<=n;i++){
ll pos=i;
for(ll j=i+1;j<=n;j++){
if(abs(a[pos][i])<abs(a[j][i]))pos=j;
}
if(i!=pos)swap(a[i],a[pos]);
ll inv=qp(a[i][i],mod-2);
for(ll j=1;j<=n;j++)
if(j!=i){
ll mul=a[j][i]*inv%mod;
for(ll k=i;k<=w;k++)
a[j][k]=((a[j][k]-a[i][k]*mul)%mod+mod)%mod;
}
for(ll j=1;j<=w;j++)a[i][j]=(a[i][j]*inv%mod);
}
for(ll i=1;i<=n;i++){
for(ll j=n+1;j<=n+r;j++)
printf("%lld ", a[i][j]%mod);
printf("\n");
}
return 0;
}
T3
略
T4
题目描述
给定\(n\)个\(2*2\)的下三角矩阵,有下列两种操作:
- 对区间\((L,R)\)所有的矩阵逐个相乘。
- 对区间\((L,R)\)所有的矩阵左下角的数字增加\(x\)。
对每一个操作1,给出答案矩阵。
输入格式
第一行两个整数\(n,q\),表示共\(n\)个矩阵,\(q\)次询问。
之后\(2n\)行,每两行代表一个矩阵。
接下来\(q\)行,每行代表一个询问,第一个数字为\(op\),若\(op\)为\(1\)接下来输入两个数字\(L,R\),若\(op\)为\(2\)接下来输入三个数字\(L,R,x\),意义如题目描述。
输出格式
对每一个\(op=1\)的操作,输出一个\(2*2\)矩阵代表答案。需要对\(998244353\)取模。
Sol:
维护区间肯定线段树。但是这个矩阵如何贴在线段树上?
下三角矩阵只有三个数字,没必要存成数组。
现在推柿子。
\(A=\begin{bmatrix} a_{1}& 0\\b_{1}& c_{1}\\ \end{bmatrix}\)
\(B=\begin{bmatrix} a_{2}& 0\\b_{2}& c_{2}\\ \end{bmatrix}\)
\(C=\begin{bmatrix} a_{3}& 0\\b_{3}& c_{3}\\ \end{bmatrix}\)
则\(A*B=\begin{bmatrix} a_{1}a_{2}& 0\\a_{1}b_{1}+c_{1}b_{2}& c_{1}c_{2}\\ \end{bmatrix}\)
\(A*B*C=\begin{bmatrix} a_{1}a_{2}a_{3}& 0\\a_{1}a_{3}b_{1}+c_{1}a_{3}b_{2}+c_{1}c_{2}b_{3}& c_{1}c_{2}c_{3}\\ \end{bmatrix}\)
推导可得\(Ans.b = sum_{i=1}^{k} [ (prod_{j=1}^{i-1} c_j) * b_i * (prod_{j=i+1}^{k} a_j) ]\)
线段树设计:
a:区间内矩阵a的乘积。
c:区间内矩阵c的乘积。
b:区间内矩阵乘积的左下角元素。
d:辅助值,用于计算b的区间加法的影响。
tag:懒惰标记,记录未传递的加法操作。
这样就完了。注意查询线段树中区间[L, R]的矩阵乘积的时候,合并顺序为从右到左。
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int maxn = 200010;
struct Node {
ll a, b, c, d;
ll tag;
} tree[4 * maxn];
void push_up(int idx) {
int lson = idx * 2;
int rson = idx * 2 + 1;
tree[idx].a = tree[rson].a * tree[lson].a % mod;
tree[idx].c = tree[rson].c * tree[lson].c % mod;
tree[idx].b = (tree[rson].a * tree[lson].b % mod + tree[rson].b * tree[lson].c % mod) % mod;
tree[idx].d = (tree[rson].a * tree[lson].d % mod + tree[rson].d * tree[lson].c % mod) % mod;
}
void push_down(int idx) {
if (tree[idx].tag) {
int lson = idx * 2;
int rson = idx * 2 + 1;
tree[lson].tag = (tree[lson].tag + tree[idx].tag) % mod;
tree[lson].b = (tree[lson].b + tree[idx].tag * tree[lson].d) % mod;
tree[rson].tag = (tree[rson].tag + tree[idx].tag) % mod;
tree[rson].b = (tree[rson].b + tree[idx].tag * tree[rson].d) % mod;
tree[idx].tag = 0;
}
}
void build(int idx, int l, int r) {
if (l == r) {
ll a_val, b_val, c_val, tmp;
scanf("%lld %lld", &a_val, &tmp);
scanf("%lld %lld", &b_val, &c_val);
tree[idx].a = a_val % mod;
tree[idx].b = b_val % mod;
tree[idx].c = c_val % mod;
tree[idx].d = 1;
tree[idx].tag = 0;
return;
}
int mid = (l + r) >> 1;
build(idx * 2, l, mid);
build(idx * 2 + 1, mid + 1, r);
push_up(idx);
}
void update(int idx, int l, int r, int ul, int ur, ll x) {
if (ul <= l && r <= ur) {
tree[idx].tag = (tree[idx].tag + x) % mod;
tree[idx].b = (tree[idx].b + x * tree[idx].d) % mod;
return;
}
push_down(idx);
int mid = (l + r) >> 1;
if (ul <= mid) update(idx * 2, l, mid, ul, ur, x);
if (ur > mid) update(idx * 2 + 1, mid + 1, r, ul, ur, x);
push_up(idx);
}
Node query(int idx, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[idx];
}
push_down(idx);
int mid = (l + r) >> 1;
if (qr <= mid) return query(idx * 2, l, mid, ql, qr);
else if (ql > mid) return query(idx * 2 + 1, mid + 1, r, ql, qr);
else {
Node right_res = query(idx * 2 + 1, mid + 1, r, ql, qr);
Node left_res = query(idx * 2, l, mid, ql, qr);
Node res;
res.a = right_res.a * left_res.a % mod;
res.c = right_res.c * left_res.c % mod;
res.b = (right_res.a * left_res.b % mod + right_res.b * left_res.c % mod) % mod;
res.d = (right_res.a * left_res.d % mod + right_res.d * left_res.c % mod) % mod;
return res;
}
}
int main() {
int n, q;
scanf("%d %d", &n, &q);
build(1, 1, n);
while (q--) {
int op;
scanf("%d", &op);
if (op == 1) {
int L, R;
scanf("%d %d", &L, &R);
Node res = query(1, 1, n, L, R);
printf("%lld 0\n", res.a);
printf("%lld %lld\n", res.b, res.c);
} else {
int L, R;
ll x;
scanf("%d %d %lld", &L, &R, &x);
update(1, 1, n, L, R, x);
}
}
return 0;
}

浙公网安备 33010602011771号