HDU 6053 TrickGCD 容斥
题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=6053
题意:
给你序列a,让你构造序列b,要求 1<=b[i]<=a[i],且b序列的gcd>=2。问你方案数。
思路:
容易想到的就是我们枚举整个序列的gcd,然后a[i]/gcd就是i位置能够填的数的个数,然后每个位置累积就能得到数列为gcd时的方案数。
最后容斥一下累加就是答案。但是最大gcd可以是100000和明显这样做n^2,会超时。
那么我们把a[i]/gcd的放在一起,然后用快速幂直接求出值。具体来说,当前枚举的gcd是a,把a,2*a,3*a.......分块,对于每一块对gcd为a的贡献是在这一块的a[i]的个数,那么前缀和处理一下就好了。
对于容斥:
也没做过几道容斥,对于dfs的容斥,第一次见,就是说每次容斥的数都是num*prime[i],保证了dfs了所有数,并且只有log层
cnt[num] 是对于num前面已经加或者减了多少次,那么我们只要每个数一次,就对于当前的数需要加或者减1-cnt[num]才能变成1,对于num的倍数也影响了1-cnt[num]。
还有一种容斥,更简单,就是对于gcd倒着遍历,筛去他的倍数的贡献,也就是重复加的要减去。嗯 很巧妙
代码:
代码一:
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define MS(a) memset(a,0,sizeof(a)) #define MP make_pair #define PB push_back const int INF = 0x3f3f3f3f; const ll INFLL = 0x3f3f3f3f3f3f3f3fLL; inline ll read(){ ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } ////////////////////////////////////////////////////////////////////////// const int maxn = 1e5+10; const int mod = 1e9+7; int v[maxn],s[maxn],cnt[maxn],res[maxn],mi,mx; bool vis[maxn]; vector<int> prime; ll ans; ll qpow(ll a,ll b){ ll res = 1; while(b){ if(b&1) res = (res*a)%mod; a = (a*a)%mod; b >>= 1; } return res; } void init(){ for(int i=2; i<maxn; i++){ if(vis[i]) continue; prime.push_back(i); for(int j=i+i; j<maxn; j+=i){ vis[j] = true; } } } void dfs(ll num){ if(num > mi) return ; if(cnt[num] == 1) return ; if(num > 1){ int dt = 1-cnt[num]; ans = (ans+dt*res[num]+mod)%mod; cnt[num] = 1; for(int i=num*2; i<=mi; i+=num){ cnt[i] += dt; } } for(int i=0; i<(int)prime.size(); i++) dfs(num*prime[i]); } int main(){ init(); int T = read(), kase = 1; while(T--){ MS(v); MS(s); MS(cnt); int n = read(); mi = INF,mx = 0; for(int i=1; i<=n; i++){ int x = read(); v[x]++; mi = min(mi,x); mx = max(mx,x); } for(int i=1; i<maxn; i++) s[i] = s[i-1]+v[i]; for(int a=2; a<=mi; a++){ res[a] = 1; for(int i=a*2,j=1; i<=mx+a; i+=a, j++){ int t = (i-1>mx ? s[mx] : s[i-1]); res[a] = (res[a]*qpow(j,t-s[i-a-1])%mod); } } // for(int i=2; i<=mi; i++) // cout << i << " " << res[i] << endl; ans = 0; dfs(1); printf("Case #%d: %I64d\n",kase++,ans); } return 0; }
代码二:
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define MS(a) memset(a,0,sizeof(a)) #define MP make_pair #define PB push_back const int INF = 0x3f3f3f3f; const ll INFLL = 0x3f3f3f3f3f3f3f3fLL; inline ll read(){ ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } ////////////////////////////////////////////////////////////////////////// const int maxn = 1e5+10; const int mod = 1e9+7; int v[maxn],s[maxn],res[maxn],mi,mx; ll ans,dp[maxn]; ll qpow(ll a,ll b){ ll res = 1; while(b){ if(b&1) res = (res*a)%mod; a = (a*a)%mod; b >>= 1; } return res; } int main(){ int T = read(), kase = 1; while(T--){ MS(v); MS(s); MS(dp); int n = read(); mi = INF,mx = 0; for(int i=1; i<=n; i++){ int x = read(); v[x]++; mi = min(mi,x); mx = max(mx,x); } for(int i=1; i<maxn; i++) s[i] = s[i-1]+v[i]; for(int a=2; a<=mi; a++){ res[a] = 1; for(int i=a*2,j=1; i<=mx+a; i+=a, j++){ int t = (i-1>mx ? s[mx] : s[i-1]); res[a] = (res[a]*qpow(j,t-s[i-a-1])%mod); } } for(int i=mi; i>=2; --i){ dp[i] = res[i]; for(int j=i<<1; j<=mi; j+=i){ dp[i] -= dp[j]; dp[i] = (dp[i]%mod+mod)%mod; } } ans = 0; for(int i=2; i<=mi; i++) ans = (ans+dp[i])%mod; printf("Case #%d: %I64d\n",kase++,ans); } return 0; }