CSP-S模拟20

没时间写了,于是就很简洁。。

A. 归隐

给了我推式子的自信,好像只用到了等比数列求和。

code
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int mod = 998244353;
const ll inv2 = 499122177;

ll n, ans;

inline ll read()
{
    ll x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')
        {
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch^48);
        ch = getchar();
    }
    return x * f;
}

ll qpow(ll a, ll b)
{
    ll ans = 1;
    while(b)
    {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int main()
{
    freopen("gy.in", "r", stdin);
    freopen("gy.out", "w", stdout);
    
    n = read();
    //ans = ((qpow(3, n-1)-1+mod)%mod*inv2%mod+1)%mod; is a[i]
    ans = (((qpow(3, n)-1+mod)%mod*inv2%mod-n%mod+mod)%mod*inv2%mod+n%mod)%mod;
    printf("%lld\n", ans);

    return 0;
}

 

B. 按位或

我觉得余数应该是位数*贡献再相加,所以把Chen_jr代码里的判断变成了2*p+q,也是对的。

容斥就是第一个是加法,后面的和上一个状态相反。

code
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int maxn = 2003;
const ll mod = 998244353;
const ll inv2 = 499122177;

ll n, t;
int c0, c1, f[75][75], c[75][75];

inline ll read()
{
    ll x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')
        {
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch^48);
        ch = getchar();
    }
    return x * f;
}

ll qpow(ll a, ll b)
{
    ll ans = 1;
    while(b)
    {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int main()
{
    freopen("or.in", "r", stdin);
    freopen("or.out", "w", stdout);
    
    n = read(); t = read();
    for(int i=0; i<=70; i++)
    {
        c[i][0] = 1;
        for(int j=1; j<=i; j++)
        {
            c[i][j] = (c[i-1][j]+c[i-1][j-1])%mod;
        }
    }
    for(int i=0; i<=60; i++) if(t& (1ll<<i))
    {
        if(i & 1) c1++; else c0++;
    }
    for(int i=0; i<=c1; i++)
    {
        for(int j=0; j<=c0; j++)
        {
            for(int p=0; p<=i; p++)
            {
                for(int q=0; q<=j; q++)
                {
                    if((2*p+q)%3==0) f[i][j] = (f[i][j]+1ll*c[i][p]*c[j][q]%mod)%mod;
                }
            }
        }
    }
    int ans = 0;
    n %= (mod-1);
    for(int i=c1; i>=0; i--)
    {
        for(int j=c0; j>=0; j--)
        {
            int opt = (c1 - i + c0 - j) & 1 ? -1 : 1;
            ans = (ans+1ll*opt*c[c1][i]*c[c0][j]%mod*qpow(f[i][j], n)%mod)%mod;
        }
    }
    ans = (ans%mod+mod)%mod;
    printf("%d\n", ans);

    return 0;
}

 

C. 最短路径

感觉和期望的线性性没什么关系,正解是用概念求的,数学期望就是求平均数!

先不考虑减掉的最长链,按每条边贡献的次数来累加答案,讨论就是两边分别有多少关键点,最后再枚举合法的链,首先它可以成为端点就没有深度更大的,所以把深度从大到小排序再枚举,在去掉深度更大的点之后都是可选范围,在这里面选离他最远的作为另一个端点,其他点任选就是组合数。

40 pts
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int maxn = 2003;
const ll mod = 998244353;
const ll inv2 = 499122177;

int n, m, sp[maxn], k, lca;
//vector<int> v1;
int dep[maxn], fa[maxn], siz[maxn], top[maxn], son[maxn];
ll ans;
//bool vis[maxn];

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')
        {
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch^48);
        ch = getchar();
    }
    return x * f;
}

struct node 
{
    int next, to;
}a[maxn<<1], e[maxn];
int head[maxn], len, h[maxn], tot;

void add(int x, int y)
{
    a[++len].to = y; a[len].next = head[x]; 
    head[x] = len;
}

void add_c(int x, int y)
{
    e[++tot].to = y; e[tot].next = head[x];
    head[x] = len;
}

void find_heavy_edge(int u, int fat, int depth)
{
    siz[u] = 1;
    dep[u] = depth;
    fa[u] = fat;
    son[u] = 0;
    int maxsize = 0;
    for(int i=head[u]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(dep[v]) continue;
        find_heavy_edge(v, u, depth+1);
        siz[u] += siz[v];
        if(siz[v] > maxsize)
        {
            maxsize = siz[v];
            son[u] = v;
        }
    }
}

void connect_heavy_edge(int u, int ancestor)
{
    top[u] = ancestor;
    if(son[u])
    {
        connect_heavy_edge(son[u], ancestor);
    }
    for(int i=head[u]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa[u] || v == son[u]) continue;
        connect_heavy_edge(v, v);
    }
}

int LCA(int x, int y)
{
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    return x;
}

