Codeforces 1580F - Problems for Codeforces(分治+FFT)
好题,可能是这场 div1 质量最高的题(
首先思考如何判定一个序列是不是好的,考虑这样一个策略:令 \(X=\lceil\dfrac{m}{2}\rceil\),那么如果存在两个相邻的 \(\ge X\) 的数那么显然不符合要求,否则我们将环上所有相邻且均 \(<X\) 的位置找出来,在所有这样的两个位置之间切一刀,这样可以得到许多段,令 \(m\leftarrow\lfloor\dfrac{m}{2}\rfloor\),然后对每一段分别判定即可。
考虑用类似于分治的思想来解决这个问题。即对 \(m\) 进行分治。考虑一次分治过程,我们要解决 \(m=M\) 时的子问题。方便起见我们假设 \(n\) 是奇数,\(M\) 是偶数。我们考虑将 \(\ge\dfrac{M}{2}\) 的数记作 \(1\),将 \(<\dfrac{M}{2}\) 的数记作 \(0\),那么由于 \(n\) 是奇数,必然存在相邻两位置 \(i\) 和 \(i\bmod n+1\),满足它们都是 \(0\)。我们按照上面的思路,将序列从所有这样的位置处断开,那么我们肯定会得到若干个连续段,每段内部肯定形如 \(010101…10\),不妨假设段长为 \(len\)(显然,\(len\) 应为奇数)。那么显然如果我们将所有 \(1\) 的位置上的数都减去 \(\dfrac{M}{2}\),我们将得到一个长度为 \(len\) 的、且任意相邻两项之和 \(<\dfrac{M}{2}\) 的序列(注意:此处头尾两项不算相邻!),记这样的序列个数为 \(F(\dfrac{M}{2},i)\)。
考虑倘若我们已经知道了 \(F\),如何求解答案,记 \(dp_i\) 表示用长度为奇数且任意相邻两项之和 \(<\dfrac{M}{2}\) 的序列 拼成长度为 \(i\) 的段的方案数,有 \(dp_i=\sum\limits_{j=0}^{i-1}dp_j·F(\dfrac{M}{2},i-j)\),可以使用多项式求逆优化到 \(n\log n\)。解出 DP 数组后我们枚举 \(1\) 所在的段的长度 \(l\),显然这一段的起点有 \(l\) 种可能的位置,于是 \(ans=\sum\limits_{l=1}^nl·F(\dfrac{M}{2},l)·dp_{n-l}\)。
接下来考虑怎么求 \(F\),考虑一段长度为奇数且相邻元素之和 \(<X\) 的序列怎么构成。我们还是记 \(<\dfrac{X}{2}\) 的为 \(0\),\(\ge\dfrac{X}{2}\) 的为 \(1\),那么考虑取出第一个相邻的 \(0\)(这里,首尾相邻也算相邻,方便后面的计算),那么可以发现,除了开头或结尾都是 \(0\) 的情况,其余情况下第一段和最后一段的长度都是偶数,而实际上对于开头或结尾是 \(0\) 的情况,我们也可以视作长度为 \(0\) 的连续段(毕竟,长度为 \(1\) 且唯一的元素为 \(1\) 的序列个数,等于长度为 \(0\) 的序列个数),也就是说符合条件的序列都可以视作开头一段长度为偶数的连续段,中间若干段长度为奇数的连续段,最后又有一个长度为偶数的连续段。不过仔细想想发现我们漏了一种情况:由于开头和结尾不算相邻,因此对于形如 \(101…01\) 的序列,它实际上是符合要求的序列,但由于我们想当然地认为长度为奇数的序列只可能形如 \(0101..010\),我们就没有将其统计在内。因此我们考虑分治地计算 \(F(\dfrac{X}{2},i)\),如果我们记这部分中下标为奇数部分的 OGF 为 \(A\),下标为偶数部分的 OGF 为 \(B\),那么 \(F(X,i)\) 的 OGF 就是:
同样可以 FFT 求。
这样我们就解决了 \(n\) 为奇数,每一步 \(M\) 都是偶数也就是 \(m\) 是 \(2\) 的整数次幂的情况,接下来考虑在原算法的基础上,解决 \(n\) 是偶数或某一步 \(M\) 是奇数的情况。
首先是 \(n\) 是偶数的情况,不难发现加上 \(n\) 是偶数的时候,“一定存在两个相邻的 \(0\)”的结论就不一定成立了,这时候有且仅有两种情况会被我们漏算:\(1010…10\) 和 \(0101…01\),递归的时候额外传个参数记录下这种情况的方案数即可。具体细节见代码。
接下来是某一步 \(M\) 是奇数的情况。一个让我们比较为难的地方是,将 \(\ge\lceil\dfrac{M}{2}\rceil\) 的数减去 \(\lceil\dfrac{M}{2}\rceil\) 后,会导致两边范围不一样,不过仔细想一下就会发现 \(\lceil\dfrac{M}{2}\rceil-1\) 这个数很特别:它不能与 \(1\) 挨着,因此它只可能出现在长度为 \(1\) 的连续段中,因此只用特殊处理下 \(F(M,1)\) 即可。
时间复杂度 \(n\log n\log m\)。
#include <bits/stdc++.h>
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/hash_policy.hpp>
// #include <ext/pb_ds/priority_queue.hpp>
using namespace std;
// using namespace __gnu_pbds;
#define fi first
#define se second
#define fill0(a) memset(a, 0, sizeof(a))
#define fill1(a) memset(a, -1, sizeof(a))
#define fillbig(a) memset(a, 63, sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
#define mt make_tuple
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
template <typename T1, typename T2> void chkmin(T1 &x, T2 y){
if (x > y) x = y;
}
template <typename T1, typename T2> void chkmax(T1 &x, T2 y){
if (x < y) x = y;
}
typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned int u32;
typedef unsigned long long u64;
typedef long double ld;
#ifdef FASTIO
namespace fastio {
#define FILE_SIZE 1 << 23
char rbuf[FILE_SIZE], *p1 = rbuf, *p2 = rbuf, wbuf[FILE_SIZE], *p3 = wbuf;
inline char getc() {
return p1 == p2 && (p2 = (p1 = rbuf) + fread(rbuf, 1, FILE_SIZE, stdin), p1 == p2) ? -1: *p1++;
}
inline void putc(char x) {(*p3++ = x);}
template <typename T> void read(T &x) {
x = 0; char c = getc(); T neg = 0;
while (!isdigit(c)) neg |= !(c ^ '-'), c = getc();
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getc();
if (neg) x = (~x) + 1;
}
template <typename T> void recursive_print(T x) {
if (!x) return;
recursive_print (x / 10);
putc (x % 10 ^ 48);
}
template <typename T> void print(T x) {
if (!x) putc('0');
if (x < 0) putc('-'), x = -x;
recursive_print(x);
}
template <typename T> void print(T x,char c) {print(x); putc(c);}
void readstr(char *s) {
char c = getc();
while (c <= 32 || c >= 127) c = getc();
while (c > 32 && c < 127) s[0] = c, s++, c = getc();
(*s) = 0;
}
void printstr(string s) {
for (int i = 0; i < s.size(); i++) putc(s[i]);
}
void printstr(char *s) {
int len = strlen(s);
for (int i = 0; i < len; i++) putc(s[i]);
}
void print_final() {fwrite(wbuf, 1, p3 - wbuf, stdout);}
}
using namespace fastio;
#endif
const int MAXN = 50000;
const int MAXP = 262144;
const int pr = 3;
const int ipr = 332748118;
const int MOD = 998244353;
int n, m;
typedef vector<int> vi;
void add(int &x, int y) {((x += y) >= MOD) && (x -= MOD);}
int qpow(int x, int e) {
int ret = 1;
for (; e; e >>= 1, x = 1ll * x * x % MOD)
if (e & 1) ret = 1ll * ret * x % MOD;
return ret;
}
int rev[MAXP + 5], LEN;
void FFT(vi &a, int len, int type) {
int lg = 31 - __builtin_clz(len);
for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
for (int i = 0; i < len; i++) if (rev[i] < i) swap(a[rev[i]], a[i]);
static int W[MAXP + 5];
for (int i = 0; i < len; i++) W[i] = 0; W[0] = 1;
for (int i = 2; i <= len; i <<= 1) {
int wW = qpow((type < 0) ? ipr : pr, (MOD - 1) / i);
for (int j = (i >> 1) - 1; ~j; j--) W[j] = (j & 1) ? (1ll * W[j >> 1] * wW % MOD) : W[j >> 1];
for (int j = 0; j < len; j += i) {
for (int k = 0; k < (i >> 1); k++) {
int X = a[j + k], Y = 1ll * W[k] * a[(i >> 1) + j + k] % MOD;
a[j + k] = (X + Y) % MOD; a[(i >> 1) + j + k] = (X - Y + MOD) % MOD;
}
}
}
if (type < 0) {
int ivn = qpow(len, MOD - 2);
for (int i = 0; i < len; i++) a[i] = 1ll * a[i] * ivn % MOD;
}
}
vi conv(vi a, vi b) {
int LEN = 1; while (LEN < a.size() + b.size()) LEN <<= 1;
a.resize(LEN, 0); b.resize(LEN, 0); FFT(a, LEN, 1); FFT(b, LEN, 1);
for (int i = 0; i < LEN; i++) a[i] = 1ll * a[i] * b[i] % MOD;
FFT(a, LEN, -1); return a;
}
vi getinv(vi a, int len) {
vi b(len, 0); b[0] = qpow(a[0], MOD - 2);
for (int i = 2; i <= len; i <<= 1) {
vi c(b.begin(), b.begin() + (i >> 1));
vi d(a.begin(), a.begin() + i);
c = conv(c, c); d = conv(c, d);
for (int j = 0; j < i; j++) b[j] = (2 * b[j] % MOD - d[j] + MOD) % MOD;
}
return b;
}
pair<vi, pii> solve(int x) {
if (x == 1) {
vector<int> F(n + 1);
for (int i = 0; i <= n; i++) F[i] = 1;
return mp(F, mp(1, 0));
// se.fi: num of seqs with at least one adjacent numbers < floor(x/2)
// se.se: num of seqs with no one adjacent numbers < floor(x/2)
}
auto T = solve(x >> 1);
vi _F = T.fi, F(n + 1), odd(LEN), even(n + 1);
_F[1] = x + 1 >> 1; odd[0] = 1;
for (int j = 0; j <= n; j++) if (j & 1) odd[j] = (MOD - _F[j]) % MOD;
vector<int> H = getinv(odd, LEN);
for (int j = 0; j <= n; j++) if (~j & 1) even[j] = _F[j];
vector<int> mul = conv(conv(even, even), H);
for (int j = 0; j <= n; j++) F[j] = (mul[j] + ((j & 1) ? _F[j] : 0)) % MOD;
int ss1 = 0, ss2 = 2ll * (T.se.fi + T.se.se) % MOD;
for (int j = 1; j <= n; j++) if (j & 1) ss1 = (ss1 + 1ll * _F[j] * j % MOD * H[n - j]) % MOD;
return mp(F, mp(ss1, ss2));
}
int main() {
scanf("%d%d", &n, &m); LEN = 1; while (LEN <= n + n) LEN <<= 1;
auto T = solve(m); printf("%d\n", (T.se.fi + (~n & 1) * T.se.se) % MOD);
return 0;
}

浙公网安备 33010602011771号