2017多校4

 

属实自闭。感觉周末要铁。

A B C D E F G H I J K L M
  $\varnothing$ $\varnothing$ $\varnothing$ $\varnothing$ $\varnothing$ $\varnothing$ $\varnothing$  O $\varnothing$ O $\varnothing$  

 

[B. Classic Quotation]

其实就是要求 $S[1 \dots i]$($ i \in [1, L]$) 和 $S[j \dots n]$($j \in [R, n]$) 拼接后有多少个 $T$。

用哈希预处理出 $ok_pre[i][len]$ 表示串 $S$ 以 $i$ 结尾长度为 $len$ 与 $T[1 \dots len]$ 是否相等。

$sumpre[i][len]$ 表示前缀 $i$ 中有多少长度为 $len$ 的子串与 $T[1 \dots len]$ 相等。

$anspre[i]$ 表示 $sumpre[j][m]$ 的前缀和。

那么 $[1, L]$ 中出现串的个数对答案的贡献就为 $anspre[L] \times (n - R + 1)$。

右半部分同理可以处理出 $suf$ 数组。

两端拼起来的贡献为 $\sum_{i = 1}^{m - 1} sumpre[L][i] \times sumsuf[R][m - i]$

#include <bits/stdc++.h>
#define ull unsigned long long
#define ll long long

const int N = 5e4 + 7;
int n, m, k;
char s[N], t[N];
bool op[N][110], os[N][110];
ll sp[N][110], ss[N][110];
ll ans_p[N], ans_s[N];

struct Hash {
    ull base[N];
    ull seed, MOD;
    ull ha[N];
    void setseed(ull s) {
        seed = s;
    }
    void setmod(ull s) {
        MOD = s;
    }
    void init(char *s, int n) {
        for (int i = base[0] = 1; i < N; i++)
            base[i] = base[i - 1] * seed % MOD;
        ha[0] = 0;
        for (int i = 1; i <= n; i++)
            ha[i] = (ha[i - 1] * seed % MOD + s[i]) % MOD;
    }
    ull get(int l, int r) {
        return (ha[r] - ha[l- 1] * base[r - l + 1] % MOD + MOD) % MOD;
    }
} S, T;

inline bool check(int l, int r, int x, int y) {
    ull L = S.get(l, r);
    ull R = T.get(x, y);
    return L == R;
}

ll cal(int l, int r) {
    ll ans = ans_p[l] * (n + 1 - r);
    ans += ans_s[r] * l;
    for (int i = 1; i < m; i++)
        ans += sp[l][i] * ss[r][m - i];
    return ans;
}

void solve() {
    static const ull seed = 146527;
    static const ull MOD = 998244353;
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 0; i <= std::max(n, m) + 2; i++) {
        ans_p[i] = ans_s[i] = 0;
        for (int j = 0; j <= m + 2; j++) 
            sp[i][j] = ss[i][j] = 0;
    }
    scanf("%s%s", s + 1, t + 1);
    S.setseed(seed); S.setmod(MOD);
    T.setseed(seed); T.setmod(MOD);
    S.init(s, n); T.init(t, m);
    for (int i = 1; i <= n; i++) {
        int len = std::min(i, m);
        for (int l = 1; l <= len; l++) {
            if (check(i - l + 1, i, 1, l))
                op[i][l] = 1;
            else 
                op[i][l] = 0;
        }
        for (int l = 1; l <= m; l++)
            sp[i][l] = sp[i - 1][l] + op[i][l];
        ans_p[i] = ans_p[i - 1] + sp[i][m];
    }
    for (int i = n; i >= 1; i--) {
        int len = std::min(n - i + 1, m);
        for (int l = 1; l <= len; l++) {
            if (check(i, i + l - 1, m - l + 1, m))
                os[i][l] = 1;
            else 
                os[i][l] = 0;
        }
        for (int l = 1; l <= m; l++)
            ss[i][l] = ss[i + 1][l] + os[i][l];
        ans_s[i] = ans_s[i + 1] + ss[i][m];
    }
    while (k--) {
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%lld\n", cal(l, r));
    }
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--) solve();
    return 0;
}
View Code

 

[C. Counting Divisors]

已经连这种题都不会做了...药丸...

求质因数的幂次用枚举质数的倍数的方法即可...其他都是暴力...

#include <bits/stdc++.h>
#define ll long long

const int N = 1e6 + 7;
const int MOD = 998244353;
int prime[N], tol, res[N];
ll temp[N];
bool vis[N];

