Tornado源码解析一

Tornado 异步DNS解析 

DefaultExecutorResolver 源码详解

Tornado官方解释:http://www.tornadoweb.org/en/stable/netutil.html#tornado.netutil.Resolver

import asyncio
import socket
from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable


def _resolve_addr(host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC) -> List[Tuple[int, Any]]:
    addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)  # 获得DNS解析结果
    results = []
    for fam, socktype, proto, canonname, address in addrinfo:
        results.append((fam, address))
    return results


class DefaultExecutorResolver(object):  # Tornado默认域名解析器
    async def resolve(self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC) -> List[Tuple[int, Any]]:
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(None, _resolve_addr, host, port, family)  # 在线程池执行器中执行该解析任务
        return result


async def test_resolver():
    resolver = DefaultExecutorResolver()
    tmp = await resolver.resolve("www.baidu.com", 80)
    print(tmp)
    # 结果: [(<AddressFamily.AF_INET: 2>, ('220.181.38.150', 80)), (<AddressFamily.AF_INET: 2>, ('220.181.38.149', 80))]


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_resolver())

详细请看:https://docs.python.org/3.6/library/asyncio-eventloop.html#executor

Call a function in an Executor (pool of threads or pool of processes). By default, an event loop uses a thread pool executor (ThreadPoolExecutor).

coroutine AbstractEventLoop.run_in_executor(executorfunc*args)

Arrange for a func to be called in the specified executor.

The executor argument should be an Executor instance. The default executor is used if executor is None.

Use functools.partial to pass keywords to the *func*.

This method is a coroutine.

Call_at、Call_later、Call_soon示例

  • call_soon:尽快执行回调函数
  • call_at : 在某个时刻执行回调函数
  • call_later: 多少时间后执行回调函数
import asyncio
from datetime import datetime


def test_call_at():
    print("时间:%s, I am call_at" % datetime.now())


def test_call_later():
    print("时间:%s, I am call_later" % datetime.now())


def test_call_soon():
    print("时间:%s, I am call_soon" % datetime.now())


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    print("时间:%s, I am start_time" % datetime.now())
    loop.call_at(loop.time() + 2, test_call_at)  # 在某个时间点执行,loop.time()返回loop当前时间
    loop.call_later(3, test_call_later)  # 3秒后执行
    loop.call_soon(test_call_soon)    # 尽快执行
    try:
        loop.run_forever()
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()

tornado.gen.sleep    功能:异步睡眠

import asyncio
from asyncio import Future
from datetime import datetime


def sleep(delay):
    future = Future()
    loop = asyncio.get_event_loop()
    loop.call_later(delay, lambda: future.set_result(None))
    return future


async def test_sleep():
    print(datetime.now())
    await sleep(3)
    print(datetime.now())


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_sleep())

tornado.gen.multi  功能:Runs multiple asynchronous operations in parallel

Demo示例

import asyncio
import time
from tornado import gen
from datetime import datetime


async def test_task(name: str = None):
    print("时间:%s, 任务%s开始运行..." % (str(datetime.now()), name))
    # await gen.sleep(3)  # 总执行时间3秒左右
    time.sleep(3)  # 总执行时间6秒左右


async def test_multi():
    print("时间:%s, 多任务开始运行..." % str(datetime.now()))
    await gen.multi([test_task(name='1'), test_task(name='2')])
    print("时间:%s, 多任务结束运行" % str(datetime.now()))


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_multi())

 

自定义muti

import sys
import asyncio

from tornado import gen
from datetime import datetime
from asyncio import Future, ensure_future
from typing import Any, Tuple, Dict, List, Union, Awaitable, Coroutine, Set


def multi(children: Union[List[Coroutine]]) -> "Union[Future[List], Future[Dict]]":
    assert all(isinstance(i, Coroutine) for i in children)
    children_futs = list(map(ensure_future, children))
    unfinished_children = set(children_futs)  #
    future = Future()
    if not children_futs:
        future.set_result([])

    def callback(fut: Future) -> None:
        unfinished_children.remove(fut)
        if not unfinished_children:  # 所有任务均已完成
            result_list = []
            for f in children_futs:
                try:
                    result_list.append(f.result())
                except Exception as e:
                    future.set_exception(sys.exc_info())
            future.set_result(result_list)  # 返回结果列表

    listening = set()  # type: Set[Future]
    for f in children_futs:
        if f not in listening:
            listening.add(f)
            f.add_done_callback(callback)
    return future


