Codechef March Cook-Off 2018. Maximum Tree Path

@(Codechef March Cook-Off 2018. Maximum Tree Path)

题意

给你一颗\(n(1e5)\)个点有边权有点权的树,\(Min(u,v)\)表示\(u,v\)路径最小点权,\(gcd(u,v)\)表示\(u,v\)路径点权的最大公因数,\(dis(u,v)\)表示\(u,v\)路径大小。
输出\(max(dis(u,v)*gcd(u,v)*Min(u,v))\)


解析

  • 法一:
  • 外层枚举路径的gcd,并把两端点是gcd倍数的边存下,按照两端点较小的权值排序。
  • 每次加一条边,并查集动态维护直径即可。路径最小点权就是当前加的边的最小点权。
  • 因为平均每条边的因子个数不超过\sqrt(10000)
  • 所以总的枚举边的个数不超过n*\sqrt(10000)
  • 复杂度:\(O(nlog(n) + n*\sqrt(10000))\)
  • 法二:暴搜+剪枝
  • 敢剪就敢过
  • dia = 树的直径
  • 剪枝:当diagcdmin <= ans时,return
  • 先从直径的一个端点开始暴搜+剪枝预处理一遍。(不然会tle,可能先以直径为端点搜会搜到大答案的概率更大,剪枝效果更明显)
  • 然后枚举起点开始暴搜,记录当前路径的gcd和min
  • 法三:点分治
  • 按重心分治,从重心开始搜索,记录每条路径的{second ancestor,dis, Gcd, Min}
  • 按Min从小到大排序。这个排序很精髓。
  • 因为所有路径都含有重心这个点,那么gcd的可能性就只有不到 \sqrt (10000)种。
  • 后面在枚举边的时候,动态对每种gcd记录最远的次远的路径,两条路径的次祖先不能相同。
  • 先枚举边,再枚举gcd,看看这个gcd记录的两条路径能否和当前枚举的边组合出更优解。
  • 复杂度肯定小于:\(O(nlog(n) + n*\sqrt(10000))\)

AC_code

/*
 * Codechef March Cook-Off 2018. Maximum Tree Path
 * 法一:
 * 外层枚举路径的gcd,并把两端点是gcd倍数的边存下,按照两端点较小的权值排序。
 * 每次加一条边,并查集动态维护直径即可。路径最小点权就是当前加的边的最小点权。
 * 因为平均每条边的因子个数不超过\sqrt(10000)
 * 所以总的枚举边的个数不超过n*\sqrt(10000)
 * 复杂度:O(nlog(n) + n*\sqrt(10000))
 * 法二:暴搜+剪枝
 * 敢剪就敢过
 * dia = 树的直径
 * 剪枝:当dia*gcd*min <= ans时,return
 * 先从直径的一个端点开始暴搜+剪枝预处理一遍。(不然会tle,可能先以直径为端点搜会搜到大答案的概率更大,剪枝效果更明显)
 * 然后枚举起点开始暴搜,记录当前路径的gcd和min
 * 法三:点分治
 * */
#pragma comment(linker, "/STACK:102400000,102400000")
#include<bits/stdc++.h>
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int, int> pii;
inline LL read() {
    LL x = 0;int f = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x = f ? -x : x;
}
inline void write(LL x, bool f) {
    if (x == 0) {putchar('0'); if(f)putchar('\n');else putchar(' ');return;}
    if (x < 0) {putchar('-');x = -x;}
    static char s[23];
    int l = 0;
    while (x != 0)s[l++] = x % 10 + 48, x /= 10;
    while (l)putchar(s[--l]);
    if(f)putchar('\n');else putchar(' ');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
void debug_out() { cerr << '\n'; }
template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);


