Luogu P1310 [NOIP 2011 普及组] 表达式的值 题解 [ 绿 ] [ 表达式树 ] [ 栈 ] [ 树形 DP ]

表达式的值:这题就差把用表达式树树形 DP 写脸上了。

首先肯定是要建表达式树的。建出表达式树后,定义 \(dp_{i,0/1}\) 表示节点 \(i\) 取值为 \(0/1\) 的方案数,分三种情况讨论即可:

  • 该节点为叶子,\(dp_{i,0/1}=1\)
  • 该节点为与操作,\(dp_{i,0}=dp_{l,0}\times dp_{r,0}+dp_{l,1}\times dp_{r,0}+dp_{l,0}\times dp_{r,1},dp_{i,1}=dp_{l,1}\times dp_{r,1}\)
  • 该节点为或操作,\(dp_{i,1}=dp_{l,1}\times dp_{r,1}+dp_{l,1}\times dp_{r,0}+dp_{l,0}\times dp_{r,1},dp_{i,0}=dp_{l,0}\times dp_{r,0}\)

其中 \(l,r\) 分别代表 \(i\) 的两个儿子。

注意几个建表达式树的细节:

  • 当该字符为乘或加,且前一个字符为左括号时,需要添加一个叶子节点。
  • 当该字符为乘或加,且后一个字符为左括号时,不需要添加叶子节点;否则添加叶子节点。
  • 建表达式树可以在 cal() 函数中实现。
  • 一个操作入栈后,只有优先级大于等于它的操作会被操作掉,之后这个操作才入栈(也就是说该操作目前并不执行)。
  • 为了让所有操作都被执行,要在表达式前后加括号
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
const int N=1000005,mod=10007;
char s[N];
int n,stk1[N],tp1,stk2[N],tp2,cnt,a[N],dp[N][2];
vector<int>g[N];
void cal()
{
    g[++cnt].push_back(stk1[tp1--]);
    g[cnt].push_back(stk1[tp1--]);
    stk1[++tp1]=cnt;
    a[cnt]=stk2[tp2];
    tp2--;    
}
void dfs(int u)
{
    if(a[u]==0)
    {
        dp[u][0]=dp[u][1]=1;
        return;
    }
    int s1=g[u][0],s2=g[u][1];
    dfs(s1);
    dfs(s2);
    if(a[u]==1)
    {
        dp[u][0]=(dp[s1][0]*dp[s2][0]%mod);
        dp[u][1]=((dp[s1][0]*dp[s2][1]%mod)+(dp[s1][1]*dp[s2][0]%mod)+(dp[s1][1]*dp[s2][1]%mod))%mod;
    }
    else
    {
        dp[u][1]=(dp[s1][1]*dp[s2][1]%mod);
        dp[u][0]=((dp[s1][0]*dp[s2][1]%mod)+(dp[s1][1]*dp[s2][0]%mod)+(dp[s1][0]*dp[s2][0]%mod))%mod;
    }
}
int main()
{
    // freopen("P1310.in","r",stdin);
    // freopen("P1310.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>s+1;
    s[0]='(',s[n+1]=')';
    for(int i=0;i<=n+1;i++)
    {
        if((s[i]=='+'||s[i]=='*')&&(s[i-1]=='('))
            stk1[++tp1]=++cnt;
        if(s[i]=='+')
        {
            while(stk2[tp2]>=1)cal();
            stk2[++tp2]=1;
            if(s[i+1]!='(')stk1[++tp1]=++cnt;
        }
        else if(s[i]=='*')
        {
            while(stk2[tp2]>=2)cal();
            stk2[++tp2]=2;
            if(s[i+1]!='(')stk1[++tp1]=++cnt;
        }
        else if(s[i]=='(')
            stk2[++tp2]=-1;
        else
        {
            while(stk2[tp2]!=-1)cal();
            tp2--;
        }
    }
    dfs(cnt);
    cout<<dp[cnt][0]%mod;
    return 0;
}
posted @ 2025-05-14 23:46  KS_Fszha  阅读(25)  评论(0)    收藏  举报