Processing math: 100%

P4248 [AHOI2013] 差异

传送门

题目描述

给定一个长度为 n 的字符串 S,令 Ti 表示它从第 i 个字符开始的后缀。求

1i<jnlen(Ti)+len(Tj)2×lcp(Ti,Tj)

其中,len(a) 表示字符串 a 的长度,lcp(a,b) 表示字符串 a 和字符串 b 的最长公共前缀。

题解

可能比较简单做法吗?
考虑我们枚举所有子串作为公共前缀,如果这个子串长度为x,出现次数为y, 那么任取两个出现次数都可以构成两个以它为前缀的后缀, 统计y(y1)x为答案就行。

显然的问题来了,题目要求的是最长公共前缀,如果当前统计的不是最长的怎么办呢?
如果x不是最长的,那么就至少存在一个x+1的前缀,由于我们统计了每一个前缀, 我们不妨考虑在x+1时将这个答案减去。

对于每一个长度为x的子串,x1必然也会被统计同样多次,我们把它减去,那么长度为1的前缀会在2处被剪掉,2会被3减, 3会被4减, 除非已经是最长公共前缀了,这样做就没有问题。

对于每一个x,统计答案即为y(y1)(x(x1))=y(y1),也就是遍历sam,对于每一个endpos,将出现次数乘上出现次数减一再乘以当前endpos覆盖区间长度就好了。

题解说按边算贡献,本质差不多,但不失为一种不错角度

实现

没啥好说的,乘法记得都在前面乘个1ll,我直接扣sam板子很快就可以写出(如果不是因为我在学习神秘vim的话

#include <iostream>
#include <cstdio>
#include <vector>
#include <string>
#define ll long long
using namespace std;

int read(){
	int num=0, flag=1; char c=getchar();
	while(!isdigit(c) && c!='-') c=getchar();
	if(c=='-') flag=-1, c=getchar();
	while(isdigit(c)) num=num*10+c-'0', c=getchar();
	return num*flag; 
}

int readc(){
	char c=getchar();
	while(c<'a' || c>'z') c=getchar();
	return c-'a'; 
}

const int N = 500500;
int n, m;
char s[N];

ll ans = 0;

namespace sam{
	struct{
		int len, link, siz=0;
		int ch[26];
	}st[N<<1];
	vector<int> ptr[N<<1];
	int las, sz;
	
	void init(){
		las=1, sz=1;
		st[1].len=0, st[1].link=0; 
	}
	
	void extend(int c){
		int cur=++sz, p=las;
		st[cur].len=st[las].len+1, st[cur].siz=1;
		while(p && !st[p].ch[c]) 
			st[p].ch[c]=cur, p=st[p].link;
		
		if(p){
			int nex = st[p].ch[c];
			if(st[p].len+1 == st[nex].len){
				st[cur].link = nex;
			}else{
				int clone = ++sz;
				st[clone].len=st[p].len+1, st[clone].link = st[nex].link;
				for(int i=0; i<26; i++) st[clone].ch[i] = st[nex].ch[i];
				st[cur].link=clone, st[nex].link=clone;
				
				while(st[p].ch[c]==nex) st[p].ch[c]=clone, p=st[p].link;
			}
		}else{
			st[cur].link=1;
		}
		
		las = cur;
	}
	
	void dfsPtr(int x){
		for(int i=0; i<ptr[x].size(); i++) {
			dfsPtr(ptr[x][i]);
			st[x].siz += st[ptr[x][i]].siz;
		}
	}
	
	void buildPtr(){
		for(int i=1; i<=sz; i++) {
			ptr[st[i].link].push_back(i);
		}
		dfsPtr(1);
	}
	
	void clac1(){
		for(int i=1; i<=sz; i++){
			if(st[i].siz!=0)
            	ans += 1ll*st[i].siz * (st[i].siz - 1) * (st[i].len - st[st[i].link].len);
        } 
	}
}

string str;

int main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> str;
	for(int i=0; i<str.size(); i++){
		s[i+1] = str[i]-'a';
	} n = str.size();
	sam::init();
	for(int i=1; i<=n; i++) sam::extend(s[i]); 
	sam::buildPtr();
	sam::clac1();
    ans *= -1;
    // ll sum = 0;
    // for(int i=1; i<=n; i++){
    //     ans += sum + 1ll*(i-1)*(n-i+1);
    //     sum += n-i+1;
    // }
	ans += 1ll*(n-1)*n*(n+1)/2;
	cout << ans << endl;
	
	return 0;
}
posted @ 2024-08-04 21:33  ltdJcoder  阅读(18)  评论(0)    收藏  举报
点击右上角即可分享
微信分享提示