Luogu P1310 [NOIP 2011 普及组] 表达式的值 题解 [ 绿 ] [ 表达式树 ] [ 栈 ] [ 树形 DP ]
表达式的值:这题就差把用表达式树树形 DP 写脸上了。
首先肯定是要建表达式树的。建出表达式树后,定义 \(dp_{i,0/1}\) 表示节点 \(i\) 取值为 \(0/1\) 的方案数,分三种情况讨论即可:
- 该节点为叶子,\(dp_{i,0/1}=1\)。
- 该节点为与操作,\(dp_{i,0}=dp_{l,0}\times dp_{r,0}+dp_{l,1}\times dp_{r,0}+dp_{l,0}\times dp_{r,1},dp_{i,1}=dp_{l,1}\times dp_{r,1}\)。
- 该节点为或操作,\(dp_{i,1}=dp_{l,1}\times dp_{r,1}+dp_{l,1}\times dp_{r,0}+dp_{l,0}\times dp_{r,1},dp_{i,0}=dp_{l,0}\times dp_{r,0}\)。
其中 \(l,r\) 分别代表 \(i\) 的两个儿子。
注意几个建表达式树的细节:
- 当该字符为乘或加,且前一个字符为左括号时,需要添加一个叶子节点。
- 当该字符为乘或加,且后一个字符为左括号时,不需要添加叶子节点;否则添加叶子节点。
- 建表达式树可以在
cal()函数中实现。 - 一个操作入栈后,只有优先级大于等于它的操作会被操作掉,之后这个操作才入栈(也就是说该操作目前并不执行)。
- 为了让所有操作都被执行,要在表达式前后加括号。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N=1000005,mod=10007;
char s[N];
int n,stk1[N],tp1,stk2[N],tp2,cnt,a[N],dp[N][2];
vector<int>g[N];
void cal()
{
g[++cnt].push_back(stk1[tp1--]);
g[cnt].push_back(stk1[tp1--]);
stk1[++tp1]=cnt;
a[cnt]=stk2[tp2];
tp2--;
}
void dfs(int u)
{
if(a[u]==0)
{
dp[u][0]=dp[u][1]=1;
return;
}
int s1=g[u][0],s2=g[u][1];
dfs(s1);
dfs(s2);
if(a[u]==1)
{
dp[u][0]=(dp[s1][0]*dp[s2][0]%mod);
dp[u][1]=((dp[s1][0]*dp[s2][1]%mod)+(dp[s1][1]*dp[s2][0]%mod)+(dp[s1][1]*dp[s2][1]%mod))%mod;
}
else
{
dp[u][1]=(dp[s1][1]*dp[s2][1]%mod);
dp[u][0]=((dp[s1][0]*dp[s2][1]%mod)+(dp[s1][1]*dp[s2][0]%mod)+(dp[s1][0]*dp[s2][0]%mod))%mod;
}
}
int main()
{
// freopen("P1310.in","r",stdin);
// freopen("P1310.out","w",stdout);
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>s+1;
s[0]='(',s[n+1]=')';
for(int i=0;i<=n+1;i++)
{
if((s[i]=='+'||s[i]=='*')&&(s[i-1]=='('))
stk1[++tp1]=++cnt;
if(s[i]=='+')
{
while(stk2[tp2]>=1)cal();
stk2[++tp2]=1;
if(s[i+1]!='(')stk1[++tp1]=++cnt;
}
else if(s[i]=='*')
{
while(stk2[tp2]>=2)cal();
stk2[++tp2]=2;
if(s[i+1]!='(')stk1[++tp1]=++cnt;
}
else if(s[i]=='(')
stk2[++tp2]=-1;
else
{
while(stk2[tp2]!=-1)cal();
tp2--;
}
}
dfs(cnt);
cout<<dp[cnt][0]%mod;
return 0;
}

浙公网安备 33010602011771号