• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
孙龙 程序员
少时总觉为人易,华年方知立业难
博客园    首页    新随笔    联系   管理    订阅  订阅
网络代理之HTTP代理(golang反向代理、负载均衡算法实现)
网络代理与网络转发的区别 golang实现发现代理 ReverseProxy ReverseProxy功能 4中负载轮训类型实现以及接口封装 拓展中间件支持:限流、熔断实现、权限、数据统计 用ReverseProxy实现一个http代理 负载均衡算法

网络代理于网络转发区别

 

 

 

 

网络代理:

用户不直接连接服务器,网络代理去连接,获取数据后返回给用户

网络转发:

是路由器对报文的转发操作,中间可能对数据包修改

 

网络代理类型:

 

 

 

 

 

正向代理:

 

 实现一个web浏览器代理:

 

 代码实现一个web浏览器代理:

 

 代码实现:

package main

import (
    "fmt"
    "io"
    "net"
    "net/http"
    "strings"
)

type Pxy struct{}

func (p *Pxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    fmt.Printf("Received request %s %s %s\n", req.Method, req.Host, req.RemoteAddr)
    transport := http.DefaultTransport
    // step 1,浅拷贝对象,然后就再新增属性数据
    outReq := new(http.Request)
    *outReq = *req
    if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
        if prior, ok := outReq.Header["X-Forwarded-For"]; ok {
            clientIP = strings.Join(prior, ", ") + ", " + clientIP
        }
        outReq.Header.Set("X-Forwarded-For", clientIP)
    }
    
    // step 2, 请求下游
    res, err := transport.RoundTrip(outReq)
    if err != nil {
        rw.WriteHeader(http.StatusBadGateway)
        return
    }

    // step 3, 把下游请求内容返回给上游
    for key, value := range res.Header {
        for _, v := range value {
            rw.Header().Add(key, v)
        }
    }
    rw.WriteHeader(res.StatusCode)
    io.Copy(rw, res.Body)
    res.Body.Close()
}

func main() {
    fmt.Println("Serve on :8080")
    http.Handle("/", &Pxy{})
    http.ListenAndServe("0.0.0.0:8080", nil)
}

 

 

反向代理:

 

 

如何实现一个反向代理:

  • 这个功能比较复杂,我们先实现一个简版的http反向代理。
  • 代理接收客户端请求,更改请求结构体信息
  • 通过一定的负载均衡算法获取下游服务地址
  • 把请求发送到下游服务器,并获取返回的内容
  • 对返回的内容做一些处理,然后返回给客户端

 

启动两个http服务(真是服务地址)

127.0.0.1:2003
127.0.0.1:2004
package main

import (
    "fmt"
    "io"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    rs1 := &RealServer{Addr: "127.0.0.1:2003"}
    rs1.Run()
    rs2 := &RealServer{Addr: "127.0.0.1:2004"}
    rs2.Run()

    //监听关闭信号
    quit := make(chan os.Signal)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
}

type RealServer struct {
    Addr string
}

func (r *RealServer) Run() {
    log.Println("Starting httpserver at " + r.Addr)
    mux := http.NewServeMux()
    mux.HandleFunc("/", r.HelloHandler)
    mux.HandleFunc("/base/error", r.ErrorHandler)
    server := &http.Server{
        Addr:         r.Addr,
        WriteTimeout: time.Second * 3,
        Handler:      mux,
    }
    go func() {
        log.Fatal(server.ListenAndServe())
    }()
}

func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
    //127.0.0.1:8008/abc?sdsdsa=11
    //r.Addr=127.0.0.1:8008
    //req.URL.Path=/abc
    fmt.Println(req.Host)
    upath := fmt.Sprintf("http://%s%s\n", r.Addr, req.URL.Path)
    realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v\n", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))

    io.WriteString(w, upath)
    io.WriteString(w, realIP)
}

func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
    upath := "error handler"
    w.WriteHeader(500)
    io.WriteString(w, upath)
}
real_server

 

 

启动一个代理服务

代理服务 127.0.0.1:2002(此处代码并没有使用负载均衡算法,只是简单地固定代理到其中一个服务器)

package main

