第14章 - 性能优化与最佳实践

第14章:性能优化与最佳实践

14.1 概述

在处理大规模地理空间数据时,性能优化是不可忽视的重要环节。本章将介绍 GDAL 的性能优化技术和开发最佳实践。

14.1.1 性能瓶颈类型

类型 典型场景 优化方向
I/O 瓶颈 读写大文件、网络数据访问 缓存、分块、压缩
CPU 瓶颈 复杂计算、格式转换 并行处理、算法优化
内存瓶颈 大型数据集处理 分块处理、内存映射
网络瓶颈 云端数据访问 缓存、HTTP 优化

14.1.2 性能优化原则

1. 测量优先:先找到瓶颈,再优化
2. 空间换时间:缓存、预计算
3. 时间换空间:压缩、分块处理
4. 并行化:利用多核 CPU
5. 减少 I/O:批量操作、减少打开/关闭次数

14.2 内存管理

14.2.1 GDAL 缓存配置

from osgeo import gdal

# 设置全局缓存大小(字节)
gdal.SetCacheMax(1024 * 1024 * 1024)  # 1GB

# 获取当前缓存设置
current_cache = gdal.GetCacheMax()
print(f"当前缓存大小: {current_cache / 1024 / 1024:.0f} MB")

# 获取缓存使用情况
used_cache = gdal.GetCacheUsed()
print(f"已使用缓存: {used_cache / 1024 / 1024:.0f} MB")

# 根据系统内存动态设置缓存
import os

def set_optimal_cache():
    """根据系统内存设置最优缓存"""
    
    try:
        import psutil
        total_memory = psutil.virtual_memory().total
        # 使用 25% 的系统内存作为缓存
        cache_size = int(total_memory * 0.25)
    except ImportError:
        # 默认 512 MB
        cache_size = 512 * 1024 * 1024
    
    gdal.SetCacheMax(cache_size)
    print(f"设置缓存: {cache_size / 1024 / 1024:.0f} MB")
    return cache_size

# 使用示例
# set_optimal_cache()

14.2.2 分块读取大文件

from osgeo import gdal
import numpy as np

