Luogu P3862 数圈 题解 [ 蓝 ] [ 递推 ] [ 打表 ]

数圈:简单题,主要是递推的思维。

先考虑前三个部分分,首先这玩意是可以容斥算的,只需要求出 \(n\) 个点的完全图的环数,和 \(n\) 个点的无向完全图,经过某条特定边的环数是多少,相减即可得到答案。

直接做显然不好做,容易发现可以采用递推的思想来算。定义 \(f_i\) 表示 \(i\) 个点的无向完全图的环数,则有递推式:

\[f_{i} = f_{i - 1} + \dfrac{\sum_{j = 2}^{i - 1}A_{i - 1}^{j}}{2} \]

具体地,可以想象为统计新加入的点的贡献,然后枚举给这个点贡献的环长即可,注意一个环正反会被数两次,要去重。

而对于 \(i\) 点无向完全图中,经过某条特定边的环数 \(g_i\),显然也可以枚举经过这条边的路径的长度进行计算:

\[g_i = \sum_{j = 1}^{i - 1}A_{i - 1}^{j} \]

为了快速求出 \(\sum_{j = 1}^{i - 1}A_{i - 1}^{j}\),则也需要对其进行递推。设 \(sm = \sum_{j = 1}^{i }A_{i}^{j}\),有:

\[sm_{i} = sm_{i - 1}\times i + i \]

之后就可以利用 \(sm\) 快速计算 \(f, g\) 了。于是最终答案就是 \(f_n - g_n\)

注意到最后一个 Subtask 的取值范围只有 \(10^6\),于是打表打出 \(n = 9.99\times 10^8\) 的答案,然后接着它的结果继续递推即可。

时间复杂度 \(O(Tn)\)。其中 \(n = 10^6\)

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 100005;
const ll mod = 998244353, inv2 = 499122177;
ll n;
void get_999000000()
{
    ll sm0 = 1, sm1 = 4, sm2 = 15, f = 1;
    for(int i = 4; i <= 999000000; i++)
    {
        f = ((f + (sm2 - i + 1) * inv2 % mod) % mod + mod) % mod;
        ll tmp = (sm2 * i % mod + i) % mod;
        sm0 = sm1;
        sm1 = sm2;
        sm2 = tmp;
    }
    cerr << sm0 << " " << sm1 << " " << sm2 << " " << f << endl;
}
void Sub4(int n)
{
    ll sm0 = 653816132, sm1 = 453961452, sm2 = 915343230, f = 930477414;
    for(int i = 999000001; i <= n; i++)
    {
        f = ((f + (sm2 - i + 1) * inv2 % mod) % mod + mod) % mod;
        ll tmp = (sm2 * i % mod + i) % mod;
        sm0 = sm1;
        sm1 = sm2;
        sm2 = tmp;
    }    
    cout << ((f - sm0) % mod + mod) % mod << "\n";
}
void Sub3(int n)
{
    ll sm0 = 1, sm1 = 4, sm2 = 15, f = 1;
    for(int i = 4; i <= n; i++)
    {
        f = ((f + (sm2 - i + 1) * inv2 % mod) % mod + mod) % mod;
        ll tmp = (sm2 * i % mod + i) % mod;
        sm0 = sm1;
        sm1 = sm2;
        sm2 = tmp;
    }    
    cout << ((f - sm0) % mod + mod) % mod << "\n";
}
void solve()
{
	cin >> n;
    if(n > 100000)
    {
        Sub4(n);
        return;
    }
    if(n > 3)
    {
        Sub3(n);
        return;
    }
    cout << "0\n";
}
int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while(t--) solve();
    return 0;
}
posted @ 2025-10-25 00:14  KS_Fszha  阅读(3)  评论(0)    收藏  举报