[golang]更新ssl证书的命令行工具

前言

现服务器上有一个别人维护的脚本会定时更新ssl证书,之前还写了一个工具用来检测nginx证书的有效期,但这个工具只会发通知,不会替换证书,搞过几次后就嫌烦了,干脆再写一个工具。

处理逻辑:

  1. 获取别人脚本帮我获取到的证书文件有效时间et1
  2. 获取nginx证书文件的有效时间et2
  3. 如果et1晚于et2,说明证书文件更新了,则用新证书文件覆盖旧证书文件
  4. 热加载nginx配置

其它要求

  • 因为要防止别人脚本生成的证书文件有格式问题,所以要先检测能否正常读取证书文件
  • 用配置文件来动态更新配置

配置文件格式

配置文件内容如下

global:
  expiredays: 3
  logdir: logs

tasks:
  - name: auth
    skip: false
    old_cert: /home/admin/apps/openresty/nginx/cert/auth/auth.wp2.com.crt
    old_key: /home/admin/apps/openresty/nginx/cert/auth/auth.wp2.com.key
    new_cert: /home/admin/tools/update-ssl/auth.wp2.com/auth.wp2.com.crt
    new_key: /home/admin/tools/update-ssl/auth.wp2.com/auth.wp2.com.key
  - name: strategy
    skip: false
    old_cert: /home/admin/apps/openresty/nginx/cert/strategy/strategy.wp2.com.crt
    old_key: /home/admin/apps/openresty/nginx/cert/strategy/strategy.wp2.com.key
    new_cert: /home/admin/tools/update-ssl/strategy.wp2.com/strategy.wp2.com.crt
    new_key: /home/admin/tools/update-ssl/strategy.wp2.com/strategy.wp2.com.key

代码

文件结构

run.sh
app.yaml
tasks.go
main.go
config/
  config.go
log/
  log.go

config

config模块用来读取配置文件

package config

import (
	"log"
	"os"

	"gopkg.in/yaml.v3"
)

type Configuration struct {
	Tasks  []Tasks `yaml:"tasks"`
	Global Global  `yaml:"global"`
}

type Tasks struct {
	OldKey  string `yaml:"old_key"`
	NewCert string `yaml:"new_cert"`
	NewKey  string `yaml:"new_key"`
	Name    string `yaml:"name"`
	Skip    bool   `yaml:"skip"`
	OldCert string `yaml:"old_cert"`
}

type Global struct {
	Expiredays int    `yaml:"expiredays"`
	Logdir     string `yaml:"logdir"`
}

func New() *Configuration {
	content, err := os.ReadFile("app.yaml")
	if err != nil {
		log.Fatalf("Read config file error: %v", err)
	}

	var cfg Configuration
	err = yaml.Unmarshal(content, &cfg)
	if err != nil {
		log.Fatalf("Unmarshal config file error: %v", err)
	}

	return &cfg
}

log

log模块基于logrus简单封装下,增加了日志文件rotate

package log

import (
	"os"
	"path/filepath"
	"update-cert/config"

	"github.com/orandin/lumberjackrus"
	"github.com/sirupsen/logrus"
)

var cfg = config.New()

var logger *logrus.Logger
var logDirPath string = cfg.Global.Logdir
var logFilePath string

func init() {
	// logs目录是否存在, 不存在则创建
	if _, err := os.Stat(logDirPath); os.IsNotExist(err) {
		os.MkdirAll(logDirPath, os.ModePerm)
	}

	logFilePath = filepath.Join(logDirPath, "app.log")

	logger = logrus.New()
	logger.SetOutput(os.Stdout)                // 在标准输出中显示日志内容
	logger.SetLevel(logrus.InfoLevel)          // 标准输出中的日志级别
	logger.SetFormatter(&logrus.JSONFormatter{ // 标准输出以json形式显示日志
		TimestampFormat: "2006-01-02 15:04:05.000",
	})

	logger.AddHook(rotateHook()) // 通过添加hook的方式实现日志输出到文件中并自动切割

}

// 日志文件自动切割
func rotateHook() logrus.Hook {
	hook, err := lumberjackrus.NewHook(&lumberjackrus.LogFile{
		Filename:   logFilePath,
		MaxAge:     1,
		MaxSize:    100,
		MaxBackups: 30,
		Compress:   true,
		LocalTime:  false,
	},
		logrus.InfoLevel,
		&logrus.JSONFormatter{TimestampFormat: "2006-01-02 15:04:05.000"},
		&lumberjackrus.LogFileOpts{})
	if err != nil {
		logrus.Fatal(err)
	}

	return hook
}

func Info(v ...interface{}) {
	logger.Info(v...)
}

func Infof(format string, v ...interface{}) {
	logger.Infof(format, v...)
}

func Warn(v ...interface{}) {
	logger.Warn(v...)
}

