@ 20240808 & lth
目标:对硬盘中的所有TIF转出其栅格范围、面积(平方千米)、写入数据库
栅格数据转矢量边界工具
支持处理栅格数据中的空洞,输出为Shapefile或GeoJSON格式
支持超大栅格的分块处理和加速优化
from pg_tools import PostgresDB参考https://www.cnblogs.com/litianhao1998/p/19012366
代码-栅格转矢量
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
栅格数据转矢量边界工具
支持处理栅格数据中的空洞,输出为Shapefile或GeoJSON格式
支持超大栅格的分块处理和加速优化
"""
import os
import sys
import argparse
import time
import gc
from multiprocessing import Pool, cpu_count
from osgeo import gdal, ogr, osr
import numpy as np
from scipy import ndimage
class RasterToVector:
"""栅格转矢量类"""
def __init__(self, chunk_size=512, max_memory_mb=256, use_multiprocessing=True):
"""
初始化栅格转矢量处理器
Args:
chunk_size (int): 分块大小(像素)
max_memory_mb (int): 最大内存使用量(MB)
use_multiprocessing (bool): 是否使用多进程
"""
# 启用GDAL异常
gdal.UseExceptions()
self.chunk_size = chunk_size
self.max_memory_mb = max_memory_mb
self.use_multiprocessing = use_multiprocessing
self.cpu_cores = cpu_count()
def get_raster_info(self, raster_path):
"""
获取栅格基本信息
Args:
raster_path (str): 栅格文件路径
Returns:
dict: 栅格信息字典
"""
dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开栅格文件: {raster_path}")
info = {
'width': dataset.RasterXSize,
'height': dataset.RasterYSize,
'bands': dataset.RasterCount,
'geotransform': dataset.GetGeoTransform(),
'projection': dataset.GetProjection(),
'datatype': dataset.GetRasterBand(1).DataType,
'nodata': dataset.GetRasterBand(1).GetNoDataValue()
}
# 计算文件大小(估算)
pixel_count = info['width'] * info['height']
bytes_per_pixel = gdal.GetDataTypeSize(info['datatype']) // 8
estimated_size_mb = (pixel_count * bytes_per_pixel) / (1024 * 1024)
info['estimated_size_mb'] = estimated_size_mb
dataset = None
return info
def should_use_chunked_processing(self, raster_info):
"""
判断是否需要使用分块处理
Args:
raster_info (dict): 栅格信息
Returns:
bool: 是否需要分块处理
"""
# 如果估算内存使用超过阈值,或者像素数量过大,则使用分块处理
print(f"估算内存需求: {raster_info['estimated_size_mb']:.1f} MB")
return (raster_info['estimated_size_mb'] > self.max_memory_mb or
raster_info['width'] > self.chunk_size * 4 or
raster_info['height'] > self.chunk_size * 4)
def create_mask_from_raster(self, raster_path, nodata_values=None, bands=None):
"""
从栅格数据创建掩膜
Args:
raster_path (str): 栅格文件路径
nodata_values (list or tuple, optional): 无数据值的列表,如果为None则从栅格文件读取。默认为None。
bands (list or tuple, optional): 要处理的波段列表,如果为None则只处理第一个波段。默认为None。
Returns:
tuple: (mask_array, geotransform, projection, width, height)
"""
# 打开栅格文件
dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开栅格文件: {raster_path}")
# 获取栅格信息
width = dataset.RasterXSize
height = dataset.RasterYSize
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
band_count = dataset.RasterCount
print(f"栅格信息:{width} x {height}, 波段数: {band_count}, 投影坐标: {projection}")
# 确定要处理的波段
if bands is None:
bands = [1] # 默认只处理第一个波段
else:
# 确保波段索引在有效范围内
bands = [b for b in bands if 1 <= b <= band_count]
if not bands:
bands = [1] # 如果没有有效波段,则使用第一个波段
print(f"处理波段: {bands}")
# 初始化掩膜为全1(有效数据)
mask = np.ones((height, width), dtype=np.uint8)
# 处理每个波段
for band_index in bands:
band = dataset.GetRasterBand(band_index)
# 获取栅格文件本身的 nodata_value
raster_nodata = band.GetNoDataValue()
# 读取数据
data = band.ReadAsArray()
# 合并传入的 nodata_values 和栅格文件的 nodata_value
all_nodata_values = set()
if nodata_values is not None:
all_nodata_values.update(nodata_values)
if raster_nodata is not None:
all_nodata_values.add(raster_nodata)
# 创建当前波段的掩膜:有效数据为1,无效数据为0
band_mask = np.ones_like(data, dtype=np.uint8)
if all_nodata_values:
for value in all_nodata_values:
band_mask[np.isclose(data, value, equal_nan=True)] = 0
else:
# 如果没有无数据值,则将NaN视为无数据
band_mask = np.where(np.isnan(data), 0, 1).astype(np.uint8)
# 将当前波段掩膜与总掩膜进行"与"操作(只有所有波段都是有效数据的像素才保留)
mask = mask & band_mask
dataset = None # 关闭数据集
return mask, geotransform, projection, width, height
def create_mask_chunked(self, raster_path, nodata_values=None, bands=None):
"""
分块创建掩膜(用于超大栅格)
Args:
raster_path (str): 栅格文件路径
nodata_values (list or tuple, optional): 无数据值的列表,如果为None则从栅格文件读取。默认为None。
bands (list or tuple, optional): 要处理的波段列表,如果为None则只处理第一个波段。默认为None。
Returns:
tuple: (mask_array, geotransform, projection, width, height)
"""
print("使用分块处理模式...")
start_time = time.time()
# 打开栅格文件
dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开栅格文件: {raster_path}")
# 获取栅格信息
width = dataset.RasterXSize
height = dataset.RasterYSize
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
band_count = dataset.RasterCount
print(f"栅格信息:{width} x {height}, 波段数: {band_count}, 投影坐标: {projection}")
# 确定要处理的波段
if bands is None:
bands = [1] # 默认只处理第一个波段
else:
# 确保波段索引在有效范围内
bands = [b for b in bands if 1 <= b <= band_count]
if not bands:
bands = [1] # 如果没有有效波段,则使用第一个波段
print(f"处理波段: {bands}")
# 创建输出掩膜数组(初始化为全1,表示所有像素都是有效的)
mask = np.ones((height, width), dtype=np.uint8)
# 计算分块参数
x_chunks = (width + self.chunk_size - 1) // self.chunk_size
y_chunks = (height + self.chunk_size - 1) // self.chunk_size
total_chunks = x_chunks * y_chunks
print(f"分块处理:{x_chunks} x {y_chunks} = {total_chunks} 个块")
# 对每个波段进行处理
for band_index in bands:
band = dataset.GetRasterBand(band_index)
# 获取栅格文件本身的 nodata_value
raster_nodata = band.GetNoDataValue()
# 合并传入的 nodata_values 和栅格文件的 nodata_value
all_nodata_values = set()
if nodata_values is not None:
all_nodata_values.update(nodata_values)
if raster_nodata is not None:
all_nodata_values.add(raster_nodata)
print(f"处理波段 {band_index},无数据值: {all_nodata_values}")
# 分块处理
processed_chunks = 0
for y_chunk in range(y_chunks):
for x_chunk in range(x_chunks):
# 计算当前块的范围
x_start = x_chunk * self.chunk_size
y_start = y_chunk * self.chunk_size
x_size = min(self.chunk_size, width - x_start)
y_size = min(self.chunk_size, height - y_start)
# 读取数据块
data_chunk = band.ReadAsArray(x_start, y_start, x_size, y_size)
# 创建掩膜块
if all_nodata_values:
mask_chunk = np.ones_like(data_chunk, dtype=np.uint8)
for value in all_nodata_values:
mask_chunk[np.isclose(data_chunk, value, equal_nan=True)] = 0
else:
# 如果没有无数据值,则将NaN视为无数据
mask_chunk = np.where(np.isnan(data_chunk), 0, 1)
# 将当前波段的掩膜块与总掩膜进行"与"操作
# 只有当前块中的所有波段都是有效数据的像素才保留
mask[y_start:y_start+y_size, x_start:x_start+x_size] &= mask_chunk.astype(np.uint8)
processed_chunks += 1
if processed_chunks % 100 == 0 or processed_chunks == total_chunks:
progress = (processed_chunks / total_chunks) * 100
elapsed = time.time() - start_time
print(f"波段 {band_index} 进度: {processed_chunks}/{total_chunks} ({progress:.1f}%), 耗时: {elapsed:.1f}s")
# 强制垃圾回收
del data_chunk, mask_chunk
if processed_chunks % 50 == 0:
gc.collect()
dataset = None # 关闭数据集
print(f"分块处理完成,总耗时: {time.time() - start_time:.1f}s")
return mask, geotransform, projection, width, height
def polygonize_mask(self, mask, geotransform, projection, output_path, output_format='ESRI Shapefile'):
"""
将掩膜矢量化
Args:
mask (numpy.ndarray): 掩膜数组
geotransform (tuple): 地理变换参数
projection (str): 投影信息
output_path (str): 输出文件路径
output_format (str): 输出格式 ('ESRI Shapefile' 或 'GeoJSON')
"""
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mem_dataset = mem_driver.Create('', mask.shape[1], mask.shape[0], 1, gdal.GDT_Byte)
mem_dataset.SetGeoTransform(geotransform)
mem_dataset.SetProjection(projection)
# 写入掩膜数据
mem_band = mem_dataset.GetRasterBand(1)
mem_band.WriteArray(mask)
mem_band.SetNoDataValue(0)
# 创建输出矢量数据集
if output_format == 'GeoJSON':
driver_name = 'GeoJSON'
if not output_path.endswith('.geojson'):
output_path += '.geojson'
else:
driver_name = 'ESRI Shapefile'
if not output_path.endswith('.shp'):
output_path += '.shp'
# 删除已存在的输出文件
if os.path.exists(output_path):
vector_driver = ogr.GetDriverByName(driver_name)
vector_driver.DeleteDataSource(output_path)
# 创建矢量数据源
vector_driver = ogr.GetDriverByName(driver_name)
vector_dataset = vector_driver.CreateDataSource(output_path)
# 创建空间参考系统
srs = osr.SpatialReference()
srs.ImportFromWkt(projection)
# 创建图层
layer = vector_dataset.CreateLayer('footprint', srs, ogr.wkbPolygon)
# 添加字段
field_def = ogr.FieldDefn('DN', ogr.OFTInteger)
layer.CreateField(field_def)
# 执行矢量化
gdal.Polygonize(mem_band, mem_band, layer, 0, [], callback=None)
# 清理
mem_dataset = None
vector_dataset = None
print(f"矢量化完成,输出文件: {output_path}")
def polygonize_mask_optimized(self, mask, geotransform, projection, output_path, output_format='ESRI Shapefile'):
"""
优化的矢量化方法,支持大型掩膜
Args:
mask (numpy.ndarray): 掩膜数组
geotransform (tuple): 地理变换参数
projection (str): 投影信息
output_path (str): 输出文件路径
output_format (str): 输出格式
"""
print("开始优化矢量化...")
start_time = time.time()
# 检查掩膜大小,决定是否需要特殊处理
mask_size_mb = mask.nbytes / (1024 * 1024)
print(f"掩膜大小: {mask_size_mb:.1f} MB")
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mem_dataset = mem_driver.Create('', mask.shape[1], mask.shape[0], 1, gdal.GDT_Byte)
mem_dataset.SetGeoTransform(geotransform)
mem_dataset.SetProjection(projection)
# 写入掩膜数据
mem_band = mem_dataset.GetRasterBand(1)
mem_band.WriteArray(mask)
mem_band.SetNoDataValue(0)
# 设置缓存大小以优化性能
gdal.SetCacheMax(min(512 * 1024 * 1024, int(mask_size_mb * 2 * 1024 * 1024))) # 设置为掩膜大小的2倍或512MB
# 创建输出矢量数据集
if output_format == 'GeoJSON':
driver_name = 'GeoJSON'
if not output_path.endswith('.geojson'):
output_path += '.geojson'
else:
driver_name = 'ESRI Shapefile'
if not output_path.endswith('.shp'):
output_path += '.shp'
# 删除已存在的输出文件
if os.path.exists(output_path):
vector_driver = ogr.GetDriverByName(driver_name)
vector_driver.DeleteDataSource(output_path)
# 创建矢量数据源
vector_driver = ogr.GetDriverByName(driver_name)
vector_dataset = vector_driver.CreateDataSource(output_path)
# 创建空间参考系统
srs = osr.SpatialReference()
srs.ImportFromWkt(projection)
# 创建图层
layer = vector_dataset.CreateLayer('footprint', srs, ogr.wkbPolygon)
# 添加字段
field_def = ogr.FieldDefn('DN', ogr.OFTInteger)
layer.CreateField(field_def)
# 执行矢量化(带进度回调)
def progress_callback(complete, message, data):
if complete % 0.1 < 0.01: # 每10%显示一次进度
elapsed = time.time() - start_time
print(f"矢量化进度: {complete*100:.1f}%, 耗时: {elapsed:.1f}s")
return 1 # 继续处理
gdal.Polygonize(mem_band, mem_band, layer, 0, [], callback=progress_callback)
# 清理
mem_dataset = None
vector_dataset = None
elapsed = time.time() - start_time
print(f"优化矢量化完成,耗时: {elapsed:.1f}s,输出文件: {output_path}")
def filter_small_regions(self, mask, geotransform, min_area_km2=10):
"""
过滤掩膜中小于指定面积的无效区域(将小的无效区域转换为有效区域)
Args:
mask (numpy.ndarray): 掩膜数组,1表示有效数据,0表示无效数据
geotransform (tuple): 地理变换参数,用于计算像素实际面积
min_area_km2 (float): 最小区域面积(平方千米),小于此面积的无效区域将被视为有效区域
Returns:
numpy.ndarray: 过滤后的掩膜数组
"""
print(f"开始过滤小于 {min_area_km2} 平方千米的无效区域...")
start_time = time.time()
# 计算像素面积(平方米)
pixel_width = abs(geotransform[1])
pixel_height = abs(geotransform[5])
pixel_area_m2 = pixel_width * pixel_height
# 转换最小面积为像素数
min_area_m2 = min_area_km2 * 1000000 # 平方千米转平方米
min_pixels = int(min_area_m2 / pixel_area_m2)
print(f"像素面积: {pixel_area_m2:.2f} 平方米")
print(f"最小区域面积: {min_area_km2} 平方千米 = {min_area_m2} 平方米 = {min_pixels} 像素")
# 创建无效区域掩膜(0表示有效数据,1表示无效数据,与原掩膜相反)
invalid_mask = 1 - mask
# 标记连通区域
labeled_array, num_features = ndimage.label(invalid_mask)
print(f"检测到 {num_features} 个无效区域")
if num_features == 0:
print("没有无效区域需要处理")
return mask
# 计算每个连通区域的大小
sizes = ndimage.sum(invalid_mask, labeled_array, range(1, num_features + 1))
# 找出小于阈值的区域标签
small_regions = np.where(sizes < min_pixels)[0] + 1 # +1 因为标签从1开始
# 创建新掩膜,将小区域转换为有效区域
filtered_mask = mask.copy()
if len(small_regions) > 0:
for region_label in small_regions:
# 将小的无效区域转换为有效区域(设置为1)
filtered_mask[labeled_array == region_label] = 1
print(f"已过滤 {len(small_regions)} 个小于 {min_area_km2} 平方千米的无效区域")
print(f"过滤前无效像素数: {np.sum(invalid_mask)}")
print(f"过滤后无效像素数: {np.sum(1 - filtered_mask)}")
else:
print("没有小于阈值的无效区域需要过滤")
elapsed = time.time() - start_time
print(f"区域过滤完成,耗时: {elapsed:.1f}s")
return filtered_mask
def filter_valid_polygons(self, input_path, output_path, output_format='ESRI Shapefile'):
"""
过滤有效的多边形(DN值为1的多边形,即有数据的区域)
Args:
input_path (str): 输入矢量文件路径
output_path (str): 输出矢量文件路径
output_format (str): 输出格式
"""
# 打开输入数据源
input_dataset = ogr.Open(input_path, 0)
input_layer = input_dataset.GetLayer()
# 创建输出数据源
if output_format == 'GeoJSON':
driver_name = 'GeoJSON'
if not output_path.endswith('.geojson'):
output_path += '.geojson'
else:
driver_name = 'ESRI Shapefile'
if not output_path.endswith('.shp'):
output_path += '.shp'
# 删除已存在的输出文件
if os.path.exists(output_path):
output_driver = ogr.GetDriverByName(driver_name)
output_driver.DeleteDataSource(output_path)
# 创建输出数据源
output_driver = ogr.GetDriverByName(driver_name)
output_dataset = output_driver.CreateDataSource(output_path)
# 创建输出图层
output_layer = output_dataset.CreateLayer(
'footprint',
input_layer.GetSpatialRef(),
ogr.wkbPolygon
)
# 复制字段定义
input_layer_defn = input_layer.GetLayerDefn()
for i in range(input_layer_defn.GetFieldCount()):
field_defn = input_layer_defn.GetFieldDefn(i)
output_layer.CreateField(field_defn)
# 过滤并复制要素
valid_count = 0
for feature in input_layer:
dn_value = feature.GetField('DN')
if dn_value == 1: # 只保留有数据的区域
output_layer.CreateFeature(feature)
valid_count += 1
# 清理
input_dataset = None
output_dataset = None
print(f"过滤完成,保留了 {valid_count} 个有效多边形")
def process_raster(self, input_raster, output_vector, output_format='ESRI Shapefile', nodata_value=None, bands=None, min_area_km2=None):
"""
处理栅格文件,生成矢量边界(智能选择处理策略)
Args:
input_raster (str): 输入栅格文件路径
output_vector (str): 输出矢量文件路径
output_format (str): 输出格式 ('ESRI Shapefile' 或 'GeoJSON')
nodata_value (float or list): 无数据值,可以是单个值或值列表
bands (list or tuple, optional): 要处理的波段列表,如果为None则只处理第一个波段。默认为None。
min_area_km2 (float, optional): 最小无效区域面积(平方千米),小于此面积的无效区域将被视为有效区域。默认为None,表示不进行面积过滤。
"""
try:
print(f"开始处理栅格文件: {input_raster}")
start_time = time.time()
# 0. 获取栅格信息并选择处理策略
print("分析栅格文件...")
raster_info = self.get_raster_info(input_raster)
print(f"栅格大小: {raster_info['width']} x {raster_info['height']}")
print(f"估算内存需求: {raster_info['estimated_size_mb']:.1f} MB")
use_chunked = self.should_use_chunked_processing(raster_info)
if use_chunked:
print("检测到大型栅格,使用分块处理模式")
else:
print("使用标准处理模式")
# 1. 创建掩膜
print("创建掩膜...")
if use_chunked:
mask, geotransform, projection, width, height = self.create_mask_chunked(
input_raster, nodata_values=nodata_value, bands=bands
)
else:
mask, geotransform, projection, width, height = self.create_mask_from_raster(
input_raster, nodata_values=nodata_value, bands=bands
)
# 1.5 过滤小面积无效区域(如果指定了最小面积)
if min_area_km2 is not None and min_area_km2 > 0:
print(f"应用连通性大小筛选,最小无效区域面积: {min_area_km2} 平方千米")
mask = self.filter_small_regions(mask, geotransform, min_area_km2)
# 2. 矢量化掩膜
print("矢量化掩膜...")
temp_output = output_vector + '_temp'
# 根据掩膜大小选择矢量化方法
mask_size_mb = mask.nbytes / (1024 * 1024)
if mask_size_mb > 100: # 大于100MB使用优化方法
self.polygonize_mask_optimized(mask, geotransform, projection, temp_output, output_format)
else:
self.polygonize_mask(mask, geotransform, projection, temp_output, output_format)
# 3. 过滤有效多边形
print("过滤有效多边形...")
if output_format == 'GeoJSON':
temp_output += '.geojson'
else:
temp_output += '.shp'
self.filter_valid_polygons(temp_output, output_vector, output_format)
# 4. 清理临时文件
if output_format == 'ESRI Shapefile':
# 删除shapefile相关文件
for ext in ['.shp', '.shx', '.dbf', '.prj']:
temp_file = temp_output.replace('.shp', ext)
if os.path.exists(temp_file):
os.remove(temp_file)
else:
if os.path.exists(temp_output):
os.remove(temp_output)
# 5. 清理内存
del mask
gc.collect()
total_time = time.time() - start_time
print(f"处理完成!总耗时: {total_time:.1f}s")
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"处理过程中发生错误: {str(e)}")
print(f"详细错误信息:")
print(error_details)
raise
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='栅格数据转矢量边界工具(支持超大栅格优化处理)')
parser.add_argument('input', help='输入栅格文件路径 (.tif)')
parser.add_argument('output', help='输出矢量文件路径 (.shp 或 .geojson)')
parser.add_argument('--format', choices=['shapefile', 'geojson'], default='shapefile',
help='输出格式 (默认: shapefile)')
parser.add_argument('--nodata', type=float, nargs='+', help='无数据值,可以指定多个值 (如果不指定则从栅格文件读取)')
parser.add_argument('--bands', type=int, nargs='+', help='要处理的波段索引,可以指定多个波段 (默认: 只处理第一个波段)')
parser.add_argument('--chunk-size', type=int, default=2048,
help='分块大小(像素),用于大型栅格处理 (默认: 2048)')
parser.add_argument('--max-memory', type=int, default=1024,
help='最大内存使用量(MB) (默认: 1024)')
parser.add_argument('--no-multiprocessing', action='store_true',
help='禁用多进程处理')
parser.add_argument('--min-area', type=float, default=10.0,
help='最小无效区域面积(平方千米),小于此面积的无效区域将被视为有效区域 (默认: 10.0)')
args = parser.parse_args()
# 检查输入文件是否存在
if not os.path.exists(args.input):
print(f"错误: 输入文件不存在: {args.input}")
sys.exit(1)
# 确定输出格式
output_format = 'GeoJSON' if args.format == 'geojson' else 'ESRI Shapefile'
# 创建处理器并执行
processor = RasterToVector(
chunk_size=args.chunk_size,
max_memory_mb=args.max_memory,
use_multiprocessing=not args.no_multiprocessing
)
print(f"处理器配置:")
print(f" 分块大小: {args.chunk_size} 像素")
print(f" 最大内存: {args.max_memory} MB")
print(f" 多进程: {'启用' if not args.no_multiprocessing else '禁用'}")
print(f" CPU核心数: {processor.cpu_cores}")
print(f" 最小无效区域面积: {args.min_area} 平方千米")
print()
processor.process_raster(args.input, args.output, output_format, args.nodata, args.bands, args.min_area)
if __name__ == '__main__':
main()
代码-读写数据库
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
从tif文件中提取空间数据并写入PostgreSQL数据库
"""
import os
import sys
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple, Union
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
import tempfile
import shutil
import numpy as np
# 导入GDAL/OGR模块
from osgeo import gdal, ogr, osr
# 导入栅格转矢量工具
from raster_to_vector import RasterToVector
# 数据库连接
from pg_tools import PostgresDB
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class TifGeometryExtractor:
"""从TIF文件提取几何信息并写入数据库"""
def __init__(self, db: PostgresDB, max_workers: int = None):
"""初始化提取器
Args:
db: PostgresDB实例
max_workers: 最大工作线程数,默认为CPU核心数
"""
self.db = db
self.max_workers = max_workers or multiprocessing.cpu_count()
# 创建栅格转矢量处理器
self.raster_processor = RasterToVector(
chunk_size=2048,
max_memory_mb=1024,
use_multiprocessing=True
)
def get_tif_files_from_db(self, disk_ids: Union[List[int], int] = None) -> List[Dict[str, Any]]:
"""从数据库中获取所有TIF文件
Args:
disk_ids: 硬盘ID列表或单个硬盘ID,如果不提供则获取所有硬盘的文件
Returns:
包含TIF文件信息的字典列表
"""
query = """SELECT
id, disk_id, file_path, parent_dir, file_name, file_ext
FROM
ds.ew_disk_index_desc_partitioned edf
WHERE
(edf.file_ext = 'tif' OR edf.file_ext = 'tiff')
AND NOT EXISTS (
SELECT 1
FROM ds.ew_spatial_data esd
WHERE esd.file_id = edf.id
)"""
params = None
if disk_ids is not None:
# 转换为列表
if isinstance(disk_ids, int):
disk_ids = [disk_ids]
# 构建IN查询
placeholders = ','.join(['%s'] * len(disk_ids))
query += f" AND disk_id IN ({placeholders})"
params = tuple(disk_ids)
self.db.execute(query, params)
return self.db.fetchall()
def extract_geometry_from_tif(self, file_path: str) -> Tuple[Optional[str], str, float]:
"""从TIF文件提取几何信息
Args:
file_path: TIF文件路径
Returns:
(WKT几何对象, 源坐标系, 面积)的元组,如果提取失败则返回(None, None, 0)
"""
try:
# 创建临时目录
temp_dir = tempfile.mkdtemp()
temp_output = os.path.join(temp_dir, "temp_vector")
try:
# 使用栅格转矢量工具处理TIF文件
self.raster_processor.process_raster(
input_raster=file_path,
output_vector=temp_output,
output_format='ESRI Shapefile',
nodata_value=[-9999, 0, 255, np.nan],
bands=[1],
# min_area_km2=0.0001
)
# 读取生成的GeoJSON文件
geojson_path = temp_output + '.shp'
if not os.path.exists(geojson_path):
logger.warning(f"未生成GeoJSON文件: {geojson_path}")
return None, None, 0
# 使用GDAL读取GeoJSON文件
from osgeo import ogr
# 打开GeoJSON文件
datasource = ogr.Open(geojson_path)
if datasource is None:
logger.warning(f"无法打开GeoJSON文件: {geojson_path}")
return None, None, 0
# 获取图层
layer = datasource.GetLayer(0)
if layer is None:
logger.warning(f"GeoJSON文件不包含图层: {geojson_path}")
return None, None, 0
# 获取空间参考
spatial_ref = layer.GetSpatialRef()
source_crs = spatial_ref.ExportToWkt() if spatial_ref else None
# 获取要素数量
feature_count = layer.GetFeatureCount()
print(f"GeoJSON文件包含 {feature_count} 个要素")
if feature_count == 0:
logger.warning(f"GeoJSON文件不包含要素: {geojson_path}")
return None, None, 0
# 创建一个多边形集合
multi_polygon = ogr.Geometry(ogr.wkbMultiPolygon)
# 遍历所有要素
layer.ResetReading()
feature = layer.GetNextFeature()
# 寻找最大的多边形
largest_polygon = None
max_area = 0
while feature:
# 获取几何对象
geometry = feature.GetGeometryRef()
if geometry:
# 检查几何类型
geom_type = geometry.GetGeometryType()
# 如果是多边形,检查面积
if geom_type == ogr.wkbPolygon:
area = geometry.GetArea()
if area > max_area:
max_area = area
largest_polygon = geometry.Clone()
# 无论如何,也添加到多边形集合中(作为备用)
multi_polygon.AddGeometry(geometry.Clone())
feature = layer.GetNextFeature()
if largest_polygon:
# 如果找到了最大的多边形,使用它
geometry = largest_polygon
print("使用最大的单个多边形")
elif not multi_polygon.IsEmpty():
# 否则使用多边形集合
geometry = multi_polygon
print("使用多边形集合")
else:
logger.warning(f"所有要素都不包含几何对象: {geojson_path}")
return None, None, 0
# 保存原始几何对象的副本
original_geometry = geometry.Clone()
# 创建目标坐标系统 (EPSG:4547)
target_srs = osr.SpatialReference()
target_srs.ImportFromEPSG(4547)
# 创建坐标转换
is_transformed = False
if spatial_ref:
# 确保空间参考系统正确设置
spatial_ref.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
target_srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
# 输出源坐标系统信息用于调试
print(f"源坐标系统: {spatial_ref.ExportToWkt()}")
print(f"目标坐标系统: {target_srs.ExportToWkt()}")
transform = osr.CoordinateTransformation(spatial_ref, target_srs)
# 转换几何对象到目标坐标系
try:
geometry.Transform(transform)
print("坐标系统已转换为 EPSG:4547")
print(f"转换后的几何WKT: {geometry.ExportToWkt()[:100]}...") # 只打印前100个字符
is_transformed = True
except Exception as e:
logger.warning(f"坐标转换失败: {e}")
print("坐标转换失败,将不会写入数据库")
# 返回None表示处理失败
return None, None, 0
# 计算面积(平方千米)
# EPSG:4547 使用米作为单位,所以除以 1,000,000 转换为平方千米
area_sqkm = geometry.GetArea() / 1_000_000
print(f"计算面积: {area_sqkm} 平方千米")
# 获取WKT表示
wkt = geometry.ExportToWkt()
# 更新源坐标系为目标坐标系
source_crs = target_srs.ExportToWkt()
# 清理
datasource = None
return wkt, source_crs, area_sqkm
finally:
# 清理临时目录
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {e}")
return None, None, 0
def process_tif_file(self, file_info: Dict[str, Any]) -> bool:
"""处理单个TIF文件
Args:
file_info: 文件信息字典
Returns:
处理是否成功
"""
file_path = file_info['file_path']
file_id = file_info['id']
disk_id = file_info['disk_id']
try:
# 检查文件是否存在
if not os.path.exists(file_path):
logger.warning(f"文件不存在: {file_path}")
return False
# 提取几何信息
wkt, source_crs, area_sqkm = self.extract_geometry_from_tif(file_path)
if wkt is None:
logger.warning(f"无法从文件 {file_path} 提取几何信息")
return False
# 获取文件名和文件路径
file_name = file_info['file_name']
file_path = file_info['file_path']
# 使用固定的SRID 4547,因为我们已经确保几何对象已成功转换为EPSG:4547
srid = 4547
print(f"使用SRID: {srid} 写入数据库")
# 打印WKT的前100个字符,用于调试
print(f"WKT前100个字符: {wkt[:100]}...")
# 插入数据库
# 使用ST_Force2D确保几何对象是2D的
# 使用ST_SetSRID明确设置SRID
# 使用ST_Multi将Polygon转换为MultiPolygon以匹配数据库列类型
insert_sql = f"""
INSERT INTO ds.ew_spatial_data (
disk_id, file_id, file_name, file_path, geometry, source_crs, area_sqkm, created_at
) VALUES (
%s, %s, %s, %s, ST_Multi(ST_SetSRID(ST_Force2D(ST_GeomFromText(%s)), {srid})), %s, %s, NOW()
) ON CONFLICT (disk_id, file_id) DO UPDATE SET
file_name = %s,
file_path = %s,
geometry = ST_Multi(ST_SetSRID(ST_Force2D(ST_GeomFromText(%s)), {srid})),
source_crs = %s,
area_sqkm = %s,
created_at = NOW()
"""
self.db.execute(insert_sql, (
disk_id, file_id, file_name, file_path, wkt, source_crs, area_sqkm,
file_name, file_path, wkt, source_crs, area_sqkm
))
logger.info(f"成功处理文件 {file_path}")
return True
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {e}")
return False
def process_all_tif_files(self, disk_ids: Union[List[int], int] = None) -> Dict[str, int]:
"""处理所有TIF文件
Args:
disk_ids: 硬盘ID列表或单个硬盘ID,如果不提供则处理所有硬盘的文件
Returns:
处理结果统计
"""
# 获取所有TIF文件
tif_files = self.get_tif_files_from_db(disk_ids)
total_files = len(tif_files)
if total_files == 0:
logger.info("未找到TIF文件")
return {"total": 0, "success": 0, "failed": 0}
logger.info(f"找到 {total_files} 个TIF文件")
# 使用线程池并行处理文件
success_count = 0
failed_count = 0
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
future_to_file = {
executor.submit(self.process_tif_file, file_info): file_info
for file_info in tif_files
}
# 获取结果
for i, future in enumerate(concurrent.futures.as_completed(future_to_file)):
try:
success = future.result()
if success:
success_count += 1
else:
failed_count += 1
except Exception as e:
logger.error(f"处理文件时发生错误: {e}")
failed_count += 1
# 显示进度
if (i + 1) % 10 == 0 or (i + 1) == total_files:
progress = (i + 1) / total_files * 100
logger.info(f"进度: {progress:.2f}% ({i + 1}/{total_files})")
logger.info(f"处理完成: 总共 {total_files} 个文件, 成功 {success_count} 个, 失败 {failed_count} 个")
return {
"total": total_files,
"success": success_count,
"failed": failed_count
}
def main():
"""主函数"""
# 数据库配置
DB_CONFIG = {
"dbname": "dc",
"user": "postgres",
"password": "123456",
"host": "172.31.60.107",
"port": "5432"
}
# 解析命令行参数
import argparse
# python extract_tif_geometry.py --disk-ids 1,2 --threads 4
parser = argparse.ArgumentParser(description='从TIF文件提取几何信息并写入数据库')
parser.add_argument('--disk-ids', type=str, help='硬盘ID列表,用逗号分隔,如"1,2,3"')
parser.add_argument('--threads', type=int, default=None, help='线程数,默认为CPU核心数')
args = parser.parse_args()
# 处理硬盘ID参数
disk_ids = None
if args.disk_ids:
try:
disk_ids = [int(id_str.strip()) for id_str in args.disk_ids.split(',') if id_str.strip()]
if not disk_ids:
logger.warning("未提供有效的硬盘ID,将处理所有硬盘")
disk_ids = None
except ValueError:
logger.error("硬盘ID格式错误,应为逗号分隔的整数列表")
sys.exit(1)
try:
# 连接数据库
with PostgresDB(**DB_CONFIG) as db:
# 创建提取器
extractor = TifGeometryExtractor(db, max_workers=args.threads)
# 处理文件
start_time = datetime.now()
result = extractor.process_all_tif_files(disk_ids=disk_ids)
end_time = datetime.now()
# 显示统计信息
duration = (end_time - start_time).total_seconds()
logger.info(f"总耗时: {duration:.2f} 秒")
if result['total'] > 0:
logger.info(f"平均速度: {result['total']/duration:.2f} 文件/秒")
logger.info(f"成功率: {result['success']/result['total']*100:.2f}%")
except Exception as e:
logger.error(f"程序执行出错: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
# C:\Users\Admin\Desktop\ew_disk\ew_tif\disk\disk\python.exe C:\Users\Admin\Desktop\ew_disk\ew_tif\extract_tif_geometry.py --disk-ids 1,2 --threads 1