Converting emcee objects to DataTree#
DataTree
is the data format ArviZ relies on.
This page covers multiple ways to generate a DataTree
from emcee objects.
See also
Conversion from Python, numpy or pandas objects
xarray_for_arviz for an overview of
InferenceData
and its role within ArviZ.schema describes the structure of
InferenceData
objects and the assumptions made by ArviZ to ease your exploratory analysis of Bayesian models.
We will start by importing the required packages and defining the model. The famous 8 school model.
import arviz_base as az
import numpy as np
import emcee
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def log_prior_8school(theta):
mu, tau, eta = theta[0], theta[1], theta[2:]
# Half-cauchy prior, hwhm=25
if tau < 0:
return -np.inf
prior_tau = -np.log(tau**2 + 25**2)
prior_mu = -((mu / 10) ** 2) # normal prior, loc=0, scale=10
prior_eta = -np.sum(eta**2) # normal prior, loc=0, scale=1
return prior_mu + prior_tau + prior_eta
def log_likelihood_8school(theta, y, s):
mu, tau, eta = theta[0], theta[1], theta[2:]
return -(((mu + tau * eta - y) / s) ** 2)
def lnprob_8school(theta, y, s):
prior = log_prior_8school(theta)
like_vect = log_likelihood_8school(theta, y, s)
like = np.sum(like_vect)
return like + prior
nwalkers = 40 # called chains in ArviZ
ndim = J + 2
draws = 1500
pos = np.random.normal(size=(nwalkers, ndim))
pos[:, 1] = np.absolute(pos[:, 1])
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob_8school, args=(y_obs, sigma))
sampler.run_mcmc(pos, draws);
Manually set variable names#
This first example will show how to convert manually setting the variable names only, leaving everything else to ArviZ defaults.
# define variable names, it cannot be inferred from emcee
var_names = ["mu", "tau"] + ["eta{}".format(i) for i in range(J)]
idata1 = az.from_emcee(sampler, var_names=var_names)
idata1
<xarray.DatasetView> Dimensions: () Data variables: *empty*
- draw: 1500
- chain: 40
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- mu(draw, chain)float640.04892 0.4058 ... 6.357 6.915
array([[ 0.04892431, 0.40584996, -0.49545336, ..., -2.10983344, -1.03126245, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], ..., [ 3.43496427, 2.98598133, 7.99173958, ..., 3.38840016, 6.35686609, 7.1535426 ], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731]])
- tau(draw, chain)float642.306 0.1492 2.226 ... 15.16 0.8366
array([[ 2.30584775, 0.14918842, 2.22637404, ..., 0.44663061, 0.3656756 , 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], ..., [15.28055298, 8.27559843, 14.76199644, ..., 18.77776866, 15.16116783, 3.92189291], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283]])
- eta0(draw, chain)float64-0.6786 1.546 ... 0.8633 0.2857
array([[-0.67857628, 1.5456469 , -0.2988126 , ..., -0.31146215, -0.58921979, -0.1978413 ], [-0.41869793, 1.5456469 , -0.43300786, ..., -0.31146215, -0.65531654, -0.1978413 ], [-0.41869793, 1.5456469 , -0.43300786, ..., -0.31146215, -0.65531654, -0.1978413 ], ..., [ 0.31934935, 1.3757245 , 0.7989722 , ..., 0.83954286, 0.86329705, 0.39941205], [ 0.31934935, 1.3757245 , 0.7989722 , ..., 0.77549432, 0.86329705, 0.28569035], [ 0.31934935, 1.3757245 , 0.7989722 , ..., 0.77549432, 0.86329705, 0.28569035]])
- eta1(draw, chain)float64-0.1647 0.01271 ... -0.6913 -0.3475
array([[-0.16468065, 0.0127109 , 1.82189573, ..., -0.94215065, -0.427894 , 0.25562326], [ 0.18647111, 0.0127109 , 1.36391804, ..., -0.94215065, -0.52226628, 0.25562326], [ 0.18647111, 0.0127109 , 1.36391804, ..., -0.94215065, -0.52226628, 0.25562326], ..., [ 0.08043603, 0.72003818, 0.19739264, ..., 0.36524514, -0.69131298, -0.2267906 ], [ 0.08043603, 0.72003818, 0.19739264, ..., 0.69970205, -0.69131298, -0.34752045], [ 0.08043603, 0.72003818, 0.19739264, ..., 0.69970205, -0.69131298, -0.34752045]])
- eta2(draw, chain)float64-0.5788 -0.6777 ... -0.1428 -0.1201
array([[-0.57879642, -0.6776883 , -0.99582489, ..., 1.09926195, 0.53215862, -0.30163309], [-0.77652266, -0.6776883 , -0.79823514, ..., 1.09926195, 0.49966297, -0.30163309], [-0.77652266, -0.6776883 , -0.79823514, ..., 1.09926195, 0.49966297, -0.30163309], ..., [-0.76803917, -0.09366257, -1.3474608 , ..., -0.10542609, -0.14277396, -0.39200948], [-0.76803917, -0.09366257, -1.3474608 , ..., -0.08869062, -0.14277396, -0.12007159], [-0.76803917, -0.09366257, -1.3474608 , ..., -0.08869062, -0.14277396, -0.12007159]])
- eta3(draw, chain)float64-0.9262 0.363 ... -0.1585 1.011
array([[-0.92617107, 0.36303128, 1.34032775, ..., 0.68979108, 1.41984251, 0.71753099], [-0.90282373, 0.36303128, 0.72003849, ..., 0.68979108, 1.42857041, 0.71753099], [-0.90282373, 0.36303128, 0.72003849, ..., 0.68979108, 1.42857041, 0.71753099], ..., [-0.08625927, -0.48122536, 0.33414925, ..., 0.3222754 , -0.15848865, 0.86072871], [-0.08625927, -0.48122536, 0.33414925, ..., 0.34921107, -0.15848865, 1.01060229], [-0.08625927, -0.48122536, 0.33414925, ..., 0.34921107, -0.15848865, 1.01060229]])
- eta4(draw, chain)float641.789 0.1214 ... -0.1618 -1.608
array([[ 1.78892783, 0.12139442, -0.52563632, ..., 0.25075209, -1.31276178, -0.74144372], [ 1.14559647, 0.12139442, -0.1387207 , ..., 0.25075209, -1.41567032, -0.74144372], [ 1.14559647, 0.12139442, -0.1387207 , ..., 0.25075209, -1.41567032, -0.74144372], ..., [-0.36611581, -0.07175391, -0.69425985, ..., 0.16778185, -0.16179822, -1.40592675], [-0.36611581, -0.07175391, -0.69425985, ..., 0.11688505, -0.16179822, -1.6084794 ], [-0.36611581, -0.07175391, -0.69425985, ..., 0.11688505, -0.16179822, -1.6084794 ]])
- eta5(draw, chain)float640.3746 0.1143 ... 0.3285 -0.2215
array([[ 0.37463455, 0.11426478, 1.87001865, ..., 0.96390528, -0.28103287, -1.35825061], [ 0.35675372, 0.11426478, 1.18707109, ..., 0.96390528, -0.31727261, -1.35825061], [ 0.35675372, 0.11426478, 1.18707109, ..., 0.96390528, -0.31727261, -1.35825061], ..., [-0.27236932, -0.64094235, -0.6463856 , ..., 0.05030224, 0.32852098, -0.31561289], [-0.27236932, -0.64094235, -0.6463856 , ..., 0.00841656, 0.32852098, -0.22146929], [-0.27236932, -0.64094235, -0.6463856 , ..., 0.00841656, 0.32852098, -0.22146929]])
- eta6(draw, chain)float640.2838 -0.6526 ... 0.6523 0.1229
array([[ 0.2838023 , -0.6525665 , 0.29470158, ..., 0.46763536, 0.17895492, 0.53966469], [ 0.01150702, -0.6525665 , -0.08897712, ..., 0.46763536, 0.29457348, 0.53966469], [ 0.01150702, -0.6525665 , -0.08897712, ..., 0.46763536, 0.29457348, 0.53966469], ..., [ 0.37489382, -0.22304128, 0.89148754, ..., 0.617635 , 0.65230976, 0.29320448], [ 0.37489382, -0.22304128, 0.89148754, ..., 0.90433984, 0.65230976, 0.12292282], [ 0.37489382, -0.22304128, 0.89148754, ..., 0.90433984, 0.65230976, 0.12292282]])
- eta7(draw, chain)float64-1.689 0.5856 ... -0.9638 -0.3025
array([[-1.68872586, 0.58564227, 0.58635191, ..., -0.2291999 , -0.98494518, -0.7510712 ], [-0.73387141, 0.58564227, 0.65329154, ..., -0.2291999 , -1.05025954, -0.7510712 ], [-0.73387141, 0.58564227, 0.65329154, ..., -0.2291999 , -1.05025954, -0.7510712 ], ..., [ 0.52550953, 0.38359654, 0.38363129, ..., -0.00329923, -0.96380601, -0.15046677], [ 0.52550953, 0.38359654, 0.38363129, ..., -0.02956931, -0.96380601, -0.30248027], [ 0.52550953, 0.38359654, 0.38363129, ..., -0.02956931, -0.96380601, -0.30248027]])
- created_at :
- 2023-06-16T00:40:42.099033
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 Data variables: mu (draw, chain) float64 0.04892 0.4058 -0.4955 ... 0.92 6.357 6.915 tau (draw, chain) float64 2.306 0.1492 2.226 ... 24.55 15.16 0.8366 eta0 (draw, chain) float64 -0.6786 1.546 -0.2988 ... 0.8633 0.2857 eta1 (draw, chain) float64 -0.1647 0.01271 1.822 ... -0.6913 -0.3475 eta2 (draw, chain) float64 -0.5788 -0.6777 -0.9958 ... -0.1428 -0.1201 eta3 (draw, chain) float64 -0.9262 0.363 1.34 ... 0.3492 -0.1585 1.011 eta4 (draw, chain) float64 1.789 0.1214 -0.5256 ... -0.1618 -1.608 eta5 (draw, chain) float64 0.3746 0.1143 1.87 ... 0.3285 -0.2215 eta6 (draw, chain) float64 0.2838 -0.6526 0.2947 ... 0.6523 0.1229 eta7 (draw, chain) float64 -1.689 0.5856 0.5864 ... -0.9638 -0.3025 Attributes: created_at: 2023-06-16T00:40:42.099033 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
posterior- arg_0_dim_0: 8
- arg_1_dim_0: 8
- arg_0_dim_0(arg_0_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_1_dim_0(arg_1_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_0(arg_0_dim_0)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- arg_1(arg_1_dim_0)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2023-06-16T00:40:42.092684
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (arg_0_dim_0: 8, arg_1_dim_0: 8) Coordinates: * arg_0_dim_0 (arg_0_dim_0) int64 0 1 2 3 4 5 6 7 * arg_1_dim_0 (arg_1_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: arg_0 (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 arg_1 (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2023-06-16T00:40:42.092684 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
observed_data- draw: 1500
- chain: 40
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- lp(draw, chain)float64-23.76 -18.08 ... -14.46 -15.33
array([[-23.75656766, -18.08303488, -24.62686788, ..., -20.89346634, -21.40430919, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], ..., [-11.16721517, -14.54501244, -12.98669067, ..., -10.00527149, -14.46337015, -13.65646064], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452]])
- created_at :
- 2023-06-16T00:40:42.090166
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 Data variables: lp (draw, chain) float64 -23.76 -18.08 -24.63 ... -11.52 -14.46 -15.33 Attributes: created_at: 2023-06-16T00:40:42.090166 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
sample_stats
ArviZ has stored the posterior variables with the provided names as expected, but it has also included other useful information in the InferenceData
object. The log probability of each sample is stored in the sample_stats
group under the name lp
and all the arguments passed to the sampler as args
have been saved in the observed_data
group.
It can also be useful to perform a burn in cut to the MCMC samples (see :meth:arviz.InferenceData.sel
for more details)
#idata1.sel(draw=slice(100, None))
From an InferenceData object, ArviZ’s native data structure, the posterior plot
of a few variables can be done in one line:
#az.plot_posterior(idata1, var_names=["mu", "tau", "eta4"])
Structuring the posterior as multidimensional variables#
This way of calling from_emcee
stores each eta
as a different variable, called eta#
,
however, they are in fact different dimensions of the same variable.
This can be seen in the code of the likelihood and prior functions, where theta
is unpacked as:
mu, tau, eta = theta[0], theta[1], theta[2:]
ArviZ has support for multidimensional variables, and there is a way to tell it how to split the variables like it was done in the likelihood and prior functions:
idata2 = az.from_emcee(sampler, slices=[0, 1, slice(2, None)])
idata2
<xarray.DatasetView> Dimensions: () Data variables: *empty*
- draw: 1500
- chain: 40
- var_2_dim_0: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- var_2_dim_0(var_2_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- var_0(draw, chain)float640.04892 0.4058 ... 6.357 6.915
array([[ 0.04892431, 0.40584996, -0.49545336, ..., -2.10983344, -1.03126245, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], ..., [ 3.43496427, 2.98598133, 7.99173958, ..., 3.38840016, 6.35686609, 7.1535426 ], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731]])
- var_1(draw, chain)float642.306 0.1492 2.226 ... 15.16 0.8366
array([[ 2.30584775, 0.14918842, 2.22637404, ..., 0.44663061, 0.3656756 , 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], ..., [15.28055298, 8.27559843, 14.76199644, ..., 18.77776866, 15.16116783, 3.92189291], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283]])
- var_2(draw, chain, var_2_dim_0)float64-0.6786 -0.1647 ... 0.1229 -0.3025
array([[[-0.67857628, -0.16468065, -0.57879642, ..., 0.37463455, 0.2838023 , -1.68872586], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.2988126 , 1.82189573, -0.99582489, ..., 1.87001865, 0.29470158, 0.58635191], ..., [-0.31146215, -0.94215065, 1.09926195, ..., 0.96390528, 0.46763536, -0.2291999 ], [-0.58921979, -0.427894 , 0.53215862, ..., -0.28103287, 0.17895492, -0.98494518], [-0.1978413 , 0.25562326, -0.30163309, ..., -1.35825061, 0.53966469, -0.7510712 ]], [[-0.41869793, 0.18647111, -0.77652266, ..., 0.35675372, 0.01150702, -0.73387141], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.43300786, 1.36391804, -0.79823514, ..., 1.18707109, -0.08897712, 0.65329154], ... [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]], [[ 0.31934935, 0.08043603, -0.76803917, ..., -0.27236932, 0.37489382, 0.52550953], [ 1.3757245 , 0.72003818, -0.09366257, ..., -0.64094235, -0.22304128, 0.38359654], [ 0.7989722 , 0.19739264, -1.3474608 , ..., -0.6463856 , 0.89148754, 0.38363129], ..., [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]]])
- created_at :
- 2023-06-16T00:40:42.482732
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, var_2_dim_0: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 ... 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39 * var_2_dim_0 (var_2_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: var_0 (draw, chain) float64 0.04892 0.4058 -0.4955 ... 6.357 6.915 var_1 (draw, chain) float64 2.306 0.1492 2.226 ... 24.55 15.16 0.8366 var_2 (draw, chain, var_2_dim_0) float64 -0.6786 -0.1647 ... -0.3025 Attributes: created_at: 2023-06-16T00:40:42.482732 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
posterior- arg_0_dim_0: 8
- arg_1_dim_0: 8
- arg_0_dim_0(arg_0_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_1_dim_0(arg_1_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_0(arg_0_dim_0)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- arg_1(arg_1_dim_0)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2023-06-16T00:40:42.481120
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (arg_0_dim_0: 8, arg_1_dim_0: 8) Coordinates: * arg_0_dim_0 (arg_0_dim_0) int64 0 1 2 3 4 5 6 7 * arg_1_dim_0 (arg_1_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: arg_0 (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 arg_1 (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2023-06-16T00:40:42.481120 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
observed_data- draw: 1500
- chain: 40
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- lp(draw, chain)float64-23.76 -18.08 ... -14.46 -15.33
array([[-23.75656766, -18.08303488, -24.62686788, ..., -20.89346634, -21.40430919, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], ..., [-11.16721517, -14.54501244, -12.98669067, ..., -10.00527149, -14.46337015, -13.65646064], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452]])
- created_at :
- 2023-06-16T00:40:42.479558
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 Data variables: lp (draw, chain) float64 -23.76 -18.08 -24.63 ... -11.52 -14.46 -15.33 Attributes: created_at: 2023-06-16T00:40:42.479558 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
sample_stats
After checking the default variable names, the trace of one dimension of eta can be plotted using ArviZ syntax:
#az.plot_trace(idata2, var_names=["var_2"], coords={"var_2_dim_0": 4});
blobs
: unlock sample stats, posterior predictive and miscellanea#
Emcee does not store per-draw sample stats, however, it has a functionality called blobs that allows to store any variable on a per-draw basis. It can be used to store some sample_stats or even posterior_predictive data.
You can modify the probability function to use this blobs
functionality and store the pointwise log likelihood,
then rerun the sampler using the new function:
def lnprob_8school_blobs(theta, y, s):
prior = log_prior_8school(theta)
like_vect = log_likelihood_8school(theta, y, s)
like = np.sum(like_vect)
return like + prior, like_vect
sampler_blobs = emcee.EnsembleSampler(
nwalkers,
ndim,
lnprob_8school_blobs,
args=(y_obs, sigma),
)
sampler_blobs.run_mcmc(pos, draws);
You can now use the blob_names
argument to indicate how to store this blob-defined variable. As the group is not specified, it will go to sample_stats.
Note that the argument blob_names is added to the arguments covered in the previous examples and we are also introducing coords
and dims
arguments to show the power and flexibility of the converter. For more on coords
and dims
see page_in_construction
.
dims = {"eta": ["school"], "log_likelihood": ["school"]}
idata3 = az.from_emcee(
sampler_blobs,
var_names=["mu", "tau", "eta"],
slices=[0, 1, slice(2, None)],
blob_names=["y"],
dims=dims,
coords={"school": range(8)},
)
idata3
<xarray.DatasetView> Dimensions: () Data variables: *empty*
- draw: 1500
- chain: 40
- school: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- school(school)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- mu(draw, chain)float640.04892 0.4058 ... 6.357 6.915
array([[ 0.04892431, 0.40584996, -0.49545336, ..., -2.10983344, -1.03126245, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], ..., [ 3.43496427, 2.98598133, 7.99173958, ..., 3.38840016, 6.35686609, 7.1535426 ], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731]])
- tau(draw, chain)float642.306 0.1492 2.226 ... 15.16 0.8366
array([[ 2.30584775, 0.14918842, 2.22637404, ..., 0.44663061, 0.3656756 , 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], ..., [15.28055298, 8.27559843, 14.76199644, ..., 18.77776866, 15.16116783, 3.92189291], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283]])
- eta(draw, chain, school)float64-0.6786 -0.1647 ... 0.1229 -0.3025
array([[[-0.67857628, -0.16468065, -0.57879642, ..., 0.37463455, 0.2838023 , -1.68872586], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.2988126 , 1.82189573, -0.99582489, ..., 1.87001865, 0.29470158, 0.58635191], ..., [-0.31146215, -0.94215065, 1.09926195, ..., 0.96390528, 0.46763536, -0.2291999 ], [-0.58921979, -0.427894 , 0.53215862, ..., -0.28103287, 0.17895492, -0.98494518], [-0.1978413 , 0.25562326, -0.30163309, ..., -1.35825061, 0.53966469, -0.7510712 ]], [[-0.41869793, 0.18647111, -0.77652266, ..., 0.35675372, 0.01150702, -0.73387141], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.43300786, 1.36391804, -0.79823514, ..., 1.18707109, -0.08897712, 0.65329154], ... [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]], [[ 0.31934935, 0.08043603, -0.76803917, ..., -0.27236932, 0.37489382, 0.52550953], [ 1.3757245 , 0.72003818, -0.09366257, ..., -0.64094235, -0.22304128, 0.38359654], [ 0.7989722 , 0.19739264, -1.3474608 , ..., -0.6463856 , 0.89148754, 0.38363129], ..., [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]]])
- created_at :
- 2023-06-16T00:40:44.547558
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 * school (school) int64 0 1 2 3 4 5 6 7 Data variables: mu (draw, chain) float64 0.04892 0.4058 -0.4955 ... 0.92 6.357 6.915 tau (draw, chain) float64 2.306 0.1492 2.226 ... 24.55 15.16 0.8366 eta (draw, chain, school) float64 -0.6786 -0.1647 ... 0.1229 -0.3025 Attributes: created_at: 2023-06-16T00:40:44.547558 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
posterior- arg_0_dim_0: 8
- arg_1_dim_0: 8
- arg_0_dim_0(arg_0_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_1_dim_0(arg_1_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- arg_0(arg_0_dim_0)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- arg_1(arg_1_dim_0)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2023-06-16T00:40:44.545948
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (arg_0_dim_0: 8, arg_1_dim_0: 8) Coordinates: * arg_0_dim_0 (arg_0_dim_0) int64 0 1 2 3 4 5 6 7 * arg_1_dim_0 (arg_1_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: arg_0 (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 arg_1 (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2023-06-16T00:40:44.545948 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
observed_data- draw: 1500
- chain: 40
- y_dim_0: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- y_dim_0(y_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- y(draw, chain, y_dim_0)float64-3.872 -0.694 ... -1.206 -0.08795
array([[[-3.87191394e+00, -6.94022986e-01, -1.14798889e-02, ..., -6.28783546e-05, -2.99174820e+00, -7.74890964e-01], [-3.32784122e+00, -5.76423166e-01, -4.26615268e-02, ..., -2.75246231e-03, -3.12989363e+00, -4.08660379e-01], [-3.77932314e+00, -1.97067808e-01, -3.22804007e-04, ..., -5.88242234e-02, -3.18241959e+00, -3.86470459e-01], ..., [-4.06665996e+00, -1.10894100e+00, -7.45125668e-03, ..., -5.93287294e-02, -3.96048733e+00, -6.23415620e-01], [-3.80164874e+00, -8.44144349e-01, -1.82813211e-02, ..., -3.76370341e-02, -3.59702442e+00, -5.53489119e-01], [-3.46073355e+00, -5.75842505e-01, -3.56990119e-02, ..., -2.42671922e-02, -3.02426141e+00, -4.66205628e-01]], [[-3.51229735e+00, -5.13390023e-01, -2.11817067e-02, ..., -8.46732416e-05, -3.04111409e+00, -4.90362364e-01], [-3.32784122e+00, -5.76423166e-01, -4.26615268e-02, ..., -2.75246231e-03, -3.12989363e+00, -4.08660379e-01], [-3.69139975e+00, -3.53158428e-01, -9.94708674e-03, ..., -4.95310535e-03, -3.33746988e+00, -3.78935966e-01], ... [-2.87168808e-01, -1.02025182e+00, -1.18577993e-02, ..., -1.32665095e-04, -2.62707062e-01, -4.30191650e-01], [-3.25245315e-01, -1.46997341e+00, -2.02064081e-01, ..., -8.83194611e-01, -3.07425783e-02, -1.26631992e+00], [-1.93136068e+00, -1.89272317e-02, -3.76269058e-01, ..., -2.71317230e-01, -1.20608363e+00, -8.79478853e-02]], [[-1.72225395e+00, -1.11284203e-01, -1.09772073e-01, ..., -2.46487004e-02, -7.80828632e-01, -8.83276560e-04], [-8.25563060e-01, -8.92511272e-03, -1.06066952e-01, ..., -9.09954745e-02, -2.84253489e+00, -1.05247174e-01], [-2.99853755e-01, -8.44279656e-02, -3.09377351e-01, ..., -5.37481955e-02, -9.93431892e-02, -8.45279306e-03], ..., [-2.87168808e-01, -1.02025182e+00, -1.18577993e-02, ..., -1.32665095e-04, -2.62707062e-01, -4.30191650e-01], [-3.25245315e-01, -1.46997341e+00, -2.02064081e-01, ..., -8.83194611e-01, -3.07425783e-02, -1.26631992e+00], [-1.93136068e+00, -1.89272317e-02, -3.76269058e-01, ..., -2.71317230e-01, -1.20608363e+00, -8.79478853e-02]]])
- created_at :
- 2023-06-16T00:40:44.542903
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, y_dim_0: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: y (draw, chain, y_dim_0) float64 -3.872 -0.694 ... -1.206 -0.08795 Attributes: created_at: 2023-06-16T00:40:44.542903 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
log_likelihood- draw: 1500
- chain: 40
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- lp(draw, chain)float64-23.76 -18.08 ... -14.46 -15.33
array([[-23.75656766, -18.08303488, -24.62686788, ..., -20.89346634, -21.40430919, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], ..., [-11.16721517, -14.54501244, -12.98669067, ..., -10.00527149, -14.46337015, -13.65646064], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452]])
- created_at :
- 2023-06-16T00:40:44.545000
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 Data variables: lp (draw, chain) float64 -23.76 -18.08 -24.63 ... -11.52 -14.46 -15.33 Attributes: created_at: 2023-06-16T00:40:44.545000 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
sample_stats
Multi-group blobs#
You might even have more complicated blobs, each corresponding to a different group of the InferenceData object. Moreover, you can store the variables passed to the EnsembleSampler
via the args
argument in observed or constant data groups. This is shown in the example below:
sampler_blobs.blobs[0, 1]
array([-3.32784122e+00, -5.76423166e-01, -4.26615268e-02, -3.53483214e-01,
-2.50328863e-02, -2.75246231e-03, -3.12989363e+00, -4.08660379e-01])
def lnprob_8school_blobs(theta, y, sigma):
mu, tau, eta = theta[0], theta[1], theta[2:]
prior = log_prior_8school(theta)
like_vect = log_likelihood_8school(theta, y, sigma)
like = np.sum(like_vect)
# store pointwise log likelihood, useful for model comparison with az.loo or az.waic
# and posterior predictive samples as blobs
return like + prior, (like_vect, np.random.normal((mu + tau * eta), sigma))
sampler_blobs = emcee.EnsembleSampler(
nwalkers,
ndim,
lnprob_8school_blobs,
args=(y_obs, sigma),
)
sampler_blobs.run_mcmc(pos, draws)
dims = {"eta": ["school"], "log_likelihood": ["school"], "y": ["school"]}
idata4 = az.from_emcee(
sampler_blobs,
var_names=["mu", "tau", "eta"],
slices=[0, 1, slice(2, None)],
arg_names=["y", "sigma"],
arg_groups=["observed_data", "constant_data"],
blob_names=["y", "y"],
blob_groups=["log_likelihood", "posterior_predictive"],
dims=dims,
coords={"school": range(8)},
)
idata4
<xarray.DatasetView> Dimensions: () Data variables: *empty*
- draw: 1500
- chain: 40
- school: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- school(school)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- mu(draw, chain)float640.04892 0.4058 ... 6.357 6.915
array([[ 0.04892431, 0.40584996, -0.49545336, ..., -2.10983344, -1.03126245, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], [ 0.54321117, 0.40584996, -0.12629448, ..., -2.10983344, -1.04969914, 0.23335853], ..., [ 3.43496427, 2.98598133, 7.99173958, ..., 3.38840016, 6.35686609, 7.1535426 ], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731], [ 3.43496427, 2.98598133, 7.99173958, ..., 0.92003529, 6.35686609, 6.91497731]])
- tau(draw, chain)float642.306 0.1492 2.226 ... 15.16 0.8366
array([[ 2.30584775, 0.14918842, 2.22637404, ..., 0.44663061, 0.3656756 , 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], [ 1.56412863, 0.14918842, 1.60096208, ..., 0.44663061, 0.36917635, 0.6971695 ], ..., [15.28055298, 8.27559843, 14.76199644, ..., 18.77776866, 15.16116783, 3.92189291], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283], [15.28055298, 8.27559843, 14.76199644, ..., 24.55433169, 15.16116783, 0.83661283]])
- eta(draw, chain, school)float64-0.6786 -0.1647 ... 0.1229 -0.3025
array([[[-0.67857628, -0.16468065, -0.57879642, ..., 0.37463455, 0.2838023 , -1.68872586], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.2988126 , 1.82189573, -0.99582489, ..., 1.87001865, 0.29470158, 0.58635191], ..., [-0.31146215, -0.94215065, 1.09926195, ..., 0.96390528, 0.46763536, -0.2291999 ], [-0.58921979, -0.427894 , 0.53215862, ..., -0.28103287, 0.17895492, -0.98494518], [-0.1978413 , 0.25562326, -0.30163309, ..., -1.35825061, 0.53966469, -0.7510712 ]], [[-0.41869793, 0.18647111, -0.77652266, ..., 0.35675372, 0.01150702, -0.73387141], [ 1.5456469 , 0.0127109 , -0.6776883 , ..., 0.11426478, -0.6525665 , 0.58564227], [-0.43300786, 1.36391804, -0.79823514, ..., 1.18707109, -0.08897712, 0.65329154], ... [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]], [[ 0.31934935, 0.08043603, -0.76803917, ..., -0.27236932, 0.37489382, 0.52550953], [ 1.3757245 , 0.72003818, -0.09366257, ..., -0.64094235, -0.22304128, 0.38359654], [ 0.7989722 , 0.19739264, -1.3474608 , ..., -0.6463856 , 0.89148754, 0.38363129], ..., [ 0.77549432, 0.69970205, -0.08869062, ..., 0.00841656, 0.90433984, -0.02956931], [ 0.86329705, -0.69131298, -0.14277396, ..., 0.32852098, 0.65230976, -0.96380601], [ 0.28569035, -0.34752045, -0.12007159, ..., -0.22146929, 0.12292282, -0.30248027]]])
- created_at :
- 2023-06-16T00:40:47.329101
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 * school (school) int64 0 1 2 3 4 5 6 7 Data variables: mu (draw, chain) float64 0.04892 0.4058 -0.4955 ... 0.92 6.357 6.915 tau (draw, chain) float64 2.306 0.1492 2.226 ... 24.55 15.16 0.8366 eta (draw, chain, school) float64 -0.6786 -0.1647 ... 0.1229 -0.3025 Attributes: created_at: 2023-06-16T00:40:47.329101 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
posterior- school: 8
- school(school)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2023-06-16T00:40:47.327326
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (school: 8) Coordinates: * school (school) int64 0 1 2 3 4 5 6 7 Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2023-06-16T00:40:47.327326 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
observed_data- sigma_dim_0: 8
- sigma_dim_0(sigma_dim_0)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- sigma(sigma_dim_0)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2023-06-16T00:40:47.327868
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (sigma_dim_0: 8) Coordinates: * sigma_dim_0 (sigma_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: sigma (sigma_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2023-06-16T00:40:47.327868 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
constant_data- draw: 1500
- chain: 40
- school: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- school(school)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- y(draw, chain, school)float64-3.872 -0.694 ... -1.206 -0.08795
array([[[-3.87191394e+00, -6.94022986e-01, -1.14798889e-02, ..., -6.28783546e-05, -2.99174820e+00, -7.74890964e-01], [-3.32784122e+00, -5.76423166e-01, -4.26615268e-02, ..., -2.75246231e-03, -3.12989363e+00, -4.08660379e-01], [-3.77932314e+00, -1.97067808e-01, -3.22804007e-04, ..., -5.88242234e-02, -3.18241959e+00, -3.86470459e-01], ..., [-4.06665996e+00, -1.10894100e+00, -7.45125668e-03, ..., -5.93287294e-02, -3.96048733e+00, -6.23415620e-01], [-3.80164874e+00, -8.44144349e-01, -1.82813211e-02, ..., -3.76370341e-02, -3.59702442e+00, -5.53489119e-01], [-3.46073355e+00, -5.75842505e-01, -3.56990119e-02, ..., -2.42671922e-02, -3.02426141e+00, -4.66205628e-01]], [[-3.51229735e+00, -5.13390023e-01, -2.11817067e-02, ..., -8.46732416e-05, -3.04111409e+00, -4.90362364e-01], [-3.32784122e+00, -5.76423166e-01, -4.26615268e-02, ..., -2.75246231e-03, -3.12989363e+00, -4.08660379e-01], [-3.69139975e+00, -3.53158428e-01, -9.94708674e-03, ..., -4.95310535e-03, -3.33746988e+00, -3.78935966e-01], ... [-2.87168808e-01, -1.02025182e+00, -1.18577993e-02, ..., -1.32665095e-04, -2.62707062e-01, -4.30191650e-01], [-3.25245315e-01, -1.46997341e+00, -2.02064081e-01, ..., -8.83194611e-01, -3.07425783e-02, -1.26631992e+00], [-1.93136068e+00, -1.89272317e-02, -3.76269058e-01, ..., -2.71317230e-01, -1.20608363e+00, -8.79478853e-02]], [[-1.72225395e+00, -1.11284203e-01, -1.09772073e-01, ..., -2.46487004e-02, -7.80828632e-01, -8.83276560e-04], [-8.25563060e-01, -8.92511272e-03, -1.06066952e-01, ..., -9.09954745e-02, -2.84253489e+00, -1.05247174e-01], [-2.99853755e-01, -8.44279656e-02, -3.09377351e-01, ..., -5.37481955e-02, -9.93431892e-02, -8.45279306e-03], ..., [-2.87168808e-01, -1.02025182e+00, -1.18577993e-02, ..., -1.32665095e-04, -2.62707062e-01, -4.30191650e-01], [-3.25245315e-01, -1.46997341e+00, -2.02064081e-01, ..., -8.83194611e-01, -3.07425783e-02, -1.26631992e+00], [-1.93136068e+00, -1.89272317e-02, -3.76269058e-01, ..., -2.71317230e-01, -1.20608363e+00, -8.79478853e-02]]])
- created_at :
- 2023-06-16T00:40:47.325055
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 * school (school) int64 0 1 2 3 4 5 6 7 Data variables: y (draw, chain, school) float64 -3.872 -0.694 ... -1.206 -0.08795 Attributes: created_at: 2023-06-16T00:40:47.325055 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
log_likelihood- draw: 1500
- chain: 40
- school: 8
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- school(school)int640 1 2 3 4 5 6 7
array([0, 1, 2, 3, 4, 5, 6, 7])
- y(draw, chain, school)float643.644 0.9363 ... 8.866 -9.154
array([[[ 3.64444301, 0.93627746, -27.95471325, ..., -9.99784581, 0.28896344, -30.26726699], [ -0.25593492, 13.81492871, -14.52405367, ..., 27.3201404 , -9.94851736, 2.06695807], [-22.49810996, 22.25987303, -16.79148337, ..., -3.83669991, 1.46500625, -60.06139352], ..., [ -3.56507768, -5.39260169, -4.05555162, ..., 18.9799484 , 0.62889902, 3.72569114], [ 9.67923188, -4.50928018, 6.02030581, ..., -31.55775182, -13.97103863, -12.70981608], [ 5.72884071, 0.39934476, -15.00999849, ..., -22.78750573, -1.28445009, -20.43985564]], [[ 0.56384831, -3.7538796 , 7.70583145, ..., -2.76801086, -8.61628088, -38.56446778], [ -0.25593492, 13.81492871, -14.52405367, ..., 27.3201404 , -9.94851736, 2.06695807], [-12.42893102, 1.56476609, -4.6207624 , ..., 0.81839767, 9.27141857, -2.48550082], ... [ -7.71386549, 12.44015448, -5.10372343, ..., 22.58734655, 33.07260743, -21.86446195], [ 2.26692903, 18.79463854, 10.55069279, ..., 20.09157967, 29.43386598, -28.99272247], [ 22.75751782, 16.69591651, -0.70045719, ..., 23.25471255, 8.86645123, -9.15383678]], [[ 10.34485489, 2.93430042, -2.56071238, ..., 7.86815171, -1.42835234, 31.58535553], [ -5.83262901, 27.77231996, -20.65292971, ..., 15.88132336, 4.14621985, 29.62166945], [ 32.4162994 , -7.33176216, -28.00357432, ..., 14.25747663, 25.71469929, 25.36272718], ..., [ -7.71386549, 12.44015448, -5.10372343, ..., 22.58734655, 33.07260743, -21.86446195], [ 2.26692903, 18.79463854, 10.55069279, ..., 20.09157967, 29.43386598, -28.99272247], [ 22.75751782, 16.69591651, -0.70045719, ..., 23.25471255, 8.86645123, -9.15383678]]])
- created_at :
- 2023-06-16T00:40:47.326020
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 * school (school) int64 0 1 2 3 4 5 6 7 Data variables: y (draw, chain, school) float64 3.644 0.9363 -27.95 ... 8.866 -9.154 Attributes: created_at: 2023-06-16T00:40:47.326020 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
posterior_predictive- draw: 1500
- chain: 40
- draw(draw)int640 1 2 3 4 ... 1496 1497 1498 1499
array([ 0, 1, 2, ..., 1497, 1498, 1499])
- chain(chain)int640 1 2 3 4 5 6 ... 34 35 36 37 38 39
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
- lp(draw, chain)float64-23.76 -18.08 ... -14.46 -15.33
array([[-23.75656766, -18.08303488, -24.62686788, ..., -20.89346634, -21.40430919, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], [-18.27861331, -18.08303488, -19.58450327, ..., -20.89346634, -22.07702707, -18.31719405], ..., [-11.16721517, -14.54501244, -12.98669067, ..., -10.00527149, -14.46337015, -13.65646064], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452], [-11.16721517, -14.54501244, -12.98669067, ..., -11.52177615, -14.46337015, -15.32749452]])
- created_at :
- 2023-06-16T00:40:47.326738
- creation_library :
- ArviZ
- creation_library_version :
- 0.1
- creation_library_language :
- Python
- inference_library :
- emcee
- inference_library_version :
- 3.1.4
<xarray.DatasetView> Dimensions: (draw: 1500, chain: 40) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499 * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39 Data variables: lp (draw, chain) float64 -23.76 -18.08 -24.63 ... -11.52 -14.46 -15.33 Attributes: created_at: 2023-06-16T00:40:47.326738 creation_library: ArviZ creation_library_version: 0.1 creation_library_language: Python inference_library: emcee inference_library_version: 3.1.4
sample_stats
This last version, which contains both observed data and posterior predictive can be used to plot posterior predictive checks:
#az.plot_ppc(idata4, var_names=["y"], alpha=0.3, num_pp_samples=200);
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Fri Jun 16 2023
Python implementation: CPython
Python version : 3.10.11
IPython version : 8.14.0
arviz_base: 0.1
numpy : 1.24.3
emcee : 3.1.4
Watermark: 2.3.1