在AC自动机上dp

通常AC自动机上的dp至少是两维的,第一维是字符串长度,第二维是AC自动机上的节点数,dp[i][j]表示长度为i的字符串在自动机上匹配到j节点。在进行转移时,选定一个已经匹配到的节点,去更新它可以到达的节点的状态。

 

洛谷P3041为例,在这一题中,先将所有的组合技插入到AC自动机中。当匹配到j节点上时,对于每一个可以匹配到的节点k,可以将匹配到k的最大值更新为匹配到j的最大值+匹配到k节点的得分,即转移方程为:

dp[i+1][k]=max(dp[i+1][k],dp[i][j]+score[k])

其中score[k]是匹配到k节点上可以获得的得分。每个节点的score也很好得出,因为每个组合技的结尾肯定能得到一分,同样,在它fail树上的子节点也肯定能拿到一分(因为这些子节点代表匹配到的字符串包含了该字符串),由此只需要在AC自动机的fail树上dfs一次就可以了。

AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;

typedef long long ll;

const int MAXN = 3e2+5;
const int INF = 1e9+7;
const int MOD = 1e4+7;

const int TRIE_MAX = 3; //字符集大小

int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该节点的得分
int AC_trie_pos; //字典树节点数
int AC_fail[MAXN]; //失配指针
vector<int> AC_fail_tree[MAXN]; //fail树

