SNOI2020 部分题解

D1T1

画图可以发现,多了一条边过后的图是串并联图。(暂时不确定)

然后我们考虑把问题变成,若生成树包含一条边\(e\),则使生成树权值乘上\(a_e\),否则乘上\(b_e\),求最终的生成树权值之和。我们只需要支持删去度数为\(1\)的点,同时删去和它相连的那条边;删去度数为2的点,把两条边合并为一条边;合并重边三种操作。

对于第一种操作,把答案乘上\(a_e\),并删去即可。对于第二种操作,把和这个点相邻的两条边记作\(e_1,e_2\),其中\(e_1\)连接\(u, v\)\(e_2\)连接\(u, w\)。则删去两条边,连接一条端点为\(v, w\)\(a_e = a_{e_1}a_{e_2}, b_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}\)的边。对于第三种操作,把\(e_1, e_2\)合并为\(e\)时,\(a_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}, b_e = b_{e_1}b_{e_2}\)

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 500005;
const long long mod = 998244353ll;

template <class T>
void read (T &x) {
	int sgn = 1;
	char ch;
	x = 0;
	for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
	if (ch == '-') ch = getchar(), sgn = -1;	
	for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
template <class T>
void write (T x) {
	if (x < 0) putchar('-'), write(-x);
	else if (x < 10) putchar(x + '0');
	else write(x / 10), putchar(x % 10 + '0');
}

int n, m, deg[N];
long long ans = 1ll;
bool vis[N];
vector<int> g[N];
set<pair<int, int> > se;
map<pair<int, int>, pair<long long, long long> > id;
queue<int> que;

int main () {
	read(n), read(m);
	for (int i = 1; i <= n; i++) deg[i] = 0;
	for (int i = 0; i < m; i++) {
		int u, v;
		read(u), read(v);
		if (u > v) swap(u, v);
		if (u == v) continue;
		if (se.count(make_pair(u, v))) id[make_pair(u, v)].second = (id[make_pair(u, v)].second + 1ll) % mod;
		else {
			deg[u]++, deg[v]++;
			g[u].push_back(v), g[v].push_back(u);
			se.insert(make_pair(u, v));
			id[make_pair(u, v)].first = id[make_pair(u, v)].second = 1ll;
		}
	}
	for (int i = 1; i <= n; i++) {
		vis[i] = false;
		if (deg[i] <= 2) que.push(i);
	}
	while (!que.empty()) {
		int u = que.front();
		que.pop();
		if (vis[u]) continue;
		vis[u] = true;
		vector<int> adj;
		for (int i = 0; i < g[u].size(); i++) {
			if (!vis[g[u][i]]) adj.push_back(g[u][i]);
		}
		if (adj.size() == 1) {
			deg[u]--;
			if (--deg[adj[0]] <= 2) que.push(adj[0]);
			ans = ans * id[make_pair(min(u, adj[0]), max(u, adj[0]))].second % mod;
		}
		else if (adj.size() == 2) {
			pair<int, int> e0(min(u, adj[0]), max(u, adj[0]));
			pair<int, int> e1(min(u, adj[1]), max(u, adj[1]));
			pair<int, int> e(min(adj[0], adj[1]), max(adj[0], adj[1]));
			pair<long long, long long> pi((id[e0].second * id[e1].first + id[e0].first * id[e1].second) % mod, id[e0].second * id[e1].second % mod);
			deg[u] -= 2;
			if (se.count(e)) {
				if (--deg[adj[0]] <= 2) que.push(adj[0]);
				if (--deg[adj[1]] <= 2) que.push(adj[1]);
				id[e] = make_pair(id[e].first * pi.first % mod, (id[e].first * pi.second + id[e].second * pi.first) % mod);
			} 
			else {
				g[adj[0]].push_back(adj[1]);
				g[adj[1]].push_back(adj[0]);
				se.insert(e);
				id[e] = pi;
			}
		}
	}
	write(ans), putchar('\n');
	return 0;
}

D1T2

神奇的找规律题,感觉方向不对就很难找出来。

我们仍然考虑打表,把\(k, n\)比较小的情况打出来(记作\(f_{n, k}\))。然后我们发现固定\(n\)\(f_{n, k}\)的取值很少。再仔细观察,发现取值变化的点恰好是斐波那契数列上的数

这启发我们对于每个数\(i > 1\),找先手第一步最小要取多少,才能保证他获胜。记该最小值为\(a_i\)。然后我们写出这个数列某些前缀:

\(1\)

\(1, 2\)

\(1, 2, 3\)

\(1, 2, 3, 1, 5\)

\(1, 2, 3, 1, 5, 1, 2, 8\)

\(\cdots\)

注意这里我们是找\(1\), \(2\), \(3\), \(5\), \(8\), \(\cdots\)的第一次出现位置为终止的前缀!
然后我们发现第\(i\)个前缀是第\(i - 1, i - 3, i - 5, \cdots\)个前缀拼接上\(F_i\)后的前缀。(\(F_0 = F_1 = 1\)
找到了这个规律后,我们就设\(g_{i, j, k}\)表示在第\(i\)个前缀的前\(j\)个数中\(\leq F_k\)的数字的个数。先通过\(dp\)预处理\(j\)就是第\(i\)个前缀的长度的情况,然后每次询问递归下去做即可。(通过类似线段树的证明,可以知道每次询问递归树的点数是\(O(\log^2 N)\)的。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int S = 505, T = 100005;

template <class T>
void read (T &x) {
	int sgn = 1;
	char ch;
	x = 0;
	for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
	if (ch == '-') ch = getchar(), sgn = -1;
	for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
template <class T>
void write (T x) {
	if (x < 0) putchar('-'), write(-x);
	else if (x < 10) putchar(x + '0');
	else write(x / 10), putchar(x % 10 + '0');
}

int t, cnt1 = 2, cnt2 = 1;
long long n[T], m[T], mx, fib[S], a[S], f[S][S];

void init () {
	fib[0] = fib[1] = 1;
	for (int i = 2; ; i++) {
		fib[i] = fib[i - 1] + fib[i - 2];
		if (fib[i] > mx) {
			cnt1 = i;
			break;
		}
	}
	a[1] = 1;
	for (int i = 2; ; i++) {
		a[i] = 1;
		for (int j = i - 1; j >= 1; j -= 2) a[i] += a[j];
		if (a[i] > mx) {
			cnt2 = i;
			break;
		}
	}
	for (int i = 1; i <= cnt1; i++) f[1][i] = 1ll;
	for (int i = 2; i <= cnt2; i++) {
		for (int j = 1; j <= cnt1; j++) {
			f[i][j] = i <= j ? 1ll : 0ll;
			for (int k = i - 1; k >= 1; k -= 2) f[i][j] += f[k][j];
		}
	}
}

long long solve (int x, int y, long long len) {
	if (len == a[x]) return f[x][y];
	long long ans = 0ll;
	for (int i = x - 1; i >= 1; i -= 2) {
		ans += solve(i, y, min(len, a[i]));
		len -= a[i];
		if (len <= 0) break;
	}
	if (len > 0) ans++;
	return ans;
}

int main () {
	read(t);
	for (int i = 1; i <= t; i++) {
		read(m[i]), read(n[i]);
		mx = max(mx, max(m[i], n[i]));
	}
	init();
	for (int i = 1; i <= t; i++) {
		int cnt = 0;
		for (int j = 1; j <= cnt1; j++) {
			if (fib[j] <= m[i]) cnt = j;
		}
		write(solve(cnt2, cnt, n[i] - 1)), putchar('\n');
	}
	return 0;
}

D1T3

直接暴力线段树维护最大子段和就过了,正解不会。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 100005;

template <class T>
void read (T &x) {
	int sgn = 1;
	char ch;
	x = 0;
	for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
	if (ch == '-') ch = getchar(), sgn = -1;
	for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
template <class T>
void write (T x) {
	if (x < 0) putchar('-'), write(-x);
	else if (x < 10) putchar(x + '0');
	else write(x / 10), putchar(x % 10 + '0');
}

int n, m;
long long a[N];
struct node {
	long long sum, pre, suf, val;
} sgt[N << 2];
node merge (node a, node b) {
	node ans;
	ans.sum = a.sum + b.sum;
	ans.pre = max(a.pre, a.sum + b.pre);
	ans.suf = max(b.suf, b.sum + a.suf);
	ans.val = max(max(a.val, b.val), a.suf + b.pre);
	return ans;
}
void pushup (int now) {
	sgt[now] = merge(sgt[now << 1], sgt[now << 1 | 1]);
}
void build (int l, int r, int now) {
	int mid = l + r >> 1;
	if (l == r) {
		sgt[now].sum = sgt[now].pre = sgt[now].suf = a[mid];
		sgt[now].val = max(a[mid], 0ll);
	}
	else {
		build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
		pushup(now);
	}
}

void change (int pos, int l, int r, int now, long long val) {
	int mid = l + r >> 1;
	if (l == r) {
		sgt[now].sum = sgt[now].pre = sgt[now].suf = val;
		sgt[now].val = max(val, 0ll);
	}
	else {
		if (pos <= mid) change(pos, l, mid, now << 1, val);
		else change(pos, mid + 1, r, now << 1 | 1, val);
		pushup(now);
	}
}
node query (int left, int right, int l, int r, int now) {
	int mid = l + r >> 1;
	if (l == left && r == right) return sgt[now];
	else if (right <= mid) return query(left, right, l, mid, now << 1);
	else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
	else return merge(query(left, mid, l, mid, now << 1), query(mid + 1, right, mid + 1, r, now << 1 | 1));
}

int main () {
	read(n), read(m);
	for (int i = 1; i <= n; i++) read(a[i]);
	build(1, n, 1);
	for (int i = 1; i <= m; i++) {
		int ty;
		read(ty);
		if (ty == 0) {
			int l, r, x;
			read(l), read(r), read(x);
			for (int j = l; j <= r; j++) {
				if (a[j] < x) a[j] = x, change(j, 1, n, 1, x);
			}
		}
		else {
			int l, r;
			read(l), read(r);
			node ans = query(l, r, 1, n, 1);
			write(ans.val), putchar('\n');
		}
	}
	return 0;
}

D2T1

我们首先考虑一个\(O(n^2)\)做法,我们把\(A\)\(B\)的所有长度为\(k\)的串建成trie树。然后相当于在trie树上有\(n - k + 1\)\(A\)类节点和\(B\)类节点,然后你要将\(A\)类节点和\(B\)类节点匹配,使得匹配的总距离之和除以二最小!

这是一个经典的问题,在深度不一致的情况也一样能做。我们只需考虑一条边\(u,v\)所经过的最小的次数,设\(v\)\(u\)的儿子,\(v\)的子树中有\(a\)\(A\)类节点,\(B\)\(B\)类节点,则容易证明至少经过\(\lvert a - b \rvert\)次。且我们可以归纳地构造出方案。

再考虑如何取优化它。只需把\(trie\)树换成把两个字符串拼在一起后构成的后缀树(用sam建立),然后在后缀树上定位\(2(n - k + 1)\)个节点即可。这里定位可以考虑倍增,也可以使用NOI2018你的名字的那种two-pointer的trick。

时间复杂度\(O(n)\)\(O(n \log n)\)不等,可以获得\(100\)分。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 150005;

int n, m, len[N << 2], par[N << 2], num[N << 2], last = 0, cnt = 0;
char a[N], b[N];
long long ans = 0ll;
map<char, int> ch[N << 2];

void extend (char c) {
	int p = last, np = ++cnt;
	len[np] = len[p] + 1;
	for (; ~p && !ch[p][c]; p = par[p]) ch[p][c] = np;
	if (p < 0) par[np] = 0;
	else {
		int q = ch[p][c];
		if (len[q] == len[p] + 1) par[np] = q;
		else {
			int nq = ++cnt;
			ch[nq] = ch[q], len[nq] = len[p] + 1;
			par[nq] = par[q], par[np] = par[q] = nq;
			for (; ~p && ch[p][c] == q; p = par[p]) ch[p][c] = nq;
		}
	}
	last = np;
}

vector<int> child[N << 2];
void dfs (int u) {
	for (int i = 0; i < child[u].size(); i++) {
		dfs(child[u][i]);
		num[u] += num[child[u][i]];
	}
	if (u) ans += 1ll * max(num[u], -num[u]) * (min(len[u], m) - min(len[par[u]], m));
}

int main () {
	scanf("%d%d", &n, &m);
	scanf("%s%s", &a, &b);
	len[0] = 0, par[0] = -1;
	for (int i = n - 1; i >= 0; i--) extend(b[i]);
	for (int i = n - 1; i >= 0; i--) extend(a[i]);
	for (int i = 0; i <= cnt; i++) num[i] = 0;
	int tmp = 0, now = 0;
	for (int i = n - 1; i >= 0; i--) {
		now = ch[now][b[i]], tmp++;
		if (tmp > m) {
			if (len[par[now]] >= m) now = par[now];
			tmp = m;
		}
		if (i <= n - m) num[now]++;
	}
	for (int i = n - 1; i >= 0; i--) {
		now = ch[now][a[i]], tmp++;
		if (tmp > m) {
			if (len[par[now]] >= m) now = par[now];
			tmp = m;
		}
		if (i <= n - m) num[now]--;
	}
	for (int i = 1; i <= cnt; i++) child[par[i]].push_back(i);
	dfs(0);
	ans >>= 1;
	printf("%lld\n", ans);
	return 0;
}

D2T2

这道毒瘤的题目性质太多,细节也太多,我可能难以给出详细的证明和解释,请见谅。

首先我们设\(0, 1, 2, ..., n, n + 1\)中已经被填的数的集合为\(A\)(约定\(0, n + 1 \in A\)),没有被填的数的集合为\(B\)。我们把\(0, 1, ..., n, n + 1\)按照已填和未填分段,设为\(A_0, B_1, A_1, B_2, ..., B_m, A_m\)。若\(1 \leq l \leq r \leq k\),则区间内部的数已经固定,我们无需考虑。我们只需考虑\(k < l \leq r \leq n,l \leq k < r \leq n\)的部分。为了同时让这两部分最大化,我们猜测有如下结论

存在一个最优解满足:

1.对于\(B_1, B_2, ..., B_m\),它在排列中一定构成递增或递减的连续的一段

2.对于每个\(k < i \leq n\)\(p_{k + 1}, p_{k + 2}, ..., p_i\)必定构成\(B\)的连续的一段。(例如\(B = \{2, 4, 6, 8\}\),则\(4,6,2\)算连续的一段,\(2,4,8\)不算)

有了第二条,我们可以统计未填数的每一种前缀填法对答案的贡献是多少,然后做一个\(O(n^2)\)的区间dp。然而毒瘤的出题人有一个神仙的做法,没有考虑到有这种辣鸡想法,所以就没有给\(O(n^2)\)的部分分

为了优化这两个dp,我们需要考虑性质\(1\)了。若\(\{ p_{k + 1}, p_{k + 2}, ..., p_i \}\)满足\(B_1, ..., B_m\)要么全在里面,要么全不在,则称这个集合是好的,反之称为坏的集合。结果我们会发现,每一个未填的后缀最多只对4个好的集合有贡献。

再考虑坏的集合必定恰好夹在两个好的集合之间。而每一个未填的后缀又最多只对4个好集合之间的坏集合有贡献。

我们把所有有贡献的点(对应的是\(O(n)\)种好集合)称之为关键点,按左端点从大往小排序,相同则按右端点从小到大排序。注意到我们的\(dp\)状态只用考虑关键点,所以状态优化到了\(O(n)\)种。然后转移这一维只对右端点的范围有限制,所以可以用树状数组优化,就变成了\(O(n \log n)\)注意经过坏集合的情况需要单独地转移,同时注意这个dp的起始位置和最终位置也最好作为关键点

当我们用记录方案的dp构造出了排列后,就使用线段树或析合树等方法计算排列连续段即可。总时间复杂度\(O(n \log n)\)

代码如下:


#include <bits/stdc++.h>
using namespace std;

const int N = 200005;

template <class T>
void read (T &x) {
	int sgn = 1;
	char ch;
	x = 0;
	for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
	if (ch == '-') ch = getchar(), sgn = -1;
	for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
template <class T>
void write (T x) {
	if (x < 0) putchar('-'), write(-x);
	else if (x < 10) putchar(x + '0');
	else write(x / 10), putchar(x % 10 + '0');
}

int n, m, p[N], inv[N], mn[N << 2], tot[N << 2], tag[N << 2], cnt = 0;
bool vis[N];
long long ans = 0ll;

void build (int l, int r, int now) {
	int mid = l + r >> 1;
	mn[now] = tag[now] = 0;
	tot[now] = r - l + 1;
	if (l < r) build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
}
void cover (int now, int val) {
	mn[now] += val, tag[now] += val;
}
void pushdown (int now) {
	cover(now << 1, tag[now]), cover(now << 1 | 1, tag[now]);
	tag[now] = 0;
}
void pushup (int now) {
	mn[now] = min(mn[now << 1], mn[now << 1 | 1]);
	tot[now] = 0;
	if (mn[now] == mn[now << 1]) tot[now] += tot[now << 1];
	if (mn[now] == mn[now << 1 | 1]) tot[now] += tot[now << 1 | 1];
}

void change (int left, int right, int l, int r, int now, int val) {
	int mid = l + r >> 1;
	if (l == left && r == right) cover(now, val);
	else {
		pushdown(now);
		if (right <= mid) change(left, right, l, mid, now << 1, val);
		else if (left > mid) change(left, right, mid + 1, r, now << 1 | 1, val);
		else change(left, mid, l, mid, now << 1, val), change(mid + 1, right, mid + 1, r, now << 1 | 1, val);
		pushup(now);
	}
}
int query (int left, int right, int l, int r, int now) {
	int mid = l + r >> 1;
	if (l == left && r == right) return mn[now] == 1 ? tot[now] : 0;
	else {
		pushdown(now);
		if (right <= mid) return query(left, right, l, mid, now << 1);
		else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
		else return query(left, mid, l, mid, now << 1) + query(mid + 1, right, mid + 1, r, now << 1 | 1);
	}
}

pair<long long, int> bit[N];
int lowbit (int x) {
	return x & -x;
}
void init () {
	for (int i = 0; i <= n + 1; i++) bit[i] = make_pair(0ll, -1);
}
void add (int pos, pair<long long, int> pi) {
	for (int i = pos; i <= n + 1; i += lowbit(i)) bit[i] = max(bit[i], pi);
}
pair<long long, int> ask (int pos) {
	pair<long long, int> res(-1ll, -1);
	for (; pos; pos ^= lowbit(pos)) res = max(res, bit[pos]);
	return res;
}

struct node {
	int l, r;
	bool operator < (node rhs) const {
		if (l > rhs.l) return true;
		if (l < rhs.l) return false;
		return r < rhs.r;
	}
} ;
node seg (int l, int r) {
	node i = {l, r};
	return i;
}

int rk[N], le[N], ri[N], blo[N];
long long val1[N], val2[N];
set<node> node_set;
vector<node> node_vec;
long long f[N * 5], trans_val1[N * 5], trans_val2[N * 5], val[N * 5];
int prv[N * 5], trans1[N * 5], trans2[N * 5];
void add_seg1 (int l, int r) {
	int bl = blo[l], br = blo[r];
	if (bl == br) {
		if (bl && le[bl] == l) val1[bl - 1] += ri[bl - 1] - le[bl - 1];
		if (br < blo[n + 1] && ri[br] == r) val2[br + 1] += ri[br + 1] - le[br + 1];
	}
	if (bl < br) node_set.insert(seg(bl + 1, br - 1));
	if (bl && le[bl] == l) node_set.insert(seg(bl - 1, br - 1));
	if (br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl + 1, br + 1));
	if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl - 1, br + 1));
}
void add_seg2 (int l, int r) {
	int bl = blo[l], br = blo[r], u, v;
	if (bl < br) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
		val[u]++;
	}
	if (bl && le[bl] == l) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
		val[u]++;
	}
	if (br < blo[n + 1] && ri[br] == r) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
		val[u]++;
	}
	if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
		val[u]++;
	}
	if (bl < br && bl && le[bl] == l) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
		v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
		trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
	}
	if (bl < br && br < blo[n + 1] && ri[br] == r) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
		v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
		trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
	}
	if (bl && br < blo[n + 1] && le[bl] == l && ri[br] == r) {
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
		v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
		trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
		u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
		v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
		trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
	}
}

