洛谷P4173 残缺的字符串

https://www.luogu.com.cn/problem/P4173

给你 \(a, b\) 串,\(|a| = n, |b| = m\), 求 \(a\) 中与 \(b\) 匹配的串个数以及下标

当不存在通配符时,就是求 \(\sum_{\sum_{i = 0}^{ m - 1} a(x - m + i + 1) - b(i) = 0}\)​​, 即每种可行解要满足 \(p(x) = \sum_{i = 0}^{m - 1} {(a(x - m + i + 1) - b(i))}^2 = 0\)​​

分解求和公式 \(-》 ~~\sum_{i = 0}^{m-1}{b(i)}^2 + \sum_{i = x - m + 1}^{x}{a(i)}^2 + \sum_{i = 0}^{m - 1}a(x - m + i + 1)b(i)\),把 \(b\) reverse 一下,就变成了 \(a(x - m + i + 1)(m - i - 1)\),显然此时满足 \(a(i)b(j), i+j=x\)​​​​

可以通过卷积算

\(T = \sum_{i = 0}^{lenb - 1} {b(i)}^2 \\ f(x) = \sum_{i = 0}^x {a(i)}^2 \\ g(x) = \sum_{i+j=x} a(i)b(j)\)​​​

\(p(x) = T + f(x) - f(x - lena) - 2g(x)\)​​

则每个 \(p(x) = 0\) 就是以x为结尾的子串是解.

const int mod = 998244353;
template<const int mod = 998244353, const int g = 3>
class NTT{
public:
    using ll=long long;
    int l{}, inv{}; vector<int>W, rev;
    inline int ksm(ll a, ll b) {
        ll res = 1;
        for(; b; b >>= 1, a = a * a % mod)    if(b & 1)   res = res * a % mod;
        return res;
    }
    inline int add(int a, int b) {return a += b, a >= mod ? a - mod : a;}
    inline int sub(int a, int b) {return a -= b, a < 0 ? a + mod : a;}
    inline void init(int n) {
        for(l = 2; l < n; l <<= 1);
        int wn = ksm(g, mod / l);
        W.resize(l + 1); rev.resize(l + 1);
        W[l >> 1] = 1;  inv = mod - (mod - 1) / l;
        for(int i = l / 2 + 1; i <= l; ++ i)    W[i] = (ll)W[i - 1] * wn % mod;
        for(int i = l / 2 - 1; i >= 1; -- i)    W[i] = W[i << 1];
        for(int i = 0; i <= l; ++ i)    rev[i] = rev[i >> 1] >> 1 | (i & 1 ? l >> 1 : 0);
    }
    inline void calc(int &x, int &y, int z) {z = (ll)y * z % mod; y = sub(x, z); x = add(x, z); }
    inline void ntt(vector<int> &a, int ty) {
        a.resize(l);
        if(ty == -1)    reverse(a.begin() + 1, a.end());
        for(int i = 0; i < l; ++ i) if(i < rev[i])  swap(a[i], a[rev[i]]);
        for(int k = 1; k < l; k <<= 1)
            for(int i = 0; i < l; i += k << 1)
                for(int j = 0; j < k; ++ j) calc(a[i + j], a[i + j + k], W[j + k]);
        if(ty == -1)    for(int i = 0; i < l; ++ i) a[i] = (ll)a[i] * inv % mod;
    }
};
NTT<mod, 3> f;


int T;
void run() {
    scanf("%s %s", s, p);
    int lens = strlen(s), lenp = strlen(p);
    reverse(p, p + lenp);
    f.init(lenp + lens);
    vector<int> a(lenp), b(lens), d(lens), c(f.l);
    for(int i = 0; i < lenp; ++ i)  a[i] = p[i] - 'a' + 1;
    for(int i = 0; i < lens; ++ i)  b[i] = s[i] - 'a' + 1;
    T = 0;  d[0] = p[0] * p[0];
    for(int i = 0; i < lenp; ++ i)  T += a[i] * a[i];
    for(int i = 1; i < lens; ++ i)  d[i] = d[i - 1] + b[i] * b[i];
    f.ntt(a, 1);    f.ntt(b, 1);
    for(int i = 0; i < f.l; ++ i)   c[i] = 1ll * a[i] * b[i] % mod;
    f.ntt(c, -1);   bool f = 1;
    for(int i = 0; i < lenp; ++ i)  if(s[i] != p[i])    f = -1;
    if(f)   printf("%d ", 1);
    for(int i = lenp - 1; i < lens; ++ i) {//1号位是匹配不出来的,因为d[-1]!,自己判一下
        int now = (T + d[i] - d[i - lenp] - 2 * c[i]) % mod;
        if(now < 0) now += mod;
        if(now == 0)    printf("%d ", i - lenp + 2);
    }
    puts("");
    return ;
}

