codeforces 17c balance

给出一个长度为 \(n\) 的字符串 \(s\) ,有小写字母 \(a\) \(b\) \(c\) 组成,可以在字符串上进行如下两种操作:

  1. 选中两个相邻的字符,把第二个字符修改成第一个字符.
  2. 选中两个相邻的字符,把第一个字符修改成第二个字符.

定义 \(|a|\) 为字符串中 \(a\) 的出现次数,\(|b|\) 为字符串中 \(b\) 的出现次数, \(|c|\) 为字符串中 \(c\) 的出现次数.

如果一个字符串的 \(-1\leq |a|-|b|\leq 1\) , \(-1\leq |a|-|c|\leq 1\) , \(-1\leq |b|-|c|\leq 1\),则称此字符串为balanced string.

问,可以通过上面的操作获得多少个balanced string?

\(1\leq n\leq 150\)

做法1 & my solution

我一开始想了很久操作的形态和变化,发现都不好设计dp状态,因为是考虑上了操作,所以还有重复的问题.

这个时候,我就想,跳出操作的圈子,看操作后的序列的性质,发现字母之间的相对顺序是不变的.

两个字符串 \(A\) , \(B\) , 去重后得到了 \(A'\) , \(B'\) . 如果 \(A\) 操作可以得到 \(B\) ,那么 \(B'\) 必须是 \(A'\) 里的一个子序列.

此时,可以去重后得到 \(s\) . 此时,又出现了一个问题,可能下标不同的子序列最后得到的字符串相同.

遇到这种问题,通常都是规定取最小的下标. 用 \(nxt(i,0/1/2)\) 表示 \(i\) 之后第一个为 \(a/b/c\) 的位置.

可以设计得到 \(dp(i,a,b,c)\) 为到了 \(s\) 串中的第 \(i\) 位,\(|a|\ |b|\ |c|\) 分别为 \(a\ b\ c\) 时的方案数.

\(dp\) 转移有 \(6\) 种,

\(if\ s_i=a,\ dp(i,a,b,c)\longrightarrow dp(i,a+1,b,c)\)

\(else \ dp(i,a,b,c)\longrightarrow dp(nxt(i,0),a+1,b,c)\)

\(if\ s_i=b,\ dp(i,a,b,c)\longrightarrow dp(i,a,b+1,c)\)

\(else \ dp(i,a,b,c)\longrightarrow dp(nxt(i,1),a,b+1,c)\)

\(if\ s_i=c, dp(i,a,b,c)\longrightarrow dp(i,a,b,c+1)\)

\(else dp(i,a,b,c)\longrightarrow dp(nxt(i,2),a,b,c+1)\)

时间复杂度: \(O(n^4)\) ,但远远跑不到,因为时间后3维是 \(50\times 50\times 50\) 的.

空间复杂度: \(O(n^4)\)

第一次提交: Wrong answer on test 32

my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n,m;
string s,t;
int nxt[152][3];
int dp[152][52][52][52];
inline void upd(int &a,int b){a+=b;a%=mod;}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
	cin>>n>>s;
	memset(nxt,-1,sizeof(nxt));
	t+=s[0];
	for(int i=1;i<n;i++){
		if(s[i]==s[i-1])continue;
		t+=s[i];
	}
	m=(int)t.size();
	for(int i=0;i<m;i++){
		if(t[i]!='a'){
			for(int j=i+1;j<m;j++){
				if(t[j]=='a'){
					nxt[i][0]=j;
					break;
				}
			}
		}
		if(t[i]!='b'){
			for(int j=i+1;j<m;j++){
				if(t[j]=='b'){
					nxt[i][1]=j;
					break;
				}
			}
		}
		if(t[i]!='c'){
			for(int j=i+1;j<m;j++){
				if(t[j]=='c'){
					nxt[i][2]=j;
					break;
				}
			}
		}
	}
