# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import scipy.io
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mticker
from matplotlib.font_manager import FontProperties
from matplotlib.patches import FancyBboxPatch

def parse_pct_trial(name):
    """解析 'Pct40_Trial3' -> (40, 3)，不符合返回 (None, None)"""
    if not isinstance(name, str) or (not name.startswith("Pct")):
        return None, None

    parts = name.split("_")
    if len(parts) != 2 or (not parts[0].startswith("Pct")) or (not parts[1].startswith("Trial")):
        return None, None

    try:
        pct = int(parts[0].replace("Pct", ""))
        tr  = int(parts[1].replace("Trial", ""))
        return pct, tr
    except Exception:
        return None, None
    
# =============================================================================
# 0. Fonts (Windows)
# =============================================================================
arial = FontProperties(fname=r"C:\Windows\Fonts\arial.ttf")
simsun = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc")

mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = ["Arial"]
mpl.rcParams["axes.unicode_minus"] = False



# =============================================================================
# 1. Inputs
# =============================================================================
npz_path = r'./ohc_all_0to2000_2005-01_2023-12.npz'
ohc_all_0to2000 = np.load(npz_path, allow_pickle=False)

months_full = np.arange(
    np.datetime64('2005-01'),
    np.datetime64('2024-01'),
    dtype='datetime64[M]'
)

scenario_keys = [
    'Pct80_Trial1', 'Pct80_Trial2', 'Pct80_Trial3', 'Pct80_Trial4', 'Pct80_Trial5',
    'Pct60_Trial1', 'Pct60_Trial2', 'Pct60_Trial3', 'Pct60_Trial4', 'Pct60_Trial5',
    'Pct40_Trial1', 'Pct40_Trial2', 'Pct40_Trial3', 'Pct40_Trial4', 'Pct40_Trial5',
    'Pct20_Trial1', 'Pct20_Trial2', 'Pct20_Trial3', 'Pct20_Trial4', 'Pct20_Trial5',
    'NoUS', 'NoAU', 'NoJA', 'NoFR', 'NoGE', 'Ref'
]

# =============================================================================
# 2. Annual mean OHC and OHCT (centered difference)
# =============================================================================
sec_per_annual = 365.2425 * 24 * 3600  # mean seconds per year (tropical year)

ohc_all_0to2000_annual = {}
ohct_all_0to2000 = {}

# 2.1 Monthly -> annual mean
for exp_name in scenario_keys:
    yy = ohc_all_0to2000[exp_name]
    yy1 = yy.reshape(-1, 12).mean(axis=1)
    ohc_all_0to2000_annual[exp_name] = yy1

# 2.2 Annual OHC -> OHCT (W m^-2), centered difference and unit conversion
for sc, series in ohc_all_0to2000_annual.items():
    series = np.asarray(series)
    ohct = (series[2:] - series[:-2]) / 2  # centered difference (per year index)
    ohct_all_0to2000[sc] = ohct / sec_per_annual / (3.61 * (10 ** 14))

# Years corresponding to centered difference (lose two endpoints)
years_full_ohct = np.arange(
    np.datetime64('2006-01'),
    np.datetime64('2023-01'),
    dtype='datetime64[Y]'
)

# =============================================================================
# 3. Metrics: RMSE/STD(Ref) and trend error (relative to Ref)
# =============================================================================
years = np.arange(2006, 2023)  # 2006–2022
ref = np.asarray(ohct_all_0to2000["Ref"], dtype=float)

if ref.shape[0] != years.shape[0]:
    raise ValueError(f"Ref length {ref.shape[0]} != years length {years.shape[0]}")

ref_std = np.nanstd(ref)

# 3.1 RMSE (common valid samples)
rmse = {}
for sc, series in ohct_all_0to2000.items():
    if sc in ["NoData", "NoArgo", "Ref"]:
        continue

    y = np.asarray(series, dtype=float)
    if y.shape[0] != ref.shape[0]:
        rmse[sc] = np.nan
        continue

    mask = np.isfinite(ref) & np.isfinite(y)
    rmse[sc] = np.sqrt(np.nanmean((ref[mask] - y[mask]) ** 2)) if mask.sum() > 0 else np.nan

