LOJ #2769 -「ROI 2017 Day 1」前往大都会(单调栈维护斜率优化)

LOJ 题面传送门

orz 斜率优化……

模拟赛时被这题送走了,所以来写篇题解(

首先这个最短路的求法是 trivial 的,直接一遍 dijkstra 即可(

重点在于怎样求第二问。注意到这个第二问平方和最大要在保证最短路的基础上求,因此考虑建出最短路 DAG,这样最短路径上一条 \(1\to n\) 的路径就对应原图中一条最短路。因此此题等价于求一条 \(1\to n\) 的路径,满足每一个属于同一连续段上路径权值和的平方之和最大。注意到 \(f(x)=x^2\) 是下凸函数,也就是对于一段同一铁路线上的路径,我们肯定不会选择将它们拆成一段段小路径。因此我们考虑这样一个 \(dp\)\(dp_i\) 表示从 \(1\to i\) 路径上每一段权值之和的平方之和的最大值,转移就枚举上一次乘坐的那段铁路线的起始城市 \(j\),由于是最短路径 DAG,必然有这一段路径的长度为 \(dis_j-dis_i\),因此有转移 \(dp_i=\max\{dp_j+(dis_j-dis_i)^2\}\)

直接做是 \(\sum s_i^2\) 的无法通过,考虑怎样优化。考虑将平方拆开来得到 \(dp_i=\max\{dp_j+dis_j^2-2dis_idis_j\}+dis^2_i\),这东西等价于平面上有若干个形如 \((dis_j,dp_j+dis_j^2)\) 的点,现在你要过这些点分别做斜率 \(2dis_i\) 的直线并取最大截距。我们考虑对所有铁路线维护这些点组成的上凸包,查询就直接在凸包中二分找斜率为 \(2dis_i\) 的直线也可以。当然也有线性的维护方法,注意到此题我们需维护上凸包,凸包中直线斜率递减,并且在拓扑排序的过程中我们查询的直线的斜率 \(2dis_i\) 肯定是递增的,因此每查询一条直线毙掉的凸包中的直线肯定一段后缀,也就是最新加入的某一些直线,因此需用单调栈维护。复杂度 \(n\log n\)

const int MAXN=1e6;
int n,m,hd[MAXN+5],to[MAXN+5],nxt[MAXN+5],val[MAXN+5],ec=0;
void adde(int u,int v,int w){to[++ec]=v;val[ec]=w;nxt[ec]=hd[u];hd[u]=ec;}
vector<pii> tr[MAXN+5],pos[MAXN+5];
vector<int> bel[MAXN+5];
ll dis[MAXN+5],dp[MAXN+5];
int ord[MAXN+5],bcnt=0;
bool cmp(int x,int y){return dis[x]<dis[y];}
ld slope(int x,int y){return 1.0*((dis[x]*dis[x]+dp[x])-(dis[y]*dis[y]+dp[y]))/(dis[x]-dis[y]);}
vector<int> stk[MAXN*2+5];
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=m;i++){
		int cnt,x;scanf("%d%d",&cnt,&x);
		tr[i].pb(mp(x,0));int pre=x;pos[x].pb(mp(i,0));
		for(int j=1;j<=cnt;j++){
			int w,y;scanf("%d%d",&w,&y);
			tr[i].pb(mp(y,w));pos[y].pb(mp(i,j));
			adde(pre,y,w);pre=y;
		} bel[i].resize(cnt+2);
	} memset(dis,63,sizeof(dis));dis[1]=0;
	priority_queue<pair<ll,int>,vector<pair<ll,int> >,greater<pair<ll,int> > > q;
	q.push(mp(0,1));
	while(!q.empty()){
		pair<ll,int> p=q.top();q.pop();int x=p.se;
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e],z=val[e];
			if(dis[y]>dis[x]+z){
				dis[y]=dis[x]+z;
				q.push(mp(dis[y],y));
			}
		}
	}
	for(int i=1;i<=n;i++) ord[i]=i;
	sort(ord+1,ord+n+1,cmp);
	for(int i=1;i<=m;i++) bel[i][0]=++bcnt;
	for(int i=1;i<=n;i++){
		int x=ord[i];//printf("%d\n",x);
		for(pii p:pos[x]){
			int id=p.fi,t=p.se;if(!t) continue;
			int pre=tr[id][t-1].fi;
			if(dis[x]==dis[pre]+tr[id][t].se) bel[id][t]=bel[id][t-1];
			else bel[id][t]=++bcnt;
//			printf("%d %d %d %d\n",x,id,t,bel[id][t]);
			if(bel[id][t]==bel[id][t-1]){
				int b=bel[id][t];
				assert(!stk[b].empty());
				while(stk[b].size()>=2&&slope(stk[b][stk[b].size()-1],stk[b][stk[b].size()-2])<2*dis[x]) stk[b].ppb();
				int y=stk[b].back();chkmax(dp[x],dp[y]+(dis[x]-dis[y])*(dis[x]-dis[y]));
			}
		}
		for(pii p:pos[x]){
			int id=p.fi,t=p.se,b=bel[id][t];
//			printf("ins %d\n",b);
			while(stk[b].size()>=2&&slope(stk[b][stk[b].size()-1],stk[b][stk[b].size()-2])<
									slope(stk[b][stk[b].size()-1],x)) stk[b].ppb();
			stk[b].pb(x);
		}
	} printf("%lld %lld\n",dis[n],dp[n]);
	return 0;
}
posted @ 2021-09-25 17:05  tzc_wk  阅读(177)  评论(0)    收藏  举报