Skip to article frontmatterSkip to article content

Xarray 2: Computations and Masks


Overview

  1. Do basic arithmetic with DataArrays and Datasets
  2. Perform aggregation (reduction) along one or multiple dimensions of a DataArray or Dataset
  3. Compute climatology and anomaly using xarray’s “split-apply-combine” approach via .groupby()
  4. Perform weighted reductions along one or multiple dimensions of a DataArray or Dataset
  5. Provide an overview of masking data in xarray
  6. Mask data using .where() method

Prerequisites

ConceptsImportanceNotes
Xarray IntroNecessary
  • Time to learn: 45 minutes

Imports

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from pythia_datasets import DATASETS
/knight/pixi/sep24/sep25_env-5711917176091715257/envs/default/lib/python3.13/site-packages/pythia_datasets/__init__.py:4: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import DistributionNotFound, get_distribution

Let’s open the monthly sea surface temperature dataset from the Community Earth System Model v2 (CESM2), which is a Global Climate Model:

filepath = DATASETS.fetch('CESM2_sst_data.nc')
ds = xr.open_dataset(filepath)
ds
Loading...

Arithmetic Operations

Arithmetic operations with a single DataArray automatically apply over all array values (like NumPy). This process is called vectorization. Let’s convert the air temperature from degrees Celsius to kelvins:

ds.tos + 273.15
Loading...

Let’s square all values in tos:

ds.tos ** 2
Loading...

Aggregation Methods

A very common step during data analysis is to summarize the data in question by computing aggregations like sum(), mean(), median(), min(), max(). This reduced data provide insight into the nature of large datasets. Let’s explore some of these aggregation methods.

Compute the mean:

ds.tos.mean()
Loading...

Because we specified no dim argument the function was applied over all dimensions, computing the mean of every element of tos across time and space. It is possible to specify a dimension along which to compute an aggregation. For example, to calculate the mean in time for all locations, specify the time dimension as the dimension along which the mean should be calculated. Using Xarray’s plot function, visualize the mean SST.

ds.tos.mean(dim='time').plot(size=7);
<Figure size 933.333x700 with 2 Axes>

Compute the temporal min:

ds.tos.min(dim=['time'])
Loading...

Compute the spatial sum:

ds.tos.sum(dim=['lat', 'lon'])
Loading...

Compute the temporal median:

ds.tos.median(dim='time')
Loading...

The following table summarizes some other built-in xarray aggregations:

AggregationDescription
count()Total number of items
mean(), median()Mean and median
min(), max()Minimum and maximum
std(), var()Standard deviation and variance
prod()Compute product of elements
sum()Compute sum of elements
argmin(), argmax()Find index of minimum and maximum value

GroupBy: Split, Apply, Combine

Simple aggregations can give useful summaries of our dataset, but often we would prefer to aggregate conditionally on some coordinate labels or groups. Xarray provides the so-called groupby operation which enables the split-apply-combine workflow on Xarray DataArrays and Datasets.

First, let’s select a gridpoint closest to a specified lat-lon, and plot a time series of SST at that point. The annual cycle will be quite evident.

ds.tos.sel(lon=310, lat=50, method='nearest').plot();
<Figure size 640x480 with 1 Axes>

Split

Let’s group data by month, i.e. all Januaries in one group, all Februaries in one group, etc.

ds.tos.groupby(ds.time.dt.month)
<DataArrayGroupBy, grouped over 1 grouper(s), 12 groups in total: 'month': UniqueGrouper('month'), 12/12 groups with labels 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>
In the above example, we are using datetime's DatetimeAccessor function to extract specific components of dates/times in our time coordinate dimension. For example, we can extract the year with: ds.time.dt.year.

Xarray also offers a more concise syntax when the variable you’re grouping on is already present in the dataset. This is identical to ds.tos.groupby(ds.time.dt.month):

ds.tos.groupby('time.month')
<DataArrayGroupBy, grouped over 1 grouper(s), 12 groups in total: 'month': UniqueGrouper('month'), 12/12 groups with labels 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>

Apply & Combine

Now that we have groups defined, it’s time to “apply” a calculation to the group. These calculations can either be:

  • aggregation: reduces the size of the group
  • transformation: preserves the group’s full size

At then end of the apply step, xarray will automatically combine the aggregated/transformed groups back into a single object.

Compute climatology

Let’s calculate the climatology at every point in the dataset:

tos_clim = ds.tos.groupby('time.month').mean()
tos_clim
Loading...

Plot climatology at a specific point:

tos_clim.sel(lon=310, lat=50, method='nearest').plot();
<Figure size 640x480 with 1 Axes>

Plot zonal mean climatology:

tos_clim.mean(dim='lon').transpose().plot.contourf(levels=12, robust=True, cmap='turbo');
<Figure size 640x480 with 2 Axes>

Calculate and plot the difference between January and December climatologies:

(tos_clim.sel(month=1) - tos_clim.sel(month=12)).plot(size=6, robust=True);
<Figure size 800x600 with 2 Axes>

Compute anomaly

Now let’s combine the previous steps to compute climatology and use xarray’s groupby arithmetic to remove this climatology from our original data:

