第14章 - 性能优化与最佳实践
第14章:性能优化与最佳实践
14.1 概述
在处理大规模地理空间数据时,性能优化是不可忽视的重要环节。本章将介绍 GDAL 的性能优化技术和开发最佳实践。
14.1.1 性能瓶颈类型
| 类型 | 典型场景 | 优化方向 |
|---|---|---|
| I/O 瓶颈 | 读写大文件、网络数据访问 | 缓存、分块、压缩 |
| CPU 瓶颈 | 复杂计算、格式转换 | 并行处理、算法优化 |
| 内存瓶颈 | 大型数据集处理 | 分块处理、内存映射 |
| 网络瓶颈 | 云端数据访问 | 缓存、HTTP 优化 |
14.1.2 性能优化原则
1. 测量优先:先找到瓶颈,再优化
2. 空间换时间:缓存、预计算
3. 时间换空间:压缩、分块处理
4. 并行化:利用多核 CPU
5. 减少 I/O:批量操作、减少打开/关闭次数
14.2 内存管理
14.2.1 GDAL 缓存配置
from osgeo import gdal
# 设置全局缓存大小(字节)
gdal.SetCacheMax(1024 * 1024 * 1024) # 1GB
# 获取当前缓存设置
current_cache = gdal.GetCacheMax()
print(f"当前缓存大小: {current_cache / 1024 / 1024:.0f} MB")
# 获取缓存使用情况
used_cache = gdal.GetCacheUsed()
print(f"已使用缓存: {used_cache / 1024 / 1024:.0f} MB")
# 根据系统内存动态设置缓存
import os
def set_optimal_cache():
"""根据系统内存设置最优缓存"""
try:
import psutil
total_memory = psutil.virtual_memory().total
# 使用 25% 的系统内存作为缓存
cache_size = int(total_memory * 0.25)
except ImportError:
# 默认 512 MB
cache_size = 512 * 1024 * 1024
gdal.SetCacheMax(cache_size)
print(f"设置缓存: {cache_size / 1024 / 1024:.0f} MB")
return cache_size
# 使用示例
# set_optimal_cache()
14.2.2 分块读取大文件
from osgeo import gdal
import numpy as np
def read_raster_by_blocks(filepath, processing_func, block_size=512):
"""
分块读取并处理大型栅格文件
参数:
filepath: 栅格文件路径
processing_func: 处理函数,接收 (data, x_off, y_off)
block_size: 分块大小
"""
gdal.UseExceptions()
ds = gdal.Open(filepath)
band = ds.GetRasterBand(1)
width = ds.RasterXSize
height = ds.RasterYSize
total_blocks = ((height + block_size - 1) // block_size) * \
((width + block_size - 1) // block_size)
processed = 0
for y_off in range(0, height, block_size):
for x_off in range(0, width, block_size):
# 计算实际块大小
x_size = min(block_size, width - x_off)
y_size = min(block_size, height - y_off)
# 读取块
data = band.ReadAsArray(x_off, y_off, x_size, y_size)
# 处理
processing_func(data, x_off, y_off)
processed += 1
if processed % 100 == 0:
print(f"进度: {processed}/{total_blocks} ({processed/total_blocks*100:.1f}%)")
ds = None
print("处理完成")
def write_raster_by_blocks(filepath, generator_func, width, height,
geotransform, projection, block_size=512):
"""
分块写入大型栅格文件
参数:
filepath: 输出文件路径
generator_func: 数据生成函数,接收 (x_off, y_off, x_size, y_size)
width, height: 栅格尺寸
geotransform: 地理变换
projection: 投影
block_size: 分块大小
"""
gdal.UseExceptions()
driver = gdal.GetDriverByName('GTiff')
ds = driver.Create(
filepath, width, height, 1, gdal.GDT_Float32,
options=[
'COMPRESS=LZW',
'TILED=YES',
f'BLOCKXSIZE={block_size}',
f'BLOCKYSIZE={block_size}',
'BIGTIFF=IF_SAFER'
]
)
ds.SetGeoTransform(geotransform)
ds.SetProjection(projection)
band = ds.GetRasterBand(1)
for y_off in range(0, height, block_size):
for x_off in range(0, width, block_size):
x_size = min(block_size, width - x_off)
y_size = min(block_size, height - y_off)
# 生成数据
data = generator_func(x_off, y_off, x_size, y_size)
# 写入
band.WriteArray(data, x_off, y_off)
ds.FlushCache()
ds = None
print(f"写入完成: {filepath}")
# 使用示例
def example_processing(data, x_off, y_off):
"""示例处理函数"""
# 计算统计信息
mean = np.mean(data)
return mean
# read_raster_by_blocks('large_file.tif', example_processing)
14.2.3 内存数据集
from osgeo import gdal
import numpy as np
def create_mem_dataset(data, geotransform=None, projection=None):
"""
创建内存数据集(避免磁盘 I/O)
"""
gdal.UseExceptions()
if data.ndim == 2:
bands = 1
height, width = data.shape
data = data[np.newaxis, ...]
else:
bands, height, width = data.shape
# NumPy 类型到 GDAL 类型映射
dtype_map = {
np.dtype('uint8'): gdal.GDT_Byte,
np.dtype('uint16'): gdal.GDT_UInt16,
np.dtype('int16'): gdal.GDT_Int16,
np.dtype('uint32'): gdal.GDT_UInt32,
np.dtype('int32'): gdal.GDT_Int32,
np.dtype('float32'): gdal.GDT_Float32,
np.dtype('float64'): gdal.GDT_Float64,
}
gdal_type = dtype_map.get(data.dtype, gdal.GDT_Float64)
# 创建内存数据集
driver = gdal.GetDriverByName('MEM')
mem_ds = driver.Create('', width, height, bands, gdal_type)
if geotransform:
mem_ds.SetGeoTransform(geotransform)
if projection:
mem_ds.SetProjection(projection)
# 写入数据
for i in range(bands):
band = mem_ds.GetRasterBand(i + 1)
band.WriteArray(data[i])
return mem_ds
def process_in_memory(input_path, processing_func):
"""
将数据加载到内存中处理
"""
gdal.UseExceptions()
# 读取源数据
src_ds = gdal.Open(input_path)
data = src_ds.ReadAsArray()
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
src_ds = None
# 在内存中处理
result = processing_func(data)
# 创建内存数据集
mem_ds = create_mem_dataset(result, geotransform, projection)
return mem_ds, result
14.3 并行处理
14.3.1 多线程配置
from osgeo import gdal
import os
# GDAL 内置多线程支持
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
# 或指定线程数
cpu_count = os.cpu_count()
gdal.SetConfigOption('GDAL_NUM_THREADS', str(cpu_count))
# Warp 操作的多线程
def warp_with_threads(src_path, dst_path, dst_srs):
"""使用多线程进行 Warp 操作"""
warp_options = gdal.WarpOptions(
format='GTiff',
dstSRS=dst_srs,
multithread=True,
warpMemoryLimit=500, # MB
creationOptions=['COMPRESS=LZW', 'TILED=YES']
)
gdal.Warp(dst_path, src_path, options=warp_options)
14.3.2 进程级并行
from osgeo import gdal
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from functools import partial
def process_single_file(args):
"""处理单个文件(进程内)"""
filepath, output_path, func_name = args
# 每个进程需要重新导入 GDAL
from osgeo import gdal
gdal.UseExceptions()
try:
ds = gdal.Open(filepath)
data = ds.ReadAsArray()
# 执行处理
if func_name == 'normalize':
result = (data - data.min()) / (data.max() - data.min())
elif func_name == 'sqrt':
result = np.sqrt(np.maximum(data, 0))
else:
result = data
# 保存结果
driver = gdal.GetDriverByName('GTiff')
out_ds = driver.Create(
output_path,
ds.RasterXSize,
ds.RasterYSize,
ds.RasterCount,
gdal.GDT_Float32,
['COMPRESS=LZW']
)
out_ds.SetGeoTransform(ds.GetGeoTransform())
out_ds.SetProjection(ds.GetProjection())
out_ds.WriteRaster(0, 0, ds.RasterXSize, ds.RasterYSize, result.tobytes())
ds = None
out_ds = None
return True, filepath
except Exception as e:
return False, f"{filepath}: {e}"
def parallel_process_files(input_files, output_dir, func_name, workers=None):
"""并行处理多个文件"""
if workers is None:
workers = multiprocessing.cpu_count()
# 构建任务列表
import os
tasks = []
for filepath in input_files:
filename = os.path.basename(filepath)
output_path = os.path.join(output_dir, f"processed_{filename}")
tasks.append((filepath, output_path, func_name))
# 并行执行
results = []
with ProcessPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(process_single_file, tasks))
# 统计结果
success = sum(1 for r in results if r[0])
failed = sum(1 for r in results if not r[0])
print(f"并行处理完成: 成功 {success}, 失败 {failed}")
return results
# 使用示例
# input_files = ['file1.tif', 'file2.tif', 'file3.tif']
# parallel_process_files(input_files, './output', 'normalize')
14.3.3 瓦片级并行
from osgeo import gdal
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
def process_tile(args):
"""处理单个瓦片"""
src_path, x_off, y_off, x_size, y_size, func_code = args
from osgeo import gdal
import numpy as np
ds = gdal.Open(src_path)
data = ds.GetRasterBand(1).ReadAsArray(x_off, y_off, x_size, y_size)
ds = None
# 执行处理
result = eval(func_code)
return x_off, y_off, result
def parallel_process_raster(src_path, dst_path, func_code,
tile_size=512, workers=None):
"""瓦片级并行处理栅格"""
gdal.UseExceptions()
if workers is None:
workers = multiprocessing.cpu_count()
# 读取源数据信息
src_ds = gdal.Open(src_path)
width = src_ds.RasterXSize
height = src_ds.RasterYSize
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
src_ds = None
# 创建任务列表
tasks = []
for y in range(0, height, tile_size):
for x in range(0, width, tile_size):
x_size = min(tile_size, width - x)
y_size = min(tile_size, height - y)
tasks.append((src_path, x, y, x_size, y_size, func_code))
print(f"总瓦片数: {len(tasks)}, 使用 {workers} 个进程")
# 并行处理
results = []
with ProcessPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(process_tile, tasks))
# 创建输出数据集
driver = gdal.GetDriverByName('GTiff')
dst_ds = driver.Create(
dst_path, width, height, 1, gdal.GDT_Float32,
['COMPRESS=LZW', 'TILED=YES', f'BLOCKXSIZE={tile_size}', f'BLOCKYSIZE={tile_size}']
)
dst_ds.SetGeoTransform(geotransform)
dst_ds.SetProjection(projection)
dst_band = dst_ds.GetRasterBand(1)
# 写入结果
for x_off, y_off, result in results:
dst_band.WriteArray(result, x_off, y_off)
dst_ds.FlushCache()
dst_ds = None
print(f"并行处理完成: {dst_path}")
# 使用示例
# 归一化处理
# func_code = "(data - data.min()) / (data.max() - data.min() + 1e-10)"
# parallel_process_raster('input.tif', 'output.tif', func_code)
14.4 I/O 优化
14.4.1 GeoTIFF 优化
from osgeo import gdal
def create_optimized_geotiff(input_path, output_path,
compress='LZW',
tiled=True,
block_size=256,
overview=True):
"""
创建优化的 GeoTIFF
"""
gdal.UseExceptions()
options = []
# 压缩选项
if compress:
options.append(f'COMPRESS={compress}')
if compress in ['LZW', 'DEFLATE']:
# 预测器可以提高压缩率
options.append('PREDICTOR=2')
elif compress == 'JPEG':
options.append('JPEG_QUALITY=90')
# 分块选项
if tiled:
options.append('TILED=YES')
options.append(f'BLOCKXSIZE={block_size}')
options.append(f'BLOCKYSIZE={block_size}')
# 大文件支持
options.append('BIGTIFF=IF_SAFER')
# 执行转换
translate_options = gdal.TranslateOptions(
format='GTiff',
creationOptions=options
)
ds = gdal.Translate(output_path, input_path, options=translate_options)
# 创建金字塔
if overview:
ds.BuildOverviews('AVERAGE', [2, 4, 8, 16, 32])
ds = None
print(f"优化的 GeoTIFF 创建完成: {output_path}")
def compare_compression_methods(input_path, output_dir):
"""比较不同压缩方法的效果"""
import os
import time
methods = ['NONE', 'LZW', 'DEFLATE', 'PACKBITS', 'ZSTD']
results = []
for method in methods:
output_path = os.path.join(output_dir, f"compressed_{method}.tif")
options = ['TILED=YES']
if method != 'NONE':
options.append(f'COMPRESS={method}')
start_time = time.time()
gdal.Translate(
output_path, input_path,
options=gdal.TranslateOptions(
format='GTiff',
creationOptions=options
)
)
elapsed = time.time() - start_time
file_size = os.path.getsize(output_path)
results.append({
'method': method,
'size_mb': file_size / 1024 / 1024,
'time_s': elapsed
})
print(f"{method}: {file_size/1024/1024:.1f} MB, {elapsed:.2f}s")
return results
14.4.2 云端数据访问优化
from osgeo import gdal
def configure_cloud_access():
"""配置云端数据访问"""
# HTTP 缓存
gdal.SetConfigOption('CPL_VSIL_CURL_CACHE_SIZE', '100000000') # 100MB
gdal.SetConfigOption('CPL_VSIL_CURL_USE_CACHE', 'YES')
# 连接池
gdal.SetConfigOption('GDAL_HTTP_MAX_CONNECTIONS', '10')
# 超时设置
gdal.SetConfigOption('GDAL_HTTP_TIMEOUT', '30')
# S3 配置
gdal.SetConfigOption('AWS_NO_SIGN_REQUEST', 'YES') # 公开数据
gdal.SetConfigOption('AWS_REGION', 'us-west-2')
# GCS 配置
# gdal.SetConfigOption('GS_NO_SIGN_REQUEST', 'YES')
print("云端访问配置完成")
def read_cog_efficiently(cog_url, bounds=None, resolution=None):
"""
高效读取 Cloud Optimized GeoTIFF
"""
gdal.UseExceptions()
configure_cloud_access()
# 使用 vsicurl 虚拟文件系统
if not cog_url.startswith('/vsicurl/'):
if cog_url.startswith('http'):
cog_url = f'/vsicurl/{cog_url}'
elif cog_url.startswith('s3://'):
cog_url = cog_url.replace('s3://', '/vsis3/')
ds = gdal.Open(cog_url)
if ds is None:
raise Exception(f"无法打开: {cog_url}")
# 如果指定了范围,只读取该范围
if bounds:
min_x, min_y, max_x, max_y = bounds
warp_options = gdal.WarpOptions(
outputBounds=bounds,
xRes=resolution,
yRes=resolution,
format='MEM' # 输出到内存
)
ds = gdal.Warp('', ds, options=warp_options)
return ds
def download_range(url, bounds, output_path):
"""只下载指定范围的数据"""
gdal.UseExceptions()
configure_cloud_access()
src_ds = read_cog_efficiently(url, bounds)
if src_ds:
gdal.Translate(output_path, src_ds, options=gdal.TranslateOptions(
format='GTiff',
creationOptions=['COMPRESS=LZW']
))
src_ds = None
print(f"下载完成: {output_path}")
# 使用示例
# url = "https://example.com/cog.tif"
# bounds = (116, 39, 117, 40) # 北京区域
# download_range(url, bounds, 'beijing.tif')
14.4.3 批量 I/O 优化
from osgeo import gdal, ogr
class BatchDatasetManager:
"""批量数据集管理器"""
def __init__(self, max_open=100):
self.max_open = max_open
self.datasets = {}
self.access_order = []
def get_dataset(self, filepath, mode='r'):
"""获取数据集(带缓存)"""
if filepath in self.datasets:
# 更新访问顺序
self.access_order.remove(filepath)
self.access_order.append(filepath)
return self.datasets[filepath]
# 检查是否需要关闭旧数据集
if len(self.datasets) >= self.max_open:
oldest = self.access_order.pop(0)
self.datasets[oldest] = None
del self.datasets[oldest]
# 打开新数据集
access = gdal.GA_Update if mode == 'w' else gdal.GA_ReadOnly
ds = gdal.Open(filepath, access)
self.datasets[filepath] = ds
self.access_order.append(filepath)
return ds
def close_all(self):
"""关闭所有数据集"""
for filepath in list(self.datasets.keys()):
self.datasets[filepath] = None
self.datasets.clear()
self.access_order.clear()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close_all()
return False
def batch_read_values(filepaths, x, y):
"""批量读取多个文件同一位置的值"""
values = []
with BatchDatasetManager(max_open=50) as manager:
for filepath in filepaths:
ds = manager.get_dataset(filepath)
if ds:
# 坐标转像素
gt = ds.GetGeoTransform()
px = int((x - gt[0]) / gt[1])
py = int((y - gt[3]) / gt[5])
# 读取值
band = ds.GetRasterBand(1)
value = band.ReadAsArray(px, py, 1, 1)[0, 0]
values.append(value)
else:
values.append(None)
return values
14.5 算法优化
14.5.1 使用 NumPy 向量化
from osgeo import gdal
import numpy as np
def ndvi_vectorized(nir_path, red_path, output_path):
"""
使用 NumPy 向量化计算 NDVI
"""
gdal.UseExceptions()
nir_ds = gdal.Open(nir_path)
red_ds = gdal.Open(red_path)
nir = nir_ds.ReadAsArray().astype(np.float32)
red = red_ds.ReadAsArray().astype(np.float32)
# 向量化计算(比循环快很多)
np.seterr(invalid='ignore', divide='ignore')
ndvi = np.where(
(nir + red) > 0,
(nir - red) / (nir + red),
-9999
)
np.seterr(invalid='warn', divide='warn')
# 保存结果
driver = gdal.GetDriverByName('GTiff')
out_ds = driver.Create(
output_path,
nir_ds.RasterXSize,
nir_ds.RasterYSize,
1,
gdal.GDT_Float32,
['COMPRESS=LZW']
)
out_ds.SetGeoTransform(nir_ds.GetGeoTransform())
out_ds.SetProjection(nir_ds.GetProjection())
out_ds.GetRasterBand(1).WriteArray(ndvi)
out_ds.GetRasterBand(1).SetNoDataValue(-9999)
nir_ds = None
red_ds = None
out_ds = None
def zonal_statistics_optimized(raster_path, vector_path):
"""
优化的分区统计
"""
gdal.UseExceptions()
# 读取栅格
raster_ds = gdal.Open(raster_path)
raster_data = raster_ds.ReadAsArray()
gt = raster_ds.GetGeoTransform()
nodata = raster_ds.GetRasterBand(1).GetNoDataValue()
# 读取矢量
vector_ds = ogr.Open(vector_path)
layer = vector_ds.GetLayer()
results = []
for feature in layer:
fid = feature.GetFID()
geom = feature.GetGeometryRef()
# 获取边界
env = geom.GetEnvelope()
# 转换为像素坐标
x_min = int((env[0] - gt[0]) / gt[1])
x_max = int((env[1] - gt[0]) / gt[1])
y_min = int((env[3] - gt[3]) / gt[5])
y_max = int((env[2] - gt[3]) / gt[5])
# 确保在范围内
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(raster_ds.RasterXSize, x_max)
y_max = min(raster_ds.RasterYSize, y_max)
if x_min >= x_max or y_min >= y_max:
continue
# 提取子区域
subset = raster_data[y_min:y_max, x_min:x_max]
# 创建掩膜
# (简化处理,实际应该用几何对象栅格化)
if nodata is not None:
valid_mask = subset != nodata
valid_data = subset[valid_mask]
else:
valid_data = subset.flatten()
if len(valid_data) > 0:
results.append({
'fid': fid,
'count': len(valid_data),
'min': float(np.min(valid_data)),
'max': float(np.max(valid_data)),
'mean': float(np.mean(valid_data)),
'std': float(np.std(valid_data)),
'sum': float(np.sum(valid_data))
})
raster_ds = None
vector_ds = None
return results
14.5.2 使用空间索引
from osgeo import ogr
import json
def create_spatial_index(shp_path):
"""创建空间索引"""
ogr.UseExceptions()
ds = ogr.Open(shp_path, 1) # 可写模式
layer = ds.GetLayer()
# 创建 R-tree 空间索引
layer.CreateSpatialIndex()
ds = None
print(f"空间索引创建完成: {shp_path}")
def spatial_query_with_index(shp_path, query_geom):
"""使用空间索引进行查询"""
ogr.UseExceptions()
ds = ogr.Open(shp_path)
layer = ds.GetLayer()
# 设置空间过滤器(会自动使用空间索引)
layer.SetSpatialFilter(query_geom)
results = []
for feature in layer:
results.append({
'fid': feature.GetFID(),
'geometry': feature.GetGeometryRef().ExportToWkt()
})
layer.SetSpatialFilter(None)
ds = None
return results
def build_rtree_index(features):
"""使用 rtree 库构建空间索引"""
try:
from rtree import index
except ImportError:
print("请安装 rtree: pip install rtree")
return None
# 创建索引
idx = index.Index()
for i, feature in enumerate(features):
geom = feature.get('geometry')
if geom:
# 获取边界
from osgeo import ogr
g = ogr.CreateGeometryFromWkt(geom)
env = g.GetEnvelope()
# 插入索引 (id, bbox, object)
idx.insert(i, (env[0], env[2], env[1], env[3]))
return idx
def query_rtree_index(idx, bbox):
"""查询 R-tree 索引"""
# 返回与 bbox 相交的要素 ID
return list(idx.intersection(bbox))
14.6 最佳实践
14.6.1 代码组织
"""
gdal_utils.py - GDAL 工具模块最佳实践
"""
from osgeo import gdal, ogr, osr
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, List, Union
from contextlib import contextmanager
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 启用异常
gdal.UseExceptions()
ogr.UseExceptions()
# 默认配置
DEFAULT_CACHE_SIZE = 512 * 1024 * 1024 # 512 MB
DEFAULT_BLOCK_SIZE = 256
DEFAULT_COMPRESSION = 'LZW'
def init_gdal(cache_size=None, threads=None):
"""初始化 GDAL 配置"""
gdal.SetCacheMax(cache_size or DEFAULT_CACHE_SIZE)
if threads:
gdal.SetConfigOption('GDAL_NUM_THREADS', str(threads))
else:
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
logger.info(f"GDAL 初始化完成,缓存: {gdal.GetCacheMax() / 1024 / 1024:.0f} MB")
@contextmanager
def open_raster(filepath: str, mode: str = 'r'):
"""栅格数据集上下文管理器"""
access = gdal.GA_Update if mode == 'w' else gdal.GA_ReadOnly
ds = gdal.Open(str(filepath), access)
if ds is None:
raise IOError(f"无法打开栅格: {filepath}")
try:
yield ds
finally:
ds.FlushCache()
ds = None
@contextmanager
def open_vector(filepath: str, mode: str = 'r'):
"""矢量数据集上下文管理器"""
update = 1 if mode == 'w' else 0
ds = ogr.Open(str(filepath), update)
if ds is None:
raise IOError(f"无法打开矢量: {filepath}")
try:
yield ds
finally:
ds = None
class RasterProcessor:
"""栅格处理器类"""
def __init__(self, filepath: str):
self.filepath = Path(filepath)
self._ds = None
def __enter__(self):
self._ds = gdal.Open(str(self.filepath))
if self._ds is None:
raise IOError(f"无法打开: {self.filepath}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._ds:
self._ds = None
return False
@property
def info(self) -> dict:
"""获取栅格信息"""
return {
'width': self._ds.RasterXSize,
'height': self._ds.RasterYSize,
'bands': self._ds.RasterCount,
'driver': self._ds.GetDriver().ShortName,
'geotransform': self._ds.GetGeoTransform(),
'projection': self._ds.GetProjection()
}
def read(self, band: int = 1) -> np.ndarray:
"""读取波段数据"""
return self._ds.GetRasterBand(band).ReadAsArray()
def read_all(self) -> np.ndarray:
"""读取所有波段"""
return self._ds.ReadAsArray()
def process(self, func, output_path: str, **kwargs):
"""应用处理函数并保存结果"""
data = self.read_all()
result = func(data, **kwargs)
self._save_result(result, output_path)
def _save_result(self, data: np.ndarray, output_path: str):
"""保存处理结果"""
if data.ndim == 2:
bands = 1
height, width = data.shape
else:
bands, height, width = data.shape
driver = gdal.GetDriverByName('GTiff')
out_ds = driver.Create(
output_path, width, height, bands,
gdal.GDT_Float32,
['COMPRESS=LZW', 'TILED=YES']
)
out_ds.SetGeoTransform(self._ds.GetGeoTransform())
out_ds.SetProjection(self._ds.GetProjection())
if bands == 1:
out_ds.GetRasterBand(1).WriteArray(data)
else:
out_ds.WriteRaster(0, 0, width, height, data.tobytes())
out_ds = None
logger.info(f"结果保存到: {output_path}")
# 使用示例
# init_gdal()
# with RasterProcessor('input.tif') as processor:
# print(processor.info)
# processor.process(lambda x: x * 2, 'output.tif')
14.6.2 错误处理
from osgeo import gdal, ogr
import logging
import traceback
from functools import wraps
logger = logging.getLogger(__name__)
class GDALError(Exception):
"""GDAL 错误基类"""
pass
class RasterOpenError(GDALError):
"""栅格打开错误"""
pass
class VectorOpenError(GDALError):
"""矢量打开错误"""
pass
class ProcessingError(GDALError):
"""处理错误"""
pass
def gdal_error_handler(err_class, err_num, err_msg):
"""GDAL 错误处理器"""
err_types = {
gdal.CE_None: 'None',
gdal.CE_Debug: 'Debug',
gdal.CE_Warning: 'Warning',
gdal.CE_Failure: 'Failure',
gdal.CE_Fatal: 'Fatal'
}
err_type = err_types.get(err_class, 'Unknown')
if err_class >= gdal.CE_Warning:
logger.warning(f"GDAL {err_type} ({err_num}): {err_msg}")
if err_class >= gdal.CE_Failure:
raise GDALError(f"GDAL {err_type}: {err_msg}")
# 注册错误处理器
# gdal.PushErrorHandler(gdal_error_handler)
def safe_gdal_operation(func):
"""安全 GDAL 操作装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except GDALError as e:
logger.error(f"GDAL 错误: {e}")
raise
except Exception as e:
logger.error(f"未知错误: {e}")
logger.debug(traceback.format_exc())
raise ProcessingError(str(e))
return wrapper
@safe_gdal_operation
def safe_open_raster(filepath):
"""安全打开栅格"""
ds = gdal.Open(filepath)
if ds is None:
raise RasterOpenError(f"无法打开栅格: {filepath}")
return ds
def validate_raster(filepath):
"""验证栅格文件"""
errors = []
try:
ds = gdal.Open(filepath)
if ds is None:
errors.append("无法打开文件")
return errors
# 检查尺寸
if ds.RasterXSize <= 0 or ds.RasterYSize <= 0:
errors.append("无效的栅格尺寸")
# 检查波段
if ds.RasterCount <= 0:
errors.append("没有波段")
# 检查投影
if not ds.GetProjection():
errors.append("缺少投影信息")
# 检查地理变换
gt = ds.GetGeoTransform()
if gt == (0, 1, 0, 0, 0, 1):
errors.append("缺少地理变换信息")
ds = None
except Exception as e:
errors.append(f"验证过程出错: {e}")
return errors
14.6.3 配置管理
"""
config.py - GDAL 配置管理
"""
import os
import json
from dataclasses import dataclass
from typing import Optional
@dataclass
class GDALConfig:
"""GDAL 配置类"""
cache_size: int = 512 * 1024 * 1024 # 512 MB
num_threads: int = 0 # 0 表示 ALL_CPUS
compression: str = 'LZW'
tiled: bool = True
block_size: int = 256
# 云端访问
http_timeout: int = 30
http_max_connections: int = 10
curl_cache_size: int = 100 * 1024 * 1024 # 100 MB
# 调试
debug: bool = False
log_file: Optional[str] = None
def apply(self):
"""应用配置"""
from osgeo import gdal
# 缓存
gdal.SetCacheMax(self.cache_size)
# 线程
if self.num_threads > 0:
gdal.SetConfigOption('GDAL_NUM_THREADS', str(self.num_threads))
else:
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
# HTTP 配置
gdal.SetConfigOption('GDAL_HTTP_TIMEOUT', str(self.http_timeout))
gdal.SetConfigOption('GDAL_HTTP_MAX_CONNECTIONS', str(self.http_max_connections))
gdal.SetConfigOption('CPL_VSIL_CURL_CACHE_SIZE', str(self.curl_cache_size))
# 调试
if self.debug:
gdal.SetConfigOption('CPL_DEBUG', 'ON')
if self.log_file:
gdal.SetConfigOption('CPL_LOG', self.log_file)
def get_creation_options(self):
"""获取创建选项"""
options = []
if self.compression:
options.append(f'COMPRESS={self.compression}')
if self.compression in ['LZW', 'DEFLATE']:
options.append('PREDICTOR=2')
if self.tiled:
options.append('TILED=YES')
options.append(f'BLOCKXSIZE={self.block_size}')
options.append(f'BLOCKYSIZE={self.block_size}')
options.append('BIGTIFF=IF_SAFER')
return options
@classmethod
def from_file(cls, filepath):
"""从文件加载配置"""
with open(filepath, 'r') as f:
data = json.load(f)
return cls(**data)
def save(self, filepath):
"""保存配置到文件"""
with open(filepath, 'w') as f:
json.dump(self.__dict__, f, indent=2)
# 全局配置实例
_config = GDALConfig()
def get_config():
"""获取全局配置"""
return _config
def set_config(config: GDALConfig):
"""设置全局配置"""
global _config
_config = config
config.apply()
# 使用示例
# config = GDALConfig(cache_size=1024*1024*1024, debug=True)
# set_config(config)
14.7 本章小结
本章介绍了 GDAL 的性能优化和最佳实践:
- 内存管理:缓存配置、分块处理、内存数据集
- 并行处理:多线程配置、进程级并行、瓦片级并行
- I/O 优化:GeoTIFF 优化、云端访问、批量操作
- 算法优化:NumPy 向量化、空间索引
- 最佳实践:代码组织、错误处理、配置管理
14.8 思考与练习
- 测量并比较不同缓存大小对处理速度的影响。
- 实现一个自适应调整并行度的批处理框架。
- 比较不同压缩算法的压缩率和处理速度。
- 编写一个性能基准测试工具。
- 实现一个支持断点续传的大文件处理器。

浙公网安备 33010602011771号