题解 P8098 [USACO22JAN] Tests for Haybales G
第一眼拿到这题,想到的应该直接就是差分约束!不妨令 \(K=n+1\),对于每个条件(第 \(i\) 个点输入是 \(j\))是这样的约束关系:
\[a_j-a_i\le K(1)
\]
\[a_{j+1}-a_i\ge K+1(2)
\]
再补充上:
\[a_{i+1}-a_i \ge 0(3)
\]
按照以上三个条件连边即可。于是你发现样例就过了,提交发现居然 \(50\%\)。然后发现 SPFA 死了。以下是 Luogu \(48\) 分代码(通过 \(50\%\)):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5e5+5;
const ll inf=100000000000000000;
ll dis[N],v[2*N];
int vis[N],n,m,k;
int cnt,nxt[2*N],t[2*N],h[2*N],num[2*N];
inline void add(int x,int y,int z) //x-y<=z
{
swap(x,y);
t[++cnt]=y;
v[cnt]=z;
nxt[cnt]=h[x];
h[x]=cnt;
}
queue <int> q;
void spfa(int s)
{
for (int i=1;i<=n;i++) dis[i]=inf,vis[i]=0;
dis[s]=0,vis[s]=1,num[s]=1;
q.push(s);
while (!q.empty())
{
int u=q.front();
q.pop();
vis[u]=0;
for (int i=h[u];i;i=nxt[i])
{
if (dis[t[i]]>dis[u]+v[i])
{
dis[t[i]]=dis[u]+v[i];
if (vis[t[i]]==0)
{
vis[t[i]]=1;
num[t[i]]++;
if (num[t[i]]>n)
{
printf("NO");
exit(0);
}
q.push(t[i]);
}
}
}
}
}
inline int read()
{
char C=getchar();
int F=1,ANS=0;
while (C<'0'||C>'9')
{
if (C=='-') F=-1;
C=getchar();
}
while (C>='0'&&C<='9')
{
ANS=ANS*10+C-'0';
C=getchar();
}
return F*ANS;
}
int main()
{
n=read(),k=n+1;
for (int i=1;i<=n;i++)
{
int x=read();
if (x!=i) add(x,i,k);
if (x!=n) add(i,x+1,-k-1);
}
for (int i=1;i<=n;i++) add(i,0,0);
for (int i=1;i<n;i++) add(i,i+1,0);
spfa(0);
printf("%d\n",k);
for (int i=1;i<=n;i++) printf("%lld\n",dis[i]+inf);
return 0;
}
因为存在负权边,所以不能使用 Dij 等其他最短路算法。那么怎么解决负权?
先把原序列切成几段,每碰到一个 \(i=j_i\) 的就切一下。显然段和段之间不影响。对于每一个段,我们再考虑分层。最后一个是第 \(0\) 层,其他的第 \(i\) 个就是 \(j_i\) 的层数再加一。
我们令第 \(i\) 层的初始权值为 \(-iK\)。这样会发现,条件 \((1)\) 中的 \(i,j\) 必然隔一层,那么 \(K\) 就消掉了。条件 \((2)\) 中 \(i,j+1\) 割一层或两层,两层的话就不用考虑,否则 \(K\) 也消掉了。条件 \((3)\) 对于同一层的照样做就行了。
于是原来的条件的权值只有 \(0\) 和 \(1\) 了,容易发现图变成了 DAG,因此可以线性求出最长路。
记得把每个数改为非负整数,还有段和段之间的间隔要注意。
// By: Little09
// Problem: P8098 [USACO22JAN] Tests for Haybales G
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P8098
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5e5+5;
const ll inf=100000000000000000;
ll dis[N],v[2*N],k;
int a[N];
bool used[N];
int n,m;
int cnt,nxt[2*N],t[2*N],h[2*N],d[N];
inline void add(int x,int y,int z) //x-y>=z
{
swap(x,y);
t[++cnt]=y;
v[cnt]=z;
nxt[cnt]=h[x];
h[x]=cnt;
d[y]++;
}
queue<int>q;
void topo(int s,int l,int r)
{
for (int i=l;i<=r;i++) dis[i]=0,used[i]=0;
q.push(s);
dis[s]=0;
for (int i=l;i<=r;i++) if (d[i]==0) q.push(i);
while (!q.empty())
{
int u=q.front();
q.pop();
for (int i=h[u];i;i=nxt[i])
{
d[t[i]]--;
dis[t[i]]=max(dis[t[i]],dis[u]+v[i]);
if (d[t[i]]==0) q.push(t[i]);
}
}
}
inline int read()
{
char C=getchar();
int F=1,ANS=0;
while (C<'0'||C>'9')
{
if (C=='-') F=-1;
C=getchar();
}
while (C>='0'&&C<='9')
{
ANS=ANS*10+C-'0';
C=getchar();
}
return F*ANS;
}
int deth[N];
void work(int l,int r,ll qwq)
{
deth[r]=0;
for (int i=0;i<=cnt;i++)
{
nxt[i]=0,v[i]=0,t[i]=0;
}
for (int i=l;i<=r;i++) h[i]=0,d[i]=0;
h[0]=0,d[0]=0;
cnt=0;
for (int i=r-1;i>=l;i--) deth[i]=deth[a[i]]+1;
for (int i=l;i<r;i++)
{
if (deth[i]==deth[i+1])
{
add(i+1,i,0);
}
}
for (int i=l;i<r;i++)
{
add(i,a[i],0);
if (a[i]+1<=r&&deth[a[i]+1]==deth[a[i]])
add(a[i]+1,i,1);
}
for (int i=l;i<=r;i++) add(i,0,0);
topo(0,l,r);
for (int i=l;i<=r;i++) dis[i]-=1ll*deth[i]*k;
for (int i=r;i>=l;i--) dis[i]-=dis[l];
for (int i=l;i<=r;i++) dis[i]+=qwq;
}
int main()
{
n=read(),k=n+1;
for (int i=1;i<=n;i++) a[i]=read();
int las=1;
ll tmp=0;
for (int i=1;i<=n;i++)
{
if (a[i]==i)
{
work(las,i,tmp+k+1);
las=i+1,tmp=dis[i];
}
}
printf("%lld\n",k);
for (int i=1;i<=n;i++) printf("%lld\n",dis[i]);
return 0;
}