import (
    "bufio"
    "log"
    "net/http"
    "net/url"
)

var (
    proxy_addr = "http://127.0.0.1:2003"
    port       = "2002"
)

func handler(w http.ResponseWriter, r *http.Request) {
    //step 1 解析代理地址,并更改请求体的协议和主机
    proxy, err := url.Parse(proxy_addr)
    r.URL.Scheme = proxy.Scheme
    r.URL.Host = proxy.Host

    //step 2 请求下游
    transport := http.DefaultTransport
    resp, err := transport.RoundTrip(r)
    if err != nil {
        log.Print(err)
        return
    }

    //step 3 把下游请求内容返回给上游
    for k, vv := range resp.Header {
        for _, v := range vv {
            w.Header().Add(k, v)
        }
    }
    defer resp.Body.Close()
    bufio.NewReader(resp.Body).WriteTo(w)
}

func main() {
    http.HandleFunc("/", handler)
    log.Println("Start serving on port " + port)
    err := http.ListenAndServe(":"+port, nil)
    if err != nil {
        log.Fatal(err)
    }
}
reverse_proxy

 

 

用户访问127.0.0.1:2002   反向代理到  127.0.0.1:2003

 

http代理

上面演示的是一个简版的http代理,不具备一下功能:

  • 错误回调及错误日志处理
  • 更改代理返回内容
  • 负载均衡
  • url重写
  • 限流、熔断、降级

 

 

用golang官方提供的ReverseProxy实现一个http代理

  • ReverseProxy功能点
  • ReverseProxy实例
  • ReverseProxy源码实现

 

 

拓展ReverseProxy功能

  • 4中负载轮训类型实现以及接口封装
  • 拓展中间件支持:限流、熔断实现、权限、数据统计

 

用ReverseProxy实现一个http代理:

 

 

package main

import (
    "log"
    "net/http"
    "net/http/httputil"
    "net/url"
)

var addr = "127.0.0.1:2002"

func main() {
    //127.0.0.1:2002/xxx  => 127.0.0.1:2003/base/xxx
    //127.0.0.1:2003/base/xxx
    rs1 := "http://127.0.0.1:2003/base"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }
    proxy := httputil.NewSingleHostReverseProxy(url1)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

 

ReverseProxy修改返回的内容

重写 

httputil.NewSingleHostReverseProxy(url1)
package main

import (
	"bytes"
	"errors"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
	"regexp"
	"strings"
)

var addr = "127.0.0.1:2002"

func main() {
	//127.0.0.1:2002/xxx
	//127.0.0.1:2003/base/xxx
	rs1 := "http://127.0.0.1:2003/base"
	url1, err1 := url.Parse(rs1)
	if err1 != nil {
		log.Println(err1)
	}
	proxy := NewSingleHostReverseProxy(url1)
	log.Println("Starting httpserver at " + addr)
	log.Fatal(http.ListenAndServe(addr, proxy))
}

func NewSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
	//http://127.0.0.1:2002/dir?name=123
	//RayQuery: name=123
	//Scheme: http
	//Host: 127.0.0.1:2002
	targetQuery := target.RawQuery
	director := func(req *http.Request) {
		//url_rewrite
		//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
		//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
		//127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
		re, _ := regexp.Compile("^/dir(.*)");
		req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host

		//target.Path : /base
		//req.URL.Path : /dir
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
		if targetQuery == "" || req.URL.RawQuery == "" {
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
		} else {
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
		}
		if _, ok := req.Header["User-Agent"]; !ok {
			req.Header.Set("User-Agent", "")
		}
	}
	modifyFunc := func(res *http.Response) error {
		if res.StatusCode != 200 {
			return errors.New("error statusCode")
			oldPayload, err := ioutil.ReadAll(res.Body)
			if err != nil {
				return err
			}
			newPayLoad := []byte("hello " + string(oldPayload))
			res.Body = ioutil.NopCloser(bytes.NewBuffer(newPayLoad))
			res.ContentLength = int64(len(newPayLoad))
			res.Header.Set("Content-Length", fmt.Sprint(len(newPayLoad)))
		}
		return nil
	}
	errorHandler := func(res http.ResponseWriter, req *http.Request, err error) {
		res.Write([]byte(err.Error()))
	}
	return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyFunc, ErrorHandler: errorHandler}
}

