[NEERC2016] Binary Code 题解
一个字符串最多有一个 \(?\),相当于每个字符串最多有两个状态,显然 \(2-sat\)。
看到前缀想到 \(trie\)。考虑将所有可能出现在答案中的字符串全都塞到 \(trie\) 树里,那么对于一个字符串状态 \(x\) 的末尾对应的点 \(cur\),假如选择了这种状态,\(cur\) 的祖先、\(cur\) 自己、\(cur\) 的后代所储存的状态就都不能用,经典 \(2-sat\) 连边了。
当然,这样边数太大,无法通过该题,所以根据 \(trie\) 树的形态,对于 \(cur\) 的祖先,我们建一棵内向树;对于 \(cur\) 的后代,我们建一棵外向树。
但问题随之出现,假如我们按 \(x\to cur\to \operatorname{oth}(x)\) 的顺序去连边,那么自己就连到自己了,不符合要求。所以我们连 \(x\to fa(cur)/sn(cur'),cur/cur'\to \operatorname{oth}(x)\) 四条边,同时对于每个 \(cur\),在以该点结尾的集合内部暴力连边。容易证明,这样点数、边数都是 \(n+\sum |S|\) 级别的。
时间复杂度 \(O(n+\sum |S|)\)。
#include<bits/stdc++.h>
using namespace std;
const int N=5e5+5,M=3e6+5;
int n,m,cnt,idx[M],ans[N],w[N];string st[N];
map<string,int>mp;
int oth(int x){return (x+n-1)%(n+n)+1;}
namespace SAT{
int id,dfn[M],low[M],vs[M];
vector<int>g[M];int st[M],tp;
void tarjan(int x){
dfn[x]=low[x]=++id,vs[st[++tp]=x]=1;
for(auto y:g[x]){
if(!dfn[y]) tarjan(y),low[x]=min(low[x],low[y]);
else if(vs[y]) low[x]=min(low[x],dfn[y]);
}if(dfn[x]!=low[x]) return;cnt++;
while(st[tp+1]!=x) idx[st[tp]]=cnt,vs[st[tp--]]=0;
}
}namespace TRIE{
int tr[N*2][2],fa[N*2];
vector<int>g[N*2];int id=1;
void add(string s,int num){
int nw=1;
for(int i=0;s[i];i++){
int c=s[i]-'0';
if(!tr[nw][c]) fa[tr[nw][c]=++id]=nw;
nw=tr[nw][c];
}g[nw].push_back(num);
}void build(){
for(int i=2;i<=id;i++){
SAT::g[n*2+i].push_back(n*2+fa[i]);
SAT::g[n*2+id+fa[i]].push_back(n*2+id+i);
for(auto x:g[i]){
SAT::g[x].push_back(n*2+fa[i]);
SAT::g[n*2+i].push_back(oth(x));
SAT::g[n*2+id+i].push_back(oth(x));
for(auto y:g[i]) if(x!=y)
SAT::g[x].push_back(oth(y));
}for(auto x:g[fa[i]])
SAT::g[x].push_back(n*2+id+i);
}m=n*2+id*2;
}
}int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0),cin>>n;
for(int i=1;i<=n;i++){
cin>>st[i],w[i]=-1,mp[st[i]]++;
if(mp[st[i]]>2) return cout<<"NO",0;
for(int j=0;st[i][j];j++) if(st[i][j]=='?') w[i]=j;
if(~w[i]){
st[i][w[i]]='0',TRIE::add(st[i],i);
st[i][w[i]]='1',TRIE::add(st[i],n+i);
}else TRIE::add(st[i],i),TRIE::add(st[i],n+i);
}TRIE::build();
for(int i=1;i<=m;i++)
if(!SAT::dfn[i]) SAT::tarjan(i);
for(int i=1;i<=n;i++){
if(idx[i]==idx[i+n]) return cout<<"NO",0;
ans[i]=(idx[i]>idx[i+n]);
}cout<<"YES\n";
for(int i=1;i<=n;i++){
if(~w[i]) st[i][w[i]]=(char)(ans[i]+'0');
for(int j=0;st[i][j];j++) cout<<st[i][j];cout<<"\n";
}return 0;
}

浙公网安备 33010602011771号