async def test_task(name: str = None):
    print("时间:%s, 任务%s开始运行..." % (str(datetime.now()), name))
    await gen.sleep(3)  # 总执行时间3秒左右
    return name


async def test_multi():
    print("时间:%s, 多任务开始运行..." % str(datetime.now()))
    ret = await multi([test_task(name='1'), test_task(name='2')])
    print(ret)
    print("时间:%s, 多任务结束运行" % str(datetime.now()))


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_multi())

 

tornado.gen.with_timeout    功能:对某个任务指定超时时间,超时则返回超时异常

自定义至custom.gen.py

import tornado.gen
import asyncio
from asyncio import Future
from datetime import datetime, timedelta
from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable


async def busy_task(delay=3):
    print("task starting...")
    await tornado.gen.sleep(delay)
    return "task result"


# 自定义 gen.with_timeout
def with_timeout(timeout: Union[float, timedelta], future: Awaitable):
    future_converted = asyncio.ensure_future(future)  # Wrap a coroutine or an awaitable in a future,用来接受迟到的结果
    result = Future()  # 构造一个新的Future对象,用来接受超时异常或未超时的正确结果

    def copy(_future) -> None:
        assert _future is future_converted
        if result.done():  # 超时后结果为超时异常
            return
        # 没有超时则设置正确的结果
        result.set_result(future_converted.result())

    # 任务完成后把结果拷贝给result
    future_converted.add_done_callback(copy)

    def error_callback(future: Future) -> None:
        try:
            overdue_result = future.result()
            print("迟到的结果:%s" % overdue_result)
        except asyncio.CancelledError as e:
            print(str(e))
        except Exception as e:
            print(str(e))

    def timeout_callback() -> None:
        if not result.done():
            # 超时后future对象设置超时异常
            result.set_exception(TimeoutError("Timeout"))
        future_converted.add_done_callback(error_callback)  #

    # 超时后执行超时回调函数
    timeout_handle = loop.call_at(loop.time() + timeout, timeout_callback)

    # 如果任务在超时前完成,则取消超时回调函数
    future_converted.add_done_callback(lambda future: timeout_handle.cancel())
    return result


async def main():
    try:
        result = await with_timeout(3, busy_task(delay=1))
        print("正常的结果:%s" % result)
    except TimeoutError as e:
        print(str(e))
    await tornado.gen.sleep(5)


if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    try:
        loop.run_until_complete(main())
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()

tornado.tcpclient.TCPClient源码解析(可以使用网络调试助手辅助测试)

TCPClient测试示例

import asyncio
import tornado.tcpclient


async def test_tcp_client():
    test_stream = await tornado.tcpclient.TCPClient().connect("127.0.0.1", 7023,timeout=3)
    await test_stream.write(b'hello server')
    response = await  test_stream.read_until(b"\n")
    print(response.decode(encoding="utf-8"))


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_tcp_client())

构造简单版TCPClient,tornado.tcpclient.TCPClient内部也是创建了一个IOStream对象 

import asyncio
import socket
from custom_gen import with_timeout
from tornado.iostream import IOStream
from datetime import datetime, timedelta
from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable, Optional, Set, Iterator


class TCPClient(object):

    def connect(self, host: str, port: int, timeout: Union[float, timedelta] = None):
        # 在超时时间内返回Stream对象,否则返回超时异常
        socket_obj = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        stream = IOStream(socket_obj)
        stream_future = stream.connect((host, port))  # 开始尝试连接IP和端口
        return with_timeout(timeout, stream_future)


async def test_tcp_client():
    test_stream = await TCPClient().connect("27.25.24.24", 7023, timeout=3)
    await test_stream.write(b'hello server')
    response = await test_stream.read_until(b"\n")
    print(response.decode(encoding="utf-8"))


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_tcp_client())

进阶版TCPClient

import asyncio
import functools
import socket
from asyncio import Future, TimerHandle

from custom_gen import with_timeout, DefaultExecutorResolver
from tornado.iostream import IOStream
from datetime import datetime, timedelta
from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable, Optional, Set, Iterator

_INITIAL_CONNECT_TIMEOUT = 0.3