void AC_insert(char *p){ //加入新的单词
	int len = strlen(p);
	int pos = 0;
	for(int i=0;i<len;i++){
		int c = p[i] - 'A';
		if(!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
		pos = AC_trie[pos][c];
	}
	AC_trie_end[pos]++;
}

void AC_getfail(){ //构建失配指针
	AC_fail[0] = 0;
	queue<int> q;
	for(int i=0;i<TRIE_MAX;i++){
		if(AC_trie[0][i]){
			AC_fail[AC_trie[0][i]] = 0;
			AC_fail_tree[0].push_back(AC_trie[0][i]);
			q.push(AC_trie[0][i]);
		}
	}
	while(!q.empty()){
		int k = q.front(); q.pop();
		for(int i=0;i<TRIE_MAX;i++){
			if(AC_trie[k][i]){
				AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
				AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
				q.push(AC_trie[k][i]);
			}
			else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
		}
	}
}

void AC_fail_dfs(int k,int p){ //对fail树树上差分,获取每个节点的得分
	AC_trie_end[k] += p;
	for (int i = 0; i < AC_fail_tree[k].size();i++){
		AC_fail_dfs(AC_fail_tree[k][i],AC_trie_end[k]);
	}
}

void AC_init(){ //初始化
	AC_trie_pos = 0;
	memset(AC_trie,0,sizeof(AC_trie));
	memset(AC_trie_end,0,sizeof(AC_trie_end));
	for(int i=0;i<MAXN;i++) AC_fail_tree[i].clear();
}

int dp[1005][MAXN];
char s[MAXN];

int main(){
	int n,len;
	while(~scanf("%d %d",&n,&len)){
		AC_init();
		for(int i=1;i<=n;i++){
			scanf("%s",s);
			AC_insert(s);
		}
		AC_getfail();
		AC_fail_dfs(0,0);
		for(int i=0;i<len;i++){
			for(int j=0;j<=AC_trie_pos;j++){
				dp[i][j]=-INF; //先将每个状态初始化为无限小
			}
		}
		dp[0][0]=0;
		for(int i=0;i<len;i++){
			for(int j=0;j<=AC_trie_pos;j++){
				for(int k=0;k<TRIE_MAX;k++){
					dp[i+1][AC_trie[j][k]]=max(dp[i+1][AC_trie[j][k]],dp[i][j]+AC_trie_end[AC_trie[j][k]]);
				}
			}
		}
		int ans = 0;
		for(int j=0;j<=AC_trie_pos;j++){
			ans=max(dp[len][j],ans);
		}
		printf("%d\n",ans);
	}
}

 

再比如HDU2825,这道题的m并不大,我们可以通过状压来表示当前已经匹配到的词。因此我们额外再开一维flag,dp[i][j][flag]表示长度为i的字符串匹配到j节点上,匹配词集为flag时的方案数。推出转移方程为:

dp[i + 1][to][flag|flag[to]] += dp[i][j][flag];

其中to代表要更新的节点,flag[to]则代表字符串匹配到to所包含的词集。这个flag数组,我们可以通过合并fail指针指向节点的词集来获得。即:

flag[k] = flag[k] | flag[AC_fail[k]];
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;

typedef long long ll;

const int MAXN = 1e4 + 5;
const int INF = 1e9 + 7;
const int MOD = 20090717;

const int TRIE_MAX = 26; //字符集大小

int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针

int flag[MAXN];
int tot;

void AC_insert(char *p) { //加入新的单词
	int len = strlen(p);
	int pos = 0;
	for (int i = 0; i < len; i++) {
		int c = p[i] - 'a';
		if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
		pos = AC_trie[pos][c];
	}
	//AC_trie_end[pos] = ++tot;
	flag[pos] = flag[pos] | (1 << tot);
	++tot;
	//AC_string_id[pos].push_back(++AC_string_pos);
}

void AC_getfail() { //构建失配指针
	AC_fail[0] = 0;
	queue<int> q;
	for (int i = 0; i < TRIE_MAX; i++) {
		if (AC_trie[0][i]) {
			AC_fail[AC_trie[0][i]] = 0;
			//AC_fail_tree[0].push_back(AC_trie[0][i]);
			q.push(AC_trie[0][i]);
		}
	}
	while (!q.empty()) {
		int k = q.front(); q.pop();
		for (int i = 0; i < TRIE_MAX; i++) {
			if (AC_trie[k][i]) {
				AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
				//AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
				q.push(AC_trie[k][i]);
			}
			else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
		}
		flag[k] = flag[k] | flag[AC_fail[k]];
	}
}

void AC_init() { //初始化
	AC_trie_pos = 0;
	memset(AC_trie, 0, sizeof(AC_trie));
	memset(AC_trie_end, 0, sizeof(AC_trie_end));

	memset(flag, 0, sizeof(flag));
}

char cs[MAXN];
int dp[35][1005][2205];

int main() {
	int n, m, tk;
	while (~scanf("%d %d %d", &n, &m, &tk) && (n || m || tk)) {
		AC_init();
		tot = 0;
		for (int i = 1; i <= m; i++) {
			scanf("%s", cs);
			AC_insert(cs);
		}
		AC_getfail();
		int mp = 1 << m;
		for (int i = 0; i <= n; i++) {
			for (int j = 0; j <= AC_trie_pos; j++) {
				for (int k = 0; k <= mp; k++) {
					dp[i][j][k] = 0;
				}
			}
		}
		dp[0][0][0] = 1;
		for (int i = 0; i < n; i++) {
			for (int j = 0; j <= AC_trie_pos; j++) {
				for (int k = 0; k <= mp; k++) {
					if (dp[i][j][k]) {
						for (int p = 0; p < TRIE_MAX; p++) {
							int to = AC_trie[j][p];
							int f = k | flag[to];
							dp[i + 1][to][f] += dp[i][j][k];
							dp[i + 1][to][f] %= MOD;
						}
					}
				}
			}
		}
		int sum = 0;
		for (int j = 0; j <= AC_trie_pos; j++) {
			for (int k = 0; k <= mp; k++) {
				int p = k;
				int s = 0;
				while (p) {
					if (p & 1) s++;
					p = p >> 1;
				}
				if (s >= tk) sum += dp[n][j][k];
				sum %= MOD;
			}
		}
		printf("%d\n", sum);
	}
	return 0;
}

 

一些AC自动机上的dp可能数据较大,这时候需要用矩阵加速dp,以2021新疆省赛A题为例,容易看出,这道题的dp式子与洛谷P3041相同,但字符串的长度最高可达1e9,显然会TLE。这时候用矩阵加速dp,就看轻易AC了。

AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;

typedef long long ll;

const int MAXN = 1e6 + 5;
const ll INF = 1e18 + 7;
const ll MOD = 1e18 + 7;

const int TRIE_MAX = 26; //字符集大小

int AC_trie[MAXN][TRIE_MAX]; //字典树
ll AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针

void AC_insert(char *p,int kp) { //加入新的单词
	int len = strlen(p);
	int pos = 0;
	for (int i = 0; i < len; i++) {
		int c = p[i] - 'a';
		if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
		pos = AC_trie[pos][c];
	}
	AC_trie_end[pos] += kp;
}

void AC_getfail() { //构建失配指针
	AC_fail[0] = 0;
	queue<int> q;
	for (int i = 0; i < TRIE_MAX; i++) {
		if (AC_trie[0][i]) {
			AC_fail[AC_trie[0][i]] = 0;
			//AC_fail_tree[0].push_back(AC_trie[0][i]);
			q.push(AC_trie[0][i]);
		}
	}
	while (!q.empty()) {
		int k = q.front(); q.pop();
		for (int i = 0; i < TRIE_MAX; i++) {
			if (AC_trie[k][i]) {
				AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
				//AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
				q.push(AC_trie[k][i]);
			}
			else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
		}
		AC_trie_end[k] += AC_trie_end[AC_fail[k]];
	}
}

void AC_init() { //初始化
	AC_trie_pos = 0;
	memset(AC_trie, 0, sizeof(AC_trie));
	memset(AC_trie_end, 0, sizeof(AC_trie_end));
}

const int MATRIX_MAXN = 2e2 + 5;

struct matrix {
	ll m[MATRIX_MAXN][MATRIX_MAXN];
	int n;
	matrix() {
		memset(m, -0x3f3f3f3f, sizeof(m));
	}
	matrix(int n) {
		this->n = n;
		memset(m, -0x3f3f3f3f, sizeof(m));
	}
	matrix(int n, bool p) {
		this->n = n;
		memset(m, -0x3f3f3f3f, sizeof(m));
		if (p) {
			for (int i = 0; i < n; i++)
				m[i][i] = 1;
		}
	}
	matrix operator * (const matrix &p) const {
		matrix ret(n);
		for (int i = 0; i < n; i++)
			for (int j = 0; j < n; j++) {
				//ret.m[i][j] = -0x3f3f3f3f;
				for (int k = 0; k < n; k++)
					//ret.m[i][j] = (ret.m[i][j] + m[i][k] * p.m[k][j]) % MOD;
					ret.m[i][j] = max(ret.m[i][j], m[i][k] + p.m[k][j]);
			}
		return ret;
	}
	void print() {
		for (int i = 0; i < n; i++) {
			for (int j = 0; j < n; j++) {
				if (j) printf(" ");
				printf("%d", m[i][j]);
			}
			printf("\n");
		}
	}
};

matrix MAT_pow(matrix base, int k) {
	matrix ret = base;
	while (k) {
		if (k & 1) ret = ret * base;
		base = base * base;
		k = k >> 1;
	}
	return ret;
}

char cs[205];

int main() {
	int n, m;
	while (~scanf("%d %d", &n, &m)) {
		AC_init();
		for (int i = 1; i <= m; i++) {
			int k;
			scanf("%s %d", cs, &k);
			AC_insert(cs, k);
		}
		AC_getfail();
		matrix z(AC_trie_pos+1);
		for (int j = 0; j <= AC_trie_pos; j++) {
			for (int p = 0; p < TRIE_MAX; p++) {
				int to = AC_trie[j][p];
				z.m[j][to] = AC_trie_end[to];
			}
		}
		z = MAT_pow(z, n - 1);
		ll ans = -2 * INF;
		for (int j = 0; j <= AC_trie_pos; j++) {
			ans = max(ans, z.m[0][j]);
		}
		printf("%lld\n", ans);
	}
	return 0;
}
posted @ 2021-08-24 15:12  樱与梅子  阅读(735)  评论(0)    收藏  举报