# 3.2 Trend slope and error relative to REF
mask_ref = np.isfinite(ref) & np.isfinite(years)
trend_ref = np.polyfit(years[mask_ref], ref[mask_ref], 1)[0] if mask_ref.sum() >= 2 else np.nan

trend_rel_err = {}
trend = {}
trend_err = {}
for sc, series in ohct_all_0to2000.items():
    if sc in ["NoData", "NoArgo", "Ref"]:
        continue

    y = np.asarray(series, dtype=float)
    if y.shape[0] != years.shape[0]:
        trend_rel_err[sc] = np.nan
        continue

    mask = np.isfinite(y) & np.isfinite(years)
    if mask.sum() < 2 or (not np.isfinite(trend_ref)) or (trend_ref == 0):
        trend_rel_err[sc] = np.nan
        continue

    trend_sc = np.polyfit(years[mask], y[mask], 1)[0]
    trend[sc] = trend_sc
    trend_err[sc] = trend_sc - trend_ref
    trend_rel_err[sc] = (trend_sc - trend_ref) / trend_ref

# =============================================================================
# 4. Organize random realizations by removal percentage / trial
# =============================================================================
pct_levels = [20, 40, 60, 80]
trials = [1, 2, 3, 4, 5]

# (kept as-is in your original logic)
offsets = {1: -1.6, 2: -0.8, 3: 0.0, 4: 0.8, 5: 1.6}
offsets = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}

rmse_by = {pct: {} for pct in pct_levels}
for k, v in rmse.items():
    if k == "NoUS":
        continue
    pct, tr = parse_pct_trial(k)
    if (pct in rmse_by) and (tr in trials):
        rmse_by[pct][tr] = float(v)

trend_by = {pct: {} for pct in pct_levels}
for k, v in trend_err.items():
    if k == "NoUS":
        continue
    pct, tr = parse_pct_trial(k)
    if (pct in trend_by) and (tr in trials):
        trend_by[pct][tr] = float(v)

# =============================================================================
# 5. Plot styling (single source of truth)
# =============================================================================
STYLE = {
    "title_fs": 12,
    "label_fs": 12,
    "tick_fs": 10,
    "legend_fs": 8,
    "lw_ref": 1.5,
    "lw_group": 0.8,
    "ms": 7.5,
    "alpha_line": 1.0,
    "alpha_marker": 0.85,
}

pct_color = ["#b6ceeb", "#6276b7", "#3c58a4", "#1c3c70"]  # No-20/40/60/80
no_color  = ["#9d2932", "#ea5414", "#fbb612", "#d2357d", "pink"]

trial_colors = {1: "#2775b6", 2: "#2775b6", 3: "#2775b6", 4: "#2775b6", 5: "#2775b6"}

no_map = {
    "NoUS": ("No-US", no_color[0]),
    "NoAU": ("No-AU", no_color[1]),
    "NoJA": ("No-JP", no_color[2]),
    "NoFR": ("No-FR", no_color[3]),
    "NoGE": ("No-DE", no_color[4]),
}

pct_map = {
    "Pct20": ("No-20%", pct_color[0]),
    "Pct40": ("No-40%", pct_color[1]),
    "Pct60": ("No-60%", pct_color[2]),
    "Pct80": ("No-80%", pct_color[3]),
}

legend_order = [
    "REF", "No-20%", "No-40%", "No-60%", "No-80%",
    "No-US", "No-AU", "No-JP", "No-FR", "No-DE",
]

# X-positions for No-X points in panels (b)(c)
no_xpos = {"NoUS": 53, "NoAU": 17, "NoJA": 5, "NoFR": 5, "NoGE": 3}

# =============================================================================
# 6. Plot helpers
# =============================================================================
def set_tick_font(ax, font_prop, tick_fs):
    for lab in ax.get_xticklabels():
        lab.set_fontproperties(font_prop)
        lab.set_fontsize(tick_fs)
    for lab in ax.get_yticklabels():
        lab.set_fontproperties(font_prop)
        lab.set_fontsize(tick_fs)

def filter_ticks_to_ylim(ticks, ylim):
    lo, hi = min(ylim), max(ylim)
    return [t for t in ticks if (t >= lo and t <= hi)]