class TCPClient(object):
    """return : 通过connect函数获取IOStream对象
    封装的功能
    DSN功能: 通过DefaultExecutorResolver获取域名的IP地址列表
    失败重试: 一定时间内没有获得连接则开始重试,连接中发生异常开始重试 (用另一个IP地址重试)
    超时反馈: 向调用方反馈超时异常,迟到的结果会被丢弃
    """

    def __init__(self):
        self.future = Future()
        self.resolver = DefaultExecutorResolver()  # DNS域名解析对象
        self.timeout = None  # type: Optional[TimerHandle] # 超时后开始尝试ipv6地址
        self.connect_timeout = None  # type: Optional[TimerHandle]  # 调用方定义的连接超时时间
        self.last_error = None  # type: Optional[Exception]# 记录最后一次尝试的异常
        self.remaining = None  # 剩余可以尝试的IP个数
        self.streams = set()  # type: Set[IOStream]  # 尝试连接中创建的stream对象
        self.loop = asyncio.get_event_loop()  # 当前时间循环对象
        self.primary_addrs, self.secondary_addrs = None, None  # 域名下对应的IPv4地址列表和IPv6地址列表

    def split(self, addrinfo: List[Tuple]) -> Tuple[
        List[Tuple[socket.AddressFamily, Tuple]],
        List[Tuple[socket.AddressFamily, Tuple]],
    ]:
        """把域名下的IP地址列表分为IPv4地址列表和IPv6地址列表"""
        primary = []  # IPv4 地址列表
        secondary = []  # IPv6 地址列表
        primary_af = addrinfo[0][0]
        for af, addr in addrinfo:
            if af == primary_af:
                primary.append((af, addr))
            else:
                secondary.append((af, addr))
        return primary, secondary

    async def connect(self, host: str, port: int, timeout: Union[float, timedelta] = None):
        # DNS解析,一个域名可能有多个IP地址
        if timeout is not None:
            addrinfo = await with_timeout(timeout, self.resolver.resolve(host, port))  # 带超时限制的DNS解析
        else:
            addrinfo = await self.resolver.resolve(host, port)

        # 把IP地址分为IPv4和IPv6地址列表
        self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
        self.try_connect(iter(self.primary_addrs))

        # 超时后开始尝试IPv6地址,默认超时时间为0.3秒
        self.set_timeout(_INITIAL_CONNECT_TIMEOUT)
        if timeout is not None:
            self.set_connect_timeout(timeout)
        addr, stream = await self.future
        return stream

    def try_connect(self, addrs) -> None:
        """尝试连接一个ip和端口"""
        try:
            af, addr = next(addrs)
        except StopIteration:
            # 如果所有IP地址均已尝试,并所有on_connect_done均被回调,调用方也没有设置超时时间,则把最后一个异常反馈给调用方
            if self.remaining == 0 and not self.future.done():
                self.future.set_exception(self.last_error or IOError("connection failed"))
            return
        # 尝试获得IP加端口对应的IOStream对象
        socket_obj = socket.socket(af)
        stream = IOStream(socket_obj)
        stream_future = stream.connect(addr)
        self.streams.add(stream)
        stream_future.add_done_callback(functools.partial(self.on_connect_done, addr))

    def on_connect_done(self, addrs, future):
        """连接超时.成功,失败都会执行该回调函数
        注:连接超时先返回超时异常给调用方,最终还是成功和失败的结果
        """
        try:
            stream = future.result()
        except Exception as e:
            if self.future.done():
                return
            # 记录最后一次异常,当所有Ip均失败后,把该异常反馈给调用方
            self.last_error = e
            self.try_connect(addrs)
            # 如果第一次尝试失败,则不等待超时就开始尝试IPv6地址列表
            if self.timeout is not None:
                self.timeout.cancel()
                self.on_timeout()
            return

        # 成功获取到连接
        self.clear_timeouts()
        if self.future.done():  # This is a late arrival; just drop it.
            stream.close()
        else:
            # 如果没有超时,则排除该stream对象外关闭其他stream
            self.streams.discard(stream)
            self.future.set_result((addrs, stream))
            self.close_streams()

    def set_timeout(self, timeout):
        self.timeout = self.loop.call_at(loop.time() + timeout, self.on_timeout)

    def on_timeout(self):
        """如果短时间内没有获取到连接或第一次尝试失败,则开始尝试开始IPv6地址池"""
        self.timeout = None
        if not self.future.done():  # 没有异常发生并且没有获取到连接
            # 开始尝试IPv6地址
            self.try_connect(iter(self.secondary_addrs))

    def set_connect_timeout(self, connect_timeout):
        self.connect_timeout = self.loop.call_at(loop.time() + connect_timeout, self.on_connect_timeout)

    def on_connect_timeout(self) -> None:
        """用户定义的超时时间,超时后向用户反馈超时"""
        if not self.future.done():
            self.future.set_exception(TimeoutError())
        self.close_streams()

    def clear_timeouts(self) -> None:
        """成功获取到连接后,清除所有超时定时器"""
        if self.timeout is not None:
            self.timeout.cancel()
        if self.connect_timeout is not None:
            self.connect_timeout.cancel()

    def close_streams(self) -> None:
        for stream in self.streams:
            stream.close()


