Codeforces 1197F Coloring Game 矩阵快速幂 (看题解)

Coloring Game

我写的复杂度是 1000 * 64 * 64 * 64 * log(1e9),  感觉这个东西是很好想的, 肯定是T了的。

其实可以优化掉一个64, 就是在转移的时候用64 * 64的矩阵和 64 * 1的答案相邻相乘, 

这样就可以优化掉一个64了, 以前好像没有见过这种小技巧。

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 1000 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;}

const int MN = 64;

struct Vec {
    int a[MN];
    Vec() {
        for(int i = 0; i < MN; i++) {
            a[i] = 0;
        }
    }
};

struct Matrix {
    int a[MN][MN];
    Matrix(int v = 0) {
        for(int i = 0; i < MN; i++) {
            for(int j = 0; j < MN; j++) {
                a[i][j] = (i == j) ? v : 0;
            }
        }
    }
    inline Matrix operator * (const Matrix &B) const {
        Matrix C(0);
        for(int i = 0; i < MN; i++) {
            for(int k = 0; k < MN; k++) {
                if(!a[i][k]) continue;
                for(int j = 0; j < MN; j++) {
                    C.a[i][j] += 1LL * a[i][k] * B.a[k][j] % mod;
                    if(C.a[i][j] >= mod) C.a[i][j] -= mod;
                }
            }
        }
        return C;
    }
    Vec operator * (const Vec &B) const {
        Vec C;
        for(int i = 0; i < MN; i++) {
            for(int j = 0; j < MN; j++) {
                add(C.a[i], 1LL *  a[i][j] * B.a[j] % mod);
            }
        }
        return C;
    }
} M[30], M2[4];



int n, m, a[N];
int ret[N][4];
vector<PII> V[N];

int f[4][4];
int dp[N][4];
bool vis[4];
int v[3];

int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    scanf("%d", &m);
    for(int i = 1; i <= m; i++) {
        int x, y, c;
        scanf("%d%d%d", &x, &y, &c);
        V[x].push_back(mk(y, c));
    }
    for(int i = 1; i <= n; i++) {
        sort(ALL(V[i]));
    }

    for(int i = 1; i <= n; i++) {
        V[i].push_back(mk(a[i] + 1, 1));
    }

    for(int i = 1; i <= 3; i++) {
        for(int j = 1; j <= 3; j++) {
            scanf("%d", &f[i][j]);
        }
    }

    for(int mask = 0; mask < MN; mask++) {
        for(int i = 0; i < 3; i++) {
            v[i] = mask >> (i * 2) & 3;
        }
        for(int color = 1; color <= 3; color++) {
            memset(vis, 0, sizeof(vis));
            if(f[color][1]) vis[v[0]] = true;
            if(f[color][2]) vis[v[1]] = true;
            if(f[color][3]) vis[v[2]] = true;
            int sg = -1;
            for(int i = 0; i < 4; i++) {
                if(!vis[i]) {
                    sg = i;
                    break;
                }
            }
            int nmask = sg + (v[0] << 2) + (v[1] << 4);
            add(M[0].a[nmask][mask], 1);
            add(M2[color].a[nmask][mask], 1);
        }
    }

    for(int i = 1; i < 30; i++) {
        M[i] = M[i - 1] * M[i - 1];
    }

    for(int i = 1; i <= n; i++) {
        Vec tmp; tmp.a[63] = 1;
        int last = 0;
        for(auto &t : V[i]) {
            int cnt = t.fi - last - 1;
            for(int j = 0; j < 30; j++) {
                if(cnt >> j & 1) {
                    tmp = M[j] * tmp;
                }
            }
            if(t.fi != a[i] + 1) {
                tmp = M2[t.se] * tmp;
            }
            last = t.fi;
        }
        for(int j = 0; j < MN; j++) {
            add(ret[i][j & 3], tmp.a[j]);
        }
    }
    dp[0][0] = 1;
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < 4; j++) {
            for(int k = 0; k < 4; k++) {
                add(dp[i + 1][j ^ k], 1LL * dp[i][j] * ret[i + 1][k] % mod);
            }
        }
    }
    printf("%d\n", dp[n][0]);
    return 0;
}

/*
*/

 

posted @ 2019-07-23 23:10  NotNight  阅读(464)  评论(0编辑  收藏  举报