# 洛谷P4719 【模板】动态dp

https://www.luogu.org/problemnew/show/P4719

  1 #include<cstdio>
2 #include<algorithm>
3 #include<cstring>
4 using namespace std;
5 typedef long long ll;
6 struct E
7 {
8     int to,nxt;
9 }e[200011];
10 int f1[100011],ne;
11 struct P1
12 {
13     ll d[2][2];//左侧不选/选，右侧不选/选
14 };
15 struct P2
16 {
17     ll d[2];//自身不选/选
18 };
19 ll a[100101];
20 int sz[100101],hson[100101],ff[100101];
21 int b[100101],pl[100101];
22 int n,m;
23 inline ll max1(ll a,ll b)
24 {
25     return a>b?a:b;
26 }
27 const ll inf1=-0x3f3f3f3f3f3f3f3f;
28 #define max max1
29 #define G(x) max1((x),inf1)
30 inline void merge(P1 &c,const P1 &a,const P1 &b)
31 {
32     c.d[0][0]=G(max(a.d[0][0]+max(b.d[1][0],b.d[0][0]),
33         a.d[0][1]+b.d[0][0]));
34     c.d[0][1]=G(max(a.d[0][0]+max(b.d[1][1],b.d[0][1]),
35         a.d[0][1]+b.d[0][1]));
36     c.d[1][0]=G(max(a.d[1][0]+max(b.d[1][0],b.d[0][0]),
37         a.d[1][1]+b.d[0][0]));
38     c.d[1][1]=G(max(a.d[1][0]+max(b.d[1][1],b.d[0][1]),
39         a.d[1][1]+b.d[0][1]));
40 }
41 inline void initnode(P1 &c,const P2 &a)
42 {
43     c.d[0][0]=a.d[0];c.d[1][1]=a.d[1];
44     c.d[0][1]=c.d[1][0]=inf1;
45 }
46 namespace S
47 {
48 #define lc (num<<1)
49 #define rc (num<<1|1)
50     P1 d[400101];
51     inline void upd(int num){merge(d[num],d[lc],d[rc]);}
52     P1 x;int L;
53     void _setx(int l,int r,int num)
54     {
55         if(l==r)
56         {
57             d[num]=x;
58             return;
59         }
60         int mid=(l+r)>>1;
61         if(L<=mid)    _setx(l,mid,lc);
62         else    _setx(mid+1,r,rc);
63         upd(num);
64     }
65     P1 getx(int L,int R,int l,int r,int num)
66     {
67         if(L<=l&&r<=R)    return d[num];
68         int mid=(l+r)>>1;
69         if(L<=mid&&mid<R)
70         {
71             P1 x;
72             merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+1,r,rc));
73             return x;
74         }
75         else if(L<=mid)
76             return getx(L,R,l,mid,lc);
77         else if(mid<R)
78             return getx(L,R,mid+1,r,rc);
79         else
80             exit(-1);
81     }
82 }
83 void dfs1(int u,int fa)
84 {
85     sz[u]=1;
86     for(int v,k=f1[u];k;k=e[k].nxt)
87         if(e[k].to!=fa)
88         {
89             v=e[k].to;
90             ff[v]=u;
91             dfs1(v,u);
92             sz[u]+=sz[v];
93             if(sz[v]>sz[hson[u]])    hson[u]=v;
94         }
95 }
96 P2 d1[100101];//d1[i]维护i节点及其轻儿子的贡献
97 P2 d2[100101];//d2[i]维护i节点(是重链顶)所在重链的dp值
98 int tp[100101],dwn[100101];//链顶,链底
99 void dfs2(int u,int fa)
100 {
101     d1[u].d[0]=0;d1[u].d[1]=a[u];
102     b[++b[0]]=u;pl[u]=b[0];
103     tp[u]=(u==hson[fa])?tp[fa]:u;
104     if(hson[u])    dfs2(hson[u],u);
105     dwn[u]=hson[u]?dwn[hson[u]]:u;
106     int v,k;
107     for(k=f1[u];k;k=e[k].nxt)
108         if(e[k].to!=fa&&e[k].to!=hson[u])
109         {
110             v=e[k].to;
111             dfs2(v,u);
112             d1[u].d[0]+=max(d2[v].d[0],d2[v].d[1]);
113             d1[u].d[1]+=d2[v].d[0];
114         }
115     initnode(S::x,d1[u]);S::L=pl[u];S::_setx(1,n,1);
116     if(u==tp[u])
117     {
118         P1 t=S::getx(pl[u],pl[dwn[u]],1,n,1);
119         d2[u].d[0]=max(t.d[0][0],t.d[0][1]);
120         d2[u].d[1]=max(t.d[1][0],t.d[1][1]);
121     }
122 }
123 int main()
124 {
125     int i,x,y;ll z;P1 t;
126     scanf("%d%d",&n,&m);
127     for(i=1;i<=n;++i)    scanf("%lld",a+i);
128     for(i=1;i<n;++i)
129     {
130         scanf("%d%d",&x,&y);
131         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
132         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
133     }
134     dfs1(1,0);
135     dfs2(1,0);
136     while(m--)
137     {
138         scanf("%d%lld",&x,&z);
139         d1[x].d[1]-=a[x];a[x]=z;d1[x].d[1]+=z;
140         while(x)
141         {
142             initnode(S::x,d1[x]);S::L=pl[x];S::_setx(1,n,1);
143             x=tp[x];y=ff[x];
144             t=S::getx(pl[x],pl[dwn[x]],1,n,1);
145             d1[y].d[0]-=max(d2[x].d[0],d2[x].d[1]);
146             d1[y].d[1]-=d2[x].d[0];
147             d2[x].d[0]=max(t.d[0][0],t.d[0][1]);
148             d2[x].d[1]=max(t.d[1][0],t.d[1][1]);
149             d1[y].d[0]+=max(d2[x].d[0],d2[x].d[1]);
150             d1[y].d[1]+=d2[x].d[0];
151             x=y;
152         }
153         //printf("%lld %lld\n",d2[1].d[0],d2[1].d[1]);
154         printf("%lld\n",max(d2[1].d[0],d2[1].d[1]));
155     }
156     return 0;
157 }
View Code

bst版本O(n*log)

posted @ 2018-12-08 15:44  hehe_54321  阅读(262)  评论(0编辑  收藏  举报