void init() {
    for (int i = 2; i < N; i++) {
        if (!vis[i]) prime[++tol] = i;
        for (int j = 1; j <= tol && i * prime[j] < N; j++) {
            vis[i * prime[j]] = 1;
            if (i % prime[j] == 0) break;
        }
    }
}

void M(int &x) {
    if (x >= MOD) x -= MOD;
}    

int main() {
    init();
    int T;
    scanf("%d", &T);
    while (T--) {
        ll l, r;
        int k;
        scanf("%lld%lld%d", &l, &r, &k);
        for (int i = 0; i <= r - l; i++)
            res[i] = 1;
        for (int i = 0; i <= r - l; i++)
            temp[i] = l + i;
        for (int i = 1; i <= tol; i++) {
            int p = prime[i];
            if (1LL * p * p > r) break;
            for (ll j = (l - 1) / p * p + p; j <= r; j += p) {
                int cnt = 0;
                while (temp[j - l] % p == 0) temp[j - l] /= p, cnt++;
                res[j - l] = 1LL * res[j - l] * (1 + 1LL * k * cnt) % MOD;
            }
        }
        for (int i = 0; i <= r - l; i++)
            if (temp[i] > 1)
                res[i] = 1LL * res[i] * (1 + k) % MOD;
        int ans = 0;
        for (int i = 0; i <= r - l; i++)
            M(ans += res[i]);
        printf("%d\n", ans);
    }
    return 0;
}
View Code

 

[D. Dirt Ratio]

已经连这种题都不会做了...药丸...

求区间数字种数与区间长度的比值。看到比值想到分数规划。二分答案。然后枚举区间右端点,然后就是区间修改区间RMQ了...

#include <bits/stdc++.h>

const int N = 6e4 + 7;
const double eps = 1e-6;
int n, a[N], pos[N];

inline int dcmp(double x) {
    if (fabs(x) < eps) return 0;
    return x < 0 ? -1 : 1;
}

struct Seg {
    #define lp p << 1
    #define rp p << 1 | 1
    double tree[N << 2], lazy[N << 2];
    inline void pushup(int p) {
        tree[p] = std::min(tree[lp], tree[rp]);
    }
    inline void tag(int p, double x) {
        tree[p] += x;
        lazy[p] += x;
    }
    inline void pushdown(int p) {
        if (dcmp(lazy[p]) == 0) return;
        tag(lp, lazy[p]);
        tag(rp, lazy[p]);
        lazy[p] = 0;
    }
    void build(int p, int l, int r) {
        lazy[p] = 0;
        tree[p] = 0;
        if (l == r) return;
        int mid = l + r >> 1;
        build(lp, l, mid);
        build(rp, mid + 1, r);
    }
    void update(int p, int l, int r, int x, int y, double f) {
        if (x > y) return;
        if (x <= l && y >= r) {
            tag(p, f);
            return;
        }
        pushdown(p);
        int mid = l + r >> 1;
        if (x <= mid) update(lp, l, mid, x, y, f);
        if (y > mid) update(rp, mid + 1, r, x, y, f);
        pushup(p);
    }
    double query(int p, int l, int r, int x, int y) {
        if (x > y) return 1e18;
        if (x <= l && y >= r) return tree[p];
        pushdown(p);
        int mid = l + r >> 1;
        double ans = 1e18;
        if (x <= mid) ans = std::min(ans, query(lp, l, mid, x, y));
        if (y > mid) ans = std::min(ans, query(rp, mid + 1, r, x, y));
        return ans;
    }
} seg;

bool check(double mid) {
    seg.build(1, 1, n);
    double ans = 1e18;
    for (int i = 1; i <= n; i++)
        pos[i] = 0;
    for (int i = 1; i <= n; i++) {
        seg.update(1, 1, n, i, i, 1 - mid);
        seg.update(1, 1, n, 1, i - 1, -mid);
        seg.update(1, 1, n, pos[a[i]] + 1, i - 1, 1);
        ans = std::min(ans, seg.query(1, 1, n, 1, i));
        pos[a[i]] = i;
    }
    return dcmp(ans) <= 0;
}

void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", a + i);
    double l = 0, r = 1.0;
    for (int i = 0; i < 15; i++) {
        double mid = (l + r) / 2.0;
        if (check(mid)) r = mid;
        else l = mid;
    }
    printf("%.6f\n", l);
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--)
        solve();
    return 0;
}
View Code

 

[E. Lazy Running]

