KMP

KMP

Border的性质

\(T\)\(S\) 的周期,则 \(|S| - T\)\(S\) 的 Border
周期定理:若p, q均为串 \(S\) 的周期,则 (p,q) 也为 \(S\) 的周期
一个串的 Border 数量时 \(O(N)\) 个,但他们组成了 \(O(\log N)\) 个等差数列

KMP推广
exKMP,aka Z算法
KMP自动机,Border树
AC自动机 KMP 多串模式(字典)
Trie图 KMP多串

模板

luogu模板题

const int N = 1e6 + 5;
char s1[N], s2[N];//文本串,匹配串
int nxt[N];
/*
kmp算法原理
首先进行一个匹配
1.如果当前字符和模式串的当前字符匹配
双指针各自前进一位
2.如果当前字符和模式串的当前字符失配
模式串移动到前一位的最大border
3.
*/
void getnext(char *s,int len){
    for (int i = 2; i <= len; i++) {
        int j = nxt[i - 1];
        while (j != 0 && s[j + 1] != s[i])
            j = nxt[j];
        if (s[j + 1] == s[i])
            nxt[i] = j + 1;
        else
            nxt[i] = 0;
    }
}
void kmp(char* t,char* s,int len1,int len2){
    for (int i = 1, j = 0; i <= len1; i++) {
        while (j != 0 && s[j + 1] != t[i])
            j = nxt[j];
        if (t[i] == s[j + 1])
            j++;
        if (j == len2) {
            // ans.push_back(i - len2 + 1);
            printf("%d\n",i-len2+1);
            j = nxt[j];
        }
    }
}
int main() {
    scanf("%s", s1 + 1);
    scanf("%s", s2 + 1);
    nxt[1] = 0;
    int n = strlen(s1 + 1), m = strlen(s2 + 1);
    getnext(s2,m);
    kmp(s1,s2,n,m);
    // for (auto v : ans)
    //     printf("%d\n", v);
    for (int i = 1; i <= m; i++)
        printf("%d ", nxt[i]);
    return 0;
}

例题

差分

一个长度为n的数列A,一个长度为m的数列B,现在询问A中有多少个长度为m的连续子序列A',
满足\((a'1+b1)\mod k = (a'2+b2)\mod k = \dots = (a'm + bm)\mod k\).

两两相减得

\[(a'1-a'2)\mod k = -(b1-b2)\mod k \]

\[(a'2-a'3)\mod k = -(b2-b3)\mod k \]

原问题等价于,先对 \(A\)\(B\) 分别作差取模,匹配差分串 \(A'\) 和差分串 \(B'\)

while (T--) {
	cin >> l1 >> l2 >> k;
	for(int i=1;i<=l1;i++) cin>> t[i];
	for(int i=1;i<l1;i++)  t[i] = (k + t[i]%k - t[i+1]%k)%k;
	for(int i=1;i<=l2;i++) cin>> s[i];
	for(int i=1;i<l2;i++)  s[i] = (k - s[i]%k + s[i+1]%k)%k;
	l1--,l2--;
	getnext(s);
	cout<<kmp(t,s)<<endl;
}

子串乘积

定义 \(f(s,t) = t\) 的子串中,与 \(s\) 相等的串的个数。如 \(f("ac","acacac")=3\),\(f("bab","babab")=2\)
现在给出 \(n\) 个字符串,第 \(i\) 个字符串为 \(s_i\),对 \(\forall 1 \leq i \leq n\),求出 \(\prod_{j=1}^n {f(s_i,s_j)}\)

首先,除了最短的串,其他串的答案一定是零
如果有多个最短串,且最短串之间存在不同的串,那么答案也为零
因此只需要拿一个最短串和其他所有串做匹配,可以在 \(O(n)\) 内解决本题。

点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define endl '\n'
#define int ll
const int N = 2e6 + 6;
const int mod = 998244353;
using namespace std;
vector<string> str;
int n, nxt[N], len[N];
void getnext(string& s, int len) {
    for (int i = 2; i <= len; i++) {
        int j = nxt[i - 1];
        while (j != 0 && s[j + 1] != s[i])
            j = nxt[j];
        if (s[j + 1] == s[i])
            nxt[i] = j + 1;
        else
            nxt[i] = 0;
    }
}
int kmp(string& t, string& s, int len1, int len2) {
    int tmp = 0;
    for (int i = 1, j = 0; i <= len1; i++) {
        while (j != 0 && s[j + 1] != t[i])
            j = nxt[j];
        if (t[i] == s[j + 1])
            j++;
        if (j == len2) {
            tmp++;
            j = nxt[j];
        }
    }
    return tmp;
}
signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);

    cin >> n;
    string s;
    int mlen = N;

    for (int i = 0; i < n; i++) {
        cin >> s;
		len[i] = s.length();
        s = "#" + s;
        str.push_back(s);
        mlen = min(mlen, len[i]);
    }

    vector<string> mn;
    for (int i = 0; i < n; i++) {
        if (len[i] == mlen) {
            mn.push_back(str[i]);
        }
    }

    getnext(mn[0], mlen);
    // mlen--;
    int flag = 1;
	// cout<<mlen<<endl;
    for (int i = 1; i < mn.size(); i++) {
        if (!kmp(mn[i], mn[0], mlen, mlen)) {
            flag = 0;
            break;
        }
    }
    ll ans = 1;
    if (flag) {
        for (int i = 0; i < n; i++) {
            int tmp = kmp(str[i], mn[0], len[i], mlen);
            ans *= tmp;
            ans %= mod;
			// cout<<len[i]<<" "<<mlen<<endl;
                        // cout << i << " " << tmp << endl;
        }
    }
    // cout<<(flag?1:-1)<<": "<<ans<<endl;
    for (int i = 0; i < n; i++) {
        if (flag && len[i] == mlen) {
            cout << ans << endl;
        } else
            cout << 0 << endl;
    }
    return 0;
}

