MAZE(2019年牛客多校第二场E题+线段树+矩阵乘法)

题目链接

传送门

题意

在一张\(n\times m\)的矩阵里面,你每次可以往左右和下三个方向移动(不能回到上一次所在的格子),\(1\)表示这个位置是墙,\(0\)为空地。

现在有\(q\)次操作,操作一是将\((x,y)\)这个位置的状态取反,操作二问你从\((1,x)\)走到\((n,y)\)的方案数。

思路

首先我们考虑不带修改操作时求方案数:

我们发现从第\(i-1\)行到第\(i\)行的\(j\)这个位置只能通过\((i-1,j)\)到达,因此可以从第\(i-1\)行到\((i,j)\)的位置只能是与\((i-1,j)\)的路径上不能有墙的点,从而我们可以得知\(dp[x][y]=\sum\limits_{i=L}^{R}dp[x-1][i]\),其中\(dp[x][y]\)表示到达\((x,y)\)这个位置的方案数,\(L,R\)表示与\((x-1,y)\)联通的左右端点。

我们发现这是一个递推式,因此我们可以用矩阵乘法来维护这个东西,我们用一个例子来理解:

假设要从第\(i-1\)行到达第\(i\)行,且第\(i-1\)行的状态为\("10010"\),那么将递推式表示成矩阵乘法为:

\[\left( \begin{matrix} dp[i][1] & dp[i][2] & dp[i][3] & dp[i][4] & dp[i][5] \end{matrix} \right)= \left( \begin{matrix} dp[i-1][1] & dp[i-1][2] & dp[i-1][3] & dp[i-1][4] & dp[i-1][5] \end{matrix} \right) \times \left( \begin{matrix} 0 & 0 & 0 & 0 & 1\\ 0 & 1 & 1 & 0 & 1\\ 0 & 1 & 1 & 0 & 1\\ 0 & 0 & 0 & 0 & 1\\ 0 & 0 & 0 & 0 & 1 \end{matrix} \right) \]

得到了相邻两项的递推式那么从第\(1\)行到第\(n\)行的答案那么答案就是\(dp[n+1][y]\),为什么是\(n+1\)而不是\(n\)呢?因为如果是\(n\),那么得到的只有从\(n-1\)行到达这个位置的方案数,缺少了从第\(n\)行的其他位置到达这个位置的方案数。

那么待修改操作的我们该怎么处理呢?

我们发现修改一个点的位置只会影响当前行与下一行的系数矩阵,并不会影响其他行直接的系数矩阵,如果我们暴力修改然后暴力求解递推式那么对于每次操作都要从第\(1\)行递推到第\(n\)行,那么每次修改操作复杂度为\(O(1)\),求解复杂度为\(O(nm^3)\),很明显不能满足题目给的时限。

我们发现如果我们用线段树来维护这个东西,那么每次修改的复杂度为\(O(m^3long(n))\),求解复杂度为\(O(1)\),那么总体复杂度就比上面上了一个\(n\)

而维护方式也很简单,定义线段树的每个结点都是一个系数矩阵,表示从结点代表的左端点\((l)\)到右端点\((r)\)\(+1\)的递推式中的矩阵相乘,而且系数矩阵中的\(sum[i][j]\)表示的是从\(l\)的第\(i\)列到\(r+1\)的第\(j\)列的方案数,最后答案为根结点的\(sum[x][y]\)

代码

#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> pil;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;

#define lson (rt<<1)
#define rson (rt<<1|1)
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********\n")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("/home/dillonh/CLionProjects/Dillonh/in.txt","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)

const double eps = 1e-8;
const int mod = 1000000007;
const int maxn = 50000 + 7;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;

int n, m, q, op, x, y;
int mp[maxn][11];

struct node {
    int l, r, sum[11][11];

    node operator * (const node& a) const {
        node c;
        for(int i = 1; i <= m; ++i) {
            for(int j = 1; j <= m; ++j) {
                c.sum[i][j] = 0;
                for(int k = 1; k <= m; ++k) {
                    c.sum[i][j] = (c.sum[i][j] + 1LL * sum[i][k] * a.sum[k][j] % mod) % mod;
                }
            }
        }
        return c;
    }
}segtree[maxn<<2];

void push_up(int rt) {
    int l = segtree[rt].l, r = segtree[rt].r;
    segtree[rt] = segtree[lson] * segtree[rson];
    segtree[rt].l = l, segtree[rt].r = r;
}

void build(int rt, int l, int r) {
    segtree[rt].l = l, segtree[rt].r = r;
    if(l == r) {
        for(int i = 1; i <= m; ++i) {
            for(int j = 1; j <= m; ++j) segtree[rt].sum[i][j] = 0;
            for(int j = i; j <= m && !mp[l][j]; ++j) segtree[rt].sum[i][j] = 1;
            for(int j = i; j >= 1 && !mp[l][j]; --j) segtree[rt].sum[i][j] = 1;
        }
        return;
    }
    int mid = (l + r) >> 1;
    build(lson, l, mid);
    build(rson, mid + 1, r);
    push_up(rt);
}

void update(int rt, int pos) {
    if(segtree[rt].l == segtree[rt].r) {
        for(int i = 1; i <= m; ++i) {
            for(int j = 1; j <= m; ++j) segtree[rt].sum[i][j] = 0;
            for(int j = i; j <= m && !mp[pos][j]; ++j) segtree[rt].sum[i][j] = 1;
            for(int j = i; j >= 1 && !mp[pos][j]; --j) segtree[rt].sum[i][j] = 1;
        }
        return;
    }
    int mid = (segtree[rt].l + segtree[rt].r) >> 1;
    if(pos <= mid) update(lson, pos);
    else update(rson, pos);
    push_up(rt);
}

int main() {
#ifndef ONLINE_JUDGE
    FIN;
#endif
    scanf("%d%d%d", &n, &m, &q);
    for(int i = 1; i <= n; ++i) {
        for(int j = 1; j <= m; ++j) {
            scanf("%1d", &mp[i][j]);
        }
    }
    build(1, 1, n);
    while(q--) {
        scanf("%d%d%d", &op, &x, &y);
        if(op == 1) {
            mp[x][y] ^= 1;
            update(1, x);
        } else {
            printf("%d\n", segtree[1].sum[x][y]);
        }
    }
    return 0;
}
posted @ 2019-08-22 16:48  Dillonh  阅读(272)  评论(0编辑  收藏  举报