贪心的想法就是找一个最小的环一直走,走到刚好跨过 $K$。如果存在一个权值为 $w$ 的环,长度为 $c$ 的路径,那么显然存在 $c + w$ 的路径。

那么所有路径在模 $w$ 下的同余,那么对于一个余数找到最短的路径即可构造出其他同余的最短路径。

$d[i][j]$ 表示到 $i$ 点,路径长度模 $w$ 为 $j$ 的最短路径,dijkstra可以求。

$w$ 为 $min(d_{1, 2}, d_{2, 3})$。最后枚举一下 $d[2][i]$,即可构造出最小的大于 $K$ 的路径长度。

#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define pii pair<ll, int>

ll d[5][130000];
int MOD;
std::vector<std::pii> vec[5];

template<class T>
inline bool chkmax(T &a, const T &b) {
    return a < b ? a = b, 1 : 0;
}

template<class T>
inline bool chkmin(T &a, const T &b) {
    return a > b ? a = b, 1 : 0;
}

inline void add(int u, int v, ll c) {
    vec[u].push_back(std::pii(c, v));
    vec[v].push_back(std::pii(c, u));
}

void dijkstra(int s) {
    memset(d, 0x3f, sizeof(d));
    std::priority_queue< std::pii, std::vector<std::pii>, std::greater<std::pii> > que;
    que.push(std::pii(d[s][0] = 0, s));
    while (!que.empty()) {
        auto p = que.top(); que.pop();
        ll cur = p.fi; int u = p.se;
        if (cur > d[u][cur % MOD]) continue;
        for (auto p: vec[u]) {
            int v = p.se; ll now = p.fi + cur;
            if (chkmin(d[v][now % MOD], now)) 
                que.push(std::pii(now, v));
        }
    }
}

int main() {
    freopen("in.txt", "r", stdin);
    int T;
    scanf("%d", &T);
    while (T--) {
        for (int i = 1; i <= 4; i++) vec[i].clear();
        ll K, d1, d2, d3, d4;
        scanf("%lld%lld%lld%lld%lld", &K, &d1, &d2, &d3, &d4);
        MOD = std::min(d1, d2) * 2;
        add(1, 2, d1); add(1, 4, d4); add(2, 3, d2); add(3, 4, d3);
        ll ans = 0x3f3f3f3f3f3f3f3f;
        dijkstra(2);
        for (int i = 0; i < MOD; i++) {
            if (d[2][i] >= K) chkmin(ans, d[2][i]);
            else chkmin(ans, d[2][i] + ((K - d[2][i] + MOD - 1) / MOD) * MOD);
        }
        printf("%lld\n", ans);
    }
    return 0;
}
View Code

 

[F. Logical Chain]

题意很裸,有向图,动态修改边,求强连通分量。

如果用邻接表的话就无法做到高效修改边,如果用邻接矩阵的话就会多了枚举的常数,大概可以看成 $O(n^2)$ 的。

所以就得用 bitset 优化枚举边的过程,然后 Kosaraju 算法是比较适合这么写的...

然后每次询问都暴力即可。

学习了一发手写 bitset。秀秀秀。

#include <bits/stdc++.h>

const int N = 256;

struct Bitset {
    unsigned v[8];
    void reset() { for (int i = 0; i < 8; i++) v[i] = 0; }
    void set(int x) { v[x >> 5] |= 1u << (x & 31); }
    void flip(int x) { v[x >> 5] ^= 1u << (x & 31); }
    bool test(int x) { return v[x >> 5] >> (x & 31) & 1; }
} vis, G[N], G2[N];

std::vector<int> S;
int n, m, cnt;
char s[N];

void dfs(int u) {
    vis.flip(u);
    for (int i = 0; i < 8; i++) {
        while (1) {
            unsigned temp = vis.v[i] & G[u].v[i];
            if (!temp) break;
            dfs(i << 5 | __builtin_ctz(temp));
        }
    }
    S.push_back(u);
}

void dfs2(int u) {
    vis.flip(u);
    ++cnt;
    for (int i = 0; i < 8; i++) {
        while (1) {
            unsigned temp = vis.v[i] & G2[u].v[i];
            if (!temp) break;
            dfs2(i << 5 | __builtin_ctz(temp));
        }
    }
}

