# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.font_manager import FontProperties
arial = FontProperties(fname=r"C:\Windows\Fonts\arial.ttf")
arial_bold = FontProperties(fname=r"C:\Windows\Fonts\ARLRDBD.TTF")
simsun = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc")


# =============================================================================
# 1) Input data
# =============================================================================
save_path = r".\WOD_profiles_200501_202312.npz"
df_all = pd.read_parquet(r".\WOD_profiles_200501_202312.parquet")
with np.load(save_path, allow_pickle=True) as data:
    lats = data["lats"]
    lons = data["lons"]
    datasets = data["datasets"]
lats = np.concatenate(lats)[:].data
lons = np.concatenate(lons)[:].data
datasets = np.concatenate(datasets)[:].data
# =============================================================================
# 2) Country grouping: top-5 + OTHER
# =============================================================================
top_countries = ["UNITED STATES", "AUSTRALIA", "JAPAN", "FRANCE", "GERMANY"]
df_all["country_top"] = df_all["country"].where(
    df_all["country"].isin(top_countries),
    "OTHER",
)
df_by_country = {
    c: df_all[df_all["country_top"] == c].reset_index(drop=True)
    for c in top_countries + ["OTHER"]
}

# =============================================================================
# 3) Dataset metadata (names, draw order, colors, marker sizes)
# =============================================================================
DATASET_NAME = {
    1: "OSD",
    2: "CTD",
    3: "MBT",
    4: "XBT",
    5: "SUR",
    6: "APB",
    7: "MRB",
    8: "PFL",
    9: "DRB",
    10: "UOR",
    11: "GLD",
    12: "DBT",
    13: "STD",
    14: "microBT",
    -999: "Unknown",
}

DRAW_ORDER = np.array(
    [8, 6, 4, 2, 3, 5, 9, 1, 10, 11, 12, 13, 14, 7, -999],
    dtype=np.int32,
)

DRAW_COLORS = {
    8: "#6B6D70",
    6: "#3BB6DB",
    4: "#B11810",
    11: "royalblue",
    2: "darkorange",
    7: "#00FF00",
    3: "#8c564b",
    5: "#d62728",
    9: "gold",
    10: "purple",
    1: "lightpink",
    12: "#ffbb78",
    13: "#98df8a",
    14: "#ff9896",
    -999: "#c5b0d5",
}

DRAW_SIZES = {
    8: 0.05,
    6: 0.15,
    4: 0.15,
    11: 0.10,
    2: 0.15,
    7: 0.7,
    3: 0.15,
    5: 0.15,
    9: 0.10,
    10: 0.15,
    1: 0.15,
    12: 0.15,
    13: 0.15,
    14: 0.15,
    -999: 0.15,
}

# =============================================================================
# 4) Figure & map layout (6 map panels + 1 legend panel)
# =============================================================================
fig = plt.figure(figsize=(7, 6.5), dpi=500)
proj = ccrs.Robinson(central_longitude=200)

gs = fig.add_gridspec(
    nrows=4,
    ncols=2,
    height_ratios=[1, 1, 1, 0.05],
    hspace=0.13,
    wspace=0.02,
)

ax_0_0 = fig.add_subplot(gs[0, 0], projection=proj)
ax_0_1 = fig.add_subplot(gs[0, 1], projection=proj)
ax_1_0 = fig.add_subplot(gs[1, 0], projection=proj)
ax_1_1 = fig.add_subplot(gs[1, 1], projection=proj)
ax_2_0 = fig.add_subplot(gs[2, 0], projection=proj)
ax_2_1 = fig.add_subplot(gs[2, 1], projection=proj)
ax_3 = fig.add_subplot(gs[3, :])

geo_axes = [ax_0_0, ax_0_1, ax_1_0, ax_1_1, ax_2_0, ax_2_1]
for ax in geo_axes:
    if "geo" in ax.spines:
        ax.spines["geo"].set_visible(False)

for ax in geo_axes:
    ax.set_global()
    ax.add_feature(cfeature.LAND, zorder=3, facecolor="lightgrey")

for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    dataset_index = datasets == ds
    if dataset_index.sum() == 0:
        continue
    lat_need = lats[dataset_index]
    lon_need = lons[dataset_index]
    ax_0_0.scatter(
        lon_need,
        lat_need,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_0_0.set_title("(a) All countries", fontproperties=arial, fontsize=10, pad=1)

cname = "UNITED STATES"
clab = "(b) United States (US) - 1st"
df_c = df_by_country[cname]
for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    sub = df_c[df_c["dataset"] == ds]
    if sub.empty:
        continue
    ax_0_1.scatter(
        sub["lon"].values,
        sub["lat"].values,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_0_1.set_title(clab, fontproperties=arial, fontsize=10, pad=1)

cname = "AUSTRALIA"
clab = "(c) Australia (AU) - 2nd"
df_c = df_by_country[cname]
for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    sub = df_c[df_c["dataset"] == ds]
    if sub.empty:
        continue
    ax_1_0.scatter(
        sub["lon"].values,
        sub["lat"].values,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_1_0.set_title(clab, fontproperties=arial, fontsize=10, pad=1)

cname = "JAPAN"
clab = "(d) Japan (JP) - 3rd"
df_c = df_by_country[cname]
for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    sub = df_c[df_c["dataset"] == ds]
    if sub.empty:
        continue
    ax_1_1.scatter(
        sub["lon"].values,
        sub["lat"].values,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_1_1.set_title(clab, fontproperties=arial, fontsize=10, pad=1)

cname = "FRANCE"
clab = "(e) France (FR) - 4th"
df_c = df_by_country[cname]
for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    sub = df_c[df_c["dataset"] == ds]
    if sub.empty:
        continue
    ax_2_0.scatter(
        sub["lon"].values,
        sub["lat"].values,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_2_0.set_title(clab, fontproperties=arial, fontsize=10, pad=1)

cname = "GERMANY"
clab = "(f) Germany (DE) - 5th"
df_c = df_by_country[cname]
for ds in DRAW_ORDER.astype(int):
    s_size = DRAW_SIZES[ds]
    sub = df_c[df_c["dataset"] == ds]
    if sub.empty:
        continue
    ax_2_1.scatter(
        sub["lon"].values,
        sub["lat"].values,
        s=s_size,
        c=DRAW_COLORS[ds],
        alpha=0.9,
        linewidths=0,
        transform=ccrs.PlateCarree(),
        zorder=2,
    )
ax_2_1.set_title(clab, fontproperties=arial, fontsize=10, pad=1)

existing_ds = np.unique(datasets.astype(int))
draw_ds = [int(ds) for ds in DRAW_ORDER.astype(int) if int(ds) in existing_ds]
legend_handles = [Line2D([], [], linestyle="None") for _ in draw_ds]
legend_labels = [DATASET_NAME.get(ds, f"Code{ds}") for ds in draw_ds]
legend_colors = [DRAW_COLORS.get(ds, "#CCCCCC") for ds in draw_ds]

ax_3.axis("off")
leg = ax_3.legend(
    handles=legend_handles,
    labels=legend_labels,
    ncol=9,
    loc="center",
    frameon=False,
    fontsize=7,
    columnspacing=1.5,
    handlelength=0,
    handletextpad=0.2,
    labelcolor=legend_colors,
)

for t in leg.get_texts():
    t.set_fontproperties(arial_bold)
    t.set_fontweight("bold")
    t.set_fontsize(10)

plt.show()
