CF1119H Triple
CF1119H Triple [* hard]
给定 \(n,k\),以及 \(n\) 个三元组 \((a_i,b_i,c_i)\),常数 \(x,y,z\) 表示每个三元组均有 \(x\) 个 \(a_i\),\(y\) 个 \(b_i\),\(z\) 个 \(c_i\),求从每个数组中选出一个数使得异或和为 \(t\in [0,2^k)\) 的方案数。
保证 \(a_i,b_i,c_i<2^k,n\le 10^5,k\le 17\)
\(\rm Sol:\)
首先可以 \(\mathcal O(n2^kk)\) 的做 FWT
然后众所周知 \(c(i,j)=(-1)^{i\& j}\),每个元素都形如 \(f_1\times x^{a_i}+f_2\times x^{b_i}+f_3\times x^{c_i}\) 这样的一个集合幂级数,可能的 FWT 值仅有 \(8\) 种,如果能对每个点值 \(x\) 算出取值数就可以知道答案了。
然后考虑一个极其巧妙的转换,不妨给所有三元组均异或上 \(a_i\) 然后视为 \((0,a_i\oplus b_i,a_i\oplus c_i)\) 这样的三元组,这样考虑 FWT 点值必然形如 \(f_1+f_2\times (-1)^{i\&(...)}+f_3\times (-1)^{i\&(...)}\)
于是 FWT 数组有且仅有 \(4\) 种点值 \((f_1+f_2+f_3,f_1-f_2-f_3,f_1+f_2-f_3,f_1-f_2+f_3)\),分别设为 \((a,b,c,d)\)。
于是只需要考虑列出 \(4\) 个方程来求解这 \(4\) 个点值的数量(黎明前的巧克力)
- \(c_1+c_2+c_3+c_4=n\)
事实上,我们进行分类讨论:
- \(c_1\) 即 \(b\) 与 \(c\) 取值均为 \(1\) 的数量。
- \(c_2\) 即 \(b\) 与 \(c\) 取值均为 \(-1\) 的数量。
- \(c_3\) 即 \(b\) 取值为 \(1\),\(c\) 取值为 \(-1\) 的数量。
- \(c_4\) 即 \(b\) 取值为 \(-1\),\(c\) 取值为 \(1\) 的数量。
于是将上值加起来做 FWT 得到的结果应该是 \(c_1+c_3-c_2-c_4\)
将 \(c\) 处取值设为 \(1\),那么 FWT 得到的结果应该是 \(c_1+c_4-c_2-c_3\) 的值。
然后最后最为巧妙的是,我们令 \(b\oplus c\) 为 \(1\)
这样得到的结果显然就是 \(c_1+c_2-c_3-c_4\)
于是设我们得到的值分别为 \(A,B,C,D\),那么就有:
解得:
然后我们进行 IFWT 即可得到答案。
最后这一步考虑正负来计数实在是 tql !
综上,我们得到了一个 \(\mathcal O(k2^k+n)\) 的优秀做法辣!
\(Code:\)
#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define int long long
int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc < '0' || cc > '9' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
while( cc >= '0' && cc <= '9' ) cn = cn * 10 + cc - '0', cc = getchar() ;
return cn * flus ;
}
const int P = 998244353 ;
const int IP = 499122177 ;
const int N = 1e5 + 5 ;
const int M = 3e5 + 5 ;
int n, f1, f2, f3, f4, limit, fA, f[M] ;
struct node { int x, y ; } c[N] ;
struct fucti { int a1, a2, a3, a4 ; } g[M] ;
int fpow( int x, int k ) {
int ans = 1, base = x ;
while(k) {
if(k & 1) ans = ans * base % P ;
base = base * base % P, k >>= 1 ;
} return ans ;
}
void FWT( int *a, int type ) {
for( re int k = 1; k < limit; k <<= 1 )
for( re int i = 0; i < limit; i += ( k << 1 ) )
for( re int j = i; j < i + k; ++ j ) {
int nx = a[j], ny = a[j + k] ;
a[j] = (nx + ny) % P, a[j + k] = (nx - ny + P) % P ;
if( !type ) a[j] = a[j] * IP % P, a[j + k] = a[j + k] * IP % P ;
}
}
void init() {
memset( f, 0, sizeof(f) ) ;
}
void Mod( int &x ) {
x += P, x += P, x %= P ;
}
signed main()
{
n = gi(), limit = gi(), limit = 1 << limit ; int x, y, z ;
x = gi(), y = gi(), z = gi() ;
f1 = x + y + z, f2 = x - y - z, f3 = x + y - z, f4 = x - y + z ;
Mod(f1), Mod(f2), Mod(f3), Mod(f4) ;
rep( i, 1, n ) {
x = gi(), y = gi() ^ x, z = gi() ^ x, fA ^= x ;
c[i].x = y, c[i].y = z ;
}
rep( i, 1, n ) ++ f[c[i].x] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a1 = n, g[i].a2 = f[i] ;
init() ; rep( i, 1, n ) ++ f[c[i].y] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a3 = f[i] ;
init() ; rep( i, 1, n ) ++ f[c[i].x ^ c[i].y] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a4 = f[i] ;
int ip = IP * IP % P ;
for( re int i = 0; i < limit; ++ i ) {
int A = g[i].a1, B = g[i].a2, C = g[i].a3, D = g[i].a4 ;
int c1 = A + B + C + D, c2 = A + D - B - C,
c3 = A + B - C - D, c4 = A + C - B - D ;
Mod(c1), Mod(c2), Mod(c3), Mod(c4),
c1 = c1 * ip % P, c2 = c2 * ip % P, c3 = c3 * ip % P, c4 = c4 * ip % P,
f[i] = fpow( f1, c1 ) * fpow( f2, c2 ) % P * fpow( f3, c3 ) % P * fpow( f4, c4 ) % P ;
}
FWT( f, 0 ) ;
for( re int i = 0; i < limit; ++ i ) printf("%lld ", f[fA ^ i] ) ;
return 0 ;
}

浙公网安备 33010602011771号