函数调用
首先类似线段树模板 2,如果我们直接模拟一个函数,就可以得到一个长度为 \(n\) 的数组表示加法标记,以及一个数表示乘法标记。
正着做是困难的,我们考虑倒过来,假设一个函数后面的函数的乘法标记是 \(mul\), 那么这个函数得到的所有加法标记都会 \(\times mul\), 这个东西等价于函数被调用了 \(mul\) 次,如果这个函数是会调用其它函数的,那么它调用的函数的调用次数要加上它的调用次数 \(\times mul\), 根据这个,先记忆化求出每个点的乘法标记,再拓扑排序即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
// typedef __int128 i128;
typedef pair<int, int> pii;
const int N = 2e5 + 10, mod = 998244353;
template<typename T>
void dbg(const T &t) { cout << t << endl; }
template<typename Type, typename... Types>
void dbg(const Type& arg, const Types&... args) {
#ifdef ONLINE_JUDGE
return ;
#endif
cout << arg << ' ';
dbg(args...);
}
int n, m, Q, typ[N], in[N], f[N];
ll a[N], p[N], val[N], vis[N], t[N], mul[N], add[N];
vector<int>e[N];
inline ll dfs(int u) {
if (mul[u] != -1) return mul[u];
if (typ[u] == 1) return mul[u] = 1;
if (typ[u] == 2) return mul[u] = val[u];
mul[u] = 1;
for (auto v : e[u]) mul[u] = mul[u] * dfs(v) % mod;
return mul[u];
}
int main() {
// freopen("data.in", "r", stdin);
// freopen("data.out", "w", stdout);
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
memset(mul, 255, sizeof mul);
cin >> m;
for (int i = 1; i <= m; i++) {
cin >> typ[i];
if (typ[i] == 1) {
cin >> p[i] >> val[i];
} else if (typ[i] == 2) {
cin >> val[i];
} else {
int k;
cin >> k;
while (k--) {
int x;
cin >> x;
++in[x];
e[i].push_back(x);
}
reverse(e[i].begin(), e[i].end());
}
}
cin >> Q;
for (int i = 1; i <= Q; i++) cin >> f[i], e[0].push_back(f[i]), ++in[f[i]];
reverse(e[0].begin(), e[0].end());
for (int i = 0; i <= m; i++) if (mul[i] == -1) dfs(i);
queue<int>q;
q.push(0);
t[0] = 1;
for (int i = 1; i <= m; i++) if (!in[i]) q.push(i);
while (!q.empty()) {
int u = q.front(); q.pop();
ll fac = 1;
for (auto v : e[u]) {
t[v] = (t[v] + t[u] * fac) % mod;
// dbg("###", u, v, fac, t[v]);
fac = fac * mul[v] % mod;
if (!--in[v]) q.push(v);
}
}
for (int i = 1; i <= n; i++) {
a[i] = a[i] * mul[0] % mod;
}
for (int i = 1; i <= m; i++) {
if (typ[i] == 1) a[p[i]] = (a[p[i]] + val[i] * t[i]) % mod;
}
for (int i = 1; i <= n; i++) {
cout << a[i] << " \n"[i == n];
}
return 0;
}

浙公网安备 33010602011771号