R - Weak Pair HDU - 5877 离散化+权值线段树+dfs序 区间种类数

R - Weak Pair

 HDU - 5877 

离散化+权值线段树

这个题目的初步想法,首先用dfs序建一颗树,然后判断对于每一个节点进行遍历,判断他的子节点和他相乘是不是小于等于k,

这么暴力的算法很自然的超时了。

然后上网搜了一下题解,感觉想的很巧妙。

就是我们要搜 子节点和父节点的乘积小于一个定值的对数。

一般求对数,有逆序对,都是把满足的放进去,到时候直接求答案就可以了。这个题目也很类似,但是还是有很大的区别的。

这个题目就是先把所有a[i] 和 k/a[i] 都放进一个数组,离散化,这一步是因为要直接求值,就是要把这个值放进线段树的这个离散化后的位置,权值为1 .

这个满足了a[i]*a[j]<=k 的要求,然后就是他们的关系必须是子节点和父节点。

这一点可以用dfs序来实现,先把父节点放进去,然后之后的子节点都可以查找这个节点,最后这个父节点的所有子节点都查找完之后就是把这个父节点弹出。

 

以上做法都是上网看题解的,我觉得还是没有那么难想了,这种差不多就是树上要满足是父节点子节点的关系都是可以用dfs来满足的。

其次就是弹出操作没有那么好想,最后就是放入线段树直接查找的这种建一棵权值线段树思想。

 

这个知道怎么写之后就很好写了,注意细节

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <iostream>
#include <queue>
#include <string>
#include <cmath>
#include <vector>
#include <map>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
const int maxn = 2e5 + 10;
typedef long long ll;
ll sum[maxn * 8], a[maxn], b[maxn], ans;
vector<int>G[maxn];
int len, f[maxn];
bool vis[maxn];
ll n, k;

void update(int id, int l, int r, int pos, int val) {
    if (l == r) {
        sum[id] += val;
        return;
    }
    int mid = (l + r) >> 1;
    if (pos <= mid) update(id << 1, l, mid, pos, val);
    else update(id << 1 | 1, mid + 1, r, pos, val);
    sum[id] = sum[id << 1] + sum[id << 1 | 1];
}

ll query(int id, int l, int r, int x, int y) {
    if (x <= l && y >= r) return sum[id];
    int mid = (l + r) >> 1;
    ll ans = 0;
    if (x <= mid) ans += query(id << 1, l, mid, x, y);
    if (y > mid) ans += query(id << 1 | 1, mid + 1, r, x, y);
    return ans;
}

void dfs(int u) {
    vis[u] = 1;
    int t2 = lower_bound(b + 1, b + 1 + len, a[u]) - b;
    update(1, 1, len, t2, 1);
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (vis[v]) continue;
        int t1 = lower_bound(b + 1, b + 1 + len, k / a[v]) - b;
        ans += query(1, 1, len, 1, t1);
        dfs(v);
    }
    update(1, 1, len, t2, -1);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        ans = 0;
        scanf("%lld%lld", &n, &k);
        memset(vis, 0, sizeof(vis));
        for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i], G[i].clear();
        for (int i = 1; i <= n; i++) b[i + n] = k / a[i];
        sort(b + 1, b + 1 + 2 * n);
        len = unique(b + 1, b + 1 + 2 * n) - b - 1;
        memset(sum, 0, sizeof(sum));
        for (int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
            vis[v] = 1;
        }
        int root = 1;
        for (int i = 1; i <= n; i++) {
            if (vis[i] == 0) {
                root = i;
                break;
            }
        }
        memset(vis, 0, sizeof(vis));
        dfs(root);
        printf("%lld\n", ans);
    }
    return 0;
}
View Code
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <iostream>
#include <queue>
#include <string>
#include <cmath>
#include <vector>
#include <map>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
const int maxn = 2e5 + 10;
typedef long long ll;
ll sum[maxn*8], a[maxn], b[maxn], ans;
vector<int>G[maxn];
int len, f[maxn];
bool vis[maxn];
ll n, k;
void build(int id,int l,int r)
{
    sum[id] = 0;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(id << 1, l, mid);
    build(id << 1 | 1, mid + 1, r);
}

void update(int id,int l,int r,int pos,int val)
{
    if(l==r)
    {
        sum[id] += val;
        return;
    }
    int mid = (l + r) >> 1;
    if(pos<=mid) update(id << 1, l, mid, pos, val);
    else update(id << 1 | 1, mid + 1, r, pos, val);
    sum[id] = sum[id << 1] + sum[id << 1 | 1];
}

ll query(int id,int l,int r,int x,int y)
{
    if (x <= l && y >= r) return sum[id];
    int mid = (l + r) >> 1;
    ll ans = 0;
    if (x <= mid) ans += query(id << 1, l, mid, x, y);
    if (y > mid) ans += query(id << 1 | 1, mid + 1, r, x, y);
    return ans;
}

void dfs(int u)
{
    vis[u] = 1;
    int t1 = lower_bound(b + 1, b + 1 + len, k / a[u]) - b;
    int t2 = lower_bound(b + 1, b + 1 + len, a[u]) - b;
    ans += query(1, 1, len, 1, t1);
    update(1, 1, len, t2, 1);
    for(int i=0;i<G[u].size();i++)
    {
        int v = G[u][i];
        if (vis[v]) continue;
        dfs(v);
    }
    update(1, 1, len, t2, -1);
}

int main()
{
    int t;
    scanf("%d", &t);
    while(t--)
    {
        ans = 0;
        scanf("%lld%lld", &n, &k);
        memset(vis, 0, sizeof(vis));
        for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i], G[i].clear();
        for (int i = 1; i <= n; i++) b[i + n] = k / a[i];
        sort(b + 1, b + 1 + 2 * n);
        len = unique(b + 1, b + 1 + 2 * n) - b - 1;
        build(1, 1, len);
        for(int i=1;i<n;i++)
        {
            int u, v;
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
            vis[v] = 1;
        }
        int root = 1;
        for (int i = 1; i <= n; i++) {
            if (vis[i] == 0) {
                root = i;
                break;
            }
        }
        memset(vis, 0, sizeof(vis));
        dfs(root);
        printf("%lld\n", ans);
    }
    return 0;
}
View Code

 

posted @ 2019-07-27 10:47  EchoZQN  阅读(161)  评论(0编辑  收藏  举报