package http_api
import (
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
type deadlinedConn struct { Timeout time.Duration
net.Conn
}
func (c *deadlinedConn) Read(b []byte) (n int, err error) { return c.Conn.Read(b)
}
func (c *deadlinedConn) Write(b []byte) (n int, err error) { return c.Conn.Write(b)
}
// A custom http.Transport with support for deadline timeouts
func NewDeadlineTransport(connectTimeout time.Duration, requestTimeout time.Duration) *http.Transport { transport := &http.Transport{ Dial: func(netw, addr string) (net.Conn, error) { c, err := net.DialTimeout(netw, addr, connectTimeout)
if err != nil { return nil, err
}
return &deadlinedConn{connectTimeout, c}, nil },
ResponseHeaderTimeout: requestTimeout,
}
return transport
}
type Client struct { c *http.Client
}
func NewClient(tlsConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) *Client { transport := NewDeadlineTransport(connectTimeout, requestTimeout)
transport.TLSClientConfig = tlsConfig
return &Client{ c: &http.Client{ Transport: transport,
Timeout: requestTimeout,
},
}
}
// NegotiateV1 is a helper function to perform a v1 HTTP request
// and fallback to parsing the old backwards-compatible response format
// storing the result in the value pointed to by v.
//
// TODO: deprecated, remove in 1.0 (replace calls with GETV1)
func (c *Client) NegotiateV1(endpoint string, v interface{}) error {retry:
req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return err
}
req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
resp, err := c.c.Do(req)
if err != nil { return err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil { return err
}
if resp.StatusCode != 200 { if resp.StatusCode == 403 && !strings.HasPrefix(endpoint, "https") { endpoint, err = httpsEndpoint(endpoint, body)
if err != nil { return err
}
goto retry
}
return fmt.Errorf("got response %s %q", resp.Status, body) }
if len(body) == 0 { body = []byte("{}") }
// unwrap pre-1.0 api response
if resp.Header.Get("X-NSQ-Content-Type") != "nsq; version=1.0" { var u struct { StatusCode int64 `json:"status_code"`
Data json.RawMessage `json:"data"`
}
err := json.Unmarshal(body, u)
if err != nil { return err
}
if u.StatusCode != 200 { return fmt.Errorf("got 200 response, but api status code of %d", u.StatusCode) }
body = u.Data
}
return json.Unmarshal(body, v)
}
// GETV1 is a helper function to perform a V1 HTTP request
// and parse our NSQ daemon's expected response format, with deadlines.
func (c *Client) GETV1(endpoint string, v interface{}) error {retry:
req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return err
}
req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
resp, err := c.c.Do(req)
if err != nil { return err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil { return err
}
if resp.StatusCode != 200 { if resp.StatusCode == 403 && !strings.HasPrefix(endpoint, "https") { endpoint, err = httpsEndpoint(endpoint, body)
if err != nil { return err
}
goto retry
}
return fmt.Errorf("got response %s %q", resp.Status, body) }
err = json.Unmarshal(body, &v)
if err != nil { return err
}
return nil
}
// PostV1 is a helper function to perform a V1 HTTP request
// and parse our NSQ daemon's expected response format, with deadlines.
func (c *Client) POSTV1(endpoint string) error {retry:
req, err := http.NewRequest("POST", endpoint, nil) if err != nil { return err
}
req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
resp, err := c.c.Do(req)
if err != nil { return err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil { return err
}
if resp.StatusCode != 200 { if resp.StatusCode == 403 && !strings.HasPrefix(endpoint, "https") { endpoint, err = httpsEndpoint(endpoint, body)
if err != nil { return err
}
goto retry
}
return fmt.Errorf("got response %s %q", resp.Status, body) }
return nil
}
func httpsEndpoint(endpoint string, body []byte) (string, error) { var forbiddenResp struct { HTTPSPort int `json:"https_port"`
}
err := json.Unmarshal(body, &forbiddenResp)
if err != nil { return "", err
}
u, err := url.Parse(endpoint)
if err != nil { return "", err
}
host, _, err := net.SplitHostPort(u.Host)
if err != nil { return "", err
}
u.Scheme = "https"
u.Host = net.JoinHostPort(host, strconv.Itoa(forbiddenResp.HTTPSPort))
return u.String(), nil
}