const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 5e5 + 7;
const int MXE = 1e6 + 7;
int n, m;
vector<pii > mp[MXN];
vector<int> has[MXN];
int ar[MXN];
pair<pii, int> cw[MXN];
int stk[MXN], top;
LL ans;
namespace LCA {
    LL dis[MXN];
    int up[MXN][20], lens[MXN];
    int cnt, dfn[MXN], en[MXN], LOG[MXN];
    void dfs(int u, int ba) {
        lens[u] = lens[ba] + 1;
        dfn[++cnt] = u;
        en[u] = cnt;
        for(auto V: mp[u]) {
            int v = V.fi;
            if(v == ba) continue;
            dis[v] = dis[u] + V.se;
            dfs(v, u);
            dfn[++ cnt] = u;
        }
    }
    inline int cmp(int u, int v) {
        return lens[u] < lens[v] ? u: v;
    }
    void init() {
        cnt = 0;
        dis[0] = lens[0] = 0;
        dfs(1, 0);
        LOG[1] = 0;
        for(int i = 2; i <= cnt; ++i) LOG[i] = LOG[i-1] + ((1<<(LOG[i-1]+1))==i);
        for(int i = 1; i <= cnt; ++i) up[i][0] = dfn[i];
        for(int j = 1; (1<<j) <= cnt; ++j)
            for(int i = 1; i + (1<<j) -1 <= cnt; ++i)
                up[i][j] = cmp(up[i][j-1], up[i+(1<<(j-1))][j-1]);
    }
    inline int lca(int x, int y) {
        int l = en[x], r = en[y];
        if(l > r) swap(l, r);
        int k = LOG[r - l + 1];
        return cmp(up[l][k], up[r-(1<<k)+1][k]);
    }
    inline LL query(int i, int j) {
        return dis[i] + dis[j] - 2 * dis[lca(i, j)];
    }
}
int fa[MXN];
pii data[MXN];
pii merge(pii A, pii B) {
    int a[4];
    a[0] = A.fi, a[1] = A.se, a[2] = B.fi, a[3] = B.se;
    pii tmp = A;
    LL res = 0;
    for(int i = 0; i < 4; ++i) {
        for(int j = i + 1; j < 4; ++j) {
            LL ret = LCA::query(a[i], a[j]);
            if(ret > res) {
                res = ret;
                tmp = mk(a[i], a[j]);
            }
        }
    }
    return tmp;
}
bool cmp(const int&a, const int&b) {
    return cw[a].se > cw[b].se;
}
int Fi(int x) {
    return fa[x] == x? x: fa[x] = Fi(fa[x]);
}
int main() {
#ifndef ONLINE_JUDGE
    freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
    //freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
    for(int i = 0; i <= 10000; ++i) stk[i] = i;
    int tim = read();
    while(tim --) {
        n = read();
        for(int i = 1; i <= n; ++i) mp[i].clear();
        for(int i = 1; i <= n; ++i) ar[i] = read();
        top = 0;
        for(int i = 1, a, b, c; i < n; ++i) {
            a = read(), b = read(), c = read();
            cw[i] = mk(mk(a, b), sml(ar[a], ar[b]));
            mp[a].eb(mk(b, c));
            mp[b].eb(mk(a, c));
//            stk[++top] = c;
            c = __gcd(ar[a], ar[b]);
            for(int j = 1; j * j <= c; ++j) if(c % j == 0) {
                has[j].eb(i);
                if(c/j != j) has[c/j].eb(i);
            }
        }
        ans = 0;
        LCA::init();
//        sort(stk + 1, stk + 1 + top);
//        top = unique(stk + 1, stk + 1 + top) - stk;
        for(int i = 1; i <= 10000; ++i) {
            if((int)has[stk[i]].size() <= 0)continue;
            for(auto V: has[stk[i]]) {
                fa[cw[V].fi.fi] = cw[V].fi.fi, fa[cw[V].fi.se] = cw[V].fi.se;
                data[cw[V].fi.fi] = mk(cw[V].fi.fi, cw[V].fi.fi), data[cw[V].fi.se] = mk(cw[V].fi.se, cw[V].fi.se);
            }
            sort(all(has[stk[i]]), cmp);
            for(auto V: has[stk[i]]) {
                int pa = Fi(cw[V].fi.fi), pb = Fi(cw[V].fi.se);
                fa[pa] = pb;
                data[pb] = merge(data[pa], data[pb]);
                ans = big(ans, LCA::query(data[pb].fi, data[pb].se) * (LL)cw[V].se * stk[i]);
            }
        }
        printf("%lld\n", ans);
        for(int i = 1; i <= 10000; ++i) has[i].clear();
    }
#ifndef ONLINE_JUDGE
    cout << "time cost:" << clock() << "ms" << endl;
#endif
    return 0;
}

