第07章 - Python绑定开发指南
第07章:Python绑定开发指南
7.1 Python GDAL 简介
7.1.1 为什么选择 Python
Python 是使用 GDAL 最流行的编程语言,原因包括:
| 优势 | 说明 |
|---|---|
| 易学易用 | Python 语法简洁,学习曲线平缓 |
| 生态丰富 | NumPy、SciPy、Pandas 等科学计算库完美配合 |
| 快速开发 | 脚本式开发,无需编译 |
| 社区活跃 | 大量教程、示例和问答资源 |
| 行业标准 | 地理空间领域的首选语言 |
7.1.2 Python GDAL 生态
┌─────────────────────────────────────────────────────────────┐
│ Python GIS 生态系统 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 高层库(基于 GDAL/OGR) │
│ ├── Rasterio : 现代化栅格处理接口 │
│ ├── Fiona : 现代化矢量处理接口 │
│ ├── GeoPandas : 空间数据分析(基于 Fiona) │
│ └── xarray : 多维数组处理(支持 NetCDF) │
│ │
│ 中间层 │
│ ├── Shapely : 几何操作(基于 GEOS) │
│ ├── PyProj : 坐标转换(基于 PROJ) │
│ └── Cartopy : 地图绑制 │
│ │
│ 底层库 │
│ ├── GDAL/OGR : 数据读写核心 │
│ ├── GEOS : 几何计算 │
│ └── PROJ : 投影计算 │
│ │
└─────────────────────────────────────────────────────────────┘
7.2 安装与配置
7.2.1 使用 Conda 安装(推荐)
# 创建新环境
conda create -n gdal_dev python=3.10
# 激活环境
conda activate gdal_dev
# 安装 GDAL
conda install -c conda-forge gdal
# 安装相关库
conda install -c conda-forge rasterio fiona geopandas shapely pyproj
# 验证安装
python -c "from osgeo import gdal; print(f'GDAL {gdal.__version__}')"
7.2.2 使用 pip 安装
# 首先确保系统已安装 GDAL
# Ubuntu
sudo apt install gdal-bin libgdal-dev
# 获取 GDAL 版本
gdal-config --version
# 安装匹配版本的 Python 绑定
pip install GDAL==$(gdal-config --version)
# 如果遇到编译错误
export CPLUS_INCLUDE_PATH=/usr/include/gdal
export C_INCLUDE_PATH=/usr/include/gdal
pip install GDAL==$(gdal-config --version)
7.2.3 环境配置
from osgeo import gdal, ogr, osr
# 配置选项
gdal.SetConfigOption('GDAL_DATA', '/path/to/gdal/data')
gdal.SetConfigOption('PROJ_LIB', '/path/to/proj/data')
# 启用异常处理(强烈推荐)
gdal.UseExceptions()
ogr.UseExceptions()
# 设置缓存大小(MB)
gdal.SetCacheMax(512 * 1024 * 1024) # 512 MB
# 设置线程数
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
# 调试模式
gdal.SetConfigOption('CPL_DEBUG', 'ON')
gdal.SetConfigOption('CPL_LOG', '/tmp/gdal_debug.log')
7.3 栅格数据处理
7.3.1 读取栅格数据
from osgeo import gdal
import numpy as np
def read_raster(filepath):
"""读取栅格数据的完整示例"""
gdal.UseExceptions()
# 打开数据集
ds = gdal.Open(filepath)
# 基本信息
print(f"驱动: {ds.GetDriver().ShortName}")
print(f"尺寸: {ds.RasterXSize} x {ds.RasterYSize}")
print(f"波段数: {ds.RasterCount}")
# 地理变换
gt = ds.GetGeoTransform()
print(f"原点: ({gt[0]}, {gt[3]})")
print(f"像素大小: ({gt[1]}, {gt[5]})")
# 投影
proj = ds.GetProjection()
srs = osr.SpatialReference(wkt=proj)
print(f"坐标系: {srs.GetName()}")
# 读取波段
band = ds.GetRasterBand(1)
print(f"数据类型: {gdal.GetDataTypeName(band.DataType)}")
print(f"无效值: {band.GetNoDataValue()}")
# 读取为 NumPy 数组
data = band.ReadAsArray()
print(f"数组形状: {data.shape}")
print(f"值范围: {data.min()} - {data.max()}")
# 读取所有波段
all_data = ds.ReadAsArray() # shape: (bands, height, width)
ds = None # 关闭数据集
return data
# 使用
# data = read_raster('example.tif')
7.3.2 创建栅格数据
from osgeo import gdal, osr
import numpy as np
def create_raster(filepath, data, geotransform, epsg=4326, nodata=-9999):
"""创建栅格数据"""
gdal.UseExceptions()
# 获取数组形状
if data.ndim == 2:
bands = 1
height, width = data.shape
data = data[np.newaxis, ...] # 添加波段维度
else:
bands, height, width = data.shape
# 确定数据类型
dtype_map = {
np.dtype('uint8'): gdal.GDT_Byte,
np.dtype('int16'): gdal.GDT_Int16,
np.dtype('uint16'): gdal.GDT_UInt16,
np.dtype('int32'): gdal.GDT_Int32,
np.dtype('uint32'): gdal.GDT_UInt32,
np.dtype('float32'): gdal.GDT_Float32,
np.dtype('float64'): gdal.GDT_Float64,
}
gdal_dtype = dtype_map.get(data.dtype, gdal.GDT_Float64)
# 创建驱动
driver = gdal.GetDriverByName('GTiff')
# 创建选项
options = [
'COMPRESS=LZW',
'TILED=YES',
'BLOCKXSIZE=256',
'BLOCKYSIZE=256',
]
# 创建数据集
ds = driver.Create(filepath, width, height, bands, gdal_dtype, options)
# 设置地理变换
ds.SetGeoTransform(geotransform)
# 设置投影
srs = osr.SpatialReference()
srs.ImportFromEPSG(epsg)
ds.SetProjection(srs.ExportToWkt())
# 写入数据
for i in range(bands):
band = ds.GetRasterBand(i + 1)
band.WriteArray(data[i])
band.SetNoDataValue(nodata)
band.ComputeStatistics(False)
# 刷新并关闭
ds.FlushCache()
ds = None
print(f"创建完成: {filepath}")
# 使用示例
height, width = 1000, 1000
data = np.random.rand(height, width).astype(np.float32)
# 地理变换:覆盖北京区域
geotransform = (116.0, 0.001, 0, 40.0, 0, -0.001)
create_raster('output.tif', data, geotransform, epsg=4326)
7.3.3 栅格处理操作
from osgeo import gdal
import numpy as np
def raster_calculator(input_files, output_file, calc_func, **kwargs):
"""
栅格计算器
参数:
input_files: 输入文件字典 {'A': 'file1.tif', 'B': 'file2.tif'}
output_file: 输出文件路径
calc_func: 计算函数,接收字典参数
"""
gdal.UseExceptions()
# 打开所有输入文件
datasets = {}
reference_ds = None
for key, filepath in input_files.items():
ds = gdal.Open(filepath)
datasets[key] = ds
if reference_ds is None:
reference_ds = ds
# 获取参考信息
width = reference_ds.RasterXSize
height = reference_ds.RasterYSize
geotransform = reference_ds.GetGeoTransform()
projection = reference_ds.GetProjection()
# 读取数据
data = {}
for key, ds in datasets.items():
data[key] = ds.GetRasterBand(1).ReadAsArray().astype(np.float32)
# 执行计算
result = calc_func(data, **kwargs)
# 创建输出
driver = gdal.GetDriverByName('GTiff')
out_ds = driver.Create(output_file, width, height, 1, gdal.GDT_Float32,
['COMPRESS=LZW'])
out_ds.SetGeoTransform(geotransform)
out_ds.SetProjection(projection)
out_band = out_ds.GetRasterBand(1)
out_band.WriteArray(result)
out_band.SetNoDataValue(-9999)
out_band.ComputeStatistics(False)
# 关闭所有数据集
for ds in datasets.values():
ds = None
out_ds = None
# NDVI 计算示例
def ndvi_calc(data):
"""计算 NDVI"""
red = data['RED']
nir = data['NIR']
# 避免除零
denominator = nir + red
ndvi = np.where(denominator > 0, (nir - red) / denominator, -9999)
return ndvi
# 使用
# raster_calculator(
# {'RED': 'red_band.tif', 'NIR': 'nir_band.tif'},
# 'ndvi.tif',
# ndvi_calc
# )
7.3.4 使用 GDAL Warp
from osgeo import gdal
def reproject_raster(src_path, dst_path, dst_crs,
resolution=None, bounds=None, resampling='bilinear'):
"""高级重投影函数"""
gdal.UseExceptions()
resample_methods = {
'nearest': gdal.GRA_NearestNeighbour,
'bilinear': gdal.GRA_Bilinear,
'cubic': gdal.GRA_Cubic,
'cubicspline': gdal.GRA_CubicSpline,
'lanczos': gdal.GRA_Lanczos,
'average': gdal.GRA_Average,
'mode': gdal.GRA_Mode,
}
warp_options = gdal.WarpOptions(
format='GTiff',
dstSRS=dst_crs,
resampleAlg=resample_methods.get(resampling, gdal.GRA_Bilinear),
creationOptions=['COMPRESS=LZW', 'TILED=YES'],
xRes=resolution,
yRes=resolution,
outputBounds=bounds,
multithread=True,
warpMemoryLimit=512,
)
ds = gdal.Warp(dst_path, src_path, options=warp_options)
ds = None
print(f"重投影完成: {dst_path}")
def clip_raster(src_path, dst_path, cutline_path=None, bounds=None):
"""裁剪栅格"""
gdal.UseExceptions()
warp_options = gdal.WarpOptions(
format='GTiff',
cutlineDSName=cutline_path,
cropToCutline=True if cutline_path else False,
outputBounds=bounds,
dstNodata=-9999,
creationOptions=['COMPRESS=LZW'],
)
ds = gdal.Warp(dst_path, src_path, options=warp_options)
ds = None
def mosaic_rasters(src_files, dst_path, nodata=None):
"""镶嵌多个栅格"""
gdal.UseExceptions()
vrt_options = gdal.BuildVRTOptions(
resampleAlg='nearest',
srcNodata=nodata,
VRTNodata=nodata,
)
vrt = gdal.BuildVRT('', src_files, options=vrt_options)
translate_options = gdal.TranslateOptions(
format='GTiff',
creationOptions=['COMPRESS=LZW', 'TILED=YES', 'BIGTIFF=IF_SAFER'],
)
ds = gdal.Translate(dst_path, vrt, options=translate_options)
vrt = None
ds = None
# 使用示例
# reproject_raster('input.tif', 'output.tif', 'EPSG:3857')
# clip_raster('input.tif', 'clipped.tif', cutline_path='boundary.shp')
# mosaic_rasters(['tile1.tif', 'tile2.tif'], 'mosaic.tif')
7.4 矢量数据处理
7.4.1 读取矢量数据
from osgeo import ogr, osr
def read_vector(filepath):
"""读取矢量数据"""
ogr.UseExceptions()
# 打开数据源
ds = ogr.Open(filepath)
print(f"数据源: {ds.GetDescription()}")
print(f"图层数: {ds.GetLayerCount()}")
# 遍历图层
for i in range(ds.GetLayerCount()):
layer = ds.GetLayer(i)
print(f"\n图层 {i}: {layer.GetName()}")
print(f" 要素数: {layer.GetFeatureCount()}")
print(f" 几何类型: {ogr.GeometryTypeToName(layer.GetGeomType())}")
# 获取范围
extent = layer.GetExtent()
print(f" 范围: {extent}")
# 获取空间参考
srs = layer.GetSpatialRef()
if srs:
print(f" 坐标系: {srs.GetName()}")
# 获取字段定义
layer_defn = layer.GetLayerDefn()
print(f" 字段数: {layer_defn.GetFieldCount()}")
for j in range(layer_defn.GetFieldCount()):
field_defn = layer_defn.GetFieldDefn(j)
print(f" - {field_defn.GetName()}: {field_defn.GetTypeName()}")
ds = None
def read_features(filepath, layer_name=None, where=None, bbox=None):
"""读取要素并返回列表"""
ogr.UseExceptions()
ds = ogr.Open(filepath)
if layer_name:
layer = ds.GetLayerByName(layer_name)
else:
layer = ds.GetLayer(0)
# 设置过滤器
if where:
layer.SetAttributeFilter(where)
if bbox:
layer.SetSpatialFilterRect(*bbox)
features = []
for feature in layer:
feat_dict = {
'fid': feature.GetFID(),
'geometry': feature.GetGeometryRef().ExportToWkt() if feature.GetGeometryRef() else None,
}
# 获取所有属性
layer_defn = layer.GetLayerDefn()
for i in range(layer_defn.GetFieldCount()):
field_name = layer_defn.GetFieldDefn(i).GetName()
feat_dict[field_name] = feature.GetField(i)
features.append(feat_dict)
ds = None
return features
# 使用示例
# read_vector('data.shp')
# features = read_features('data.shp', where="population > 1000000")
7.4.2 创建矢量数据
from osgeo import ogr, osr
def create_point_shapefile(filepath, points, fields, epsg=4326):
"""
创建点图层
参数:
filepath: 输出路径
points: 点列表,每个点是字典 {'x': ..., 'y': ..., 'field1': ..., ...}
fields: 字段定义列表 [(name, type, width), ...]
epsg: 坐标系 EPSG 代码
"""
ogr.UseExceptions()
# 创建数据源
driver = ogr.GetDriverByName('ESRI Shapefile')
# 删除已存在的文件
if driver.DeleteDataSource(filepath) != 0:
pass
ds = driver.CreateDataSource(filepath)
# 创建空间参考
srs = osr.SpatialReference()
srs.ImportFromEPSG(epsg)
# 创建图层
layer = ds.CreateLayer('points', srs, ogr.wkbPoint)
# 创建字段
for field_def in fields:
name, field_type = field_def[0], field_def[1]
field = ogr.FieldDefn(name, field_type)
if field_type == ogr.OFTString and len(field_def) > 2:
field.SetWidth(field_def[2])
layer.CreateField(field)
# 获取图层定义
layer_defn = layer.GetLayerDefn()
# 添加要素
for point_data in points:
feature = ogr.Feature(layer_defn)
# 设置几何
point = ogr.Geometry(ogr.wkbPoint)
point.AddPoint(point_data['x'], point_data['y'])
feature.SetGeometry(point)
# 设置属性
for field_def in fields:
field_name = field_def[0]
if field_name in point_data:
feature.SetField(field_name, point_data[field_name])
layer.CreateFeature(feature)
feature = None
ds = None
print(f"创建完成: {filepath}")
# 使用示例
points = [
{'x': 116.4, 'y': 39.9, 'name': '北京', 'population': 21540000},
{'x': 121.5, 'y': 31.2, 'name': '上海', 'population': 24280000},
{'x': 113.3, 'y': 23.1, 'name': '广州', 'population': 18680000},
]
fields = [
('name', ogr.OFTString, 50),
('population', ogr.OFTInteger64),
]
# create_point_shapefile('cities.shp', points, fields)
7.4.3 空间分析
from osgeo import ogr, osr
def buffer_layer(input_path, output_path, distance, dissolve=False):
"""对图层进行缓冲区分析"""
ogr.UseExceptions()
# 打开输入
src_ds = ogr.Open(input_path)
src_layer = src_ds.GetLayer()
# 创建输出
driver = ogr.GetDriverByName('ESRI Shapefile')
dst_ds = driver.CreateDataSource(output_path)
dst_layer = dst_ds.CreateLayer(
'buffer',
src_layer.GetSpatialRef(),
ogr.wkbPolygon
)
# 复制字段
src_defn = src_layer.GetLayerDefn()
for i in range(src_defn.GetFieldCount()):
dst_layer.CreateField(src_defn.GetFieldDefn(i))
dst_defn = dst_layer.GetLayerDefn()
# 执行缓冲
if dissolve:
# 溶解所有缓冲区
union_geom = None
for src_feature in src_layer:
geom = src_feature.GetGeometryRef()
if geom:
buffer_geom = geom.Buffer(distance)
if union_geom is None:
union_geom = buffer_geom
else:
union_geom = union_geom.Union(buffer_geom)
if union_geom:
dst_feature = ogr.Feature(dst_defn)
dst_feature.SetGeometry(union_geom)
dst_layer.CreateFeature(dst_feature)
else:
for src_feature in src_layer:
geom = src_feature.GetGeometryRef()
if geom:
buffer_geom = geom.Buffer(distance)
dst_feature = ogr.Feature(dst_defn)
dst_feature.SetGeometry(buffer_geom)
for i in range(src_defn.GetFieldCount()):
dst_feature.SetField(i, src_feature.GetField(i))
dst_layer.CreateFeature(dst_feature)
src_ds = None
dst_ds = None
def spatial_join(target_path, join_path, output_path, how='intersects'):
"""空间连接"""
ogr.UseExceptions()
target_ds = ogr.Open(target_path)
target_layer = target_ds.GetLayer()
join_ds = ogr.Open(join_path)
join_layer = join_ds.GetLayer()
# 创建输出
driver = ogr.GetDriverByName('ESRI Shapefile')
out_ds = driver.CreateDataSource(output_path)
out_layer = out_ds.CreateLayer(
'joined',
target_layer.GetSpatialRef(),
target_layer.GetGeomType()
)
# 复制目标图层字段
target_defn = target_layer.GetLayerDefn()
for i in range(target_defn.GetFieldCount()):
out_layer.CreateField(target_defn.GetFieldDefn(i))
# 添加连接图层字段(加前缀)
join_defn = join_layer.GetLayerDefn()
for i in range(join_defn.GetFieldCount()):
field_defn = join_defn.GetFieldDefn(i)
new_field = ogr.FieldDefn(f'join_{field_defn.GetName()}', field_defn.GetType())
out_layer.CreateField(new_field)
out_defn = out_layer.GetLayerDefn()
# 执行空间连接
for target_feature in target_layer:
target_geom = target_feature.GetGeometryRef()
if target_geom is None:
continue
# 空间过滤
join_layer.SetSpatialFilter(target_geom)
for join_feature in join_layer:
join_geom = join_feature.GetGeometryRef()
# 检查空间关系
if how == 'intersects' and target_geom.Intersects(join_geom):
out_feature = ogr.Feature(out_defn)
out_feature.SetGeometry(target_geom.Clone())
# 复制目标属性
for i in range(target_defn.GetFieldCount()):
out_feature.SetField(i, target_feature.GetField(i))
# 复制连接属性
for i in range(join_defn.GetFieldCount()):
out_feature.SetField(
target_defn.GetFieldCount() + i,
join_feature.GetField(i)
)
out_layer.CreateFeature(out_feature)
break # 只取第一个匹配
join_layer.ResetReading()
target_ds = None
join_ds = None
out_ds = None
# 使用示例
# buffer_layer('points.shp', 'buffer.shp', 1000) # 1000米缓冲区
# spatial_join('parcels.shp', 'zones.shp', 'parcels_with_zone.shp')
7.5 与 NumPy 集成
7.5.1 高效数据转换
from osgeo import gdal, gdal_array
import numpy as np
def read_raster_as_array(filepath, band_index=1):
"""高效读取栅格为 NumPy 数组"""
ds = gdal.Open(filepath)
band = ds.GetRasterBand(band_index)
# 使用 gdal_array 模块
data = gdal_array.BandReadAsArray(band)
# 获取元数据
nodata = band.GetNoDataValue()
geotransform = ds.GetGeoTransform()
projection = ds.GetProjection()
ds = None
return {
'data': data,
'nodata': nodata,
'geotransform': geotransform,
'projection': projection,
}
def write_array_as_raster(filepath, data, geotransform, projection, nodata=-9999):
"""将 NumPy 数组写入栅格"""
# NumPy 类型到 GDAL 类型映射
numpy_to_gdal = {
np.dtype('uint8'): gdal.GDT_Byte,
np.dtype('int8'): gdal.GDT_Byte,
np.dtype('uint16'): gdal.GDT_UInt16,
np.dtype('int16'): gdal.GDT_Int16,
np.dtype('uint32'): gdal.GDT_UInt32,
np.dtype('int32'): gdal.GDT_Int32,
np.dtype('float32'): gdal.GDT_Float32,
np.dtype('float64'): gdal.GDT_Float64,
}
gdal_type = numpy_to_gdal.get(data.dtype, gdal.GDT_Float64)
if data.ndim == 2:
bands = 1
height, width = data.shape
else:
bands, height, width = data.shape
driver = gdal.GetDriverByName('GTiff')
ds = driver.Create(
filepath, width, height, bands, gdal_type,
['COMPRESS=LZW', 'TILED=YES']
)
ds.SetGeoTransform(geotransform)
ds.SetProjection(projection)
if data.ndim == 2:
band = ds.GetRasterBand(1)
gdal_array.BandWriteArray(band, data)
band.SetNoDataValue(nodata)
else:
for i in range(bands):
band = ds.GetRasterBand(i + 1)
gdal_array.BandWriteArray(band, data[i])
band.SetNoDataValue(nodata)
ds.FlushCache()
ds = None
7.5.2 内存数据集
from osgeo import gdal
import numpy as np
def create_memory_raster(data, geotransform, projection):
"""创建内存中的栅格数据集"""
if data.ndim == 2:
bands = 1
height, width = data.shape
data = data[np.newaxis, ...]
else:
bands, height, width = data.shape
# 创建内存驱动
driver = gdal.GetDriverByName('MEM')
# 确定数据类型
dtype_map = {
np.dtype('float32'): gdal.GDT_Float32,
np.dtype('float64'): gdal.GDT_Float64,
np.dtype('int32'): gdal.GDT_Int32,
np.dtype('uint8'): gdal.GDT_Byte,
}
gdal_type = dtype_map.get(data.dtype, gdal.GDT_Float64)
# 创建内存数据集
mem_ds = driver.Create('', width, height, bands, gdal_type)
mem_ds.SetGeoTransform(geotransform)
mem_ds.SetProjection(projection)
for i in range(bands):
band = mem_ds.GetRasterBand(i + 1)
band.WriteArray(data[i])
return mem_ds
def array_to_vrt(data, geotransform, projection):
"""将数组转换为 VRT 数据集"""
# 首先创建内存数据集
mem_ds = create_memory_raster(data, geotransform, projection)
# 创建 VRT
vrt = gdal.BuildVRT('', mem_ds)
return vrt, mem_ds # 注意:必须保持 mem_ds 的引用
7.6 高级主题
7.6.1 多线程处理
from osgeo import gdal
import numpy as np
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import multiprocessing
def process_tile(args):
"""处理单个瓦片"""
src_path, x_off, y_off, x_size, y_size, processing_func = args
# 每个进程需要重新打开数据集
ds = gdal.Open(src_path)
band = ds.GetRasterBand(1)
# 读取瓦片
data = band.ReadAsArray(x_off, y_off, x_size, y_size)
# 处理
result = processing_func(data)
ds = None
return x_off, y_off, result
def parallel_raster_processing(src_path, dst_path, processing_func,
tile_size=512, num_workers=None):
"""并行处理栅格数据"""
gdal.UseExceptions()
if num_workers is None:
num_workers = multiprocessing.cpu_count()
# 打开源数据
src_ds = gdal.Open(src_path)
width = src_ds.RasterXSize
height = src_ds.RasterYSize
# 创建任务列表
tasks = []
for y in range(0, height, tile_size):
for x in range(0, width, tile_size):
x_size = min(tile_size, width - x)
y_size = min(tile_size, height - y)
tasks.append((src_path, x, y, x_size, y_size, processing_func))
# 创建输出数据集
driver = gdal.GetDriverByName('GTiff')
dst_ds = driver.Create(
dst_path, width, height, 1, gdal.GDT_Float32,
['COMPRESS=LZW', 'TILED=YES']
)
dst_ds.SetGeoTransform(src_ds.GetGeoTransform())
dst_ds.SetProjection(src_ds.GetProjection())
dst_band = dst_ds.GetRasterBand(1)
src_ds = None
# 并行处理
with ProcessPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(process_tile, tasks))
# 写入结果
for x_off, y_off, result in results:
dst_band.WriteArray(result, x_off, y_off)
dst_ds.FlushCache()
dst_ds = None
print(f"并行处理完成: {dst_path}")
# 使用示例
def normalize(data):
"""归一化处理"""
return (data - data.min()) / (data.max() - data.min())
# parallel_raster_processing('input.tif', 'output.tif', normalize)
7.6.2 错误处理和日志
from osgeo import gdal
import logging
# 配置日志
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('gdal_app')
# GDAL 错误处理器
def gdal_error_handler(err_class, err_num, err_msg):
"""自定义 GDAL 错误处理器"""
err_type_map = {
gdal.CE_None: 'None',
gdal.CE_Debug: 'Debug',
gdal.CE_Warning: 'Warning',
gdal.CE_Failure: 'Failure',
gdal.CE_Fatal: 'Fatal',
}
err_type = err_type_map.get(err_class, 'Unknown')
if err_class >= gdal.CE_Warning:
logger.warning(f"GDAL {err_type} ({err_num}): {err_msg}")
if err_class >= gdal.CE_Failure:
raise RuntimeError(f"GDAL {err_type}: {err_msg}")
# 注册错误处理器
# gdal.PushErrorHandler(gdal_error_handler)
# 使用上下文管理器进行错误处理
from contextlib import contextmanager
@contextmanager
def gdal_context():
"""GDAL 操作上下文管理器"""
gdal.UseExceptions()
# 保存当前配置
old_debug = gdal.GetConfigOption('CPL_DEBUG')
try:
yield
except Exception as e:
logger.error(f"GDAL 操作失败: {e}")
raise
finally:
# 恢复配置
if old_debug:
gdal.SetConfigOption('CPL_DEBUG', old_debug)
# 使用示例
with gdal_context():
ds = gdal.Open('example.tif')
# 处理数据...
ds = None
7.7 最佳实践
7.7.1 代码组织
"""
gdal_utils.py - GDAL 工具函数模块
"""
from osgeo import gdal, ogr, osr
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Any
import logging
# 模块初始化
gdal.UseExceptions()
ogr.UseExceptions()
logger = logging.getLogger(__name__)
class RasterDataset:
"""栅格数据集封装类"""
def __init__(self, filepath: str, mode: str = 'r'):
self.filepath = Path(filepath)
self.mode = mode
self._ds = None
def __enter__(self):
access = gdal.GA_Update if self.mode == 'w' else gdal.GA_ReadOnly
self._ds = gdal.Open(str(self.filepath), access)
if self._ds is None:
raise IOError(f"无法打开: {self.filepath}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._ds:
self._ds.FlushCache()
self._ds = None
return False
@property
def width(self) -> int:
return self._ds.RasterXSize
@property
def height(self) -> int:
return self._ds.RasterYSize
@property
def bands(self) -> int:
return self._ds.RasterCount
@property
def geotransform(self) -> Tuple:
return self._ds.GetGeoTransform()
@property
def projection(self) -> str:
return self._ds.GetProjection()
def read_band(self, band_index: int = 1) -> np.ndarray:
"""读取指定波段"""
band = self._ds.GetRasterBand(band_index)
return band.ReadAsArray()
def read_all(self) -> np.ndarray:
"""读取所有波段"""
return self._ds.ReadAsArray()
class VectorDataset:
"""矢量数据集封装类"""
def __init__(self, filepath: str, mode: str = 'r'):
self.filepath = Path(filepath)
self.mode = mode
self._ds = None
def __enter__(self):
update = 1 if self.mode == 'w' else 0
self._ds = ogr.Open(str(self.filepath), update)
if self._ds is None:
raise IOError(f"无法打开: {self.filepath}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._ds = None
return False
@property
def layer_count(self) -> int:
return self._ds.GetLayerCount()
def get_layer(self, index: int = 0):
return self._ds.GetLayer(index)
def features(self, layer_index: int = 0) -> List[Dict[str, Any]]:
"""获取所有要素"""
layer = self._ds.GetLayer(layer_index)
features = []
for feature in layer:
feat_dict = {'fid': feature.GetFID()}
geom = feature.GetGeometryRef()
if geom:
feat_dict['geometry'] = geom.ExportToWkt()
layer_defn = layer.GetLayerDefn()
for i in range(layer_defn.GetFieldCount()):
field_name = layer_defn.GetFieldDefn(i).GetName()
feat_dict[field_name] = feature.GetField(i)
features.append(feat_dict)
layer.ResetReading()
return features
# 使用示例
# with RasterDataset('input.tif') as ds:
# data = ds.read_all()
# print(f"数据形状: {data.shape}")
7.7.2 性能优化建议
# 1. 使用分块读取大文件
def read_large_raster_efficiently(filepath, processing_func, block_size=512):
"""高效读取大型栅格文件"""
ds = gdal.Open(filepath)
band = ds.GetRasterBand(1)
width = ds.RasterXSize
height = ds.RasterYSize
for y in range(0, height, block_size):
for x in range(0, width, block_size):
x_size = min(block_size, width - x)
y_size = min(block_size, height - y)
data = band.ReadAsArray(x, y, x_size, y_size)
processing_func(data, x, y)
ds = None
# 2. 设置合适的缓存大小
gdal.SetCacheMax(1024 * 1024 * 1024) # 1GB
# 3. 使用多线程
gdal.SetConfigOption('GDAL_NUM_THREADS', 'ALL_CPUS')
# 4. 对于网络数据,启用 HTTP 缓存
gdal.SetConfigOption('CPL_VSIL_CURL_CACHE_SIZE', '100000000') # 100MB
# 5. 使用 VRT 进行虚拟处理
# VRT 可以避免不必要的数据复制
7.8 本章小结
本章详细介绍了 Python GDAL 开发:
- 环境安装:Conda 安装(推荐)和 pip 安装方法
- 栅格处理:读取、创建、裁剪、重投影、镶嵌
- 矢量处理:读取、创建、空间分析
- NumPy 集成:高效的数据转换和内存数据集
- 高级主题:多线程处理、错误处理
- 最佳实践:代码组织和性能优化
7.9 思考与练习
- 比较 osgeo.gdal 和 rasterio 的优缺点。
- 编写一个函数,读取多波段影像并计算指定的波段指数。
- 实现一个支持断点续传的大文件处理器。
- 如何处理 GDAL 操作中的内存溢出问题?
- 编写一个矢量数据转换器,支持多种输出格式。
- 实现并行处理 1000 个栅格文件的批处理脚本。

浙公网安备 33010602011771号