//	cout<<t<<endl;
//	for(int i=0;i<m;i++)cout<<nxt[i][0]<<" ";cout<<endl;
//	for(int i=0;i<m;i++)cout<<nxt[i][1]<<" ";cout<<endl;
//	for(int i=0;i<m;i++)cout<<nxt[i][2]<<" ";cout<<endl;
	int ans=0;
	dp[0][0][0][0]=1;
	for(int i=0;i<m;i++){
		for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
			if(dp[i][a][b][c]==0)continue;
		//	cout<<i<<","<<a<<","<<b<<","<<c<<","<<dp[i][a][b][c]<<endl;
			if(a+b+c==n&&abs(a-b)<=1&&abs(a-c)<=1&&abs(b-c)<=1)upd(ans,dp[i][a][b][c]);
			if(a+1<=n/3+1){
				if(t[i]=='a')upd(dp[i][a+1][b][c],dp[i][a][b][c]);
				if(t[i]!='a'&&nxt[i][0]!=-1)upd(dp[nxt[i][0]][a+1][b][c],dp[i][a][b][c]);
			}
			if(b+1<=n/3+1){
				if(t[i]=='b')upd(dp[i][a][b+1][c],dp[i][a][b][c]);
				if(t[i]!='b'&&nxt[i][1]!=-1)upd(dp[nxt[i][1]][a][b+1][c],dp[i][a][b][c]);
			}
			if(c+1<=n/3+1){
				if(t[i]=='c')upd(dp[i][a][b][c+1],dp[i][a][b][c]);
				if(t[i]!='c'&&nxt[i][2]!=-1)upd(dp[nxt[i][2]][a][b][c+1],dp[i][a][b][c]);
			}
		}
	}	
	cout<<ans<<endl;
	return 0;
}
/*inline? ll or int? size? min max?*/

ftiasch's code
// Codeforces Beta Round #17
// Problem C -- Balance
#include <cstdio>
#include <cstring>
#include <algorithm>

#define N 150
#define M 50
#define MOD 51123987

#define INC(x, a) x = (x + (a)) % MOD

using namespace std;

int n, m, dp[N][M + 2][M + 2][M + 2], next[N + 1][3];
char buffer[N + 1];

int main()
{
    scanf("%d%s", &n, buffer);
    m = (n + 2) / 3;
    next[n][0] = next[n][1] = next[n][2] = n;
    for(int i = n - 1; i != -1; -- i)
        for(int j = 0; j != 3; ++ j)
            next[i][j] = buffer[i] - 'a' == j? i: next[i + 1][j];
    int answer = 0;
    memset(dp, 0, sizeof(dp));
    dp[0][0][0][0] = 1;
    for(int i = 0; i != n; ++ i)
        for(int a = 0; a <= m; ++ a)
            for(int b = 0; b <= m; ++ b)
                for(int c = 0; c <= m; ++ c)
                {
                    if(a + b + c == n && abs(a - b) <= 1 && abs(b - c) <= 1 && abs(c - a) <= 1)
                        INC(answer, dp[i][a][b][c]);
                    if(next[i][0] != n)
                        INC(dp[next[i][0]][a + 1][b][c], dp[i][a][b][c]);
                    if(next[i][1] != n)
                        INC(dp[next[i][1]][a][b + 1][c], dp[i][a][b][c]);                   
                    if(next[i][2] != n)
                        INC(dp[next[i][2]][a][b][c + 1], dp[i][a][b][c]);
                }
    printf("%d\n", answer);
    return 0;
}

除了使用\(nxt\)数组,也可以用 \(mask\) 记录不可以使用的字符. 这个思路我觉得很妙. 此时,还可以节省空间.

时间复杂度: \(O(n^4)\)

空间复杂度: \(O(n^3)\)