那么考虑通配符的影响,显然对于统配符的值,可以作为0,也就是0*任何数都是0,都能匹配上
那么原来的公式可以变为 \(\sum_{i = 0}^{m - 1} (a_i + b_j)^2a_ib_j = 0\)即可
将求和展开,可以得到就是原来的三项都乘上了 \(a_ib_j\),通过对 \(b~reverse\) 后每项都是 \(i + j = x\)的形式
那么对三项都做FFT,相加即可

\[F(x) = \sum_{i + j = x} {A(i)}^3B(j) + \sum_{i + j = x} {B(i)}^3A(j) - \sum_{i + j = x} 2{A(i)}^2{B(j)}^2 \]

void run3() {
    int lens, lenp;
    scanf("%d %d", &lenp, &lens);
    scanf("%s %s", p, s);
    reverse(p, p + lenp);
    f.init(lenp + lens);
    vector<int> a(lenp), b(lens), A(lenp), B(lens), c(f.l);
    for(int i = 0; i < lenp; ++ i)  A[i] = p[i] == '*' ? 0 : p[i] - 'a' + 1;
    for(int i = 0; i < lens; ++ i)  B[i] = s[i] == '*' ? 0 : s[i] - 'a' + 1;

    for(int i = 0; i < lenp; ++ i)  a[i] = A[i] * A[i] * A[i];
    for(int i = 0; i < lens; ++ i)  b[i] = B[i];
    f.ntt(a, 1);    f.ntt(b, 1);
    for(int i = 0; i < f.l; ++ i) c[i] = 1ll * a[i] * b[i] % mod;

    a.resize(lenp); b.resize(lens);
    for(int i = 0; i < lenp; ++ i)  a[i] = A[i];
    for(int i = 0; i < lens; ++ i)  b[i] = B[i] * B[i] * B[i];
    f.ntt(a, 1);        f.ntt(b, 1);
    for(int i = 0; i < f.l; ++ i) c[i] = f.add(c[i], 1ll * a[i] * b[i] % mod);

    a.resize(lenp); b.resize(lens);
    for(int i = 0; i < lenp; ++ i)  a[i] = A[i] * A[i];
    for(int i = 0; i < lens; ++ i)  b[i] = B[i] * B[i];
    f.ntt(a, 1);       f.ntt(b, 1);
    for(int i = 0; i < f.l; ++ i) c[i] = f.sub(c[i], 2ll * a[i] * b[i] % mod);
    f.ntt(c, -1);
    vector<int> v;
    for(int i = lenp - 1; i < lens; ++ i) if(!c[i])   v.push_back(i - lenp + 2);
    printf("%d\n", v.size());
    for(int i = 0; i < v.size(); ++ i)  printf("%d%c", v[i], " \n"[i == v.size() - 1]);
    return ;
}

fft:

struct Complex{
    double x, y;
    Complex(double x = 0, double y = 0) : x(x), y(y){}
    Complex conj(){return Complex(x,-y);}
    Complex operator * (const Complex a) {
        return Complex(x * a.x - y * a.y, x * a.y + y * a.x);
    }
    Complex operator + (const Complex a) {
        return Complex(x + a.x, y + a.y);
    }
    Complex operator - (const Complex a) {
        return Complex(x - a.x, y - a.y);
    }
    Complex operator / (const double a) {
        return Complex(x / a, y / a);
    }
    Complex operator * (const double a) {
        return Complex(x * a, y * a);
    }
};

