P1637 三元上升子序列 权值线段树

解题思路

这段代码使用权值线段树高效统计三元上升子序列的数量。主要思路是:

  1. 离散化处理原始数据,将大范围的数值映射到紧凑的区间

  2. 两次遍历序列:

    • 从左到右计算每个元素左侧比它小的元素个数(存储在ls数组)

    • 从右到左计算每个元素右侧比它大的元素个数(存储在rs数组)

  3. 统计结果:对于每个中间元素,它能组成的三元组数量等于左侧较小元素个数乘以右侧较大元素个数

代码注释

#include<bits/stdc++.h>
#define lc rt << 1      // 左子节点宏定义
#define rc rt << 1 | 1   // 右子节点宏定义
#define lson lc,l,mid    // 左子树区间宏定义
#define rson rc,mid + 1,r // 右子树区间宏定义
#define ll long long     // 长整型别名
using namespace std;
const int N = 1e5 + 10;  // 数组大小常量

struct node{
    int sum;             // 线段树节点,记录区间内数字出现次数
};

node t[N << 2];          // 线段树数组(开4倍空间)
int n,a[N],b[N];         // n:数字个数 a:原始数组 b:离散化辅助数组
ll ls[N],rs[N];          // ls:左侧比a[i]小的个数 rs:右侧比a[i]大的个数

// 线段树向上更新函数
void pushup(int rt)
{
    t[rt].sum = t[lc].sum + t[rc].sum; // 当前节点值为左右子节点值之和
}

// 线段树单点更新函数
void change(int rt,int l,int r,int x)
{
    if(x < l || r < x) return;    // 超出当前区间范围则返回
    if(l == r){                   // 找到目标叶子节点
        t[rt].sum++;              // 该值出现次数+1
        return;
    }
    int mid = (l + r) >> 1;       // 计算中点
    change(lson,x);               // 递归更新左子树
    change(rson,x);               // 递归更新右子树
    pushup(rt);                   // 更新当前节点值
}

// 线段树区间查询函数
int query(int rt,int l,int r,int x,int y)
{
    if(r < x || y < l) return 0;  // 查询区间与当前区间无交集
    if(x <= l && r <= y) return t[rt].sum; // 当前区间完全包含在查询区间内
    int mid = (l + r) >> 1;       // 计算中点
    return query(lson,x,y) + query(rson,x,y); // 返回左右子树查询结果之和
}

int main()
{
    cin >> n;
    // 读取输入数据并准备离散化
    for(int i = 1; i <= n; i++)
    {
        cin >> a[i];
        b[i] = a[i];              // 复制到b数组用于离散化
    }
    
    // 离散化处理
    sort(b + 1,b + 1 + n);        // 排序
    unique(b + 1,b + 1 + n);      // 去重
    
    // 将原始数据映射为离散化后的值
    for(int i = 1; i <= n; i++)
    {
        int x = lower_bound(b + 1,b + 1 + n,a[i]) - b; // 查找离散化后的值
        a[i] = x;                 // 替换为离散化值
    }
    
    // 第一次遍历:从左到右计算每个元素左侧比它小的元素个数
    for(int i = 1; i <= n; i++)
    {
        change(1,1,N,a[i]);       // 将当前元素插入线段树
        ls[i] = query(1,1,N,1,a[i] - 1); // 查询比当前元素小的元素个数
    }
    
    // 清空线段树准备第二次遍历
    memset(t,0,sizeof(t));
    
    // 第二次遍历:从右到左计算每个元素右侧比它大的元素个数
    for(int i = n; i >= 1; i--)
    {
        change(1,1,N,a[i]);       // 将当前元素插入线段树
        rs[i] = query(1,1,N,a[i] + 1,N); // 查询比当前元素大的元素个数
    }
    
    // 计算最终答案:每个元素能组成的三元组数量=左侧较小数×右侧较大数
    ll ans = 0;
    for(int i = 1; i <= n; i++)
        ans += (ll)ls[i] * rs[i]; // 累加所有可能的三元组
    
    cout << ans;
    return 0;
}

 

posted @ 2025-06-13 15:26  CRt0729  阅读(12)  评论(0)    收藏  举报