gb = ds.tos.groupby('time.month')
tos_anom = gb - gb.mean(dim='time')
tos_anom
Loading...
tos_anom.sel(lon=310, lat=50, method='nearest').plot();
<Figure size 640x480 with 1 Axes>

Let’s compute and visualize the mean global anomaly over time. We need to specify both lat and lon dimensions in the dim argument to mean():

unweighted_mean_global_anom = tos_anom.mean(dim=['lat', 'lon'])
unweighted_mean_global_anom.plot();
<Figure size 640x480 with 1 Axes>
Note: An operation which combines grid cells of different size is not scientifically valid unless each cell is weighted by the size of the grid cell. xarray has a convenient weighted method to accomplish this.

Let’s first load the cell area data, from another CESM2 dataset that contains the weights for the grid cells:

filepath2 = DATASETS.fetch('CESM2_grid_variables.nc')
areacello = xr.open_dataset(filepath2).areacello
areacello
Loading...

As before, let’s calculate area-weighted mean global anomaly:

weighted_mean_global_anom = tos_anom.weighted(areacello).mean(dim=['lat', 'lon'])

Let’s plot both unweighted and weighted means:

unweighted_mean_global_anom.plot(size=7)
weighted_mean_global_anom.plot()
plt.legend(['unweighted', 'weighted']);
<Figure size 933.333x700 with 1 Axes>

For example, resample to annual frequency:

r = ds.tos.resample(time='AS')
r
<string>:7: FutureWarning: 'AS' is deprecated and will be removed in a future version. Please use 'YS' instead of 'AS'.
<DataArrayResample, grouped over 1 grouper(s), 15 groups in total: '__resample_dim__': TimeResampler('__resample_dim__'), 15/15 groups with labels 2000-01-01, 00:00:00, ..., 201...>
r.mean()
Loading...

Compute a 5-month moving average:

m_avg = ds.tos.rolling(time=5, center=True).mean()
m_avg
Loading...
lat = 50
lon = 310

m_avg.isel(lat=lat, lon=lon).plot(size=6)
ds.tos.isel(lat=lat, lon=lon).plot()
plt.legend(['5-month moving average', 'monthly data']);
<Figure size 800x600 with 1 Axes>

Masking Data

Using the xr.where() or .where() method, elements of an xarray Dataset or xarray DataArray that satisfy a given condition or multiple conditions can be replaced/masked. To demonstrate this, we are going to use the .where() method on the tos DataArray.

We will use the same sea surface temperature dataset:

ds
Loading...

Using where with one condition

Imagine we wish to analyze just the last time in the dataset. We could of course use isel for this:

sample = ds.tos.isel(time=-1)
sample
Loading...

Unlike .isel() and .sel() that change the shape of the returned results, .where() preserves the shape of the original data. It accomplishes this by returning values from the original DataArray or Dataset if the condition is True, and fills in values (by default nan) wherever the condition is False.

Before applying it, let’s look at the .where() documentation. As the documention points out, the conditional expression in .where() can be:

  • a DataArray
  • a Dataset
  • a function

For demonstration purposes, let’s use .where() to mask locations with temperature values greater than 0:

masked_sample = sample.where(sample < 0.0)
masked_sample
Loading...

Let’s plot both our original sample, and the masked sample:

fig, axes = plt.subplots(ncols=2, figsize=(19, 6))
sample.plot(ax=axes[0], robust=True)
masked_sample.plot(ax=axes[1], robust=True);
<Figure size 1900x600 with 4 Axes>
fig, axes = plt.subplots(ncols=2, figsize=(19, 6))
sample.plot(ax=axes[0])
masked_sample.plot(ax=axes[1]);
<Figure size 1900x600 with 4 Axes>

Using where with multiple conditions

.where() allows providing multiple conditions. To do this, we need to make sure each conditional expression is enclosed in (). To combine conditions, we use the bit-wise and (&) operator and/or the bit-wise or (|). Let’s use .where() to mask locations with temperature values less than 25 and greater than 30:

sample.where((sample > 25) & (sample < 30)).plot(size=6, robust=True);
<Figure size 800x600 with 2 Axes>
sample.where((sample > 25) & (sample < 30)).plot(size=6);
<Figure size 800x600 with 2 Axes>

We can use coordinates to apply a mask as well. Below, we use the latitude and longitude coordinates to mask everywhere outside of the Niño 3.4 region:

sample.where(
    (sample.lat < 5) & (sample.lat > -5) & (sample.lon > 190) & (sample.lon < 240)
).plot(size=6, robust=True);
<Figure size 800x600 with 2 Axes>

Using where with a custom fill value

.where() can take a second argument, which, if supplied, defines a fill value for the masked region. Below we fill masked regions with a constant 0:

sample.where((sample > 25) & (sample < 30), 0).plot(size=6, robust=False);
<Figure size 800x600 with 2 Axes>

Summary

  • Similar to NumPy, arithmetic operations are vectorized over a DataArray
  • Xarray provides aggregation methods like sum() and mean(), with the option to specify which dimension over which the operation will be done
  • groupby enables the convenient split-apply-combine workflow
  • The .where() method allows for filtering or replacing of data based on one or more provided conditions

What’s next?

In an upcoming notebook, we will work through an example of plotting the Niño 3.4 Index.