MMMPlotSuite.marginal_curve#
- MMMPlotSuite.marginal_curve(hdi_prob=0.94, ax=None, aggregation=None, subplot_kwargs=None, *, plot_kwargs=None, ylabel='Marginal effect', xlabel='Sweep', title='Marginal effects', add_figure_title=True)[source]#
Plot precomputed marginal effects stored under
idata.sensitivity_analysis['marginal_effects'].- Parameters:
- hdi_prob
float, default 0.94 HDI probability mass.
- ax
plt.Axes, optional The axis to plot on.
- aggregation
dict, optional Aggregation to apply to the data. E.g., {“sum”: (“channel”,)} to sum over the channel dimension.
- subplot_kwargs
dict, optional Additional subplot configuration forwarded to
sensitivity_analysis().- plot_kwargs
dict, optional Keyword arguments forwarded to the underlying line plot. Defaults to
{"color": "C1"}.- ylabel
str, optional Y-axis label. Defaults to “Marginal effect”.
- xlabel
str, optional X-axis label. Defaults to “Sweep”.
- title
str, optional Figure-level title to add when
add_figure_title=True. Defaults to “Marginal effects”.- add_figure_titlebool, optional
Whether to add a figure-level title. Defaults to
True.
- hdi_prob
Examples
Persist marginal effects and plot:
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( var_input="channel_data", sweep_values=sweeps, var_names="channel_contribution", sweep_type="multiplicative", ) me = sa.compute_marginal_effects(results, extend_idata=True) _ = mmm.plot.marginal_curve(hdi_prob=0.9)