-
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathforestplot.py
530 lines (492 loc) · 22 KB
/
forestplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
"""Forest plot code."""
from copy import copy
from importlib import import_module
import arviz_stats # pylint: disable=unused-import
import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_plots.plot_collection import PlotCollection, process_facet_dims
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
from arviz_plots.visuals import annotate_label, fill_between_y, line_x, remove_axis, scatter_x
def plot_forest(
dt,
var_names=None,
filter_vars=None,
group="posterior",
coords=None,
sample_dims=None,
combined=False,
point_estimate=None,
ci_kind=None,
ci_probs=None,
labels=None,
shade_label=None,
plot_collection=None,
backend=None,
labeller=None,
aes_map=None,
plot_kwargs=None,
stats_kwargs=None,
pc_kwargs=None,
):
"""Plot 1D marginal credible intervals in a single plot.
Parameters
----------
dt : DataTree or dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
``plot_forest`` uses the dimension "column" (creating it if necessary) to generate the grid
then adds the intervals+point estimates to its "forest" coordinate
and labels to its "labels" coordinates. The data used to plot is then the subset
``column="forest"``.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, default None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, default "posterior"
Group to be plotted.
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
combined : bool, default False
Whether to plot intervals for each chain or not. Ignored when the "chain" dimension
is not present.
point_estimate : {"mean", "median", "mode"}, optional
Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate`
ci_kind : {"eti", "hdi"}, optional
Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]``
ci_probs : (float, float), optional
Indicates the probabilities that should be contained within the plotted credible intervals.
It should be sorted as the elements refer to the probabilities of the "trunk" and "twig"
elements. Defaults to ``(0.5, rcParams["stats.ci_prob"])``
labels : sequence of str, optional
Sequence with the dimensions to be labelled in the plot. By default all dimensions
except "chain" and "model" (if present). The order of `labels` is ignored,
only elements being present in it matters.
It can include the special "__variable__" indicator, and does so by default.
shade_label : str, default None
Element of `labels` that should be used to add shading horizontal strips to the plot.
Note that labels and credible intervals are plotted in different :term:`plots`.
The shading is applied to both plots, and the spacing between them is set to 0
*if possible*, which is not always the case (one notable example being matplotlib's
constrained layout).
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_map : mapping of {str : sequence of str or False}, optional
Mapping of artists to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `plot_kwargs` except "ticklabels"
which doesn't apply and "twig" and "trunk" which, similarly to `stats_kwargs`
take the same aesthetics through the "credible_interval" key.
By default, aesthetic mappings are generated for: y, alpha, overlay and color
(if multiple models are present). All aesthetic mappings but alpha are applied
to both the credible intervals and the point estimate; overlay is applied
to labels; and both overlay and alpha are applied to the shade.
"overlay" is a dummy aesthetic to trigger looping over variables and/or
dimensions using all aesthetics in every iteration. "alpha" gets two
values (0, 0.3) in order to trigger the alternate shading effect.
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* trunk, twig -> passed to :func:`~.visuals.line_x`
* point_estimate -> passed to :func:`~.visuals.scatter_x`
* labels -> passed to :func:`~.visuals.annotate_label`
* shade -> passed to :func:`~.visuals.fill_between_y`
* ticklabels -> passed to :func:`~.backend.xticks`
* remove_axis -> not passed anywhere, can only take ``False`` as value to skip calling
:func:`~.visuals.remove_axis`
stats_kwargs : mapping, optional
Valid keys are:
* trunk, twig -> passed to eti or hdi
* point_estimate -> passed to mean, median or mode
pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.grid`
Returns
-------
PlotCollection
Notes
-----
The separation between variables and all its coordinate values is set to 1.
The only two exceptions to this are the dimensions named "chain" and "model"
in case they are present, which get a smaller spacing to give a sense of
grouping among visual elements that only differ on their chain or model id.
See Also
--------
:ref:`plots_intro` :
General introduction to batteries-included plotting functions, common use and logic overview
plot_ridge : Visual representation of marginal distributions over the y axis
Examples
--------
Single model forest plot with color mapped to the variable (mapping which is also applied
to the labels) and alternate shading per school.
Moreover, to ensure the shading looks continuous, we'll specify we don't want to use
constrained layout (set by the "arviz-clean" theme) and to avoid having the labels
too squished we'll set the ``width_ratios`` for
:func:`~arviz_plots.backend.none.create_plotting_grid` via ``pc_kwargs``.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_forest, style
>>> style.use("arviz-clean")
>>> from arviz_base import load_arviz_data
>>> non_centered = load_arviz_data('non_centered_eight')
>>> pc = plot_forest(
>>> non_centered,
>>> var_names=["theta", "mu", "theta_t", "tau"],
>>> pc_kwargs={
>>> "aes": {"color": ["__variable__"]},
>>> "plot_grid_kws": {"width_ratios": [1, 2], "layout": "none"}
>>> },
>>> aes_map={"labels": ["color"]},
>>> shade_label="school",
>>> )
.. minigallery:: plot_forest
"""
if ci_kind not in ("hdi", "eti", None):
raise ValueError("ci_kind must be either 'hdi' or 'eti'")
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if ci_probs is None:
rc_ci_prob = rcParams["stats.ci_prob"]
ci_probs = (0.5, rc_ci_prob) if rc_ci_prob > 0.5 else (0.5 * rc_ci_prob, rc_ci_prob)
if ci_probs[0] > ci_probs[1]:
raise ValueError("First element of ci_probs must be smaller than the second")
if ci_kind is None:
ci_kind = rcParams["stats.ci_kind"]
if point_estimate is None:
point_estimate = rcParams["stats.point_estimate"]
if plot_kwargs is None:
plot_kwargs = {}
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()
if stats_kwargs is None:
stats_kwargs = {}
distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
labellable_dims = ["__variable__"] + [
dim for dim in distribution.dims if (dim not in {"model", "column"}.union(sample_dims))
]
if labels is None:
labels = labellable_dims
if not combined and "chain" not in distribution.dims:
combined = True
labels_kwargs = copy(plot_kwargs.get("labels", {}))
if labels_kwargs is False:
raise ValueError("plot_kwargs['labels'] can't be False, use labels=[] to remove all labels")
shade_kwargs = copy(plot_kwargs.get("shade", {}))
if shade_kwargs is False:
raise ValueError(
"plot_kwargs['shade'] can't be False, use shade_label=None to remove shading"
)
if shade_label is not None and shade_label not in labels:
raise ValueError("shade_label must be one of the elements in labels argument")
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
given_plotcollection = True
if plot_collection is None:
given_plotcollection = False
pc_data = distribution
if "column" not in pc_data:
pc_data = pc_data.expand_dims(column=2).assign_coords(column=["labels", "forest"])
elif ("forest" not in pc_data.column) or ("labels" not in pc_data.column):
raise ValueError(
"Found colum dimension in input data but required coordinates "
"'labels' and 'forest' are missing."
)
pc_kwargs.setdefault("cols", ["column"])
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy()
pc_kwargs["plot_grid_kws"].setdefault("sharey", True)
width_ratios = xr.ones_like(pc_data.column, dtype=float)
width_ratios.loc[{"column": "forest"}] = 3 if len(labels) < 3 else 2
pc_kwargs["plot_grid_kws"].setdefault("width_ratios", width_ratios.values)
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
alpha_aes_dims = False
if shade_label is not None:
pc_kwargs["plot_grid_kws"].setdefault("plot_hspace", 0)
alpha_aes_dims = pc_kwargs["aes"].get("alpha", [shade_label])
pc_kwargs["aes"]["alpha"] = alpha_aes_dims
pc_kwargs["aes"].setdefault("y", labellable_dims)
pc_kwargs["aes"].setdefault("overlay", labellable_dims)
if alpha_aes_dims is not False:
if ("__variable__" in pc_kwargs["aes"]["alpha"]) or all(
shade_label in da.dims for da in distribution.values()
):
pc_kwargs.setdefault("alpha", [0, 0.3])
else:
# trigger inclusion of neutral element in aes cycle
pc_kwargs.setdefault("alpha", [0, 0, 0.3])
if "model" in distribution.dims:
pc_kwargs["aes"].setdefault("color", ["model"])
figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None)
figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches")
if figsize is None:
coeff = 0.2
n_blocks = process_facet_dims(
pc_data, [dim for dim in pc_kwargs["aes"]["y"] if dim not in ("chain", "model")]
)[0]
if not combined and "chain" in distribution.dims:
coeff += 0.1
if "model" in distribution.dims:
coeff += 0.1 * distribution.sizes["model"]
figsize = plot_bknd.scale_fig_size(
figsize,
rows=1 + coeff * n_blocks,
cols=process_facet_dims(pc_data, pc_kwargs["cols"])[0],
figsize_units=figsize_units,
)
figsize_units = "dots"
pc_kwargs["plot_grid_kws"]["figsize"] = figsize
pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units
plot_collection = PlotCollection.grid(
pc_data,
backend=backend,
**pc_kwargs,
)
if "column" in distribution.dims:
distribution = distribution.sel(column="forest")
if combined:
chain_mapped_to_aes = set(
aes_key
for var_name, child in plot_collection.aes.children.items()
for aes_key, aes_vals in child.items()
if "chain" in aes_vals.dims
)
if chain_mapped_to_aes:
raise ValueError(
f"Found properties {chain_mapped_to_aes} mapped to the chain dimension, "
"but combined=True. Set combined=False or modify the aesthetic mappings"
)
# fine tune y position for model and chain
add_factor = 0.2 if (not combined) or ("model" in distribution.dims) else 0
y_ds = plot_collection.get_aes_as_dataset("y")
if not given_plotcollection:
shift = 0
if combined and "model" in distribution.dims:
shift = xr.DataArray(
np.linspace(-0.2, 0.2, distribution.sizes["model"]),
coords={"model": distribution.model},
)
elif (not combined) and ("model" in distribution.dims):
model_spacing = xr.DataArray(
np.linspace(-0.2, 0.2, distribution.sizes["model"]),
coords={"model": distribution.model},
)
chain_lim = 0.4 * (model_spacing[1] - model_spacing[0]).item()
chain_spacing = xr.DataArray(
np.linspace(-chain_lim, chain_lim, distribution.sizes["chain"]),
coords={"chain": distribution.chain},
)
shift = model_spacing + chain_spacing
elif not combined:
shift = xr.DataArray(
np.linspace(-0.2, 0.2, distribution.sizes["chain"]),
coords={"chain": distribution.chain},
)
y_ds = y_ds.max().to_array().max() - add_factor - y_ds - shift
plot_collection.update_aes_from_dataset("y", y_ds)
if aes_map is None:
aes_map = {}
else:
aes_map = aes_map.copy()
aes_map.setdefault("credible_interval", plot_collection.aes_set.difference({"alpha"}))
aes_map.setdefault("point_estimate", plot_collection.aes_set.difference({"alpha"}))
aes_map["labels"] = {"overlay"}.union(aes_map.get("labels", {}))
aes_map["shade"] = {"overlay", "alpha"}.union(aes_map.get("shade", {}))
if labeller is None:
labeller = BaseLabeller()
# compute credible interval
ci_dims, ci_aes, ci_ignore = filter_aes(
plot_collection, aes_map, "credible_interval", sample_dims
)
twig_kwargs = copy(plot_kwargs.get("twig", {}))
trunk_kwargs = copy(plot_kwargs.get("trunk", {}))
if ci_kind == "eti":
ci_fun = distribution.azstats.eti
elif ci_kind == "hdi":
ci_fun = distribution.azstats.hdi
if twig_kwargs is not False:
twig_stats = stats_kwargs.get("twig", {})
if isinstance(twig_stats, xr.Dataset):
ci_twig = twig_stats
else:
ci_twig = ci_fun(prob=ci_probs[1], dims=ci_dims, **twig_stats)
if trunk_kwargs is not False:
trunk_stats = stats_kwargs.get("trunk", {})
if isinstance(trunk_stats, xr.Dataset):
ci_trunk = trunk_stats
else:
ci_trunk = ci_fun(prob=ci_probs[0], dims=ci_dims, **trunk_stats)
# compute point estimate
pe_kwargs = copy(plot_kwargs.get("point_estimate", {}))
if pe_kwargs is not None:
pe_dims, pe_aes, pe_ignore = filter_aes(
plot_collection, aes_map, "point_estimate", sample_dims
)
pe_stats = stats_kwargs.get("point_estimate", {})
if isinstance(pe_stats, xr.Dataset):
point = pe_stats
elif point_estimate == "median":
point = distribution.median(dim=pe_dims, **pe_stats)
elif point_estimate == "mean":
point = distribution.mean(dim=pe_dims, **pe_stats)
else:
raise NotImplementedError("coming soon")
if twig_kwargs is not False:
x_range = ci_twig
elif trunk_kwargs is not False:
x_range = ci_trunk
elif pe_kwargs is not False:
x_range = point
else:
x_range = xr.ones_like(distribution)
# add labels and shading first, so forest plot is rendered on top
cumulative_label = []
x = 0
for label in labellable_dims:
cumulative_label.append(label)
# each variable+coord combination has the space between i.5 and (i+1).5 "reserved"
# if there multiple models/combined all lines are plotted in the central region
# of .4 width, if single model+combined, the line is at the center
# shade extend is the value to add to the max/substract to the min to shade the
# whole unit regions reserved to variables/coords
shade_extend = 0.5 if add_factor == 0 else 0.3
if label not in labels:
continue
if label == "__variable__":
y_max = y_ds.max() + shade_extend
y_min = y_ds.min() - shade_extend
else:
reduce_dims = [dim for dim in y_ds.dims if dim not in cumulative_label]
y_max = y_ds.max(reduce_dims) + shade_extend
y_min = y_ds.min(reduce_dims) - shade_extend
y = (y_max + y_min) / 2
if shade_label == label:
_, shade_aes, shade_ignore = filter_aes(plot_collection, aes_map, "shade", sample_dims)
if "color" not in shade_aes:
shade_kwargs.setdefault("color", "gray")
shade_data = xr.concat((y_min, y_max), "kwarg").assign_coords(
kwarg=["y_bottom", "y_top"]
)
shade_start = -0.1 if x == 0 else x - 0.6
xlim_labels = [-0.1, len(labels) - 0.9]
plot_collection.map(
fill_between_y,
"shade",
data=shade_data,
x=[shade_start, xlim_labels[1]],
coords={"column": "labels"},
ignore_aes=shade_ignore,
**shade_kwargs,
)
ci_global_min = x_range.min().to_array().min().item()
ci_global_max = x_range.max().to_array().max().item()
ci_range_extend = 0.1 * (ci_global_max - ci_global_min)
xlim_forest = [ci_global_min - ci_range_extend, ci_global_max + ci_range_extend]
plot_collection.map(
fill_between_y,
"shade",
data=shade_data,
x=xlim_forest,
coords={"column": "forest"},
ignore_aes=shade_ignore,
**shade_kwargs,
)
_, lab_aes, lab_ignore = filter_aes(plot_collection, aes_map, "labels", sample_dims)
extra_ignore_aes = []
for aes_key in lab_aes:
if aes_key == "overlay":
continue
aes_ds = plot_collection.get_aes_as_dataset(aes_key)
if set(aes_ds.dims).difference(cumulative_label):
extra_ignore_aes.append(aes_key)
lab_aes = set(lab_aes).difference(extra_ignore_aes)
lab_ignore = set(lab_ignore).union(extra_ignore_aes)
lab_kwargs = labels_kwargs.copy()
if "color" not in lab_aes:
lab_kwargs.setdefault("color", "black")
if x == 0:
lab_kwargs.setdefault("horizontal_align", "left")
if x == len(labels) - 1:
lab_kwargs.setdefault("horizontal_align", "right")
plot_collection.map(
annotate_label,
f"{label.strip('_')}_label",
data=y,
x=x,
dim=None if label == "__variable__" else label,
subset_info=True,
coords={"column": "labels"},
ignore_aes=lab_ignore,
**lab_kwargs,
)
x += 1
ticklabel_kwargs = copy(plot_kwargs.get("ticklabels", {}))
if ticklabel_kwargs is not False:
plot_bknd.xticks(
np.arange(len(labels)),
[label.strip("_") for label in labellable_dims if label in labels],
plot_collection.get_target(None, {"column": "labels"}),
**ticklabel_kwargs,
)
# plot credible interval
default_color = plot_bknd.get_default_aes("color", 1, {})[0]
if twig_kwargs is not False:
twig_kwargs.setdefault("width", 0.7)
if "color" not in ci_aes:
twig_kwargs.setdefault("color", default_color)
plot_collection.map(
line_x,
"twig",
data=ci_twig,
ignore_aes=ci_ignore,
coords={"column": "forest"},
**twig_kwargs,
)
if trunk_kwargs is not False:
trunk_kwargs.setdefault("width", 2)
if "color" not in ci_aes:
trunk_kwargs.setdefault("color", default_color)
plot_collection.map(
line_x,
"trunk",
data=ci_trunk,
ignore_aes=ci_ignore,
coords={"column": "forest"},
**trunk_kwargs,
)
# point estimate
if pe_kwargs is not False:
if "color" not in pe_aes:
pe_kwargs.setdefault("color", default_color)
if "facecolor" not in pe_aes:
pe_kwargs.setdefault("facecolor", "white")
if "width" not in pe_aes:
pe_kwargs.setdefault("width", 1)
plot_collection.map(
scatter_x,
"point_estimate",
data=point,
ignore_aes=pe_ignore,
coords={"column": "forest"},
**pe_kwargs,
)
if shade_label is not None:
plot_bknd.xlim(xlim_labels, plot_collection.get_target(None, {"column": "labels"}))
plot_bknd.xlim(xlim_forest, plot_collection.get_target(None, {"column": "forest"}))
if plot_kwargs.get("remove_axis", True) is not False:
plot_collection.map(
remove_axis, store_artist=False, axis="y", ignore_aes=plot_collection.aes_set
)
return plot_collection