算法学习(15):堆应用的补充

堆应用的补充

题目:数据流取得中位数

解题思路:准备一个大根堆和小根堆,首先把第一个数加进大根堆,然后后面的数依次与大根堆顶的数比较,如果小于等于大根堆顶的数,放入大根堆,如果大于大根堆顶,放入小根堆,当大根堆和小根堆的元素个数差到达了2,则把多的那个堆的堆顶弹出进小根堆。
C++代码实现:

class Median
{
public:
    Median()
    {
        this->num = 0;
    }
    void insert(int val)
    {
        this->num++;
        if (this->big.empty())
        {
            this->big.push(val);
            return;
        }
        if (val <= this->big.top())
        {
            this->big.push(val);
        }
        else
        {
            this->small.push(val);
        }
        int sSize = this->small.size();
        int bSize = this->big.size();
        if (bSize - sSize > 1)
        {
            int temp = this->big.top();
            this->big.pop();
            this->small.push(temp);
        }
        if (sSize - bSize > 1)
        {
            int temp = this->small.top();
            this->small.pop();
            this->big.push(temp);
        }
    }
    int getMedian()
    {
        if (this->num == 0)
        {
            return NULL;
        }
        if (this->num % 2 == 0)
        {
            int temp = 0;
            temp += big.top();
            temp += small.top();
            return temp / 2;
        }
        else
        {
            int sSize = this->small.size();
            int bSize = this->big.size();
            if (sSize > bSize)
            {
                return this->small.top();
            }
            else
            {
                return this->big.top();
            }
        }
    }
    bool isEmpty()
    {
        if (this->num == 0)
            return true;
        return false;
    }
private:
    priority_queue<int, vector<int>, greater<int>> small;
    priority_queue<int> big;
    int num;
};

堆的改写:Dijkstra算法改写堆后的优化

系统提供的堆只能是插入一个数,然后自动调整,在不插入数的时候,内部是不会变的。Dijkstra算法每次选取表中路径最小的点,如果采用堆来实现,会更好,但是这种堆是需要修改内部的数值,然后再调整,系统提供的堆是插入新数字再调整,显然是不能满足需求的,所以我们需要自己改写堆。

C++代码实现堆的改写

改写后的堆提供三个接口,isEmpty()判断堆是否是空的;addOrUpdateOrIgnore()用来新增节点和更新节点的距离,以前没有的节点就插入,以前有且距离比旧的距离小则更新距离,距离比旧的距离大则忽视;pop()用来弹出堆顶元素,返回值是一个NodeRecord类,里面共有两条记录,一条是节点,一条是距离。

class NodeRecord
{
public:
    NodeRecord(Node &node, int distance)
    {
        this->m_Node = &node;
        this->m_Distance = distance;
    }
public:
    Node* m_Node;
    int m_Distance;
};

class NodeHeap
{
public:
    bool isEmpty()
    {
        return m_Nodes.size() == 0;
    }

    void addOrUpdateOrIgnore(Node& node, int distance)
    {
        if (inHeap(node))
        {
            distanceMap.at(&node) = min(distanceMap.at(&node), distance);
            heapInsert(node, heapIndexMap.at(&node));
        }
        if (!isEntered(node))
        {
            m_Nodes.push_back(&node);
            heapIndexMap.insert(make_pair(&node, m_Nodes.size() - 1));
            distanceMap.insert(make_pair(&node, distance));
            heapInsert(node, m_Nodes.size() - 1);
        }
    }

    NodeRecord* pop()
    {
        if (m_Nodes.size() != 0)
        {
            NodeRecord* nodeRecord = new NodeRecord(*m_Nodes[0], distanceMap.at(m_Nodes[0]));
            heapIndexMap.at(m_Nodes[0]) = -1;
            distanceMap.erase(m_Nodes[0]);
            swap(0, m_Nodes.size() - 1);
            m_Nodes.pop_back();
            heapify(0);
            return nodeRecord;
        }
        return NULL;
    }

private:
    void heapInsert(Node &node, int index)
    {
        while (distanceMap.at(m_Nodes[index]) < distanceMap.at(m_Nodes[(index - 1) / 2]))
        {
            swap(index, (index - 1) / 2);
            index = (index - 1) / 2;
        }
    }
    
    void heapify(int index)
    {
        int left = index * 2 + 1;
        while (left < m_Nodes.size())
        {
            int smallest = left + 1 < m_Nodes.size() && distanceMap.at(m_Nodes[left]) < distanceMap.at(m_Nodes[left + 1]) ? left : left + 1;
            smallest = distanceMap.at(m_Nodes[index]) < distanceMap.at(m_Nodes[smallest]) ? index : smallest;
            if (smallest == index)
            {
                break;
            }
            swap(index, smallest);
            index = smallest;
            left = index * 2 + 1;
        }
    }

    bool isEntered(Node &node)
    {
        return heapIndexMap.count(&node) != 0;
    }

    bool inHeap(Node& node)
    {
        return isEntered(node) && heapIndexMap.at(&node) != -1;
    }

    void swap(int index1, int index2)
    {
        heapIndexMap.at(m_Nodes[index1]) = index2;
        heapIndexMap.at(m_Nodes[index2]) = index1;
        Node* temp = m_Nodes[index2];
        m_Nodes[index2] = m_Nodes[index1];
        m_Nodes[index1] = temp;
    }

private:
    vector<Node*> m_Nodes;
    unordered_map<Node*, int> heapIndexMap;
    unordered_map<Node*, int> distanceMap;
};

Dijkstra算法利用改写的堆后的优化代码(C++)

unordered_map<Node*, int> Dijkstra2(Node &node)
{
    NodeHeap nodeHeap;
    nodeHeap.addOrUpdateOrIgnore(node, 0);
    unordered_map<Node*, int> result;
    while (!nodeHeap.isEmpty())
    {
        NodeRecord* curRecord = nodeHeap.pop();
        int distance = curRecord->m_Distance;
        for (auto edge : curRecord->m_Node->m_Edges)
        {
            nodeHeap.addOrUpdateOrIgnore(edge.m_To, edge.m_Weight + distance);
        }
        result.insert(make_pair(curRecord->m_Node, curRecord->m_Distance));
        delete curRecord;
        curRecord = nullptr;
    }
    return result;
}
posted @ 2022-07-27 13:25  小肉包i  阅读(18)  评论(0)    收藏  举报