UOJ#388. 【UNR #3】配对树 树链剖分+线段树
这道题卡常啊 !
出题人说 $O(n \log^2 n)$ 可过,但我写了个 $O(n \log^2 n)$ 的树剖卡了半天常数.
最暴力的做法:枚举区间,然后跑一个树形DP 来求最小匹配.
显然,因为要求匹配值最小,所以一定是能匹配就先匹配.
也就是说递归完 $x$ 的所有儿子后,$x$ 的每一个儿子最多只有 1 个点还没有匹配.
这个时间复杂度是 $O(n^3)$ 的.
然后我们对每一条边分别考虑:
令 $v[x]$ 表示点 $x$ 到其父亲的边权(以 1 为根),那么 $v[x]$ 能产生贡献,当且仅当一个区间中 $x$ 子树中有奇数个点.
这个很好理解,因为如果有奇数个点,就意味着 1 个点没有被匹配到,而需要向上延伸的 $x$ 的父亲,依此类推......
那么就枚举右端点,然后令 $f[x][0/1]$ 分别表示多少个长度为偶数的区间满足在 $x$ 的子树中有偶数/奇数个点.
由于要求区间长度是偶数,我们可以分别以 $1,2$ 为起点各跑一次,每次同时加入两个点来保证长度为偶数.
考虑加入 $x,y$ 后的影响:
$x$ 到 $lca$ 与 $y$ 到 $lca$ (不包括 lca 这个点)的路径上 $f[x][0]=f[x][1]$,$f[x][1]=f[x][0]+1$
不在 $x,y$ 路径上的点 $f[x][1]$ 不变,$f[x][0] \leftarrow f[x][0]+1$.
这个暴力修改的话是 $O(n^2)$ 的,可以获得 $50$pts.
满分算法的话就是用树链剖分+线段树来维护上面的东西.
我们无外乎就是要支持:每个节点维护 $f[x][0],f[x][1]$,区间加,区间交换.
然后定义标记 $(rev,x,y)$ 表示是否要交换 $f[x][0],f[x][1]$ 的值,交换后对 $f[x][0]$,$f[x][1]$ 分别加上 $x,y$.
时间复杂度为 $O(n \log^2 n)$,但是会有点卡常.
这里说几个卡常技巧:
1. 读入优化
2. 开 long long 要比取模快.
3. 由于上述操作中每次加的数是 1 或 -1,所以这个标记可以直接开 int,然后区间和开 long long.
code:
#include <cstdio>
#include <ctime>
#include <cstring>
#include <algorithm>
#define N 100008
#define ll long long
#define mod 998244353
#define lson now<<1
#define rson now<<1|1
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int edges,n,m,tim;
int nd[N],f[N][2],fa[N];
int hd[N],to[N<<1],nex[N<<1],val[N<<1];
int dep[N],a[N],size[N],top[N],son[N],dfn[N],bu[N];
ll ans;
struct data {
int rev;
int vx,vy;
ll sx,sy,sum;
data(int rev=0,int vx=0,int vy=0):rev(rev),vx(vx),vy(vy){}
}s[N<<2];
inline void add(int u,int v,int c) {
nex[++edges]=hd[u];
hd[u]=edges,to[edges]=v,val[edges]=c;
}
void dfs(int x,int ff) {
size[x]=1;
fa[x]=ff,dep[x]=dep[ff]+1;
for(int i=hd[x];i;i=nex[i]) {
int y=to[i];
if(y==ff) continue;
nd[y]=val[i],dfs(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) son[x]=y;
}
}
void dfs2(int x,int tp) {
top[x]=tp;
dfn[x]=++tim;
bu[tim]=x;
if(son[x]) dfs2(son[x],tp);
for(int i=hd[x];i;i=nex[i])
if(to[i]!=fa[x]&&to[i]!=son[x])
dfs2(to[i],to[i]);
}
inline int get_lca(int x,int y) {
while(top[x]!=top[y]) {
dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];
}
return dep[x]<dep[y]?x:y;
}
inline void pushup(int now) {
s[now].sx=(ll)(s[lson].sx+s[rson].sx);
s[now].sy=(ll)(s[lson].sy+s[rson].sy);
}
inline void mark_rev(int now) {
swap(s[now].sx,s[now].sy);
swap(s[now].vx,s[now].vy);
s[now].rev^=1;
}
inline void mark_add(int now,int vx,int vy) {
if(vx) (s[now].sx+=(ll)vx*s[now].sum);
if(vy) (s[now].sy+=(ll)vy*s[now].sum);
if(vx) (s[now].vx+=vx);
if(vy) (s[now].vy+=vy);
}
inline void pushdown(int now) {
if(s[now].rev) {
s[now].rev=0;
mark_rev(lson);
mark_rev(rson);
}
if(s[now].vx||s[now].vy) {
mark_add(lson,s[now].vx,s[now].vy);
mark_add(rson,s[now].vx,s[now].vy);
s[now].vx=s[now].vy=0;
}
}
void build(int l,int r,int now) {
s[now]=data();
s[now].sx=0;
s[now].sy=0;
if(l==r) {
s[now].sum=nd[bu[l]];
return;
}
int mid=(l+r)>>1;
build(l,mid,lson),build(mid+1,r,rson);
s[now].sum=(ll)(s[lson].sum+s[rson].sum)%mod;
}
void REV(int l,int r,int now,int L,int R) {
if(l>=L&&r<=R) {
mark_rev(now);
return;
}
pushdown(now);
int mid=(l+r)>>1;
if(L<=mid) REV(l,mid,lson,L,R);
if(R>mid) REV(mid+1,r,rson,L,R);
pushup(now);
}
void ADD(int l,int r,int now,int L,int R,int vx,int vy) {
if(l>=L&&r<=R) {
mark_add(now,vx,vy);
return;
}
pushdown(now);
int mid=(l+r)>>1;
if(L<=mid) ADD(l,mid,lson,L,R,vx,vy);
if(R>mid) ADD(mid+1,r,rson,L,R,vx,vy);
pushup(now);
}
inline void upd(int x,int y) {
while(top[y]!=top[x]) {
ADD(1,n,1,dfn[top[y]],dfn[y],-1,0);
REV(1,n,1,dfn[top[y]],dfn[y]);
ADD(1,n,1,dfn[top[y]],dfn[y],0,1);
y=fa[top[y]];
}
if(y!=x) {
ADD(1,n,1,dfn[x]+1,dfn[y],-1,0);
REV(1,n,1,dfn[x]+1,dfn[y]);
ADD(1,n,1,dfn[x]+1,dfn[y],0,1);
}
}
void sol(int st) {
int x,y,lca;
build(1,n,1);
for(int i=st;i<=m;i+=2) {
if(i+1>m) break;
x=a[i],y=a[i+1];
if(dep[x]>dep[y]) swap(x,y);
lca=get_lca(x,y);
mark_add(1,1,0);
upd(lca,x);
upd(lca,y);
(ans+=s[1].sy)%=mod;
}
}
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd()
{
int x=0; char c;
while(c<48) c=nc();
while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc();
return x;
}
int main() {
// setIO("input");
n=rd(),m=rd();
int x,y,z;
for(int i=1;i<n;++i) {
x=rd(),y=rd(),z=rd();
if(z>=mod) z-=mod;
add(x,y,z),add(y,x,z);
}
dfs(1,0);
dfs2(1,1);
for(int i=1;i<=m;++i) a[i]=rd();
sol(1),sol(2);
printf("%lld\n",ans);
return 0;
}

浙公网安备 33010602011771号