洛谷 P5658 [CSP-S 2019] 括号树 题解
题目大意
给定一棵树,每个节点有一个括号。对于每个节点 \(i\),定义 \(s_i\) 为从根节点到 \(i\) 的路径上所有括号按顺序组成的字符串。求每个 \(s_i\) 中互不相同的合法括号子串的个数 \(k_i\)。
思路
首先,\(k_i\) 可以从父节点递推得到,\(k_i=k_{f_i}+a_i\)。其中 \(a_i\) 为以节点 \(i\) 结尾的合法括号序列数量。因此只要求出每个节点的 \(a\)。
以 ( 为 \(1\) ) 为 \(−1\) 做树上前缀和,设点 \(u\)
的前缀和为 \(sum_u\)。则以 \(u\) 结尾的合法括号子串的开头 \(v\) 需要满足:
- \(sum_{f_v}=sum_u\)。
- 对于 \(v\to u\) 这条链上的所有点 \(x\),有 \(sum_x\ge sum_u\)。
在 DFS 过程中开一棵值域线段树维护 \(1\to u\) 这条链上每个 \(sum\) 值对应的最大节点深度。这样就能找到 \(sum_p<sum_u\) 且深度最大的节点 \(p\)。
设 \(ask(x,y)\) 表示 \(1\to x\) 链上 \(sum=y\) 的节点数量。则 \(a_u=ask(f_u,k)-ask(p,k)\)。
第一遍 DFS 求出所有询问并离线下来。
第二遍 DFS 求出所有点的 \(a\)。
第三遍 DFS 对 \(a\) 做树上前缀和得到所有点的 \(k\) 即可。
Code
#include <bits/stdc++.h>
#define rept(i,a,b) for(int i(a);i<=b;++i)
#define ls(p) ((p)<<1)
#define rs(p) ((p)<<1|1)
#define eb emplace_back
#define int long long
using namespace std;
constexpr int N=5e5+5;
struct Query{
int k,coef,id;
// k:目标值
// coef:贡献系数,1/-1
// id:贡献给到的节点
Query(int _k,int _coef,int _id):k(_k),coef(_coef),id(_id){}
};
struct SegTree{
int t[N<<3];
void update(int p,int pl,int pr,int pos,int x){ // 单点修改
if(pl==pr) return void(t[p]=x);
int mid=pl+pr>>1;
if(pos<=mid) update(ls(p),pl,mid,pos,x);
else update(rs(p),mid+1,pr,pos,x);
t[p]=max(t[ls(p)],t[rs(p)]);
}
int query(int p,int pl,int pr,int l,int r){ // 区间max
if(l>r) return 0;
if(l<=pl&&pr<=r) return t[p];
int mid=pl+pr>>1,a=0;
if(l<=mid) a=max(a,query(ls(p),pl,mid,l,r));
if(mid<r) a=max(a,query(rs(p),mid+1,pr,l,r));
return a;
}
}sgt;
char s[N];
int sum[N],dep[N],cnt[N<<1],a[N],st[N];
int n,m,ans;
vector<int> g[N];
vector<Query> q[N];
void dfs1(int u){
int lst=sgt.query(1,1,m,sum[u],sum[u]);
sgt.update(1,1,m,sum[u],dep[u]);
st[dep[u]]=u;
for(int v:g[u]){
sum[v]=sum[u]+(s[v]=='('?1:-1);
dep[v]=dep[u]+1;
if(s[v]==')'){
int bound=sgt.query(1,1,m,1,sum[v]-1);
q[u].eb(sum[v],1,v);
if(bound) q[st[bound]].eb(sum[v],-1,v);
}
dfs1(v);
}
sgt.update(1,1,m,sum[u],lst);
}
void dfs2(int u){
++cnt[sum[u]];
for(Query x:q[u]){
a[x.id]+=x.coef*cnt[x.k];
}
for(int v:g[u]) dfs2(v);
--cnt[sum[u]];
}
void dfs3(int u){
for(int v:g[u]){
a[v]+=a[u];
dfs3(v);
}
ans^=u*a[u];
}
signed main(){
cin.tie(0)->sync_with_stdio(0);
cin>>n>>s+1;
m=n<<1;
rept(i,2,n){
int x;cin>>x;
g[x].eb(i);
}
g[0].eb(1);
sum[0]=n,dep[0]=1; // 为了不出负数,sum统一加上n
dfs1(0),dfs2(0),dfs3(0);
cout<<ans;
return 0;
}

浙公网安备 33010602011771号