分析uvicorn和hypercorn的热更新原理

前言

在用FastAPI框架开发项目时,我们习惯性地用uvicorn启动服务,但uvicorn启动服务时,如果代码有改动,需要手动重启服务。

在生产环境中,我们一般希望服务在启动后,一直保持运行状态,如果代码有改动,自动重启服务。

为了不至于每次代码一有改动就手动重启服务我们会在使用uvicorn命令开启服务时加上--reload参数,这样当代码有改动时,服务会自动重启。

这时的命令如下:

uvicorn main:app --host 0.0.0.0 --port 8000 --reload

我们其实可以探究一下这个--reload是在何时调用的?又是怎样实现热更新的?

通过追踪uvicorn源码执行,我们看到uvicorn走到main.py后会进入run()这个函数,在函数内部与热更新密切相关的代码如下:

from uvicorn.supervisors import ChangeReload
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()#run走的是BaseReload的run,子类没有run()

正是这个supervisors模块,提供了热更新的功能。

热更新核心——supervisors模块

supervisors模块中与热更新相关的supervisor包括:

  • BaseReload
  • WatchFilesReload
  • WatchGodReload
  • StatReload

先指出这个类的出处,便于大家查找:

# uvicorn/supervisors/__init__.py
if TYPE_CHECKING:
    ChangeReload: Type[BaseReload]
else:
    try:  # 有watchfiles模块
        from uvicorn.supervisors.watchfilesreload import (
            WatchFilesReload as ChangeReload,
        )
    except ImportError:
        try:  # 没有watchfiles模块,但有watchgod模块
            from uvicorn.supervisors.watchgodreload import (
                WatchGodReload as ChangeReload,
            )
        except ImportError:  # 两个模块都没有,这时只检查文件最近修改时间mtime
            from uvicorn.supervisors.statreload import StatReload as ChangeReload

这四个类中BaseReload是基类,其他三个类都继承自BaseReload,分别为should_restart()这个方法提供了具体实现。

那么我们先来看一下基类BaseReload做了哪些事情吧!

骨架

