语法研究:自适应大小bitset

在 CF1856E2 PermuTree (hard version) 一题中,有一种解法是使用了bitset来维护相关背包。这里不赘述原题,仅需要知道我们想要一个能够满足大小与给定值成正比的bitset,这样才能保证复杂度的正确性。

先上代码再解释。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1 << 20; // 举例,最大 2^20 位
// 模板分派:根据 n 选择合适的 len
template<int len>
void fun(bitset<len> &bits,int n){
	for (int i = 0; i < n; ++i) {
		if (i % 2 == 0) bits.set(i);
	}//上为操作示例:赋值
	std::cout << "logical n = " << n << "\n";
	std::cout << "bitset size (len) = " << bits.size() << "\n";
}
template<int len = 1>
void dobits(int n) {
	if (n > len) return dobits<(len*2>maxn)?maxn:len*2>(n);
	// 走到这里时:len >= n 且 len 是某个 2^k(或 maxn)
	bitset<len> bits;
	fun<len>(bits, n);  // 把 bitset 和逻辑长度 n 交给回调
}
int main() {
	int k = 1; // 逻辑上只需要前 k 位
	cin>>k;
	dobits(k);
}


原理:编译器会找出所有可能被调用的实例化,比如dobits<1>(n),dobits<2>(n),dobits<4>(n)...dobits<maxn>(n),调用的时候会递归直到满足条件。
由于这些len都是编译期就可以确定的常量,因此bitset就是一个合法的声明。

但是可惜的是,我们创造的这个bitset只能活在函数dobits里而无法调用或引出,因此想要发挥作用就必须用另一个函数fun来接受其参数并进行你想要的操作。比如这里是赋值所有偶数位。

警示:不能将dobits进行如下修改:

点击查看代码
template<int len = 1>
void dobits(int n) {
    if (n > len) {
        if (len * 2 <= maxn) {
            dobits<len * 2>(n);  // 继续放大
        } else {
            dobits<maxn>(n);     // 封顶
        }
        return;
    }
    std::bitset<len> bits;
    fun(bits, n);
}

由于c++14还未支持if constexpr,因此编译器会认为两个分支都有可能被访问。使得len=maxn时,继续产生len=2*maxn,len=4*maxn这样的类型,无限递归,编译失败。

如果你愿意,将bitset换成vector<bitset<len>> 也是可以的,这样你就不止开了一个bitset,注意要对应的将fun也进行修改。

这里附上一份CF1856E2这题的代码。如果不使用这种技巧,你将被迫写近20个bitset并手工判定处理(想想就吓人)。

