2020牛客暑期多校训练营(第四场) C - Count New String (字符串,广义后缀自动机,序列自动机)

Count New String

题意:

  1. 定义字符串函数 \(f(S,x,y)(1\le x\le y\le n)\),返回一个长度为y-x+1的字符串,第 i 位是 \(max_{i=x...x+k-1}S_i\)
  2. 设集合\(A = {f(f(S, x_1,y_1),x_2-x_1+1,y_2-x_1+1)|1\le x_1 \le x_2 \le y_2 \le y_2 \le n}\)
  3. 求集合A 的大小
  4. \(N\le 1e5\) 字符集大小 <=10

分析:
先放出官方题解

方法一

核心点1比较容易想到,进一步可以观察到他们之间有很大一部分后面是重复的,感性的想到如果倒着插入Trie树,在Trie树上面的节点可能不会很多。具体证明来讲,当前字符 \(i\), 最近的大于它的字符的位置是 \(j (j > i)\), 那么在将位置 \(i + 1\) 的字符插入到Trie树之后,还要把长度 \((j - i)\) 的字符串插入到Trie树中。考虑一个 \(j_2\),再找一个最大的 \(j_1\)\(S_{j_1} \ge S_{j_2}\),那么\([j_1,j_2]\)这个区间最多利用10次,所以Trie树节点不超过10N。然后就是常规的在Trie树上面建立广义后缀自动机,扫一遍所有节点即可得到本质不同的子串个数

#include<bits/stdc++.h>
//#define ONLINE_JUDGE
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "\033[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "\033[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }

const int N = 2000000 + 5; //N为字符串长度两倍
const int P = 10; 
char s[N];
struct node{
    int link, len, trans[P];
	void clear(){
        memset(trans,0, sizeof trans);
		link = len = 0;
	}
};
struct SAM{
    node S[N];
    int p, np, size;
    int b[N], c[N];
	SAM():p(1),np(1),size(1){}
	void clear(){
		for(int i=0;i<=size;i++)S[i].clear();
		np = size = p = 1;
	}
	void insert(char ch){
		int x = ch - 'a';
		np = ++size;
		S[np].len = S[p].len + 1;
		while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
		if(p == 0)S[np].link = 1;
		else{
			int q, nq;
			q = S[p].trans[x];
			if(S[q].len == S[p].len + 1) S[np].link = q;
			else{
                nq = ++size;
				S[nq] = S[q]; 
                S[nq].len = S[p].len + 1;
				S[np].link = S[q].link = nq;
				while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
			}
		}
		p = np;
	}
}sam;
int pos[N];
int n;

int main(){
#ifndef ONLINE_JUDGE
freopen("i.in","r",stdin);
//  freopen("o.out,"w",stdout);
#endif
    scanf("%s", s+1);
    n = strlen(s+1);
    stack<int> st;
    pos[n+1] = 1;
    st.push(n+1);
    for(int i=n;i>=1;i--){
        while(st.size() != 1 && s[st.top()] < s[i]) st.pop();
        int k = st.top();
        sam.p = pos[k];
        for(int j=i;j<k;j++) sam.insert(s[i]);
        pos[i] = sam.p;
        st.push(i);
    }

    ll res = 0;
    for(int i = 2;i<=sam.size;i++){
        res += sam.S[i].len - sam.S[sam.S[i].link].len;
    }
    printf("%lld\n", res);
    return 0;
}

方法二

然后讲解一下场上自己想出来的一个做法:

观察到\(f(S, x, y)\),最多只会有 10 段连续的一样的字符。采取这样的键值对表示法:从序列\(abcdefghij\)中抽连续的某一个部分,比如aabbc可以用\(\{"abc", [2, 2, 1]\}\) 来表示, bccd 可以用\(\{"bcd", [1, 2, 1]\}\)来表示。注意到这样的表示中,缩减的字符串最多只会有1024种。每个字符串对应的序列的长度最大是10。

对于一个后缀\(f(S, i, n)\),从中可以提取出来的键值对最多有100个。所以整个字符串,可以提取出来的键值对极限只有1e7,实际上会小很多。

对于键值相同的情况,比如 bccdde 与 bcccde, bccddee,将它们一起处理,"bcde" 对应的值是一个二维数组\([[1, 2, 2, 1], [1, 3, 1, 1], [1, 2, 2, 2]]\)。然后重点思考如何处理这个二位数组,因为我们利用"bcde" 去产生子串,具体的选择上面3个串中的某个时,c和d字符的个数是确定的,b和e的个数可以控制。

