Go语言练习:基于最小堆的外部排序

问题:一个很大的数据文件,单节点可用内存有限,使用多节点实现文件数据排序。

思路:

1.一个主节点负责将文件分块,如果分3块,则开3个goroutine分别读取文件的一部分,交给一个外部节点独立排序;

2.每个外部节点接收数据并排序,将排序结果传回主节点;

3.主节点从每个外部节点接收一个数据构建最小堆,最小堆的元素除记录数据外还要记录数据来自于哪个节点,这样将最小数据写入文件写缓冲区后,需要从相应的节点再拿一个数据过来重新构造最小堆,直至所有数据均写入文件,排序完成。

4.考虑到数据可能存在相等的情况,构建的最小堆需要支持记录数值,以及所有数值的来源,这样从最小堆取出一个数据,可以写入几个相等的数据到文件缓冲区,同时从每个相等数据的来源分别读一个数据重新构造最小堆:

type item struct {
    data   interface{}
    source []int // 标记data来源,例如[]chan的id
}

图解:

代码结构:

GOPATH/pipeline/heapsort
--heap
----heap.go //最小堆数据结构
--mergesort
----mergesort.go //多路归并排序
----source.go      //数据源(随机数据源,文件数据源,内存排序数据源,外部排序数据源)
--netnode
----server.go      //外部排序节点
--main.go

代码实现:

最小堆数据结构:

package heap

import (
    "fmt"
)

// 最小堆数据结构
// data 存储堆元素
// Cmp 元素比较函数
type Heap struct {
    data []item
    Cmp  HeapCmpFunc // 比较函数,0相等,<0小于,>0大于
}

type HeapCmpFunc func(interface{}, interface{}) int

type item struct {
    data   interface{}
    source []int // 标记data来源,例如[]chan的id
}

func NewHeap(cap int, cmp HeapCmpFunc) Heap {
    return Heap{
        data: make([]item, 0, cap),
        Cmp:  cmp,
    }
}

func (heap *Heap) Len() int {
    return len(heap.data)
}

func (heap *Heap) Print() {
    for i, v := range heap.data {
        fmt.Println(i, v.data)
    }
}

// 向堆中添加新元素
func (heap *Heap) Add(data interface{}, source int) bool {
    for _, v := range heap.data {
        if v.data == nil {
            break
        }
        if heap.Cmp(data, v.data) == 0 { // 堆中有相等数据,只添加数据来源
            v.source = append(v.source, source)
            return true
        }
    }

    idx := heap.Len()
    if idx >= cap(heap.data) {
        heap.scale()
    }
    heap.data = append(heap.data, item{data, []int{source}})
    heap.shiftUp(idx)
    return true
}

// 获取堆顶元素值与来源标记
func (heap *Heap) Get() (data interface{}, source []int) {
    if heap.Len() < 1 {
        return nil, nil
    }

    data = heap.data[0].data
    source = heap.data[0].source
    heap.data = heap.data[1:]
    heap.heapify()
    return
}

func (heap *Heap) heapify() {
    firstParent := (heap.Len() - 1) / 2
    for i := firstParent; i >= 0; i-- {
        heap.shiftDown(i)
    }
}

func (heap *Heap) shiftUp(idx int) {
    for idx > 0 {
        if heap.Cmp(heap.data[idx].data, heap.data[parent(idx)].data) < 0 {
            heap.swap(idx, parent(idx))
            idx = parent(idx)
        } else {
            break
        }
    }
}

func (heap *Heap) shiftDown(idx int) {
    l, r := left(idx), right(idx)
    if r < heap.Len() && heap.Cmp(heap.data[idx].data, heap.data[r].data) > 0 {
        heap.swap(idx, r)
    }
    if l < heap.Len() && heap.Cmp(heap.data[idx].data, heap.data[l].data) > 0 {
        heap.swap(idx, l)
    }
}
func (heap *Heap) swap(i, j int) {
    heap.data[i], heap.data[j] = heap.data[j], heap.data[i]
}

func (heap *Heap) scale() {
    cap := len(heap.data) * 2
    if cap == 0 {
        cap = 8
    }
    data := make([]item, len(heap.data), cap)
    copy(data, heap.data)
    heap.data = data
}

func parent(idx int) int {
    return (idx - 1) / 2
}
func left(idx int) int {
    return 2*idx + 1
}
func right(idx int) int {
    return 2*idx + 2
}
heap.go

多路归并排序算法:

package mergesort

// 基于最小堆的多路归并排序

import (
    "pipeline/heapsort/heap"
    "strings"
)

