arviz.extract#

arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#

Extract an InferenceData group or subset of it.

Parameters:
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by ~ when you want to exclude them.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns:
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset>
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404
    theta    (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295
    tau      (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
Attributes:
    created_at:                 2022-10-13T14:37:37.315398
    arviz_version:              0.13.0.dev0
    inference_library:          pymc
    inference_library_version:  4.2.2
    sampling_time:              7.480114936828613
    tuning_steps:               1000

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)>
array([[ 1.80641321e+00,  1.15003230e+01,  3.08531844e-02,
         4.78763947e+00,  1.67363399e+00, -1.45575080e+00,
         5.69355298e+00,  1.42338706e+01,  4.09880116e+00,
        -3.91057140e-01,  2.97984261e+00,  8.03717738e+00,
         2.12092258e+01,  8.26200915e+00,  3.25808233e+00,
         7.88942950e+00,  6.00429650e+00,  2.99332684e+00,
         9.12787329e+00,  1.48492425e+01,  1.02958492e+01,
         7.59329274e+00,  9.32455129e+00,  7.09929855e+00,
        -1.10356423e+01,  9.41495072e+00,  8.07411402e+00,
         3.49517632e+00,  9.29485423e+00,  9.56574403e+00,
        -1.62812501e-01,  1.09299330e+01,  5.38266189e+00,
         1.18297230e+01,  7.98906311e+00, -7.82667421e-01,
         1.51649688e+01,  6.32443017e+00,  5.61472796e+00,
         2.28772366e+00,  7.03974855e+00,  5.35480976e+00,
         1.09081595e+01,  7.39820257e+00,  8.96712480e+00,
         7.61099342e+00,  9.36990215e+00,  6.75293935e-01,
         2.17756186e+00,  6.75846366e+00,  1.45711219e+01,
        -1.45575080e+00,  5.74800450e+00,  1.26162675e+01,
         3.25808233e+00, -6.11834987e+00,  3.10856643e+00,
         1.79981973e+00,  1.19226808e+01,  2.22671520e+01,
...
         8.24713823e+00,  7.58553160e+00,  8.12276457e+00,
         7.88053428e+00,  9.31886541e+00,  6.01996324e-01,
         9.70130676e-01,  3.08674539e+00,  1.97055711e+00,
         6.61190575e-02,  7.44372537e+00,  5.49871226e+00,
         3.35132225e+00, -2.58684632e+00,  3.56038714e+00,
         5.59174676e+00,  4.01302519e+00,  6.33952795e+00,
        -3.13202161e-01,  3.23104629e+00,  4.67752202e+00,
         2.22113452e+00,  1.51420100e+01,  9.24872843e+00,
         9.84207192e-01,  8.54143826e-01,  2.70468967e+00,
         2.30988559e+00,  2.99111323e+00,  1.96163207e+01,
         5.55486533e+00,  1.22603714e+01, -2.38039401e+00,
        -2.79705277e+00,  6.40489757e+00, -2.71736424e+00,
        -1.00264551e-01,  1.02279786e+01,  5.60222262e+00,
         4.11176477e+00,  8.76202503e+00,  5.47024771e+00,
         1.01830503e+01,  9.29037990e+00, -7.79125947e-01,
         6.05209366e+00,  9.24872843e+00,  4.49810267e+00,
         8.71118605e+00, -3.00193706e+00,  6.87930293e+00,
         2.81711205e-01,  7.97279212e+00, -7.01906741e-01,
         1.72614613e+01,  2.77586579e+00,  4.54495889e-01,
        -3.26253213e+00]])
Coordinates:
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 1 1 2 1 3 2 0 2 3 2 3 ... 3 0 2 1 3 2 2 2 0 1 0 1
  * draw     (sample) int64 161 276 377 128 430 257 ... 459 69 146 409 207 381

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    tau      (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    mu       (chain, draw) float64 ...
Attributes:
    arviz_version:              0.13.0.dev0
    created_at:                 2022-10-13T14:37:26.602116
    inference_library:          pymc
    inference_library_version:  4.2.2