LOJ 3059 「HNOI2019」序列——贪心与前后缀的思路+线段树上二分

题目:https://loj.ac/problem/3059

一段 A 选一个 B 的话, B 是这段 A 的平均值。因为 \( \sum (A_i-B)^2 = \sum A_i^2 - 2*B \sum A_i + len*B^2 \) ,这是关于 B 的二次方程,对称轴是 \( B = - \frac{-2*\sum A_i}{2*len} \) ,恰是 A 的平均值。

所以自己前 10 分写了 “ dp[ i ][ j ] 表示前 i 个 A 、最后一段的 B = j ” 的 DP , n,m <= 100 的写了 “ dp[ i ] 表示前 i 个 A 的答案、转移枚举 i 所在的段到哪为止 ” 的 DP 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define db double
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,m;
namespace S1{
  const int N=15,M=1005; const ll INF=1e16;
  ll Mn(ll a,ll b){return a<b?a:b;}
  ll Mx(ll a,ll b){return a>b?a:b;}
  int a[N]; ll dp[N][M],f[N][M];
  ll Sqr(int x){return (ll)x*x;}
  int calc()
  {
    int mx=0;
    for(int i=1;i<=n;i++)mx=Mx(mx,a[i]);
    dp[0][0]=0;
    for(int i=1;i<=n;i++)
      {
    f[i][0]=INF;
    for(int j=1;j<=mx;j++)
      {
        ll sm=Sqr(j-a[i]); dp[i][j]=INF;
        for(int k=i-1;k>=0;k--)
          {
        dp[i][j]=Mn(dp[i][j],f[k][j]+sm);
        if(k)sm+=Sqr(j-a[k]);
          }
        f[i][j]=Mn(f[i][j-1],dp[i][j]);
      }
      }
    ll ret=INF;
    for(int j=1;j<=mx;j++) ret=Mn(ret,dp[n][j]);
    return ret%mod;
  }
  void solve()
  {
    for(int i=1;i<=n;i++) a[i]=rdn();
    printf("%d\n",calc());
    for(int i=1,u,k,d;i<=m;i++)
      {
    u=rdn();k=rdn(); d=a[u];a[u]=k;
    printf("%d\n",calc());
    a[u]=d;
      }
  }
}
namespace S2{
  const int N=105; const db INF=1e16;
  int a[N];db dp[N],f[N];int ans[N];
  db cal(int l,int r,db d)
  {
    db ret=0;
    for(int i=l;i<=r;i++)
      ret+=(a[i]-d)*(a[i]-d);
    return ret;
  }
  int cal2(int l,int r,ll sm)
  {
    sm=(ll)sm*pw(r-l+1,mod-2)%mod;
    int ret=0;
    for(int i=l;i<=r;i++)
      ret=(ret+(ll)(a[i]-sm)*(a[i]-sm))%mod;
    return ret;
  }
  int calc()
  {
    for(int i=1;i<=n;i++)
      {
    db sm=a[i]; dp[i]=INF;
    for(int j=i-1;j>=0;j--)
      {
        db d=sm/(i-j);
        if(d>=f[j])
          {
        db k=cal(j+1,i,d);
        if((dp[j]+k<dp[i])||(dp[j]+k==dp[i]&&d<f[i]))
          {
            dp[i]=dp[j]+k; f[i]=d;
            ans[i]=upt(ans[j]+cal2(j+1,i,sm));
          }
          }
        sm+=a[j];
      }
      }
    return ans[n];
  }
  void solve()
  {
    for(int i=1;i<=n;i++) a[i]=rdn();
    printf("%d\n",calc());
    for(int i=1,u,k,d;i<=m;i++)
      {
    u=rdn();k=rdn(); d=a[u];a[u]=k;
    printf("%d\n",calc());
    a[u]=d;
      }
  } 
}
int main()
{
  n=rdn();m=rdn();
  if(n<=10){S1::solve();return 0;}
  if(n<=100){S2::solve();return 0;}
  return 0;
}
View Code