func singleJoiningSlash(a, b string) string {
	aslash := strings.HasSuffix(a, "/")
	bslash := strings.HasPrefix(b, "/")
	switch {
	case aslash && bslash:
		return a + b[1:]
	case !aslash && !bslash:
		return a + "/" + b
	}
	return a + b
}

  

ReverseProxy补充知识:

特殊Header头:X-Forward-For、X-Real-Ip、Connection、TE、Trailer

第一代理取出标准的逐段传输头(HOP-BY-HOP)

X-Forward-For

  • 记录最后直连实际服务器之前,整个代理过程
  • 可能会被伪造

 

 

X-Real-Ip

  • 请求实际服务器的IP
  • 每过一层代理都会被覆盖掉,只需要第一代里设置转发
  • 不会被伪造

 代码实现:

package main

import (
    "bytes"
    "io/ioutil"
    "log"
    "math/rand"
    "net"
    "net/http"
    "net/http/httputil"
    "net/url"
    "regexp"
    "strconv"
    "strings"
    "time"
)

var addr = "127.0.0.1:2001"

func main() {
    rs1 := "http://127.0.0.1:2002"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }
    urls := []*url.URL{url1}
    proxy := NewMultipleHostsReverseProxy(urls)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

var transport = &http.Transport{
    DialContext: (&net.Dialer{
        Timeout:   30 * time.Second, //连接超时
        KeepAlive: 30 * time.Second, //长连接超时时间
    }).DialContext,
    MaxIdleConns:          100,              //最大空闲连接
    IdleConnTimeout:       90 * time.Second, //空闲超时时间
    TLSHandshakeTimeout:   10 * time.Second, //tls握手超时时间
    ExpectContinueTimeout: 1 * time.Second,  //100-continue 超时时间
}

func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
    //请求协调者
    director := func(req *http.Request) {
        //url_rewrite
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
        //127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
        re, _ := regexp.Compile("^/dir(.*)");
        req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

        //随机负载均衡
        targetIndex := rand.Intn(len(targets))
        target := targets[targetIndex]
        targetQuery := target.RawQuery
        req.URL.Scheme = target.Scheme
        req.URL.Host = target.Host

        // url地址重写:重写前:/aa 重写后:/base/aa
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
        if _, ok := req.Header["User-Agent"]; !ok {
            req.Header.Set("User-Agent", "user-agent")
        }
        //只在第一代理中设置此header头
        req.Header.Set("X-Real-Ip", req.RemoteAddr)
    }
    //更改内容
    modifyFunc := func(resp *http.Response) error {
        //请求以下命令:curl 'http://127.0.0.1:2002/error'
        if resp.StatusCode != 200 {
            //获取内容
            oldPayload, err := ioutil.ReadAll(resp.Body)
            if err != nil {
                return err
            }
            //追加内容
            newPayload := []byte("StatusCode error:" + string(oldPayload))
            resp.Body = ioutil.NopCloser(bytes.NewBuffer(newPayload))
            resp.ContentLength = int64(len(newPayload))
            resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(newPayload)), 10))
        }
        return nil
    }
    //错误回调 :关闭real_server时测试,错误回调
    errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
        http.Error(w, "ErrorHandler error:"+err.Error(), 500)
    }
    return &httputil.ReverseProxy{
        Director:       director,
        Transport:      transport,
        ModifyResponse: modifyFunc,
        ErrorHandler:   errFunc}
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash:
        return a + "/" + b
    }
    return a + b
}
第一层代理

 

 

第二层代理

package main

import (
    "bytes"
    "compress/gzip"
    "io/ioutil"
    "log"
    "math/rand"
    "net"
    "net/http"
    "net/http/httputil"
    "net/url"
    "regexp"
    "strconv"
    "strings"
    "time"
)

var addr = "127.0.0.1:2002"

