dij费用流/Johnson Reweighting
dij费用流/Johnson Reweighting
我们一般敲的费用流都是套\(SPFA\)的\(dinic\),这是因为会有负边权,\(dij\)做不了,考虑能不能动点手脚使得我们的边权变成正的
可能会直接想到给每条边+\(INF\),这显然有点难蚌
\(Johnson\ Reweighting\)就是用来解决这个问题的
考虑给每个点赋一个点权\(h[x]\),表示\(S\)走到\(x\)的最短路的长度,我们现在让边\((u,v,w)\)的权变成\(w+h[u]-h[v]\),因为\(h\)是最短距离,所以显然这个值\(\geq 0\),且这样直接跑\(dij\)得到的\(S\)到\(T\)的最短路距离\(v\)加上\(-dis[S]+dis[T]\)就是真正的最短路距离
那么考虑怎么快速得到这个\(h\),显然第一次就直接跑个\(SPFA\),考虑后面怎么办
对于第\(i\)(\(i>1\))次跑\(dij\)时,记\(h_i\)表示跑完第\(i\)次\(dij\)后的最短路
显然我们想用\(h_i\)来操作第\(i\)次\(dij\),但显然这不现实
发现就用\(h_{i-1}\)来操作的话,所有边权也都是非负的,因为对于第\(i-1\)次\(dij\),它会改变的边\((u,v,w)\)满足\(w+h[u]-h[v]=0\)
复杂度是\(O(F(M+NlogN))\),好像\(OI\)里一般常写成\(O(F(N+M)logN)\)
模板(https://www.luogu.com.cn/problem/P3381):
#include<bits/stdc++.h>
using namespace std;
const int N=5e3+5,M=5e4+5,INF=1e9;
int n,m,S,T;
int head[N],cnt=1;
struct node{ int nxt,v,val,w; }tree[M<<1];
void add(int u,int v,int val,int w){
tree[++cnt]={head[u],v,val,w},head[u]=cnt;
tree[++cnt]={head[v],u,0,-w},head[v]=cnt;
}
int dis[N]; bool vis[N];
void spfa(){
for(int i=1;i<=n;++i) dis[i]=INF;
dis[S]=0; queue<int> q; q.push(S);
while(!q.empty()){
int x=q.front(); q.pop(),vis[x]=false;
for(int i=head[x],y;i;i=tree[i].nxt) if(tree[i].val&&dis[y=tree[i].v]>dis[x]+tree[i].w){
dis[y]=dis[x]+tree[i].w;
if(!vis[y]) q.push(y),vis[y]=true;
}
}
}
struct use{
int x,d;
bool operator < (const use &other)const{
return d>other.d;
}
};
int h[N],fr[N],ans,ans1;
bool dij(){
priority_queue<use> q;
for(int i=1;i<=n;++i) dis[i]=INF,vis[i]=false;
q.push({S,dis[S]=0});
while(!q.empty()){
int x=q.top().x; q.pop();
if(vis[x]) continue;
vis[x]=true;
for(int i=head[x],y;i;i=tree[i].nxt) if(tree[i].val&&!vis[y=tree[i].v]){
int t=dis[x]+tree[i].w+h[x]-h[y];
if(t<dis[y]) dis[y]=t,fr[y]=i,q.push({y,dis[y]});
}
}
if(dis[T]==INF) return false;
int mn=INF;
for(int x=T,i=fr[x];x!=S;x=tree[i^1].v,i=fr[x]) mn=min(mn,tree[i].val);
ans+=mn;
for(int x=T,i=fr[x];x!=S;x=tree[i^1].v,i=fr[x]) tree[i].val-=mn,tree[i^1].val+=mn,ans1+=mn*tree[i].w;
for(int i=1;i<=n;++i) if(dis[i]<INF) h[i]+=dis[i];
return true;
}
int main(){
scanf("%d%d%d%d",&n,&m,&S,&T);
for(int i=1,u,v,val,w;i<=m;++i) scanf("%d%d%d%d",&u,&v,&val,&w),add(u,v,val,w);
spfa(); for(int i=1;i<=n;++i) h[i]=dis[i];
while(dij());
printf("%d %d\n",ans,ans1);
return 0;
}

浙公网安备 33010602011771号