[csp2020] 函数调用
前言
考试时想着用每个点map维护加法操作,然后启发式合并乱搞,然而复杂度有两个log,而且这是个DAG...
最后打了暴力滚粗
题目
https://www.luogu.com.cn/problem/P7077
题解
首先我们可以建一个0号点,向序列中的所有函数结点连边,这样就变成了调用一次0函数后的状态
考虑把加法和乘法分开来,
假设原来是$a_1+2 ,\ all*3$这样的操作,可以变为$all*3 ,\ a_1+6$
即,如果一个加法操作后面有乘法,相当于这个加法会被执行乘数次
即对于序列每一项,它的最终结果是$a_i*x+add_i$,x是全局乘上的数
考虑如何计算每一项加法操作的执行次数
因为只有后面的乘法会影响到当前的加法,所以倒序遍历子节点
对于这样一个图

i的执行次数为【fa的执行次数】*【执行j后全局会多乘上多少】*【执行k后全局会多乘上多少】
我们可以在遍历子节点时维护一个变量multiply,表示已遍历的子节点对全局乘法的贡献
每遍历完一个,就将multiply乘上它对全局乘法的贡献,这个可以用一个dfs预处理出来
另外,由于这是个DAG,所以一个函数节点必须入读为0时才能得出最终计算次数,所以我们可以用拓扑排序这整张图
另外,对于入读为0,但不会调用的,也要将其放入初始队列
否则有些点入边消不完,导致无法进入。

另外,注意乘数和加数有可能为0.
代码
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
using namespace std;
#define N 1000010
#define int long long
#define mod 998244353
int type[N],p1[N],p2[N],val[N],n,deg[N],m,vis[N];
int add[N]/*每一个位置的增加量*/,mul[N]/*调用这个函数后全局乘上的值*/,times[N]/*每个函数的调用次数*/;
vector<int> vec[N];
void dfs(int id)//计算mul
{
mul[id]=1;
vis[id]=1;
//cout<<id<<endl;
if(type[id]!=3)
{
if(type[id]==2) mul[id]=p1[id];
return;
}
for(int i=0;i<vec[id].size();i++)
{
int to=vec[id][i];
if(!vis[to]) dfs(to);
mul[id]*=mul[to],mul[id]%=mod;
}
}
void topic()
{
queue<int> q;
times[0]=1;
for(int i=0;i<=m;i++) if(!deg[i]) q.push(i);
while(!q.empty())
{
int now=q.front();
q.pop();
int multiply=times[now];//下一个调用的函数的调用次数
for(int i=vec[now].size()-1;i>=0;i--)
{
int to=vec[now][i];
if(--deg[to]==0) q.push(to);
times[to]+=multiply,times[to]%=mod;
if(type[to]==1) add[p1[to]]+=p2[to]*multiply%mod,add[p1[to]]%=mod;
multiply*=mul[to],multiply%=mod;
}
}
}
signed main()
{
//freopen("call.in","r",stdin);
//freopen("call.out","w",stdout);
cin>>n;
for(int i=1;i<=n;i++)
{
scanf("%lld",&val[i]);
}
cin>>m;
for(int i=1;i<=m;i++)
{
int t;
scanf("%lld",&t);
type[i]=t;
if(t==1)
{
scanf("%lld%lld",&p1[i],&p2[i]);
}
else if(t==2)
{
scanf("%lld",&p1[i]);
}
else
{
int c;
scanf("%lld",&c);
for(int j=1;j<=c;j++)
{
int a;
scanf("%lld",&a);
vec[i].push_back(a);
deg[a]++;
}
}
}
int q;
cin>>q;
for(int i=1;i<=q;i++)
{
int a;
scanf("%lld",&a);
vec[0].push_back(a);
deg[a]++;
}
type[0]=3;
dfs(0);
topic();
for(int i=1;i<=n;i++) printf("%lld ",(val[i]*mul[0]%mod+add[i])%mod);
}
看都看了,顺手点个推荐呗 :)

浙公网安备 33010602011771号