bool dir[N];
void seg_init () {
	vis[0] = vis[n + 1] = true;
	rk[0] = le[0] = blo[0] = 0;
	for (int i = 1; i <= n + 1; i++) {
		rk[i] = rk[i - 1] + vis[i];
		if (vis[i] != vis[i - 1]) {
			blo[i] = blo[i - 1] + 1;
			le[blo[i]] = ri[blo[i]] = i;
		}
		else blo[i] = blo[i - 1], ri[blo[i]] = i;
	}

	for (int i = 0; i <= blo[n + 1]; i++) val1[i] = val2[i] = 0ll;
	for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
		l = min(l, p[i]), r = max(r, p[i]);
		if (rk[r] - rk[l] == m - i) add_seg1(l, r);
	}
	node_set.insert(seg(1, blo[n + 1] - 1));
	for (int i = 1; i <= blo[n + 1]; i += 2) node_set.insert(seg(i, i));
	for (set<node> :: iterator it = node_set.begin(); it != node_set.end(); it++) node_vec.push_back(*it);
	for (int i = 0; i < node_vec.size(); i++) {
		f[i] = val[i] = trans_val1[i] = trans_val2[i] = 0ll;
		prv[i] = trans1[i] = trans2[i] = -1;
	}
	for (int i = 1; i <= blo[n + 1]; i += 2) {
		int pos = lower_bound(node_vec.begin(), node_vec.end(), seg(i, i)) - node_vec.begin();
		if (val1[i] >= val2[i]) dir[i] = false, val[pos] += val1[i];
		else dir[i] = true, val[pos] += val2[i];
	}
	for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
		l = min(l, p[i]), r = max(r, p[i]);
		if (rk[r] - rk[l] == m - i) add_seg2(l, r);
	}
}
void get_dp () {
	init();
	for (int i = 0; i < node_vec.size(); i++) {
		if (~trans1[i] && f[i] < f[trans1[i]] + trans_val1[i]) {
			prv[i] = trans1[i];
			f[i] = f[trans1[i]] + trans_val1[i];
		}
		if (~trans2[i] && f[i] < f[trans2[i]] + trans_val2[i]) {
			prv[i] = trans2[i];
			f[i] = f[trans2[i]] + trans_val2[i];
		}
		pair<long long, int> pi = ask(node_vec[i].r);
		if (f[i] < pi.first) prv[i] = pi.second, f[i] = pi.first;
		f[i] += val[i], add(node_vec[i].r, make_pair(f[i], i));
	}
}
void construct () {
	vector<int> stk;
	for (int i = node_vec.size() - 1; ~i; i = prv[i]) stk.push_back(i);
	reverse(stk.begin(), stk.end());
	int l = le[node_vec[stk[0]].l], r = ri[node_vec[stk[0]].r];
	if (blo[l] < blo[r] || dir[blo[l]]) r = l;
	else l = r;
	cnt = m, p[++cnt] = l;
	for (int i = 0; i < stk.size(); i++) {
		int L = le[node_vec[stk[i]].l], R = ri[node_vec[stk[i]].r];
		for (; l > L; ) {
			if (!vis[--l]) p[++cnt] = l;
		}
		for (; r < R; ) {
			if (!vis[++r]) p[++cnt] = r;
		}
	}
}

int main () {
	read(n), read(m);
	for (int i = 1; i <= n; i++) vis[i] = false;
	for (int i = 1; i <= m; i++) read(p[i]), vis[p[i]] = true;
	if (m < n) {
		seg_init();
		get_dp();
		construct();
	}
	ans = 0ll;
	for (int i = 1; i <= n; i++) inv[p[i]] = i;
	build(1, n, 1);
	for (int i = 1; i <= n; i++) {
		change(1, i, 1, n, 1, 1);
		if (p[i] > 1 && inv[p[i] - 1] < i) change(1, inv[p[i] - 1], 1, n, 1, -1);
		if (p[i] < n && inv[p[i] + 1] < i) change(1, inv[p[i] + 1], 1, n, 1, -1);
		ans += query(1, i, 1, n, 1);
	}
	write(ans), putchar('\n');
	for (int i = 1; i <= n; i++) write(p[i]), putchar(' ');
	putchar('\n');
	return 0;
}

D2T3

不会正解,咕掉了。

posted @ 2020-07-10 08:32  unzcjouhi  阅读(361)  评论(0编辑  收藏  举报