ll qpow(ll a, ll b)
{
    ll ans = 1;
    while(b)
    {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

ll dp[maxn];
bool vis[maxn];
void dfs(int u, int fa)
{
    for(int i=head[u]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa) continue;
        dfs(v, u);
        if(dp[v] || vis[v]) dp[u] = (dp[u] + dp[v] + 2) % mod;
    }
}

int main()
{
    freopen("tree.in", "r", stdin);
    freopen("tree.out", "w", stdout);
    
    n = read(); m = read(); k = read();
    for(int i=1; i<=m; i++)
    {
        int x = read(); vis[x] = 1;
        sp[i] = x;
    }
    for(int i=1; i<n; i++)
    {
        int x = read(), y = read();
        add(x, y); add(y, x);
    }
    find_heavy_edge(1, 1, 1);
    connect_heavy_edge(1, 1);
    if(k == 2)
    {
        for(int i=1; i<m; i++)
        {
            for(int j=i+1; j<=m; j++)
            {
                int lca = LCA(sp[i], sp[j]);
                ans = (ans+dep[sp[i]]+dep[sp[j]]-dep[lca]-dep[lca])%mod;
            }
        }
        ans = ans * qpow(m*(m-1)%mod*inv2%mod, mod-2) % mod;
        printf("%lld\n", ans);
        exit(0);
    }
    if(k == m)
    {
        dfs(sp[1], 0);
        ll Max = 0;
        for(int i=1; i<m; i++)
        {
            for(int j=i+1; j<=m; j++)
            {
                int lca = LCA(sp[i], sp[j]);
                Max = max(Max, (ll)dep[sp[i]]+dep[sp[j]]-dep[lca]-dep[lca]);
            }
        }
        ans = (dp[sp[1]] - Max + mod) % mod;
        printf("%lld\n", ans);
    }

    return 0;
}
code
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int maxn = 2003;
const ll mod = 998244353;
const ll inv2 = 499122177;

int dist[maxn], si[maxn], ans, fac[maxn], inv[maxn], key[maxn], now;
int dis[maxn][maxn], iskey[maxn], n, m, k;

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')
        {
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch^48);
        ch = getchar();
    }
    return x * f;
}

void add(int &x, int y) {x += y; x = x >= mod ? x - mod : x;}

ll qpow(ll a, ll b)
{
    ll ans = 1;
    while(b)
    {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int C(int n, int m)
{
    if(n < 0 || m < 0 || n < m) return 0;
    return 1ll * fac[n] * inv[m] % mod * inv[n-m] % mod;
}

struct node 
{
    int next, to;
}a[maxn<<1];
int head[maxn], len;

void link(int x, int y)
{
    a[++len].to = y; a[len].next = head[x]; 
    head[x] = len;
}

void get_dis(int x, int fa)
{
    for(int i=head[x]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa) continue;
        dis[now][v] = dis[now][x] + 1;
        get_dis(v, x);
    }
}

void dfs(int x, int fa)
{
    si[x] = iskey[x];
    for(int i=head[x]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa) continue;
        dfs(v, x);
        si[x] += si[v];
        for(int j=1; j<=min(si[v], k-1); j++)
        {
            add(ans, 1ll*C(si[v], j)*C(m-si[v], k-j)*2%mod);
        }
    }
}

bool cmp(int x, int y) {return dis[1][x] > dis[1][y];}

void del()
{
    for(int i=1; i<=m; i++)
    {
        int x = key[i], p = 0;
        for(int j=i+1; j<=m; j++)
        {
            int y = key[j];
            dist[++p] = dis[x][y];
        }
        sort(dist+1, dist+1+p);
        for(int j=p; j>=1; j--) add(ans, mod-1ll*C(j-1, k-2)*dist[j]%mod);
    }
}

int main()
{
    freopen("tree.in", "r", stdin);
    freopen("tree.out", "w", stdout);
    
    n = read(); m = read(); k = read();
    for(int i=1; i<=m; i++) iskey[key[i]=read()] = true;
    for(int i=1; i<n; i++)
    {
        int u = read(), v = read();
        link(u, v); link(v, u);
    }
    for(int i=1; i<=n; i++) {now = i; dis[i][i] = 0; get_dis(i, 0);}
    fac[0] = inv[0] = 1; for(int i=1; i<=n; i++) fac[i] = 1ll * fac[i-1] * i % mod;
    inv[n] = qpow(fac[n], mod-2); for(int i=n-1; i>=1; i--) inv[i] = 1ll*inv[i+1]*(i+1)%mod;
    dfs(1, 0);
    sort(key+1, key+1+m, cmp);
    del();
    ans = 1ll * ans * qpow(C(m, k), mod-2) % mod;
    printf("%d\n", ans);

    return 0;
}

 

D. 最短路

跟2的乘方的性质关系不大,也不是想告诉我们取模之后也有比较大小的方法,就是个高精度。

至于搞个可持久化的树为什么可以优化。。我并不知道。

posted @ 2022-10-21 07:51  Catherine_leah  阅读(25)  评论(0)    收藏  举报
/* */