🚀argparse 进阶实战指南:从脚本到专业命令行工具
一、argparse + 配置文件融合方案
1.1 为什么需要配置文件?
真实场景:模型训练任务
# 开发阶段:频繁调整参数
python train.py --model resnet50 --lr 0.001 --batch-size 64 \
--epochs 100 --optimizer adam --dropout 0.5 \
--weight-decay 0.0001 --data-augment True \
--early-stop 10 --save-dir ./checkpoints \
--log-interval 100 --val-interval 5
问题:
-
命令太长,容易出错
-
参数难以复用
-
版本管理困难
-
团队协作不便
解决方案:配置文件 + CLI 覆盖
1.2 配置文件格式选择
| 格式 | 优点 | 缺点 | 推荐场景 |
|---|---|---|---|
| YAML | 可读性好,广泛使用 | 需要PyYAML库 | 机器学习、配置管理 |
| TOML | Python内置支持,简洁 | 嵌套结构稍复杂 | 项目配置、工具设置 |
| JSON | 标准,无需额外依赖 | 不能写注释,冗长 | API配置、Web应用 |
| .py文件 | 灵活,可编程 | 有安全风险 | 动态配置、复杂逻辑 |
推荐:YAML(最流行)或 TOML(Python原生)
1.3 完整实现方案
基础版:CLI优先原则
import argparse
import yaml
import os
from pathlib import Path
def parse_args():
"""解析命令行参数,支持配置文件"""
parser = argparse.ArgumentParser(
description='模型训练程序',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 配置文件参数(放在最前面)
parser.add_argument(
'-c', '--config',
type=Path,
help='配置文件路径 (YAML/TOML/JSON)',
default=None
)
# 模型参数组
model_group = parser.add_argument_group('模型参数')
model_group.add_argument('--model', type=str, help='模型名称')
model_group.add_argument('--pretrained', action='store_true', help='使用预训练权重')
# 训练参数组
train_group = parser.add_argument_group('训练参数')
train_group.add_argument('--epochs', type=int, help='训练轮数')
train_group.add_argument('--batch-size', type=int, help='批次大小')
train_group.add_argument('--lr', type=float, help='学习率')
# 系统参数组
system_group = parser.add_argument_group('系统参数')
system_group.add_argument('--device', type=str, help='运行设备')
system_group.add_argument('--workers', type=int, help='数据加载线程数')
# 解析命令行参数
cli_args = parser.parse_args()
# 最终配置 = 默认值 ← 配置文件 ← 命令行参数
config = {}
# 1. 加载配置文件(如果有)
if cli_args.config:
config = load_config(cli_args.config)
# 2. 合并配置:命令行参数覆盖配置文件
final_args = merge_configs(parser, config, cli_args)
# 3. 验证必需参数
validate_required_args(final_args)
return final_args
def load_config(config_path):
"""加载配置文件,支持多种格式"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"配置文件不存在: {config_path}")
suffix = config_path.suffix.lower()
if suffix == '.yaml' or suffix == '.yml':
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
elif suffix == '.toml':
import tomllib
with open(config_path, 'rb') as f:
return tomllib.load(f)
elif suffix == '.json':
import json
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
else:
raise ValueError(f"不支持的配置文件格式: {suffix}")
def merge_configs(parser, config_dict, cli_args):
"""
合并配置的优先级:
1. 命令行参数(最高)
2. 配置文件
3. argparse默认值(最低)
"""
# 获取所有参数的定义
arg_defaults = {
action.dest: action.default
for action in parser._actions
if action.dest != 'help'
}
# 从配置文件和命令行参数构建最终配置
final_args = argparse.Namespace()
for dest in arg_defaults:
# 1. 尝试从命令行获取
cli_value = getattr(cli_args, dest, None)
# 2. 如果命令行没有,尝试从配置文件获取
if cli_value is None and dest in config_dict:
config_value = config_dict[dest]
setattr(final_args, dest, config_value)
# 3. 如果命令行有值(即使为None),使用命令行值
else:
setattr(final_args, dest, cli_value)
return final_args
def validate_required_args(args):
"""验证必需参数"""
required_fields = ['model', 'epochs', 'batch_size', 'lr']
missing = []
for field in required_fields:
if getattr(args, field, None) is None:
missing.append(field)
if missing:
raise ValueError(f"缺少必需参数: {', '.join(missing)}")
if __name__ == '__main__':
try:
args = parse_args()
print("最终配置:")
for key, value in vars(args).items():
print(f" {key}: {value}")
except Exception as e:
print(f"错误: {e}")
高级版:支持环境变量和默认配置文件
import os
from typing import Dict, Any, Optional
class ConfigManager:
"""高级配置管理器"""
def __init__(self):
self.config_hierarchy = []
def load_configs(self, cli_args) -> Dict[str, Any]:
"""按优先级加载配置"""
configs = []
# 1. 默认配置(代码硬编码)
configs.append(self.get_default_config())
# 2. 环境配置(根据环境变量)
env_config = self.load_env_config()
if env_config:
configs.append(env_config)
# 3. 用户配置文件(~/.config/)
user_config = self.load_user_config()
if user_config:
configs.append(user_config)
# 4. 项目配置文件(./config/)
project_config = self.load_project_config()
if project_config:
configs.append(project_config)
# 5. 命令行指定的配置文件
if cli_args.config:
file_config = self.load_config_file(cli_args.config)
configs.append(file_config)
# 6. 命令行参数(最高优先级)
cli_config = self.cli_args_to_dict(cli_args)
configs.append(cli_config)
# 合并所有配置
final_config = {}
for cfg in configs:
final_config.update(cfg)
return final_config
def get_default_config(self) -> Dict[str, Any]:
"""返回默认配置"""
return {
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'workers': 4,
'seed': 42,
'log_level': 'INFO'
}
def load_env_config(self) -> Dict[str, Any]:
"""从环境变量加载配置"""
config = {}
# 环境变量前缀
prefix = 'MYAPP_'
for key, value in os.environ.items():
if key.startswith(prefix):
config_key = key[len(prefix):].lower()
# 尝试转换为合适的类型
config[config_key] = self.convert_value(value)
return config
def convert_value(self, value: str) -> Any:
"""智能类型转换"""
# 布尔值
if value.lower() in ('true', 'yes', '1'):
return True
elif value.lower() in ('false', 'no', '0'):
return False
# 数字
try:
if '.' in value:
return float(value)
else:
return int(value)
except ValueError:
pass
# 列表(逗号分隔)
if ',' in value:
return [self.convert_value(v.strip()) for v in value.split(',')]
# 字符串
return value
# 使用示例
config_manager = ConfigManager()
config = config_manager.load_configs(args)
1.4 配置文件示例
config.yaml:
# 训练配置
model: resnet50
pretrained: true
epochs: 100
batch_size: 64
lr: 0.001
optimizer: adam
weight_decay: 0.0001
# 数据配置
data:
root: ./data
train_split: 0.8
val_split: 0.2
augmentation: true
# 系统配置
device: cuda
workers: 8
seed: 42
log_level: INFO
使用方式:
# 基础使用
python train.py --config config.yaml
# CLI覆盖配置
python train.py --config config.yaml --lr 0.01 --batch-size 32
# 完全使用CLI(无配置文件)
python train.py --model resnet50 --epochs 50 --lr 0.001
二、argparse + logging 专业日志系统
2.1 为什么需要专业的日志?
print()的局限性:
-
无法控制输出级别
-
不能输出到文件
-
没有时间戳
-
无法区分不同模块
-
生产环境无法使用
专业的日志系统应具备:
-
多级别日志(DEBUG/INFO/WARNING/ERROR)
-
同时输出到控制台和文件
-
日志轮转(防止日志文件过大)
-
结构化日志(JSON格式)
-
性能监控
2.2 完整实现方案
import argparse
import logging
import sys
from logging.handlers import RotatingFileHandler
import json
from datetime import datetime
from pathlib import Path
class StructuredFormatter(logging.Formatter):
"""结构化日志格式化器(JSON格式)"""
def format(self, record):
log_record = {
'timestamp': datetime.utcnow().isoformat() + 'Z',
'level': record.levelname,
'logger': record.name,
'module': record.module,
'function': record.funcName,
'line': record.lineno,
'message': record.getMessage(),
}
# 添加额外字段
if hasattr(record, 'extra'):
log_record.update(record.extra)
# 异常信息
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
return json.dumps(log_record, ensure_ascii=False)
def setup_logging(args):
"""
配置日志系统
日志级别对应关系:
--verbose 0 (默认): WARNING
--verbose 1 : INFO
--verbose 2 : DEBUG
"""
# 确定日志级别
level_mapping = {
0: logging.WARNING,
1: logging.INFO,
2: logging.DEBUG
}
log_level = level_mapping.get(args.verbose, logging.DEBUG)
# 配置根日志记录器
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# 清除已有的处理器
root_logger.handlers.clear()
# 1. 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
if args.verbose >= 2: # 调试模式使用详细格式
console_format = '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
else:
console_format = '%(asctime)s - %(levelname)s - %(message)s'
console_formatter = logging.Formatter(console_format, datefmt='%H:%M:%S')
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# 2. 文件处理器(如果指定了日志文件)
if args.log_file:
log_path = Path(args.log_file)
# 创建日志目录
log_path.parent.mkdir(parents=True, exist_ok=True)
# 使用轮转文件处理器(每个文件10MB,最多保留5个)
file_handler = RotatingFileHandler(
log_path,
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
)
# 文件日志使用JSON格式
file_handler.setFormatter(StructuredFormatter())
root_logger.addHandler(file_handler)
# 3. 错误处理器(分离错误日志)
if args.error_log:
error_handler = logging.FileHandler(
args.error_log,
encoding='utf-8'
)
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(StructuredFormatter())
root_logger.addHandler(error_handler)
# 配置第三方库的日志级别
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
return root_logger
def add_logging_args(parser):
"""添加日志相关参数到argparser"""
log_group = parser.add_argument_group('日志选项')
log_group.add_argument(
'-v', '--verbose',
action='count',
default=0,
help='日志详细程度: -v=INFO, -vv=DEBUG'
)
log_group.add_argument(
'--log-file',
type=Path,
help='日志文件路径'
)
log_group.add_argument(
'--error-log',
type=Path,
help='错误日志文件路径(单独存储ERROR级别日志)'
)
log_group.add_argument(
'--log-json',
action='store_true',
help='控制台也输出JSON格式日志'
)
log_group.add_argument(
'--quiet',
action='store_true',
help='静默模式,不输出任何日志'
)
return parser
# 使用示例
def main():
parser = argparse.ArgumentParser(description='带日志系统的程序')
parser = add_logging_args(parser)
# 添加其他参数...
parser.add_argument('--input', required=True, help='输入文件')
args = parser.parse_args()
# 设置日志
logger = setup_logging(args)
# 使用日志
logger.info("程序启动")
logger.debug(f"参数: {args}")
try:
# 业务逻辑
logger.info(f"处理文件: {args.input}")
# ...
logger.info("处理完成")
except Exception as e:
logger.error(f"处理失败: {e}", exc_info=True)
raise
class MetricsCollector:
"""性能指标收集器"""
def __init__(self, logger):
self.logger = logger
self.metrics = {}
self.start_time = datetime.now()
def start_timer(self, name):
"""开始计时"""
self.metrics[name] = {
'start': datetime.now(),
'end': None,
'duration': None
}
def stop_timer(self, name):
"""结束计时"""
if name in self.metrics:
self.metrics[name]['end'] = datetime.now()
duration = (self.metrics[name]['end'] -
self.metrics[name]['start']).total_seconds()
self.metrics[name]['duration'] = duration
# 记录指标
self.logger.info(
f"计时器 '{name}' 完成",
extra={'metric': name, 'duration_seconds': duration}
)
def log_metrics(self):
"""记录所有指标"""
total_time = (datetime.now() - self.start_time).total_seconds()
self.logger.info(
"性能指标汇总",
extra={
'total_duration_seconds': total_time,
'metrics': self.metrics
}
)
2.3 高级日志功能
2.3.1 日志上下文管理器
import contextlib
@contextlib.contextmanager
def log_section(logger, message, level=logging.INFO):
"""日志分段上下文管理器"""
logger.log(level, f"┌─ 开始: {message}")
start_time = datetime.now()
try:
yield
except Exception as e:
logger.error(f"✗ 失败: {message}")
raise
finally:
duration = (datetime.now() - start_time).total_seconds()
logger.log(level, f"└─ 完成: {message} ({duration:.2f}s)")
# 使用
with log_section(logger, "数据加载"):
data = load_data(args.input)
2.3.2 进度条日志
class ProgressLogger:
"""带进度条的日志"""
def __init__(self, logger, total, desc="进度"):
self.logger = logger
self.total = total
self.desc = desc
self.current = 0
self.start_time = datetime.now()
logger.info(f"开始 {desc} (共 {total} 项)")
def update(self, n=1):
"""更新进度"""
self.current += n
percent = (self.current / self.total) * 100
if self.current % 10 == 0 or self.current == self.total:
elapsed = (datetime.now() - self.start_time).total_seconds()
eta = (elapsed / self.current) * (self.total - self.current)
self.logger.info(
f"{self.desc}: {self.current}/{self.total} "
f"({percent:.1f}%) ETA: {eta:.1f}s"
)
def finish(self):
"""完成进度"""
total_time = (datetime.now() - self.start_time).total_seconds()
self.logger.info(
f"完成 {self.desc}, 总耗时: {total_time:.2f}s"
)
# 使用
progress = ProgressLogger(logger, len(data), "训练批次")
for batch in data:
train_batch(batch)
progress.update()
progress.finish()
三、argparse → click 迁移指南
3.1 click 是什么?为什么选择它?
click 的优势:
-
装饰器语法:更简洁,更Pythonic
-
自动补全:支持bash/zsh自动补全
-
更好的错误提示:颜色高亮,更友好的错误信息
-
参数验证:内置丰富的验证器
-
子命令支持:更优雅的子命令实现
-
进度条:内置进度条支持
-
颜色输出:终端彩色输出
何时迁移?
-
❌ 保持 argparse:内部脚本、简单工具、依赖少
-
✅ 迁移到 click:对外发布、复杂CLI、需要良好用户体验
3.2 argparse → click 对照表
| argparse 功能 | click 等价实现 | 说明 |
|---|---|---|
add_argument() |
@click.option() |
可选参数 |
| 位置参数 | @click.argument() |
必需参数 |
add_subparsers() |
@click.group() |
子命令 |
action='store_true' |
is_flag=True |
布尔标志 |
type=int |
type=click.INT |
类型验证 |
choices=['a','b'] |
type=click.Choice() |
选择限制 |
nargs='+' |
nargs=-1 |
多个参数 |
3.3 迁移示例
argparse 版本
import argparse
def main():
parser = argparse.ArgumentParser(description='文件处理工具')
parser.add_argument('input', help='输入文件')
parser.add_argument('output', nargs='?', help='输出文件')
parser.add_argument('--verbose', '-v', action='count', default=0)
parser.add_argument('--force', action='store_true', help='强制覆盖')
subparsers = parser.add_subparsers(dest='command', required=True)
# 压缩命令
compress_parser = subparsers.add_parser('compress', help='压缩文件')
compress_parser.add_argument('--level', type=int, choices=range(1, 10), default=6)
# 解压命令
extract_parser = subparsers.add_parser('extract', help='解压文件')
args = parser.parse_args()
if args.verbose >= 2:
print(f"调试模式: {args}")
# 处理逻辑...
if __name__ == '__main__':
main()
click 迁移版本
import click
import sys
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option('1.0.0', prog_name='文件处理工具')
@click.pass_context
def cli(ctx):
"""文件处理工具 - 压缩和解压文件"""
# 初始化上下文
ctx.ensure_object(dict)
ctx.obj['verbose'] = 0
@cli.command()
@click.argument('input', type=click.Path(exists=True))
@click.argument('output', type=click.Path(), required=False)
@click.option('--verbose', '-v', count=True, help='详细输出')
@click.option('--force', is_flag=True, help='强制覆盖已存在的文件')
@click.option('--level', type=click.IntRange(1, 9), default=6,
help='压缩级别 (1-9, 9为最高压缩)')
@click.pass_context
def compress(ctx, input, output, verbose, force, level):
"""压缩文件"""
# 设置详细级别
ctx.obj['verbose'] = verbose
if verbose >= 2:
click.echo(f"调试信息: 输入={input}, 输出={output}, 级别={level}")
# 检查输出文件是否已存在
if output and click.Path(exists=True).convert(output, None, None) and not force:
if not click.confirm(f"文件 {output} 已存在,是否覆盖?"):
click.echo("取消操作")
return
# 压缩逻辑...
click.echo(f"正在压缩 {input} -> {output or 'stdout'}")
# 进度条示例
with click.progressbar(range(100), label='压缩进度') as bar:
for i in bar:
# 模拟压缩过程
import time
time.sleep(0.01)
click.secho("✓ 压缩完成", fg='green')
@cli.command()
@click.argument('input', type=click.Path(exists=True))
@click.argument('output_dir', type=click.Path(file_okay=False), required=False)
@click.option('--verbose', '-v', count=True, help='详细输出')
@click.option('--list', '-l', is_flag=True, help='列出压缩包内容')
@click.pass_context
def extract(ctx, input, output_dir, verbose, list):
"""解压文件"""
ctx.obj['verbose'] = verbose
if list:
# 列出文件内容
click.echo(f"压缩包 {input} 内容:")
# 模拟列出文件
files = ['file1.txt', 'file2.jpg', 'subdir/']
for f in files:
click.echo(f" {f}")
return
# 解压逻辑...
click.echo(f"正在解压 {input} -> {output_dir or '当前目录'}")
# 自动补全支持
@cli.command()
@click.argument('shell', type=click.Choice(['bash', 'zsh', 'fish']))
def completion(shell):
"""生成自动补全脚本"""
from click.shell_completion import completion_source
click.echo(completion_source(cli, shell))
if __name__ == '__main__':
cli()
渐进式迁移策略
# 第一步:混合模式(同时支持argparse和click)
import argparse
import click
import sys
def legacy_argparse_mode():
"""旧版argparse模式"""
parser = argparse.ArgumentParser()
parser.add_argument('--old-arg', help='旧参数')
return parser.parse_args()
def new_click_mode():
"""新版click模式"""
@click.command()
@click.option('--new-arg', help='新参数')
def cli(new_arg):
click.echo(f"新模式: {new_arg}")
cli()
if __name__ == '__main__':
# 检查是否有--new-mode标志
if '--new-mode' in sys.argv:
sys.argv.remove('--new-mode')
new_click_mode()
else:
# 旧模式
args = legacy_argparse_mode()
print(f"旧模式: {args.old_arg}")
3.4 click 高级特性
3.4.1 参数类型和验证
import click
# 自定义参数类型
class EmailParamType(click.ParamType):
name = "email"
def convert(self, value, param, ctx):
if '@' not in value:
self.fail(f"{value} 不是有效的邮箱地址", param, ctx)
return value.lower()
EMAIL = EmailParamType()
@click.command()
@click.option('--email', type=EMAIL, required=True)
def register(email):
click.echo(f"注册邮箱: {email}")
# 文件验证
@click.command()
@click.argument('config',
type=click.Path(exists=True, dir_okay=False, readable=True))
def load_config(config):
click.echo(f"加载配置: {config}")
# 密码输入(隐藏输入)
@click.command()
@click.option('--password', prompt=True, hide_input=True,
confirmation_prompt=True)
def set_password(password):
click.echo(f"密码已设置")
3.4.2 彩色输出和表格
import click
from tabulate import tabulate
@click.command()
def list_users():
"""列出用户(带颜色和表格)"""
users = [
{'id': 1, 'name': 'Alice', 'active': True, 'role': 'admin'},
{'id': 2, 'name': 'Bob', 'active': False, 'role': 'user'},
{'id': 3, 'name': 'Charlie', 'active': True, 'role': 'user'},
]
# 带颜色的表格
headers = ['ID', 'Name', 'Status', 'Role']
rows = []
for user in users:
status = (click.style('✓ 活跃', fg='green')
if user['active'] else
click.style('✗ 禁用', fg='red'))
rows.append([
user['id'],
click.style(user['name'], fg='blue'),
status,
click.style(user['role'],
fg='yellow' if user['role'] == 'admin' else 'white')
])
click.echo(tabulate(rows, headers=headers, tablefmt='grid'))
四、打包成专业命令行工具
4.1 项目结构
mytool/
├── mytool/ # 主包
│ ├── __init__.py
│ ├── cli.py # CLI入口
│ ├── commands/ # 子命令模块
│ │ ├── __init__.py
│ │ ├── train.py
│ │ ├── predict.py
│ │ └── config.py
│ ├── core.py # 核心逻辑
│ └── utils/ # 工具函数
│ ├── __init__.py
│ ├── logging.py
│ └── config.py
├── tests/ # 测试
│ ├── __init__.py
│ └── test_cli.py
├── configs/ # 配置文件示例
│ ├── default.yaml
│ └── production.yaml
├── pyproject.toml # 构建配置
├── README.md
└── LICENSE
4.2 pyproject.toml 配置
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "mytool"
version = "1.0.0"
description = "专业的机器学习工具"
readme = "README.md"
authors = [
{name = "Your Name", email = "your.email@example.com"}
]
license = {text = "MIT"}
keywords = ["machine-learning", "cli", "tools"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
dependencies = [
"click>=8.0.0",
"PyYAML>=6.0",
"rich>=13.0.0", # 更好的终端输出
"tqdm>=4.65.0", # 进度条
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"black>=23.0.0",
"flake8>=6.0.0",
"mypy>=1.0.0",
]
ml = [
"torch>=2.0.0",
"scikit-learn>=1.2.0",
"pandas>=2.0.0",
]
[project.urls]
Homepage = "https://github.com/yourname/mytool"
Documentation = "https://mytool.readthedocs.io/"
Repository = "https://github.com/yourname/mytool"
Issues = "https://github.com/yourname/mytool/issues"
[project.scripts]
mytool = "mytool.cli:main"
[project.entry-points."mytool.commands"]
train = "mytool.commands.train:train_cmd"
predict = "mytool.commands.predict:predict_cmd"
[tool.setuptools.packages.find]
include = ["mytool*"]
[tool.black]
line-length = 88
target-version = ['py38', 'py39', 'py310', 'py311']
[tool.isort]
profile = "black"
[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
4.3 专业CLI入口
# mytool/cli.py
import sys
import click
from importlib import metadata
from pathlib import Path
# 动态加载子命令
def load_commands():
"""动态发现和加载子命令"""
try:
# 从entry_points加载
eps = metadata.entry_points()
if hasattr(eps, 'select'):
# Python 3.10+
command_eps = eps.select(group='mytool.commands')
else:
command_eps = eps.get('mytool.commands', [])
commands = {}
for ep in command_eps:
try:
cmd_func = ep.load()
commands[ep.name] = cmd_func
except Exception as e:
click.echo(f"警告: 无法加载命令 {ep.name}: {e}", err=True)
return commands
except Exception as e:
click.echo(f"警告: 无法加载动态命令: {e}", err=True)
return {}
CONTEXT_SETTINGS = {
'help_option_names': ['-h', '--help'],
'max_content_width': 120,
'color': True,
}
@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option(metadata.version('mytool'),
'--version', '-V',
message='%(prog)s v%(version)s')
@click.pass_context
def cli(ctx):
"""
🚀 MyTool - 专业的机器学习工具
官方网站: https://mytool.example.com
文档: https://docs.mytool.example.com
"""
# 确保上下文对象
ctx.ensure_object(dict)
# 初始化配置
config_path = Path.home() / '.config' / 'mytool' / 'config.yaml'
ctx.obj['config_path'] = config_path
# 设置默认配置目录
config_dir = config_path.parent
config_dir.mkdir(parents=True, exist_ok=True)
# 动态添加子命令
commands = load_commands()
for name, cmd_func in commands.items():
cli.add_command(cmd_func, name=name)
# 内置命令
@cli.command()
@click.pass_context
def init(ctx):
"""初始化配置文件"""
config_path = ctx.obj['config_path']
if config_path.exists():
if not click.confirm(f"配置文件已存在,是否覆盖?"):
return
# 创建默认配置
default_config = {
'model': 'resnet50',
'training': {
'epochs': 100,
'batch_size': 64,
'learning_rate': 0.001
},
'paths': {
'data': './data',
'checkpoints': './checkpoints',
'logs': './logs'
}
}
import yaml
with open(config_path, 'w') as f:
yaml.dump(default_config, f, default_flow_style=False)
click.echo(f"✅ 配置文件已创建: {config_path}")
@cli.command()
@click.argument('paths', nargs=-1, type=click.Path())
@click.option('--recursive', '-r', is_flag=True, help='递归查找')
def find(paths, recursive):
"""查找配置文件"""
from mytool.utils.config import find_config_files
files = find_config_files(paths, recursive)
if not files:
click.echo("未找到配置文件")
return
for f in files:
click.echo(f"📄 {f}")
def main():
"""主入口函数"""
try:
cli()
except KeyboardInterrupt:
click.echo("\n⚠️ 操作被用户中断")
sys.exit(130) # 标准退出码
except Exception as e:
click.echo(f"❌ 错误: {e}", err=True)
sys.exit(1)
if __name__ == '__main__':
main()
4.4 安装和使用
本地开发安装:
# 开发模式安装
pip install -e .
# 安装所有依赖(包括可选依赖)
pip install -e ".[dev,ml]"
用户安装:
# 从PyPI安装
pip install mytool
# 使用
mytool --help
mytool init
mytool train --config config.yaml
打包发布:
# 构建
python -m build
# 上传到PyPI
python -m twine upload dist/*
五、最佳实践总结
5.1 设计原则
-
CLI优先原则:命令行参数覆盖配置文件
-
渐进式配置:默认值 ← 配置文件 ← 环境变量 ← 命令行
-
模块化设计:命令、参数、逻辑分离
-
良好错误处理:友好的错误信息,正确的退出码
-
完整文档:帮助文档、示例、使用说明
5.2 常用退出码
| 退出码 | 含义 | 适用场景 |
|---|---|---|
| 0 | 成功 | 正常执行完成 |
| 1 | 一般错误 | 参数错误、文件不存在 |
| 2 | 错误用法 | argparse/click显示帮助后 |
| 130 | 被中断 | Ctrl+C中断 |
| 其他 | 业务错误 | 自定义错误类型 |
5.3 测试策略
# tests/test_cli.py
import pytest
from click.testing import CliRunner
from mytool.cli import cli
def test_version():
runner = CliRunner()
result = runner.invoke(cli, ['--version'])
assert result.exit_code == 0
assert 'mytool v' in result.output
def test_help():
runner = CliRunner()
result = runner.invoke(cli, ['--help'])
assert result.exit_code == 0
assert 'MyTool' in result.output
def test_init_command(tmp_path):
runner = CliRunner()
with runner.isolated_filesystem(temp_dir=tmp_path):
result = runner.invoke(cli, ['init'])
assert result.exit_code == 0
assert '配置文件已创建' in result.output
5.4 性能考虑
-
延迟导入:只在需要时导入大模块
-
命令响应速度:子命令加载应快速
-
内存管理:及时释放资源
-
并发处理:支持并行执行

浙公网安备 33010602011771号