class FFT{
public:
    int l, bit;
    const double PI = acos(-1.0);
    vector<int> rev;
    vector<Complex> w, ww;
    inline void init(int n) {
        l = 1; bit = 0;
        while(l < n) l <<= 1, ++bit;
        rev.resize(l + 1);
        //w.resize(l + 1);
        //ww.resize(l + 1);
        //for(int i = 0; i <= l; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
        for(int i = 0; i <= l; ++ i)    {
            rev[i] = rev[i >> 1] >> 1 | (i & 1 ? l >> 1 : 0);
            //ww[i] = w[i] = {cos(PI * i / l), sin(PI * i / l)};
            //ww[i].y = -ww[i].y;
        }
    }
    void fft(vector<Complex> &a, int type) {
        a.resize(l);
        for(int i = 0; i < l; ++ i) if(i < rev[i])  swap(a[i], a[rev[i]]);//ºûµû±ä»»
        for(int mid = 1, d = 0, z = __builtin_ctz(l); mid < l; mid <<= 1, ++ d)  {
            Complex wn(cos(PI / mid), type * sin(PI / mid));
            for(int i = mid << 1, pos = 0; pos < l; pos += i) {
                Complex w(1, 0);
                for(int k = 0; k < mid; ++ k, w = w * wn) {
                    /*Complex &x = a[pos + k + mid];
                    Complex &y = a[pos + k];
                    Complex t = (type == 1 ? w[k<<(z-d)] : ww[k<<(z-d)]) * x;
                    x = y - t;  y = y + t;*/

                    Complex x = a[pos + k];
                    Complex y = w * a[pos + k + mid];
                    a[pos + k] = x + y;
                    a[pos + k + mid] = x - y;
                }
            }
        }
        if(type == -1)  for(int i = 0; i < l; ++ i) a[i].x /= l;//a[i].y /= len;
    }
};
FFT tt;

void run4() {
    int n, m;
    scanf("%d %d", &n, &m);
    scanf("%s %s", p, s);
    reverse(p, p + n);
	tt.init(n + m);
	vector<int> A(n), B(m);
	vector<Complex> a(n), b(m), P(tt.l);
	for(int i = 0; i < n; ++ i) A[i] = (p[i] != '*') ? (p[i] - 'a' + 1) : 0;
	for(int i = 0; i < m; ++ i) B[i] = (s[i] != '*') ? (s[i] - 'a' + 1) : 0;

	for(int i = 0; i < n; ++ i) a[i] = {A[i] * A[i] * A[i], 0};
	for(int i = 0; i < m; ++ i) b[i] = {B[i], 0};
	tt.fft(a, 1); tt.fft(b, 1);
	for(int i = 0; i < tt.l; ++ i) P[i] = a[i] * b[i];

	a.resize(n); b.resize(m);
	for(int i = 0; i < n; ++ i) a[i]= {A[i], 0};
	for(int i = 0; i < m; ++ i) b[i] = {B[i] * B[i] * B[i], 0};
	tt.fft(a, 1); tt.fft(b, 1);
	for(int i = 0; i < tt.l; ++ i) P[i] = P[i] + a[i] * b[i];

    a.resize(n); b.resize(m);
	for(int i = 0; i < n; ++ i) a[i] = {A[i] * A[i], 0};
	for(int i = 0; i < m; ++ i) b[i] = {B[i] * B[i], 0};
	tt.fft(a, 1); tt.fft(b, 1);
	Complex k(2, 0);
	for(int i = 0; i < tt.l; ++ i) P[i] = P[i] - a[i] * b[i] * k;
	tt.fft(P, -1);

    vector<int> v;
	for(int i = n - 1; i < m; ++ i) if(fabs(P[i].x) < 0.5)  v.push_back(i - n + 2);
	printf("%d\n", v.size());
	for(auto it : v)    printf("%d ", it);  puts("");
}

还有 \(bitset\) 的解法,但是复杂度比较高,是n^2的,据说 \(bitset\) 位运算能通过 \(cpu\) 优化,就是比如64位机就能同时算64位,这样复杂度能除个64,再加上位运算非常快

bitset<300005> part[26], cur;
void run2() {
    int lens, lenp;
    scanf("%d %d", &lenp, &lens);
    cur.set();
    for(int i = 0; i < 26; ++ i)    part[i].set();//置1
    scanf("%s %s", p, s);
    for(int i = 0; i < lenp; ++ i)  {
        if(p[i] == '*')   for(int j = 0; j < 26; ++ j)    part[j].reset(i);//第i位置0
        else            part[p[i] - 'a'].reset(i);
    }
    int ans = 0;    vector<int> v;
    for(int i = 0; i < lens; ++ i) {//暴力O(n^2)  bitset->优化
        if(s[i] == '*') cur = cur << 1;//全匹配
        else            cur = cur << 1 | part[s[i] - 'a'];//判断是否匹配
        if(cur[lenp - 1] == 0)  v.push_back(i - lenp + 2);
    }
    printf("%d\n", v.size());
    for(int i = 0; i < v.size(); ++ i)  printf("%d%c", v[i], " \n"[i == v.size() - 1]);
    return ;
}

posted @ 2021-07-28 22:21  wlhp  阅读(9)  评论(0)    收藏  举报