然后进一步观察,bccdde与bccddee是同一类,因为其中c和d的数量是一样的,而bcccde是另外一类。我们要对他们进行归类,也就是对二维数组\([[1, 2, 2, 1], [1, 3, 1, 1], [1, 2, 2, 2]]\)中的一位数组元素归类。

如何归类?通过他们的第2到len-1个元素的值进行归类。这个操作可以用set实现,具体来说,起初他们都属于一个集合,然后遍历第2到len-1个数字,遍历所有集合,然后根据当前这个数字,对集合进行拆分即可。

最后考虑如何计算对答案的贡献即可。比如上面的例子,bccdde与bccddee归为一类之后,只考虑b和e的个数:\(\{[1, 1], [1, 2]\}\), 对答案的贡献就是2。(这里处理起来还有一些细节需要注意)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "\033[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "\033[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }
const int N = 100000 + 5;
int n;
char s[N];
int next_big[N]; // 下个比它大的字符
int next_not_equ[N]; // 下个不等于的位置
map<string, int> mp;
vector<vector<int>> vv[2000];
set<int> st[2000010];
int vis[N];
int totn;
int get(string str){
    if(mp.count(str)) return mp[str];
    mp[str] = ++totn;
    return totn;
}
ll solve(){
    ll res = 0;
    for(int ii = 1; ii <= totn; ii ++){
        vector<vector<int>> &tt = vv[ii];
        int len = tt[0].size();
        int n = tt.size();
        // 一开始所有的都属于一个集合,这个集合里面应该有一个最大值
        int set_count = 1;
        ll Max = 0;
        for(int i=0;i<n;i++) st[i].clear();
        for(int i = 0; i < n ;i++){
            st[0].insert(i);
            Max = max(Max, 1ll * tt[i][0]);
        }
        if(len == 1) {
            res += Max;
            continue;
        }
        for(int r = 1; r < len-1; r ++){
            // 拆集合了
            int new_set_count = set_count;
            for(int j = 0; j < set_count; j++){
                int num = 0;
                vector<int> tmp;
                set<int> temp = st[j];
                st[j].clear();
                for(auto x : temp){
                    tmp.push_back(tt[x][r]);
                    if(vis[tt[x][r]] == -1) {
                        if(num == 0) vis[tt[x][r]] = j;
                        else {
                            vis[tt[x][r]] = new_set_count ++;
                        }
                        num ++;
                    }
                    int index = vis[tt[x][r]];
                    st[index].insert(x);
                }
                for(auto x : tmp) {
                    vis[x] = -1;
                }
            }
            set_count = new_set_count;
        }
        for(int j = 0; j < set_count; j++){
            vector<pair<int,int>> t;
            for(auto x : st[j]){
                t.push_back({tt[x][0], tt[x][len-1]});
            }
            sort(t.begin(), t.end());
            int rmax = t.back().second;
            for(int i=t.size()-1;i>=0;i--){
                t[i].second = max(t[i].second, rmax);
                rmax = t[i].second;
            }
            for(int i=0;i<t.size();i++){
                ll l = t[i].first, r = t[i].second;
                if(i != 0 && l == t[i-1].first) continue;
                if(i == 0) {
                    res += l * r;
                } else 
                    res += (l - t[i-1].first) * r;
            }
        }
    }
    return res;
}


int main(){
#ifndef ONLINE_JUDGE
freopen("i.in","r",stdin);
//  freopen("o.out,"w",stdout);
#endif
    memset(vis, -1, sizeof vis);
    scanf("%s", s+1);
    n = strlen(s+1);
    int aft[10];
    for(int i=0;i<10;i++) aft[i] = n + 1;
    for(int i = n; i >= 1; i--){
        s[i] = s[i] - 'a';
    }
    aft[s[n]] = n;
    next_big[n] = next_not_equ[n] = n + 1;
    for(int i=n-1;i>=1;i--){
        next_big[i] = n + 1;
        for(int j=s[i]+1;j<10;j++){
            if(aft[j] != n + 1) {
                next_big[i] = min(next_big[i], aft[j]); // 要找最近的
            }
        }
        aft[s[i]] = i;
        if(s[i] == s[i+1]) next_not_equ[i] = next_not_equ[i+1];
        else next_not_equ[i] = i + 1;
    }
    int l = 1;
    while(l <= n){
        string str = "";
        int r = l, pre = 0;
        vector<int> v;
        do {
            pre = r;
            r = next_big[r];
            str += char('0' + s[pre]);
            v.push_back(r - pre); 
            int id = get(str);
            vv[id].push_back(v);      
        }while(r <= n);
        l = next_not_equ[l];
    }
    printf("%lld\n", solve());
    return 0;
}
posted @ 2020-07-20 23:26  kpole  阅读(305)  评论(0编辑  收藏  举报