Loading

题解—[BJWC2011]禁忌

发现这个题的题解写的有点不亲民啊,所以我来造福一下社会,并且我的做法好像和各位大佬不太一样。

首先,有一个题跟这个题有异曲同工之妙,我也写了题解。
推荐大家可以做一做,有一些思想还是很相似的,尤其是我的思路在这两道题之间还是很连贯的。
\(GT\)考试 题面
\(GT\)考试 题解

1.建自动机


首先这里面有很多串,考虑用\(AC\)自动机,先把这些串插到\(Tire\)树里面,对每一个串的结尾打一个标记,建一个\(AC\)自动机。
这里建自动机有一点需要小注意的,如果我们到达一个带有标记的节点,说明我们已经到了一个串的末尾。
由于这道题是把长串切割成若干段,所以每个字母是不能被重复利用的
那么,到达一个串的末尾之后,我们就应该乖乖回根节点重新匹配,也就是,叶子节点的\(fail\)\(0\)号节点


2.\(dp\)

我们先考虑推一下朴素的\(dp\),也就是统计每个长度下放出他的全排列会收到多少伤害,最后除以\(alphabet^{len}\)
具体的说,我们设计\(f[i][j]\)代表长串匹配到\(i\)位,处于\(AC\)自动机上的\(j\)节点时收到的伤害的和
式子及其好推,其中\(ends\)代表这个节点是不是一个串的结尾,是个布尔变量

\[f[i + 1][tr [j][p]]+=f[i][j]+ends[tr[i][j]] \]

然后我们开心的码出来,发现过不了样例。
(可能各位都不会这么傻码一个看起来都不对的式子,但本弱确实码了)
手模一会样例,我们发现\(f[i][j]\)可能不止包含一个串
以样例为例子,\(f[3][1]\)可能包含了\(aba,bba,aaa\)三个串
所以我们还需要记录一下当前状态包含了多少个串。
我们记录\(size[i][j]\)代表长串匹配到\(i\)位,处于\(AC\)自动机上的\(j\)节点时有多少个串
然后初始化\(size[0][0]=1\),对于每一个状态\(f[i][j]\),枚举\(p\)就可以转移。

\[f[i+1][tr[j][p]]+=f[i][j]+ends[tr[i][j]]*size[i][j]\ \ ,\ \ size[i+1][tr[j][p]]+=size[i][j] \]

很显然,答案就是\(\Large \frac{\sum_{i=0}^{tnt}f[len][i]}{alphabet^{len}}\),其中,\(tnt\)\(Tire\)树上的节点个数。


打完这个\(dp\)之后,我们得到了\(50pts\)的好成绩,但是还有两个点是\(WA\)的,貌似努一努力就可以骗到


3.\(dp\)的小优化

\(WA\)的原因是分子和分母都很大,但是最终的分数却在可表示范围内
这个时候,我们一般要改变\(dp\),从计数\(dp\)改成概率\期望\(dp\)
我们再次设计状态,\(f[i][j]\)代表长串匹配到\(i\)位,处于\(AC\)自动机上的\(j\)节点时收到的伤害的期望
然后,每一个节点有\(alphabet\)种选择转移到其他点,那么每种转移的概率都为\(\frac{1}{alphabet}\)
那么,新的式子就可以推出来了。

\[\large f[i+1][tr[j][p]]+=\frac{f[i][j]+ends[tr[j][p]]*size[i][j]}{alphabet}\ \ ,\ \ size[i+1][tr[j][p]]+=\frac{size[i][j]}{alphabet} \]

答案很显然是\(\large\sum_{i=0}^{tnt}f[len][i]\)


优化完之后,可以得到\(70pts\)的好成绩,剩下三个点显然现在的做法已经无能为力了


4.矩阵优化

这个\(dp\)的复杂度明显很爆炸(\(O(len*tnt*alphabet)\)),并且这个转移就很像矩阵优化(可以参考集训队大佬论文学习\(link\))。
我们直接填一下系数(这个东西填起来不难,但是需要动一下脑子),搞一个矩阵快速幂优化,复杂度就变成\(O(log(len)*(2*tnt)^3)\)
值得一提的是,我直接把\(f_{0~tnt}\)\(size_{0~tnt}\) 压到一个矩阵里面跑,这可以解释复杂度中\((2*tnt)^3\)

5.代码实现

(放这么丑的代码真是对不起大家的眼睛啦)