async def test_tcp_client():
    test_stream = await TCPClient().connect("127.0.0.1", 9000, timeout=3)
    await test_stream.write(b'hello server\n')
    response = await test_stream.read_until(b"\n")
    print(response.decode(encoding="utf-8"))


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_tcp_client())

 Tornado IOStream源码解读

IOStream测试

import asyncio
import socket
from tornado.iostream import IOStream


async def test_stream():
    socket_obj = socket.socket(socket.AF_INET)
    stream = IOStream(socket_obj)
    await stream.connect(("127.0.0.1", 9000))
    await stream.write(b'hello server\n')
    response = await stream.read_until(b"\n")
    print(response.decode(encoding="utf-8"))


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_stream())

add_reader, add_writer

官网add_reader示例:https://docs.python.org/3.6/library/asyncio-eventloop.html?highlight=add_reader#watch-a-file-descriptor-for-read-events

import asyncio
try:
    from socket import socketpair
except ImportError:
    from asyncio.windows_utils import socketpair

# Create a pair of connected file descriptors
rsock, wsock = socketpair()
loop = asyncio.get_event_loop()

def reader():
    data = rsock.recv(100)
    print("Received:", data.decode())
    # We are done: unregister the file descriptor
    loop.remove_reader(rsock)
    # Stop the event loop
    loop.stop()

# Register the file descriptor for read event
loop.add_reader(rsock, reader)

# Simulate the reception of data from the network
loop.call_soon(wsock.send, 'abc'.encode())

# Run the event loop
loop.run_forever()

# We are done, close sockets and the event loop
rsock.close()
wsock.close()
loop.close()

不完整版IOStream

Python 官网关键函数介绍:https://docs.python.org/3.6/library/asyncio-eventloop.html?highlight=remove_writer#watch-file-descriptors

 

 

Python官网介绍:https://docs.python.org/3.6/library/socket.html?highlight=recv_into#socket.socket.recv_into

 

 

 

注意:

On windows, socket.send blows up if given a write buffer that's too large, instead of just returning the number of bytes it was able to process.  Therefore we must not call socket.send with more than 128KB at a time.

import asyncio
import errno
import socket
from asyncio import Future
from typing import Union, Optional

# These errnos indicate that a non-blocking operation must be retried
# at a later time.  On most platforms they're the same value, but on
# some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)

# These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors.
_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT)

# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)


class StreamClosedError(IOError):
    def __init__(self, ) -> None:
        super(StreamClosedError, self).__init__("Stream is closed")