func main() {
    //rs1 := "http://www.baidu.com"
    rs1 := "http://127.0.0.1:2003"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }

    //rs2 := "http://www.baidu.com"
    rs2 := "http://127.0.0.1:2004"
    url2, err2 := url.Parse(rs2)
    if err2 != nil {
        log.Println(err2)
    }
    urls := []*url.URL{url1, url2}
    proxy := NewMultipleHostsReverseProxy(urls)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

var transport = &http.Transport{
    DialContext: (&net.Dialer{
        Timeout:   30 * time.Second, //连接超时
        KeepAlive: 30 * time.Second, //长连接超时时间
    }).DialContext,
    MaxIdleConns:          100,              //最大空闲连接
    IdleConnTimeout:       90 * time.Second, //空闲超时时间
    TLSHandshakeTimeout:   10 * time.Second, //tls握手超时时间
    ExpectContinueTimeout: 1 * time.Second,  //100-continue 超时时间
}

func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
    //请求协调者
    director := func(req *http.Request) {
        //url_rewrite
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
        //127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
        re, _ := regexp.Compile("^/dir(.*)");
        req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

        //随机负载均衡
        targetIndex := rand.Intn(len(targets))
        target := targets[targetIndex]
        targetQuery := target.RawQuery
        req.URL.Scheme = target.Scheme
        req.URL.Host = target.Host

        //todo 部分章节补充1
        //todo 当对域名(非内网)反向代理时需要设置此项。当作后端反向代理时不需要
        req.Host = target.Host

        // url地址重写:重写前:/aa 重写后:/base/aa
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
        if _, ok := req.Header["User-Agent"]; !ok {
            req.Header.Set("User-Agent", "user-agent")
        }
        //只在第一代理中设置此header头
        //req.Header.Set("X-Real-Ip", req.RemoteAddr)
    }
    //更改内容
    modifyFunc := func(resp *http.Response) error {
        //请求以下命令:curl 'http://127.0.0.1:2002/error'
        //todo 部分章节功能补充2
        //todo 兼容websocket
        if strings.Contains(resp.Header.Get("Connection"), "Upgrade") {
            return nil
        }
        var payload []byte
        var readErr error

        //todo 部分章节功能补充3
        //todo 兼容gzip压缩
        if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
            gr, err := gzip.NewReader(resp.Body)
            if err != nil {
                return err
            }
            payload, readErr = ioutil.ReadAll(gr)
            resp.Header.Del("Content-Encoding")
        } else {
            payload, readErr = ioutil.ReadAll(resp.Body)
        }
        if readErr != nil {
            return readErr
        }

        //异常请求时设置StatusCode
        if resp.StatusCode != 200 {
            payload = []byte("StatusCode error:" + string(payload))
        }

        //todo 部分章节功能补充4
        //todo 因为预读了数据所以内容重新回写
        resp.Body = ioutil.NopCloser(bytes.NewBuffer(payload))
        resp.ContentLength = int64(len(payload))
        resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(payload)), 10))
        return nil
    }
    //错误回调 :关闭real_server时测试,错误回调
    errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
        http.Error(w, "ErrorHandler error:"+err.Error(), 500)
    }
    return &httputil.ReverseProxy{
        Director:       director,
        Transport:      transport,
        ModifyResponse: modifyFunc,
        ErrorHandler:   errFunc}
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash:
        return a + "/" + b
    }
    return a + b
}
View Code

 

 

实际服务器:

package main

import (
    "fmt"
    "io"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    rs1 := &RealServer{Addr: "127.0.0.1:2003"}
    rs1.Run()
    rs2 := &RealServer{Addr: "127.0.0.1:2004"}
    rs2.Run()

    //监听关闭信号
    quit := make(chan os.Signal)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
}

type RealServer struct {
    Addr string
}

func (r *RealServer) Run() {
    log.Println("Starting httpserver at " + r.Addr)
    mux := http.NewServeMux()
    mux.HandleFunc("/", r.HelloHandler)
    mux.HandleFunc("/base/error", r.ErrorHandler)
    server := &http.Server{
        Addr:         r.Addr,
        WriteTimeout: time.Second * 3,
        Handler:      mux,
    }
    go func() {
        log.Fatal(server.ListenAndServe())
    }()
}