const int MXN = 5e5 + 7;
const int MXE = 1e6 + 7;
int n, m;
int ar[MXN], sz[MXN], son[MXN], inde, fid[MXN], lid[MXN], rid[MXN], dep[MXN];
vector<pii > mp[MXN];
int stk[MXN], top, aim;
LL dia, dis[MXN], ans;
void dfs_sz(int u, int ba) {
    sz[u] = 1;
    son[u] = 0;
    fid[u] = ++ inde;
    rid[inde] = u;
    for(auto V: mp[u]) {
        if(V.fi == ba) continue;
        dep[V.fi] = dep[u] + 1;
        dfs_sz(V.fi, u);
        sz[u] += sz[V.fi];
        if(sz[V.fi] > sz[son[u]]) son[u] = V.fi;
    }
    lid[u] = inde;
}
void dfs(int u, int ba) {
    for(auto V: mp[u]) {
        if(V.fi == ba) continue;
        dis[V.fi] = dis[u] + V.se;
        dfs(V.fi, u);
    }
}
void chk(int u, int ba, LL len, int Gcd, int Min) {
    if(dia * Gcd * Min <= ans) return;
    ans = big(ans, len * Gcd * Min);
    for(auto V: mp[u]) {
        if(V.fi == ba) continue;
        chk(V.fi, u, len + V.se, __gcd(Gcd, ar[V.fi]), sml(Min, ar[V.fi]));
    }
}
int main() {
    int tim = read();
    while(tim --) {
        n = read();
        for(int i = 1; i <= n; ++i) mp[i].clear();
        for(int i = 1; i <= n; ++i) ar[i] = read();
        inde = top = ans = 0;
        for(int i = 1, a, b, c; i < n; ++i) {
            a = read(), b = read(), c = read();
            mp[a].eb(mk(b, c));
            mp[b].eb(mk(a, c));
            stk[++top] = c;
        }
        sort(stk + 1, stk + 1 + top);
        top = unique(stk + 1, stk + 1 + top) - stk;
        int S = 1, T = 1;
        dis[1] = 0;
        dfs(1, 0);
        for(int i = 1; i <= n; ++i) if(dis[i] > dis[S]) S = i;
        dis[S] = 0;
        dfs(S, 0);
        for(int i = 1; i <= n; ++i) if(dis[i] > dis[T]) T = i;
        dia = dis[T];
        chk(S, 0, 0, ar[S], ar[S]);
        for(int i = 1; i <= n; ++i) chk(i, 0, 0, ar[i], ar[i]);
        printf("%lld\n", ans);
    }
    return 0;
}

点分治做法我还没补,大致思路和上面一样,先贴一下大佬的代码。

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define uint ungigned
#define db double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pli pair<ll,int>
#define vi vector<int>
#define vpi vector<pii >
#define IT iterator
 
#define PB push_back
#define MK make_pair
#define LB lower_bound
#define UB upper_bound
#define y1 wzpakking 
#define fi first
#define se second
#define BG begin
#define ED end
 
#define For(i,j,k) for (int i=(int)(j);i<=(int)(k);i++)
#define Rep(i,j,k) for (int i=(int)(j);i>=(int)(k);i--)
#define UPD(x,y) (((x)+=(y))>=mo?(x)-=mo:233)
#define CLR(a,v) memset(a,v,sizeof(a))
#define CPY(a,b) memcpy(a,b,sizeof(a))
#define sqr(x) (1ll*x*x)
 
