Seasonal and Annual Averages#

Seasonal, Annual, and Water Year Averages#

This section introduces a generalized method to calculate seasonal means using a trailing running mean (i.e., aligned with the last month of the window). The method is flexible and supports the computation of:

  • Seasonal averages (e.g., DJF, NDJFM)

  • Annual averages (12-month running means)

  • Water year averages (October to September)

The function below includes masking for missing data coverage and an option to remove long-term trends.


📌 Function: calc_seasonal_mean() (Trailing Mean Based on End Month)#

def calc_seasonal_mean(dat, window=5, end_month=1, min_coverage=0.9, dtrend=True):
    """
    Calculate seasonal mean using a trailing running mean, extract the final month (e.g., Mar for NDJFM),
    convert to year-lat-lon DataArray, apply minimum coverage mask, and optionally remove trend.

    Parameters:
    -----------
    dat : xr.DataArray
        Input data with dimensions (time, lat, lon) and datetime64 'time'.
    window : int
        Running mean window size (default is 5).
    end_month : int
        Target month used to extract seasonal means (final month of the trailing average).
    min_coverage : float
        Minimum fraction of year coverage required for masking (default 0.9).
    dtrend : bool
        If True, remove linear trend after applying coverage mask.

    Returns:
    --------
    dat_out : xr.DataArray
        Seasonal mean with dimensions (year, lat, lon), optionally detrended and masked.
    """

    # Compute monthly anomalies
    clm = dat.groupby("time.month").mean(dim="time")
    anm = dat.groupby("time.month") - clm

    # Apply trailing running mean
    dat_rm = anm.rolling(time=window, center=False, min_periods=window).mean()

    # Filter for entries where month == end_month
    dat_tmp = dat_rm.sel(time=dat_rm["time"].dt.month == end_month)

    # Extract year from the end_month timestamps
    years = dat_tmp["time"].dt.year

    # Create clean DataArray with dimensions ['year', 'lat', 'lon']
    datS = xr.DataArray(
        data=dat_tmp.values,
        dims=["year", "lat", "lon"],
        coords={
            "year": years.values,
            "lat": dat_tmp["lat"].values,
            "lon": dat_tmp["lon"].values,
        },
        name=dat.name if hasattr(dat, "name") else "SeasonalMean",
        attrs=dat.attrs.copy(),
    )

    # Remove edge years if all values are missing
    for yr in [datS.year.values[0], datS.year.values[-1]]:
        if datS.sel(year=yr).isnull().all():
            datS = datS.sel(year=datS.year != yr)

    # Apply coverage mask
    valid_counts = datS.count(dim="year")
    total_years = datS["year"].size
    min_valid_years = int(total_years * min_coverage)
    sufficient_coverage = valid_counts >= min_valid_years
    datS_masked = datS.where(sufficient_coverage)

    dat_out = datS_masked

    # Optional detrending
    if dtrend:
        coeffs = datS_masked.polyfit(dim="year", deg=1)
        trend = xr.polyval(datS_masked["year"], coeffs.polyfit_coefficients)
        dat_out = datS_masked - trend

    # Update attributes
    dat_out.attrs.update(datS.attrs)
    dat_out.attrs["note1"] = (
        f"{window}-month trailing mean ending in month {end_month}. Year corresponds to {end_month}."
    )
    if dtrend:
        dat_out.attrs["note2"] = f"Linear trend removed after applying {int(min_coverage * 100)}% data coverage mask."
    else:
        dat_out.attrs["note2"] = f"Applied {int(min_coverage * 100)}% data coverage mask only (no detrending)."

    return dat_out


📌 Examples#

▶️ DJF Seasonal Mean (3-month average)#

# DJF: December–January–February, ending in February (month=2), window=3
dat_djf = calc_seasonal_mean(dat, window=3, end_month=2, dtrend=True)
print(dat_djf)

▶️ NDJFM Extended Seasonal Mean (5-month average)#

# NDJFM: November–March, ending in March (month=3), window=5
dat_ndjfm = calc_seasonal_mean(dat, window=5, end_month=3, dtrend=True)
print(dat_ndjfm)

▶️ Annual Mean#

# Calendar year average: 12-month mean ending in December (end_month=12)
dat_annual = calc_seasonal_mean(dat, window=12, end_month=12, dtrend=True)
print(dat_annual)

▶️ Water Year Mean (Oct–Sep, ending in September)#

# Water year average (Oct–Sep), 12-month mean ending in September (end_month=9)
dat_wy = calc_seasonal_mean(dat, window=12, end_month=9, dtrend=True)
print(dat_wy)

📎 Note: The output DataArray has a new year coordinate instead of time. Each value corresponds to the year of the end month used in the seasonal averaging.

This function provides a unified way to extract meaningful long-period averages with optional quality control and trend removal. Adjust the window and end_month parameters to match your season definition.