my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n,m,k=0;
string s,t;
int dp[9][150100];
int id[52][52][52];
int to[150100][3];
inline void upd(int &x,int y){x+=y;x%=mod;}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
	cin>>n>>s;
	t+=s[0];
	for(int i=1;i<n;i++){
		if(s[i]==s[i-1])continue;
		t+=s[i]; 
	}
	m=(int)t.size();
	memset(id,-1,sizeof id);
	for(int a=0;a<=n/3+1;a++)for(int b=0;a+b<=n&&b<=n/3+1;b++)
		for(int c=0;c+a+b<=n&&c<=n/3+1;c++){
			id[a][b][c]=k++;
		//	cout<<"a:"<<a<<" b:"<<b<<" c:"<<c<<" "<<id[a][b][c]<<endl;
		}
	memset(to,-1,sizeof to);
	for(int a=0;a<=n/3+1;a++)for(int b=0;a+b<=n&&b<=n/3+1;b++)
		for(int c=0;c+a+b<=n&&c<=n/3+1;c++){
			if(a+1<=n/3+1)to[id[a][b][c]][0]=id[a+1][b][c];
			if(b+1<=n/3+1)to[id[a][b][c]][1]=id[a][b+1][c];
			if(c+1<=n/3+1)to[id[a][b][c]][2]=id[a][b][c+1];
		}
	dp[0][id[0][0][0]]=1;
	for(int i=0;i<m;i++){
		int x=t[i]-'a';
		for(int mask=0;mask<=8;mask++)if(~mask&(1<<x)){
			for(int state=0;state<k;state++){
				if(dp[mask][state]==0)continue;
				if(to[state][x]!=-1){
					upd(dp[8][to[state][x]],dp[mask][state]);
				}
			}
		}
		for(int mask=0;mask<=8;mask++)if(~mask&(1<<x)){
			int mask2=mask&7|1<<x;
			for(int state=0;state<k;state++){
				if(dp[mask][state]==0)continue;
			//	if(mask2==1&&state==0&&i==0)
			//		cout<<"!"<<mask<<","<<state<<","<<dp[mask][state]<<endl;
				upd(dp[mask2][state],dp[mask][state]);
				dp[mask][state]=0;
			}
		}
	/*	for(int mask=0;mask<=8;mask++)for(int state=0;state<k;state++){
			if(dp[mask][state]==0)continue;
			cout<<i<<","<<mask<<","<<state<<":"<<dp[mask][state]<<endl;
		}
	*/
	}
	int ans=0;
	for(int a=n/3;a<=n/3+1;a++)for(int b=n/3;b<=n/3+1;b++)for(int c=n/3;c<=n/3+1;c++){
		for(int mask=0;mask<8;mask++){
			if(dp[mask][id[a][b][c]]==0)continue;
			if(a+b+c==n&&abs(a-b)<=1&&abs(a-c)<=1&&abs(b-c)<=1){
		//	cout<<mask<<","<<id[a][b][c]<<","<<dp[mask][id[a][b][c]]<<endl;
				upd(ans,dp[mask][id[a][b][c]]);
			}
		}
	}
	cout<<ans<<endl;
	return 0;
}
/*inline? ll or int? size? min max?*/
/*
4
cbba
*/
hos.lyric's code
//  Codeforces Beta Round #17
//  Problem C

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <utility>
#include <numeric>
#include <algorithm>
#include <bitset>
#include <complex>

using namespace std;

typedef unsigned uint;
typedef long long Int;
typedef vector<int> vint;
typedef pair<int,int> pint;
#define mp make_pair

template<class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cout << *i << " "; cout << endl; }
int in_c() { int c; for (; (c = getchar()) <= ' '; ) { if (!~c) throw ~0; } return c; }
int in() {
    int x = 0, c;
    for (; (uint)((c = getchar()) - '0') >= 10; ) { if (c == '-') return -in(); if (!~c) throw ~0; }
    do { x = (x << 3) + (x << 1) + (c - '0'); } while ((uint)((c = getchar()) - '0') < 10);
    return x;
}

const int MO = 51123987;

void pl(int &t, int f) { if ((t += f) >= MO) t -= MO; }

int M, N;
char S[160];
int K, is[160][160][160];
int to[600010][3];
int dp[9][600010];

int main() {
    int i;
    int a, b, c;
    int z, x, y, s;
    
    scanf("%d%s", &M, S);
    N = unique(S, S + M) - S;
    for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
        is[a][b][c] = ++K;
    }
    for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
        to[is[a][b][c]][0] = is[a + 1][b][c];
        to[is[a][b][c]][1] = is[a][b + 1][c];
        to[is[a][b][c]][2] = is[a][b][c + 1];
    }
    dp[0][is[0][0][0]] = 1;
    for (i = 0; i < N; ++i) {
        s = S[i] - 'a';
        for (x = 0; x <= 8; ++x) if (~x & 1 << s) {
            for (z = 1; z <= K; ++z) if (dp[x][z]) {
                pl(dp[8][to[z][s]], dp[x][z]);
            }
        }
        for (x = 0; x <= 8; ++x) if (~x & 1 << s) {
            y = x & 7 | 1 << s;
            for (z = 1; z <= K; ++z) if (dp[x][z]) {
                pl(dp[y][z], dp[x][z]);
                dp[x][z] = 0;
            }
        }
    }
    
    int ans = 0;
    for (a = 0; a <= M; ++a) for (b = 0; a + b <= M; ++b) for (c = 0; a + b + c <= M; ++c) {
        if (a + b + c == M && abs(a - b) <= 1 && abs(a - c) <= 1 && abs(b - c) <= 1) {
            for (x = 0; x < 8; ++x) {
                pl(ans, dp[x][is[a][b][c]]);
            }
        }
    }
    cout << ans << endl;
    
    return 0;
}