void solve() {
    S.clear();
    int ans = 0;
    for (int i = 0; i < n; i++)
        vis.set(i);
    for (int i = 0; i < n; i++)
        if (vis.test(i))
            dfs(i);
    for (int i = 0; i < n; i++)
        vis.set(i);
    for (int i = n - 1; i >= 0; i--) {
        if (vis.test(S[i])) {
            cnt = 0;
            dfs2(S[i]);
            ans += cnt * (cnt - 1) / 2;
        }
    }
    printf("%d\n", ans);
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        for (int i = 0; i < n; i++)
            G[i].reset(), G2[i].reset();
        vis.reset();
        for (int i = 0; i < n; i++) {
            scanf("%s", s);
            for (int j = 0; j < n; j++)
                if (s[j] == '1')
                    G[i].flip(j), G2[j].flip(i);
        }
        while (m--) {
            int k;
            scanf("%d", &k);
            while (k--) {
                int u, v;
                scanf("%d%d", &u, &v);
                u--, v--;
                G[u].flip(v), G2[v].flip(u);
            }
            solve();
        }
    }
    return 0;
}
View Code

 

[G. Matching In Multiplication]

对于 $V$ 部的点,若其度数为 $1$,那么其匹配就已经确定了。拓扑排序去除这些点即可。

然后剩下的图一定是 $U$ 部和 $V$ 部各剩 $m$ 个点,其中 $U$ 部每个点的度数为 $2$,那么总共只有 $2m$ 条边,所以 $V$ 部每个点的度数也为 $2$,否则不存在完美匹配。

那么对于剩下的每一个连通块来说是一个环,匹配方式只有两种,dfs统计一下即可。

#include <bits/stdc++.h>
#define pii pair<int, int>
#define fi first
#define se second

namespace IO
{
    char buf[1 << 21], buf2[1 << 21], a[20], *p1 = buf, *p2 = buf, hh = '\n';
    int p, p3 = -1;
    void read() {}
    void print() {}
    inline int getc() {
        return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
    }
    inline void flush() {
        fwrite(buf2, 1, p3 + 1, stdout), p3 = -1;
    }
    template <typename T, typename... T2>
    inline void read(T &x, T2 &... oth) {
        T f = 1; x = 0;
        char ch = getc();
        while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getc(); }
        while (isdigit(ch)) { x = x * 10 + ch - 48; ch = getc(); }
        x *= f;
        read(oth...);
    }
    template <typename T, typename... T2>
    inline void print(T x, T2... oth) {
        if (p3 > 1 << 20) flush();
        if (x < 0) buf2[++p3] = 45, x = -x;
        do {
            a[++p] = x % 10 + 48;
        } while (x /= 10);
        do {
            buf2[++p3] = a[p];
        } while (--p);
        buf2[++p3] = hh;
        print(oth...);
    }
} // using namespace IO
#define read IO::read
#define print IO::print
#define flush IO::flush

const int N = 6e5 + 7;
const int MOD = 998244353;
bool vis[N];
std::vector<std::pii> vec[N];
int n, temp[2], degree[N];

void dfs(int u, int type, int fa) {
    for (auto p: vec[u]) {
        int v = p.fi, c = p.se;
        if (vis[v] || v == fa) continue;
        vis[v] = 1;
        temp[type] = 1LL * temp[type] * c % MOD;
        dfs(v, type ^ 1, u);
    }
}

int main() {
    int T;
    read(T);
    while (T--) {
        read(n);
        for (int i = 1; i <= 2 * n; i++) {
            vec[i].clear();
            degree[i] = vis[i] = 0;
        }
        for (int u = 1; u <= n; u++) {
            int v, c;
            read(v, c);
            v += n;
            vec[u].push_back(std::pii(v, c));
            vec[v].push_back(std::pii(u, c));
            degree[u]++; degree[v]++;
            read(v, c);
            v += n;
            vec[u].push_back(std::pii(v, c));
            vec[v].push_back(std::pii(u, c));
            degree[u]++; degree[v]++;
        }
        std::queue<int> que;
        for (int i = 1; i <= 2 * n; i++) {
            if (degree[i] == 1)
                que.push(i), vis[i] = 1;
        }
        int ans = 1;
        while (!que.empty()) {
            int u = que.front(); que.pop();
            for (auto p: vec[u]) {
                int v = p.fi, c = p.se;
                if (vis[v]) continue;
                ans = 1LL * ans * c % MOD;
                vis[v] = 1;
                for (auto q: vec[v]) {
                    int vv = q.fi;
                    if (--degree[vv] == 1)
                        que.push(vv), vis[vv] = 1;
                }
            }
        }
        for (int i = 1; i <= n; i++)
            if (!vis[i])
                temp[0] = temp[1] = 1, dfs(i, 0, 0), ans = 1LL * (temp[0] + temp[1]) * ans % MOD;
        print(ans);
    }
    flush();
    return 0;
}
View Code

 