# uvicorn/supervisors/basereload.py
HANDLED_SIGNALS = (
    signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C.
    signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`.
)
class BaseReload:
    def __init__(
        self,
        config: Config,
        target: Callable[[Optional[List[socket]]], None],
        sockets: List[socket],
    ) -> None:
        self.config = config
        self.target = target
        self.sockets = sockets
        self.should_exit = threading.Event()
        self.pid = os.getpid()
        self.is_restarting = False
        self.reloader_name: Optional[str] = None

    # =============== 主要逻辑 ===============
    def run(self) -> None:
        self.startup()  # 1.注册退出信号+启动进程
        for changes in self:  # 内容有更新就重启
            if changes:
                self.restart()  # 2.退出旧进程开启新进程

        self.shutdown()  # 3.关闭进程+逐个关闭套接字,前台任务做完后主线程才挂join

    def startup(self) -> None:
        for sig in HANDLED_SIGNALS:  # 注册要用到的退出信号,用信号机制实现优雅退出
            signal.signal(sig, self.signal_handler)

        self.process = get_subprocess(  # 创建进程
            config=self.config, target=self.target, sockets=self.sockets
        )
        self.process.start()  # 开启进程

    def restart(self) -> None:
        if sys.platform == "win32": 
            self.is_restarting = True
            assert self.process.pid is not None  # 没有启动着不能重启
            # 用户按下了Ctrl+C时将信号发到当前进程,结束这个进程
            os.kill(self.process.pid, signal.CTRL_C_EVENT)  # 先杀掉旧进程
        else:
            self.process.terminate()  # Linux直接用这句结束旧进程
        self.process.join()  # 还要开启新进程,所以不要退出!

        self.process = get_subprocess(  # 再开启新进程
            config=self.config, target=self.target, sockets=self.sockets
        )
        self.process.start()

    def shutdown(self) -> None:
        if sys.platform == "win32":  # 判断平台
            self.should_exit.set()  
        else:
            self.process.terminate() 
        self.process.join()

        for sock in self.sockets:  # 挨个关闭套接字
            sock.close()

        message = "Stopping reloader process [{}]".format(str(self.pid))
        color_message = "Stopping reloader process [{}]".format(
            click.style(str(self.pid), fg="cyan", bold=True)
        )
        logger.info(message, extra={"color_message": color_message})

    # =============== 迭代器 ===============
    def __iter__(self) -> Iterator[Optional[List[Path]]]:
        return self

    def __next__(self) -> Optional[List[Path]]:  # 使用next()调用
        return self.should_restart()
    
    def pause(self) -> None:
        if self.should_exit.wait(self.config.reload_delay):
            raise StopIteration()

    # =============== 需要子类重写的方法 ===============
    def should_restart(self) -> Optional[List[Path]]:  # 由具体重载策略实现
        raise NotImplementedError("Reload strategies should override should_restart()")

可以看出基类只是说了热更新是怎么一个过程,但执行热更新的时机需要子类去进一步交代。

单就uvicorn而言,从历史的角度看,其实经历了三个版本:

1.第一个版本只是简单地检查文件是否修改过,如果修改过文件的元数据mtime就会改变,只要比较开启热更新功能时指定路径下诸文件的mtime,与当前的mtime,如果不同,就说明有文件被修改过,此时就会执行热更新;这个版本存在的问题是如果文件确实修改过了,但一通修改后,文件内容并没有发生改变,此时执行热更新是没有意义的,而且会使得服务不必要地频繁重启,生产环境中要避免发生这种情况。

2.第二个版本(名为watchgod)引入了文件内容比较,如果文件内容确实发生了改变,才会执行热更新。

3.与上个版本类似地,仍然比较文件内容,不过由于解释型语言执行效率相比编译型语言要低一些,同时也得益于编译型语言Rust日益流行,作者尝试对第二个版本用Rust语言进行了重写。第二个版本直到2022年3月23日停止维护,以后建议使用Rust重写的版本,同时名字也改成了watchfiles。

下面我们沿着这条脉络来品一品他是怎么做的。

热更新策略

  • 检查文件修改时间是否改变
class StatReload(BaseReload):
    def __init__(
        self,
        config: Config,
        target: Callable[[Optional[List[socket]]], None],
        sockets: List[socket],
    ) -> None:
        super().__init__(config, target, sockets)
        self.reloader_name = "StatReload"
        self.mtimes: Dict[Path, float] = {}

    # 只检查文件修改时间,最近修改时间比打开监控程序时修改时间晚表示文件有变化,不检查文件内容是否改变
    def should_restart(self) -> Optional[List[Path]]:
        self.pause()

        for file in self.iter_py_files():
            try:
                mtime = file.stat().st_mtime
            except OSError:
                continue

            old_time = self.mtimes.get(file)
            if old_time is None:
                self.mtimes[file] = mtime
                continue
            elif mtime > old_time:
                return [file]
        return None

    def restart(self) -> None:
        self.mtimes = {}
        return super().restart()

  • 检查文件内容是否改变
# 用到模块watchfiles
# https://github.com/samuelcolvin/watchfiles.git
# pip install watchfiles
from watchfiles import watch
class FileFilter:  # 筛选监控哪些文件的改变
    def __init__(self, config: Config):  # 设置哪些文件监控哪些不监控
        default_includes = ["*.py"]
        self.includes = [  # 要监控的文件类型
            default
            for default in default_includes
            if default not in config.reload_excludes
        ]
        self.includes.extend(config.reload_includes)  # 加入写在配置中的文件类型
        self.includes = list(set(self.includes))  # 类型去重

        default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]  # 排除监控的类型
        self.excludes = [
            default
            for default in default_excludes
            if default not in config.reload_includes
        ]
        self.exclude_dirs = []
        for e in config.reload_excludes:
            p = Path(e)
            try:
                is_dir = p.is_dir()  # 是否目录
            except OSError:  
                is_dir = False

            if is_dir:
                self.exclude_dirs.append(p)
            else:
                self.excludes.append(e)
        self.excludes = list(set(self.excludes))  # 不监控的类型

    def __call__(self, path: Path) -> bool:
        for include_pattern in self.includes:
            if path.match(include_pattern):
                for exclude_dir in self.exclude_dirs:
                    if exclude_dir in path.parents:
                        return False

                for exclude_pattern in self.excludes:
                    if path.match(exclude_pattern):
                        return False

                return True
        return False


class WatchFilesReload(BaseReload):
    def __init__(
        self,
        config: Config,
        target: Callable[[Optional[List[socket]]], None],
        sockets: List[socket],
    ) -> None:
        super().__init__(config, target, sockets)
        self.reloader_name = "WatchFiles"
        self.reload_dirs = []  # 上面设置了对哪些类型文件监控,这里设置监控文件在哪些目录
        for directory in config.reload_dirs:  # 检查配置文件
            if Path.cwd() not in directory.parents:
                self.reload_dirs.append(directory)  # 启动监控命令的那个目录加进来,监控范围:当前目录及其子目录
        if Path.cwd() not in self.reload_dirs:  # 检查非配置文件
            self.reload_dirs.append(Path.cwd())

        self.watch_filter = FileFilter(config)  # 过滤监控文件类型
        self.watcher = watch(  # 核心逻辑由Rust语言实现,检查哪些路径下内容有变化,返回生成器,底层会开单独的线程作为后台线程
            *self.reload_dirs,
            watch_filter=None,
            stop_event=self.should_exit,
            yield_on_timeout=True,
        )

    def should_restart(self) -> Optional[List[Path]]:
        self.pause()

        changes = next(self.watcher)  # 把所有变化的路径去重后加到列表中并返回
        if changes:
            unique_paths = {Path(c[1]) for c in changes}
            return [p for p in unique_paths if self.watch_filter(p)]
        return None

至于watch函数的实现,这里不展开讲,仅附上实现

pub fn watch(
        slf: &PyCell<Self>,
        py: Python,
        debounce_ms: u64,
        step_ms: u64,
        timeout_ms: u64,
        stop_event: PyObject,
    ) -> PyResult<PyObject> {
        // 核心代码
        let mut max_debounce_time: Option<SystemTime> = None;
        let step_time = Duration::from_millis(step_ms);
        let mut last_size: usize = 0;  // 注意这个变量!
        let max_timeout_time: Option<SystemTime> = match timeout_ms {
            0 => None,
            _ => Some(SystemTime::now() + Duration::from_millis(timeout_ms)),
        };
        loop {
            let size = slf.borrow().changes.lock().unwrap().len();  // 可以看出其实就是随时检测文件大小,一旦增减字节文件大小就发生变化
            if size > 0 {
                if size == last_size {
                    break;
                }
                last_size = size;

                let now = SystemTime::now();
                if let Some(max_time) = max_debounce_time {
                    if now > max_time {
                        break;
                    }
                } else {
                    max_debounce_time = Some(now + Duration::from_millis(debounce_ms));
                }
            } else if let Some(max_time) = max_timeout_time {
                if SystemTime::now() > max_time {
                    slf.borrow().clear();
                    return Ok("timeout".to_object(py));
                }
            }
        }
        let py_changes = slf.borrow().changes.lock().unwrap().to_object(py);
        slf.borrow().clear();
        Ok(py_changes)
    }

当然,以上只是核心代码,具体实现有兴趣的朋友可以自己去看,这里就不展开说了。

该代码仓库在Github上的地址:https://github.com/samuelcolvin/watchfiles.git

最后,作为这篇短文的结尾,我们来聊聊另一款著名的web服务器:hypercorn是怎么实现热更新的。

大家用我这里的方法可以自己去追踪函数入口,这里只贴出主要代码。

def wait_for_changes(shutdown_event: EventType) -> None:  # 只检查文件最近修改时间
    last_updates: Dict[Path, float] = {}  # hypercorn实现热更新最重要的就是这个叫last_updates的字典
    for module in list(sys.modules.values()):
        filename = getattr(module, "__file__", None)
        if filename is None:
            continue
        path = Path(filename)
        try:
            last_updates[Path(filename)] = path.stat().st_mtime
        except (FileNotFoundError, NotADirectoryError):
            pass

    while not shutdown_event.is_set():
        time.sleep(1)

        for index, (path, last_mtime) in enumerate(last_updates.items()):
            if index % 10 == 0:
                # Yield to the event loop
                time.sleep(0)

            try:
                mtime = path.stat().st_mtime
            except FileNotFoundError:
                return
            else:
                if mtime > last_mtime:
                    return
                else:
                    last_updates[path] = mtime

由此可见,hypercorn的实现方式是:当文件有修改时,就重新启动服务。也许以后hypercorn也会实现类似watchfiles的方案吧!

posted @ 2023-11-21 15:33  你好aloha  阅读(928)  评论(0)    收藏  举报