Global Map (ENSO)

Global Map (ENSO)#

# --- Packages ---
## General Packages
import pandas as pd
import xarray as xr
import numpy as np
import os
import ipynbname

## GeoCAT
import geocat.comp as gccomp
import geocat.viz as gv
import geocat.viz.util as gvutil

## Visualization
import cmaps  
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import shapely.geometry as sgeom

## MatPlotLib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
import matplotlib.dates as mdates

from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.ticker import MultipleLocator

## Unique Plots
import matplotlib.gridspec as gridspec
# --- Automated output filename ---
def get_filename():
    try:
        # Will fail when __file__ is undefined (e.g., in notebooks)
        filename = os.path.splitext(os.path.basename(__file__))[0]
    except NameError:
        try:
            # Will fail during non-interactive builds
            import ipynbname
            nb_path = ipynbname.path()
            filename = os.path.splitext(os.path.basename(str(nb_path)))[0]
        except Exception:
            # Fallback during Jupyter Book builds or other headless execution
            filename = "template_analysis_clim"
    return filename

fnFIG = get_filename() + ".png"
print(f"Figure filename: {fnFIG}")
Figure filename: template_analysis_clim.png
# --- Parameter setting ---
data_dir="../data"
fname = "sst.cobe2.185001-202504.nc"
ystr, yend = 1950, 2020
fvar = "sst"

# --- Reading NetCDF Dataset ---
# Construct full path and open dataset
path_data = os.path.join(data_dir, fname)
ds = xr.open_dataset(path_data)

# Extract the variable
var = ds[fvar]

# Ensure dimensions are (time, lat, lon)
var = var.transpose("time", "lat", "lon", missing_dims="ignore")

# Ensure latitude is ascending
if var.lat.values[0] > var.lat.values[-1]:
    var = var.sortby("lat")

# Ensure time is in datetime64 format
if not np.issubdtype(var.time.dtype, np.datetime64):
    try:
        var["time"] = xr.decode_cf(ds).time
    except Exception as e:
        raise ValueError("Time conversion to datetime64 failed: " + str(e))

# === Select time range ===
dat = var.sel(time=slice(f"{ystr}-01-01", f"{yend}-12-31"))
print(dat)
<xarray.DataArray 'sst' (time: 852, lat: 180, lon: 360)> Size: 221MB
[55209600 values with dtype=float32]
Coordinates:
  * lat      (lat) float32 720B -89.5 -88.5 -87.5 -86.5 ... 86.5 87.5 88.5 89.5
  * lon      (lon) float32 1kB 0.5 1.5 2.5 3.5 4.5 ... 356.5 357.5 358.5 359.5
  * time     (time) datetime64[ns] 7kB 1950-01-01 1950-02-01 ... 2020-12-01
Attributes:
    long_name:     Monthly Means of Global Sea Surface Temperature
    valid_range:   [-5. 40.]
    units:         degC
    var_desc:      Sea Surface Temperature
    dataset:       COBE-SST2 Sea Surface Temperature
    statistic:     Mean
    parent_stat:   Individual obs
    level_desc:    Surface
    actual_range:  [-3.0000002 34.392    ]