[H. Phone Call]

类似于最小生成树的做法,先把电话线按权值排序,然后对于每条电话线,先把 $a$ 和 $b$ 分别合并到它们的 LCA 上,再把 $c$ 和 $d$ 分别合并到它们的 LCA 上,最后再合并两个 LCA。

对于每个节点 $i$ 维护一个 $up[i]$,表示 $i$ 往上的路径上,深度最大的和其不在一个连通块中的儿子,也就是沿着 $up$ 一直往上走,直到一个 $up$ 等于自身,表示这些点都在一个连通块中,那么每次合并就用这个 $up$ 解决,并且路径压缩一下,复杂度就没问题了。

#include <bits/stdc++.h>
#define ll long long

const int N = 1e5 + 7;

int n, m, ffa[N], fa[N], up[N], ssz[N], sz[N], son[N], top[N], dep[N];
ll cost[N];
std::vector<int> vec[N];

struct Edge {
    int a, b, c, d;
    ll cc;
    bool operator < (const Edge &rhs) const {
        return cc < rhs.cc;
    }
} edge[N];

int find(int x, int *f) {
    return x == f[x] ? x : f[x] = find(f[x], f);
}

void dfs1(int u, int f) {
    ffa[u] = f;
    dep[u] = dep[f] + 1;
    ssz[u] = 1;
    son[u] = 0;
    for (int v: vec[u]) {
        if (v == f) continue;
        dfs1(v, u);
        ssz[u] += ssz[v];
        if (ssz[v] > ssz[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int tp) {
    top[u] = tp;
    if (!son[u]) return;
    dfs2(son[u], tp);
    for (int v: vec[u])
        if (v != ffa[u] && v != son[u])
            dfs2(v, v);
}

int Lca(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
        u = ffa[top[u]];
    }
    if (dep[u] > dep[v]) std::swap(u, v);
    return u;
}

void merge(int u, int v, ll cc) {
    u = find(u, fa), v = find(v, fa);
    if (u == v) return;
    fa[u] = v;
    sz[v] += sz[u];
    cost[v] += cost[u] + cc;
}

void solve(int u, int v, ll cc) {
    while (1) {
        u = find(u, up);
        if (dep[u] <= dep[v]) return;
        int uu = ffa[u];
        merge(u, uu, cc);
        up[u] = uu;
    }
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        for (int i = 1; i <= n; i++) {
            up[i] = fa[i] = i;
            sz[i] = 1;
            cost[i] = 0;
            vec[i].clear();
        }
        for (int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            vec[u].push_back(v);
            vec[v].push_back(u);
        }
        dfs1(1, 0);
        dfs2(1, 1);
        for (int i = 0; i < m; i++)
            scanf("%d%d%d%d%lld", &edge[i].a, &edge[i].b, &edge[i].c, &edge[i].d, &edge[i].cc);
        std::sort(edge, edge + m);
        for (int i = 0; i < m; i++) {
            int lca = Lca(edge[i].a, edge[i].b);
            solve(edge[i].a, lca, edge[i].cc);
            solve(edge[i].b, lca, edge[i].cc);
            lca = Lca(edge[i].c, edge[i].d);
            solve(edge[i].c, lca, edge[i].cc);
            solve(edge[i].d, lca, edge[i].cc);
            merge(edge[i].a, edge[i].c, edge[i].cc);
        }
        int ans = find(1, fa);
        printf("%d %lld\n", sz[ans], cost[ans]);
    }
    return 0;
}
View Code

 

[I. Questionnaire]

取 $m=2$ 即可。

#include <bits/stdc++.h>

const int N = 1e5 + 7;
int n;

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        int cnt0 = 0, cnt1 = 0;
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) {
            int x;
            scanf("%d", &x);
            if (x & 1) cnt1++;
            else cnt0++;
        }
        if (cnt1 >= (n + 1) / 2)
            printf("2 1\n");
        else 
            printf("2 0\n");
    }
    return 0;
}
View Code

 

[J. Security Check]

暴力DP的话,$dp[i][j] = dp[i - 1][j - 1] + 1$,$|a[i] - b[j]| > k$

$dp[i][j] = min(dp[i - 1][j], dp[i][j - 1]) + 1$,$|a[i] - b[j]| \leq k$

 第二个可以暴力,因为每个数最多只有 $2k$ 个位置。

