[题解]AT_abc236_f [ABC236F] Spices

思路

首先对所有的 \(c\) 从小到大排序,然后对于每一个值如果之前能凑出就不选,否则就选。

这样做显然是对的。令 \(p_1,p_2,\dots,p_{2^n-1}\) 表示将 \(c\) 排序之后,对应原来的下标;\(S\) 表示选出数的集合;\(S'\) 表示最终选出数的集合。可以证明两个问题:

  1. 如果 \(p_i\) 可以被已选出数凑出,则不需要选 \(p_i\)
  2. 如果 \(p_i\) 不可以被已选出的数凑出,则选 \(p_i\) 最优。

对于第一个问题,我们总可以选出 \(x_1,x_2,\dots,x_k \in S\),使得 \(x_1 \oplus x_2 \oplus \dots \oplus x_k = p_i\)

如果 \(p_i \in S'\),并且能选出 \(p_i,y_1,y_2,\dots,y_q \in S'\),使得 \(z = p_i \oplus y_1 \oplus y_2 \oplus \dots \oplus y_q\),那么一定有 \(z = x_1 \oplus x_2 \oplus \dots \oplus x_k \oplus y_1 \oplus y_2 \oplus \dots \oplus y_q\)

所以选定 \(p_i\) 不是最优的方式。


对于第二个问题,显然会存在 \(p_i \not\in S'\),并且有 \(x_1,x_2,\dots,x_k \in S'\),使得 \(x_1 \oplus x_2 \oplus \dots \oplus x_k = p_i\),即 \(x_1 = p_i \oplus x_2 \oplus \dots \oplus x_k\),即 \(p_i \oplus x_2 \oplus \dots \oplus x_k = x_1\)。又因为 \(p_i\) 不能被凑出,所以 \(x_1,x_2,\dots,x_k\) 中一定有一个元素不在 \(S\) 中。

那么,对于所有的 \(z \in [1,2^n)\) 都存在 \(y_1,y_2,\dots,y_q \in S'\),使得 \(y_1 \oplus y_2 \oplus \dots \oplus y_q = z\),这里假令 \(x_1 = y_1\)。那么有:

\[ z = x_1 \oplus y_2 \oplus y_3 \oplus \dots \oplus y_q = p_i \oplus x_2 \oplus \dots \oplus x_k \oplus y_2 \oplus y_3 \oplus \dots \oplus y_q \]

所以即使 \(S'\) 中没有 \(x_1\),但加上 \(p_i\) 依旧能使得条件成立。

又因为此时 \(S\) 中没有 \(x_1\),所以 \(c_{p_i} \leq c_{x_1}\),因此选 \(p_i\) 更优。


但是这个复杂度看似是 \(\Theta(m^2)\) 的,其中 \(m = 2^n\)。但是其实是 \(\Theta(nm)\) 的。

不难发现最多选出 \(n\) 个数就能将 \([1,2^n)\) 中的所有数凑齐。

Code

#include <bits/stdc++.h>  
#define re register  
#define int long long  
  
using namespace std;  
  
const int N = 1e5 + 10;  
int n,m,ans;  
bool vis[N];  
  
struct point{  
    int x,id;  
  
    friend bool operator <(const point &a,const point &b){  
        return a.x < b.x;  
    }  
}arr[N];  
  
inline int read(){  
    int r = 0,w = 1;  
    char c = getchar();  
    while (c < '0' || c > '9'){  
        if (c == '-') w = -1;  
        c = getchar();  
    }  
    while (c >= '0' && c <= '9'){  
        r = (r << 3) + (r << 1) + (c ^ 48);  
        c = getchar();  
    }  
    return r * w;  
}  
  
signed main(){  
    n = read();  
    m = (1ll << n) - 1;  
    for (re int i = 1;i <= m;i++){  
        arr[i].x = read();  
        arr[i].id = i;  
    }  
    sort(arr + 1,arr + m + 1);  
    for (re int i = 1;i <= m;i++){  
        if (vis[arr[i].id]) continue;  
        ans += arr[i].x;  
        vis[arr[i].id] = true;  
        for (re int j = 1;j <= m;j++) vis[j ^ arr[i].id] |= vis[j];  
    }  
    printf("%lld",ans);  
    return 0;  
}  
posted @ 2024-06-22 10:55  WBIKPS  阅读(15)  评论(0)    收藏  举报