Educational Codeforces Round 45 (Rated for Div. 2) G - GCD Counting
思路:我猜测了一下gcd的个数不会很多,然后我就用dfs回溯的时候用map暴力合并就好啦。
终判被卡了MLE。。。。。 需要每次清空一下子树的map。。。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define pii pair<int,int> #define piii pair<int, pair<int,int> > using namespace std; const int N = 2e5 + 7; const int M = 10 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-6; int n, a[N]; LL cnt[N]; vector<int> edge[N]; map<int, LL> mp[N]; int gcd(int a, int b) { return !b ? a : gcd(b, a % b); } void dfs(int u, int p) { mp[u][a[u]]++; for(int v : edge[u]) { if(v == p) continue; dfs(v, u); map<int, LL> :: iterator itu, itv; for(itu = mp[u].begin(); itu != mp[u].end(); itu++) { for(itv = mp[v].begin(); itv != mp[v].end(); itv++) { cnt[gcd(itu -> first, itv -> first)] += itu -> second * itv -> second; } } for(itv = mp[v].begin(); itv != mp[v].end(); itv++) { mp[u][gcd(a[u], itv -> first)] += itv -> second; } } } int main() { scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d", &a[i]); cnt[a[i]]++; } for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); edge[u].push_back(v); edge[v].push_back(u); } dfs(1, 0); for(int i = 1; i < N; i++) { if(cnt[i]) { printf("%d %lld\n", i, cnt[i]); } } return 0; } /* */
通过代码
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define pii pair<int,int> #define piii pair<int, pair<int,int> > using namespace std; const int N = 2e5 + 7; const int M = 10 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-6; int n, a[N]; LL cnt[N]; vector<int> edge[N]; map<int, LL> mp[N]; int gcd(int a, int b) { return !b ? a : gcd(b, a % b); } void dfs(int u, int p) { mp[u][a[u]]++; for(int v : edge[u]) { if(v == p) continue; dfs(v, u); map<int, LL> :: iterator itu, itv; for(itu = mp[u].begin(); itu != mp[u].end(); itu++) { for(itv = mp[v].begin(); itv != mp[v].end(); itv++) { cnt[gcd(itu -> first, itv -> first)] += itu -> second * itv -> second; } } for(itv = mp[v].begin(); itv != mp[v].end(); itv++) { mp[u][gcd(a[u], itv -> first)] += itv -> second; } mp[v].clear(); } } int main() { scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d", &a[i]); cnt[a[i]]++; } for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); edge[u].push_back(v); edge[v].push_back(u); } dfs(1, 0); for(int i = 1; i < N; i++) { if(cnt[i]) { printf("%d %lld\n", i, cnt[i]); } } return 0; } /* */