func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
    //127.0.0.1:8008/abc?sdsdsa=11
    //r.Addr=127.0.0.1:8008
    //req.URL.Path=/abc
    fmt.Println(req.Host)
    upath := fmt.Sprintf("http://%s%s\n", r.Addr, req.URL.Path)
    realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v\n", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))

    io.WriteString(w, upath)
    io.WriteString(w, realIP)
}

func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
    upath := "error handler"
    w.WriteHeader(500)
    io.WriteString(w, upath)
}
View Code

 

负载均衡策略:

  • 随机负载
  •   随机挑选目标服务器ip
  • 轮询负载
  •   ABC三台服务器,ABCABC一次轮询
  • 加权负载
  •   给目标设置访问权重,按照权重轮询
  • 一致性hash负载
  •   请求固定的url访问固定的ip

 

随机负载:

package load_balance

import (
    "errors"
    "fmt"
    "math/rand"
    "strings"
)

type RandomBalance struct {
    curIndex int
    rss      []string
    //观察主体
    conf LoadBalanceConf
}

func (r *RandomBalance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    r.rss = append(r.rss, addr)
    return nil
}

func (r *RandomBalance) Next() string {
    if len(r.rss) == 0 {
        return ""
    }
    r.curIndex = rand.Intn(len(r.rss))
    return r.rss[r.curIndex]
}

