codeforces 17c balance
给出一个长度为 \(n\) 的字符串 \(s\) ,有小写字母 \(a\) \(b\) \(c\) 组成,可以在字符串上进行如下两种操作:
- 选中两个相邻的字符,把第二个字符修改成第一个字符.
- 选中两个相邻的字符,把第一个字符修改成第二个字符.
定义 \(|a|\) 为字符串中 \(a\) 的出现次数,\(|b|\) 为字符串中 \(b\) 的出现次数, \(|c|\) 为字符串中 \(c\) 的出现次数.
如果一个字符串的 \(-1\leq |a|-|b|\leq 1\) , \(-1\leq |a|-|c|\leq 1\) , \(-1\leq |b|-|c|\leq 1\),则称此字符串为balanced string.
问,可以通过上面的操作获得多少个balanced string?
\(1\leq n\leq 150\)
做法1 & my solution
我一开始想了很久操作的形态和变化,发现都不好设计dp状态,因为是考虑上了操作,所以还有重复的问题.
这个时候,我就想,跳出操作的圈子,看操作后的序列的性质,发现字母之间的相对顺序是不变的.
两个字符串 \(A\) , \(B\) , 去重后得到了 \(A'\) , \(B'\) . 如果 \(A\) 操作可以得到 \(B\) ,那么 \(B'\) 必须是 \(A'\) 里的一个子序列.
此时,可以去重后得到 \(s\) . 此时,又出现了一个问题,可能下标不同的子序列最后得到的字符串相同.
遇到这种问题,通常都是规定取最小的下标. 用 \(nxt(i,0/1/2)\) 表示 \(i\) 之后第一个为 \(a/b/c\) 的位置.
可以设计得到 \(dp(i,a,b,c)\) 为到了 \(s\) 串中的第 \(i\) 位,\(|a|\ |b|\ |c|\) 分别为 \(a\ b\ c\) 时的方案数.
\(dp\) 转移有 \(6\) 种,
\(if\ s_i=a,\ dp(i,a,b,c)\longrightarrow dp(i,a+1,b,c)\)
\(else \ dp(i,a,b,c)\longrightarrow dp(nxt(i,0),a+1,b,c)\)
\(if\ s_i=b,\ dp(i,a,b,c)\longrightarrow dp(i,a,b+1,c)\)
\(else \ dp(i,a,b,c)\longrightarrow dp(nxt(i,1),a,b+1,c)\)
\(if\ s_i=c, dp(i,a,b,c)\longrightarrow dp(i,a,b,c+1)\)
\(else dp(i,a,b,c)\longrightarrow dp(nxt(i,2),a,b,c+1)\)
时间复杂度: \(O(n^4)\) ,但远远跑不到,因为时间后3维是 \(50\times 50\times 50\) 的.
空间复杂度: \(O(n^4)\)
第一次提交: Wrong answer on test 32
my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n,m;
string s,t;
int nxt[152][3];
int dp[152][52][52][52];
inline void upd(int &a,int b){a+=b;a%=mod;}
int main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>s;
memset(nxt,-1,sizeof(nxt));
t+=s[0];
for(int i=1;i<n;i++){
if(s[i]==s[i-1])continue;
t+=s[i];
}
m=(int)t.size();
for(int i=0;i<m;i++){
if(t[i]!='a'){
for(int j=i+1;j<m;j++){
if(t[j]=='a'){
nxt[i][0]=j;
break;
}
}
}
if(t[i]!='b'){
for(int j=i+1;j<m;j++){
if(t[j]=='b'){
nxt[i][1]=j;
break;
}
}
}
if(t[i]!='c'){
for(int j=i+1;j<m;j++){
if(t[j]=='c'){
nxt[i][2]=j;
break;
}
}
}
}
// cout<<t<<endl;
// for(int i=0;i<m;i++)cout<<nxt[i][0]<<" ";cout<<endl;
// for(int i=0;i<m;i++)cout<<nxt[i][1]<<" ";cout<<endl;
// for(int i=0;i<m;i++)cout<<nxt[i][2]<<" ";cout<<endl;
int ans=0;
dp[0][0][0][0]=1;
for(int i=0;i<m;i++){
for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
if(dp[i][a][b][c]==0)continue;
// cout<<i<<","<<a<<","<<b<<","<<c<<","<<dp[i][a][b][c]<<endl;
if(a+b+c==n&&abs(a-b)<=1&&abs(a-c)<=1&&abs(b-c)<=1)upd(ans,dp[i][a][b][c]);
if(a+1<=n/3+1){
if(t[i]=='a')upd(dp[i][a+1][b][c],dp[i][a][b][c]);
if(t[i]!='a'&&nxt[i][0]!=-1)upd(dp[nxt[i][0]][a+1][b][c],dp[i][a][b][c]);
}
if(b+1<=n/3+1){
if(t[i]=='b')upd(dp[i][a][b+1][c],dp[i][a][b][c]);
if(t[i]!='b'&&nxt[i][1]!=-1)upd(dp[nxt[i][1]][a][b+1][c],dp[i][a][b][c]);
}
if(c+1<=n/3+1){
if(t[i]=='c')upd(dp[i][a][b][c+1],dp[i][a][b][c]);
if(t[i]!='c'&&nxt[i][2]!=-1)upd(dp[nxt[i][2]][a][b][c+1],dp[i][a][b][c]);
}
}
}
cout<<ans<<endl;
return 0;
}
/*inline? ll or int? size? min max?*/
ftiasch's code
// Codeforces Beta Round #17
// Problem C -- Balance
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 150
#define M 50
#define MOD 51123987
#define INC(x, a) x = (x + (a)) % MOD
using namespace std;
int n, m, dp[N][M + 2][M + 2][M + 2], next[N + 1][3];
char buffer[N + 1];
int main()
{
scanf("%d%s", &n, buffer);
m = (n + 2) / 3;
next[n][0] = next[n][1] = next[n][2] = n;
for(int i = n - 1; i != -1; -- i)
for(int j = 0; j != 3; ++ j)
next[i][j] = buffer[i] - 'a' == j? i: next[i + 1][j];
int answer = 0;
memset(dp, 0, sizeof(dp));
dp[0][0][0][0] = 1;
for(int i = 0; i != n; ++ i)
for(int a = 0; a <= m; ++ a)
for(int b = 0; b <= m; ++ b)
for(int c = 0; c <= m; ++ c)
{
if(a + b + c == n && abs(a - b) <= 1 && abs(b - c) <= 1 && abs(c - a) <= 1)
INC(answer, dp[i][a][b][c]);
if(next[i][0] != n)
INC(dp[next[i][0]][a + 1][b][c], dp[i][a][b][c]);
if(next[i][1] != n)
INC(dp[next[i][1]][a][b + 1][c], dp[i][a][b][c]);
if(next[i][2] != n)
INC(dp[next[i][2]][a][b][c + 1], dp[i][a][b][c]);
}
printf("%d\n", answer);
return 0;
}
除了使用\(nxt\)数组,也可以用 \(mask\) 记录不可以使用的字符. 这个思路我觉得很妙. 此时,还可以节省空间.
时间复杂度: \(O(n^4)\)
空间复杂度: \(O(n^3)\)
my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n,m,k=0;
string s,t;
int dp[9][150100];
int id[52][52][52];
int to[150100][3];
inline void upd(int &x,int y){x+=y;x%=mod;}
int main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>s;
t+=s[0];
for(int i=1;i<n;i++){
if(s[i]==s[i-1])continue;
t+=s[i];
}
m=(int)t.size();
memset(id,-1,sizeof id);
for(int a=0;a<=n/3+1;a++)for(int b=0;a+b<=n&&b<=n/3+1;b++)
for(int c=0;c+a+b<=n&&c<=n/3+1;c++){
id[a][b][c]=k++;
// cout<<"a:"<<a<<" b:"<<b<<" c:"<<c<<" "<<id[a][b][c]<<endl;
}
memset(to,-1,sizeof to);
for(int a=0;a<=n/3+1;a++)for(int b=0;a+b<=n&&b<=n/3+1;b++)
for(int c=0;c+a+b<=n&&c<=n/3+1;c++){
if(a+1<=n/3+1)to[id[a][b][c]][0]=id[a+1][b][c];
if(b+1<=n/3+1)to[id[a][b][c]][1]=id[a][b+1][c];
if(c+1<=n/3+1)to[id[a][b][c]][2]=id[a][b][c+1];
}
dp[0][id[0][0][0]]=1;
for(int i=0;i<m;i++){
int x=t[i]-'a';
for(int mask=0;mask<=8;mask++)if(~mask&(1<<x)){
for(int state=0;state<k;state++){
if(dp[mask][state]==0)continue;
if(to[state][x]!=-1){
upd(dp[8][to[state][x]],dp[mask][state]);
}
}
}
for(int mask=0;mask<=8;mask++)if(~mask&(1<<x)){
int mask2=mask&7|1<<x;
for(int state=0;state<k;state++){
if(dp[mask][state]==0)continue;
// if(mask2==1&&state==0&&i==0)
// cout<<"!"<<mask<<","<<state<<","<<dp[mask][state]<<endl;
upd(dp[mask2][state],dp[mask][state]);
dp[mask][state]=0;
}
}
/* for(int mask=0;mask<=8;mask++)for(int state=0;state<k;state++){
if(dp[mask][state]==0)continue;
cout<<i<<","<<mask<<","<<state<<":"<<dp[mask][state]<<endl;
}
*/
}
int ans=0;
for(int a=n/3;a<=n/3+1;a++)for(int b=n/3;b<=n/3+1;b++)for(int c=n/3;c<=n/3+1;c++){
for(int mask=0;mask<8;mask++){
if(dp[mask][id[a][b][c]]==0)continue;
if(a+b+c==n&&abs(a-b)<=1&&abs(a-c)<=1&&abs(b-c)<=1){
// cout<<mask<<","<<id[a][b][c]<<","<<dp[mask][id[a][b][c]]<<endl;
upd(ans,dp[mask][id[a][b][c]]);
}
}
}
cout<<ans<<endl;
return 0;
}
/*inline? ll or int? size? min max?*/
/*
4
cbba
*/
hos.lyric's code
// Codeforces Beta Round #17
// Problem C
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <utility>
#include <numeric>
#include <algorithm>
#include <bitset>
#include <complex>
using namespace std;
typedef unsigned uint;
typedef long long Int;
typedef vector<int> vint;
typedef pair<int,int> pint;
#define mp make_pair
template<class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cout << *i << " "; cout << endl; }
int in_c() { int c; for (; (c = getchar()) <= ' '; ) { if (!~c) throw ~0; } return c; }
int in() {
int x = 0, c;
for (; (uint)((c = getchar()) - '0') >= 10; ) { if (c == '-') return -in(); if (!~c) throw ~0; }
do { x = (x << 3) + (x << 1) + (c - '0'); } while ((uint)((c = getchar()) - '0') < 10);
return x;
}
const int MO = 51123987;
void pl(int &t, int f) { if ((t += f) >= MO) t -= MO; }
int M, N;
char S[160];
int K, is[160][160][160];
int to[600010][3];
int dp[9][600010];
int main() {
int i;
int a, b, c;
int z, x, y, s;
scanf("%d%s", &M, S);
N = unique(S, S + M) - S;
for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
is[a][b][c] = ++K;
}
for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
to[is[a][b][c]][0] = is[a + 1][b][c];
to[is[a][b][c]][1] = is[a][b + 1][c];
to[is[a][b][c]][2] = is[a][b][c + 1];
}
dp[0][is[0][0][0]] = 1;
for (i = 0; i < N; ++i) {
s = S[i] - 'a';
for (x = 0; x <= 8; ++x) if (~x & 1 << s) {
for (z = 1; z <= K; ++z) if (dp[x][z]) {
pl(dp[8][to[z][s]], dp[x][z]);
}
}
for (x = 0; x <= 8; ++x) if (~x & 1 << s) {
y = x & 7 | 1 << s;
for (z = 1; z <= K; ++z) if (dp[x][z]) {
pl(dp[y][z], dp[x][z]);
dp[x][z] = 0;
}
}
}
int ans = 0;
for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
if (a + b + c == M && abs(a - b) <= 1 && abs(a - c) <= 1 && abs(b - c) <= 1) {
for (x = 0; x < 8; ++x) {
pl(ans, dp[x][is[a][b][c]]);
}
}
}
cout << ans << endl;
return 0;
}
做法2
最后的 \(s'\) 需要是 \(s\) 的子序列,\(s'\) 相邻两位不能相同,并且要下表最小,考虑求 \(s'\) 的数量.
用 \(f(i,a,b,c,S)\) 表示到了原先串 \(t\) 的第 \(i\) 为,\(|a|\ |b|\ |c|\) 的数量分别为 \(a\ b\ c\) ,不可选择的字符为 \(mask\) 时的方案数.
\(mask\) 是用来限制下标最小和相邻不同这两个条件的.
这个做法不是我想出来的,但是我觉得用 \(mask\) 来限制下表最小的思路很妙,如果遇到第一个 \(a\) 不选,那么在选下一个字符之前第一个遇到的 \(a\) 是不可能被选中的.
有两种 \(dp\) 转移.
\(if\ t_i\notin S,\ f(i,a,b,c,S)\longrightarrow f(i+1,a,b,c,\{t_i\})\)
\(f(i,a,b,c,S)\longrightarrow f(i+1,a,b,c,S\cup\{ti\}))\)
考虑如何将得到的 \(s'\) 扩充成长度为 \(n\) 的字符串,要加入一些 \(a\), 一些\(b\),一些\(c\) .
可以用插板法,也可以用dp求.
\(dp(i,j)\) 表示将 \(j\) 个字母分成 \(i\) 个区间的方案数.
\(dp(i,j)=\sum\limits_{i-1}^{j-1} dp(i-1,k)\).
时间复杂度: \(O(n^4+n^2)\)
空间复杂度: \(O(n^3)\)
第一次提交: Accept
my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n;
string s;
int f[2][53][53][53][8];//pos a b c mask
int dp[53][53];//i->j
inline void upd(int &x,int y){x+=y;x%=mod;}
int main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>s;
f[0][0][0][0][0]=1;
for(int i=0;i<n;i++){
int cur=i&1,nxt=cur^1,tmp=s[i]-'a';
memset(f[nxt],0,sizeof f[nxt]);
for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
for(int mask=0;mask<8;mask++){
if(f[cur][a][b][c][mask]==0)continue;
// cout<<i<<","<<cur<<","<<a<<","<<b<<","<<c<<","<<mask<<","<<tmp<<","
// <<f[cur][a][b][c][mask]<<endl;
if(!(mask&(1<<tmp))){
if(tmp==0)upd(f[nxt][a+1][b][c][1<<tmp],f[cur][a][b][c][mask]);
if(tmp==1)upd(f[nxt][a][b+1][c][1<<tmp],f[cur][a][b][c][mask]);
if(tmp==2)upd(f[nxt][a][b][c+1][1<<tmp],f[cur][a][b][c][mask]);
}
upd(f[nxt][a][b][c][mask|(1<<tmp)],f[cur][a][b][c][mask]);
}
}
}
// cout<<"ok"<<endl;
dp[0][0]=1;
for(int i=1;i<=n/3+1;i++)for(int j=i;j<=n/3+1;j++){
for(int mid=i-1;mid<=j-1;mid++){
upd(dp[i][j],dp[i-1][mid]);
}
// cout<<i<<","<<j<<","<<dp[i][j]<<endl;
}
// for(int i=1;i<=n/3+1;i++)for(int j=i;j<=n/3+1;j++)
// cout<<i<<","<<j<<","<<dp[i][j]<<endl;
int ans=0;
for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
for(int mask=0;mask<8;mask++){
if(f[n&1][a][b][c][mask]==0)continue;
// cout<<a<<","<<b<<","<<c<<","<<mask<<":"<<f[n&1][a][b][c][mask]<<endl;
for(int ta=n/3;ta<=n/3+1;ta++)for(int tb=n/3;tb<=n/3+1;tb++)
for(int tc=n/3;tc<=n/3+1;tc++){
if(ta+tb+tc==n&&abs(ta-tb)<=1&&abs(ta-tc)<=1&&abs(tb-tc)<=1){
int tmp=1ll*f[n&1][a][b][c][mask]*dp[a][ta]%mod*
dp[b][tb]%mod*dp[c][tc]%mod;
// cout<<ta<<","<<tc<<","<<tc<<","<<dp[a][ta]<<","<<dp[b][tb]<<","<<
// dp[c][tc]<<","<<tmp<<endl;
upd(ans,tmp);
}
}
}
}
cout<<ans<<endl;
return 0;
}
/*inline? ll or int? size? min max?*/
rng_58's code
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <deque>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <functional>
#include <utility>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <cstdio>
using namespace std;
#define REP(i,n) for((i)=0;(i)<(int)(n);(i)++)
#define foreach(c,itr) for(__typeof((c).begin()) itr=(c).begin();itr!=(c).end();itr++)
template <class T> inline string itos(T n) {return (n)<0?"-"+itos(-(n)):(n)<10?(string)("")+(char)('0'+(n)):itos((n)/10)+itos((n)%10);}
#define MOD 51123987
#define PLUS(x,y) {(x) += (y); if((x) >= MOD) (x) -= MOD;}
typedef long long ll;
int dp[2][152][77][52][8]; // pos, a, b, c, mask
int dp2[160][160];
string pre(string s){
int n=s.length(),i,cnt[3]={0};
REP(i,n) cnt[s[i]-'a']++;
pair <int, char> p[3];
REP(i,3) p[i] = make_pair(cnt[i],'a'+i);
sort(p,p+3); reverse(p,p+3);
map <char, char> mp;
REP(i,3) mp[p[i].second] = 'a' + i;
REP(i,n) s[i] = mp[s[i]];
return s;
}
int main(void){
int N,i,j,k,A=0,B=0,C=0,a,b,c,mask;
string s;
cin >> N >> s;
s = pre(s);
dp[0][0][0][0][0] = 1;
REP(i,N){
int cur = i%2, next = (i+1)%2;
REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) dp[next][a][b][c][mask] = 0;
int x = s[i] - 'a';
REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) if(dp[cur][a][b][c][mask] > 0){
if(!(mask&(1<<x))) PLUS(dp[next][a+(x==0?1:0)][b+(x==1?1:0)][c+(x==2?1:0)][(1<<x)],dp[cur][a][b][c][mask]);
PLUS(dp[next][a][b][c][mask|(1<<x)],dp[cur][a][b][c][mask]);
}
if(x == 0) A++; if(x == 1) B++; if(x == 2) C++;
}
dp2[0][0] = 1;
REP(i,N+5) REP(j,N+5) if(dp2[i][j] > 0) for(k=j+1;k<=N+5;k++) PLUS(dp2[i+1][k],dp2[i][j]);
int ans = 0;
REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) if(dp[N%2][a][b][c][mask] > 0){
ll tmp = dp[N%2][a][b][c][mask];
for(int a2=N/3;a2<=N/3+1;a2++) for(int b2=N/3;b2<=N/3+1;b2++) for(int c2=N/3;c2<=N/3+1;c2++) if(a2+b2+c2 == N){
ll tmp2 = tmp * dp2[a][a2] % MOD * dp2[b][b2] % MOD * dp2[c][c2] % MOD;
ans = (ans + tmp2) % MOD;
}
}
cout << ans << endl;
return 0;
}
小结
第一次见到此题是在20年的寒假,那时候死活想不出来这个dp是如何设计状态,如何转移的,现在能在20min中内想出来,还是有些进步的.
我觉得这是一个非常好的题目,从各个方面理解这个题目,不同的思路,不同的方法,最直观的方法是我的思路,但是,如果空间是64mb的话,肯定是不可以的.
我觉得用 \(mask\) 的方法不是很自然,但是很巧妙,转移也变得简单,更重要的是,可以将空间优化成 \(n^3\) ,偷偷积累成一个trick.
个人认为方法二没有方法一来得直观和简约. 但是给计数题的求法有启发.

浙公网安备 33010602011771号