【noip模拟】tree

Time Limit: 1000 ms        Memory Limit: 128 MB

 

 

 

[吐槽]

  点分治点分治点分治

  嗯。。场上思考树状数组的时候好像傻掉了。。反正就是挂了就是了。。

 

[题解]

  首先如果没有环的话就是一道十分简单的点分治啦

  但是这题有环啊

  

  考虑强行变树

  从题目各种谜一般的描述中得出来的结论是:$m<=n$

  其实也就是说最多只有一个环

  那么就有一个很直接的想法,先把唯一的一个环找出来,断掉其中的一条边

  这样就使它变成一棵树了,直接跑一遍点分就好

 

  考虑断掉的那条边

  这样统计有一个很明显的问题:经过断开那条边的情况全部都没有算进去

  所以现在就考虑怎么算过这条边的ans

  

  首先我们可以将这个环摊开变成这样:

  

  

  然后发现这个东西其实就是一条“链”上面有若干棵树

  断开的那条边显然就是连接这条“链”一头一尾的边(为了方便描述,将这条断开的边记作$(x,y)$

  我们定义

  $rt_i$表示$i$所属的子树的根节点

  $dis_i$ 表示$i$到$rt_i$的的路径上的点数

  $left_i$表示$rt_i$到这条“链”头(也就是图中编号为1的点)的节点数

  $right_i$表述$rt_i$到这条“链”尾(图中编号为5的点)的节点数

  那么要算一条过$(x,y)$的路径$(i,j)$的点数的话,显然就是子树里面的距离+链上要走的距离

  也就是 $dis_i+dis_j+left_i+right_j$ ($rt_i$在$rt_j$左边)

  

  那么就可以用一个树状数组来搞定了

 

  考虑怎么统计

  (其实实现起来并不用上面的那些奇妙数组)

  我们可以先将链上的点(也就是各个子树的根节点)编个号

  那么对于一个这条链上面的第$i$和第$j$ $(i<j)$ 个点,那么链上要走的距离就为 $i+(len-j+1)$

  其中$len$表示的是链的长度

  然后将式子上一步中求路径上点数的式子稍微整理一下,得到

  $(dis_i+i)+(dis_j+len-j+1)  (i<j) $

  

  所以我们可以从左往右一个一个点处理

  先将当前点$i$子树内的$dis$处理出来

  然后对于每一个$dis_j (j \in subtree(i))$ ,在树状数组里面查询大于等于$k-dis_j-(len-j+1)$的数量(原因在后面解释)

   查询完了之后将$dis_j+j$丢入树状数组中

 

  这么处理的原因显然

  整理过后的式子可以分为两部分,分别只与$i$和$j$有关

  然后因为我们是从左到右处理链上面的点的,所以可以保证查询到的点是在当前点的前面的

 

  然后这题就十分愉快地解决啦

 

[一些小细节]

  因为这题是求>=的方案数

  所以树状数组要十分愉快地反过来(也就是insert的时候是x-=x&-x,query的时候是x+=x&-x,见代码)

  以及因为insert的时候是dis+i,所以上限应该是2*n

  以及要用long long

  嗯大概就是这样ovo

 

 

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<algorithm>
  5 #define ll long long
  6 using namespace std;
  7 const int MAXN=100010;
  8 int h[MAXN],size[MAXN],mx[MAXN];
  9 ll dis[MAXN];
 10 bool vis[MAXN];
 11 int n,m,k,tot,rt,rt_mx;
 12 ll ans,num;
 13 struct xxx
 14 {
 15     int y,next;
 16     bool flag;
 17 }a[MAXN*2];
 18 struct data
 19 {
 20     ll c[MAXN*2];
 21     int insert(int x,ll delta) {_insert(x,delta);}
 22     int _insert(int x,ll delta)
 23     {
 24         for (;x;x-=x&-x) c[x]+=delta;
 25     }
 26     ll query(int x) {return _query(x);}
 27     ll _query(int x)
 28     {
 29         ll ret=0;
 30         if (x<1) x=1;
 31         for (;x<=2*n;x+=x&-x) ret+=c[x];
 32         return ret;
 33     }
 34 }c;
 35 int pre[MAXN],cir[MAXN];
 36 int add(int x,int y);
 37 int dfs(int x);
 38 int dfs_size(int x,int fa);
 39 int dfs_root(int r,int x,int fa);
 40 int get_dis(int x,int fa,int d);
 41 int get_cir(int fa,int x);
 42 ll cal(int x,int d);
 43 bool cmp(int x,int y){return x>y;}
 44 int solve_cir();
 45 
 46 int main()
 47 {
 48     freopen("a.in","r",stdin);
 49     freopen("a.out","w",stdout);
 50 
 51     int x,y,z;    
 52     scanf("%d%d%d",&n,&m,&k);
 53     tot=1;
 54     memset(h,-1,sizeof(h));
 55     for (int i=1;i<=m;++i)
 56     {
 57         scanf("%d%d",&x,&y);
 58         add(x,y); add(y,x);
 59     }
 60     if (m+1==n) {dfs(1); printf("%lld\n",ans); return 0;}
 61     cir[0]=0;
 62     get_cir(0,1);
 63     solve_cir();
 64 }
 65 
 66 int add(int x,int y)
 67 {
 68     a[++tot].y=y; a[tot].next=h[x]; h[x]=tot; a[tot].flag=true;
 69 }
 70 
 71 int dfs(int x)
 72 {
 73     rt=0,rt_mx=n;
 74     dfs_size(x,0);
 75     dfs_root(x,x,0);
 76     ans=ans+cal(rt,0);
 77     vis[rt]=true;
 78     for (int i=h[rt];i!=-1;i=a[i].next)
 79         if (!vis[a[i].y]&&a[i].flag)
 80         {
 81             ans=ans-cal(a[i].y,1);
 82             dfs(a[i].y);
 83         }
 84 }
 85 
 86 int dfs_size(int x,int fa)
 87 {
 88     size[x]=1;
 89     mx[x]=0;
 90     for (int i=h[x];i!=-1;i=a[i].next)
 91         if (a[i].y!=fa&&!vis[a[i].y]&&a[i].flag)
 92         {
 93             dfs_size(a[i].y,x);
 94             size[x]+=size[a[i].y];
 95             mx[x]=max(mx[x],size[a[i].y]);
 96         }
 97 }
 98 
 99 int dfs_root(int r,int x,int fa)
100 {
101     mx[x]=max(mx[x],size[r]-size[x]);
102     if (rt_mx>mx[x]) rt_mx=mx[x],rt=x;
103     for (int i=h[x];i!=-1;i=a[i].next)
104         if (a[i].y!=fa&&!vis[a[i].y]&&a[i].flag)
105             dfs_root(r,a[i].y,x);
106 }
107 
108 int get_dis(int x,int fa,int d)
109 {
110     dis[++num]=d;
111     for (int i=h[x];i!=-1;i=a[i].next)
112         if (a[i].y!=fa&&!vis[a[i].y]&&a[i].flag)
113             get_dis(a[i].y,x,d+1);
114 }
115 
116 ll cal(int x,int d)
117 {
118     num=0;
119     get_dis(x,0,d);
120     int left=1,right=num;
121     ll re=0;
122     sort(dis+1,dis+1+num,cmp);
123     while (left<right)
124     {
125         while (dis[left]+dis[right]+1<k&&left<right) --right;
126         re+=right-left;
127         ++left;
128     }
129     return re;
130 }
131 
132 int get_cir(int fa,int x)
133 {
134     int u;
135     vis[x]=true; pre[x]=fa;
136     for (int i=h[x];i!=-1;i=a[i].next)
137     {
138         u=a[i].y;
139         if (u==fa) continue;
140         if (vis[u])
141         {
142             a[i].flag=false; a[i^1].flag=false;
143             for (int j=x;j!=u;j=pre[j]) cir[++cir[0]]=j;
144             cir[++cir[0]]=u;
145             return 0;
146         }
147         get_cir(x,u);
148         if (cir[0]) return 0;
149     }
150 }
151 
152 int solve_cir()
153 {
154     for (int i=1;i<=n;++i) vis[i]=false;
155     dfs(1);
156     for (int i=1;i<=n;++i) vis[i]=false;
157     for (int i=1;i<=cir[0];++i) vis[cir[i]]=true;
158     for (int i=1;i<=cir[0];++i)
159     {
160         num=0;
161         get_dis(cir[i],0,0);
162         for (int j=1;j<=num;++j)
163             ans+=c.query(k-dis[j]-(cir[0]-i+1));
164         for (int j=1;j<=num;++j)
165             c.insert(dis[j]+i,1);
166     }
167     printf("%lld\n",ans);
168 }
挫挫的代码

 

posted @ 2017-09-21 20:58  yoyoball  阅读(293)  评论(2编辑  收藏  举报