class CustomIOStream(object):
    def __init__(self):
        self.loop = asyncio.get_event_loop()
        self.sock = None
        self.connect_future = None
        self.read_future = None
        self.write_future = None
        self._closed = False

    def connect(self, address):
        socket_obj = socket.socket(socket.AF_INET)
        socket_obj.setblocking(False)
        self.sock = socket_obj
        future = Future()
        try:
            socket_obj.connect(address)
        except socket.error as e:
            if e.errno not in _ERRNO_INPROGRESS and e.errno not in _ERRNO_WOULDBLOCK:
                self.close()
                future.set_exception(StreamClosedError())
                return future
        self.connect_future = future
        self.loop.add_writer(self.sock.fileno(), self._handle_connect)
        return future

    def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
        try:
            return self.sock.recv_into(buf, len(buf))
        except socket.error as e:
            if e.args[0] in _ERRNO_WOULDBLOCK:
                return None
            else:
                raise
        finally:
            del buf

    def write_to_fd(self, data: memoryview) -> int:
        try:
            return self.sock.send(data)  # type: ignore
        finally:
            del data

    def _handle_connect(self):
        try:
            err = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
        except Exception as e:
            if e.args[0] in _ERRNO_WOULDBLOCK:
                err = 0
            else:
                self.close()
                return
        self.loop.remove_writer(self.sock.fileno())
        if err != 0:  # 如果不为0,则表示出错
            self.connect_future.set_exception(StreamClosedError())
        else:
            self.connect_future.set_result(self)

    def _handle_read(self):
        buf = bytearray(1024)
        bytes_read = self.read_from_fd(buf)
        if bytes_read == 0:
            self.close()
        elif bytes_read is not None and self.read_future is not None:
            b = (memoryview(buf)[0: bytes_read]).tobytes()
            future = self.read_future
            self.read_future = None
            future.set_result(b)

    def _handle_write(self, data=None):
        try:
            if data is None:
                return
            num_bytes = self.write_to_fd(data)
        except (socket.error, IOError, OSError) as e:
            if e.args[0] in _ERRNO_WOULDBLOCK:
                pass
            else:
                self.close()
                return
        future = self.write_future
        self.write_future = None
        if not future.cancelled():
            self.loop.remove_writer(self.sock.fileno())
            future.set_result(None)

    def write(self, data):
        self._check_closed()
        future = Future()  # type: Future[None]
        self.write_future = future
        self.loop.add_writer(self.sock.fileno(), self._handle_write)
        self._handle_write(data)
        return future

    def read(self):
        self._check_closed()
        future = Future()  # type: Future[None]
        future.add_done_callback(lambda f: f.exception())
        self.read_future = future
        self.loop.add_reader(self.sock.fileno(), self._handle_read)
        self._handle_read()
        return future

    def _check_closed(self) -> None:
        if self._closed:
            raise StreamClosedError()

    def close(self):
        self.loop.remove_writer(self.sock.fileno())
        self.loop.remove_reader(self.sock.fileno())
        self.sock.close()
        self.sock = None  # type: ignore
        for future in [self.read_future, self.write_future]:
            if future is not None and not future.done():
                future.set_exception(StreamClosedError())
        self._closed = True


async def test_stream():
    stream = CustomIOStream()
    await stream.connect(("192.168.1.252", 9000))
    await stream.write(b"123\n")
    response = await stream.read()
    print(response)


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(test_stream())

HTTP服务器

简单的同步、阻塞、多进程http服务器

import socket
from multiprocessing import Process


def handle_client(client_socket):
    """
    处理客户端请求
    """
    request_data = client_socket.recv(1024)
    print("request data:", request_data.decode("utf-8"))
    # 构造响应数据
    response_start_line = "HTTP/1.1 200 OK\r\n"
    response_headers = "Server: My server\r\n"
    response_body = "<h1>Python HTTP Test</h1>"
    response = response_start_line + response_headers + "\r\n" + response_body

    # 向客户端返回响应数据
    client_socket.send(bytes(response, "utf-8"))

    # 关闭客户端连接
    client_socket.close()


if __name__ == "__main__":
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind(("", 8888))
    server_socket.listen(128)

    while True:
        client_socket, client_address = server_socket.accept()
        print("[%s]用户连接上了" % str(client_address))
        handle_client_process = Process(target=handle_client, args=(client_socket,))
        handle_client_process.start()
        client_socket.close()

异步非阻塞HTTP服务器

import socket
from typing import Any, Tuple
import asyncio
from tornado.iostream import IOStream


class Application(object):
    def __init__(self):
        self.sock = None
        self.loop = asyncio.get_event_loop()

    async def handle_stream(self, stream: IOStream, address: Tuple) -> None:
        header_future = await stream.read_until_regex(b"\r?\n\r?\n", max_bytes=1024)
        print(header_future.decode(encoding="utf-8"))
        await stream.write(b"HTTP/1.1 200 OK\r\ncontent-type: text/html; charset=UTF-8\r\n\r\n<h1>Hello, World</h1>\r\n")
        stream.close()

    def _handle_connection(self) -> None:
        try:
            connection, address = self.sock.accept()
        except socket.error as e:
            raise
        stream = IOStream(connection)
        self.stream = stream
        fut = asyncio.ensure_future(self.handle_stream(stream, address))
        fut.add_done_callback(lambda f: f.result())

    def listen(self, port: int):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setblocking(False)
        sock.bind(("127.0.0.1", port))
        sock.listen(125)
        self.loop.add_reader(sock, self._handle_connection)
        self.sock = sock