应该更大胆一点。结论是可以贪心做那个 DP 的过程,用栈维护现有的 A 的段,如果往后添一个 A 会使得栈顶段平均值 > 栈顶前一个段平均值,就把栈顶和它前面那个段合并起来;则合并后的段平均值比原来 “栈顶前面那个段” 的平均值大,不会使更前面不合法。这样就有 50 分了。

考虑每次有修改一个位置该怎么做。

一个很好的思路是预处理每个前缀、后缀的栈的样子(用主席树存各时刻的栈),询问的时候拼一下即可。

刚才那个贪心的过程,不是从前往后做而是从后往前做,做出来的栈的形态还是一样的。因为在一个 A 的最优划分中,任意一个 A 换一下所属的段都不会变优;如果是等价的话,会分成尽量多的段,所以从前往后还是从后往前与最后的形态无关。

预处理的东西用主席树存起来。自己的写法是线段树第 i 个位置存了第 i 个段的右/左端点和平均值,区间存的是区间里最后一个/第一个段的信息。

如果知道修改的这个位置所属的段是 [ L , R ] ,那么 [ 1 , L-1 ] 部分的划分就是预处理出的那个,[ R+1 , n ] 的划分也是预处理出的,所以找一下 [ L , R ] 是哪即可。

找 [ L , R ] 可以在线段树上二分。设修改位置是 qi ,先找 R ,在表示 [ qi+1 , n ] 这个后缀的线段树上二分(就是每次看一下 mid+1 是否可行),如果 mid+1 可行,就进左孩子里找,因为段数越多越优;可行的意思是 mid+1 这个段作为 R 后面的第一个段是否满足 “不降” 的要求。

  固定一个 R ,可以一样地在线段树上找到它对应的 L ,就是在表示 [ 1 , qi-1 ] 的线段树上二分,看 mid 作为 L 前面的第一个段是否可行;mid 是 L 前面的第一个段的话, mid 这段的右端点就是 L ,又知 R ,就可以求出 qi 所在段的平均值,看看是不是比 mid 这段大于等于即可。如果 mid 可行,就尝试去右孩子找(如果找不到还是要返回 mid )。

所以判断 “ mid+1 这个段作为 R 后面的第一个段是否可行 ” 的流程就是先找出 R (此时的 R 就是 mid+1 这段左端点的前一个位置)对应的 L ,则已知 [ L , R ] 的平均值,看看是不是比 mid+1 这一段的平均值小。

算答案的时候别直接用原来的式子枚举 [ L , R ] 地算,用那个 \( \sum A_i^2 - 2*B\sum A_i + len*B^2 \) ,预处理 A 的前缀和、 A2 的前缀和即可。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define db double
#define ll long long
#define ls Ls[cr]
#define rs Rs[cr]
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=1e5+5,M=5e6+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
int Sqr(int x){return (ll)x*x%mod;}

