# -*- coding: utf-8 -*-

import sys
from collections import defaultdict
import numpy as np
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.font_manager import FontProperties

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
# =============================================================================
# 0. Fonts (Windows)
# =============================================================================
arial = FontProperties(fname=r"C:\Windows\Fonts\arial.ttf")
simsun = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc")

# =============================================================================
# 1. Custom colormap utility
# =============================================================================
sys.path.append(r".\Function")
import colormap_utils  # noqa: E402

# =============================================================================
# 2. Load data
# =============================================================================
npz_path = r"./OHC_total_0to2000_200501_202312.npz"
with np.load(npz_path, allow_pickle=True) as data:
    # dict-like: OHC_total['Ref'], OHC_total['Pct80_Trial1'], ...
    OHC_total = data["OHC_total"].item()

scenario_keys_exp = [
    "Pct80_Trial1", "Pct80_Trial2", "Pct80_Trial3", "Pct80_Trial4", "Pct80_Trial5",
    "Pct60_Trial1", "Pct60_Trial2", "Pct60_Trial3", "Pct60_Trial4", "Pct60_Trial5",
    "Pct40_Trial1", "Pct40_Trial2", "Pct40_Trial3", "Pct40_Trial4", "Pct40_Trial5",
    "Pct20_Trial1", "Pct20_Trial2", "Pct20_Trial3", "Pct20_Trial4", "Pct20_Trial5",
    "NoUS", "NoAU", "NoJA", "NoFR", "NoGE",
]

# Grid info (lat/lon/depth/area; plotting uses lat/lon only)
ds = xr.open_dataset(r".\area_180x360x119.nc")
lat = ds["lat"].values
lon = ds["lon"].values
depth = ds["depth"].values
area_xyz = ds["area_xyz"].values

# =============================================================================
# 3. Annual RMSE & nRMSE (grid-point wise)
#    - Convert monthly time series to annual means first
#    - RMSE = sqrt(mean((Exp_annual - Ref_annual)^2))
#    - nRMSE = RMSE / std(Ref_annual)
# =============================================================================
ohc_annual_RMSE = defaultdict(lambda: defaultdict(list))
ohc_annual_nRMSE = defaultdict(lambda: defaultdict(list))

OHC_3D_Ref = OHC_total["Ref"]  # expected shape: (lat, lon, time)

for label in scenario_keys_exp:
    OHC_3D = OHC_total[label]

    rmse_2d = np.zeros((180, 360), dtype=np.float64)
    nrmse_2d = np.full((180, 360), np.nan, dtype=np.float64)

    for iLat in range(180):
        for iLon in range(360):
            exp_ts = OHC_3D[iLat, iLon, :]
            ref_ts = OHC_3D_Ref[iLat, iLon, :]

            # Monthly -> annual means (assumes complete years: len(time) divisible by 12)
            exp_annual = exp_ts.reshape(-1, 12).mean(axis=1)
            ref_annual = ref_ts.reshape(-1, 12).mean(axis=1)

            delta = exp_annual - ref_annual
            rmse = np.sqrt(np.mean(delta ** 2))
            rmse_2d[iLat, iLon] = rmse

            sigma = np.nanstd(ref_annual)
            if np.isfinite(sigma) and sigma != 0:
                nrmse_2d[iLat, iLon] = rmse / sigma

    ohc_annual_RMSE[label] = rmse_2d
    ohc_annual_nRMSE[label] = nrmse_2d

# =============================================================================
# 4. Pattern correlation among trials (used to assess similarity across 5 trials)
# =============================================================================
def pattern_corr(a, b):
    """Pattern correlation over finite grid cells."""
    m = np.isfinite(a) & np.isfinite(b)
    aa, bb = a[m].ravel(), b[m].ravel()
    aa = aa - aa.mean()
    bb = bb - bb.mean()
    return (aa @ bb) / np.sqrt((aa @ aa) * (bb @ bb))

def pairwise_corr(rmse_5):
    """Pairwise pattern correlation matrix for a stack: (n, lat, lon)."""
    n = rmse_5.shape[0]
    C = np.full((n, n), np.nan)
    for i in range(n):
        for j in range(n):
            C[i, j] = pattern_corr(rmse_5[i], rmse_5[j])
    return C

pct_list = ["Pct20", "Pct40", "Pct60", "Pct80"]
ohc_annual_rmse_mean = {}
Corr = {}

for pct in pct_list:
    pct_data = [ohc_annual_RMSE[k] for k in ohc_annual_RMSE.keys() if k.startswith(pct)]
    pct_data = np.array(pct_data)              # (5, lat, lon)
    mean_map = np.nanmean(pct_data, axis=0)    # (lat, lon)

    C = pairwise_corr(pct_data)
    C[C == 1] = np.nan  # mask diagonal (self-corr)
    ohc_annual_rmse_mean[pct] = mean_map
    Corr[pct] = C

# =============================================================================
# 5. Plotting
# =============================================================================
mpl.rcParams["font.family"] = "Arial"
mpl.rcParams["mathtext.fontset"] = "dejavusans"
mpl.rcParams["axes.unicode_minus"] = False

fig_title = ["(a) No-20%", "(b) No-40%", "(c) No-60%", "(d) No-80%"]

fig = plt.figure(figsize=(7, 4.5), dpi=300)
proj = ccrs.Robinson(central_longitude=200)