做法2

最后的 \(s'\) 需要是 \(s\) 的子序列,\(s'\) 相邻两位不能相同,并且要下表最小,考虑求 \(s'\) 的数量.

\(f(i,a,b,c,S)\) 表示到了原先串 \(t\) 的第 \(i\) 为,\(|a|\ |b|\ |c|\) 的数量分别为 \(a\ b\ c\) ,不可选择的字符为 \(mask\) 时的方案数.

\(mask\) 是用来限制下标最小和相邻不同这两个条件的.

这个做法不是我想出来的,但是我觉得用 \(mask\) 来限制下表最小的思路很妙,如果遇到第一个 \(a\) 不选,那么在选下一个字符之前第一个遇到的 \(a\) 是不可能被选中的.

有两种 \(dp\) 转移.

\(if\ t_i\notin S,\ f(i,a,b,c,S)\longrightarrow f(i+1,a,b,c,\{t_i\})\)

\(f(i,a,b,c,S)\longrightarrow f(i+1,a,b,c,S\cup\{ti\}))\)

考虑如何将得到的 \(s'\) 扩充成长度为 \(n\) 的字符串,要加入一些 \(a\), 一些\(b\),一些\(c\) .

可以用插板法,也可以用dp求.

\(dp(i,j)\) 表示将 \(j\) 个字母分成 \(i\) 个区间的方案数.

\(dp(i,j)=\sum\limits_{i-1}^{j-1} dp(i-1,k)\).

时间复杂度: \(O(n^4+n^2)\)

空间复杂度: \(O(n^3)\)

第一次提交: Accept

my code
#include<bits/stdc++.h>
using namespace std;
const int mod=51123987;
int n;
string s;
int f[2][53][53][53][8];//pos a b c mask
int dp[53][53];//i->j
inline void upd(int &x,int y){x+=y;x%=mod;}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
	cin>>n>>s;
	f[0][0][0][0][0]=1;
	for(int i=0;i<n;i++){
		int cur=i&1,nxt=cur^1,tmp=s[i]-'a';
		memset(f[nxt],0,sizeof f[nxt]);
		for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
			for(int mask=0;mask<8;mask++){
				if(f[cur][a][b][c][mask]==0)continue;
			//	cout<<i<<","<<cur<<","<<a<<","<<b<<","<<c<<","<<mask<<","<<tmp<<","
			//	<<f[cur][a][b][c][mask]<<endl;
				if(!(mask&(1<<tmp))){
					if(tmp==0)upd(f[nxt][a+1][b][c][1<<tmp],f[cur][a][b][c][mask]);
					if(tmp==1)upd(f[nxt][a][b+1][c][1<<tmp],f[cur][a][b][c][mask]);
					if(tmp==2)upd(f[nxt][a][b][c+1][1<<tmp],f[cur][a][b][c][mask]);
				}
				upd(f[nxt][a][b][c][mask|(1<<tmp)],f[cur][a][b][c][mask]);
			}
		}
	}
//	cout<<"ok"<<endl;
	dp[0][0]=1;
	for(int i=1;i<=n/3+1;i++)for(int j=i;j<=n/3+1;j++){
		for(int mid=i-1;mid<=j-1;mid++){
			upd(dp[i][j],dp[i-1][mid]);
		}
	//	cout<<i<<","<<j<<","<<dp[i][j]<<endl;
	}
//	for(int i=1;i<=n/3+1;i++)for(int j=i;j<=n/3+1;j++)
//		cout<<i<<","<<j<<","<<dp[i][j]<<endl;
	int ans=0;
	for(int a=0;a<=n/3+1;a++)for(int b=0;b<=n/3+1;b++)for(int c=0;c<=n/3+1;c++){
		for(int mask=0;mask<8;mask++){
			if(f[n&1][a][b][c][mask]==0)continue;
		//	cout<<a<<","<<b<<","<<c<<","<<mask<<":"<<f[n&1][a][b][c][mask]<<endl;
			for(int ta=n/3;ta<=n/3+1;ta++)for(int tb=n/3;tb<=n/3+1;tb++)
				for(int tc=n/3;tc<=n/3+1;tc++){
					if(ta+tb+tc==n&&abs(ta-tb)<=1&&abs(ta-tc)<=1&&abs(tb-tc)<=1){
						int tmp=1ll*f[n&1][a][b][c][mask]*dp[a][ta]%mod*
						dp[b][tb]%mod*dp[c][tc]%mod;
					//	cout<<ta<<","<<tc<<","<<tc<<","<<dp[a][ta]<<","<<dp[b][tb]<<","<<
					//	dp[c][tc]<<","<<tmp<<endl;
						upd(ans,tmp);
					}
				}
		}
	}
	cout<<ans<<endl;
	return 0;
}
/*inline? ll or int? size? min max?*/

