先考虑 n 比较小的时候怎么做。
非常显然地,按照线段树建树,枚举每个点作为 LCA 的答案。假设当前节点 u,对应区间 [l,r],区间长度 len=r−l+1,那么对答案的贡献为 u×(2⌈2len⌉−1)×(2⌊2len⌋−1),即在左右子树的叶子节点中各选若干,且保证左右两侧都至少选了一个点。答案即为每个点贡献之和。
上述的朴素做法复杂度为 O(n)。通常在线段树或者完全二叉树之类的比较优秀的树形结构上求一些贡献,但不能完全建出树时,常用方法是记忆化。例如在线段树上有经典结论,每个点对应的不同的 len=r−l+1 的数量量级为 O(logn)。而上文的做法中确实需要用 len 计算贡献。所以不妨考虑记忆化搜索。
但现在问题是,上文的贡献中还有 u。而 u∈[1,n],不同数量很大。于是我们想把贡献中的 u 分出来。我们猜测以 (u,len) 为根的子树贡献是一个关于 u 的一次函数,其中 len 是定值。这样我们就可以记忆化的过程中维护每个 len 对应的一次函数即可。
现在的问题是,为什么是关于 u 的一次函数。考虑归纳证明:
令 f(u,len) 为根节点为 u,根的区间 [l,r],区间长度 len=r−l+1 时的整棵子树答案。
首先,len=1 时,必有 f(u,1)=u。这很显然是一个 b=0 的一次函数。
考虑 len>1。我们假设 u 的左右儿子的 f(u,len) 都是关于 u 的一次函数。考虑 f(u,len) 的转移形式:f(u,len)=f(lson,⌈2len⌉)+f(rson,⌊2len⌋)+u×(2⌈2len⌉−1)×(2⌊2len⌋−1)。
进一步,由于左右儿子的 f 都是一次函数,且 lson=2u,rson=2u+1。
所以 f(u,len)=k⌈2len⌉×2u+b⌈2len⌉+k⌊2len⌋×(2u+1)+b⌊2len⌋+u×(2⌈2len⌉−1)×(2⌊2len⌋−1)。
其中 k 是斜率,b 是截距。
发现这个式子关于 u 是一次的,把式子整理成关于 u 的一次函数就可以发现 klen 和 blen 关于左右儿子的递推式。
具体地:klen=2×(k⌈2len⌉+k⌊2len⌋)+(2⌈2len⌉−1)×(2⌊2len⌋−1),blen=b⌈2len⌉+b⌊2len⌋+k⌊2len⌋。读者自证不难。
使用 map 记忆化即可。复杂度 O(Tlog2n)。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <queue>
#include <map>
using namespace std;
const long long MOD = 998244353ll;
using ll = long long;
int t;
ll n;
map<ll, pair<ll, ll>> mp;
ll qpow(ll a, ll b)
{
ll res = 1ll, base = a;
while (b)
{
if (b & 1ll) res = res * base % MOD;
base = base * base % MOD;
b >>= 1ll;
}
return res;
}
pair<ll, ll> query(ll u, ll len)
{
if (mp.count(len)) return mp[len];
if (len == 1) return (mp[len] = make_pair(1ll, 0ll));
if (len == 0) return (mp[len] = make_pair(0ll, 0ll));
ll x = (len + 1) / 2ll, y = len / 2ll;
auto lft = query(u << 1ll, x), rit = query(u << 1ll | 1ll, y);
auto res = make_pair(((2ll * lft.first % MOD + 2ll * rit.first % MOD) % MOD + ((qpow(2ll, x) - 1) * (qpow(2ll, y) - 1) % MOD)) % MOD, (lft.second + rit.second + rit.first) % MOD);
return mp[len] = res;
}
int main()
{
ios::sync_with_stdio(0), cin.tie(0);
cin >> t;
while (t--)
{
cin >> n;
mp.clear();
auto g = query(1, n);
cout << (g.first + g.second) % MOD << "\n";
}
return 0;
}