#include <cstring>
#include <algorithm>
#include <cstdio>
#include <queue>
#define mp make_pair
#define R register int
#define int long 
#define printf Ruusupuu = printf
#define scanf Ruusupuu = scanf

int Ruusupuu ;

using namespace std ;
typedef long long L ;
typedef long double D ;
typedef unsigned long long G ;
typedef pair< int , int > PI ;
const int N = 2e2 + 10 ;

inline int read(){
	int w = 0 ; bool fg = 0 ; char ch = getchar() ;
	while( ch < '0' || ch > '9' ) fg |= ( ch == '-' ) , ch = getchar() ;
	while( ch >= '0' && ch <= '9' ) w = ( w << 1 ) + ( w << 3 ) + ( ch ^ '0' ) , ch = getchar() ;
	return fg ? -w : w ; 
}

inline D qpow( D a , int b ){
	D ans = 1.0 ; 
	while( b ){
		if( b & 1 ) ans *= a ;
		a *= a ; b >>= 1 ;
	} return ans ;
}

int n , l , abt ;
int tnt , tr [N][27] , fail [N] ; 
char s [N] ; 
D g [N][N] ,  f [N] , ends [N] , c [N][N] , q [N] ;

inline void ins(){
	int x = 0 , len = strlen( s + 1 ) ;
	for( R i = 1 ; i <= len ; i ++ ){
		int ch = s [i] - 'a' ;
		if( !tr [x][ch] ) tr [x][ch] = ++ tnt ;
			x = tr [x][ch] ;
	} ends [x] = 1.0 ; 
}

void build(){
	queue< int > q ; 
	for( R i = 0 ; i < abt ; i ++ ) if( tr [0][i] ) q.push( tr [0][i] ) ;
	
	while( !q.empty() ){
		int x = q.front() ; q.pop() ;
		for( R i = 0 ; i < abt ; i ++ )
			if( tr [x][i] && !ends [tr [x][i]] ) fail [tr [x][i]] = tr [fail [x]][i] , q.push( tr [x][i] ) ;
			else if( tr [x][i] && ends [tr [x][i]] ) q.push( tr [x][i] ) ;
			else tr [x][i] = tr [fail [x]][i] ; 
	} tnt ++ ;
	
	for( R i = 0 ; i < tnt ; i ++ )
		for( R j = 0 ; j < abt ; j ++ ){
			g [i][tr [i][j]] += 1.0 / (D) abt ;
			g [i + tnt][tr [i][j]] += ends [tr [i][j]] / (D) abt ;
			g [i + tnt][tr [i][j] + tnt] += 1.0 / (D) abt ;
		} 
	
}

void sc(){
	n = read() , l = read() , abt = read() ;
	for( R i = 1 ; i <= n ; i ++ ) 
		scanf( "%s" , s + 1 ) , ins() ; 
}

inline void mul( D a [N] , D b [N][N] ){
	for( R i = 0 ; i < N ; i ++ ) q [i] = 0.0 ;

	for( R i = 0 ; i < 2 * tnt ; i ++ )
		for( R j = 0 ; j < 2 * tnt ; j ++ )
			q [i] += ( a [j] * b [j][i] ) ;

	memcpy( a , q , sizeof( q ) ) ;
} 

inline void mulself( D a [N][N] ){
	for( R i = 0 ; i < N ; i ++ ) for( R j = 0 ; j < N ; j ++ ) c [i][j] = 0.0 ;
	
	for( R i = 0 ; i < 2 * tnt ; i ++ )
		for( R j = 0 ; j < 2 * tnt ; j ++ )
			for( R k = 0 ; k < 2 * tnt ; k ++ ) 
				c [i][j] += ( a [i][k] * a [k][j] ) ;
			
	memcpy( a , c , sizeof( c ) ) ;
}

void work(){	
	build() , f [tnt] = 1.0 ;

	for( ; l ; l >>= 1 ){
		if( l & 1 ) mul( f , g ) ;
		mulself( g ) ;
	}
	
	D ans = 0 ;
	for( R i = 0 ; i < tnt ; i ++ ) 
		ans += f [i] ;
	
	printf( "%.7Lf\n" , ans ) ;
}

signed main(){	
	sc() ;
	work() ;
	return 0 ;
} 
posted @ 2021-06-30 18:52  Soresen  阅读(81)  评论(2)    收藏  举报