在AC自动机上dp
通常AC自动机上的dp至少是两维的,第一维是字符串长度,第二维是AC自动机上的节点数,dp[i][j]表示长度为i的字符串在自动机上匹配到j节点。在进行转移时,选定一个已经匹配到的节点,去更新它可以到达的节点的状态。
以洛谷P3041为例,在这一题中,先将所有的组合技插入到AC自动机中。当匹配到j节点上时,对于每一个可以匹配到的节点k,可以将匹配到k的最大值更新为匹配到j的最大值+匹配到k节点的得分,即转移方程为:
dp[i+1][k]=max(dp[i+1][k],dp[i][j]+score[k])
其中score[k]是匹配到k节点上可以获得的得分。每个节点的score也很好得出,因为每个组合技的结尾肯定能得到一分,同样,在它fail树上的子节点也肯定能拿到一分(因为这些子节点代表匹配到的字符串包含了该字符串),由此只需要在AC自动机的fail树上dfs一次就可以了。
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 3e2+5;
const int INF = 1e9+7;
const int MOD = 1e4+7;
const int TRIE_MAX = 3; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该节点的得分
int AC_trie_pos; //字典树节点数
int AC_fail[MAXN]; //失配指针
vector<int> AC_fail_tree[MAXN]; //fail树
void AC_insert(char *p){ //加入新的单词
int len = strlen(p);
int pos = 0;
for(int i=0;i<len;i++){
int c = p[i] - 'A';
if(!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
}
AC_trie_end[pos]++;
}
void AC_getfail(){ //构建失配指针
AC_fail[0] = 0;
queue<int> q;
for(int i=0;i<TRIE_MAX;i++){
if(AC_trie[0][i]){
AC_fail[AC_trie[0][i]] = 0;
AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while(!q.empty()){
int k = q.front(); q.pop();
for(int i=0;i<TRIE_MAX;i++){
if(AC_trie[k][i]){
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
}
}
void AC_fail_dfs(int k,int p){ //对fail树树上差分,获取每个节点的得分
AC_trie_end[k] += p;
for (int i = 0; i < AC_fail_tree[k].size();i++){
AC_fail_dfs(AC_fail_tree[k][i],AC_trie_end[k]);
}
}
void AC_init(){ //初始化
AC_trie_pos = 0;
memset(AC_trie,0,sizeof(AC_trie));
memset(AC_trie_end,0,sizeof(AC_trie_end));
for(int i=0;i<MAXN;i++) AC_fail_tree[i].clear();
}
int dp[1005][MAXN];
char s[MAXN];
int main(){
int n,len;
while(~scanf("%d %d",&n,&len)){
AC_init();
for(int i=1;i<=n;i++){
scanf("%s",s);
AC_insert(s);
}
AC_getfail();
AC_fail_dfs(0,0);
for(int i=0;i<len;i++){
for(int j=0;j<=AC_trie_pos;j++){
dp[i][j]=-INF; //先将每个状态初始化为无限小
}
}
dp[0][0]=0;
for(int i=0;i<len;i++){
for(int j=0;j<=AC_trie_pos;j++){
for(int k=0;k<TRIE_MAX;k++){
dp[i+1][AC_trie[j][k]]=max(dp[i+1][AC_trie[j][k]],dp[i][j]+AC_trie_end[AC_trie[j][k]]);
}
}
}
int ans = 0;
for(int j=0;j<=AC_trie_pos;j++){
ans=max(dp[len][j],ans);
}
printf("%d\n",ans);
}
}
再比如HDU2825,这道题的m并不大,我们可以通过状压来表示当前已经匹配到的词。因此我们额外再开一维flag,dp[i][j][flag]表示长度为i的字符串匹配到j节点上,匹配词集为flag时的方案数。推出转移方程为:
dp[i + 1][to][flag|flag[to]] += dp[i][j][flag];
其中to代表要更新的节点,flag[to]则代表字符串匹配到to所包含的词集。这个flag数组,我们可以通过合并fail指针指向节点的词集来获得。即:
flag[k] = flag[k] | flag[AC_fail[k]];
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 1e4 + 5;
const int INF = 1e9 + 7;
const int MOD = 20090717;
const int TRIE_MAX = 26; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针
int flag[MAXN];
int tot;
void AC_insert(char *p) { //加入新的单词
int len = strlen(p);
int pos = 0;
for (int i = 0; i < len; i++) {
int c = p[i] - 'a';
if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
}
//AC_trie_end[pos] = ++tot;
flag[pos] = flag[pos] | (1 << tot);
++tot;
//AC_string_id[pos].push_back(++AC_string_pos);
}
void AC_getfail() { //构建失配指针
AC_fail[0] = 0;
queue<int> q;
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[0][i]) {
AC_fail[AC_trie[0][i]] = 0;
//AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while (!q.empty()) {
int k = q.front(); q.pop();
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[k][i]) {
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
//AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
flag[k] = flag[k] | flag[AC_fail[k]];
}
}
void AC_init() { //初始化
AC_trie_pos = 0;
memset(AC_trie, 0, sizeof(AC_trie));
memset(AC_trie_end, 0, sizeof(AC_trie_end));
memset(flag, 0, sizeof(flag));
}
char cs[MAXN];
int dp[35][1005][2205];
int main() {
int n, m, tk;
while (~scanf("%d %d %d", &n, &m, &tk) && (n || m || tk)) {
AC_init();
tot = 0;
for (int i = 1; i <= m; i++) {
scanf("%s", cs);
AC_insert(cs);
}
AC_getfail();
int mp = 1 << m;
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= AC_trie_pos; j++) {
for (int k = 0; k <= mp; k++) {
dp[i][j][k] = 0;
}
}
}
dp[0][0][0] = 1;
for (int i = 0; i < n; i++) {
for (int j = 0; j <= AC_trie_pos; j++) {
for (int k = 0; k <= mp; k++) {
if (dp[i][j][k]) {
for (int p = 0; p < TRIE_MAX; p++) {
int to = AC_trie[j][p];
int f = k | flag[to];
dp[i + 1][to][f] += dp[i][j][k];
dp[i + 1][to][f] %= MOD;
}
}
}
}
}
int sum = 0;
for (int j = 0; j <= AC_trie_pos; j++) {
for (int k = 0; k <= mp; k++) {
int p = k;
int s = 0;
while (p) {
if (p & 1) s++;
p = p >> 1;
}
if (s >= tk) sum += dp[n][j][k];
sum %= MOD;
}
}
printf("%d\n", sum);
}
return 0;
}
一些AC自动机上的dp可能数据较大,这时候需要用矩阵加速dp,以2021新疆省赛A题为例,容易看出,这道题的dp式子与洛谷P3041相同,但字符串的长度最高可达1e9,显然会TLE。这时候用矩阵加速dp,就看轻易AC了。
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 1e6 + 5;
const ll INF = 1e18 + 7;
const ll MOD = 1e18 + 7;
const int TRIE_MAX = 26; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
ll AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针
void AC_insert(char *p,int kp) { //加入新的单词
int len = strlen(p);
int pos = 0;
for (int i = 0; i < len; i++) {
int c = p[i] - 'a';
if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
}
AC_trie_end[pos] += kp;
}
void AC_getfail() { //构建失配指针
AC_fail[0] = 0;
queue<int> q;
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[0][i]) {
AC_fail[AC_trie[0][i]] = 0;
//AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while (!q.empty()) {
int k = q.front(); q.pop();
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[k][i]) {
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
//AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
AC_trie_end[k] += AC_trie_end[AC_fail[k]];
}
}
void AC_init() { //初始化
AC_trie_pos = 0;
memset(AC_trie, 0, sizeof(AC_trie));
memset(AC_trie_end, 0, sizeof(AC_trie_end));
}
const int MATRIX_MAXN = 2e2 + 5;
struct matrix {
ll m[MATRIX_MAXN][MATRIX_MAXN];
int n;
matrix() {
memset(m, -0x3f3f3f3f, sizeof(m));
}
matrix(int n) {
this->n = n;
memset(m, -0x3f3f3f3f, sizeof(m));
}
matrix(int n, bool p) {
this->n = n;
memset(m, -0x3f3f3f3f, sizeof(m));
if (p) {
for (int i = 0; i < n; i++)
m[i][i] = 1;
}
}
matrix operator * (const matrix &p) const {
matrix ret(n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) {
//ret.m[i][j] = -0x3f3f3f3f;
for (int k = 0; k < n; k++)
//ret.m[i][j] = (ret.m[i][j] + m[i][k] * p.m[k][j]) % MOD;
ret.m[i][j] = max(ret.m[i][j], m[i][k] + p.m[k][j]);
}
return ret;
}
void print() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (j) printf(" ");
printf("%d", m[i][j]);
}
printf("\n");
}
}
};
matrix MAT_pow(matrix base, int k) {
matrix ret = base;
while (k) {
if (k & 1) ret = ret * base;
base = base * base;
k = k >> 1;
}
return ret;
}
char cs[205];
int main() {
int n, m;
while (~scanf("%d %d", &n, &m)) {
AC_init();
for (int i = 1; i <= m; i++) {
int k;
scanf("%s %d", cs, &k);
AC_insert(cs, k);
}
AC_getfail();
matrix z(AC_trie_pos+1);
for (int j = 0; j <= AC_trie_pos; j++) {
for (int p = 0; p < TRIE_MAX; p++) {
int to = AC_trie[j][p];
z.m[j][to] = AC_trie_end[to];
}
}
z = MAT_pow(z, n - 1);
ll ans = -2 * INF;
for (int j = 0; j <= AC_trie_pos; j++) {
ans = max(ans, z.m[0][j]);
}
printf("%lld\n", ans);
}
return 0;
}

浙公网安备 33010602011771号