func cmpInt(a, b interface{}) int {
    return a.(int) - b.(int)
}

func cmpStr(a, b interface{}) int {
    return strings.Compare(a.(string), b.(string))
}

func MergeSortInt(out chan int, ins ...chan int) {
    MergeSort(cmpInt, out, ins...)
}

// 多路归并排序
func MergeSort(cmp heap.HeapCmpFunc, out chan int, ins ...chan int) {
    hp := heap.NewHeap(len(ins), cmp)
    // 构造堆数据
    for idx, in := range ins {
        v, ok := <-in
        if ok {
            hp.Add(v, idx)
        }
    }

    for hp.Len() > 0 {
        // 从堆中读取最小值
        min, sources := hp.Get()
        if min != nil {
            // 填充堆数据
            for _, idx := range sources {
                out <- min.(int)
                v, ok := <-ins[idx]
                if ok {
                    hp.Add(v, idx)
                }
            }
        }
    }
    close(out)
}
mergesort.go

数据源:

package mergesort

import (
    "bufio"
    "encoding/binary"
    "io"
    "log"
    "math/rand"
    "net"
    "os"
    "sort"
)

// 生成指定数目个随机数据写入out通道
func RandomSource(n int) chan int {
    out := make(chan int)
    go func() {
        for n > 0 {
            out <- rand.Int()
            n--
        }
        close(out)
    }()
    return out
}

// 读取指定文件指定偏移的一块数据放入out通道
func ReaderSource(fileName string, offset int64, chunkSize int64) (chan int, error) {
    f, err := os.Open(fileName)
    if err != nil {
        return nil, err
    }

    if offset > 0 {
        _, err = f.Seek(offset, 0)
        if err != nil {
            return nil, err
        }
    }

    out := make(chan int)
    go func() {
        defer f.Close()

        var num int64
        var count int64
        r := bufio.NewReader(f)
        for {
            if chunkSize != -1 && count >= chunkSize {
                break
            }
            err = binary.Read(r, binary.LittleEndian, &num)
            if err != nil {
                if err != io.EOF {
                    log.Println(err)
                }
                break
            }
            out <- int(num)
            count += int64(binary.Size(num))
        }
        close(out)
    }()
    return out, nil
}

// 内存排序
// 接收in通道的数据放入slice中,将排序的数据写入out通道
func InMemSort(in chan int) chan int {
    out := make(chan int)
    go func() {
        var data []int
        for v := range in {
            data = append(data, v)
        }
        sort.Ints(data)
        for _, v := range data {
            out <- v
        }
        close(out)
    }()
    return out
}

// 网络排序
// 将in通道中的数据传输到addr指定的网络节点,由该节点进行排序然后接收排序的数据传入out通道
func NetworkSort(in chan int, addr string) (out chan int, err error) {
    con, err := net.Dial("tcp", addr)
    if err != nil {
        return nil, err
    }
    out = make(chan int)
    go func() {
        defer con.Close()
        bufw := bufio.NewWriter(con)
        // 将待排序数字发送给远端服务
        for v := range in {
            err = binary.Write(bufw, binary.LittleEndian, int64(v))
            if err != nil && err != io.EOF {
                log.Println(err)
                break
            }
        }
        bufw.Flush()
        // 关闭连接的写半边
        tcpCon := con.(*net.TCPConn)
        err = tcpCon.CloseWrite()
        if err != nil {
            log.Println(err)
            return
        }

        // 接收排序后的数据
        bufr := bufio.NewReader(con)
        for {
            var num int64
            err = binary.Read(bufr, binary.LittleEndian, &num)
            if err != nil {
                if err != io.EOF {
                    log.Println(err)
                }
                break
            }
            out <- int(num)
        }
        close(out)
    }()
    return out, nil
}
source.go

外部排序节点:

package main

import (
    "bufio"
    "encoding/binary"
    "flag"
    "fmt"
    "io"
    "log"
    "net"
    "sort"
    "time"
)

var port = flag.Int("p", 8888, "port to listen on this server,default 8888.")

func main() {
    flag.Parse()
    listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
    if err != nil {
        log.Fatal(err)
    }
    for {
        con, err := listener.Accept()
        if err != nil {
            log.Println(err)
            continue
        }
        log.Println(con.RemoteAddr())
        go netSort(con)
    }
}

