HDU3341 Lost's revenge(AC自动机&&dp)

一看到ACGT就会想起AC自动机上的dp,这种奇怪的联想可能是源于某道叫DNA什么的题的。

题意,给你很多个长度不大于10的小串,小串最多有50个,然后有一个长度<40的串,然后让你将这个这个长度<40的串经过重新排列之后,小串在里面出现的次数总和最大。譬如如果我的小串是AA,AAC,长串是CAAA,我们重新排列成AAAC之后,AA在里面出现了2次,AAC出现了1次,总和是3次,这个数字就是我们要求的。

思路:思路跟HDU4758 walk through squares很像的,首先对每个小串插Trie树,建自动机,然后要做一下预处理,对于每个状态预处理出到达该状态时匹配了多少个小串,方法就是沿着失配边将cnt加起来。然后对于每个状态,如果它不存在某个字母的后继,就沿着失配边走找到存在该字母的后继,这样预处理后,后面的状态转移起来就比较方便。然后定义状态dp[A][C][G][T][sta]表示已经匹配的A,C,G,T对应为A,C,G,T个,在自动机上的状态为sta时所能匹配到的最大的状态数。然后转移就好。

Trick的部分是,虽然A,C,G,T所能产生的状态数最大是11*11*11*11(即40平均分的时候的情况),但是因为有可能有些字母出现40次,所以开的时候要dp[41][41][41][41][550],想到这里我就不知道怎么写了- -0。后来发现其实可以先hash一下,对于sta[i][j][k][t]=用一个数字s代表其状态,然后开一个数组p[s][0~3]存的是该状态对应的A,C,G,T数,然后再转移就好。

不过貌似跑的有点慢,3s多,感觉挺容易TLE的。

#pragma warning(disable:4996)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cmath>
#include<iostream>
#include<queue>
#define maxn 1500
using namespace std;

char str[50][15];
char T[50];
int n;

void convert(char *s){
	int len = strlen(s);
	for (int i = 0; i < len; i++){
		if (s[i] == 'A') s[i] = 'a';
		else if (s[i] == 'C') s[i] = 'b';
		else if (s[i] == 'G') s[i] = 'c';
		else s[i] = 'd';
	}
}

struct Trie{
	Trie *fail, *go[4];
	int cnt; bool flag;
	void init(){
		memset(go, 0, sizeof(go)); fail = NULL; cnt = 0; flag = false;
	}
}pool[maxn],*root;
int tot;

void insert(char *c){
	int len = strlen(c); Trie *p = root;
	for (int i = 0; i < len; i++){
		if (p->go[c[i] - 'a'] != 0) p = p->go[c[i] - 'a'];
		else{
			pool[tot].init();
			p->go[c[i] - 'a'] = &pool[tot++];
			p = p->go[c[i] - 'a'];
		}
	}
	p->cnt++;
}

void getFail()
{
	queue<Trie*> que;
	que.push(root);
	root->fail = NULL;
	while (!que.empty()){
		Trie *temp = que.front(); que.pop();
		Trie *p = NULL;
		for (int i = 0; i < 4; i++){
			if (temp->go[i] != NULL){
				if (temp == root) temp->go[i]->fail = root;
				else{
					p = temp->fail;
					while (p != NULL){
						if (p->go[i] != NULL){
							temp->go[i]->fail = p->go[i]; break;
						}
						p = p->fail;
					}
					if (p == NULL) temp->go[i]->fail = root;
				}
				que.push(temp->go[i]);
			}
		}
	}
}

int dfs(Trie *p){
	if (p == root) return 0;
	if (p->flag == true) return p->cnt;
	p->cnt += dfs(p->fail); p->flag = true;
	return p->cnt;
}

int sta[45][45][45][45];
int p[15000][4];
int stanum;
int dp[15000][520];
int A, B, C, D;

int main()
{
	int ca = 0;
	while (cin >> n&&n)
	{
		tot = 0; root = &pool[tot++]; root->init();
		for (int i = 0; i < n; i++){
			scanf("%s", str[i]); convert(str[i]);
			insert(str[i]);
		}
		scanf("%s", T); A = B = C = D = 0; int len = strlen(T);
		for (int i = 0; i < len; i++){
			if (T[i] == 'A') A++;
			else if (T[i] == 'C') B++;
			else if (T[i] == 'G') C++;
			else D++;
		}
		getFail();
		for (int i = 0; i < tot; i++) dfs(&pool[i]);
		for (int i = 0; i < tot; i++){
			Trie *p = &pool[i];
			for (int k = 0; k < 4; k++){
				if (p->go[k] == NULL){
					Trie *temp = p; temp = temp->fail;
					while (temp != NULL){
						if (temp->go[k] != NULL) {
							p->go[k] = temp->go[k]; break;
						}
						temp = temp->fail;
					}
					if (temp == NULL) p->go[k] = root;
				}
			}
		}
		stanum = 0;
		for (int i = 0; i <= A; i++){
			for (int j = 0; j <= B; j++){
				for (int k = 0; k <= C; k++){
					for (int t = 0; t <= D; t++){
						sta[i][j][k][t] = stanum;
						p[stanum][0] = i; p[stanum][1] = j;
						p[stanum][2] = k; p[stanum][3] = t; stanum++;
					}
				}
			}
		}
		memset(dp, -1, sizeof(dp)); int a, b, c, d;
		dp[0][0] = 0;
		for (int i = 0; i < stanum; i++){
			a = p[i][0]; b = p[i][1]; c = p[i][2]; d = p[i][3];
			for (int j = 0; j < tot; j++){
				if (dp[i][j] == -1) continue;
				if (a + 1 <= A) dp[sta[a + 1][b][c][d]][pool[j].go[0] - pool] =
					max(dp[sta[a + 1][b][c][d]][pool[j].go[0] - pool], dp[i][j] + pool[j].go[0]->cnt);

				if (b + 1 <= B) dp[sta[a][b + 1][c][d]][pool[j].go[1] - pool] =
					max(dp[sta[a][b + 1][c][d]][pool[j].go[1] - pool], dp[i][j] + pool[j].go[1]->cnt);

				if (c + 1 <= C) dp[sta[a][b][c + 1][d]][pool[j].go[2] - pool] =
					max(dp[sta[a][b][c + 1][d]][pool[j].go[2] - pool], dp[i][j] + pool[j].go[2]->cnt);

				if (d + 1 <= D) dp[sta[a][b][c][d + 1]][pool[j].go[3] - pool] =
					max(dp[sta[a][b][c][d + 1]][pool[j].go[3] - pool], dp[i][j] + pool[j].go[3]->cnt);
			}
		}
		int ans = 0; int fin = sta[A][B][C][D];
		for (int i = 0; i < tot; i++){
			ans = max(ans, dp[fin][i]);
		}
		printf("Case %d: %d\n", ++ca, ans);
	}
	return 0;
}

 

posted @ 2014-04-09 20:43  chanme  阅读(157)  评论(0编辑  收藏  举报