if __name__ == '__main__':
    app = Application()
    app.listen(8888)
    loop = asyncio.get_event_loop()
    loop.run_forever()

进阶版HTTP Server

httputil.py

import copy
import re
import time
from asyncio import Future
from typing import Any, Tuple, Dict, List, Union, Callable, Optional, Awaitable, Iterable
from datetime import datetime
from tornado.httpclient import HTTPError
from tornado.escape import parse_qs_bytes

_CRLF_RE = re.compile(r"\r?\n")


class CustomHTTPHeaders(object):
    def __init__(self, *args: Any):
        self._dict = {}  # type: Dict[str, str]
        self._as_list = {}  # type: Dict[str, List[str]]
        if len(args) == 1:
            self.update(*args)

    @classmethod
    def parse(cls, headers: str) -> "CustomHTTPHeaders":
        h = cls()
        for line in _CRLF_RE.split(headers):
            if line:
                h.parse_line(line)
        return h

    def parse_line(self, line: str) -> None:
        try:
            name, value = line.split(":", 1)
        except ValueError:
            raise
        self.add(name, value.strip())

    def add(self, name: str, value: str) -> None:
        """Adds a new value for the given key."""
        if name in self:
            self._dict[name] = (self[name] + "," + value)
            self._as_list[name].append(value)
        else:
            self[name] = value

    def get_all(self) -> Iterable[Tuple[str, str]]:
        """Returns an iterable of all (name, value) pairs.

        If a header has multiple values, multiple pairs will be
        returned with the same name.
        """
        for name, values in self._as_list.items():
            for value in values:
                yield (name, value)

    def update(self, *args):
        for key in args[0].keys():
            self[key] = args[0][key]

    def get(self, name: str) -> str:
        return self._dict[name]

    def __setitem__(self, name: str, value: str) -> None:
        self._dict[name] = value
        self._as_list[name] = [value]

    def __getitem__(self, name: str) -> str:
        return self._dict[name]

    def __contains__(self, key):
        try:
            self[key]
        except KeyError:
            return False
        else:
            return True


class HTTPServerRequest(object):
    path = None  # type: str
    query = None  # type: str

    # HACK: Used for stream_request_body
    _body_future = None  # type: Future[None]

    def __init__(
            self,
            header_data: bytes = None,
            address: Tuple = None,
            protocol: str = None,
    ) -> None:
        _header_data = header_data.decode("latin1").lstrip("\r\n")
        eol = _header_data.find("\n")
        start_line = _header_data[:eol].rstrip("\r")
        self.method, self.uri, self.version = start_line.split(" ")
        self.headers = CustomHTTPHeaders.parse(_header_data[eol:])
        self.body = b""

        # set remote IP and protocol
        self.remote_ip = address[0]
        self.host = self.headers.get("Host") or "127.0.0.1"
        self._start_time = time.time()
        self._finish_time = None
        self.protocol = protocol

        if self.uri is not None:
            self.path, sep, self.query = self.uri.partition("?")
        self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
        self.query_arguments = copy.deepcopy(self.arguments)
        self.body_arguments = {}  # type: Dict[str, List[bytes]]

    def full_url(self) -> str:
        """Reconstructs the full URL for this request."""
        return self.protocol + "://" + self.host + self.uri

    def request_time(self) -> float:
        """Returns the amount of time it took for this request to execute."""
        if self._finish_time is None:
            return time.time() - self._start_time
        else:
            return self._finish_time - self._start_time

    def _parse_body(self) -> None:
        for k, v in self.body_arguments.items():
            self.arguments.setdefault(k, []).extend(v)

    def __repr__(self) -> str:
        attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
        args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
        return "%s(%s)" % (self.__class__.__name__, args)