第一个可以二分找到一个位置 $t$ 使得 $dp[i][j] = dp[i - t][j - t] + t$。

因为最多只有 $k + 1$ 段,所以复杂度是 $O(nklogn)$。

#include <bits/stdc++.h>

const int N = 6e4 + 7;
std::vector<int> vec[N * 2];
int a[N], b[N], dp[N][25], pos[N], k;

int DP(int n, int m) {
    if (!n || !m) return n + m;
    if (std::abs(a[n] - b[m]) > k) {
        int pos = std::lower_bound(vec[m - n + N].begin(), vec[m - n + N].end(), n) - vec[m - n + N].begin();
        if (!pos) return std::max(n, m);
        pos = vec[m - n + N][pos - 1];
        return DP(pos, m - n + pos) + n - pos;
    }
    int pos = b[m] - a[n] + k;
    if (dp[m][pos] == -1)
        dp[m][pos] = std::min(DP(n - 1, m), DP(n, m - 1)) + 1;
    return dp[m][pos];
}

int main() {
    int T;
    scanf("%d", &T);
    for (int n; T--; ) {
        for (int i = 0; i < N * 2; i++)
            vec[i].clear();
        memset(dp, -1, sizeof dp);
        scanf("%d%d", &n, &k);
        for (int i = 1; i <= n; i++)
            scanf("%d", a + i), pos[a[i]] = i;
        for (int i = 1; i <= n; i++)
            scanf("%d", b + i);
        for (int i = 1; i <= n; i++)
            for (int j = std::max(1, b[i] - k); j <= std::min(b[i] + k, n); j++)
                vec[i - pos[j] + N].push_back(pos[j]);
        for (int i = 0; i < N * 2; i++)
            std::sort(vec[i].begin(), vec[i].end());
        printf("%d\n", DP(n, n));
    }
    return 0;
}
View Code

 

[K. Time To Get Up]

签到。

 

[L. Wavel Sequence]

$dp[i][j][k]$ 表示第一个考虑前 $i$ 个元素,第二个数组以第 $j$ 个元素结尾,当前上升状态为 $k$($0$/$1$)。

转移方程

$dp[i][j][k] += dp[i - 1][j][k]$

$a[i] = b[j]$ 时,

$dp[i][j][0] += \sum_{k < j \wedge b[k] < b[j]} dp[i - 1][k][1]$

$dp[i][j][1] += \sum_{k < j \wedge b[j] < b[k]} dp[i - 1][k][0]$

暴力做是 $O(n ^ 3)$ 的。

一看显然可以前缀和优化...

枚举 $i$ 之后,第二部分能转移就只有当 $b[j] = a[i]$ 的时候,即 $b[j]$ 也跟着确定下来了。

那么就用两个变量来保存 $\sum_{b[j] < a[i]} dp[i - 1][j][1]$,$\sum_{a[i] < b[j]} dp[i - 1][j][0]$ 的和即可。

#include <bits/stdc++.h>

const int MOD = 998244353;
const int N = 2200;

inline void M(int &x) {
    if (x >= MOD)
        x -= MOD;
}

int dp[N][N][2], a[N], b[N], n, m;

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        for (int i = 1; i <= n; i++)
            scanf("%d", a + i);
        for (int i = 1; i <= m; i++)
            scanf("%d", b + i);
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= m; j++)
                dp[i][j][0] = dp[i][j][1] = 0;
        for (int i = 0; i <= n + 10; i++)
            dp[i][0][0] = 1;
        for (int i = 1; i <= n; i++) {
            int sum0 = 1, sum1 = 0;
            for (int j = 1; j <= m; j++) {
                M(dp[i][j][0] += dp[i - 1][j][0]);
                M(dp[i][j][1] += dp[i - 1][j][1]);
                if (a[i] == b[j]) {
                    M(dp[i][j][0] += sum1);
                    M(dp[i][j][1] += sum0);
                }
                if (b[j] < a[i])
                    M(sum1 += dp[i - 1][j][1]);
                if (b[j] > a[i])
                    M(sum0 += dp[i - 1][j][0]);
            }
        }
        int ans = 0;
        for (int i = 1; i <= m; i++)
            M(ans += dp[n][i][0]), M(ans += dp[n][i][1]);
        printf("%d\n", ans);
    }
    return 0;
}
View Code

 

posted @ 2019-10-30 23:47  Mrzdtz220  阅读(164)  评论(0)    收藏  举报