[模板]KMP字符串匹配

kmp匹配是利用t串的border减少无效比较的,板子中的j表示当前已经完成匹配的t串的长度

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
void solve()
{
    string s, t;
    cin>>s>>t;
    auto KMP = [&](string s, string t) -> void 
    {
        int n = s.size(), m = t.size();
        vector<int> nxt(m + 1);
        s = '-' + s;
        t = '-' + t;
        for (int i = 2, j = 0; i <= m; i++) 
        {
            while (j && t[i] != t[j + 1]) j = nxt[j];
            if (t[i] == t[j + 1]) j++;
            nxt[i] = j;
        }
        for (int i = 1, j = 0; i <= n; i++) 
        {
            while (j && s[i] != t[j + 1]) j = nxt[j];
            if (s[i] == t[j + 1]) j++;
            if (j == m) 
            {
                cout << i - m + 1 << "\n"; // t 在 s 中出现的位置
                j = nxt[j];
            }
        }
        for(int i = 1; i <= m; ++i) cout<<nxt[i]<<' ';
    };
    KMP(s, t);
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T = 1;
	//cin>>T;
	while(T--)
	{
		solve();
	}
}

数一数

对于长度大于Minsz的s[i]其ans = 0, 若所有长度为Minsz的串不相等,则无非零解,否则跑一次kmp即可

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
//inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
const int p = 998244353;
void solve()
{
    int n;
    cin>>n;
    vector<pair<string,pair<int,int>>> s(n);
    for(int i = 0; i < n; ++i) cin>>s[i].x, s[i].y.x = i;
    sort(s.begin(),s.end(),[&](auto A,auto B){return A.x.size() < B.x.size();});
    auto KMP = [&](string s, string t) -> int 
    {
        int n = s.size(), m = t.size();
        int cnt = 0;
        vector<int> nxt(m + 1);
        s = '-' + s;
        t = '-' + t;
        for (int i = 2, j = 0; i <= m; i++) 
        {
            while (j && t[i] != t[j + 1]) j = nxt[j];
            if (t[i] == t[j + 1]) j++;
            nxt[i] = j;
        }
        for (int i = 1, j = 0; i <= n; i++) 
        {
            while (j && s[i] != t[j + 1]) j = nxt[j];
            if (s[i] == t[j + 1]) j++;
            if (j == m) 
            {
                //cout << i - m + 1 << "\n"; // t 在 s 中出现的位置
                cnt ++;
                j = nxt[j];
            }
        }
        return cnt;
        //for(int i = 1; i <= m; ++i) cout<<nxt[i]<<' ';
    };
    int Minsz = s[0].x.size();
    //对于长度大于Minsz的ans一定为0
    //对于长度等于Minsz的字符串集合,若不是全相等,则ans = 0
    //若对于长度等于Minsz的字符串集合全相等,只需跑一次KMP
    int flag = n;
    bool ok = true;
    for(int i = 0; i < n; ++i)
    {
        if(s[i].x.size() != Minsz) {flag = i;break;}
        if(s[i].x != s[0].x) ok = false;
    }
    if(!ok) for(int i=0;i<n;++i) cout<<"0\n";
    else
    {
        ll ans = 1;
        for(int i = flag; i < n; ++i)
        {
            ans = ans * KMP(s[i].x, s[0].x) % p;
            if(!ans) break;
        }
        for(int i = 0; i < flag; ++i) s[i].y.y = ans;
        for(int i = flag; i < n; ++i) s[i].y.y = 0;
        sort(s.begin(),s.end(),[&](auto A,auto B){return A.y.x < B.y.x;});
        for(int i = 0; i < n; ++i) cout<<s[i].y.y<<'\n';
    }
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T = 1;
	//cin>>T;
	while(T--)
	{
		solve();
	}
}

栗酱的数列

