函数调用

题目

https://www.luogu.com.cn/problem/P7077
这题给我好搞。不过思路精巧。

\(Solution\)

首先我们考虑一个比较简单的问题。
称函数1是加法,函数2是乘法,函数3是调用。

首先,要转换思路。原先暴力模拟,考虑函数顺序,\(O(Qn)\), 必炸。我们考虑统计每个函数对答案的贡献。这样是\(O(n)\)复杂度的。其次,我们在考虑贡献时,也不应该直接修改上去。如果对于一个函数2,这个做法瞬间就退化了。类比线段树的懒标记,我们想到,给他们打标记,最后一并计算就行了。记住一个性质:乘可以看成多次加

我们从简单问题入手:
在不考虑函数3的情况下,对于函数2,我们自然可以维护总乘积;对于函数1,我们可以维护一个\(mul[i]\),代表当前函数被乘的次数,注意这里,这个函数被乘的次数等于后面函数的\(mul\)之积。如果前后存在多个\(f[i]\), 那么计算也是一样的。这样我们最后把每个\(a[p]\)加上\(mul[i] * v\)就可以了。

接下来考虑函数3。做法搬过来。我们变一下\(mul\)定义:调用此函数时,全局乘了多少次。也就是后面的mul不用乘了。
先把\(mul\)维护好(下面要用,否则\(sum\)维护不出)。具体来说,我们对于每一个调用函数,按照拓扑逆序递推:

\[mul[u] = \prod \limits_{i \in son} mul[i] \]

我们发现,对于函数调用的函数(姑且称为子函数吧),他们的贡献没有被计算。我们这时候只维护\(mul\)是无法做的,因为子函数无法处理了,子函数会被父函数调用的次数影响。这启发我们维护\(sum[i]\),代表\(f[i]\)被调用的次数。具体来说,我们对于每个父函数(其实就是特判,他们永远只被调用一次)

\[sum[i] = \prod \limits_{j = i + 1}^Q mul[j] \]

对于每个子函数:

\[sum[u] = sum[father] * \prod \limits_{在u右边} mul[brother] \]

其实就是 父函数调用多少次 乘上后面兄弟乘了多少次
到现在为止应该能理解为什么维护\(mul\)了吧。因为后面兄弟调用的次数包括父节点调用次数,不能计算啦~,乘起来答案就不对了。

不要忘记要把相同的函数加起来~

那么我们最后就可以统计答案了~
对于函数1,维护总乘数,直接乘。对于函数2,由于可能做子函数,于是我们拿\(sum\)直接\(w[p] += v * sum[i]\)

这样就把答案统计出来了。

总结:按拓扑逆序,先求单点挂着的\(mul\), 反向枚举sum[i], 算父函数后面乘积,在拓扑正序求出sum, 最后统计答案。


更新:

又想了想,之前没有说的多么清楚,其实可以这么想:类比不带函数3,我们同样考虑维护每个点对答案的贡献\(sum\)。可以分为两类。第一类是大函数,就是最开始调用的函数,第二类是子函数。
对于大函数,我们发现其实需要知道一个函数对前面函数的贡献,也就是让所有数乘了多少倍,这也就是\(mul\),于是大函数类似不带函数3就解决了。
对于小函数,我们发现其实根据当前维护的\(sum\)是求不出\(sum[u]\)的,因为后面兄弟对这个点的贡献不知道是多少,所以我们就需要维护这个值,也就是\(mul\),表示调用函数时给全部数乘了多少。我们发现这个非常好维护。拓扑逆序递推 儿子乘起来就行了。这样我们就能求解了。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010, M = 2000010, mod = 998244353;

typedef long long LL;
//sum分为大函数贡献和子函数 子函数贡献需要兄弟对孩子的贡献引出mul

struct Function
{
    int t, p, v, sum, mul;
}f[N];
int n, m, Q;
int h[N], e[M], ne[M], idx;
int w[N];
int q[N], din[N];
int g[N];

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void topsort()
{
    int hh = 1, tt = 0;
    for (int i = 1; i <= m; i ++ )
        if (!din[i]) q[ ++ tt] = i;
    while (hh <= tt)
    {
        int t = q[hh ++ ];
        for (int i = h[t]; ~i; i = ne[i])
            if ( -- din[e[i]] == 0) q[ ++ tt] = e[i];
    }
}

void get_mul() //调用函数i全局乘几次
{
    for (int k = m; k; k -- )
    {
        int t = q[k];
        for (int i = h[t]; ~i; i = ne[i])
        {
            int j = e[i];
            f[t].mul = (LL)f[t].mul * f[j].mul % mod; 
        }
    }
}

void get_sum()
{
    for (int t = 1; t <= m; t ++ )
    {
        int k = q[t];
        int sum = f[k].sum;
        for (int i = h[k]; ~i; i = ne[i]) //前向星天然头插 最右侧最先遍历
        {
            int j = e[i];
            f[j].sum = ((LL)f[j].sum + sum) % mod; //相同函数不能相乘 是不同的方式增加的
            sum = (LL)sum * f[j].mul % mod;
        }
    }
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    
    memset(h, -1, sizeof h);
    scanf("%d", &m);
    for (int i = 1; i <= m; i ++ )
    {
        int opt, p = 0, v = 0;
        scanf("%d", &opt);
        if (opt == 1) scanf("%d%d", &p, &v);
        else if (opt == 2) scanf("%d", &v);
        else
        {
            int cnt = 0;
            scanf("%d", &cnt);
            while (cnt -- )
            {
                int x;
                scanf("%d", &x);
                add(i, x);
                din[x] ++ ;
            }
        }
        
        f[i] = {opt, p, v, 0, 0};
    }
    
    //mul初始化
    for (int i = 1; i <= m; i ++ )
        if (f[i].t == 2) f[i].mul = f[i].v;
        else f[i].mul = 1;
    
    topsort();
    //拓扑逆序求mul 可以不用知道操作序列
    get_mul();
    
    scanf("%d", &Q);
    for (int i = 1; i <= Q; i ++ ) scanf("%d", &g[i]);
    
    //求sum 大函数 递推起点
    int sum = 1;
    for (int i = Q; i; i -- )
    {
        f[g[i]].sum = (f[g[i]].sum + sum) % mod;
        sum = (LL)sum * f[g[i]].mul % mod;
    }
    get_sum(); //正序求sum 子函数
    
    //统计答案
    for (int i = 1; i <= n; i ++ )
        w[i] = (LL)w[i] * sum % mod;
    
    for (int i = 1; i <= m; i ++ )
        if (f[i].t == 1) w[f[i].p] = (w[f[i].p] + (LL)f[i].v * f[i].sum) % mod;
    
    for (int i = 1; i <= n; i ++ )
        printf("%d ", w[i]);
    return 0;
}
posted @ 2024-03-02 16:26  琴忆庭  阅读(43)  评论(0编辑  收藏  举报