DataTree for Exploratory Analysis of Bayesian Models#

Here we present a collection of common manipulations you can use while working with :class:datatree.DataTree.

import arviz_base as az
from datatree import DataTree
import numpy as np
import xarray as xr

xr.set_options(display_expand_data=False, display_expand_attrs=False);

display_expand_data=False makes the default view for xarray.DataArray fold the data values to a single line. To explore the values, click on the icon on the left of the view, right under the xarray.DataArray text. It has no effect on Dataset objects that already default to folded views.

display_expand_attrs=False folds the attributes in both DataArray and Dataset objects to keep the views shorter. In this page we print DataArrays and Datasets several times and they always have the same attributes.

idata = az.load_arviz_data("centered_eight")
idata
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    *empty*

Get a specific group#

post = idata["posterior"]
post
<xarray.DatasetView>
Dimensions:  (chain: 4, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    tau      (chain, draw) float64 ...
Attributes: (6)

Tip

You’ll have noticed we stored the posterior group in a new variable: post. As .copy() was not called, now using idata["posterior"] or post is equivalent.

Use this to keep your code short yet easy to read. Store the groups you’ll need very often as separate variables to use explicitly, but don’t delete the DataTree parent. You’ll need it for many ArviZ functions to work properly. For example: plot_pair needs data from sample_stats group to show divergences, compare needs data from both log_likelihood and posterior groups, plot_loo_pit needs not 2 but 3 groups: log_likelihood, posterior_predictive and posterior.

Add a new variable#

post["log_tau"] = np.log(post["tau"])
idata.posterior
<xarray.DatasetView>
Dimensions:  (chain: 4, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    tau      (chain, draw) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
    log_tau  (chain, draw) float64 1.553 1.363 1.578 ... 1.008 1.076 1.495
Attributes: (6)

Combine chains and draws#

stacked = az.extract(idata)
stacked
<xarray.Dataset>
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404
    theta    (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295
    tau      (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
    log_tau  (sample) float64 1.553 1.363 1.578 0.6188 ... 1.008 1.076 1.495
Attributes: (6)

arviz.extract is a convenience function aimed at taking care of the most common subsetting operations with MCMC samples. It can:

  • Combine chains and draws

  • Return a subset of variables (with optional filtering with regular expressions or string matching)

  • Return a subset of samples. Moreover by default it returns a random subset to prevent getting non-representative samples due to bad mixing.

  • Access any group

Get a random subset of the samples#

az.extract(idata, num_samples=100)
<xarray.Dataset>
Dimensions:  (sample: 100, school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 2 1 2 0 0 2 0 0 0 3 2 0 ... 1 3 3 1 3 2 3 2 3 1 1 0
  * draw     (sample) int64 120 385 192 83 10 97 441 ... 414 44 48 463 120 478
Data variables:
    mu       (sample) float64 8.793 0.4854 7.523 3.48 ... 9.111 2.821 6.28 5.913
    theta    (school, sample) float64 25.74 7.232 6.255 ... 3.351 3.232 -4.656
    tau      (sample) float64 11.36 9.712 7.546 1.185 ... 0.8965 4.987 5.036
    log_tau  (sample) float64 2.43 2.273 2.021 0.1701 ... -0.1093 1.607 1.617
Attributes: (6)

Tip

Use a random seed to get the same subset from multiple groups: az.extract(idata, num_samples=100, rng=3) and az.extract(idata, group="log_likelihood", num_samples=100, rng=3) will continue to have matching samples

Obtain a NumPy array for a given parameter#

Let’s say we want to get the values for mu as a NumPy array.

stacked.mu.values
array([7.87179637, 3.38455431, 9.10047569, ..., 1.76673325, 3.48611194,
       3.40446391])

Get the dimension lengths#

Let’s check how many groups are in our hierarchical model.

len(idata.observed_data.school)
8

Get coordinate values#

What are the names of the groups in our hierarchical model? You can access them from the coordinate name school in this case

idata.observed_data.school
<xarray.DataArray 'school' (school: 8)>
'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon'
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'

Get a subset of chains#

Let’s keep only chain 0 and 2 here. For the subset to take effect on all relevant DataTree groups: posterior, sample_stats, log_likelihood, posterior_predictive we will use the datatree.DataTree.filter before using .sel.

posterior_groups = {"posterior", "posterior_predictive", "sample_stats", "log_likelihood"}
idata.filter(lambda node: node.name in posterior_groups).sel(chain=[0, 2])
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    *empty*

Remove the first n draws (burn-in)#

Let’s say we want to remove the first 100 samples, from all the chains and all DataTree groups with draws.

idata.filter(lambda node: "draw" in node.dims).sel(draw=slice(100, None))
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    *empty*

If you check the burnin object you will see that the groups posterior, posterior_predictive, prior and sample_stats have 400 draws compared to idata that has 500. Alternatively, you can specify which group or groups you want to change.

idata.filter(lambda node: node.name in posterior_groups).sel(draw=slice(100, None))
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    *empty*

Compute posterior mean values along draw and chain dimensions#

To compute the mean value of the posterior samples, do the following:

post.mean()
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    mu       float64 4.486
    theta    float64 4.912
    tau      float64 4.124
    log_tau  float64 1.173

This computes the mean along all dimensions. This is probably what you want for mu and tau, which have two dimensions (chain and draw), but maybe not what you expected for theta, which has one more dimension school.

You can specify along which dimension you want to compute the mean (or other functions).

post.mean(dim=["chain", "draw"])
<xarray.DatasetView>
Dimensions:  (school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       float64 4.486
    theta    (school) float64 6.46 5.028 3.938 4.872 3.667 3.975 6.581 4.772
    tau      float64 4.124
    log_tau  float64 1.173

Compute and store posterior pushforward quantities#

We use “posterior pushfoward quantities” to refer to quantities that are not variables in the posterior but deterministic computations using posterior variables.

You can use xarray for these pushforward operations and store them as a new variable in the posterior group. You’ll then be able to plot them with ArviZ functions, calculate stats and diagnostics on them (like the mcse) or save and share the inferencedata object with the pushforward quantities included.

Compute the rolling mean of \(\log(\tau)\) with xarray.DataArray.rolling, storing the result in the posterior

post["mlogtau"] = post["log_tau"].rolling({"draw": 50}).mean()

Using xarray for pusforward calculations has all the advantages of working with xarray. It also inherits the disadvantages of working with xarray, but we believe those to be outweighed by the advantages, and we have already shown how to extract the data as NumPy arrays. Working with InferenceData is working mainly with xarray objects and this is what is shown in this guide.

Some examples of these advantages are specifying operations with named dimensions instead of positional ones (as seen in some previous sections), automatic alignment and broadcasting of arrays (as we’ll see now), or integration with Dask (as shown in the dask_for_arviz guide).

In this cell you will compute pairwise differences between schools on their mean effects (variable theta). To do so, substract the variable theta after renaming the school dimension to the original variable. Xarray then aligns and broadcasts the two variables because they have different dimensions, and the result is a 4d variable with all the pointwise differences.

Eventually, store the result in the theta_school_diff variable:

post["theta_school_diff"] = post.theta - post.theta.rename(school="school_bis")

The theta_shool_diff variable in the posterior has kept the named dimensions and coordinates:

post
<xarray.DatasetView>
Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)
Coordinates:
  * chain              (chain) int64 0 1 2 3
  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
  * school             (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
  * school_bis         (school_bis) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
    mu                 (chain, draw) float64 7.872 3.385 9.1 ... 3.486 3.404
    theta              (chain, draw, school) float64 12.32 9.905 ... 6.762 1.295
    tau                (chain, draw) float64 4.726 3.909 4.844 ... 2.932 4.461
    log_tau            (chain, draw) float64 1.553 1.363 1.578 ... 1.076 1.495
    mlogtau            (chain, draw) float64 nan nan nan ... 1.494 1.496 1.511
    theta_school_diff  (chain, draw, school, school_bis) float64 0.0 ... 0.0
Attributes: (6)

Advanced subsetting#

To select the value corresponding to the difference between the Choate and Deerfield schools do:

post["theta_school_diff"].sel(school="Choate", school_bis="Deerfield")
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)>
2.415 2.156 -0.04943 1.228 3.384 9.662 ... -1.656 -0.4021 1.524 -3.372 -6.305
Coordinates:
  * chain       (chain) int64 0 1 2 3
  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
    school      <U6 'Choate'
    school_bis  <U9 'Deerfield'

For more advanced subsetting (the equivalent to what is sometimes called “fancy indexing” in NumPy) you need to provide the indices as DataArray objects:

school_idx = xr.DataArray(["Choate", "Hotchkiss", "Mt. Hermon"], dims=["pairwise_school_diff"])
school_bis_idx = xr.DataArray(
    ["Deerfield", "Choate", "Lawrenceville"], dims=["pairwise_school_diff"]
)
post["theta_school_diff"].sel(school=school_idx, school_bis=school_bis_idx)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500,
                                       pairwise_school_diff: 3)>
2.415 -6.741 -1.84 2.156 -3.474 3.784 ... -2.619 6.923 -6.305 1.667 -6.641
Coordinates:
  * chain       (chain) int64 0 1 2 3
  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
    school      (pairwise_school_diff) object 'Choate' 'Hotchkiss' 'Mt. Hermon'
    school_bis  (pairwise_school_diff) object 'Deerfield' ... 'Lawrenceville'
Dimensions without coordinates: pairwise_school_diff

Using lists or NumPy arrays instead of DataArrays does colum/row based indexing. As you can see, the result has 9 values of theta_shool_diff instead of the 3 pairs of difference we selected in the previous cell:

post["theta_school_diff"].sel(
    school=["Choate", "Hotchkiss", "Mt. Hermon"],
    school_bis=["Deerfield", "Choate", "Lawrenceville"],
)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3,
                                       school_bis: 3)>
2.415 0.0 -4.581 -4.326 -6.741 -11.32 ... 1.667 -6.077 -5.203 1.102 -6.641
Coordinates:
  * chain       (chain) int64 0 1 2 3
  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * school      (school) object 'Choate' 'Hotchkiss' 'Mt. Hermon'
  * school_bis  (school_bis) object 'Deerfield' 'Choate' 'Lawrenceville'

Add new chains using concat#

After checking the mcse and realizing you need more samples, you rerun the model with two chains and obtain an idata_rerun object.

# once implemented
# idata.merge(idata_rerun)

Add a new group to a DataTree#

You can also add new groups to a DataTree with the .merge method as above, or using the parent argument when creating new DataTrees object. The code below creates an example dataset and adds it to the idata DataTree.

rng = np.random.default_rng(3)
ds = az.dict_to_dataset(
    {"obs": rng.normal(size=(4, 500, 2))},
    dims={"obs": ["new_school"]},
    coords={"new_school": ["Essex College", "Moordale"]},
)
DataTree(ds, name="predictions", parent=idata)
idata
<xarray.DatasetView>
Dimensions:  ()
Data variables:
    *empty*