语法研究:自适应大小bitset
在 CF1856E2 PermuTree (hard version) 一题中,有一种解法是使用了bitset来维护相关背包。这里不赘述原题,仅需要知道我们想要一个能够满足大小与给定值成正比的bitset,这样才能保证复杂度的正确性。
先上代码再解释。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1 << 20; // 举例,最大 2^20 位
// 模板分派:根据 n 选择合适的 len
template<int len>
void fun(bitset<len> &bits,int n){
for (int i = 0; i < n; ++i) {
if (i % 2 == 0) bits.set(i);
}//上为操作示例:赋值
std::cout << "logical n = " << n << "\n";
std::cout << "bitset size (len) = " << bits.size() << "\n";
}
template<int len = 1>
void dobits(int n) {
if (n > len) return dobits<(len*2>maxn)?maxn:len*2>(n);
// 走到这里时:len >= n 且 len 是某个 2^k(或 maxn)
bitset<len> bits;
fun<len>(bits, n); // 把 bitset 和逻辑长度 n 交给回调
}
int main() {
int k = 1; // 逻辑上只需要前 k 位
cin>>k;
dobits(k);
}
原理:编译器会找出所有可能被调用的实例化,比如dobits<1>(n),dobits<2>(n),dobits<4>(n)...dobits<maxn>(n),调用的时候会递归直到满足条件。
由于这些len都是编译期就可以确定的常量,因此bitset
但是可惜的是,我们创造的这个bitset只能活在函数dobits里而无法调用或引出,因此想要发挥作用就必须用另一个函数fun来接受其参数并进行你想要的操作。比如这里是赋值所有偶数位。
警示:不能将dobits进行如下修改:
点击查看代码
template<int len = 1>
void dobits(int n) {
if (n > len) {
if (len * 2 <= maxn) {
dobits<len * 2>(n); // 继续放大
} else {
dobits<maxn>(n); // 封顶
}
return;
}
std::bitset<len> bits;
fun(bits, n);
}
由于c++14还未支持if constexpr,因此编译器会认为两个分支都有可能被访问。使得len=maxn时,继续产生len=2*maxn,len=4*maxn这样的类型,无限递归,编译失败。
如果你愿意,将bitset
这里附上一份CF1856E2这题的代码。如果不使用这种技巧,你将被迫写近20个bitset并手工判定处理(想想就吓人)。
点击查看代码
#include<bits/stdc++.h>
#define min(x,y) ((x)<(y)?(x):(y))
using namespace std;
typedef long long ll;
const int N=1e6+5;
struct node{
int nxt,to;
}e[N];
int num,head[N],n,son[N],siz[N],c[N],cnt,a[N],tot;
ll ans;
void add(int x,int y){
e[++num]=node{head[x],y};head[x]=num;
}
template<int len>
void fun(bitset<len> &f,bitset<len> &tmp){
f.reset();f[0]=1;
for(int i=1;i<=cnt;i++)
tmp=f<<a[i],f|=tmp;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
template<int len=1>
void dobits(int ndlen){
if(ndlen>=len) return dobits<min(len*2,N)>(ndlen);
bitset<len> f,tmp;fun<len>(f,tmp);
}
void dfs(int u){
siz[u]=1;vector<int> q;map<int,int> c;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
dfs(v);
siz[u]+=siz[v];
if(!son[u]||siz[son[u]]<siz[v])
son[u]=v;
if(!c[siz[v]])q.push_back(siz[v]);
c[siz[v]]++;
}
tot=(siz[u]-1);
cnt=0;
for(int i=0;i<q.size();i++){
int w=q[i],v=c[q[i]];
for(int j=1;j<=v;j<<=1)
v-=j,a[++cnt]=j*w;
if(v)a[++cnt]=v*w;
c[q[i]]=0;
}
if(!son[u])return;
if(siz[son[u]]>=tot/2)ans+=1ll*siz[son[u]]*(tot-siz[son[u]]);
else dobits(siz[u]);
}
int main(){
scanf("%d",&n);
for(int i=2;i<=n;i++){
int x;scanf("%d",&x);
add(x,i);
}
dfs(1);
cout<<ans<<endl;
return 0;
}
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e6+5;
struct node{
int nxt,to;
}e[N];
int num,head[N],n,son[N],siz[N],c[N],cnt,a[N],tot;
ll ans;
void add(int x,int y){
e[++num]=node{head[x],y};head[x]=num;
}
bitset<64> f,tmp;
void solve0(int x){
f.reset();f[0]=1;
for(int i=1;i<=cnt;i++)
tmp=f<<a[i],f|=tmp;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<128> f1,tmp1;
void solve1(int x){
f1.reset();f1[0]=1;
for(int i=1;i<=cnt;i++)
tmp1=f1<<a[i],f1|=tmp1;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f1[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<256> f2,tmp2;
void solve2(int x){
f2.reset();f2[0]=1;
for(int i=1;i<=cnt;i++)
tmp2=f2<<a[i],f2|=tmp2;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f2[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<512> f3,tmp3;
void solve3(int x){
f3.reset();f3[0]=1;
for(int i=1;i<=cnt;i++)
tmp3=f3<<a[i],f3|=tmp3;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f3[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<1024> f4,tmp4;
void solve4(int x){
f4.reset();f4[0]=1;
for(int i=1;i<=cnt;i++)
tmp4=f4<<a[i],f4|=tmp4;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f4[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<2048> f5,tmp5;
void solve5(int x){
f5.reset();f5[0]=1;
for(int i=1;i<=cnt;i++)
tmp5=f5<<a[i],f5|=tmp5;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f5[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<4096> f6,tmp6;
void solve6(int x){
f6.reset();f6[0]=1;
for(int i=1;i<=cnt;i++)
tmp6=f6<<a[i],f6|=tmp6;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f6[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<8192> f7,tmp7;
void solve7(int x){
f7.reset();f7[0]=1;
for(int i=1;i<=cnt;i++)
tmp7=f7<<a[i],f7|=tmp7;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f7[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<16384> f8,tmp8;
void solve8(int x){
f8.reset();f8[0]=1;
for(int i=1;i<=cnt;i++)
tmp8=f8<<a[i],f8|=tmp8;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f8[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<32768> f9,tmp9;
void solve9(int x){
f9.reset();f9[0]=1;
for(int i=1;i<=cnt;i++)
tmp9=f9<<a[i],f9|=tmp9;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f9[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<65536> f10,tmp10;
void solve10(int x){
f10.reset();f10[0]=1;
for(int i=1;i<=cnt;i++)
tmp10=f10<<a[i],f10|=tmp10;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f10[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<131072> f11,tmp11;
void solve11(int x){
f11.reset();f11[0]=1;
for(int i=1;i<=cnt;i++)
tmp11=f11<<a[i],f11|=tmp11;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f11[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<262144> f12,tmp12;
void solve12(int x){
f12.reset();f12[0]=1;
for(int i=1;i<=cnt;i++)
tmp12=f12<<a[i],f12|=tmp12;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f12[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<524288> f13,tmp13;
void solve13(int x){
f13.reset();f13[0]=1;
for(int i=1;i<=cnt;i++)
tmp13=f13<<a[i],f13|=tmp13;ll ss=0;
for(int i=0;i<=tot/2;i++)
if(f13[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
void dfs(int u){
siz[u]=1;vector<int> q;map<int,int> c;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
dfs(v);
siz[u]+=siz[v];
if(!son[u]||siz[son[u]]<siz[v])
son[u]=v;
if(!c[siz[v]])q.push_back(siz[v]);
c[siz[v]]++;
}
tot=(siz[u]-1);
cnt=0;
for(int i=0;i<q.size();i++){
int w=q[i],v=c[q[i]];
for(int j=1;j<=v;j<<=1)
v-=j,a[++cnt]=j*w;
if(v)a[++cnt]=v*w;
c[q[i]]=0;
}
if(!son[u])return;
if(siz[son[u]]>=tot/2)ans+=1ll*siz[son[u]]*(tot-siz[son[u]]);
else{
if(siz[u]<=128)solve0(siz[son[u]]);
else if(siz[u]<=256)solve1(siz[son[u]]);
else if(siz[u]<=512)solve2(siz[son[u]]);
else if(siz[u]<=1024)solve3(siz[son[u]]);
else if(siz[u]<=2048)solve4(siz[son[u]]);
else if(siz[u]<=4096)solve5(siz[son[u]]);
else if(siz[u]<=8192)solve6(siz[son[u]]);
else if(siz[u]<=16384)solve7(siz[son[u]]);
else if(siz[u]<=32768)solve8(siz[son[u]]);
else if(siz[u]<=65536)solve9(siz[son[u]]);
else if(siz[u]<=131072)solve10(siz[son[u]]);
else if(siz[u]<=262144)solve11(siz[son[u]]);
else if(siz[u]<=524288)solve12(siz[son[u]]);
else solve13(siz[son[u]]);
}
}
int main(){
scanf("%d",&n);
for(int i=2;i<=n;i++){
int x;scanf("%d",&x);
add(x,i);
}
dfs(1);
cout<<ans<<endl;
return 0;
}

浙公网安备 33010602011771号