def calc_seasonal_mean(dat, window=5, end_month=1, min_coverage=0.9, dtrend=True):
    """
    Calculate seasonal mean using a trailing running mean, extract the final month (e.g., Mar for NDJFM),
    convert to year-lat-lon DataArray, apply minimum coverage mask, and optionally remove trend.

    Parameters:
    -----------
    dat : xr.DataArray
        Input data with dimensions (time, lat, lon) and datetime64 'time'.
    window : int
        Running mean window size (default is 5).
    end_month : int
        Target month used to extract seasonal means (final month of the trailing average).
    min_coverage : float
        Minimum fraction of year coverage required for masking (default 0.9).
    dtrend : bool
        If True, remove linear trend after applying coverage mask.

    Returns:
    --------
    dat_out : xr.DataArray
        Seasonal mean with dimensions (year, lat, lon), optionally detrended and masked.
    """

    # Compute monthly anomalies
    clm = dat.groupby("time.month").mean(dim="time")
    anm = dat.groupby("time.month") - clm

    # Apply trailing running mean
    dat_rm = anm.rolling(time=window, center=False, min_periods=window).mean()

    # Filter for entries where month == end_month
    dat_tmp = dat_rm.sel(time=dat_rm["time"].dt.month == end_month)

    # Extract year from the end_month timestamps
    years = dat_tmp["time"].dt.year

    # Create clean DataArray with dimensions ['year', 'lat', 'lon']
    datS = xr.DataArray(
        data=dat_tmp.values,
        dims=["year", "lat", "lon"],
        coords={
            "year": years.values,
            "lat": dat_tmp["lat"].values,
            "lon": dat_tmp["lon"].values,
        },
        name=dat.name if hasattr(dat, "name") else "SeasonalMean",
        attrs=dat.attrs.copy(),
    )

    # Remove edge years if all values are missing
    for yr in [datS.year.values[0], datS.year.values[-1]]:
        if datS.sel(year=yr).isnull().all():
            datS = datS.sel(year=datS.year != yr)

    # Apply coverage mask
    valid_counts = datS.count(dim="year")
    total_years = datS["year"].size
    min_valid_years = int(total_years * min_coverage)
    sufficient_coverage = valid_counts >= min_valid_years
    datS_masked = datS.where(sufficient_coverage)

    dat_out = datS_masked

    # Optional detrending
    if dtrend:
        coeffs = datS_masked.polyfit(dim="year", deg=1)
        trend = xr.polyval(datS_masked["year"], coeffs.polyfit_coefficients)
        dat_out = datS_masked - trend

    # Update attributes
    dat_out.attrs.update(datS.attrs)
    dat_out.attrs["note1"] = (
        f"{window}-month trailing mean ending in month {end_month}. Year corresponds to {end_month}."
    )
    if dtrend:
        dat_out.attrs["note2"] = f"Linear trend removed after applying {int(min_coverage * 100)}% data coverage mask."
    else:
        dat_out.attrs["note2"] = f"Applied {int(min_coverage * 100)}% data coverage mask only (no detrending)."

    return dat_out