# =============================================================================
# 7. Figure layout
# =============================================================================
fig = plt.figure(figsize=(7, 7), dpi=500)
gs = fig.add_gridspec(
    nrows=2, ncols=2,
    height_ratios=[1.2, 1],
    hspace=0.2,
    wspace=0.52,
)

ax1 = fig.add_subplot(gs[0, :])   # (a)
ax2 = fig.add_subplot(gs[1, 0])   # (b)
ax3 = fig.add_subplot(gs[1, 1])   # (c)

# =============================================================================
# (a) OHCT time series
# =============================================================================
legend_handles = {}

for sc, series in ohct_all_0to2000.items():
    if sc in ["NoData", "NoArgo"]:
        continue

    if sc == "Ref":
        key = "REF"
        line, = ax1.plot(
            years_full_ohct, series,
            color="black", linestyle="--",
            linewidth=STYLE["lw_ref"],
            alpha=STYLE["alpha_line"], zorder=3
        )

    elif sc in no_map:
        key, color = no_map[sc]
        line, = ax1.plot(
            years_full_ohct, series,
            color=color, linestyle="-",
            linewidth=STYLE["lw_ref"],
            alpha=STYLE["alpha_line"], zorder=2
        )

    else:
        matched = False
        for pct_tag, (key, color) in pct_map.items():
            if sc.startswith(pct_tag):
                line, = ax1.plot(
                    years_full_ohct, series,
                    color=color,
                    linewidth=STYLE["lw_group"],
                    alpha=STYLE["alpha_line"], zorder=1
                )
                matched = True
                break
        if not matched:
            continue

    if key not in legend_handles:
        legend_handles[key] = line

# Legend background (draw first, then legend)
ax1.add_patch(
    FancyBboxPatch(
        (0.67, 0.032), 0.31, 0.34,
        boxstyle="round,pad=0.01,rounding_size=0.02",
        transform=ax1.transAxes,
        facecolor="#F5F5F5",
        edgecolor="none",
        alpha=1,
        zorder=10,
    )
)

handles = [legend_handles[k] for k in legend_order if k in legend_handles]
labels  = [k for k in legend_order if k in legend_handles]

leg = ax1.legend(
    handles, labels,
    frameon=False,
    prop=arial,
    fontsize=6,
    loc="lower right",
    bbox_to_anchor=(1.008, -0.026),
    ncol=2,
    columnspacing=1.1,
    handletextpad=0.3,
)
leg.set_zorder(100)

for lh, txt in zip(leg.get_lines(), leg.get_texts()):
    lab = txt.get_text()
    if lab in ["No-20%", "No-40%", "No-60%", "No-80%"]:
        lh.set_linewidth(1.2)
    elif lab in ["No-US", "No-AU", "No-JP", "No-FR", "No-DE", "No-EU"]:
        lh.set_linewidth(2.0)
    else:
        lh.set_linewidth(2.5)

ax1.set_title("(a) Ocean heating rate (annual)", fontproperties=arial, fontsize=STYLE["title_fs"], pad=6)
ax1.set_ylabel("W m$^{-2}$", fontproperties=arial, fontsize=STYLE["label_fs"])

ax1.set_xlim(years_full_ohct[0], years_full_ohct[-1] + np.timedelta64(1, "M"))
ax1.xaxis.set_major_locator(mdates.YearLocator(2))
ax1.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax1.tick_params(axis="both", which="major", direction="in")
ax1.minorticks_off()
set_tick_font(ax1, arial, STYLE["tick_fs"])
ax1.axhline(0, color="k", lw=1.0, ls="--", alpha=0.7, zorder=0)

# =============================================================================
# (b) RMSE / std(Ref) (%)
# =============================================================================
ax2r = ax2.twinx()