class CustomRequestHandler(object):
    def __init__(self, request):
        self.request = request
        self._headers = CustomHTTPHeaders(
            {
                "Server": "TornadoServer/%s" % "6.1",
                "Content-Type": "text/html; charset=UTF-8",
                "Date": str(datetime.now()),
            }
        )
        self.set_default_headers()
        self._write_buffer = []  # type: List[bytes]
        self._status_code = 200
        # self._reason = httputil.responses[200]

    def set_default_headers(self):
        self.set_header('Access-Control-Allow-Origin', '*')
        self.set_header('Access-Control-Allow-Headers', '*')
        self.set_header('Access-Control-Max-Age', str(1000))
        self.set_header('Content-type', 'application/json')
        self.set_header('Access-Control-Allow-Methods', 'POST, GET, DELETE, PUT, PATCH, OPTIONS')
        self.set_header(
            'Access-Control-Allow-Headers',
            'Content-Type, tsessionid, Access-Control-Allow-Origin, Access-Control-Allow-Headers, X-Requested-By, Access-Control-Allow-Methods')

    def set_header(self, name: str, value: str) -> None:
        self._headers[name] = value

    def finish(self, chunk: Union[str, bytes, dict]) -> bytes:
        chunk = chunk.encode("utf-8")
        self._write_buffer.append(chunk)
        if "Content-Length" not in self._headers:
            content_length = sum(len(part) for part in self._write_buffer)
            self.set_header("Content-Length", str(content_length))
        lines = ["HTTP/1.1 200 OK".encode(encoding="utf-8")]
        header_lines = (
            n + ": " + v for n, v in self._headers.get_all()
        )
        lines.extend(l.encode("latin1") for l in header_lines)
        for line in lines:
            if b"\n" in line:
                raise ValueError("Newline in header: " + repr(line))
        return b"\r\n".join(lines) + b"\r\n\r\n" + chunk

    def write(self, chunk: Union[str, bytes, dict]) -> None:
        chunk = chunk.encode("utf-8")
        self._write_buffer.append(chunk)

    def _unimplemented_method(self, *args: str, **kwargs: str) -> None:
        raise HTTPError(405)

    get = _unimplemented_method  # type: Callable[..., Optional[Awaitable[None]]]

 

application

import socket
import asyncio
import json
from httputil import CustomHTTPHeaders, HTTPServerRequest, CustomRequestHandler
from typing import Any, Tuple, Dict, List
from tornado.iostream import IOStream, SSLIOStream


class Application(object):
    def __init__(self, handlers):
        self.sock = None
        self.loop = asyncio.get_event_loop()
        self.handlers = {k: v for k, v in handlers}

    async def handle_stream(self, stream: IOStream, address: Tuple) -> None:
        header_data = await stream.read_until_regex(b"\r?\n\r?\n", max_bytes=1024)
        if isinstance(stream, SSLIOStream):
            protocol = "https"
        else:
            protocol = "http"
        request = HTTPServerRequest(header_data=header_data, protocol=protocol, address=address)
        handler_class = self.handlers.get(request.path)
        # result = method(*self.path_args, **self.path_kwargs)
        if handler_class is not None:
            handler = handler_class(request=request)
            method = getattr(handler, request.method.lower())
            response = await  method()
            await stream.write(response)
        stream.close()

    def _handle_connection(self) -> None:
        try:
            connection, address = self.sock.accept()
        except socket.error as e:
            raise
        stream = IOStream(connection)
        self.stream = stream
        fut = asyncio.ensure_future(self.handle_stream(stream, address))
        fut.add_done_callback(lambda f: f.result())

    def listen(self, port: int):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setblocking(False)
        sock.bind(("127.0.0.1", port))
        sock.listen(125)
        self.loop.add_reader(sock, self._handle_connection)
        self.sock = sock


class IndexHandler(CustomRequestHandler):

    async def get(self, *args, **kwargs):
        # city = self.get_query_argument("city", None)
        ret = {
            "sites": [
                {"name": "菜鸟教程", "url": "www.runoob.com"},
                {"name": "google", "url": "www.google.com"},
                {"name": "微博", "url": "www.weibo.com"}
            ]
        }
        self.set_header("Content-Type", "application/json; charset=UTF-8")
        return self.finish(json.dumps(ret))


if __name__ == '__main__':
    app = Application([(r'/', IndexHandler)])
    app.listen(8888)
    loop = asyncio.get_event_loop()
    loop.run_forever()

 

 

 

 

 

 

 

 

 

 

 

posted @ 2019-12-13 13:03  逐梦客!  阅读(349)  评论(0)    收藏  举报