func Warnf(format string, v ...interface{}) {
	logger.Warnf(format, v...)
}

func Debug(v ...interface{}) {
	logger.Debug(v...)
}

func Debugf(format string, v ...interface{}) {
	logger.Debugf(format, v...)
}

func Error(v ...interface{}) {
	logger.Error(v...)
}

func Errorf(format string, v ...interface{}) {
	logger.Errorf(format, v...)
}

func Panic(v ...interface{}) {
	logger.Panic(v...)
}

func Panicf(format string, v ...interface{}) {
	logger.Panicf(format, v...)
}

func Fatal(v ...interface{}) {
	logger.Fatal(v...)
}

func Fatalf(format string, v ...interface{}) {
	logger.Fatalf(format, v...)
}

tasks

tasks包含了主要的处理逻辑,用WaitGroup做了个并发处理

package main

import (
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"os"
	"os/exec"
	"sync"
	"time"
	"update-cert/config"
	"update-cert/log"
)

/*
解析SSL证书文件
返回证书过期时间
*/
func ParseCertificate(certFile string) (time.Time, error) {
	certData, err := os.ReadFile(certFile)
	if err != nil {
		log.Errorf("Read cert file error: %v", err)
		return time.Time{}, err
	}

	block, _ := pem.Decode(certData)
	if block == nil {
		log.Errorf("Decode cert file error: %v", err)
		return time.Time{}, err
	}

	cert, err := x509.ParseCertificate(block.Bytes)
	if err != nil {
		log.Errorf("Parse cert file error: %v", err)
		return time.Time{}, err
	}

	return cert.NotAfter, nil
}

// 使用新证书文件直接覆盖旧证书文件
func updateCert(oldFile, newFile string) error {
	newContent, err := os.ReadFile(newFile)
	if err != nil {
		return err
	}

	f, err := os.Create(oldFile)
	if err != nil {
		return err
	}
	defer f.Close()

	_, err = f.Write(newContent)
	if err != nil {
		return fmt.Errorf("write file error: %v", err)
	}
	return nil
}

func subProcess(wg *sync.WaitGroup, task config.Tasks) {
	defer wg.Done()
	oldCertNotAfter, err := ParseCertificate(task.OldCert)
	if err != nil {
		log.Errorf("Parse cert file error: %v", err)
		return
	}

	newCertNotAfter, err := ParseCertificate(task.NewCert)
	if err != nil {
		log.Errorf("Parse cert file error: %v", err)
		return
	}

	if newCertNotAfter.After(oldCertNotAfter) {
		log.Infof("Update cert: %s", task.Name)

		if err := updateCert(task.OldCert, task.NewCert); err != nil {
			log.Errorf("Update cert error: %v", err)
			return
		}
		if err := updateCert(task.OldKey, task.NewKey); err != nil {
			log.Errorf("Update key error: %v", err)
			return
		}
	}
}

func Process() {
	cfg := config.New()
	var wg sync.WaitGroup
	// wg.Add(len(cfg.Tasks))
	for _, task := range cfg.Tasks {
		if task.Skip {
			log.Infof("Skip task: %s", task.Name)
			continue
		}
		wg.Add(1)
		go subProcess(&wg, task)
	}
	wg.Wait()
	ReloadNginx()
}

// 热加载 openresty 配置
func ReloadNginx() {
	log.Info("Reload nginx")
	cmd := exec.Command("openresty", "-s", "reload")
	_, err := cmd.CombinedOutput()
	if err != nil {
		log.Errorf("Reload nginx error: %v", err)
	}
}

main

main部分就很简单了,直接调用tasks的函数。注释部分是用来测试读取证书文件的,实际代码中可去掉。

package main

func main() {
	// 指定证书文件路径
	// certFilePath := "auth.wp2.com"

	// // 读取证书文件
	// certData, err := os.ReadFile(certFilePath)
	// if err != nil {
	// 	log.Fatalf("无法读取证书文件: %v", err)
	// }

	// // 解析证书
	// block, _ := pem.Decode(certData)
	// if block == nil {
	// 	log.Fatalf("无法解析 PEM 格式的证书")
	// }

	// cert, err := x509.ParseCertificate(block.Bytes)
	// if err != nil {
	// 	log.Fatalf("无法解析证书: %v", err)
	// }

	// // 获取证书有效期
	// fmt.Printf("证书生效时间: %v\n", cert.NotBefore)
	// fmt.Printf("证书到期时间: %v\n", cert.NotAfter)
	Process()
}

run脚本

run.sh脚本用来执行这个工具

#!/bin/bash

set -u

script_dir=$(cd $(dirname $0) && pwd)

cd $script_dir

if [ ! -f "update-cert" ]; then
    echo "update-cert not found"
    go build -o update-cert
fi

./update-cert
posted @ 2025-02-04 14:29  花酒锄作田  阅读(24)  评论(0)    收藏  举报