[ARC126E] Infinite Operations

不妨把 \(a\) 排序。

考虑一个特殊情况:\(a_1=a_2=\cdots=a_{n-1}=0\)\(a_n=x\)。不妨设此时答案为 \(F(n,x)\)

可以递归把 \(a_2,a_3,\cdots,a_{n}\) 全部变为 \(\dfrac{x}{n-1}\),然后全部取相反数后就是相同问题。可以归纳证明 \(F(n,x)\) 的下界是 \(\dfrac{(n-1)x}{2}\)

对于一般情况,考虑依次加入每个元素,然后把所有数操作到相同,然后就是特殊情况。加入 \(a_i\) 时,前 \(i-1\) 个数都是 \(\dfrac{pre_{i-1}}{i-1}\),其中 \(pre\)\(a\) 的前缀和数组。那么加入 \(i\) 的贡献就是 \(\dfrac{i-1}{2}\times(a_i-\dfrac{pre_{i-1}}{i-1})\)

经过一些化简后,这种方案构造出的收益可以表示为 \(\sum\limits_{i=1}^na_i\times i-\dfrac{n+1}{2}\times\sum a\)。可以离散化后树状数组维护做到 \((n+q)\log(n+q)\)

对于这个方案最优的证明,可以考虑设置初始势能 \(\sum\limits_{i=1}^na_i\times i\),那么每次获得 \(x\) 的收益时,势能总会减少 \(\ge x\)。并且有结束时势能为 \(\dfrac{n+1}{2}\times\sum a\),因此答案也不会超过两者的差。

#include <bits/stdc++.h>
#define ALL(x) begin(x), end(x)
using namespace std;
void file() {
  freopen("1.in", "r", stdin);
  freopen("1.out", "w", stdout);
}
using ll = long long;

namespace QwQ {
  const int kMod = 998244353;
  const int inv2 = (kMod + 1) / 2;

  void Add(int& x, int y) { ((x += y) >= kMod) && (x -= kMod); }
  void Sub(int& x, int y) { ((x -= y) < 0) && (x += kMod); }
  int Sum(int x, int y) { return Add(x, y), x; }
  int Dif(int x, int y) { return Sub(x, y), x; }
  
  const int kN = 6e5 + 5;
  int n, q, o;
  array<int, kN> a, b, ord;

  struct Upd {
    int x, y;
    Upd() {  }
  };
  array<Upd, kN> upd;

  struct BIT {
    array<ll, kN> tr;
    void update(int x, ll v) {
      for(; x < kN; x += (x & -x))
        tr[x] += v;
    }
    ll query(int x) {
      ll ans = 0;
      for(; x; x &= (x - 1))
        ans += tr[x];
      return ans;
    }
  }bit1, bit2;

  int main() {
    // file();
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> q;
    for(int i = 1; i <= n; i++) {
      cin >> a[i];
      b[++o] = a[i];
    }
    for(int i = 1; i <= q; i++) {
      cin >> upd[i].x >> upd[i].y;
      b[++o] = upd[i].y;
    }
    sort(b.data() + 1, b.data() + o + 1);
    o = unique(b.data() + 1, b.data() + o + 1) - b.data() - 1;
    auto find = [&](int x) -> int {
      return lower_bound(b.data() + 1, b.data() + o + 1, x) - b.data();
    };
    for(int i = 1; i <= n; i++)
      a[i] = find(a[i]);
    for(int i = 1; i <= q; i++)
      upd[i].y = find(upd[i].y);
    for(int i = 1; i <= o; i++)
      if(b[i] >= kMod) b[i] -= kMod;
    int sum = 0, ans = 0;
    iota(ord.data() + 1, ord.data() + n + 1, 1);
    sort(ord.data() + 1, ord.data() + n + 1,
      [&](int x, int y) -> bool {
        return a[x] < a[y];
      }
    );
    for(int i = 1; i <= n; i++) {
      int x = ord[i];
      bit1.update(a[x], 1);
      bit2.update(a[x], b[a[x]]);
      Add(sum, b[a[x]]);
      Add(ans, (ll)i * b[a[x]] % kMod);
    }
    for(int i = 1; i <= q; i++) {
      int x = upd[i].x, y = upd[i].y;
      Sub(ans, bit1.query(a[x]) * b[a[x]] % kMod);
      Sub(ans, Dif(sum, bit2.query(a[x]) % kMod));
      bit1.update(a[x], -1);
      bit2.update(a[x], -b[a[x]]);
      Add(sum, Dif(b[y], b[a[x]])), a[x] = y;
      bit1.update(a[x], 1);
      bit2.update(a[x], b[a[x]]);
      Add(ans, bit1.query(a[x]) * b[a[x]] % kMod);
      Add(ans, Dif(sum, bit2.query(a[x]) % kMod));
      cout << ((ans * 2ll - (ll)sum * (n + 1)) % kMod + kMod) * inv2 % kMod << "\n";
    }
    return 0;
  }
} // QwQ

int main() { return QwQ::main(); }
posted @ 2024-11-26 20:00  CJzdc  阅读(38)  评论(0)    收藏  举报