洛谷 「P6475 [NOI Online #2 入门组] 建设城市」
\(\texttt{Description}\)
求满足如下条件的序列 \(a\) 的数量:
-
长度为 \(2n\)。
-
\(\forall i\in[1,n],a_i\in[1,m]\) 且 \(a_i\) 为正整数。
-
前 \(n\) 项单调不降,后 \(n\) 项单调不增。
-
要求 \(a_x=a_y\)。
答案对 \(998244353\) 取模。
\(\texttt{Data Range:}1\le x<y\le2n,1\le n,m\le10^5\)
\(\texttt{Solution}\)
分两种情况讨论。
-
若 \(x\) 和 \(y\) 在异侧,即 \(x\le n,y>n\)。
首先枚举 \(x\) 和 \(y\) 的值 \(i\),然后分四段来看。
第一段是 \(1\sim x-1\),第二段是 \(x+1\sim n\),第三段是 \(n+1\sim y-1\),第四段是 \(y+1\sim2n\)。
对于第一段,需要单调不降且范围是 \(1\sim i\),那么可以看成在这些数之间插 \(i-1\) 块板,将其分为 \(i\) 段,第 \(1\) 段代表值为 \(1\) 的,以此类推。
这样就满足了单调不降的限制,我们又知道插板法的公式,所以第一段的答案就求出来了,第二、三、四段同理。
用乘法原理把四段答案相乘,再用加法原理把每一次枚举出的答案相加即可。 -
若 \(x\) 和 \(y\) 在同侧,即 \(x\le n,y\le n\) 或 \(x>n,y>n\)。
这里可以继续沿用上面的 trick。
于是没了 qwq
\(\texttt{Code}\)
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
ll ans, fac[300005];
inline void init(int n) {
fac[0] = 1;
for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
}
inline ll quick_pow(ll a, ll k, ll p) {
ll res = 1;
a %= p;
while (k) {
if (k & 1) res = res * a % p;
a = a * a % p;
k >>= 1;
}
return res;
}
inline ll inv(ll a, ll p) {return quick_pow(a, p - 2, p);}
inline ll get_C(int n, int m, ll p) {
if (n < m) return 0;
return fac[n] * inv(fac[m] * fac[n - m], p) % p;
}
int main() {
init(3e5);
int m, n, x, y;
scanf("%d %d %d %d", &m, &n, &x, &y);
if ((x <= n) ^ (y <= n)) {
for (int i = 1; i <= m; i++) {
ll tmp1 = get_C(x + i - 2, i - 1, mod);
ll tmp2 = get_C(n - x + m - i, m - i, mod);
ll tmp3 = get_C(y - n - 1 + m - i, m - i, mod);
ll tmp4 = get_C(n * 2 - y + i - 1, i - 1, mod);
ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod * tmp4 % mod) % mod;
}
printf("%lld", ans);
return 0;
}
if (x <= n && y <= n) {
for (int i = 1; i <= m; i++) {
ll tmp1 = get_C(x + i - 2, i - 1, mod);
ll tmp2 = get_C(n - y + m - i, m - i, mod);
ll tmp3 = get_C(n + m - 1, m - 1, mod);
ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod) % mod;
}
printf("%lld", ans);
return 0;
}
for (int i = 1; i <= m; i++) {
ll tmp1 = get_C(n + m - 1, m - 1, mod);
ll tmp2 = get_C(x - n + m - i - 1, m - i, mod);
ll tmp3 = get_C(n * 2 - y + i - 1, i - 1, mod);
ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod) % mod;
}
printf("%lld", ans);
return 0;
}

浙公网安备 33010602011771号