first_trial = True
for tr in trials:
    xs, ys = [], []
    for pct in pct_levels:
        if tr in rmse_by[pct]:
            xs.append(pct + offsets[tr])
            ys.append(rmse_by[pct][tr])
    if len(xs) == 0:
        continue

    ys = np.asarray(ys, dtype=float)

    label = "Random realizations (n = 5)" if first_trial else "_nolegend_"
    first_trial = False

    ax2.plot(
        xs, (ys / ref_std) * 100.0,
        "o", color=trial_colors[tr],
        markersize=STYLE["ms"],
        markeredgewidth=0,
        alpha=STYLE["alpha_marker"],
        label=label
    )
    ax2r.plot(xs, ys, "o", color=trial_colors[tr], alpha=0.0, label="_nolegend_")

for k, (lab, col) in no_map.items():
    if k in rmse and np.isfinite(rmse[k]):
        y_abs = float(rmse[k])
        ax2.plot(
            no_xpos[k], (y_abs / ref_std) * 100.0,
            "o", color=col,
            markersize=STYLE["ms"],
            alpha=STYLE["alpha_marker"],
            markeredgewidth=0,
            label=lab
        )
        ax2r.plot(no_xpos[k], y_abs, "o", color=col, alpha=0.0, label="_nolegend_")

ax2.set_xticks([0, 20, 40, 60, 80])
ax2.set_xticklabels(["0", "20", "40", "60", "80"])
ax2.set_yticks(np.arange(20, 161, 20))
ax2.set_yticklabels([f"{x}%" for x in np.arange(20, 161, 20)])

ax2.set_title("(b) Annual heating rate error", fontproperties=arial, fontsize=STYLE["title_fs"], pad=6)
ax2.tick_params(axis="both", which="major", direction="in")
ax2.minorticks_off()
ax2.set_ylabel("RMSE/STD(REF)", fontproperties=arial, fontsize=10, labelpad=1)
set_tick_font(ax2, arial, STYLE["tick_fs"])

yl0, yl1 = ax2.get_ylim()
scale_b = ref_std / 100.0
ax2r.set_ylim(yl0 * scale_b, yl1 * scale_b)

rmse_ticks = [0.07, 0.14, 0.21, 0.28, 0.35, 0.42, 0.49, 0.56]
rmse_ticks = filter_ticks_to_ylim(rmse_ticks, ax2r.get_ylim())
ax2r.set_yticks(rmse_ticks)

def fmt_rmse(v, pos):
    s = f"{v:.2f}".rstrip("0").rstrip(".")
    return "0" if s == "-0" else s

ax2r.yaxis.set_major_formatter(mticker.FuncFormatter(fmt_rmse))
ax2r.set_ylabel("RMSE (W m$^{-2}$)", fontproperties=arial, fontsize=10, labelpad=1)
ax2r.tick_params(axis="y", which="major", direction="in")
ax2r.minorticks_off()
set_tick_font(ax2r, arial, STYLE["tick_fs"])

# =============================================================================
# (c) Trend error / Ref trend (%)
# =============================================================================
ax3r = ax3.twinx()

first_trial = True
for tr in trials:
    xs, ys = [], []
    for pct in pct_levels:
        if tr in trend_by[pct]:
            xs.append(pct + offsets[tr])
            ys.append(trend_by[pct][tr])
    if len(xs) == 0:
        continue

    ys = np.array(ys)

    label = "Random realizations (n = 5)" if first_trial else "_nolegend_"
    first_trial = False

    ax3.plot(
        xs, abs((ys / trend_ref) * 100.0),
        "o", color=trial_colors[tr],
        markersize=STYLE["ms"],
        alpha=STYLE["alpha_marker"],
        markeredgewidth=0,
        label=label
    )
    ax3r.plot(xs, ys * 10, "o", color=trial_colors[tr], alpha=0.0, label="_nolegend_")

for k, (lab, col) in no_map.items():
    if k in trend_rel_err and np.isfinite(trend_rel_err[k]):
        ys = np.array(trend_err[k])
        ax3.plot(
            no_xpos[k], abs((ys / trend_ref) * 100.0),
            "o", color=col,
            markersize=STYLE["ms"],
            alpha=STYLE["alpha_marker"],
            markeredgewidth=0,
            label=lab
        )
        ax3r.plot(no_xpos[k], ys * 10, "o", color=col, alpha=0.0, label="_nolegend_")

ax3.axhline(0, color="k", lw=1.0, ls="--", alpha=0.7, zorder=0)

