[题解]AT_arc116_b [ARC116B] Products of Min-Max

思路

我们容易可以得到一个朴素的做法,首先对 \(a\) 数组排序,然后枚举最大值和最小值 \(a_i,a_j\),那么对于中间的元素都有选与不选两种情况,得到答案:

\[ \sum_{i = 1}^{n}(a_i \times a_i + (\sum_{j = i + 1}^{n}a_i \times a_j \times 2^{j - i - 1})) \]

然后对这个式子做一个化简:

\[ \sum_{i = 1}^{n}(a_i \times a_i + a_i \times (\sum_{j = i + 1}^{n}a_j \times 2^{j - i - 1})) \]

发现对于每一个 \(i\)\(a_j \times 2^{j - i - 1}\) 都是类似的,所以考虑预处理。

定义 \(m_i = \sum_{j = 1}^{i}(a_j \times 2^j)\),那么发现:

\[ m_n - m_i = \sum_{j = i + 1}^{n}{a_j}\times 2^j \]

然后,发现对于每一项 \(j\) 对于原式都多乘了一个 \(2^{i + 1}\),直接除掉即可。得答案为:

\[ \sum_{i = 1}^n{(a_i \times a_i + \frac{m_n - m_i}{2^{i + 1}} \times a_i)} \]

时间复杂度 \(\Theta(n \log n)\)

Code

#include <bits/stdc++.h>  
#define int long long  
#define re register  
  
using namespace std;  
  
const int N = 2e5 + 10,mod = 998244353;  
int n,ans;  
int arr[N],pot[N],mul[N],inv[N];  
  
inline int read(){  
    int r = 0,w = 1;  
    char c = getchar();  
    while (c < '0' || c > '9'){  
        if (c == '-') w = -1;  
        c = getchar();  
    }  
    while (c >= '0' && c <= '9'){  
        r = (r << 3) + (r << 1) + (c ^ 48);  
        c = getchar();  
    }  
    return r * w;  
}  
  
inline int Add(int a,int b){  
    return (a + b) % mod;  
}  
  
inline int Sub(int a,int b){  
    return ((a - b) % mod + mod) % mod;  
}  
  
inline int Mul(int a,int b){  
    return a * b % mod;  
}  
  
inline void exgcd(int a,int b,int &x,int &y){  
    if (!b){  
        x = 1;  
        y = 0;  
        return;  
    }  
    exgcd(b,a % b,y,x);  
    y = y - a / b * x;  
}  
  
inline int get_inv(int a,int p){  
    int x,y;  
    exgcd(a,p,x,y);  
    return (x % mod + mod) % mod;  
}  
  
inline void init(){  
    pot[0] = 1;  
    for (re int i = 1;i <= n + 1;i++){  
        pot[i] = Mul(pot[i - 1],2);  
        mul[i] = Add(mul[i - 1],Mul(arr[i],pot[i]));  
        inv[i] = get_inv(pot[i],mod);  
    }  
}  
  
signed main(){  
    n = read();  
    for (re int i = 1;i <= n;i++) arr[i] = read();  
    sort(arr + 1,arr + n + 1);  
    init();  
    for (re int i = 1;i <= n;i++){  
        ans = Add(ans,Mul(Mul(Sub(mul[n],mul[i]),inv[i + 1]),arr[i]));  
        ans = Add(ans,Mul(arr[i],arr[i]));  
    }  
    printf("%lld",ans);  
    return 0;  
}  
posted @ 2024-06-23 12:59  WBIKPS  阅读(26)  评论(0)    收藏  举报