hihocoder-1347 小h的树上的朋友(lca+线段树)

题目链接:

小h的树上的朋友

时间限制:18000ms
单点时限:2000ms
内存限制:512MB

描述

小h拥有n位朋友。每位朋友拥有一个数值Vi代表他与小h的亲密度。亲密度有可能发生变化。

岁月流逝,小h的朋友们形成了一种稳定的树状关系。每位朋友恰好对应树上的一个节点。

每次小h想请两位朋友一起聚餐,他都必须把连接两位朋友的路径上的所有朋友都一起邀请上。并且聚餐的花费是这条路径上所有朋友的亲密度乘积。

小h很苦恼,他需要知道每一次聚餐的花销。小h问小y,小y当然会了,他想考考你。

输入

输入文件第一行是一个整数n,表示朋友的数目,从1开始编号。

输入文件第二行是n个正整数Vi,表示每位朋友的初始的亲密度。

接下来n-1行,每行两个整数u和v,表示u和v有一条边。

然后是一个整数m,代表操作的数目。每次操作为两者之一:

0 u v 询问邀请朋友u和v聚餐的花费

1 u v 改变朋友u的亲密度为v

1<=n,m<=5*105

Vi<=109

输出

对于每一次询问操作,你需要输出一个整数,表示聚餐所需的花费。你的答案应该模1,000,000,007输出。

样例输入
3
1 2 3
1 2
2 3
5
0 1 2
0 1 3
1 2 3
1 3 5
0 1 3
样例输出
2
6
15
题意:
中文的就不说了;

思路:
显然是一个线段树的题;
先dfs,把树映射到区间上同时求出每个点到根节点的花费,
0的时候询问:先找到lca;再dis[u]*dis[v]*w[lca]/(dis[lca]*dis[lca]);可以费马小定理快速幂求逆;
1的时候更新:dfs的时候找到了每个点的包含此点所以子节点的区间,把这个区间的dis都更新同时还要更新w[u]我就是这两个问题写漏了改了一夜晚;

AC代码:

#include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

#define For(i,j,n) for(int i=j;i<=n;i++)
#define mst(ss,b) memset(ss,b,sizeof(ss));

typedef  long long LL;

template<class T> void read(T&num) {
    char CH; bool F=false;
    for(CH=getchar();CH<'0'||CH>'9';F= CH=='-',CH=getchar());
    for(num=0;CH>='0'&&CH<='9';num=num*10+CH-'0',CH=getchar());
    F && (num=-num);
}
int stk[70], tp;
template<class T> inline void print(T p) {
    if(!p) { puts("0"); return; }
    while(p) stk[++ tp] = p%10, p/=10;
    while(tp) putchar(stk[tp--] + '0');
    putchar('\n');
}

const LL mod=1e9+7;
const double PI=acos(-1.0);
const int inf=1e9;
const int N=5e5+10;
const int maxn=1e3+10;
const double eps=1e-10;

LL w[N],dis[N];
vector<int>ve[N];

int n,in[N],a[2*N],dep[N],cnt=0,out[N];

LL pow_mod(LL x,LL y)
{
    LL s=1,base=x;
    while(y)
    {
        if(y&1)s=s*base%mod;
        base=base*base%mod;
        y>>=1;
    }
    return s;
}

void dfs(int x,int deep,int fa)
{
    cnt++;
    in[x]=cnt;
    a[cnt]=x;
    dep[x]=deep;
    int len=ve[x].size();
    For(i,0,len-1)
    {
        int y=ve[x][i];
        if(y==fa)continue;
        dis[y]=dis[x]*w[y]%mod;
        dfs(y,deep+1,x);
        cnt++;
        a[cnt]=x;
    }
    out[x]=cnt;
}
struct Tree
{
    int l,r,lca;
    LL dis;
}tr[8*N];
void pushdown(int o)
{
    tr[2*o].dis=tr[2*o].dis*tr[o].dis%mod;
    tr[2*o+1].dis=tr[2*o+1].dis*tr[o].dis%mod;
    tr[o].dis=1;
}
void build(int o,int L,int R)
{
    tr[o].l=L;
    tr[o].r=R;
    tr[o].dis=1;
    if(L==R)
    {
        tr[o].dis=dis[a[L]];
        tr[o].lca=a[L];
        return ;
    }
    int mid=(L+R)>>1;
    build(2*o,L,mid);
    build(2*o+1,mid+1,R);
    if(dep[tr[2*o].lca]>=dep[tr[2*o+1].lca])tr[o].lca=tr[2*o+1].lca;
    else tr[o].lca=tr[2*o].lca;
}
void update(int o,int L,int R,LL val)
{
    if(tr[o].l>=L&&tr[o].r<=R)
    {
        tr[o].dis=tr[o].dis*val%mod;
        return ;
    }
    int mid=(tr[o].l+tr[o].r)>>1;

    if(L>mid)update(2*o+1,L,R,val);
    else if(R<=mid)update(2*o,L,R,val);
    else {
        update(2*o,L,mid,val);
        update(2*o+1,mid+1,R,val);
    }
}
int querylca(int o,int L,int R)
{

        if(tr[o].l>=L&&tr[o].r<=R)return tr[o].lca;
        int mid=(tr[o].l+tr[o].r)>>1;
        if(R<=mid)return querylca(2*o,L,R);
        else if(L>mid)return querylca(2*o+1,L,R);
        else 
        {
            int fl=querylca(2*o,L,mid),fr=querylca(2*o+1,mid+1,R);
            if(dep[fl]<=dep[fr])return fl;
            else return fr;
        }
}
LL query(int o,int pos)
{
    if(tr[o].l==tr[o].r&&tr[o].l==pos)return tr[o].dis;
    int mid=(tr[o].l+tr[o].r)>>1;
    pushdown(o);
    if(pos>mid)return query(2*o+1,pos);
    return query(2*o,pos);
}
int main()
{
        read(n);
        For(i,1,n)read(w[i]);
        int u,v;
        For(i,1,n-1)
        {
            read(u);read(v);
            ve[u].push_back(v);
            ve[v].push_back(u);
        }
        dis[1]=w[1];
        dfs(1,0,0);
        build(1,1,cnt);
        int q,f;
        read(q);
        while(q--)
        {
            read(f);read(u);read(v);
            if(f)
            {
                LL temp=w[u];
                w[u]=(LL)v;
                update(1,in[u],out[u],w[u]*pow_mod(temp,mod-2)%mod);
            }
            else
            {
                if(in[u]>in[v])swap(u,v);
                int lca=querylca(1,in[u],in[v]);
                LL temp=query(1,in[lca]);
                temp=pow_mod(temp,mod-2);
                temp=temp*temp%mod;
                LL ans=query(1,in[u])*query(1,in[v])%mod*temp%mod*w[lca]%mod;
                cout<<ans<<"\n";
            }
        }        
        return 0;
}

  

 

posted @ 2016-07-16 23:38  LittlePointer  阅读(557)  评论(2编辑  收藏  举报