Luogu P5351 Ruri Loves Maschera

先ORZ\(Owen\)一发。感觉是个很套路的题,这里给一个蒟蒻的需要特判数据的伪\(n\log^2 n\)算法,真正的两只\(\log\)的还是去看标算吧(但这个好想好写跑不满啊)

首先这种树上路径统计的问题我们先套一个点分治上去就是了,然后求出分治中心连出去的每一条路径

我们还是套路地分成不同子树求解,如果我们对一条路径设一个二元组\((dep,val)\)表示到分治中心的深度以及路径上的最大值,那么就考虑怎么把一棵子树内的路径和其它子树的合并了

考虑如果只有一个端点限制怎么做,很容易想到它们的和取值都满足单调性

那么我们直接把两边都按\(dep\)排序,然后直接\(\text{two points}\)扫出一个端点的范围即可

接下来考虑两个端点怎么做,我们发现此时可以选的点形成了一个区间,既然这个区间两个端点都满足单调性,所以这个区间移动也是单调的

那么就很简单了,我们扫出每个点的合法区间之后考虑怎么统计答案

首先所有比它小的数最后的贡献都是它的权值,然后所有大于它的数的贡献即是它们本身

我们直接拿两个树状数组来维护,一个记录比它小的数的个数,另一个记录比它大的数的权值和,直接跟着指针一起添删点即可

很容易发现,这样的复杂度大部分消耗在排序上,如果每次暴力排序的话复杂度最坏会被卡到\(n^2\log^2 n\),但是很多时候一个点的度数没有那么多,所以跑起来非常快,大致比两个\(\log\)多一些常数吧

所以我们写完之后特判一下菊花图即可,这个更加简单,因为路径长度只有\(1/2\),暴力排序后枚举一下即可

CODE

