MMMPlotSuite.marginal_curve#

MMMPlotSuite.marginal_curve(hdi_prob=0.94, ax=None, aggregation=None, subplot_kwargs=None, *, plot_kwargs=None, ylabel='Marginal effect', xlabel='Sweep', title='Marginal effects', add_figure_title=True)[source]#

Plot precomputed marginal effects stored under idata.sensitivity_analysis['marginal_effects'].

Parameters:
hdi_probfloat, default 0.94

HDI probability mass.

axplt.Axes, optional

The axis to plot on.

aggregationdict, optional

Aggregation to apply to the data. E.g., {“sum”: (“channel”,)} to sum over the channel dimension.

subplot_kwargsdict, optional

Additional subplot configuration forwarded to sensitivity_analysis().

plot_kwargsdict, optional

Keyword arguments forwarded to the underlying line plot. Defaults to {"color": "C1"}.

ylabelstr, optional

Y-axis label. Defaults to “Marginal effect”.

xlabelstr, optional

X-axis label. Defaults to “Sweep”.

titlestr, optional

Figure-level title to add when add_figure_title=True. Defaults to “Marginal effects”.

add_figure_titlebool, optional

Whether to add a figure-level title. Defaults to True.

Examples

Persist marginal effects and plot:

from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis

sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
    var_input="channel_data",
    sweep_values=sweeps,
    var_names="channel_contribution",
    sweep_type="multiplicative",
)
me = sa.compute_marginal_effects(results, extend_idata=True)
_ = mmm.plot.marginal_curve(hdi_prob=0.9)