#define LS3 k*2,l,mid
#define RS3 k*2+1,mid+1,r
#define LS5 k*2,l,mid,x,y
#define RS5 k*2+1,mid+1,r,x,y
#define GET pushdown(k);int mid=(l+r)/2
#define INF (1<<29)
using namespace std;
int gcd(int x,int y){
    return y?gcd(y,x%y):x;
}
const int N=100005;
vector<int> divi[N];
int head[N],vis[N],tot;
int sz[N],a[N],mx[N],rt;
int n;
ll ans;
struct edge{
    int to,next,v;
}e[N*2];
struct node{
    int fr;
    ll mn,G,dis;
    bool operator <(const node &a)const{
        return mn<a.mn;
    }
}g[N],g2[N];
vector<node> v;
void add(int x,int y,int v){
    e[++tot]=(edge){y,head[x],v};
    head[x]=tot;
}
void Dfs(int x,int fa,ll dis,int G,int mn,int fr){
    mn=min(mn,a[x]); G=gcd(G,a[x]);
    v.push_back((node){fr,mn,G,dis});
    for (int i=head[x];i;i=e[i].next)
        if (!vis[e[i].to]&&e[i].to!=fa)
            Dfs(e[i].to,x,dis+e[i].v,G,mn,fr);
}
void solve(int x){
    v.clear();
    for (int i=head[x];i;i=e[i].next)
        if (!vis[e[i].to])
            Dfs(e[i].to,0,e[i].v,a[x],a[x],e[i].to);
    sort(v.begin(),v.end());
    for (auto i:divi[a[x]])
        g[i]=g2[i]=(node){0,0,0,0};
    Rep(i,v.size()-1,0){
        int fr=v[i].fr;
        ll mn=v[i].mn;
        ll G=v[i].G;
        ll dis=v[i].dis;
        ans=max(ans,dis*G*mn);
        for (auto di:divi[a[x]])
            if (fr!=g[di].fr&&g[di].dis)
                ans=max(ans,(dis+g[di].dis)*min(mn,g[di].mn)*gcd(G,di));
            else if (g2[di].dis)
                ans=max(ans,(dis+g2[di].dis)*min(mn,g2[di].mn)*gcd(G,di));
        if (dis>g[G].dis){
            if (g[G].fr!=fr) g2[G]=g[G];
            g[G]=v[i];
        }
        else if (dis>g2[G].dis&&fr!=g[G].fr)
            g2[G]=v[i];
    }
}
void dfs(int x,int fa,int Sz){
    mx[x]=0; sz[x]=1;
    for (int i=head[x];i;i=e[i].next)
        if (e[i].to!=fa&&!vis[e[i].to]){
            dfs(e[i].to,x,Sz);
            mx[x]=max(mx[x],sz[e[i].to]);
            sz[x]+=sz[e[i].to];
        }
    mx[x]=max(mx[x],Sz-sz[x]);
    if (mx[x]<mx[rt]) rt=x;
}
void divide(int x,int Sz){
    rt=0; dfs(x,0,Sz); 
    vis[x=rt]=1; solve(x);
    for (int i=head[x];i;i=e[i].next)
        if (!vis[e[i].to]){
            int nsz;
            if (sz[e[i].to]>sz[x])
                nsz=Sz-sz[x];
            else nsz=sz[e[i].to];
            divide(e[i].to,nsz);
        }
}
void solve(){
    ans=tot=0; mx[0]=1e9;
    scanf("%d",&n);
    For(i,0,n+1) sz[i]=head[i]=vis[i]=0;
    For(i,1,n) scanf("%d",&a[i]);
    For(i,1,n-1){
        int x,y,v;
        scanf("%d%d%d",&x,&y,&v);
        add(x,y,v); add(y,x,v);
    }
    divide(1,n);
    printf("%lld\n",ans);
}
void init(){
    For(i,1,10000) For(j,1,10000/i)
        divi[i*j].PB(i);
}
int main(){
    init();
    int T;
    scanf("%d",&T);
    while (T--) solve();
}
posted @ 2019-09-05 20:49 Cwolf9 阅读(...) 评论(...) 编辑 收藏

Contact with me