gs = fig.add_gridspec(
    nrows=7,
    ncols=12,
    height_ratios=[1, 1, 1, 1, 1, 1, 0.2],
    hspace=0.13,
    wspace=0.05,
)

ax1 = fig.add_subplot(gs[0:2, 0:4], projection=proj)
ax2 = fig.add_subplot(gs[0:2, 4:8], projection=proj)
ax3 = fig.add_subplot(gs[0:2, 8:12], projection=proj)
ax4 = fig.add_subplot(gs[2:4, 0:4], projection=proj)
ax5 = fig.add_subplot(gs[2:4, 4:8], projection=proj)
ax6 = fig.add_subplot(gs[2:4, 8:12], projection=proj)
ax7 = fig.add_subplot(gs[4:6, 0:4], projection=proj)
ax8 = fig.add_subplot(gs[4:6, 4:8], projection=proj)
ax9 = fig.add_subplot(gs[4:6, 8:12], projection=proj)
ax_cbar = fig.add_subplot(gs[6, 1:9])

axes_map = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9]

# Remove Robinson boundary spine for cleaner look
for ax in axes_map:
    if "geo" in ax.spines:
        ax.spines["geo"].set_visible(False)

# Colormap levels: convert to "x 1e9 J" scale in colorbar label
base_levels = np.arange(0, 2.1, 0.2) * 1e9 / 5
cmap, norm, extend, refined_levels = colormap_utils.get_cmap_norm(
    base_levels,
    cmap_name="RdYlBu_r",
    zero_centered_colors=True,
    extend_option="max",
    colors_per_interval=2,
)

def setup_map(ax):
    """Basic global map styling."""
    ax.set_global()
    ax.add_feature(cfeature.LAND, zorder=3, facecolor="lightgray")

cf_for_cbar = None

# (a)-(d): ensemble mean RMSE maps for Pct20/40/60/80
for i, pct in enumerate(pct_list):
    ax = axes_map[i]
    setup_map(ax)

    data_c = ohc_annual_rmse_mean[pct]
    data_c_cyc, lon_cyc = add_cyclic_point(data_c, coord=lon)

    cf_for_cbar = ax.contourf(
        lon_cyc,
        lat,
        data_c_cyc,
        levels=refined_levels,
        cmap=cmap,
        norm=norm,
        extend=extend,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
    )
    ax.set_title(fig_title[i], fontsize=8, pad=2.4)

# (e)-(i): selected country-removal scenarios
setup_map(ax5)
data_c_cyc, lon_cyc = add_cyclic_point(ohc_annual_RMSE["NoUS"], coord=lon)
cf_for_cbar = ax5.contourf(
    lon_cyc, lat, data_c_cyc,
    levels=refined_levels, cmap=cmap, norm=norm, extend=extend,
    transform=ccrs.PlateCarree(), alpha=0.9
)
ax5.set_title("(e) No-US", fontsize=8, pad=2.4)

setup_map(ax6)
data_c_cyc, lon_cyc = add_cyclic_point(ohc_annual_RMSE["NoAU"], coord=lon)
cf_for_cbar = ax6.contourf(
    lon_cyc, lat, data_c_cyc,
    levels=refined_levels, cmap=cmap, norm=norm, extend=extend,
    transform=ccrs.PlateCarree(), alpha=0.9
)
ax6.set_title("(f) No-AU", fontsize=8, pad=2.4)

setup_map(ax7)
data_c_cyc, lon_cyc = add_cyclic_point(ohc_annual_RMSE["NoJA"], coord=lon)
cf_for_cbar = ax7.contourf(
    lon_cyc, lat, data_c_cyc,
    levels=refined_levels, cmap=cmap, norm=norm, extend=extend,
    transform=ccrs.PlateCarree(), alpha=0.9
)
ax7.set_title("(g) No-JP", fontsize=8, pad=2.4)

setup_map(ax8)
data_c_cyc, lon_cyc = add_cyclic_point(ohc_annual_RMSE["NoFR"], coord=lon)
cf_for_cbar = ax8.contourf(
    lon_cyc, lat, data_c_cyc,
    levels=refined_levels, cmap=cmap, norm=norm, extend=extend,
    transform=ccrs.PlateCarree(), alpha=0.9
)
ax8.set_title("(h) No-FR", fontsize=8, pad=2.4)

setup_map(ax9)
data_c_cyc, lon_cyc = add_cyclic_point(ohc_annual_RMSE["NoGE"], coord=lon)
cf_for_cbar = ax9.contourf(
    lon_cyc, lat, data_c_cyc,
    levels=refined_levels, cmap=cmap, norm=norm, extend=extend,
    transform=ccrs.PlateCarree(), alpha=0.9
)
ax9.set_title("(i) No-DE", fontsize=8, pad=2.4)

# Colorbar
cbar = fig.colorbar(cf_for_cbar, cax=ax_cbar, orientation="horizontal")
ticks = base_levels[::2]
cbar.set_ticks(ticks)
cbar.ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x / 1e9:.2f}"))
cbar.ax.tick_params(axis="x", direction="in", length=4, width=0.8, labelsize=8, pad=2)
cbar.ax.minorticks_off()
cbar.ax.text(
    1.06,
    0.5,
    r"RMSE ($\times 10^9$ J)",
    transform=cbar.ax.transAxes,
    ha="left",
    va="center",
    fontproperties=arial,
    fontsize=8,
)

plt.show()