ax3.set_xticks([0, 20, 40, 60, 80])
ax3.set_xticklabels(["0", "20", "40", "60", "80"])
ax3.set_yticks([0, 10, 20, 30, 40])
ax3.set_yticklabels(["0", "10%", "20%", "30%", "40%"])

ax3.set_title("(c) Warming acceleration error", fontproperties=arial, fontsize=STYLE["title_fs"], pad=6)
ax3.set_ylabel(r"$\vert\mathrm{DIFF}/\mathrm{TREND(REF)}\vert$", fontproperties=arial, fontsize=10, labelpad=1)
ax3.tick_params(axis="both", which="major", direction="in")
ax3.minorticks_off()
set_tick_font(ax3, arial, STYLE["tick_fs"])

yl0, yl1 = ax3.get_ylim()
scale_c = abs(trend_ref) / 100.0 * 10.0
ax3r.set_ylim(yl0 * scale_c, yl1 * scale_c)

trend_ticks = [0.0, 0.05, 0.1, 0.15, 0.2]
trend_ticks = filter_ticks_to_ylim(trend_ticks, ax3r.get_ylim())
ax3r.set_yticks(trend_ticks)

def fmt_trend(v, pos):
    if abs(v) < 5e-6:
        return "0"
    return f"{v:.2f}".rstrip("0").rstrip(".")

ax3r.yaxis.set_major_formatter(mticker.FuncFormatter(fmt_trend))
ax3r.set_ylabel(r"$\vert\mathrm{DIFF}\vert$ (W m$^{-2}$ dec$^{-1}$)", fontproperties=arial, fontsize=10, labelpad=1)
ax3r.tick_params(axis="y", which="major", direction="in")
ax3r.minorticks_off()
set_tick_font(ax3r, arial, STYLE["tick_fs"])

# =============================================================================
# 8. Shared labels and legends (panel b -> panel c)
# =============================================================================
fig.supxlabel(
    "Percentage of data removed (%)",
    fontproperties=arial,
    fontsize=11,
    y=0.072
)

# Legend background blocks (draw before legends)
ax3.add_patch(
    FancyBboxPatch(
        (0.035, 0.66 + 0.09), 0.71, 0.23,
        boxstyle="round,pad=0.01,rounding_size=0.02",
        transform=ax3.transAxes,
        facecolor="#F5F5F5",
        edgecolor="none",
        alpha=1,
        zorder=0
    )
)
ax3.add_patch(
    FancyBboxPatch(
        (0.035, 0.66), 0.25, 0.09,
        boxstyle="round,pad=0.01,rounding_size=0.02",
        transform=ax3.transAxes,
        facecolor="#F5F5F5",
        edgecolor="none",
        alpha=1,
        zorder=0
    )
)

handles2, labels2 = ax2.get_legend_handles_labels()

seen = set()
handles_u, labels_u = [], []
for h, l in zip(handles2, labels2):
    if l not in seen:
        handles_u.append(h)
        labels_u.append(l)
        seen.add(l)

leg_main = ax3.legend(
    handles_u[:1], labels_u[:1],
    frameon=False,
    loc="upper left",
    bbox_to_anchor=(-0.02, 1.00),
    prop=arial,
    fontsize=STYLE["legend_fs"],
    handletextpad=-0.5,
    borderaxespad=0.0,
)
for t in leg_main.get_texts():
    t.set_fontproperties(arial)
    t.set_fontsize(STYLE["legend_fs"])
for h in leg_main.legend_handles:
    h.set_markersize(6)
ax3.add_artist(leg_main)

leg_sub = ax3.legend(
    handles_u[1:], labels_u[1:],
    frameon=False,
    loc="upper left",
    bbox_to_anchor=(-0.02, 0.92),
    prop=arial,
    fontsize=STYLE["legend_fs"],
    ncol=2,
    columnspacing=0.4,
    handletextpad=-0.5,
    borderaxespad=0.0,
)
for t in leg_sub.get_texts():
    t.set_fontproperties(arial)
    t.set_fontsize(STYLE["legend_fs"])
for h in leg_sub.legend_handles:
    h.set_markersize(6)

plt.show()