func netSort(con net.Conn) {
    start := time.Now()
    defer con.Close()
    bufr := bufio.NewReader(con)
    // 读取客户端发送的数据
    var in []int
    var num int64
    for {
        err := binary.Read(bufr, binary.LittleEndian, &num)
        if err != nil {
            if err != io.EOF {
                log.Println(err)
            }
            break
        }
        in = append(in, int(num))
    }
    log.Println("read done.", time.Since(start))

    // 排序
    sort.Ints(in)
    log.Println("sort done.", time.Since(start))

    // 将排好序的数据写回客户端
    bufw := bufio.NewWriter(con)
    for _, v := range in {
        err := binary.Write(bufw, binary.LittleEndian, int64(v))
        if err != nil {
            log.Println(err)
            break
        }
    }
    bufw.Flush()
    tcpCon := con.(*net.TCPConn)
    tcpCon.CloseWrite()
    log.Println("send done.", time.Since(start))
}
server.go

主节点:

package main

import (
    "bufio"
    "encoding/binary"
    "fmt"
    "math/rand"
    "os"
    "pipeline/heapsort/mergesort"
)

func main() {
    //createTestFile("small.in", 10000)
    //createTestFile("large.in", 1000000)
    distributeFile("small.in", 3)
}

// 文件数据源外部排序测试
func distributeFile(filename string, chunkNum int) {
    f, err := os.Open(filename)
    if err != nil {
        fmt.Println(err)
        return
    }
    defer f.Close()
    finfo, err := f.Stat()
    filesize := finfo.Size()
    ins := make([]chan int, chunkNum)
    var offset int64
    chunkSize := int64(filesize/int64(binary.Size(offset))/int64(chunkNum)) * int64(binary.Size(offset))
    for i, _ := range ins {
        if i == chunkNum-1 {
            chunkSize = -1
        }
        fin, err := mergesort.ReaderSource(filename, offset, chunkSize)
        offset += chunkSize
        if err != nil {
            fmt.Println(err)
            return
        }
        in, err := mergesort.NetworkSort(fin, fmt.Sprintf(":808%d", i))
        if err != nil {
            fmt.Println(err)
            return
        }
        ins[i] = in
    }

    out := make(chan int)
    go mergesort.MergeSortInt(out, ins...)

    if testResult(out) {
        fmt.Println("PASS")
    } else {
        fmt.Println("FAIL")
    }
}

// 随机数据源外部排序测试
func distributeRandom() {
    inCount := 3
    numCount := 100
    ins := make([]chan int, inCount)
    for i, _ := range ins {
        in, err := mergesort.NetworkSort(mergesort.RandomSource(numCount), fmt.Sprintf(":808%d", i))
        if err != nil {
            fmt.Println(err)
            return
        }
        ins[i] = in
    }
    out := make(chan int)
    go mergesort.MergeSortInt(out, ins...)

    if testResult(out) {
        fmt.Println("PASS")
    } else {
        fmt.Println("FAIL")
    }
}

// 随机数据源内存排序测试
func small() {
    inCount := 4
    numCount := 100
    ins := make([]chan int, inCount)
    for i, _ := range ins {
        ins[i] = mergesort.InMemSort(mergesort.RandomSource(numCount))
    }
    out := make(chan int)
    go mergesort.MergeSortInt(out, ins...)

    if testResult(out) {
        fmt.Println("PASS")
    } else {
        fmt.Println("FAIL")
    }
}

//新建测试数据源
func createTestFile(filename string, count int) {
    f, err := os.Create(filename)
    if err != nil {
        fmt.Println(err)
        return
    }
    defer f.Close()

    bufw := bufio.NewWriter(f)
    defer bufw.Flush()
    for i := 0; i < count; i++ {
        num := rand.Int()
        err = binary.Write(bufw, binary.LittleEndian, int64(num))
        if err != nil {
            fmt.Println(err)
            return
        }
    }
}

//检查排序结果
func testResult(out chan int) bool {
    var pre int
    var start bool
    var printCOunt int
    for v := range out {
        if printCOunt < 100 {
            fmt.Println(v)
            printCOunt++
        }
        if !start {
            pre = v
            start = true
            continue
        }
        if v < pre {
            for _ = range out {

            }
            return false
        }
    }
    return true
}
main.go

执行效果:

1.准备测试数据文件:main.createTestFile("small.in",10000)

2.开启三个外部排序节点:

netnode -p 8080

netnode -p 8081

netnode -p 8082

3.运行主节点heapsort

外部排序节点:

问题:

设计堆数据结构时考虑到支持不同数据类型,并可以通过指定比较函数实现最大最小堆的转换,但是没有很好的实现方便的数据类型切换,interface{}可以接受所有类型,但chan interface{}却不能接受chan int或chan string。

 

posted @ 2019-03-20 13:10  zerofl-diary  阅读(932)  评论(0编辑  收藏  举报