#include<cstdio>
#include<cctype>
#include<algorithm>
#define RI int
#define CI const int&
#define Tp template <typename T>
using namespace std;
typedef long long LL;
const int N=100005,INF=1e9;
struct edge
{
    int to,nxt,v;
}e[N<<1]; int n,head[N],rst[N],num,cnt,L,R,x,y,z; LL ans; bool flag;
class FileInputOutput
{
    private:
        static const int S=1<<21;
        #define tc() (A==B&&(B=(A=Fin)+fread(Fin,1,S,stdin),A==B)?EOF:*A++)
        char Fin[S],*A,*B;
    public:
        Tp inline void read(T& x)
        {
            x=0; char ch; while (!isdigit(ch=tc()));
            while (x=(x<<3)+(x<<1)+(ch&15),isdigit(ch=tc()));
        }
        #undef tc
}F;
inline int min(CI a,CI b)
{
    return a<b?a:b;
}
inline void addedge(CI x,CI y,CI z)
{
    e[++cnt]=(edge){y,head[x],z}; head[x]=cnt;
    e[++cnt]=(edge){x,head[y],z}; head[y]=cnt;
}
namespace Case1 //Point Division Solver
{
    class Tree_Array
    {
        private:
            int size[N]; LL sum[N]; bool vis[N];
        public:
            #define lowbit(x) x&-x
            inline void add(CI x,CI y,CI opt,RI i=0)
            {
                for (i=x;i<=num;i+=lowbit(i)) size[i]+=opt;
                for (i=x;i;i-=lowbit(i)) sum[i]+=opt*y;
            }
            inline int query_rk(RI x,int ret=0)
            {
                for (;x;x-=lowbit(x)) ret+=size[x]; return ret;
            }
            inline LL query_sum(RI x,LL ret=0)
            {
                for (;x<=num;x+=lowbit(x)) ret+=sum[x]; return ret;
            }
            #undef lowbit
    }BIT;
    #define to e[i].to
    class Point_Division_Solver
    {
        private:
            struct data
            {
                int dep,mxv;
                friend inline bool operator < (const data& A,const data& B)
                {
                    return A.dep<B.dep;
                }
            }s[N]; int size[N],tot,pl[N],pr[N]; bool vis[N];
            inline int max(CI a,CI b)
            {
                return a>b?a:b;
            }
            inline void maxer(int& x,CI y)
            {
                if (y>x) x=y;
            }
            inline void DFS(CI now,CI fa,CI d,CI mx)
            {
                s[++tot]=(data){d,mx}; if (L<=d&&d<=R) ans+=rst[mx];
                for (RI i=head[now];i;i=e[i].nxt)
                if (to!=fa&&!vis[to]) DFS(to,now,d+1,max(mx,e[i].v));
            }
        public:
            int mx[N],ots,rt;
            inline void getrt(CI now,CI fa=1)
            {
                size[now]=1; mx[now]=0;
                for (RI i=head[now];i;i=e[i].nxt) if (to!=fa&&!vis[to])
                getrt(to,now),size[now]+=size[to],maxer(mx[now],size[to]);
                if (maxer(mx[now],ots-size[now]),mx[now]<mx[rt]) rt=now;
            } 
            inline void solve(CI now)
            {
                RI i,j; vis[now]=1; tot=0; int lst=0;
                for (i=head[now];i;i=e[i].nxt) if (!vis[to])
                {
                    DFS(to,now,1,e[i].v); sort(s+lst+2,s+tot+1); if (lst)
                    {
                        int p1=lst; for (j=lst+1;j<=tot;++j)
                        {
                            while (p1&&s[p1].dep+s[j].dep>R) --p1; pr[j]=p1; 
                        }
                        int p2=1; for (j=tot;j>lst;--j)
                        {
                            while (p2<=lst&&s[p2].dep+s[j].dep<L) ++p2; pl[j]=p2;
                        }
                        p1=lst+1; p2=lst; for (j=lst+1;j<=tot;++j)
                        {
                            while (p1>pl[j]) --p1,BIT.add(s[p1].mxv,rst[s[p1].mxv],1);
                            while (p2>pr[j]) BIT.add(s[p2].mxv,rst[s[p2].mxv],-1),--p2;
                            ans+=1LL*BIT.query_rk(s[j].mxv)*rst[s[j].mxv]+BIT.query_sum(s[j].mxv+1);
                        }
                        for (j=p1;j<=p2;++j) BIT.add(s[j].mxv,rst[s[j].mxv],-1);
                    }
                    lst=tot; sort(s+1,s+lst+1);
                }
                for (i=head[now];i;i=e[i].nxt) if (!vis[to])
                mx[rt=0]=INF,ots=size[to],getrt(to,now),solve(rt);
            }
    }PD;
    #undef to
    inline void solve(void)
    {
        sort(rst+1,rst+n); num=unique(rst+1,rst+n)-rst-1;
        for (RI i=1;i<=cnt;++i) e[i].v=lower_bound(rst+1,rst+num+1,e[i].v)-rst;
        PD.mx[PD.rt=0]=INF; PD.ots=n; PD.getrt(1); PD.solve(PD.rt); 
    }
};
namespace Case2 //Flower Graph Solver
{
    inline void solve(void)
    {
        RI i; sort(rst+1,rst+n); if (L<=1&&1<=R)
        for (i=1;i<n;++i) ans+=rst[i]; if (L<=2&&2<=R)
        for (i=1;i<n;++i) ans+=1LL*rst[i]*(i-1);
    }
};
int main()
{
    //freopen("CODE.in","r",stdin); freopen("CODE.out","w",stdout);
    RI i; for (F.read(n),F.read(L),F.read(R),flag=i=1;i<n;++i)
    {
        F.read(x); F.read(y); F.read(z);
        addedge(x,y,z); rst[i]=z; if (min(x,y)!=1) flag=0;
    }
    if (flag) Case2::solve(); else Case1::solve();
    return printf("%lld",ans<<1LL),0;
}
posted @ 2019-05-08 21:30 hl666 阅读(...) 评论(...) 编辑 收藏