def read_raster_by_blocks(filepath, processing_func, block_size=512):
    """
    分块读取并处理大型栅格文件
    
    参数:
        filepath: 栅格文件路径
        processing_func: 处理函数,接收 (data, x_off, y_off)
        block_size: 分块大小
    """
    gdal.UseExceptions()
    
    ds = gdal.Open(filepath)
    band = ds.GetRasterBand(1)
    
    width = ds.RasterXSize
    height = ds.RasterYSize
    
    total_blocks = ((height + block_size - 1) // block_size) * \
                   ((width + block_size - 1) // block_size)
    processed = 0
    
    for y_off in range(0, height, block_size):
        for x_off in range(0, width, block_size):
            # 计算实际块大小
            x_size = min(block_size, width - x_off)
            y_size = min(block_size, height - y_off)
            
            # 读取块
            data = band.ReadAsArray(x_off, y_off, x_size, y_size)
            
            # 处理
            processing_func(data, x_off, y_off)
            
            processed += 1
            if processed % 100 == 0:
                print(f"进度: {processed}/{total_blocks} ({processed/total_blocks*100:.1f}%)")
    
    ds = None
    print("处理完成")

def write_raster_by_blocks(filepath, generator_func, width, height, 
                           geotransform, projection, block_size=512):
    """
    分块写入大型栅格文件
    
    参数:
        filepath: 输出文件路径
        generator_func: 数据生成函数,接收 (x_off, y_off, x_size, y_size)
        width, height: 栅格尺寸
        geotransform: 地理变换
        projection: 投影
        block_size: 分块大小
    """
    gdal.UseExceptions()
    
    driver = gdal.GetDriverByName('GTiff')
    
    ds = driver.Create(
        filepath, width, height, 1, gdal.GDT_Float32,
        options=[
            'COMPRESS=LZW',
            'TILED=YES',
            f'BLOCKXSIZE={block_size}',
            f'BLOCKYSIZE={block_size}',
            'BIGTIFF=IF_SAFER'
        ]
    )
    
    ds.SetGeoTransform(geotransform)
    ds.SetProjection(projection)
    
    band = ds.GetRasterBand(1)
    
    for y_off in range(0, height, block_size):
        for x_off in range(0, width, block_size):
            x_size = min(block_size, width - x_off)
            y_size = min(block_size, height - y_off)
            
            # 生成数据
            data = generator_func(x_off, y_off, x_size, y_size)
            
            # 写入
            band.WriteArray(data, x_off, y_off)
    
    ds.FlushCache()
    ds = None
    
    print(f"写入完成: {filepath}")

# 使用示例
def example_processing(data, x_off, y_off):
    """示例处理函数"""
    # 计算统计信息
    mean = np.mean(data)
    return mean

# read_raster_by_blocks('large_file.tif', example_processing)

14.2.3 内存数据集

from osgeo import gdal
import numpy as np

def create_mem_dataset(data, geotransform=None, projection=None):
    """
    创建内存数据集(避免磁盘 I/O)
    """
    gdal.UseExceptions()
    
    if data.ndim == 2:
        bands = 1
        height, width = data.shape
        data = data[np.newaxis, ...]
    else:
        bands, height, width = data.shape
    
    # NumPy 类型到 GDAL 类型映射
    dtype_map = {
        np.dtype('uint8'): gdal.GDT_Byte,
        np.dtype('uint16'): gdal.GDT_UInt16,
        np.dtype('int16'): gdal.GDT_Int16,
        np.dtype('uint32'): gdal.GDT_UInt32,
        np.dtype('int32'): gdal.GDT_Int32,
        np.dtype('float32'): gdal.GDT_Float32,
        np.dtype('float64'): gdal.GDT_Float64,
    }
    
    gdal_type = dtype_map.get(data.dtype, gdal.GDT_Float64)
    
    # 创建内存数据集
    driver = gdal.GetDriverByName('MEM')
    mem_ds = driver.Create('', width, height, bands, gdal_type)
    
    if geotransform:
        mem_ds.SetGeoTransform(geotransform)
    
    if projection:
        mem_ds.SetProjection(projection)
    
    # 写入数据
    for i in range(bands):
        band = mem_ds.GetRasterBand(i + 1)
        band.WriteArray(data[i])
    
    return mem_ds

def process_in_memory(input_path, processing_func):
    """
    将数据加载到内存中处理
    """
    gdal.UseExceptions()
    
    # 读取源数据
    src_ds = gdal.Open(input_path)
    data = src_ds.ReadAsArray()
    geotransform = src_ds.GetGeoTransform()
    projection = src_ds.GetProjection()
    src_ds = None
    
    # 在内存中处理
    result = processing_func(data)
    
    # 创建内存数据集
    mem_ds = create_mem_dataset(result, geotransform, projection)
    
    return mem_ds, result

14.3 并行处理

14.3.1 多线程配置

from osgeo import gdal
import os

# GDAL 内置多线程支持
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')

# 或指定线程数
cpu_count = os.cpu_count()
gdal.SetConfigOption('GDAL_NUM_THREADS', str(cpu_count))

# Warp 操作的多线程
def warp_with_threads(src_path, dst_path, dst_srs):
    """使用多线程进行 Warp 操作"""
    
    warp_options = gdal.WarpOptions(
        format='GTiff',
        dstSRS=dst_srs,
        multithread=True,
        warpMemoryLimit=500,  # MB
        creationOptions=['COMPRESS=LZW', 'TILED=YES']
    )
    
    gdal.Warp(dst_path, src_path, options=warp_options)

14.3.2 进程级并行

from osgeo import gdal
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from functools import partial

def process_single_file(args):
    """处理单个文件(进程内)"""
    
    filepath, output_path, func_name = args
    
    # 每个进程需要重新导入 GDAL
    from osgeo import gdal
    gdal.UseExceptions()
    
    try:
        ds = gdal.Open(filepath)
        data = ds.ReadAsArray()
        
        # 执行处理
        if func_name == 'normalize':
            result = (data - data.min()) / (data.max() - data.min())
        elif func_name == 'sqrt':
            result = np.sqrt(np.maximum(data, 0))
        else:
            result = data
        
        # 保存结果
        driver = gdal.GetDriverByName('GTiff')
        out_ds = driver.Create(
            output_path,
            ds.RasterXSize,
            ds.RasterYSize,
            ds.RasterCount,
            gdal.GDT_Float32,
            ['COMPRESS=LZW']
        )
        out_ds.SetGeoTransform(ds.GetGeoTransform())
        out_ds.SetProjection(ds.GetProjection())
        out_ds.WriteRaster(0, 0, ds.RasterXSize, ds.RasterYSize, result.tobytes())
        
        ds = None
        out_ds = None
        
        return True, filepath
    except Exception as e:
        return False, f"{filepath}: {e}"

def parallel_process_files(input_files, output_dir, func_name, workers=None):
    """并行处理多个文件"""
    
    if workers is None:
        workers = multiprocessing.cpu_count()
    
    # 构建任务列表
    import os
    tasks = []
    for filepath in input_files:
        filename = os.path.basename(filepath)
        output_path = os.path.join(output_dir, f"processed_{filename}")
        tasks.append((filepath, output_path, func_name))
    
    # 并行执行
    results = []
    with ProcessPoolExecutor(max_workers=workers) as executor:
        results = list(executor.map(process_single_file, tasks))
    
    # 统计结果
    success = sum(1 for r in results if r[0])
    failed = sum(1 for r in results if not r[0])
    
    print(f"并行处理完成: 成功 {success}, 失败 {failed}")
    
    return results

# 使用示例
# input_files = ['file1.tif', 'file2.tif', 'file3.tif']
# parallel_process_files(input_files, './output', 'normalize')

14.3.3 瓦片级并行

from osgeo import gdal
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import multiprocessing

def process_tile(args):
    """处理单个瓦片"""
    
    src_path, x_off, y_off, x_size, y_size, func_code = args
    
    from osgeo import gdal
    import numpy as np
    
    ds = gdal.Open(src_path)
    data = ds.GetRasterBand(1).ReadAsArray(x_off, y_off, x_size, y_size)
    ds = None
    
    # 执行处理
    result = eval(func_code)
    
    return x_off, y_off, result

def parallel_process_raster(src_path, dst_path, func_code, 
                            tile_size=512, workers=None):
    """瓦片级并行处理栅格"""
    
    gdal.UseExceptions()
    
    if workers is None:
        workers = multiprocessing.cpu_count()
    
    # 读取源数据信息
    src_ds = gdal.Open(src_path)
    width = src_ds.RasterXSize
    height = src_ds.RasterYSize
    geotransform = src_ds.GetGeoTransform()
    projection = src_ds.GetProjection()
    src_ds = None
    
    # 创建任务列表
    tasks = []
    for y in range(0, height, tile_size):
        for x in range(0, width, tile_size):
            x_size = min(tile_size, width - x)
            y_size = min(tile_size, height - y)
            tasks.append((src_path, x, y, x_size, y_size, func_code))
    
    print(f"总瓦片数: {len(tasks)}, 使用 {workers} 个进程")
    
    # 并行处理
    results = []
    with ProcessPoolExecutor(max_workers=workers) as executor:
        results = list(executor.map(process_tile, tasks))
    
    # 创建输出数据集
    driver = gdal.GetDriverByName('GTiff')
    dst_ds = driver.Create(
        dst_path, width, height, 1, gdal.GDT_Float32,
        ['COMPRESS=LZW', 'TILED=YES', f'BLOCKXSIZE={tile_size}', f'BLOCKYSIZE={tile_size}']
    )
    dst_ds.SetGeoTransform(geotransform)
    dst_ds.SetProjection(projection)
    
    dst_band = dst_ds.GetRasterBand(1)
    
    # 写入结果
    for x_off, y_off, result in results:
        dst_band.WriteArray(result, x_off, y_off)
    
    dst_ds.FlushCache()
    dst_ds = None
    
    print(f"并行处理完成: {dst_path}")

# 使用示例
# 归一化处理
# func_code = "(data - data.min()) / (data.max() - data.min() + 1e-10)"
# parallel_process_raster('input.tif', 'output.tif', func_code)

14.4 I/O 优化

14.4.1 GeoTIFF 优化

from osgeo import gdal

def create_optimized_geotiff(input_path, output_path, 
                             compress='LZW',
                             tiled=True,
                             block_size=256,
                             overview=True):
    """
    创建优化的 GeoTIFF
    """
    gdal.UseExceptions()
    
    options = []
    
    # 压缩选项
    if compress:
        options.append(f'COMPRESS={compress}')
        
        if compress in ['LZW', 'DEFLATE']:
            # 预测器可以提高压缩率
            options.append('PREDICTOR=2')
        elif compress == 'JPEG':
            options.append('JPEG_QUALITY=90')
    
    # 分块选项
    if tiled:
        options.append('TILED=YES')
        options.append(f'BLOCKXSIZE={block_size}')
        options.append(f'BLOCKYSIZE={block_size}')
    
    # 大文件支持
    options.append('BIGTIFF=IF_SAFER')
    
    # 执行转换
    translate_options = gdal.TranslateOptions(
        format='GTiff',
        creationOptions=options
    )
    
    ds = gdal.Translate(output_path, input_path, options=translate_options)
    
    # 创建金字塔
    if overview:
        ds.BuildOverviews('AVERAGE', [2, 4, 8, 16, 32])
    
    ds = None
    
    print(f"优化的 GeoTIFF 创建完成: {output_path}")

def compare_compression_methods(input_path, output_dir):
    """比较不同压缩方法的效果"""
    
    import os
    import time
    
    methods = ['NONE', 'LZW', 'DEFLATE', 'PACKBITS', 'ZSTD']
    results = []
    
    for method in methods:
        output_path = os.path.join(output_dir, f"compressed_{method}.tif")
        
        options = ['TILED=YES']
        if method != 'NONE':
            options.append(f'COMPRESS={method}')
        
        start_time = time.time()
        
        gdal.Translate(
            output_path, input_path,
            options=gdal.TranslateOptions(
                format='GTiff',
                creationOptions=options
            )
        )
        
        elapsed = time.time() - start_time
        file_size = os.path.getsize(output_path)
        
        results.append({
            'method': method,
            'size_mb': file_size / 1024 / 1024,
            'time_s': elapsed
        })
        
        print(f"{method}: {file_size/1024/1024:.1f} MB, {elapsed:.2f}s")
    
    return results

14.4.2 云端数据访问优化

from osgeo import gdal

def configure_cloud_access():
    """配置云端数据访问"""
    
    # HTTP 缓存
    gdal.SetConfigOption('CPL_VSIL_CURL_CACHE_SIZE', '100000000')  # 100MB
    gdal.SetConfigOption('CPL_VSIL_CURL_USE_CACHE', 'YES')
    
    # 连接池
    gdal.SetConfigOption('GDAL_HTTP_MAX_CONNECTIONS', '10')
    
    # 超时设置
    gdal.SetConfigOption('GDAL_HTTP_TIMEOUT', '30')
    
    # S3 配置
    gdal.SetConfigOption('AWS_NO_SIGN_REQUEST', 'YES')  # 公开数据
    gdal.SetConfigOption('AWS_REGION', 'us-west-2')
    
    # GCS 配置
    # gdal.SetConfigOption('GS_NO_SIGN_REQUEST', 'YES')
    
    print("云端访问配置完成")

def read_cog_efficiently(cog_url, bounds=None, resolution=None):
    """
    高效读取 Cloud Optimized GeoTIFF
    """
    gdal.UseExceptions()
    configure_cloud_access()
    
    # 使用 vsicurl 虚拟文件系统
    if not cog_url.startswith('/vsicurl/'):
        if cog_url.startswith('http'):
            cog_url = f'/vsicurl/{cog_url}'
        elif cog_url.startswith('s3://'):
            cog_url = cog_url.replace('s3://', '/vsis3/')
    
    ds = gdal.Open(cog_url)
    
    if ds is None:
        raise Exception(f"无法打开: {cog_url}")
    
    # 如果指定了范围,只读取该范围
    if bounds:
        min_x, min_y, max_x, max_y = bounds
        
        warp_options = gdal.WarpOptions(
            outputBounds=bounds,
            xRes=resolution,
            yRes=resolution,
            format='MEM'  # 输出到内存
        )
        
        ds = gdal.Warp('', ds, options=warp_options)
    
    return ds

def download_range(url, bounds, output_path):
    """只下载指定范围的数据"""
    
    gdal.UseExceptions()
    configure_cloud_access()
    
    src_ds = read_cog_efficiently(url, bounds)
    
    if src_ds:
        gdal.Translate(output_path, src_ds, options=gdal.TranslateOptions(
            format='GTiff',
            creationOptions=['COMPRESS=LZW']
        ))
        src_ds = None
        print(f"下载完成: {output_path}")

# 使用示例
# url = "https://example.com/cog.tif"
# bounds = (116, 39, 117, 40)  # 北京区域
# download_range(url, bounds, 'beijing.tif')

14.4.3 批量 I/O 优化

from osgeo import gdal, ogr

class BatchDatasetManager:
    """批量数据集管理器"""
    
    def __init__(self, max_open=100):
        self.max_open = max_open
        self.datasets = {}
        self.access_order = []
    
    def get_dataset(self, filepath, mode='r'):
        """获取数据集(带缓存)"""
        
        if filepath in self.datasets:
            # 更新访问顺序
            self.access_order.remove(filepath)
            self.access_order.append(filepath)
            return self.datasets[filepath]
        
        # 检查是否需要关闭旧数据集
        if len(self.datasets) >= self.max_open:
            oldest = self.access_order.pop(0)
            self.datasets[oldest] = None
            del self.datasets[oldest]
        
        # 打开新数据集
        access = gdal.GA_Update if mode == 'w' else gdal.GA_ReadOnly
        ds = gdal.Open(filepath, access)
        
        self.datasets[filepath] = ds
        self.access_order.append(filepath)
        
        return ds
    
    def close_all(self):
        """关闭所有数据集"""
        
        for filepath in list(self.datasets.keys()):
            self.datasets[filepath] = None
        
        self.datasets.clear()
        self.access_order.clear()
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close_all()
        return False

def batch_read_values(filepaths, x, y):
    """批量读取多个文件同一位置的值"""
    
    values = []
    
    with BatchDatasetManager(max_open=50) as manager:
        for filepath in filepaths:
            ds = manager.get_dataset(filepath)
            
            if ds:
                # 坐标转像素
                gt = ds.GetGeoTransform()
                px = int((x - gt[0]) / gt[1])
                py = int((y - gt[3]) / gt[5])
                
                # 读取值
                band = ds.GetRasterBand(1)
                value = band.ReadAsArray(px, py, 1, 1)[0, 0]
                values.append(value)
            else:
                values.append(None)
    
    return values

14.5 算法优化

14.5.1 使用 NumPy 向量化

from osgeo import gdal
import numpy as np

def ndvi_vectorized(nir_path, red_path, output_path):
    """
    使用 NumPy 向量化计算 NDVI
    """
    gdal.UseExceptions()
    
    nir_ds = gdal.Open(nir_path)
    red_ds = gdal.Open(red_path)
    
    nir = nir_ds.ReadAsArray().astype(np.float32)
    red = red_ds.ReadAsArray().astype(np.float32)
    
    # 向量化计算(比循环快很多)
    np.seterr(invalid='ignore', divide='ignore')
    ndvi = np.where(
        (nir + red) > 0,
        (nir - red) / (nir + red),
        -9999
    )
    np.seterr(invalid='warn', divide='warn')
    
    # 保存结果
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(
        output_path,
        nir_ds.RasterXSize,
        nir_ds.RasterYSize,
        1,
        gdal.GDT_Float32,
        ['COMPRESS=LZW']
    )
    out_ds.SetGeoTransform(nir_ds.GetGeoTransform())
    out_ds.SetProjection(nir_ds.GetProjection())
    out_ds.GetRasterBand(1).WriteArray(ndvi)
    out_ds.GetRasterBand(1).SetNoDataValue(-9999)
    
    nir_ds = None
    red_ds = None
    out_ds = None

def zonal_statistics_optimized(raster_path, vector_path):
    """
    优化的分区统计
    """
    gdal.UseExceptions()
    
    # 读取栅格
    raster_ds = gdal.Open(raster_path)
    raster_data = raster_ds.ReadAsArray()
    gt = raster_ds.GetGeoTransform()
    nodata = raster_ds.GetRasterBand(1).GetNoDataValue()
    
    # 读取矢量
    vector_ds = ogr.Open(vector_path)
    layer = vector_ds.GetLayer()
    
    results = []
    
    for feature in layer:
        fid = feature.GetFID()
        geom = feature.GetGeometryRef()
        
        # 获取边界
        env = geom.GetEnvelope()
        
        # 转换为像素坐标
        x_min = int((env[0] - gt[0]) / gt[1])
        x_max = int((env[1] - gt[0]) / gt[1])
        y_min = int((env[3] - gt[3]) / gt[5])
        y_max = int((env[2] - gt[3]) / gt[5])
        
        # 确保在范围内
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(raster_ds.RasterXSize, x_max)
        y_max = min(raster_ds.RasterYSize, y_max)
        
        if x_min >= x_max or y_min >= y_max:
            continue
        
        # 提取子区域
        subset = raster_data[y_min:y_max, x_min:x_max]
        
        # 创建掩膜
        # (简化处理,实际应该用几何对象栅格化)
        if nodata is not None:
            valid_mask = subset != nodata
            valid_data = subset[valid_mask]
        else:
            valid_data = subset.flatten()
        
        if len(valid_data) > 0:
            results.append({
                'fid': fid,
                'count': len(valid_data),
                'min': float(np.min(valid_data)),
                'max': float(np.max(valid_data)),
                'mean': float(np.mean(valid_data)),
                'std': float(np.std(valid_data)),
                'sum': float(np.sum(valid_data))
            })
    
    raster_ds = None
    vector_ds = None
    
    return results

14.5.2 使用空间索引

from osgeo import ogr
import json

def create_spatial_index(shp_path):
    """创建空间索引"""
    
    ogr.UseExceptions()
    
    ds = ogr.Open(shp_path, 1)  # 可写模式
    layer = ds.GetLayer()
    
    # 创建 R-tree 空间索引
    layer.CreateSpatialIndex()
    
    ds = None
    print(f"空间索引创建完成: {shp_path}")

def spatial_query_with_index(shp_path, query_geom):
    """使用空间索引进行查询"""
    
    ogr.UseExceptions()
    
    ds = ogr.Open(shp_path)
    layer = ds.GetLayer()
    
    # 设置空间过滤器(会自动使用空间索引)
    layer.SetSpatialFilter(query_geom)
    
    results = []
    for feature in layer:
        results.append({
            'fid': feature.GetFID(),
            'geometry': feature.GetGeometryRef().ExportToWkt()
        })
    
    layer.SetSpatialFilter(None)
    ds = None
    
    return results

def build_rtree_index(features):
    """使用 rtree 库构建空间索引"""
    
    try:
        from rtree import index
    except ImportError:
        print("请安装 rtree: pip install rtree")
        return None
    
    # 创建索引
    idx = index.Index()
    
    for i, feature in enumerate(features):
        geom = feature.get('geometry')
        if geom:
            # 获取边界
            from osgeo import ogr
            g = ogr.CreateGeometryFromWkt(geom)
            env = g.GetEnvelope()
            
            # 插入索引 (id, bbox, object)
            idx.insert(i, (env[0], env[2], env[1], env[3]))
    
    return idx

def query_rtree_index(idx, bbox):
    """查询 R-tree 索引"""
    
    # 返回与 bbox 相交的要素 ID
    return list(idx.intersection(bbox))

14.6 最佳实践

14.6.1 代码组织

"""
gdal_utils.py - GDAL 工具模块最佳实践
"""

from osgeo import gdal, ogr, osr
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, List, Union
from contextlib import contextmanager
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 启用异常
gdal.UseExceptions()
ogr.UseExceptions()

# 默认配置
DEFAULT_CACHE_SIZE = 512 * 1024 * 1024  # 512 MB
DEFAULT_BLOCK_SIZE = 256
DEFAULT_COMPRESSION = 'LZW'

def init_gdal(cache_size=None, threads=None):
    """初始化 GDAL 配置"""
    
    gdal.SetCacheMax(cache_size or DEFAULT_CACHE_SIZE)
    
    if threads:
        gdal.SetConfigOption('GDAL_NUM_THREADS', str(threads))
    else:
        gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
    
    logger.info(f"GDAL 初始化完成,缓存: {gdal.GetCacheMax() / 1024 / 1024:.0f} MB")

@contextmanager
def open_raster(filepath: str, mode: str = 'r'):
    """栅格数据集上下文管理器"""
    
    access = gdal.GA_Update if mode == 'w' else gdal.GA_ReadOnly
    ds = gdal.Open(str(filepath), access)
    
    if ds is None:
        raise IOError(f"无法打开栅格: {filepath}")
    
    try:
        yield ds
    finally:
        ds.FlushCache()
        ds = None

@contextmanager
def open_vector(filepath: str, mode: str = 'r'):
    """矢量数据集上下文管理器"""
    
    update = 1 if mode == 'w' else 0
    ds = ogr.Open(str(filepath), update)
    
    if ds is None:
        raise IOError(f"无法打开矢量: {filepath}")
    
    try:
        yield ds
    finally:
        ds = None

class RasterProcessor:
    """栅格处理器类"""
    
    def __init__(self, filepath: str):
        self.filepath = Path(filepath)
        self._ds = None
    
    def __enter__(self):
        self._ds = gdal.Open(str(self.filepath))
        if self._ds is None:
            raise IOError(f"无法打开: {self.filepath}")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._ds:
            self._ds = None
        return False
    
    @property
    def info(self) -> dict:
        """获取栅格信息"""
        return {
            'width': self._ds.RasterXSize,
            'height': self._ds.RasterYSize,
            'bands': self._ds.RasterCount,
            'driver': self._ds.GetDriver().ShortName,
            'geotransform': self._ds.GetGeoTransform(),
            'projection': self._ds.GetProjection()
        }
    
    def read(self, band: int = 1) -> np.ndarray:
        """读取波段数据"""
        return self._ds.GetRasterBand(band).ReadAsArray()
    
    def read_all(self) -> np.ndarray:
        """读取所有波段"""
        return self._ds.ReadAsArray()
    
    def process(self, func, output_path: str, **kwargs):
        """应用处理函数并保存结果"""
        
        data = self.read_all()
        result = func(data, **kwargs)
        
        self._save_result(result, output_path)
    
    def _save_result(self, data: np.ndarray, output_path: str):
        """保存处理结果"""
        
        if data.ndim == 2:
            bands = 1
            height, width = data.shape
        else:
            bands, height, width = data.shape
        
        driver = gdal.GetDriverByName('GTiff')
        out_ds = driver.Create(
            output_path, width, height, bands,
            gdal.GDT_Float32,
            ['COMPRESS=LZW', 'TILED=YES']
        )
        
        out_ds.SetGeoTransform(self._ds.GetGeoTransform())
        out_ds.SetProjection(self._ds.GetProjection())
        
        if bands == 1:
            out_ds.GetRasterBand(1).WriteArray(data)
        else:
            out_ds.WriteRaster(0, 0, width, height, data.tobytes())
        
        out_ds = None
        logger.info(f"结果保存到: {output_path}")

# 使用示例
# init_gdal()
# with RasterProcessor('input.tif') as processor:
#     print(processor.info)
#     processor.process(lambda x: x * 2, 'output.tif')

14.6.2 错误处理

from osgeo import gdal, ogr
import logging
import traceback
from functools import wraps

logger = logging.getLogger(__name__)

class GDALError(Exception):
    """GDAL 错误基类"""
    pass

class RasterOpenError(GDALError):
    """栅格打开错误"""
    pass

class VectorOpenError(GDALError):
    """矢量打开错误"""
    pass

class ProcessingError(GDALError):
    """处理错误"""
    pass

def gdal_error_handler(err_class, err_num, err_msg):
    """GDAL 错误处理器"""
    
    err_types = {
        gdal.CE_None: 'None',
        gdal.CE_Debug: 'Debug',
        gdal.CE_Warning: 'Warning',
        gdal.CE_Failure: 'Failure',
        gdal.CE_Fatal: 'Fatal'
    }
    
    err_type = err_types.get(err_class, 'Unknown')
    
    if err_class >= gdal.CE_Warning:
        logger.warning(f"GDAL {err_type} ({err_num}): {err_msg}")
    
    if err_class >= gdal.CE_Failure:
        raise GDALError(f"GDAL {err_type}: {err_msg}")

# 注册错误处理器
# gdal.PushErrorHandler(gdal_error_handler)

def safe_gdal_operation(func):
    """安全 GDAL 操作装饰器"""
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except GDALError as e:
            logger.error(f"GDAL 错误: {e}")
            raise
        except Exception as e:
            logger.error(f"未知错误: {e}")
            logger.debug(traceback.format_exc())
            raise ProcessingError(str(e))
    
    return wrapper

@safe_gdal_operation
def safe_open_raster(filepath):
    """安全打开栅格"""
    
    ds = gdal.Open(filepath)
    
    if ds is None:
        raise RasterOpenError(f"无法打开栅格: {filepath}")
    
    return ds

def validate_raster(filepath):
    """验证栅格文件"""
    
    errors = []
    
    try:
        ds = gdal.Open(filepath)
        
        if ds is None:
            errors.append("无法打开文件")
            return errors
        
        # 检查尺寸
        if ds.RasterXSize <= 0 or ds.RasterYSize <= 0:
            errors.append("无效的栅格尺寸")
        
        # 检查波段
        if ds.RasterCount <= 0:
            errors.append("没有波段")
        
        # 检查投影
        if not ds.GetProjection():
            errors.append("缺少投影信息")
        
        # 检查地理变换
        gt = ds.GetGeoTransform()
        if gt == (0, 1, 0, 0, 0, 1):
            errors.append("缺少地理变换信息")
        
        ds = None
    
    except Exception as e:
        errors.append(f"验证过程出错: {e}")
    
    return errors

14.6.3 配置管理

"""
config.py - GDAL 配置管理
"""

import os
import json
from dataclasses import dataclass
from typing import Optional

@dataclass
class GDALConfig:
    """GDAL 配置类"""
    
    cache_size: int = 512 * 1024 * 1024  # 512 MB
    num_threads: int = 0  # 0 表示 ALL_CPUS
    compression: str = 'LZW'
    tiled: bool = True
    block_size: int = 256
    
    # 云端访问
    http_timeout: int = 30
    http_max_connections: int = 10
    curl_cache_size: int = 100 * 1024 * 1024  # 100 MB
    
    # 调试
    debug: bool = False
    log_file: Optional[str] = None
    
    def apply(self):
        """应用配置"""
        
        from osgeo import gdal
        
        # 缓存
        gdal.SetCacheMax(self.cache_size)
        
        # 线程
        if self.num_threads > 0:
            gdal.SetConfigOption('GDAL_NUM_THREADS', str(self.num_threads))
        else:
            gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
        
        # HTTP 配置
        gdal.SetConfigOption('GDAL_HTTP_TIMEOUT', str(self.http_timeout))
        gdal.SetConfigOption('GDAL_HTTP_MAX_CONNECTIONS', str(self.http_max_connections))
        gdal.SetConfigOption('CPL_VSIL_CURL_CACHE_SIZE', str(self.curl_cache_size))
        
        # 调试
        if self.debug:
            gdal.SetConfigOption('CPL_DEBUG', 'ON')
            if self.log_file:
                gdal.SetConfigOption('CPL_LOG', self.log_file)
    
    def get_creation_options(self):
        """获取创建选项"""
        
        options = []
        
        if self.compression:
            options.append(f'COMPRESS={self.compression}')
            if self.compression in ['LZW', 'DEFLATE']:
                options.append('PREDICTOR=2')
        
        if self.tiled:
            options.append('TILED=YES')
            options.append(f'BLOCKXSIZE={self.block_size}')
            options.append(f'BLOCKYSIZE={self.block_size}')
        
        options.append('BIGTIFF=IF_SAFER')
        
        return options
    
    @classmethod
    def from_file(cls, filepath):
        """从文件加载配置"""
        
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        return cls(**data)
    
    def save(self, filepath):
        """保存配置到文件"""
        
        with open(filepath, 'w') as f:
            json.dump(self.__dict__, f, indent=2)

# 全局配置实例
_config = GDALConfig()

def get_config():
    """获取全局配置"""
    return _config

def set_config(config: GDALConfig):
    """设置全局配置"""
    global _config
    _config = config
    config.apply()

# 使用示例
# config = GDALConfig(cache_size=1024*1024*1024, debug=True)
# set_config(config)

14.7 本章小结

本章介绍了 GDAL 的性能优化和最佳实践:

  1. 内存管理:缓存配置、分块处理、内存数据集
  2. 并行处理:多线程配置、进程级并行、瓦片级并行
  3. I/O 优化:GeoTIFF 优化、云端访问、批量操作
  4. 算法优化:NumPy 向量化、空间索引
  5. 最佳实践:代码组织、错误处理、配置管理

14.8 思考与练习

  1. 测量并比较不同缓存大小对处理速度的影响。
  2. 实现一个自适应调整并行度的批处理框架。
  3. 比较不同压缩算法的压缩率和处理速度。
  4. 编写一个性能基准测试工具。
  5. 实现一个支持断点续传的大文件处理器。
posted @ 2025-12-29 10:48  我才是银古  阅读(1)  评论(0)    收藏  举报