将式子变形拆开后得到两个差分相等即可,构造差分数组跑kmp

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
void solve()
{
    int n,m,k;
    cin>>n>>m>>k;
    //将式子做个变换,有(a[i] - a[i-1]) % k == (b[i] - b[i-1]) %k
    vector<int> a(n+1), b(m+1);
    for(int i = 1; i <= n; ++i) cin>>a[i], a[i] %= k;
    for(int i = 1; i <= m; ++i) cin>>b[i], b[i] %= k;
    vector<int> s(n), t(m);
    for(int i = 1; i < n; ++i) s[i] = (a[i + 1] - a[i] + k) % k;
    for(int i = 1; i < m; ++i) t[i] = (b[i] - b[i + 1] + k) % k;
    n --, m --;
    auto KMP = [&](vector<int> &s, vector<int> &t) ->int 
    {
        int cnt = 0;
        vector<int> nxt(m + 1);
        for (int i = 2, j = 0; i <= m; i++) 
        {
            while (j && t[i] != t[j + 1]) j = nxt[j];
            if (t[i] == t[j + 1]) j++;
            nxt[i] = j;
        }
        for (int i = 1, j = 0; i <= n; i++) 
        {
            while (j && s[i] != t[j + 1]) j = nxt[j];
            if (s[i] == t[j + 1]) j++;
            if (j == m) 
            {
                //cout << i - m + 1 << "\n"; // t 在 s 中出现的位置
                j = nxt[j];
                cnt ++;
            }
        }
        return cnt;
        //for(int i = 1; i <= m; ++i) cout<<nxt[i]<<' ';
    };
    cout<<KMP(s, t)<<'\n';
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T;
	cin>>T;
	while(T--)
	{
		solve();
	}
}

K匹配

可以用所有子串数减去无匹配的子串数,或者按新增贡献来计数

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
inline ll cal(int x)
{
    return (ll)x * (x + 1) / 2;
}
void solve()
{
    //用所有子串的数目减去不包含完整子串的数目
    int n, m;
    cin>>n>>m;
    string s, t;
    cin>>s>>t;
    vector<int> pos;
    auto KMP = [&](string s, string t) -> void 
    {
        vector<int> nxt(m + 1);
        s = '-' + s;
        t = '-' + t;
        for (int i = 2, j = 0; i <= m; i++) 
        {
            while (j && t[i] != t[j + 1]) j = nxt[j];
            if (t[i] == t[j + 1]) j++;
            nxt[i] = j;
        }
        for (int i = 1, j = 0; i <= n; i++) 
        {
            while (j && s[i] != t[j + 1]) j = nxt[j];
            if (s[i] == t[j + 1]) j++;
            if (j == m) 
            {
                j = nxt[j];
                pos.push_back(i - m + 1);
            }
        }
    };
    KMP(s, t);
    ll res = 0;
    //计算以i为开头的串
    for(int i = 1; i <= n; ++i)
    {
        auto it = lower_bound(pos.begin(), pos.end(), i);
        if(it == pos.end()) break;
        int tail = *it + m - 1;
        res += n - tail + 1;
    }
    cout<<res<<'\n';
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T = 1;
	//cin>>T;
	while(T--)
	{
		solve();
	}
}

如果我让你查回文你还爱我吗

将区间[l,r]拆成[l,mid],[mid+1, r]两端,对每个点i分别计算对左右的贡献,查询离线后,计算左侧贡献按右端点排序以防影响后续查询,右侧反之。注意对偶回文的讨论,也可以采用删去一次重复回文区间的方法避免讨论

// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
// inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
const int N = 2e5+10;
struct info
{
	ll sum,sz;
};
struct tag
{
	ll add;
};
//处理左右子树
info operator + (const info &l,const info &r)
{
	return {l.sum + r.sum, l.sz + r.sz};
}
//处理节点和标记的作用
info operator + (const info &v,const tag &t)
{
	return {v.sum + t.add*v.sz, v.sz};
}
//处理lazy标记的传递t1作用于t2
tag operator + (const tag &t1, const tag &t2)
{
	return {t1.add + t2.add};
}
struct node
{
	int l,r;
	info val;
	tag t;
}tr[N<<2];

