【HDU 5785】Interesting

题意:
给出一个字符串\(S\),询问有多少对\((i, j, k)\)满足\(S[i \cdots j]\)\(S[j + 1 \cdots k]\)是一个回文串,输出\(\sum\sum i \cdot k\)

思路:
考虑\(a\)为以\(i - 1\)结尾的回文串的长度,\(b\)为以\(i\)为开头回文串的长度,那么变换式子有:

\[\begin{eqnarray*} &&\sum\sum i \cdot k \\ &=& \sum (i - 1 - a + 1) \cdot (i + b - 1) \\ &=& \sum i^2 + i \cdot (b - a - 1) + a \cdot (1 - b) \end{eqnarray*} \]

那么令:

  • \(prenum_i\)表示以\(i\)结尾的回文串的个数
  • \(presum_i\)表示以\(i\)结尾的回文串的长度和
  • \(sufnum_i\)表示以\(i\)开头的回文串的个数
  • \(sufsum_i\)表示以\(i\)开头的回文串的长度和

那么最终式子为:

\[\begin{eqnarray*} i^2 \cdot (prenum_{i - 1} \cdot sufnum_i) + i \cdot (sufsum_{i} \cdot prenum_{i - 1} - presum_{i - 1} \cdot sufnum_i - prenum_{i - 1} \cdot sufnum_i) + presum_{i - 1} \cdot (sufnum - sufsum) \end{eqnarray*} \]

然后回文树跑两遍即可。

代码:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define N 1000005
#define ALP 26
const int p = 1e9 + 7;
template <class T1, class T2>
void add(T1 &x, T2 y) {
	x += y;
	if (x >= p) x -= p;
}
struct PAM{         
    int Next[N][ALP]; 
    int fail[N];   
    int num[N];      
	int sum[N];    
    int len[N];      
    char s[N];        
    int last;       
    int n;          
    int p;           
 
    int newnode(int w){
        for(int i=0;i<ALP;i++)
            Next[p][i] = 0;
        //cnt[p] = 0;
        num[p] = 0;
        len[p] = w;
        return p++;
    }
    void init(){
        p = 0;
        newnode(0);
        newnode(-1);
        last = 0;
        n = 0;
        s[n] = -1; 		
        fail[0] = 1;
    }
    int get_fail(int x){ 
        while(s[n-len[x]-1] != s[n]) x = fail[x];
        return x;
    }
    bool insert(int c){
		bool F = 0;
        c -= 'a';
        s[++n] = c;
		int cur = get_fail(last);
		if(!Next[cur][c]){
            int now = newnode(len[cur]+2);
            fail[now] = Next[get_fail(fail[cur])][c];
            Next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
			sum[now] = 0;
			add(sum[now], sum[fail[now]] + len[now]);
			F = 1;
		}
        last = Next[cur][c];
		return F;
    }
}pam; 
char s[N];
int prenum[N], presum[N];

int main() {
	while (scanf("%s", s + 1) != EOF) {
		pam.init();
		int len = strlen(s + 1);
		for (int i = 1; i <= len; ++i) {
			pam.insert(s[i]);
			prenum[i] = pam.num[pam.last];
			presum[i] = pam.sum[pam.last];
		}
		pam.init();
		ll res = 0;
		for (int i = len; i > 1; --i) {
			pam.insert(s[i]);
			int sufnum = pam.num[pam.last];
			int sufsum = pam.sum[pam.last];
			add(res, 1ll * prenum[i - 1] * sufnum % p * i % p * i % p);
			add(res, (1ll * sufsum * prenum[i - 1] % p - 1ll * prenum[i - 1] % p * sufnum % p - 1ll * sufnum * presum[i - 1] % p + p) % p * i % p);
			add(res, p - 1ll * presum[i - 1] * sufsum % p);
			add(res, 1ll * presum[i - 1] * sufnum % p);
		}
		printf("%lld\n", res);
	}
	return 0;
}

posted @ 2019-07-27 09:47  Dup4  阅读(123)  评论(0)    收藏  举报