func (r *RandomBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *RandomBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *RandomBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = []string{}
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
random.go

 

package load_balance

import (
    "fmt"
    "testing"
)

func TestRandomBalance(t *testing.T) {
    rb := &RandomBalance{}
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
random_test

 

=== RUN   TestRandomBalance
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2005
127.0.0.1:2007
127.0.0.1:2004
127.0.0.1:2006
127.0.0.1:2003
127.0.0.1:2003
127.0.0.1:2004
--- PASS: TestRandomBalance (0.00s)
PASS

 

轮询负载:

package load_balance

import (
    "errors"
    "fmt"
    "strings"
)

type RoundRobinBalance struct {
    curIndex int
    rss      []string
    //观察主体
    conf LoadBalanceConf
}

func (r *RoundRobinBalance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    r.rss = append(r.rss, addr)
    return nil
}

func (r *RoundRobinBalance) Next() string {
    if len(r.rss) == 0 {
        return ""
    }
    lens := len(r.rss) //5
    if r.curIndex >= lens {
        r.curIndex = 0
    }
    curAddr := r.rss[r.curIndex]
    r.curIndex = (r.curIndex + 1) % lens
    return curAddr
}

func (r *RoundRobinBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *RoundRobinBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *RoundRobinBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = []string{}
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
round_tobin

 

package load_balance

import (
    "fmt"
    "testing"
)

func Test_main(t *testing.T) {
    rb := &RoundRobinBalance{}
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
round_robin_test

 

=== RUN   Test_main
127.0.0.1:2003
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2006
127.0.0.1:2007
127.0.0.1:2003
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2006
--- PASS: Test_main (0.00s)
PASS

 

加权负载均衡:

  • Weight
  • 初始化时对节点约定的权重
  • currentWeight
  • 节点临时权重,每轮都会变化
  • effectiveWeight
  • 节点有效权重,默认与Weight相同
  • totalWeight
  • 所有节点有效权重之和:sum(effectiveWeight)

 

type WeightNode struct {
    addr            string
    weight          int //权重值
    currentWeight   int //节点当前权重
    effectiveWeight int //有效权重
}

 

  • 1,currentWeight = currentWeight + effectiveWeight
  • 2,选中最大的currentWeight节点为选中的节点
  • 3,currentWeight = currentWeight-totalWeight(4+3+2=9)

 

 计算方法如下:

第一次:

  •   currentWeight = currentWeight + effectiveWeight
  •     currentWeight     {A=4+4,B=3+3,C=2+2}   ==  {A=8,B=6,C=4}
  •   选中最大的currentWeight节点为选中的节点
  •     A最大 此时作为节点
  •   currentWeight = currentWeight-totalWeight(4+3+2=9)  【选中的节点currentWeight = currentWeight-totalWeight】
  •     currentWeight  {A=8-9,B=6,C=4}  == {A=-1,B=6,C=4}

第二次:{A=-1,B=6,C=4} 开始

  •   currentWeight = currentWeight + effectiveWeight
  •     currentWeight     {A=-1+4,B=6+3,C=4+2}   ==  {A=3,B=9,C=6}
  •   选中最大的currentWeight节点为选中的节点
  •     B最大 此时作为节点
  •   选中的节点currentWeight = currentWeight-totalWeight(4+3+2=9)
  •     currentWeight  {A=3,B=9-9,C=6}  == {A=3,B=0,C=6}

。。。。。。。以此类推。。。。。。。。。

 

 

package load_balance

import (
    "errors"
    "fmt"
    "strconv"
    "strings"
)

type WeightRoundRobinBalance struct {
    curIndex int
    rss      []*WeightNode
    rsw      []int
    //观察主体
    conf LoadBalanceConf
}

type WeightNode struct {
    addr            string
    weight          int //权重值
    currentWeight   int //节点当前权重
    effectiveWeight int //有效权重
}

func (r *WeightRoundRobinBalance) Add(params ...string) error {
    if len(params) != 2 {
        return errors.New("param len need 2")
    }
    parInt, err := strconv.ParseInt(params[1], 10, 64)
    if err != nil {
        return err
    }
    node := &WeightNode{addr: params[0], weight: int(parInt)}
    node.effectiveWeight = node.weight
    r.rss = append(r.rss, node)
    return nil
}

func (r *WeightRoundRobinBalance) Next() string {
    total := 0
    var best *WeightNode
    for i := 0; i < len(r.rss); i++ {
        w := r.rss[i]
        //step 1 统计所有有效权重之和
        total += w.effectiveWeight

        //step 2 变更节点临时权重为的节点临时权重+节点有效权重
        w.currentWeight += w.effectiveWeight

        //step 3 有效权重默认与权重相同,通讯异常时-1, 通讯成功+1,直到恢复到weight大小
        if w.effectiveWeight < w.weight {
            w.effectiveWeight++
        }
        //step 4 选择最大临时权重点节点
        if best == nil || w.currentWeight > best.currentWeight {
            best = w
        }
    }
    if best == nil {
        return ""
    }
    //step 5 变更临时权重为 临时权重-有效权重之和
    best.currentWeight -= total
    return best.addr
}

func (r *WeightRoundRobinBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *WeightRoundRobinBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *WeightRoundRobinBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
weight_tound_robin.go
package load_balance

import (
    "fmt"
    "testing"
)

func TestLB(t *testing.T) {
    rb := &WeightRoundRobinBalance{}
    rb.Add("127.0.0.1:2003", "4") //0
    rb.Add("127.0.0.1:2004", "3") //1
    rb.Add("127.0.0.1:2005", "2") //2

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
test

 

一致性hash(ip_hash、url_hash)

为了解决平衡性:引入了虚拟节点概念(把个节点 均匀的覆盖到环上)

package load_balance

import (
    "errors"
    "fmt"
    "hash/crc32"
    "sort"
    "strconv"
    "strings"
    "sync"
)

type Hash func(data []byte) uint32

type UInt32Slice []uint32

func (s UInt32Slice) Len() int {
    return len(s)
}

func (s UInt32Slice) Less(i, j int) bool {
    return s[i] < s[j]
}

func (s UInt32Slice) Swap(i, j int) {
    s[i], s[j] = s[j], s[i]
}

type ConsistentHashBanlance struct {
    mux      sync.RWMutex
    hash     Hash
    replicas int               //复制因子 虚拟节点数
    keys     UInt32Slice       //已排序的节点hash切片 映射在环上的虚拟节点
    hashMap  map[uint32]string //节点哈希和Key的map,键是hash值,值是节点key

    //观察主体
    conf LoadBalanceConf
}

func NewConsistentHashBanlance(replicas int, fn Hash) *ConsistentHashBanlance {
    m := &ConsistentHashBanlance{
        replicas: replicas,//复制因子 虚拟节点数
        hash:     fn,
        hashMap:  make(map[uint32]string),
    }
    if m.hash == nil {
        //最多32位,保证是一个2^32-1环
        m.hash = crc32.ChecksumIEEE
    }
    return m
}

// 验证是否为空
func (c *ConsistentHashBanlance) IsEmpty() bool {
    return len(c.keys) == 0
}

// Add 方法用来添加缓存节点,参数为节点key,比如使用IP
func (c *ConsistentHashBanlance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    c.mux.Lock()
    defer c.mux.Unlock()
    // 结合复制因子计算所有虚拟节点的hash值,并存入m.keys中,同时在m.hashMap中保存哈希值和key的映射
    for i := 0; i < c.replicas; i++ {
        hash := c.hash([]byte(strconv.Itoa(i) + addr))
        c.keys = append(c.keys, hash)
        c.hashMap[hash] = addr
    }
    // 对所有虚拟节点的哈希值进行排序,方便之后进行二分查找
    sort.Sort(c.keys)
    return nil
}

// Get 方法根据给定的对象获取最靠近它的那个节点
func (c *ConsistentHashBanlance) Get(key string) (string, error) {
    if c.IsEmpty() {
        return "", errors.New("node is empty")
    }
    hash := c.hash([]byte(key))

    // 通过二分查找获取最优节点,第一个"服务器hash"值大于"数据hash"值的就是最优"服务器节点"
    idx := sort.Search(len(c.keys), func(i int) bool { return c.keys[i] >= hash })

    // 如果查找结果 大于 服务器节点哈希数组的最大索引,表示此时该对象哈希值位于最后一个节点之后,那么放入第一个节点中
    if idx == len(c.keys) {
        idx = 0
    }
    c.mux.RLock()
    defer c.mux.RUnlock()
    return c.hashMap[c.keys[idx]], nil
}

func (c *ConsistentHashBanlance) SetConf(conf LoadBalanceConf) {
    c.conf = conf
}

func (c *ConsistentHashBanlance) Update() {
    if conf, ok := c.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        c.mux.Lock()
        defer c.mux.Unlock()
        c.keys = nil
        c.hashMap = nil
        for _, ip := range conf.GetConf() {
            c.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := c.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        c.mux.Lock()
        defer c.mux.Unlock()
        c.keys = nil
        c.hashMap = nil
        for _, ip := range conf.GetConf() {
            c.Add(strings.Split(ip, ",")...)
        }
    }
}
consistent_hash.go

 

package load_balance

import (
    "fmt"
    "testing"
)

func TestNewConsistentHashBanlance(t *testing.T) {
    rb := NewConsistentHashBanlance(10, nil)
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    //url hash
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/error"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/changepwd"))

    //ip hash
    fmt.Println(rb.Get("127.0.0.1"))
    fmt.Println(rb.Get("192.168.0.1"))
    fmt.Println(rb.Get("127.0.0.1"))
}
test.go

 

工厂方法简单封装上述几种拒载均衡调用:

interface.go

package load_balance

type LoadBalance interface {
    Add(...string) error
    Get(string) (string, error)

    //后期服务发现补充
    Update()
}

 

 

factory.go

package load_balance

type LbType int

const (
    LbRandom LbType = iota
    LbRoundRobin
    LbWeightRoundRobin
    LbConsistentHash
)

func LoadBanlanceFactory(lbType LbType) LoadBalance {
    switch lbType {
    case LbRandom:
        return &RandomBalance{}
    case LbConsistentHash:
        return NewConsistentHashBanlance(10, nil)
    case LbRoundRobin:
        return &RoundRobinBalance{}
    case LbWeightRoundRobin:
        return &WeightRoundRobinBalance{}
    default:
        return &RandomBalance{}
    }
}

 

 

调用:

func main() {
    rb := load_balance.LoadBanlanceFactory(load_balance.LbWeightRoundRobin)
    if err := rb.Add("http://127.0.0.1:2003/base", "10"); err != nil {
        log.Println(err)
    }
    if err := rb.Add("http://127.0.0.1:2004/base", "20"); err != nil {
        log.Println(err)
    }
   // 。。。。。。。。。。。。。。。
}

 

本文来自博客园,作者:孙龙-程序员,转载请注明原文链接:https://www.cnblogs.com/sunlong88/p/13512362.html

posted on 2020-08-16 14:01  孙龙-程序员  阅读(1083)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3