//settag是在pushdown里面用表示用tagt更新p点
void settag(int p,tag t)
{
	//注意tag表示的是下面的区段未更新的值,settag当前p应该由t更新
	tr[p].val = tr[p].val + t;
	tr[p].t = t + tr[p].t;
}
void pushup(int p)
{
	tr[p].val = tr[p<<1].val + tr[p<<1|1].val;
}
void pushdown(int p)
{
	auto &t = tr[p].t;
	if(t.add)
	{
		settag(p<<1,t);
		settag(p<<1|1,t);
		t.add = 0;
	}
}
void build(int l,int r,int p)
{
	tr[p].l = l, tr[p].r = r;
	if(l == r) 
	{
		tr[p].val = {0, 1};
		return ;
	}
	int m = (l+r)/2;
	build(l,m,p<<1);
	build(m+1,r,p<<1|1);
	pushup(p);
}
void update(int L,int R,tag C,int p)
{
	int l = tr[p].l, r = tr[p].r;
	if(L<=l&&r<=R) 
	{
		settag(p,C);
		return ;
	}
	pushdown(p);
	int m = (l+r)/2;
	if(R<=m) update(L,R,C,p<<1);
	else if(L>m) update(L,R,C,p<<1|1);
	else 
	{
		update(L,m,C,p<<1);
		update(m+1,R,C,p<<1|1);
	}
	pushup(p);
}
info query(int L,int R,int p)
{
	int l = tr[p].l, r = tr[p].r;
	if(L<=l&&r<=R) return tr[p].val;
	pushdown(p);
	int m = (l+r)/2;
	if(R<=m) return query(L,R,p<<1);
	else if(L>m) return query(L,R,p<<1|1);
	else return query(L,m,p<<1) + query(m+1,R,p<<1|1);
}
pair<vector<int>,string> manacher(string &s)
{
	//p[i]为带'$'回文半径,即回文串长度,引入‘$’避免对回文串奇偶性的讨论
    //s[i]的奇回文半径为p[2*i],偶为p[2*i+1]
	int n=s.size()-1;
    //vector<char> t(n*2+2);
    string t = " ";
    vector<int> p(n*2+2);
	int m = 2*n + 1;
	t +='$';
	for(int i=1;i<=n;++i)
	t += s[i], t +='$';
	int M = 0, R = 0;
	for(int i=1;i<=m;++i)
	{
		p[i]=1;
		if(i<=R)
		{
			p[i]=min(p[M*2-i],R-i+1);
		}
		int &k=p[i];
		while(i-k>0&&i+k<=m&&t[i-k]==t[i+k]) k++;
		if(i+k-1>R) M=i,R=i+k-1;
	}
	return {p, t};
}
ll ans[N];
vector<int> v1[N],v2[N],v3[N];
void solve()
{
    int n,q;
    cin>>n>>q;
    build(1, n, 1);
    string s;
    cin>>s;
    s = ' ' + s;
    auto [p, t] = manacher(s);
	//计算i对于左右两侧的贡献
    for(int i=1;i<=n;++i)
    {
        //考虑奇回文
        int ro = p[i*2]/2;
        if(ro)
        {
        	v1[i].push_back(i - ro + 1);
        	v2[i].push_back(i + ro - 1);
        }
        //考虑偶回文
        int re = p[i*2+1]/2;
        if(re)
        {
        	v1[i].push_back(i - re + 1);
        	//对右向以i为左回文中心的偶回文应该是[i + 1, i + 1 + re - 1]
        	v3[i].push_back(i + re);	
        }
    }
    int m = t.size();
    vector<array<int,3>> Lq,Rq;
    for(int i=1;i<=q;++i)
    {
        int l,r;
        cin>>l>>r;
        int mid = (l + r)/2;
        if((r - l + 1) & 1) mid --;
        if(mid >= l) Lq.push_back({l, mid, i});
        if(r >= mid + 1) Rq.push_back({mid + 1, r, i});
    }
    sort(Lq.begin(),Lq.end(),[&](auto A, auto B){return A[1] < B[1];});
    //对于左回文的区间贡献查询按照区间右端点排序保证更新时不会出现对后面有错误影响
    sort(Rq.begin(),Rq.end(),[&](auto A, auto B){return A[0] > B[0];});
    //for(auto t:v1[])
    int j = 0;
    for(auto [l, r, id]:Lq)
    {
    	while(j < r)
    	{
    		j ++;
    		for(auto t:v1[j])
    		{
    			update(t, j, tag{1}, 1);
    		}
    	}
    	ans[id] += query(l, r, 1).sum;
    	//if(id == 7) cout<<l<<' '<<r<<' '<<ans[id]<<'\n';
    }
    memset(tr, 0, sizeof tr);
    build(1, n, 1);
    j = n + 1;
    for(auto [l, r, id]:Rq)
    {
    	while(j > l)
    	{
    		j --;
    		for(auto t:v2[j])
    		{
    			update(j, t, tag{1}, 1);
    		}
    		for(auto t:v3[j])
    		{
    			update(j + 1, t, tag{1}, 1);
    		}
    	}
    	ans[id] += query(l, r, 1).sum;
    	//if(id == 7) cout<<l<<' '<<r<<' '<<ans[id]<<'\n';
    }
    for(int i=1;i<=q;++i) cout<<ans[i]<<'\n';
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T = 1;
	//cin>>T;
	while(T--)
	{
		solve();
	}
}

 posted on 2025-02-22 10:53  ruoye123456  阅读(12)  评论(0)    收藏  举报