子串个数

给两个串 \(T, S\),问 \(T\) 有多少个子串中包含 \(S\)
可以容斥,首先 kmp 求出所有出现的下标,之后每一段下标间隔的贡献就是 间隔\(\times\)该位置到末尾的长度

	pos.push_back(0);
    getnext(s, l2);
    kmp(t, s, l1, l2);

    ll ans = 0;
    for (int i = 1; i < pos.size(); i++) {
        ans += 1ll * (pos[i] - pos[i - 1]) * (l1 - l2 - pos[i] + 2);
        // cout<<(pos[i]-pos[i-1])*(l1-l2-pos[i]+2)<<endl;
    }

循环子矩阵(多校)

给一个 \(n*m\) 的矩阵,找到这个矩阵的周期 \(p,q\),使得 \(A[i-p][j-q] = A[i][j]\), 定义矩阵的cost为 \(max(A[i][j])*(p+1)*(q+1)\),输出最小cost

Input:
2 5
acaca
acaca
3 9 2 8 7
4 5 7 3 1
Output:
3(1+1)(2+1) = 18

最小周期一定是行和列同时满足,不难发现行列可以分开处理,分别求出求出行与列最大的周期。求行的循环节时,我们可以对每一列做一次kmp(这里需要先把没列/行的串哈希一下),找到这些列的最大公共循环节,列同理。这样就确定了 \(p, q\).
接下来就是确定每个 \(p*q\) 大小的子矩阵中,局部最大值的最小值。如果只有一维,则是一个滑动窗口问题。二维则只需要先处理一维,同时只中保留该情况最大值,将处理后的矩阵再做另一维滑动窗口即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5;
int line[N], col[N];
const int seed = 131;
const int mod = 1e9 + 7;
char s[N];
vector<int> val[N];
int nxt[N];
int n, m, key[26];
inline int getnext(int *a,int n) {
	nxt[0]=0;
	int i,c,p;
	for (i=1;i<n&&a[i]==a[i+1];i++);i--;
	nxt[2]=i;c=2,p=2+nxt[2]-1;
	if (nxt[2]==n-1) return 1;
	for (i=3;i<=n;i++) {
		if (p<i||p-i+1<=nxt[i-c+1]) {
			nxt[i]=max(p-i+1,0);
			while (a[nxt[i]+i]==a[nxt[i]+1]) nxt[i]++;
			if (i+nxt[i]>n) return i-1;
			c=i,p=i+nxt[i]-1;
		}
		else nxt[i]=nxt[i-c+1];
	}
	return n;
}
int getval(int x, int y) {
    return val[x][y];
}
deque<int> win(N);
void getmx(int n, int len, int p) {
    win.clear();
    for (int i = 1; i <= n; i++) {
        while (!win.empty() && val[p][win.back()] < val[p][i])
            win.pop_back();
        while (!win.empty() && win.front() < i - len + 1)
            win.pop_front();
        win.push_back(i);
        if (i + 1 >= len)
            val[p][i - len + 1] = val[p][win.front()];
    }
}
int main() {
    cin >> n >> m;
    char c;
    for (int i = 0; i < 26; i++)
        key[i] = rand();
    for (int i = 1; i <= n; i++) {
        cin >> s + 1;
        for (int j = 1; j <= m; j++) {
            c = s[j];
            line[i] = (1ll * line[i] * seed + key[c - 'a']) % mod;
            col[j] = (1ll * col[j] * seed + key[c - 'a']) % mod;
        }
    }

    int x = 0, y = 0;
    int tmp;
    val[0].resize(m + 1);
    for (int i = 1; i <= n; i++) {
        val[i].resize(m + 1);
        for (int j = 1; j <= m; j++)
            cin >> tmp, val[i][j] = tmp;
    }

    x = getnext(col, m);
    y = getnext(line, n);

    for (int i = 1; i <= n; i++) {
        getmx(m, x, i);
    }
    for (int i = 1; i <= m; i++) {
        win.clear();
        for (int j = 1; j <= n; j++) {
            // cout<<i<<","<<j<<endl;
            while (!win.empty() && val[win.back()][i] < val[j][i])
                win.pop_back();
            while (!win.empty() && win.front() < j - y + 1)
                win.pop_front();
            win.push_back(j);
            // cout<< j - y + 1<<" "<<val[win.front()][i]<<endl;
            if (j + 1 >= y)
                val[j + 1 - y][i] = val[win.front()][i];
            // cout<<"---\n";
        }
    }
    int ans = 0x3f3f3f3f;
    for (int i = 1; i <= n - y + 1; i++) {
        for (int j = 1; j <= m - x + 1; j++) {
            ans = min(ans, val[i][j]);

//             cout << val[i][j] << " ";
        }
//         cout << endl;
    }
    cout << 1ll * (x + 1) * (y + 1) * ans << endl;
}

题单地址

posted @ 2022-01-27 13:07  FushimiYuki  阅读(64)  评论(0)    收藏  举报