import os
import re
import glob
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import rasterio
from rasterio.plot import plotting_extent
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import FuncFormatter
from pyproj import Transformer
import pyproj

======== PROJ 路径 ========

os.environ["PROJ_LIB"] = pyproj.datadir.get_data_dir()

======== 全局字体样式设置 ========

mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['font.size'] = 10.5
mpl.rcParams['font.weight'] = 'bold'
mpl.rcParams['axes.labelweight'] = 'bold'
mpl.rcParams['axes.titleweight'] = 'bold'
mpl.rcParams['axes.titlesize'] = 10.5
mpl.rcParams['axes.labelsize'] = 10.5
mpl.rcParams['legend.fontsize'] = 10.5
mpl.rcParams['xtick.labelsize'] = 10.5
mpl.rcParams['ytick.labelsize'] = 10.5
mpl.rcParams['font.style'] = 'normal'

def deg_to_dms(deg):
sign = "-" if deg < 0 else ""
deg = abs(deg)
d = int(deg)
m_float = (deg - d) * 60
m = int(m_float)
s = (m_float - m) * 60
return f"{sign}{d}°{m}′{s:0.0f}″"

def plot_smc_one(tif_path, out_png, n):
cmap_name = "Spectral"

with rasterio.open(tif_path) as src:
    data = src.read(1)
    extent = plotting_extent(src)
    src_crs = src.crs
    nodata = src.nodata

# 掩膜和裁剪
if nodata is not None:
    data = np.ma.masked_equal(data, nodata)
data = np.clip(data, 0, 1.0)

xmin, xmax, ymin, ymax = extent

fig, ax = plt.subplots(figsize=(10, 5))

im = ax.imshow(
    data,
    extent=extent,
    origin="upper",
    cmap=cmap_name,
    vmin=0,
    vmax=0.5,
)

# 如果是投影坐标,转成经纬度用来生刻度
if src_crs is not None and not src_crs.is_geographic:
    transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
    lon_min, lat_min = transformer.transform(xmin, ymin)
    lon_max, lat_max = transformer.transform(xmax, ymax)

    x_ticks_lon = np.linspace(lon_min, lon_max, 6)
    y_ticks_lat = np.linspace(lat_min, lat_max, 5)

    # 把经纬度刻度反投影回原坐标
    xtick_pos = [transformer.transform(lon, lat_min, direction="INVERSE")[0]
                 for lon in x_ticks_lon]
    ytick_pos = [transformer.transform(lon_min, lat, direction="INVERSE")[1]
                 for lat in y_ticks_lat]

    ax.set_xticks(xtick_pos)
    ax.set_yticks(ytick_pos)

    def fmt_x(x, pos):
        lon, _ = transformer.transform(x, ymin, direction="FORWARD")
        return deg_to_dms(lon) + "E"

    def fmt_y(y, pos):
        _, lat = transformer.transform(xmin, y, direction="FORWARD")
        return deg_to_dms(lat) + "N"

    ax.xaxis.set_major_formatter(FuncFormatter(fmt_x))
    ax.yaxis.set_major_formatter(FuncFormatter(fmt_y))
else:
    x_ticks = np.linspace(xmin, xmax, 6)
    y_ticks = np.linspace(ymin, ymax, 5)
    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: deg_to_dms(x) + "E"))
    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, pos: deg_to_dms(y) + "N"))

# 四周留白
margin_x = (xmax - xmin) * 0.1
margin_y = (ymax - ymin) * 0.1
ax.set_xlim(xmin - margin_x, xmax + margin_x)
ax.set_ylim(ymin - margin_y, ymax + margin_y)

# 图框
for spine in ax.spines.values():
    spine.set_edgecolor("black")
    spine.set_linewidth(2)

# 左下角小色带
cax = inset_axes(
    ax,
    width="2%",
    height="30%",
    loc="lower left",
    bbox_to_anchor=(0.02, 0.06, 1, 1),
    bbox_transform=ax.transAxes,
    borderpad=0,
)
cb = plt.colorbar(im, cax=cax, orientation="vertical")
cb.ax.tick_params(labelsize=6)

# 安全地改最后一个刻度为 >0.5
ticks = cb.get_ticks()
tick_labels = [f"{t:.2f}" for t in ticks]
if tick_labels:
    tick_labels[-1] = ">0.5"
cb.set_ticks(ticks)
cb.set_ticklabels(tick_labels)

# 标题
ax.set_title(f"SMCRetrievalN={n}.tif", pad=10)

# 不用 tight_layout,改用手动调
plt.subplots_adjust(left=0.04, right=0.99, top=0.92, bottom=0.04)

# 保存前先渲染,避免 _get_renderer
fig.canvas.draw()
fig.savefig(out_png, dpi=300)   # 不要 bbox_inches="tight"
plt.close(fig)

================= 批量部分 =================

root_dir = r""
out_dir = r""
os.makedirs(out_dir, exist_ok=True)

tif_paths = [
f for f in glob.glob(os.path.join(root_dir, "SMCRetrievalN=*.tif"))
if re.search(r"N=\d+", os.path.basename(f))
]

for tif_path in tif_paths:
n_value = re.search(r"N=(\d+)", os.path.basename(tif_path)).group(1)
out_png = os.path.join(out_dir, f"SMCRetrievalN={n_value}.png")
plot_smc_one(tif_path, out_png=out_png, n=n_value)