【YbtOJ#763】攻城略池

题目

题目链接:https://www.ybtoj.com.cn/contest/120/problem/1

\(n\leq 10^5,l_i\leq 10^3,d_i\leq 10^8\)

思路

\(f_x\) 是点 \(x\) 被攻占的时间。显然这个值可以二分,然后枚举子树内的每一个点,计算在二分到的时间内从枚举到的点可以过去多少人。
\(mid\) 时间内会被攻占当且仅当

\[d_x\leq \sum_{y\in\text{subtree}(x)}\max(mid-f_y-(\text{dep}_y-\text{dep}_x),0) \]

把括号拆开来,考虑把 \(x\) 子树内的点扔到权值线段树上,权值线段树上的节点 \([l,r]\) 储存所有 \(\text{dep}_y+f_y\in[l,r]\) 的点的权值之和以及数量。
然后二分可以直接在线段树上二分,当我们到达区间 \([l,r]\) 时,记 \(c\) 为权值在 \([1,mid]\) 的点的数量,\(v\) 为权值在 \([1,mid]\) 的点的权值和,那么我们往右边二分当且仅当

\[c\times mid-v<a_x \]

然后往上的时候线段树合并就可以了。
线段树值域上界是 \(3\times 10^8\),动态开点就可以了。
时间复杂度 \(O(n\log (d+nl))\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=100010,LG=30,MAXN=N*LG*4,Lim=3e8;
int n,ans,tot,a[N],head[N],dep[N],rt[N],f[N];

struct edge
{
	int next,to,dis;
}e[N*2];

void add(int from,int to,int dis)
{
	e[++tot]=(edge){head[from],to,dis};
	head[from]=tot;
}

struct SegTree
{
	int lc[MAXN],rc[MAXN],cnt[MAXN];
	ll sum[MAXN];
	
	int merge(int x,int y)
	{
		if (!x || !y) return x|y;
		sum[x]+=sum[y]; cnt[x]+=cnt[y];
		lc[x]=merge(lc[x],lc[y]);
		rc[x]=merge(rc[x],rc[y]);
		return x;
	}
	
	int update(int x,int l,int r,int v)
	{
		if (!x) x=++tot;
		cnt[x]++; sum[x]+=v;
		if (l==r) return x;
		int mid=(l+r)>>1;
		if (v<=mid) lc[x]=update(lc[x],l,mid,v);
			else rc[x]=update(rc[x],mid+1,r,v);
		return x;
	}
	
	int query(int x,int l,int r,int k,int c,ll s)
	{
		if (l==r) return l;
		int mid=(l+r)>>1;
		ll ans=1LL*(cnt[lc[x]]+c)*mid-(sum[lc[x]]+s);
		if (ans>=a[k]) return query(lc[x],l,mid,k,c,s);
			else return query(rc[x],mid+1,r,k,c+cnt[lc[x]],s+sum[lc[x]]);
	}
}seg;

void dfs(int x,int fa)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			dep[v]=dep[x]+e[i].dis;
			dfs(v,x);
			rt[x]=seg.merge(rt[x],rt[v]);
		}
	}
	f[x]=max(seg.query(rt[x],0,Lim,x,0,0)-dep[x],0);
	rt[x]=seg.update(rt[x],0,Lim,f[x]+dep[x]);
	ans=max(ans,f[x]);
}

signed main()
{
	freopen("conquer.in","r",stdin);
	freopen("conquer.out","w",stdout);
//	return printf("%d\n",sizeof(seg)/1024/1024),0;
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1,x,y,z;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z); add(y,x,z);
	}
	dfs(1,0);
	printf("%lld",ans);
	return 0;
}
posted @ 2021-02-21 16:35  stoorz  阅读(253)  评论(0)    收藏  举报