int n,a[N],a2[N],qi,qk,qk2,s2[N];ll s[N];
struct Node{ int p;db v; Node(int p=0,db v=0):p(p),v(v) {} }tI;
struct Dt{ int x,y; Dt(int x=0,int y=0):x(x),y(y) {} }dI;
db cal(int l,int r)
{
  ll ret=s[r]-s[l-1]; if(qi>=l&&qi<=r)ret+=qk-a[qi];
  return (db)ret/(r-l+1);
}
int cal2(int l,int r)
{
  ll d=s[r]-s[l-1]; if(qi>=l&&qi<=r)d+=qk-a[qi];
  ll x=d%mod*pw(r-l+1,mod-2)%mod; d%=mod;//d%=mod
  int ret=upt(s2[r]-s2[l-1]);
  if(qi>=l&&qi<=r)ret=upt(ret+qk2-a2[qi]);
  ret=upt(ret-2*x*d%mod);
  ret=(ret+(ll)(r-l+1)*x%mod*x)%mod;
  return ret;
}
namespace P{
  int ct[N],ans[N];
  int tot,rt[N],Ls[M],Rs[M],dfn[M],tim; bool tg[M];
  Node vl[M],I;
  int nwnd(int pr)
  {
    if(pr&&dfn[pr]==tim)return pr;
    int cr=++tot; dfn[cr]=tim; if(!pr)return cr;
    ls=Ls[pr]; rs=Rs[pr]; vl[cr]=vl[pr]; tg[cr]=tg[pr];
    return cr;
  }
  void pshd(int cr)
  {
    if(!tg[cr])return; tg[cr]=0;
    ls=nwnd(ls); rs=nwnd(rs); vl[ls]=vl[rs]=0; tg[ls]=tg[rs]=1;
  }
  void pshp(int cr){if(vl[rs].p)vl[cr]=vl[rs]; else vl[cr]=vl[ls];}
  void ins(int l,int r,int &cr,int p,Node k)
  {
    cr=nwnd(cr); if(l==r){vl[cr]=k;return;}
    int mid=l+r>>1; pshd(cr);
    if(p<=mid)ins(l,mid,ls,p,k); else ins(mid+1,r,rs,p,k);
    pshp(cr);
  }
  void mdfy(int l,int r,int &cr,int L,int R)
  {
    cr=nwnd(cr); if(l>=L&&r<=R){vl[cr]=I;tg[cr]=1;return;}
    int mid=l+r>>1; pshd(cr);
    if(L<=mid)mdfy(l,mid,ls,L,R); if(mid<R)mdfy(mid+1,r,rs,L,R);
    pshp(cr);
  }
  Dt qry(int l,int r,int cr,int R,int p)//no !cr appear
  {
    if(l==r)
      {
    if(cal(vl[cr].p+1,p)>=vl[cr].v)return Dt(l,vl[cr].p);
    else return dI;
      }
    int mid=l+r>>1; pshd(cr);
    if(mid>=R)return qry(l,mid,ls,R,p);
    if(cal(vl[ls].p+1,p)>=vl[ls].v)//mid is ok
      {
    Dt d=qry(mid+1,r,rs,R,p);
    if(d.y)return d; else return Dt(mid,vl[ls].p);
      }
    else return qry(l,mid,ls,R,p);
  }
  int qryx(int l,int r,int cr,int R,int p)
  {
    if(l==r)
      {
    if(cal(vl[cr].p+1,p)>=vl[cr].v)return vl[cr].p;
    else return 0;
      }
    int mid=l+r>>1; pshd(cr);
    if(mid>=R)return qryx(l,mid,ls,R,p);
    if(cal(vl[ls].p+1,p)>=vl[ls].v)//mid is ok
      {
    int d=qryx(mid+1,r,rs,R,p);
    if(d)return d; else return vl[ls].p;
      }
    else return qryx(l,mid,ls,R,p);
  }
  void solve()
  {
    ins(0,n,rt[0],0,I); Dt d;
    for(int i=1;i<=n;i++)
      {
    tim++; rt[i]=nwnd(rt[i-1]);
    d=qry(0,n,rt[i],ct[i-1],i);
    if(d.x<ct[i-1])mdfy(0,n,rt[i],d.x+1,ct[i-1]);
    ct[i]=d.x+1;
    ins(0,n,rt[i],ct[i],Node(i,cal(d.y+1,i)));
    ans[i]=upt(ans[d.y]+cal2(d.y+1,i));
      }
  }
  int qryx(int p){ return qryx(0,n,rt[qi-1],ct[qi-1],p);}
};
namespace S{
  const db INF=1e9+5;
  int ct[N],ans[N],lm;
  int tot,rt[N],Ls[M],Rs[M],dfn[M],tim; bool tg[M];
  Node vl[M],I;
  int nwnd(int pr)
  {
    if(pr&&dfn[pr]==tim)return pr;
    int cr=++tot; dfn[cr]=tim; if(!pr)return cr;
    ls=Ls[pr]; rs=Rs[pr]; vl[cr]=vl[pr]; tg[cr]=tg[pr];
    return cr;
  }
  void pshd(int cr)
  {
    if(!tg[cr])return; tg[cr]=0;
    ls=nwnd(ls); rs=nwnd(rs); vl[ls]=vl[rs]=0; tg[ls]=tg[rs]=1;
  }
  void pshp(int cr){if(vl[ls].p)vl[cr]=vl[ls]; else vl[cr]=vl[rs];}
  void ins(int l,int r,int &cr,int p,Node k)
  {
    cr=nwnd(cr); if(l==r){vl[cr]=k;return;}
    int mid=l+r>>1; pshd(cr);
    if(p<=mid)ins(l,mid,ls,p,k); else ins(mid+1,r,rs,p,k);
    pshp(cr);
  }
  void mdfy(int l,int r,int &cr,int L,int R)
  {
    cr=nwnd(cr); if(l>=L&&r<=R){vl[cr]=I;tg[cr]=1;return;}
    int mid=l+r>>1; pshd(cr);
    if(L<=mid)mdfy(l,mid,ls,L,R); if(mid<R)mdfy(mid+1,r,rs,L,R);
    pshp(cr);
  }
  Dt qry(int l,int r,int cr,int L,int p)
  {
    if(l==r)
      {
    if(cal(p,vl[cr].p-1)<=vl[cr].v) return Dt(l,vl[cr].p);
    else return dI;
      }
    int mid=l+r>>1; pshd(cr);
    if(mid<L)return qry(mid+1,r,rs,L,p);
    if(cal(p,vl[rs].p-1)<=vl[rs].v)//mid+1 is ok
      {
    Dt d=qry(l,mid,ls,L,p);
    if(d.y)return d; else return Dt(mid+1,vl[rs].p);
      }
    else return qry(mid+1,r,rs,L,p);
  }
  Dt qryx(int l,int r,int cr,int L)
  {
    if(l==r)
      {
    int d=P::qryx(vl[cr].p-1);
    if(cal(d+1,vl[cr].p-1)<=vl[cr].v)return Dt(d,vl[cr].p);
    else return dI;
      }
    int mid=l+r>>1; pshd(cr);
    if(mid<L)return qryx(mid+1,r,rs,L);
    int d=P::qryx(vl[rs].p-1);
    if(cal(d+1,vl[rs].p-1)<=vl[rs].v)//mid+1 is ok
      {
    Dt ret=qryx(l,mid,ls,L);
    if(ret.y)return ret; else return Dt(d,vl[rs].p);//.y for .x can be 0
      }
    else return qryx(mid+1,r,rs,L);
  }
  void solve()
  {
    I=Node(n+1,INF); lm=n+1;
    ins(1,lm,rt[lm],lm,I); ct[lm]=lm; Dt d;
    for(int i=n;i;i--)
      {
    tim++; rt[i]=nwnd(rt[i+1]);
    d=qry(1,lm,rt[i],ct[i+1],i);
    if(d.x>ct[i+1])mdfy(1,lm,rt[i],ct[i+1],d.x-1);
    ct[i]=d.x-1;
    ins(1,lm,rt[i],ct[i],Node(i,cal(i,d.y-1)));
    ans[i]=upt(ans[d.y]+cal2(i,d.y-1));
      }
  }
  Dt qryx(){ return qryx(1,lm,rt[qi+1],ct[qi+1]);}
}
int main()
{
  n=rdn();int m=rdn();
  for(int i=1;i<=n;i++)
    { a[i]=rdn(); a2[i]=(ll)a[i]*a[i]%mod;
      s[i]=s[i-1]+a[i]; s2[i]=(s2[i-1]+(ll)a[i]*a[i])%mod;}
  P::solve(); S::solve();
  printf("%d\n",P::ans[n]);
  while(m--)
    {
      qi=rdn(); qk=rdn(); qk2=(ll)qk*qk%mod;
      Dt d=S::qryx();
      int ans=upt(P::ans[d.x]+S::ans[d.y]);
      ans=upt(ans+cal2(d.x+1,d.y-1));
      printf("%d\n",ans);
    }
  return 0;
}

 

posted on 2019-04-18 13:26  Narh  阅读(356)  评论(0编辑  收藏  举报

导航