bzoj 4084 双旋转字符串

 

给两个集合A,B,找满足要求的(a,b)的对数,可以计算对于a,哪些b成立.

还有就是字符串hash的使用,感觉平时用字符串hash太少了.

 

  1 /**************************************************************
  2     Problem: 4084
  3     User: idy002
  4     Language: C++
  5     Result: Accepted
  6     Time:6456 ms
  7     Memory:290272 kb
  8 ****************************************************************/
  9  
 10 #include <cstdio>
 11 #include <set>
 12 #include <vector>
 13 #include <algorithm>
 14 #define N 8000010
 15 #define Base 31
 16 #define Mod 1000000007
 17 using namespace std;
 18  
 19 typedef long long dnt;
 20  
 21 int n, m, ln, lm;
 22 char *sa[N], *sb[N];
 23 char buf_arr[N], *buf=buf_arr;
 24 dnt fx[N], sx[N], pow[N];
 25 int fail[N];
 26 multiset<int> stb;
 27  
 28 int hash( char *s ) {
 29     dnt rt = 0;
 30     for( int i=0; s[i]; i++ )
 31         rt = (rt*Base + s[i]-'a') % Mod;
 32     return rt;
 33 }
 34 void init_hash( int ln, char *s ) {     //  ln>=1
 35     fx[0] = s[0]-'a';
 36     for( int j=1; j<ln; j++ )
 37         fx[j] = (fx[j-1]*Base + s[j]-'a') % Mod;
 38     sx[ln-1] = s[ln-1]-'a';
 39     for( int j=ln-2; j>=0; j-- )
 40         sx[j] = ((s[j]-'a')*pow[ln-1-j] + sx[j+1]) % Mod;
 41 }
 42 void init_fail( char *s ) {
 43     fail[0] = 0;
 44     fail[1] = 0;
 45     for( int i=1; s[i]; i++ ) {
 46         int j=fail[i];
 47         while( j && s[j]!=s[i] ) j=fail[j];
 48         if( s[j]==s[i] ) 
 49             fail[i+1]=j+1;
 50         else
 51             fail[i+1]=0;
 52     }
 53 }
 54 void work1() {  //  ln > lm
 55     vector<int> vc;
 56     int ll = (ln+lm)>>1;
 57     dnt ans = 0;
 58     for( int t=1; t<=n; t++ ) {
 59         init_fail(sa[t]+ll);
 60         for( int i=0; i<ll; i++ )
 61             buf[i] = sa[t][i];
 62         for( int i=0; i<ll; i++ )
 63             buf[ll+i] = sa[t][i];
 64         init_hash(ll+ll,buf);
 65         vc.clear();
 66         int i=0, j=0;
 67         while( i<ln-1 ) {
 68             while( i<ln-1 && j<ln-ll && buf[i]==sa[t][ll+j] ) i++, j++;
 69             if( j==ln-ll ) {
 70                 dnt v;
 71                 v = fx[i+(ll+ll-ln)-1]-fx[i-1]*pow[ll+ll-ln];
 72                 v = (v%Mod+Mod) % Mod;
 73                 vc.push_back(v);
 74                 j = fail[ln-ll];
 75             }
 76             if( i==ln-1 ) break;
 77             while( j && sa[t][ll+j]!=buf[i] ) j=fail[j];
 78             if( sa[t][ll+j]!=buf[i] ) i++;
 79         }
 80         sort( vc.begin(), vc.end() );
 81         vc.erase( unique(vc.begin(),vc.end()), vc.end() );
 82         for( int t=0; t<vc.size(); t++ )
 83             ans += stb.count(vc[t]);
 84     }
 85     printf( "%lld\n", ans );
 86 }
 87 void work3() {  //  ln = lm
 88     vector<int> vc;
 89     dnt ans = 0;
 90     for( int i=1; i<=n; i++ ) {
 91         init_hash(ln,sa[i]);
 92         vc.clear();
 93         vc.push_back( sx[0] );
 94         for( int j=1; j<ln; j++ ) {
 95             int v = ((dnt)sx[j]*pow[j]+fx[j-1]) % Mod;
 96             vc.push_back(v);
 97         }
 98         sort( vc.begin(), vc.end() );
 99         vc.erase( unique(vc.begin(), vc.end()), vc.end() );
100         for( int t=0; t<vc.size(); t++ )
101             ans += stb.count(vc[t]);
102     }
103     printf( "%lld\n", ans );
104 }
105 int main() {
106     scanf( "%d%d%d%d", &n, &m, &ln, &lm );
107     if( ln>lm ) {
108         for( int i=1; i<=n; i++ ) {
109             scanf( "%s", buf );
110             sa[i] = buf;
111             buf += ln+1;
112         }
113         for( int i=1; i<=m; i++ ) {
114             scanf( "%s", buf );
115             sb[i] = buf;
116             buf += lm+1;
117         }
118     } else {
119         swap(n,m);
120         swap(ln,lm);
121         for( int i=1; i<=m; i++ ) {
122             scanf( "%s", buf );
123             sb[i] = buf;
124             reverse( buf, buf+lm );
125             buf += lm+1;
126         }
127         for( int i=1; i<=n; i++ ) {
128             scanf( "%s", buf );
129             sa[i] = buf;
130             reverse( buf, buf+ln );
131             buf += ln+1;
132         }
133     }
134     pow[0] = 1;
135     for( int i=1; i<=ln; i++ )
136         pow[i] = pow[i-1]*Base % Mod;
137     for( int i=1; i<=m; i++ ) 
138         stb.insert( hash(sb[i]) );
139     if( ln!=lm )
140         work1();
141     else
142         work3();
143 }
View Code

 

posted @ 2015-06-16 15:05  idy002  阅读(561)  评论(0编辑  收藏  举报