🚀argparse 进阶实战指南:从脚本到专业命令行工具

一、argparse + 配置文件融合方案

1.1 为什么需要配置文件?

真实场景:模型训练任务

# 开发阶段:频繁调整参数
python train.py --model resnet50 --lr 0.001 --batch-size 64 \
                --epochs 100 --optimizer adam --dropout 0.5 \
                --weight-decay 0.0001 --data-augment True \
                --early-stop 10 --save-dir ./checkpoints \
                --log-interval 100 --val-interval 5

问题

  • 命令太长,容易出错

  • 参数难以复用

  • 版本管理困难

  • 团队协作不便

解决方案:配置文件 + CLI 覆盖

1.2 配置文件格式选择

格式 优点 缺点 推荐场景
YAML 可读性好,广泛使用 需要PyYAML库 机器学习、配置管理
TOML Python内置支持,简洁 嵌套结构稍复杂 项目配置、工具设置
JSON 标准,无需额外依赖 不能写注释,冗长 API配置、Web应用
.py文件 灵活,可编程 有安全风险 动态配置、复杂逻辑

推荐:YAML(最流行)或 TOML(Python原生)

1.3 完整实现方案

基础版:CLI优先原则

import argparse
import yaml
import os
from pathlib import Path

def parse_args():
    """解析命令行参数,支持配置文件"""
    parser = argparse.ArgumentParser(
        description='模型训练程序',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # 配置文件参数(放在最前面)
    parser.add_argument(
        '-c', '--config',
        type=Path,
        help='配置文件路径 (YAML/TOML/JSON)',
        default=None
    )
    
    # 模型参数组
    model_group = parser.add_argument_group('模型参数')
    model_group.add_argument('--model', type=str, help='模型名称')
    model_group.add_argument('--pretrained', action='store_true', help='使用预训练权重')
    
    # 训练参数组
    train_group = parser.add_argument_group('训练参数')
    train_group.add_argument('--epochs', type=int, help='训练轮数')
    train_group.add_argument('--batch-size', type=int, help='批次大小')
    train_group.add_argument('--lr', type=float, help='学习率')
    
    # 系统参数组
    system_group = parser.add_argument_group('系统参数')
    system_group.add_argument('--device', type=str, help='运行设备')
    system_group.add_argument('--workers', type=int, help='数据加载线程数')
    
    # 解析命令行参数
    cli_args = parser.parse_args()
    
    # 最终配置 = 默认值 ← 配置文件 ← 命令行参数
    config = {}
    
    # 1. 加载配置文件(如果有)
    if cli_args.config:
        config = load_config(cli_args.config)
    
    # 2. 合并配置:命令行参数覆盖配置文件
    final_args = merge_configs(parser, config, cli_args)
    
    # 3. 验证必需参数
    validate_required_args(final_args)
    
    return final_args

def load_config(config_path):
    """加载配置文件,支持多种格式"""
    config_path = Path(config_path)
    
    if not config_path.exists():
        raise FileNotFoundError(f"配置文件不存在: {config_path}")
    
    suffix = config_path.suffix.lower()
    
    if suffix == '.yaml' or suffix == '.yml':
        import yaml
        with open(config_path, 'r', encoding='utf-8') as f:
            return yaml.safe_load(f)
    elif suffix == '.toml':
        import tomllib
        with open(config_path, 'rb') as f:
            return tomllib.load(f)
    elif suffix == '.json':
        import json
        with open(config_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    else:
        raise ValueError(f"不支持的配置文件格式: {suffix}")

def merge_configs(parser, config_dict, cli_args):
    """
    合并配置的优先级:
    1. 命令行参数(最高)
    2. 配置文件
    3. argparse默认值(最低)
    """
    # 获取所有参数的定义
    arg_defaults = {
        action.dest: action.default
        for action in parser._actions
        if action.dest != 'help'
    }
    
    # 从配置文件和命令行参数构建最终配置
    final_args = argparse.Namespace()
    
    for dest in arg_defaults:
        # 1. 尝试从命令行获取
        cli_value = getattr(cli_args, dest, None)
        
        # 2. 如果命令行没有,尝试从配置文件获取
        if cli_value is None and dest in config_dict:
            config_value = config_dict[dest]
            setattr(final_args, dest, config_value)
        # 3. 如果命令行有值(即使为None),使用命令行值
        else:
            setattr(final_args, dest, cli_value)
    
    return final_args

def validate_required_args(args):
    """验证必需参数"""
    required_fields = ['model', 'epochs', 'batch_size', 'lr']
    missing = []
    
    for field in required_fields:
        if getattr(args, field, None) is None:
            missing.append(field)
    
    if missing:
        raise ValueError(f"缺少必需参数: {', '.join(missing)}")

if __name__ == '__main__':
    try:
        args = parse_args()
        print("最终配置:")
        for key, value in vars(args).items():
            print(f"  {key}: {value}")
    except Exception as e:
        print(f"错误: {e}")

高级版:支持环境变量和默认配置文件

import os
from typing import Dict, Any, Optional

class ConfigManager:
    """高级配置管理器"""
    
    def __init__(self):
        self.config_hierarchy = []
        
    def load_configs(self, cli_args) -> Dict[str, Any]:
        """按优先级加载配置"""
        configs = []
        
        # 1. 默认配置(代码硬编码)
        configs.append(self.get_default_config())
        
        # 2. 环境配置(根据环境变量)
        env_config = self.load_env_config()
        if env_config:
            configs.append(env_config)
        
        # 3. 用户配置文件(~/.config/)
        user_config = self.load_user_config()
        if user_config:
            configs.append(user_config)
        
        # 4. 项目配置文件(./config/)
        project_config = self.load_project_config()
        if project_config:
            configs.append(project_config)
        
        # 5. 命令行指定的配置文件
        if cli_args.config:
            file_config = self.load_config_file(cli_args.config)
            configs.append(file_config)
        
        # 6. 命令行参数(最高优先级)
        cli_config = self.cli_args_to_dict(cli_args)
        configs.append(cli_config)
        
        # 合并所有配置
        final_config = {}
        for cfg in configs:
            final_config.update(cfg)
        
        return final_config
    
    def get_default_config(self) -> Dict[str, Any]:
        """返回默认配置"""
        return {
            'device': 'cuda' if torch.cuda.is_available() else 'cpu',
            'workers': 4,
            'seed': 42,
            'log_level': 'INFO'
        }
    
    def load_env_config(self) -> Dict[str, Any]:
        """从环境变量加载配置"""
        config = {}
        
        # 环境变量前缀
        prefix = 'MYAPP_'
        
        for key, value in os.environ.items():
            if key.startswith(prefix):
                config_key = key[len(prefix):].lower()
                # 尝试转换为合适的类型
                config[config_key] = self.convert_value(value)
        
        return config
    
    def convert_value(self, value: str) -> Any:
        """智能类型转换"""
        # 布尔值
        if value.lower() in ('true', 'yes', '1'):
            return True
        elif value.lower() in ('false', 'no', '0'):
            return False
        
        # 数字
        try:
            if '.' in value:
                return float(value)
            else:
                return int(value)
        except ValueError:
            pass
        
        # 列表(逗号分隔)
        if ',' in value:
            return [self.convert_value(v.strip()) for v in value.split(',')]
        
        # 字符串
        return value

# 使用示例
config_manager = ConfigManager()
config = config_manager.load_configs(args)

1.4 配置文件示例

config.yaml:

# 训练配置
model: resnet50
pretrained: true
epochs: 100
batch_size: 64
lr: 0.001
optimizer: adam
weight_decay: 0.0001

# 数据配置
data:
  root: ./data
  train_split: 0.8
  val_split: 0.2
  augmentation: true

# 系统配置
device: cuda
workers: 8
seed: 42
log_level: INFO

使用方式:

# 基础使用
python train.py --config config.yaml

# CLI覆盖配置
python train.py --config config.yaml --lr 0.01 --batch-size 32

# 完全使用CLI(无配置文件)
python train.py --model resnet50 --epochs 50 --lr 0.001

二、argparse + logging 专业日志系统

2.1 为什么需要专业的日志?

print()的局限性:

  • 无法控制输出级别

  • 不能输出到文件

  • 没有时间戳

  • 无法区分不同模块

  • 生产环境无法使用

专业的日志系统应具备:

  • 多级别日志(DEBUG/INFO/WARNING/ERROR)

  • 同时输出到控制台和文件

  • 日志轮转(防止日志文件过大)

  • 结构化日志(JSON格式)

  • 性能监控

2.2 完整实现方案

import argparse
import logging
import sys
from logging.handlers import RotatingFileHandler
import json
from datetime import datetime
from pathlib import Path

class StructuredFormatter(logging.Formatter):
    """结构化日志格式化器(JSON格式)"""
    
    def format(self, record):
        log_record = {
            'timestamp': datetime.utcnow().isoformat() + 'Z',
            'level': record.levelname,
            'logger': record.name,
            'module': record.module,
            'function': record.funcName,
            'line': record.lineno,
            'message': record.getMessage(),
        }
        
        # 添加额外字段
        if hasattr(record, 'extra'):
            log_record.update(record.extra)
        
        # 异常信息
        if record.exc_info:
            log_record['exception'] = self.formatException(record.exc_info)
        
        return json.dumps(log_record, ensure_ascii=False)

def setup_logging(args):
    """
    配置日志系统
    
    日志级别对应关系:
    --verbose 0 (默认): WARNING
    --verbose 1       : INFO
    --verbose 2       : DEBUG
    """
    # 确定日志级别
    level_mapping = {
        0: logging.WARNING,
        1: logging.INFO,
        2: logging.DEBUG
    }
    log_level = level_mapping.get(args.verbose, logging.DEBUG)
    
    # 配置根日志记录器
    root_logger = logging.getLogger()
    root_logger.setLevel(log_level)
    
    # 清除已有的处理器
    root_logger.handlers.clear()
    
    # 1. 控制台处理器
    console_handler = logging.StreamHandler(sys.stdout)
    if args.verbose >= 2:  # 调试模式使用详细格式
        console_format = '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
    else:
        console_format = '%(asctime)s - %(levelname)s - %(message)s'
    
    console_formatter = logging.Formatter(console_format, datefmt='%H:%M:%S')
    console_handler.setFormatter(console_formatter)
    root_logger.addHandler(console_handler)
    
    # 2. 文件处理器(如果指定了日志文件)
    if args.log_file:
        log_path = Path(args.log_file)
        
        # 创建日志目录
        log_path.parent.mkdir(parents=True, exist_ok=True)
        
        # 使用轮转文件处理器(每个文件10MB,最多保留5个)
        file_handler = RotatingFileHandler(
            log_path,
            maxBytes=10 * 1024 * 1024,  # 10MB
            backupCount=5,
            encoding='utf-8'
        )
        
        # 文件日志使用JSON格式
        file_handler.setFormatter(StructuredFormatter())
        root_logger.addHandler(file_handler)
    
    # 3. 错误处理器(分离错误日志)
    if args.error_log:
        error_handler = logging.FileHandler(
            args.error_log,
            encoding='utf-8'
        )
        error_handler.setLevel(logging.ERROR)
        error_handler.setFormatter(StructuredFormatter())
        root_logger.addHandler(error_handler)
    
    # 配置第三方库的日志级别
    logging.getLogger('urllib3').setLevel(logging.WARNING)
    logging.getLogger('requests').setLevel(logging.WARNING)
    
    return root_logger

def add_logging_args(parser):
    """添加日志相关参数到argparser"""
    log_group = parser.add_argument_group('日志选项')
    
    log_group.add_argument(
        '-v', '--verbose',
        action='count',
        default=0,
        help='日志详细程度: -v=INFO, -vv=DEBUG'
    )
    
    log_group.add_argument(
        '--log-file',
        type=Path,
        help='日志文件路径'
    )
    
    log_group.add_argument(
        '--error-log',
        type=Path,
        help='错误日志文件路径(单独存储ERROR级别日志)'
    )
    
    log_group.add_argument(
        '--log-json',
        action='store_true',
        help='控制台也输出JSON格式日志'
    )
    
    log_group.add_argument(
        '--quiet',
        action='store_true',
        help='静默模式,不输出任何日志'
    )
    
    return parser

# 使用示例
def main():
    parser = argparse.ArgumentParser(description='带日志系统的程序')
    parser = add_logging_args(parser)
    
    # 添加其他参数...
    parser.add_argument('--input', required=True, help='输入文件')
    
    args = parser.parse_args()
    
    # 设置日志
    logger = setup_logging(args)
    
    # 使用日志
    logger.info("程序启动")
    logger.debug(f"参数: {args}")
    
    try:
        # 业务逻辑
        logger.info(f"处理文件: {args.input}")
        # ...
        logger.info("处理完成")
    except Exception as e:
        logger.error(f"处理失败: {e}", exc_info=True)
        raise

class MetricsCollector:
    """性能指标收集器"""
    
    def __init__(self, logger):
        self.logger = logger
        self.metrics = {}
        self.start_time = datetime.now()
    
    def start_timer(self, name):
        """开始计时"""
        self.metrics[name] = {
            'start': datetime.now(),
            'end': None,
            'duration': None
        }
    
    def stop_timer(self, name):
        """结束计时"""
        if name in self.metrics:
            self.metrics[name]['end'] = datetime.now()
            duration = (self.metrics[name]['end'] - 
                       self.metrics[name]['start']).total_seconds()
            self.metrics[name]['duration'] = duration
            
            # 记录指标
            self.logger.info(
                f"计时器 '{name}' 完成",
                extra={'metric': name, 'duration_seconds': duration}
            )
    
    def log_metrics(self):
        """记录所有指标"""
        total_time = (datetime.now() - self.start_time).total_seconds()
        
        self.logger.info(
            "性能指标汇总",
            extra={
                'total_duration_seconds': total_time,
                'metrics': self.metrics
            }
        )

2.3 高级日志功能

2.3.1 日志上下文管理器

import contextlib

@contextlib.contextmanager
def log_section(logger, message, level=logging.INFO):
    """日志分段上下文管理器"""
    logger.log(level, f"┌─ 开始: {message}")
    start_time = datetime.now()
    
    try:
        yield
    except Exception as e:
        logger.error(f"✗ 失败: {message}")
        raise
    finally:
        duration = (datetime.now() - start_time).total_seconds()
        logger.log(level, f"└─ 完成: {message} ({duration:.2f}s)")

# 使用
with log_section(logger, "数据加载"):
    data = load_data(args.input)

2.3.2 进度条日志

class ProgressLogger:
    """带进度条的日志"""
    
    def __init__(self, logger, total, desc="进度"):
        self.logger = logger
        self.total = total
        self.desc = desc
        self.current = 0
        self.start_time = datetime.now()
        
        logger.info(f"开始 {desc} (共 {total} 项)")
    
    def update(self, n=1):
        """更新进度"""
        self.current += n
        percent = (self.current / self.total) * 100
        
        if self.current % 10 == 0 or self.current == self.total:
            elapsed = (datetime.now() - self.start_time).total_seconds()
            eta = (elapsed / self.current) * (self.total - self.current)
            
            self.logger.info(
                f"{self.desc}: {self.current}/{self.total} "
                f"({percent:.1f}%) ETA: {eta:.1f}s"
            )
    
    def finish(self):
        """完成进度"""
        total_time = (datetime.now() - self.start_time).total_seconds()
        self.logger.info(
            f"完成 {self.desc}, 总耗时: {total_time:.2f}s"
        )

# 使用
progress = ProgressLogger(logger, len(data), "训练批次")
for batch in data:
    train_batch(batch)
    progress.update()
progress.finish()

三、argparse → click 迁移指南

3.1 click 是什么?为什么选择它?

click 的优势:

  1. 装饰器语法:更简洁,更Pythonic

  2. 自动补全:支持bash/zsh自动补全

  3. 更好的错误提示:颜色高亮,更友好的错误信息

  4. 参数验证:内置丰富的验证器

  5. 子命令支持:更优雅的子命令实现

  6. 进度条:内置进度条支持

  7. 颜色输出:终端彩色输出

何时迁移?

  • 保持 argparse:内部脚本、简单工具、依赖少

  • 迁移到 click:对外发布、复杂CLI、需要良好用户体验

3.2 argparse → click 对照表

argparse 功能 click 等价实现 说明
add_argument() @click.option() 可选参数
位置参数 @click.argument() 必需参数
add_subparsers() @click.group() 子命令
action='store_true' is_flag=True 布尔标志
type=int type=click.INT 类型验证
choices=['a','b'] type=click.Choice() 选择限制
nargs='+' nargs=-1 多个参数

3.3 迁移示例

argparse 版本

import argparse

def main():
    parser = argparse.ArgumentParser(description='文件处理工具')
    
    parser.add_argument('input', help='输入文件')
    parser.add_argument('output', nargs='?', help='输出文件')
    parser.add_argument('--verbose', '-v', action='count', default=0)
    parser.add_argument('--force', action='store_true', help='强制覆盖')
    
    subparsers = parser.add_subparsers(dest='command', required=True)
    
    # 压缩命令
    compress_parser = subparsers.add_parser('compress', help='压缩文件')
    compress_parser.add_argument('--level', type=int, choices=range(1, 10), default=6)
    
    # 解压命令
    extract_parser = subparsers.add_parser('extract', help='解压文件')
    
    args = parser.parse_args()
    
    if args.verbose >= 2:
        print(f"调试模式: {args}")
    
    # 处理逻辑...

if __name__ == '__main__':
    main()

click 迁移版本

import click
import sys

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option('1.0.0', prog_name='文件处理工具')
@click.pass_context
def cli(ctx):
    """文件处理工具 - 压缩和解压文件"""
    # 初始化上下文
    ctx.ensure_object(dict)
    ctx.obj['verbose'] = 0

@cli.command()
@click.argument('input', type=click.Path(exists=True))
@click.argument('output', type=click.Path(), required=False)
@click.option('--verbose', '-v', count=True, help='详细输出')
@click.option('--force', is_flag=True, help='强制覆盖已存在的文件')
@click.option('--level', type=click.IntRange(1, 9), default=6,
              help='压缩级别 (1-9, 9为最高压缩)')
@click.pass_context
def compress(ctx, input, output, verbose, force, level):
    """压缩文件"""
    # 设置详细级别
    ctx.obj['verbose'] = verbose
    
    if verbose >= 2:
        click.echo(f"调试信息: 输入={input}, 输出={output}, 级别={level}")
    
    # 检查输出文件是否已存在
    if output and click.Path(exists=True).convert(output, None, None) and not force:
        if not click.confirm(f"文件 {output} 已存在,是否覆盖?"):
            click.echo("取消操作")
            return
    
    # 压缩逻辑...
    click.echo(f"正在压缩 {input} -> {output or 'stdout'}")
    
    # 进度条示例
    with click.progressbar(range(100), label='压缩进度') as bar:
        for i in bar:
            # 模拟压缩过程
            import time
            time.sleep(0.01)
    
    click.secho("✓ 压缩完成", fg='green')

@cli.command()
@click.argument('input', type=click.Path(exists=True))
@click.argument('output_dir', type=click.Path(file_okay=False), required=False)
@click.option('--verbose', '-v', count=True, help='详细输出')
@click.option('--list', '-l', is_flag=True, help='列出压缩包内容')
@click.pass_context
def extract(ctx, input, output_dir, verbose, list):
    """解压文件"""
    ctx.obj['verbose'] = verbose
    
    if list:
        # 列出文件内容
        click.echo(f"压缩包 {input} 内容:")
        # 模拟列出文件
        files = ['file1.txt', 'file2.jpg', 'subdir/']
        for f in files:
            click.echo(f"  {f}")
        return
    
    # 解压逻辑...
    click.echo(f"正在解压 {input} -> {output_dir or '当前目录'}")

# 自动补全支持
@cli.command()
@click.argument('shell', type=click.Choice(['bash', 'zsh', 'fish']))
def completion(shell):
    """生成自动补全脚本"""
    from click.shell_completion import completion_source
    click.echo(completion_source(cli, shell))

if __name__ == '__main__':
    cli()

渐进式迁移策略

# 第一步:混合模式(同时支持argparse和click)
import argparse
import click
import sys

def legacy_argparse_mode():
    """旧版argparse模式"""
    parser = argparse.ArgumentParser()
    parser.add_argument('--old-arg', help='旧参数')
    return parser.parse_args()

def new_click_mode():
    """新版click模式"""
    @click.command()
    @click.option('--new-arg', help='新参数')
    def cli(new_arg):
        click.echo(f"新模式: {new_arg}")
    
    cli()

if __name__ == '__main__':
    # 检查是否有--new-mode标志
    if '--new-mode' in sys.argv:
        sys.argv.remove('--new-mode')
        new_click_mode()
    else:
        # 旧模式
        args = legacy_argparse_mode()
        print(f"旧模式: {args.old_arg}")

3.4 click 高级特性

3.4.1 参数类型和验证

import click

# 自定义参数类型
class EmailParamType(click.ParamType):
    name = "email"
    
    def convert(self, value, param, ctx):
        if '@' not in value:
            self.fail(f"{value} 不是有效的邮箱地址", param, ctx)
        return value.lower()

EMAIL = EmailParamType()

@click.command()
@click.option('--email', type=EMAIL, required=True)
def register(email):
    click.echo(f"注册邮箱: {email}")

# 文件验证
@click.command()
@click.argument('config', 
                type=click.Path(exists=True, dir_okay=False, readable=True))
def load_config(config):
    click.echo(f"加载配置: {config}")

# 密码输入(隐藏输入)
@click.command()
@click.option('--password', prompt=True, hide_input=True, 
              confirmation_prompt=True)
def set_password(password):
    click.echo(f"密码已设置")

3.4.2 彩色输出和表格

import click
from tabulate import tabulate

@click.command()
def list_users():
    """列出用户(带颜色和表格)"""
    users = [
        {'id': 1, 'name': 'Alice', 'active': True, 'role': 'admin'},
        {'id': 2, 'name': 'Bob', 'active': False, 'role': 'user'},
        {'id': 3, 'name': 'Charlie', 'active': True, 'role': 'user'},
    ]
    
    # 带颜色的表格
    headers = ['ID', 'Name', 'Status', 'Role']
    rows = []
    
    for user in users:
        status = (click.style('✓ 活跃', fg='green') 
                  if user['active'] else 
                  click.style('✗ 禁用', fg='red'))
        
        rows.append([
            user['id'],
            click.style(user['name'], fg='blue'),
            status,
            click.style(user['role'], 
                       fg='yellow' if user['role'] == 'admin' else 'white')
        ])
    
    click.echo(tabulate(rows, headers=headers, tablefmt='grid'))

四、打包成专业命令行工具

4.1 项目结构

mytool/
├── mytool/                    # 主包
│   ├── __init__.py
│   ├── cli.py                # CLI入口
│   ├── commands/             # 子命令模块
│   │   ├── __init__.py
│   │   ├── train.py
│   │   ├── predict.py
│   │   └── config.py
│   ├── core.py               # 核心逻辑
│   └── utils/                # 工具函数
│       ├── __init__.py
│       ├── logging.py
│       └── config.py
├── tests/                    # 测试
│   ├── __init__.py
│   └── test_cli.py
├── configs/                  # 配置文件示例
│   ├── default.yaml
│   └── production.yaml
├── pyproject.toml           # 构建配置
├── README.md
└── LICENSE

4.2 pyproject.toml 配置

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "mytool"
version = "1.0.0"
description = "专业的机器学习工具"
readme = "README.md"
authors = [
    {name = "Your Name", email = "your.email@example.com"}
]
license = {text = "MIT"}
keywords = ["machine-learning", "cli", "tools"]
classifiers = [
    "Development Status :: 4 - Beta",
    "Intended Audience :: Developers",
    "License :: OSI Approved :: MIT License",
    "Programming Language :: Python :: 3",
    "Programming Language :: Python :: 3.8",
    "Programming Language :: Python :: 3.9",
    "Programming Language :: Python :: 3.10",
    "Programming Language :: Python :: 3.11",
]

dependencies = [
    "click>=8.0.0",
    "PyYAML>=6.0",
    "rich>=13.0.0",  # 更好的终端输出
    "tqdm>=4.65.0",  # 进度条
]

[project.optional-dependencies]
dev = [
    "pytest>=7.0.0",
    "black>=23.0.0",
    "flake8>=6.0.0",
    "mypy>=1.0.0",
]
ml = [
    "torch>=2.0.0",
    "scikit-learn>=1.2.0",
    "pandas>=2.0.0",
]

[project.urls]
Homepage = "https://github.com/yourname/mytool"
Documentation = "https://mytool.readthedocs.io/"
Repository = "https://github.com/yourname/mytool"
Issues = "https://github.com/yourname/mytool/issues"

[project.scripts]
mytool = "mytool.cli:main"

[project.entry-points."mytool.commands"]
train = "mytool.commands.train:train_cmd"
predict = "mytool.commands.predict:predict_cmd"

[tool.setuptools.packages.find]
include = ["mytool*"]

[tool.black]
line-length = 88
target-version = ['py38', 'py39', 'py310', 'py311']

[tool.isort]
profile = "black"

[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true

4.3 专业CLI入口

# mytool/cli.py
import sys
import click
from importlib import metadata
from pathlib import Path

# 动态加载子命令
def load_commands():
    """动态发现和加载子命令"""
    try:
        # 从entry_points加载
        eps = metadata.entry_points()
        if hasattr(eps, 'select'):
            # Python 3.10+
            command_eps = eps.select(group='mytool.commands')
        else:
            command_eps = eps.get('mytool.commands', [])
        
        commands = {}
        for ep in command_eps:
            try:
                cmd_func = ep.load()
                commands[ep.name] = cmd_func
            except Exception as e:
                click.echo(f"警告: 无法加载命令 {ep.name}: {e}", err=True)
        
        return commands
    except Exception as e:
        click.echo(f"警告: 无法加载动态命令: {e}", err=True)
        return {}

CONTEXT_SETTINGS = {
    'help_option_names': ['-h', '--help'],
    'max_content_width': 120,
    'color': True,
}

@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option(metadata.version('mytool'), 
                     '--version', '-V',
                     message='%(prog)s v%(version)s')
@click.pass_context
def cli(ctx):
    """
    🚀 MyTool - 专业的机器学习工具
    
    官方网站: https://mytool.example.com
    文档: https://docs.mytool.example.com
    """
    # 确保上下文对象
    ctx.ensure_object(dict)
    
    # 初始化配置
    config_path = Path.home() / '.config' / 'mytool' / 'config.yaml'
    ctx.obj['config_path'] = config_path
    
    # 设置默认配置目录
    config_dir = config_path.parent
    config_dir.mkdir(parents=True, exist_ok=True)

# 动态添加子命令
commands = load_commands()
for name, cmd_func in commands.items():
    cli.add_command(cmd_func, name=name)

# 内置命令
@cli.command()
@click.pass_context
def init(ctx):
    """初始化配置文件"""
    config_path = ctx.obj['config_path']
    
    if config_path.exists():
        if not click.confirm(f"配置文件已存在,是否覆盖?"):
            return
    
    # 创建默认配置
    default_config = {
        'model': 'resnet50',
        'training': {
            'epochs': 100,
            'batch_size': 64,
            'learning_rate': 0.001
        },
        'paths': {
            'data': './data',
            'checkpoints': './checkpoints',
            'logs': './logs'
        }
    }
    
    import yaml
    with open(config_path, 'w') as f:
        yaml.dump(default_config, f, default_flow_style=False)
    
    click.echo(f"✅ 配置文件已创建: {config_path}")

@cli.command()
@click.argument('paths', nargs=-1, type=click.Path())
@click.option('--recursive', '-r', is_flag=True, help='递归查找')
def find(paths, recursive):
    """查找配置文件"""
    from mytool.utils.config import find_config_files
    
    files = find_config_files(paths, recursive)
    
    if not files:
        click.echo("未找到配置文件")
        return
    
    for f in files:
        click.echo(f"📄 {f}")

def main():
    """主入口函数"""
    try:
        cli()
    except KeyboardInterrupt:
        click.echo("\n⚠️  操作被用户中断")
        sys.exit(130)  # 标准退出码
    except Exception as e:
        click.echo(f"❌ 错误: {e}", err=True)
        sys.exit(1)

if __name__ == '__main__':
    main()

4.4 安装和使用

本地开发安装:

# 开发模式安装
pip install -e .

# 安装所有依赖(包括可选依赖)
pip install -e ".[dev,ml]"

用户安装:

# 从PyPI安装
pip install mytool

# 使用
mytool --help
mytool init
mytool train --config config.yaml

打包发布:

# 构建
python -m build

# 上传到PyPI
python -m twine upload dist/*

五、最佳实践总结

5.1 设计原则

  1. CLI优先原则:命令行参数覆盖配置文件

  2. 渐进式配置:默认值 ← 配置文件 ← 环境变量 ← 命令行

  3. 模块化设计:命令、参数、逻辑分离

  4. 良好错误处理:友好的错误信息,正确的退出码

  5. 完整文档:帮助文档、示例、使用说明

5.2 常用退出码

退出码 含义 适用场景
0 成功 正常执行完成
1 一般错误 参数错误、文件不存在
2 错误用法 argparse/click显示帮助后
130 被中断 Ctrl+C中断
其他 业务错误 自定义错误类型

5.3 测试策略

# tests/test_cli.py
import pytest
from click.testing import CliRunner
from mytool.cli import cli

def test_version():
    runner = CliRunner()
    result = runner.invoke(cli, ['--version'])
    assert result.exit_code == 0
    assert 'mytool v' in result.output

def test_help():
    runner = CliRunner()
    result = runner.invoke(cli, ['--help'])
    assert result.exit_code == 0
    assert 'MyTool' in result.output

def test_init_command(tmp_path):
    runner = CliRunner()
    with runner.isolated_filesystem(temp_dir=tmp_path):
        result = runner.invoke(cli, ['init'])
        assert result.exit_code == 0
        assert '配置文件已创建' in result.output

5.4 性能考虑

  1. 延迟导入:只在需要时导入大模块

  2. 命令响应速度:子命令加载应快速

  3. 内存管理:及时释放资源

  4. 并发处理:支持并行执行

posted @ 2025-12-28 19:32  kyle_7Qc  阅读(19)  评论(0)    收藏  举报