rng_58's code
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <deque>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <functional>
#include <utility>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <cstdio>

using namespace std;

#define REP(i,n) for((i)=0;(i)<(int)(n);(i)++)
#define foreach(c,itr) for(__typeof((c).begin()) itr=(c).begin();itr!=(c).end();itr++)
template <class T> inline string itos(T n) {return (n)<0?"-"+itos(-(n)):(n)<10?(string)("")+(char)('0'+(n)):itos((n)/10)+itos((n)%10);}

#define MOD 51123987
#define PLUS(x,y) {(x) += (y); if((x) >= MOD) (x) -= MOD;}

typedef long long ll;

int dp[2][152][77][52][8]; // pos, a, b, c, mask
int dp2[160][160];

string pre(string s){
    int n=s.length(),i,cnt[3]={0};
    
    REP(i,n) cnt[s[i]-'a']++;
    
    pair <int, char> p[3];
    REP(i,3) p[i] = make_pair(cnt[i],'a'+i);
    sort(p,p+3); reverse(p,p+3);
    
    map <char, char> mp;
    REP(i,3) mp[p[i].second] = 'a' + i;
    
    REP(i,n) s[i] = mp[s[i]];
    return s;
}

int main(void){
    int N,i,j,k,A=0,B=0,C=0,a,b,c,mask;
    
    string s;
    cin >> N >> s;
    s = pre(s);
    
    dp[0][0][0][0][0] = 1;
    REP(i,N){
        int cur = i%2, next = (i+1)%2;
        REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) dp[next][a][b][c][mask] = 0;
        
        int x = s[i] - 'a';
        REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) if(dp[cur][a][b][c][mask] > 0){
            if(!(mask&(1<<x))) PLUS(dp[next][a+(x==0?1:0)][b+(x==1?1:0)][c+(x==2?1:0)][(1<<x)],dp[cur][a][b][c][mask]);
            PLUS(dp[next][a][b][c][mask|(1<<x)],dp[cur][a][b][c][mask]);
        }
        
        if(x == 0) A++; if(x == 1) B++; if(x == 2) C++;
    }
    
    dp2[0][0] = 1;
    REP(i,N+5) REP(j,N+5) if(dp2[i][j] > 0) for(k=j+1;k<=N+5;k++) PLUS(dp2[i+1][k],dp2[i][j]);
    
    int ans = 0;
    REP(a,A+1) REP(b,B+1) REP(c,C+1) REP(mask,8) if(dp[N%2][a][b][c][mask] > 0){
        ll tmp = dp[N%2][a][b][c][mask];
        for(int a2=N/3;a2<=N/3+1;a2++) for(int b2=N/3;b2<=N/3+1;b2++) for(int c2=N/3;c2<=N/3+1;c2++) if(a2+b2+c2 == N){
            ll tmp2 = tmp * dp2[a][a2] % MOD * dp2[b][b2] % MOD * dp2[c][c2] % MOD;
            ans = (ans + tmp2) % MOD;
        }
    }
    cout << ans << endl;
            
    return 0;
}
小结

第一次见到此题是在20年的寒假,那时候死活想不出来这个dp是如何设计状态,如何转移的,现在能在20min中内想出来,还是有些进步的.

我觉得这是一个非常好的题目,从各个方面理解这个题目,不同的思路,不同的方法,最直观的方法是我的思路,但是,如果空间是64mb的话,肯定是不可以的.

我觉得用 \(mask\) 的方法不是很自然,但是很巧妙,转移也变得简单,更重要的是,可以将空间优化成 \(n^3\) ,偷偷积累成一个trick.

个人认为方法二没有方法一来得直观和简约. 但是给计数题的求法有启发.

posted @ 2021-07-01 16:20  xyangh  阅读(2359)  评论(0)    收藏  举报