P4842 城市旅行

$ \color{#0066ff}{ 题目描述 }$

W国地人物博,有n座城市组成,共n-1条双向道路连接其中的两座城市,且任意两座城市都可相互到达。

风景秀美的w国吸引了无数慕名而来的游客,根据游客对每座城市的打分,我们定义第i座城市的美丽度为\(a_i\)。一次从城市x到城市y的旅行,所获得的的偷悦指数为从城市x到城市y所有城市的美丽度之和(包括X和y)。我们诅义这个值为H(x,y)。

现在小A在城市X,Sharon在城市Y,他们想知道如果在城市X到城市Y之间的所有城市中任选两座城市x和y(x可以等于y),那么H(x,y)的期望值是多少,我们记这个期望值为E(x,y)。

当然,城市之间的交通状况飘忽不定,因此我们不能排除某些时刻某些道路将无法通行。某些时刻会突然添加新的道路。以及游客们审美观的改变,某些城市的美丽度也会发现变化。作为W国负责旅游行业的T君,他要求你来写一个程序来模拟上而的所有过程。

\(\color{#0066ff}{输入格式}\)

第一行两个整数,n,m表示城市个数和操作个数。

接下来一行n个整数,第i个表示\(a_i\)。 接下来n-1行,每行两个整数u,v,表示u和v之间有一条路。 接下来m行,是进行下面的操作:

  • 1 u v 如果城市u和城市v已经无直接连接的道路,则忽略这个操作,否则删除u,v之间的道路。
  • 2 u v 如果城市u和城市v联通那么忽略。否则在u,v之间添加一条道路。
  • 3 u v d 如果城市u和城市v不连通,那么忽略。否则将城市u到城市v的路径中所有城市(包括u和v)的美丽度都增加d。
  • 4 u v 询问E(u,v)的值

\(\color{#0066ff}{输出格式}\)

对于操作4,输出答案,一个经过化简的分数p/q。如果u和v不连通输出-1。

\(\color{#0066ff}{输入样例}\)

4 5
1 3 2 5
1 2
1 3
2 4
4 2 4
1 2 4
2 3 4
3 1 4 1
4 1 4

\(\color{#0066ff}{输出样例}\)

16/3
6/1

\(\color{#0066ff}{数据范围与提示}\)

对于所有数据满足 \(1<=N<=50,000 1<=M<=50,000 1<=a_i<=10^6 1<=D<=100 1<=U,V<=N\)

\(\color{#0066ff}{题解}\)

根据题目,十分显然你需要维护一个LCT,于是。。。先不管维护啥东西,先把板子bia上

然后我们考虑维护啥东西就行了

如果我们把查询的树链抽象成一个长度为L的序列,显然答案就是所有的子段和的和/子段个数

既然是所有子段和的和,那么我们可以考虑每个元素的贡献,即哪个子段会包含它

显然是\(l的范围*r的范围*val\),如果是序列上,暴力我们可以\(O(n)\)统计

现在我们把它放在树上,考虑Splay维护的是一个深度严格递增的链,可以想到l的范围是左子树大小+1,r 的范围是右子树大小+1

但是很明显这个答案当且仅当这个点是Splay的根的时候才是成立的,所以不能直接这样维护

现在的问题是怎么把两个子树的答案合并

比如当前点链长为9, 左子树链长3,右子树链长5,点权按深度依次为\(a_1,a_2\dots a_9\)

那么左子树的答案就是\(1*3*a_1+2*2*a_2+3*1*a_3\)

右子树的答案就是\(1*5*a_5+2*4*a_6+3*3*a_7+4*2*a_8+5*1*a_9\)

而当前点的答案是\(1*9*a_1+2*8*a_2+3*7*a_3+4*6*a_4+5*5*a_5+6*4*a_6+7*3*a_7+8*2*a_8+9*1*a_9\)

观察一下式子,把式子分成三部分,左子树贡献,右子树贡献,自己的贡献

自己的贡献就是\((左siz+1)*(右siz+1)*a_4\)

可以发现,左子树的那些项,到了当前点,实际上是加了\(1*6*a_1+2*6*a_2+3*6*a_3\)

把6提出来,我们发现,需要维护一个\(a_i*i\)

同理,推一下右孩子,发现还需要维护一个\(a_i*(len-i+1)\),就是反过来乘

对于这两个的维护,也要类似上面推一下,发现需要维护子树点权和

然后无修改的询问就处理完了(注意翻转的时候正反的\(a_i*i\)要交换)

如果有修改呢?

肯定是打标记跑不了

那么只需改当前的点的值就行了

把推的式子的\(a_i\)换成\(a_i+d\),看看发生了哪些变化,统计一下就行

由于维护的东西比较多,要是不写哨兵得乱死。。。

// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define LL long long
LL read() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
template<class T> bool chkmax(T &a, const T &b) { return a < b? a = b, 1 : 0; }
template<class T> bool chkmin(T &a, const T &b) { return b < a? a = b, 1 : 0; }
const int maxn = 1e5 + 10;
LL f[maxn];
void predoit(int n) {
	for(int i = 1; i <= n; i++) f[i] = i;
	for(int i = 1; i <= n; i++) f[i] += f[i - 1];
	for(int i = 1; i <= n; i++) f[i] += f[i - 1];
}
LL gcd(LL x, LL y) { return y? gcd(y, x % y) : x; }
struct LCT {
	protected:
		struct node {
			node *ch[2], *fa;
			LL sum, val, add, lsum, rsum, totval;
			int siz, rev;
			node(LL sum = 0, LL val = 0, LL add = 0, LL lsum = 0, LL rsum = 0, LL totval = 0, int siz = 0, int rev = 0): sum(sum), val(val), add(add), lsum(lsum), rsum(rsum), totval(totval), siz(siz), rev(rev) { ch[0] = ch[1] = fa = NULL; }
			void upd() {
				siz = ch[0]->siz + ch[1]->siz + 1;
				totval = ch[0]->totval + ch[1]->totval + val;
				lsum = ch[0]->lsum + ch[1]->lsum + 1LL * (ch[1]->totval + val) * (ch[0]->siz + 1);
				rsum = ch[1]->rsum + ch[0]->rsum + 1LL * (ch[0]->totval + val) * (ch[1]->siz + 1);
				sum = ch[0]->sum + ch[1]->sum + 1LL * ch[0]->lsum * (ch[1]->siz + 1) + 1LL * ch[1]->rsum * (ch[0]->siz + 1) + 1LL * (ch[0]->siz + 1) * (ch[1]->siz + 1) * val; 
			}
			void trnadd(LL v) { 
				if(this->fa == this) return;
				totval += 1LL * siz * v;
				add += v;
				lsum += (1LL * siz * (siz + 1) >> 1) * v;
				rsum += (1LL * siz * (siz + 1) >> 1) * v;
				sum += f[siz] * v;
				val += v;
			}
			void trnrev() { 
				if(this->fa == this) return;
				std::swap(ch[0], ch[1]), rev ^= 1; 
				std::swap(lsum, rsum);
			}
			void dwn() {
				if(rev) {
					ch[0]->trnrev();
					ch[1]->trnrev();
					rev = 0;
				}
				if(add) {
					ch[0]->trnadd(add);
					ch[1]->trnadd(add);
					add = 0;
				}
			}
			bool ntr() { return (fa->ch[0] == this || fa->ch[1] == this); }
			bool isr() { return this == fa->ch[1]; }
		}pool[maxn], *null;
		void rot(node *x) {
			node *y = x->fa, *z = y->fa;
			bool k = x->isr(); node *w = x->ch[!k];
			if(y->ntr()) z->ch[y->isr()] = x;
			(x->ch[!k] = y)->ch[k] = w;
			(y->fa = x)->fa = z;
			if(w != null) w->fa = y;
			y->upd(), x->upd();
		}
		void splay(node *o) {
			static node *st[maxn]; int top;
			st[top = 1] = o;
			while(st[top]->ntr()) st[top + 1] = st[top]->fa, top++;
			while(top) st[top--]->dwn();
			while(o->ntr()) {
				if(o->fa->ntr()) rot(o->isr() ^ o->fa->isr()? o : o->fa);
				rot(o);
			}
		}
		void access(node *x) {
			for(node *y = null; x != null; x = (y = x)->fa)
				splay(x), x->ch[1] = y, x->upd();
		}
		void makeroot(node *x) { access(x), splay(x), x->trnrev(); }
		node *findroot(node *x) {
			access(x), splay(x);
			while(x->dwn(), x->ch[0] != null) x = x->ch[0];
			return x;
		}
	public:
		void link(int l, int r) {
			node *x = pool + l, *y = pool + r;
			if(findroot(x) == findroot(y)) return;
			makeroot(x), x->fa = y;
		}
		void cut(int l, int r) {
			node *x = pool + l, *y = pool + r;
			makeroot(x), access(y), splay(y);
			if(y->ch[0] == x && x->ch[1] == null) y->ch[0] = x->fa = null, y->upd();
		}
		void add(int l, int r, LL val) {
			node *x = pool + l, *y = pool + r;
			if(findroot(x) != findroot(y)) return;
			makeroot(x), access(y), splay(y);
			y->trnadd(val);
		}
		void query(int l, int r) {
			node *x = pool + l, *y = pool + r;
			if(findroot(x) != findroot(y)) return (void)(puts("-1"));
			makeroot(x), access(y), splay(y);
			LL dn = 1LL * y->siz * (y->siz + 1) >> 1;
			LL up = y->sum;
			LL gc = gcd(up, dn);
			printf("%lld/%lld\n", up / gc, dn / gc);
		}
		void set(int x, LL val) { 
			pool[x].val = val, pool[x].siz = 1; 
			pool[x].ch[0] = pool[x].ch[1] = pool[x].fa = null;
			pool[x].upd();
		}
		void init() {
			null = new node();
			null->ch[0] = null->ch[1] = null->fa = null;
		}
}T;
int main() {
	int n = read(), m = read();
	T.init(); predoit(n);
	for(int i = 1; i <= n; i++) T.set(i, read());
	for(int i = 1; i < n; i++) T.link(read(), read());
	LL p, x, y, d;
	while(m --> 0) {
		p = read();
		if(p == 1) x = read(), y = read(), T.cut(x, y);
		if(p == 2) x = read(), y = read(), T.link(x, y);
		if(p == 3) x = read(), y = read(), d = read(), T.add(x, y, d);
		if(p == 4) x = read(), y = read(), T.query(x, y);
	}
	return 0;
}
posted @ 2019-04-04 07:15  olinr  阅读(222)  评论(0编辑  收藏  举报