# 【LG4437】[HNOI/AHOI2018]排列

## 题解

$W_{ab}=\sum_{j=1}^{m_1}(i+j)w_{a_j}+\sum_{j=1}^{m_2}(i+j+m_1)w_{b_j}​\\ W_{ba}=\sum_{j=1}^{m_2}(i+j)w_{b_j}+\sum_{j=1}^{m_1}(i+j+m_2)w_{a_j}​\\ W_{ab}-W_{ba}=m_1W_b-m_2W_a​$

## 代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
return w * data;
}
const int MAX_N = 5e5 + 5;
struct Graph { int next, to; } e[MAX_N << 1]; int fir[MAX_N], e_cnt;
void clearGraph() { memset(fir, -1, sizeof(fir)); e_cnt = 0; }
void Add_Edge(int u, int v) { e[e_cnt] = (Graph){fir[u], v}; fir[u] = e_cnt++; }
bool vis[MAX_N];
int N, tot, pa[MAX_N], fa[MAX_N], size[MAX_N];
long long w[MAX_N];
void dfs(int x) {
vis[x] = 1, ++tot;
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (vis[v]) { puts("-1"); exit(0); }
else dfs(v);
}
}
int getf(int x) { return pa[x] == x ? x : pa[x] = getf(pa[x]); }

struct Node { int u, sz; long long  w; } ;
bool operator < (const Node &l, const Node &r) { return l.w * r.sz > r.w * l.sz; }
struct Heap{
Node h[MAX_N]; int cur;
Node top() { return h[1]; }
void push(const Node &x) { h[++cur] = x; push_heap(&h[1], &h[cur + 1]); }
void pop() { pop_heap(&h[1], &h[cur + 1]); --cur; }
bool empty() { return cur == 0; }
} que;
int main () {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin);
#endif
clearGraph();
N = gi();
for (int i = 1; i <= N; i++) fa[i] = gi(), Add_Edge(fa[i], i);
for (int i = 1; i <= N; i++) w[i] = gi();
dfs(0); if (tot <= N) return puts("-1") & 0;
for (int i = 0; i <= N; i++) pa[i] = i, size[i] = 1;
for (int i = 1; i <= N; i++) que.push((Node){i, 1, w[i]});
long long ans = 0;
while (!que.empty()) {
Node p = que.top(); que.pop();
int u = getf(p.u);
if (size[u] != p.sz) continue;
int f = getf(fa[u]); pa[u] = f;
ans += w[u] * size[f], w[f] += w[u], size[f] += size[u];
if (f) que.push((Node){f, size[f], w[f]});
}
printf("%lld\n", ans);
return 0;
} 
posted @ 2019-11-04 16:25  heyujun  阅读(...)  评论(...编辑  收藏