# -- Get NDJFM mean with detrended.
dat_NDJFM = calc_seasonal_mean(dat, window=5, end_month=2, dtrend=True)
dat_MJJAS = calc_seasonal_mean(dat, window=5, end_month=9, dtrend=True)
print(dat_NDJFM)

  
<xarray.DataArray (year: 70, lat: 180, lon: 360)> Size: 36MB
array([[[       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        ...,
        [0.02326872, 0.02348154, 0.02298293, ..., 0.02257151,
         0.02365055, 0.02348184],
        [0.02483404, 0.02546713, 0.02575516, ..., 0.02538817,
         0.02564338, 0.02559502],
        [0.02242027, 0.0222255 , 0.02272483, ..., 0.02195308,
         0.02223469, 0.02248361]],

       [[       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
...
        [0.04113557, 0.04261062, 0.04251236, ..., 0.039855  ,
         0.04042611, 0.04140724],
        [0.05431002, 0.05359076, 0.05437862, ..., 0.05354543,
         0.05313137, 0.05407702],
        [0.04739001, 0.04772313, 0.04759856, ..., 0.04785807,
         0.04732536, 0.0477709 ]],

       [[       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan],
        ...,
        [0.04385129, 0.04382134, 0.04361135, ..., 0.04428564,
         0.04415513, 0.04422673],
        [0.02202884, 0.02230439, 0.02168488, ..., 0.02329186,
         0.02287089, 0.02251357],
        [0.01074545, 0.01067164, 0.00994377, ..., 0.01071551,
         0.01077672, 0.01121927]]], shape=(70, 180, 360))
Coordinates:
  * year     (year) int64 560B 1951 1952 1953 1954 1955 ... 2017 2018 2019 2020
  * lat      (lat) float32 720B -89.5 -88.5 -87.5 -86.5 ... 86.5 87.5 88.5 89.5
  * lon      (lon) float32 1kB 0.5 1.5 2.5 3.5 4.5 ... 356.5 357.5 358.5 359.5
Attributes:
    long_name:     Monthly Means of Global Sea Surface Temperature
    valid_range:   [-5. 40.]
    units:         degC
    var_desc:      Sea Surface Temperature
    dataset:       COBE-SST2 Sea Surface Temperature
    statistic:     Mean
    parent_stat:   Individual obs
    level_desc:    Surface
    actual_range:  [-3.0000002 34.392    ]
    note1:         5-month trailing mean ending in month 2. Year corresponds ...
    note2:         Linear trend removed after applying 90% data coverage mask.
def regional_weighted_mean(data, lat1, lat2, lon1, lon2):
    """
    Compute cosine-latitude-weighted regional mean over the specified lat-lon bounds.
    
    Parameters
    ----------
    data : xarray.DataArray
        Input data with dimensions (time, lat, lon).
    lat1, lat2 : float
        Latitude bounds.
    lon1, lon2 : float
        Longitude bounds (either in -180 to 180 or 0 to 360).
    
    Returns
    -------
    dat_region_mean : xarray.DataArray
        Regional mean time series (time).
    """
    # --- Longitude adjustment ---
    target_lon_range = 'neg180_180' if lon1 < 0 or lon2 < 0 else '0_360'
    if target_lon_range == 'neg180_180' and (data.lon > 180).any():
        data = data.assign_coords(lon=((data.lon + 180) % 360 - 180)).sortby('lon')
    elif target_lon_range == '0_360' and (data.lon < 0).any():
        data = data.assign_coords(lon=(data.lon % 360)).sortby('lon')

    # --- Select region ---
    dat_region = data.sel(lat=slice(lat1, lat2), lon=slice(lon1, lon2))

    # --- Apply cosine-latitude weighting ---
    weights = np.cos(np.deg2rad(dat_region['lat']))
    dat_region_mean = dat_region.weighted(weights).mean(dim=['lat', 'lon'], skipna=True)

    # --- Assign name ---
    dat_region_mean.name = f"regional_mean_{lat1}_{lat2}_{lon1}_{lon2}"

    return dat_region_mean

# --- Calculate Niño 3.4 index from NDJFM data ---
nino34_ts = regional_weighted_mean(dat_NDJFM, -5, 5, 190, 240)

# --- Calculate standard deviation over the available years ---
nino34_std = nino34_ts.std(dim="year", skipna=True)

# --- Standardize the time series ---
nino34_ts_std = nino34_ts / nino34_std

# --- Update attributes ---
nino34_ts_std.attrs = nino34_ts.attrs.copy()
nino34_ts_std.attrs['note1'] = f"Standardized Niño 3.4 index. Standard deviation: {nino34_std.values:.3f}"

print(nino34_ts_std)
<xarray.DataArray 'regional_mean_-5_5_190_240' (year: 70)> Size: 560B
array([-0.84242396,  0.86353962,  0.15385515,  0.2855797 , -0.93414499,
       -1.22521061, -0.09921426,  1.60278945,  0.58513473,  0.18864412,
       -0.08125704, -0.50144525, -0.42366512,  1.01065251, -0.88651785,
        1.48532399, -0.10824377, -0.47533738,  0.72837009,  0.76351155,
       -1.38114218, -0.85111633,  1.68083599, -1.82370587, -0.59160717,
       -1.48655636,  0.64278058,  0.7629035 , -0.07909781,  0.43672724,
       -0.16109375, -0.26470221,  2.01019884, -0.70970947, -0.85609144,
       -0.37170209,  1.20287726,  0.92019422, -1.73201992, -0.16373535,
        0.28102254,  1.47835879,  0.08953337,  0.29129455,  0.95906713,
       -1.02026181, -0.55540958,  2.35493331, -1.39454424, -1.50470693,
       -0.67777684, -0.17179938,  1.15602986,  0.30964153,  0.54970597,
       -0.58591287,  0.65923747, -1.67075273, -0.72687507,  1.3178921 ,
       -1.60351327, -0.91385179, -0.09065725, -0.49126429,  0.54168987,
        2.41709005, -0.53377439, -0.86609021,  0.7018336 ,  0.4256821 ])
Coordinates:
  * year     (year) int64 560B 1951 1952 1953 1954 1955 ... 2017 2018 2019 2020
Attributes:
    note1:    Standardized Niño 3.4 index. Standard deviation: 1.038
def lead_lag_correlation_geocat(ts, ystr1, yend1, dat, ystr2, yend2, min_coverage=0.9):
    """
    Compute Pearson correlation between a 1D timeseries (ts) and a gridded field (dat),
    using geocat.comp.stats.pearson_r, applying a data coverage mask (e.g., 90%) and
    safely skipping grid points with no variability.

    Parameters
    ----------
    ts : xr.DataArray
        1D time series with dimension 'year'.
    ystr1, yend1 : int
        Start and end year for the timeseries.
    dat : xr.DataArray
        3D gridded data with dimensions ('year', 'lat', 'lon').
    ystr2, yend2 : int
        Start and end year for the field data.
    min_coverage : float
        Minimum fraction of year coverage required to compute correlation (default 0.9).

    Returns
    -------
    corr : xr.DataArray
        Correlation map with dimensions ('lat', 'lon'), masked where coverage or variability is insufficient.
    """

    # Subset both datasets
    ts_sel = ts.sel(year=slice(ystr1, yend1))
    dat_sel = dat.sel(year=slice(ystr2, yend2))

    # Check matching year sizes after subset
    if ts_sel['year'].size != dat_sel['year'].size:
        raise ValueError(f"Year mismatch: ts has {ts_sel['year'].size}, dat has {dat_sel['year'].size}.")

    # Replace 'year' coordinate with dummy index to avoid alignment errors in geocat.comp
    dummy_year = np.arange(ts_sel['year'].size)
    ts_sel = ts_sel.assign_coords(year=dummy_year)
    dat_sel = dat_sel.assign_coords(year=dummy_year)

    # Calculate std and valid counts
    dat_std = dat_sel.std(dim="year", skipna=True)
    valid_counts = dat_sel.count(dim="year")  # cleaner and xarray-native

    min_required_years = int(dat_sel['year'].size * min_coverage)

    # Create mask where std == 0 or insufficient coverage
    #invalid_mask = (dat_std == 0) | (valid_counts < min_required_years)
    invalid_mask = (dat_std < 1e-10) | (valid_counts < min_required_years)

    # Apply mask (mask invalid areas to NaN; keep valid areas unchanged)
    dat_masked = dat_sel.where(~invalid_mask)

    # Calculate correlation using geocat
    corr = gccomp.stats.pearson_r(ts_sel, dat_masked, dim="year", skipna=True, keep_attrs=True)

    return corr
def lead_lag_correlation(ts, ystr1, yend1, dat, ystr2, yend2, min_coverage=0.9):
    """
    Compute Pearson correlation between a 1D timeseries (ts) and a gridded field (dat)
    using numpy's corrcoef, applying a data coverage mask (e.g., 90%).

    Parameters
    ----------
    ts : xr.DataArray
        1D time series with dimension 'year'.
    ystr1, yend1 : int
        Start and end year for the timeseries.
    dat : xr.DataArray
        3D gridded data with dimensions ('year', 'lat', 'lon').
    ystr2, yend2 : int
        Start and end year for the field data.
    min_coverage : float
        Minimum fraction of year coverage required to compute correlation (default 0.9).

    Returns
    -------
    corr_map : xr.DataArray
        Correlation map with dimensions ('lat', 'lon'), masked where coverage is insufficient.
    """
    # Subset both datasets for the specified years
    ts_sel = ts.sel(year=slice(ystr1, yend1)).values  # shape: (year,)
    dat_sel = dat.sel(year=slice(ystr2, yend2)).values  # shape: (year, lat, lon)

    # Check matching number of years
    if ts_sel.shape[0] != dat_sel.shape[0]:
        raise ValueError(f"Number of years mismatch: ts has {ts_sel.shape[0]}, dat has {dat_sel.shape[0]}.")

    ntime, nlat, nlon = dat_sel.shape

    # Calculate valid data counts at each grid point
    valid_counts = np.isfinite(dat_sel).sum(axis=0)
    min_required_years = int(ntime * min_coverage)

    # Mask for sufficient coverage
    sufficient_mask = valid_counts >= min_required_years

    # Initialize correlation map
    corr_map = np.full((nlat, nlon), np.nan)

    # Calculate correlation only for sufficient coverage grid points
    for i in range(nlat):
        for j in range(nlon):
            if sufficient_mask[i, j]:
                x = dat_sel[:, i, j]
                mask = np.isfinite(x) & np.isfinite(ts_sel)
                if mask.sum() > 2:
                    r = np.corrcoef(ts_sel[mask], x[mask])[0, 1]
                    corr_map[i, j] = r

    # Convert to xarray
    corr_da = xr.DataArray(
        corr_map,
        dims=['lat', 'lon'],
        coords={'lat': dat['lat'], 'lon': dat['lon']},
        name='lead_lag_correlation'
    )

    # Mask out insufficient data coverage explicitly
    corr_da = corr_da.where(sufficient_mask)

    return corr_da
corr_m12 = lead_lag_correlation_geocat(
    ts=nino34_ts_std,
    ystr1=1952, yend1=2020,
    dat=dat_NDJFM,
    ystr2=1951, yend2=2019,
    min_coverage=0.9
)

corr_m6 = lead_lag_correlation_geocat(
    ts=nino34_ts_std,
    ystr1=1952, yend1=2020,
    dat=dat_MJJAS,
    ystr2=1951, yend2=2019,
    min_coverage=0.9
)

corr_0 = lead_lag_correlation_geocat(
    ts=nino34_ts_std,
    ystr1=1951, yend1=2020,
    dat=dat_NDJFM,
    ystr2=1951, yend2=2020,
    min_coverage=0.9
)

corr_p6 = lead_lag_correlation_geocat(
    ts=nino34_ts_std,
    ystr1=1951, yend1=2020,
    dat=dat_MJJAS,
    ystr2=1951, yend2=2020,
    min_coverage=0.9
)
/Users/a02235045/miniconda3/envs/vbook/lib/python3.13/site-packages/xskillscore/core/np_deterministic.py:314: RuntimeWarning: invalid value encountered in divide
  r = r_num / r_den
/Users/a02235045/miniconda3/envs/vbook/lib/python3.13/site-packages/xskillscore/core/np_deterministic.py:314: RuntimeWarning: invalid value encountered in divide
  r = r_num / r_den
/Users/a02235045/miniconda3/envs/vbook/lib/python3.13/site-packages/xskillscore/core/np_deterministic.py:314: RuntimeWarning: invalid value encountered in divide
  r = r_num / r_den
/Users/a02235045/miniconda3/envs/vbook/lib/python3.13/site-packages/xskillscore/core/np_deterministic.py:314: RuntimeWarning: invalid value encountered in divide
  r = r_num / r_den
def plot_global_panel_OP2(ax, std_data, std_levels, titleM, titleL, color_list,
                      clm_data=None, clm_levels=None):
    """
    Plot global map of monthly standard deviation with optional climatology contours.

    Parameters
    ----------
    ax : cartopy.mpl.geoaxes.GeoAxesSubplot
        Axes to plot on.
    std_data : xarray.DataArray (month, lat, lon)
        Standard deviation data to plot with filled contours.
    std_levels : list or array-like
        Contour levels for the std_data.
    title : str
        Title above the plot panel.
    color_list : list
        Colors for the colormap.
    clm_data : xarray.DataArray (month, lat, lon), optional
        Climatology data for contour overlay.
    clm_levels : list or array-like, optional
        Contour levels for climatology.
    """

    # --- Add cyclic longitudes ---
    std_cyclic = gvutil.xr_add_cyclic_longitudes(std_data, "lon")

    # --- Plot standard deviation ---
    im = ax.contourf(
        std_cyclic["lon"],
        std_cyclic["lat"],
        std_cyclic,
        levels=std_levels,
        cmap=color_list,
        extend="both",
        transform=ccrs.PlateCarree()
    )

    # --- Plot climatology contours (optional) ---
    if clm_data is not None:
        clm_cyclic = gvutil.xr_add_cyclic_longitudes(clm_data, "lon")
        levels = clm_levels if clm_levels is not None else 10
        
        csW = ax.contour(
            clm_cyclic["lon"], clm_cyclic["lat"], clm_cyclic,
            levels=levels, colors="white", linewidths=3,
            transform=ccrs.PlateCarree()
        )
        csB = ax.contour(
            clm_cyclic["lon"], clm_cyclic["lat"], clm_cyclic,
            levels=levels, colors="black", linewidths=1,
            transform=ccrs.PlateCarree()
        )

        ax.clabel(csW, inline=1, fontsize=8, fmt="%.0f")
        ax.clabel(csB, inline=1, fontsize=8, fmt="%.0f")

    # --- Map features ---
    ax.add_feature(cfeature.LAND, facecolor='lightgray', edgecolor='none', zorder=0)
    ax.coastlines(linewidth=0.5, alpha=0.6)
    ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree())

    # --- Option 2: Manual setting ---
    # --- Set major ticks explicitly ---
    ax.set_xticks(np.arange(-180, 181, 60), crs=ccrs.PlateCarree())
    ax.set_yticks(np.arange(-90, 91, 30), crs=ccrs.PlateCarree())
    # Custom tick labels
    ax.set_yticklabels(["90°S", "60°S", "30°S", "EQ", "30°N", "60°N", "90°N"])
    ax.set_xticklabels(["180°", "120°W", "60°W", "0°", "60°E", "120°E", ""])

    # --- Set minor tick locators ---
    ax.xaxis.set_minor_locator(mticker.MultipleLocator(15))
    ax.yaxis.set_minor_locator(mticker.MultipleLocator(10))

    # --- Add tick marks ---
    gvutil.add_major_minor_ticks(
        ax,
        labelsize=10,
        x_minor_per_major=4,  # 15° spacing for minor if major is 60°
        y_minor_per_major=3   # 10° spacing for minor if major is 30°
    )

    # --- Customize ticks to show outward ---
    ax.tick_params(
        axis="both",
        which="both",
        direction="out",
        length=6,
        width=1,
        top=True,
        bottom=True,
        left=True,
        right=True
    )

    ax.tick_params(
        axis="both",
        which="minor",
        length=4,
        width=0.75
    )

    gvutil.set_titles_and_labels(ax, maintitle=titleM, maintitlefontsize=16, lefttitle=titleL, lefttitlefontsize=14)

    return im
# --- Global settings ---
proj = ccrs.PlateCarree(central_longitude=200)

fig = plt.figure(figsize=(12, 11))

gs = gridspec.GridSpec(nrows=3, ncols=8, height_ratios=[1.0, 1.0, 1.0], hspace=0.5, wspace=1.0)
ax0 = fig.add_subplot(gs[0, 1:7])
ax1 = fig.add_subplot(gs[1, 0:4], projection=proj)
ax2 = fig.add_subplot(gs[1, 4:8], projection=proj)
ax3 = fig.add_subplot(gs[2, 0:4], projection=proj)
ax4 = fig.add_subplot(gs[2, 4:8], projection=proj)


#fig.suptitle("SST Anomalies during ENSO events", fontsize=16, fontweight="bold")

colormap = cmaps.BlueYellowRed.colors  # Use the AMWG BlueYellowRed colormap
color_list = ListedColormap(colormap)
color_list.set_under(colormap[0])
color_list.set_over(colormap[-1])

clevs_corr = np.arange(-0.9, 1.0, 0.1)



# === Top panel: Niño 3.4 index ===

def setup_axis(ax, ystr, yend, ymin, ymax, y_major, y_minor):
    # Make grid go below the bars
    ax.set_axisbelow(True)

    # X-Axis
    #ax.xaxis.set_major_locator(MultipleLocator(10))
    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.set_xlim(ystr,yend)

    # Y-Axis
    ax.set_ylim(ymin, ymax)
    ax.yaxis.set_minor_locator(MultipleLocator(y_minor))

    # Customized grid setting
    ax.tick_params(axis="both", which="major", length=7, top=True, right=True)
    ax.tick_params(axis="both", which="minor", length=4, top=True, right=True)
    ax.grid(visible=True, which="major", linestyle="-", linewidth=0.7, alpha=0.7)
    ax.grid(visible=True, which="minor", linestyle="--", linewidth=0.4, alpha=0.5)

positive_color = "pink"
negative_color = "lightblue"
positive_color_edge = "red"
negative_color_edge = "blue"
bar_colors = [positive_color if val >= 0 else negative_color for val in nino34_ts_std]
edge_colors = [positive_color_edge if val >= 0 else negative_color_edge for val in nino34_ts_std]

ax0.bar(nino34_ts_std.year, nino34_ts_std, color=bar_colors, edgecolor=edge_colors,
             linewidth=1, label="Nino 3.4 index")
setup_axis(ax0, ystr, yend+1, -3., 3., 1, 0.2)

gvutil.set_titles_and_labels(ax0,
    maintitle="SST Anomalies during ENSO events",
    lefttitle="(a) Nino 3.4 Index in NDJFM",
    righttitle="",
    ylabel="Normalized",
    xlabel="Year",
    maintitlefontsize=14,  # Adjust main title font size
    lefttitlefontsize=12,  # Set left title font size
    righttitlefontsize=12,  # Set left title font size
    labelfontsize=10  # Set y-label font size
)
ax0.axhline(0, color="gray", linestyle="-")


# === Map panels ===
im1 = plot_global_panel_OP2(ax1, std_data=corr_m12, std_levels=clevs_corr,
                            titleM="", titleL="(a) SST anomalies at -12 month",
                            color_list=color_list)


im2 = plot_global_panel_OP2(ax2, std_data=corr_m6, std_levels=clevs_corr,
                            titleM="", titleL="(b) SST anomalies at -6 month",
                            color_list=color_list)


im3 = plot_global_panel_OP2(ax3, std_data=corr_0, std_levels=clevs_corr,
                            titleM="", titleL="(c) SST anomalies at 0 month",
                            color_list=color_list)

im4 = plot_global_panel_OP2(ax4, std_data=corr_p6, std_levels=clevs_corr,
                            titleM="", titleL="(d) SST anomalies at +6 month",
                            color_list=color_list)

# --- Add a single shared colorbar for the maps ---
cbar_ax = fig.add_axes([0.25, 0.05, 0.5, 0.02])
cbar = plt.colorbar(im1, cax=cbar_ax, orientation="horizontal",
                    spacing="uniform", extend='both', extendfrac=0.07,
                    drawedges=True)
cbar.set_label("Correlation Coefficient", fontsize=12)
cbar.ax.tick_params(length=0, labelsize=10)
cbar.outline.set_edgecolor("black")
cbar.outline.set_linewidth(1.0)

# --- Finalize ---
plt.savefig(fnFIG, dpi=300, bbox_inches="tight")
plt.show()
../_images/f9f240216be2dbd2785897e85d04b554ab5d9522fe62636e89894b5f18e720a0.png