点击查看代码
#include<bits/stdc++.h>
#define min(x,y) ((x)<(y)?(x):(y))
using namespace std;
typedef long long ll;
const int N=1e6+5;
struct node{
	int nxt,to;
}e[N];
int num,head[N],n,son[N],siz[N],c[N],cnt,a[N],tot;
ll ans;
void add(int x,int y){
	e[++num]=node{head[x],y};head[x]=num;
}
template<int len>
void fun(bitset<len> &f,bitset<len> &tmp){
	f.reset();f[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp=f<<a[i],f|=tmp;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
template<int len=1>
void dobits(int ndlen){
	if(ndlen>=len) return dobits<min(len*2,N)>(ndlen);
	bitset<len> f,tmp;fun<len>(f,tmp);
}
void dfs(int u){
	siz[u]=1;vector<int> q;map<int,int> c;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		dfs(v);
		siz[u]+=siz[v];
		if(!son[u]||siz[son[u]]<siz[v])
		son[u]=v;
		if(!c[siz[v]])q.push_back(siz[v]);
		c[siz[v]]++;
	}
	tot=(siz[u]-1);
	cnt=0;
	for(int i=0;i<q.size();i++){
		int w=q[i],v=c[q[i]];
		for(int j=1;j<=v;j<<=1)
		v-=j,a[++cnt]=j*w;
		if(v)a[++cnt]=v*w;
		c[q[i]]=0;
	}
	if(!son[u])return;
	if(siz[son[u]]>=tot/2)ans+=1ll*siz[son[u]]*(tot-siz[son[u]]);
	else dobits(siz[u]);
}
int main(){
	scanf("%d",&n);
	for(int i=2;i<=n;i++){
		int x;scanf("%d",&x);
		add(x,i);
	}
	dfs(1);
	cout<<ans<<endl;
	return 0;
}
下面是手工分讨代码:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e6+5;
struct node{
	int nxt,to;
}e[N];
int num,head[N],n,son[N],siz[N],c[N],cnt,a[N],tot;
ll ans;
void add(int x,int y){
	e[++num]=node{head[x],y};head[x]=num;
}
bitset<64>  f,tmp;
void solve0(int x){
	f.reset();f[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp=f<<a[i],f|=tmp;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<128>  f1,tmp1;
void solve1(int x){
	f1.reset();f1[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp1=f1<<a[i],f1|=tmp1;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f1[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<256>  f2,tmp2;
void solve2(int x){
	f2.reset();f2[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp2=f2<<a[i],f2|=tmp2;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f2[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<512>  f3,tmp3;
void solve3(int x){
	f3.reset();f3[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp3=f3<<a[i],f3|=tmp3;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f3[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<1024>  f4,tmp4;
void solve4(int x){
	f4.reset();f4[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp4=f4<<a[i],f4|=tmp4;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f4[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<2048>  f5,tmp5;
void solve5(int x){
	f5.reset();f5[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp5=f5<<a[i],f5|=tmp5;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f5[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<4096>  f6,tmp6;
void solve6(int x){
	f6.reset();f6[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp6=f6<<a[i],f6|=tmp6;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f6[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<8192>  f7,tmp7;
void solve7(int x){
	f7.reset();f7[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp7=f7<<a[i],f7|=tmp7;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f7[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<16384>  f8,tmp8;
void solve8(int x){
	f8.reset();f8[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp8=f8<<a[i],f8|=tmp8;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f8[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<32768>  f9,tmp9;
void solve9(int x){
	f9.reset();f9[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp9=f9<<a[i],f9|=tmp9;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f9[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<65536>  f10,tmp10;
void solve10(int x){
	f10.reset();f10[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp10=f10<<a[i],f10|=tmp10;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f10[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<131072>  f11,tmp11;
void solve11(int x){
	f11.reset();f11[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp11=f11<<a[i],f11|=tmp11;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f11[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<262144>  f12,tmp12;
void solve12(int x){
	f12.reset();f12[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp12=f12<<a[i],f12|=tmp12;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f12[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
bitset<524288>  f13,tmp13;
void solve13(int x){
	f13.reset();f13[0]=1;
	for(int i=1;i<=cnt;i++)
	tmp13=f13<<a[i],f13|=tmp13;ll ss=0;
	for(int i=0;i<=tot/2;i++)
	if(f13[i])ss=max(ss,1ll*i*(tot-i));ans+=ss;
}
void dfs(int u){
	siz[u]=1;vector<int> q;map<int,int> c;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		dfs(v);
		siz[u]+=siz[v];
		if(!son[u]||siz[son[u]]<siz[v])
		son[u]=v;
		if(!c[siz[v]])q.push_back(siz[v]);
		c[siz[v]]++;
	}
	tot=(siz[u]-1);
	cnt=0;
	for(int i=0;i<q.size();i++){
		int w=q[i],v=c[q[i]];
		for(int j=1;j<=v;j<<=1)
		v-=j,a[++cnt]=j*w;
		if(v)a[++cnt]=v*w;
		c[q[i]]=0;
	}
	if(!son[u])return;
	if(siz[son[u]]>=tot/2)ans+=1ll*siz[son[u]]*(tot-siz[son[u]]);
	else{
		if(siz[u]<=128)solve0(siz[son[u]]);
		else if(siz[u]<=256)solve1(siz[son[u]]);
		else if(siz[u]<=512)solve2(siz[son[u]]);
		else if(siz[u]<=1024)solve3(siz[son[u]]);
		else if(siz[u]<=2048)solve4(siz[son[u]]);
		else if(siz[u]<=4096)solve5(siz[son[u]]);
		else if(siz[u]<=8192)solve6(siz[son[u]]);
		else if(siz[u]<=16384)solve7(siz[son[u]]);
		else if(siz[u]<=32768)solve8(siz[son[u]]);
		else if(siz[u]<=65536)solve9(siz[son[u]]);
		else if(siz[u]<=131072)solve10(siz[son[u]]);
		else if(siz[u]<=262144)solve11(siz[son[u]]);
		else if(siz[u]<=524288)solve12(siz[son[u]]);
		else solve13(siz[son[u]]);
	}
}
int main(){
	scanf("%d",&n);
	for(int i=2;i<=n;i++){
		int x;scanf("%d",&x);
		add(x,i);
	}
	dfs(1);
	cout<<ans<<endl;
	return 0;
}
posted @ 2025-11-15 11:03  runzelai  阅读(25)  评论(0)    收藏  举报