Skip to contents

Use marginaleffects::avg_predictions() to estimate marginal predictions for each variable of a model and return a tibble tidied in a way that it could be used by broom.helpers functions. See marginaleffects::avg_predictions() for a list of supported models.

Usage

tidy_marginal_predictions(
  x,
  variables_list = "auto",
  conf.int = TRUE,
  conf.level = 0.95,
  ...
)

variables_to_predict(
  model,
  interactions = TRUE,
  categorical = unique,
  continuous = stats::fivenum
)

plot_marginal_predictions(x, variables_list = "auto", conf.level = 0.95, ...)

Arguments

x

(a model object, e.g. glm)
A model to be tidied.

variables_list

(list or string)
A list whose elements will be sequentially passed to variables in marginaleffects::avg_predictions() (see details below); alternatively, it could also be the string "auto" (default) or "no_interaction".

conf.int

(logical)
Whether or not to include a confidence interval in the tidied output.

conf.level

(numeric)
The confidence level to use for the confidence interval (between 0 ans 1).

...

Additional parameters passed to marginaleffects::avg_predictions().

model

(a model object, e.g. glm)
A model.

interactions

(logical)
Should combinations of variables corresponding to interactions be returned?

categorical

(predictor values)
Default values for categorical variables.

continuous

(predictor values)
Default values for continuous variables.

Details

Marginal predictions are obtained by calling, for each variable, marginaleffects::avg_predictions() with the same variable being used for the variables and the by argument.

Considering a categorical variable named cat, tidy_marginal_predictions() will call avg_predictions(model, variables = list(cat = unique), by = "cat") to obtain average marginal predictions for this variable.

Considering a continuous variable named cont, tidy_marginal_predictions() will call avg_predictions(model, variables = list(cont = "fivenum"), by = "cont") to obtain average marginal predictions for this variable at the minimum, the first quartile, the median, the third quartile and the maximum of the observed values of cont.

By default, average marginal predictions are computed: predictions are made using a counterfactual grid for each value of the variable of interest, before averaging the results. Marginal predictions at the mean could be obtained by indicating newdata = "mean". Other assumptions are possible, see the help file of marginaleffects::avg_predictions().

tidy_marginal_predictions() will compute marginal predictions for each variable or combination of variables, before stacking the results in a unique tibble. This is why tidy_marginal_predictions() has a variables_list argument consisting of a list of specifications that will be passed sequentially to the variables argument of marginaleffects::avg_predictions().

The helper function variables_to_predict() could be used to automatically generate a suitable list to be used with variables_list. By default, all unique values are retained for categorical variables and fivenum (i.e. Tukey's five numbers, minimum, quartiles and maximum) for continuous variables. When interactions = FALSE, variables_to_predict() will return a list of all individual variables used in the model. If interactions = FALSE, it will search for higher order combinations of variables (see model_list_higher_order_variables()).

variables_list's default value, "auto", calls variables_to_predict(interactions = TRUE) while "no_interaction" is a shortcut for variables_to_predict(interactions = FALSE).

You can also provide custom specifications (see examples).

plot_marginal_predictions() works in a similar way and returns a list of plots that could be combined with patchwork::wrap_plots() (see examples).

For more information, see vignette("marginal_tidiers", "broom.helpers").

Examples

# example code

# \donttest{
# Average Marginal Predictions
df <- Titanic |>
  dplyr::as_tibble() |>
  tidyr::uncount(n) |>
  dplyr::mutate(Survived = factor(Survived, c("No", "Yes")))
mod <- glm(
  Survived ~ Class + Age + Sex,
  data = df, family = binomial
)
tidy_marginal_predictions(mod)
#>   variable   term  estimate   std.error statistic       p.value   s.value
#> 1    Class    1st 0.5181374 0.027913631  18.56217  6.503000e-77 253.08736
#> 2    Class    2nd 0.3229724 0.024270457  13.30722  2.101515e-40 131.80569
#> 3    Class    3rd 0.2115172 0.013107421  16.13721  1.397101e-58 192.18939
#> 4    Class   Crew 0.3504908 0.015030218  23.31908 2.839357e-120 397.12581
#> 5      Age  Adult 0.3140462 0.008827755  35.57486 3.429334e-277 918.39615
#> 6      Age  Child 0.5112612 0.047737204  10.70991  9.145005e-27  86.49907
#> 7      Sex Female 0.7255784 0.021732728  33.38644 2.157304e-244 809.44123
#> 8      Sex   Male 0.2182251 0.009932243  21.97138 5.410116e-107 353.01065
#>    conf.low conf.high
#> 1 0.4634277 0.5728471
#> 2 0.2754031 0.3705416
#> 3 0.1858272 0.2372073
#> 4 0.3210322 0.3799495
#> 5 0.2967441 0.3313483
#> 6 0.4176980 0.6048244
#> 7 0.6829830 0.7681737
#> 8 0.1987582 0.2376919
tidy_plus_plus(mod, tidy_fun = tidy_marginal_predictions)
#> # A tibble: 8 × 20
#>   term   variable var_label var_class var_type    var_nlevels contrasts      
#>   <chr>  <chr>    <chr>     <chr>     <chr>             <int> <chr>          
#> 1 1st    Class    Class     character categorical           4 contr.treatment
#> 2 2nd    Class    Class     character categorical           4 contr.treatment
#> 3 3rd    Class    Class     character categorical           4 contr.treatment
#> 4 Crew   Class    Class     character categorical           4 contr.treatment
#> 5 Adult  Age      Age       character dichotomous           2 contr.treatment
#> 6 Child  Age      Age       character dichotomous           2 contr.treatment
#> 7 Female Sex      Sex       character dichotomous           2 contr.treatment
#> 8 Male   Sex      Sex       character dichotomous           2 contr.treatment
#> # ℹ 13 more variables: contrasts_type <chr>, reference_row <lgl>, label <chr>,
#> #   n_obs <dbl>, n_event <dbl>, estimate <dbl>, std.error <dbl>,
#> #   statistic <dbl>, p.value <dbl>, s.value <dbl>, conf.low <dbl>,
#> #   conf.high <dbl>, label_attr <chr>
if (require("patchwork")) {
  plot_marginal_predictions(mod) |> patchwork::wrap_plots()
  plot_marginal_predictions(mod) |>
    patchwork::wrap_plots() &
    ggplot2::scale_y_continuous(limits = c(0, 1), label = scales::percent)
}
#> Loading required package: patchwork


mod2 <- lm(Petal.Length ~ poly(Petal.Width, 2) + Species, data = iris)
tidy_marginal_predictions(mod2)
#>      variable       term estimate  std.error statistic       p.value   s.value
#> 1 Petal.Width        0.1 2.302462 0.28351129  8.121235  4.614622e-16  50.94464
#> 2 Petal.Width        0.3 2.626585 0.20194705 13.006308  1.126554e-38 126.06135
#> 3 Petal.Width        1.3 3.997316 0.09875586 40.476747  0.000000e+00       Inf
#> 4 Petal.Width        1.8 4.526502 0.14311589 31.628228 1.511205e-219 726.90655
#> 5 Petal.Width        2.5 5.092442 0.19970017 25.500436 1.949411e-143 474.07268
#> 6     Species     setosa 2.681553 0.22806208 11.757994  6.424222e-32 103.61818
#> 7     Species versicolor 3.998581 0.10599505 37.724219  0.000000e+00       Inf
#> 8     Species  virginica 4.593867 0.15720110 29.222868 9.935198e-188 621.20993
#>   conf.low conf.high
#> 1 1.746790  2.858134
#> 2 2.230776  3.022394
#> 3 3.803758  4.190874
#> 4 4.246000  4.807004
#> 5 4.701036  5.483847
#> 6 2.234559  3.128546
#> 7 3.790834  4.206327
#> 8 4.285758  4.901975
if (require("patchwork")) {
  plot_marginal_predictions(mod2) |> patchwork::wrap_plots()
}

tidy_marginal_predictions(
  mod2,
  variables_list = variables_to_predict(mod2, continuous = "threenum")
)
#>      variable              term estimate  std.error statistic       p.value
#> 1 Petal.Width 0.437095664372987 2.839141 0.15360160  18.48380  2.788266e-76
#> 2 Petal.Width  1.19933333333333 3.878182 0.08698259  44.58572  0.000000e+00
#> 3 Petal.Width  1.96157100229368 4.675245 0.15291920  30.57331 2.771682e-205
#> 4     Species            setosa 2.681553 0.22806208  11.75799  6.424222e-32
#> 5     Species        versicolor 3.998581 0.10599505  37.72422  0.000000e+00
#> 6     Species         virginica 4.593867 0.15720110  29.22287 9.935198e-188
#>    s.value conf.low conf.high
#> 1 250.9872 2.538088  3.140195
#> 2      Inf 3.707699  4.048664
#> 3 679.5245 4.375529  4.974962
#> 4 103.6182 2.234559  3.128546
#> 5      Inf 3.790834  4.206327
#> 6 621.2099 4.285758  4.901975
tidy_marginal_predictions(
  mod2,
  variables_list = list(
    list(Petal.Width = c(0, 1, 2, 3)),
    list(Species = unique)
  )
)
#>      variable       term estimate  std.error statistic       p.value   s.value
#> 1 Petal.Width          0 2.134153 0.32883856  6.489972  8.585239e-11  33.43935
#> 2 Petal.Width          1 3.629827 0.06654813 54.544386  0.000000e+00       Inf
#> 3 Petal.Width          2 4.709023 0.15519349 30.342916 3.115795e-202 669.38988
#> 4 Petal.Width          3 5.371741 0.31274119 17.176313  3.995104e-66 217.24902
#> 5     Species     setosa 2.681553 0.22806208 11.757994  6.424222e-32 103.61818
#> 6     Species versicolor 3.998581 0.10599505 37.724219  0.000000e+00       Inf
#> 7     Species  virginica 4.593867 0.15720110 29.222868 9.935198e-188 621.20993
#>   conf.low conf.high
#> 1 1.489641  2.778665
#> 2 3.499395  3.760259
#> 3 4.404849  5.013197
#> 4 4.758779  5.984702
#> 5 2.234559  3.128546
#> 6 3.790834  4.206327
#> 7 4.285758  4.901975
tidy_marginal_predictions(
  mod2,
  variables_list = list(list(Species = unique, Petal.Width = 1:3))
)
#>              variable           term estimate  std.error  statistic
#> 1 Species:Petal.Width     setosa * 1 2.553380 0.25261804  10.107670
#> 2 Species:Petal.Width     setosa * 2 3.632576 0.37561389   9.671036
#> 3 Species:Petal.Width     setosa * 3 4.295293 0.42138340  10.193314
#> 4 Species:Petal.Width versicolor * 1 3.870408 0.08240088  46.970464
#> 5 Species:Petal.Width versicolor * 2 4.949603 0.11523718  42.951446
#> 6 Species:Petal.Width versicolor * 3 5.612321 0.35269179  15.912820
#> 7 Species:Petal.Width  virginica * 1 4.465694 0.16675860  26.779393
#> 8 Species:Petal.Width  virginica * 2 5.544890 0.05494211 100.922404
#> 9 Species:Petal.Width  virginica * 3 6.207607 0.27675590  22.429901
#>         p.value   s.value conf.low conf.high
#> 1  5.108488e-24  77.37338 2.058257  3.048502
#> 2  4.003040e-22  71.08132 2.896386  4.368765
#> 3  2.123982e-24  78.63950 3.469397  5.121190
#> 4  0.000000e+00       Inf 3.708905  4.031910
#> 5  0.000000e+00       Inf 4.723743  5.175464
#> 6  5.163415e-57 186.98158 4.921058  6.303584
#> 7 5.616443e-158 522.37498 4.138853  4.792535
#> 8  0.000000e+00       Inf 5.437205  5.652574
#> 9 2.010821e-111 367.72623 5.665176  6.750039

# Model with interactions
mod3 <- glm(
  Survived ~ Sex * Age + Class,
  data = df, family = binomial
)
tidy_marginal_predictions(mod3)
#>   variable           term  estimate   std.error statistic       p.value
#> 1    Class            1st 0.5122895 0.027808505 18.422043  8.743951e-76
#> 2    Class            2nd 0.3202040 0.023739078 13.488477  1.828376e-41
#> 3    Class            3rd 0.2095767 0.012826120 16.339834  5.139230e-60
#> 4    Class           Crew 0.3587744 0.014931458 24.028090 1.414851e-127
#> 5  Sex:Age Female * Adult 0.7419417 0.021989818 33.740241 1.486474e-249
#> 6  Sex:Age Female * Child 0.7217706 0.059574157 12.115498  8.743100e-34
#> 7  Sex:Age   Male * Adult 0.2032339 0.009817932 20.700274  3.444077e-95
#> 8  Sex:Age   Male * Child 0.5723894 0.060149442  9.516122  1.797624e-21
#>    s.value  conf.low conf.high
#> 1 249.3382 0.4577858 0.5667931
#> 2 135.3285 0.2736763 0.3667318
#> 3 196.9541 0.1844379 0.2347154
#> 4 421.3842 0.3295093 0.3880395
#> 5 826.5882 0.6988425 0.7850410
#> 6 109.8174 0.6050074 0.8385338
#> 7 313.7991 0.1839911 0.2224767
#> 8  68.9144 0.4544987 0.6902802
tidy_marginal_predictions(mod3, "no_interaction")
#>   variable   term  estimate   std.error statistic       p.value  s.value
#> 1      Sex Female 0.7407251 0.021107074  35.09369 8.413373e-270 893.8479
#> 2      Sex   Male 0.2191225 0.009889623  22.15681 8.967719e-109 358.9254
#> 3      Age  Adult 0.3142777 0.008706739  36.09591 2.629104e-285 945.3549
#> 4      Age  Child 0.6033792 0.050063951  12.05217  1.889135e-33 108.7059
#> 5    Class    1st 0.5122895 0.027808505  18.42204  8.743951e-76 249.3382
#> 6    Class    2nd 0.3202040 0.023739078  13.48848  1.828376e-41 135.3285
#> 7    Class    3rd 0.2095767 0.012826120  16.33983  5.139230e-60 196.9541
#> 8    Class   Crew 0.3587744 0.014931458  24.02809 1.414851e-127 421.3842
#>    conf.low conf.high
#> 1 0.6993560 0.7820942
#> 2 0.1997392 0.2385058
#> 3 0.2972128 0.3313426
#> 4 0.5052556 0.7015027
#> 5 0.4577858 0.5667931
#> 6 0.2736763 0.3667318
#> 7 0.1844379 0.2347154
#> 8 0.3295093 0.3880395
if (require("patchwork")) {
  plot_marginal_predictions(mod3) |>
    patchwork::wrap_plots()
  plot_marginal_predictions(mod3, "no_interaction") |>
    patchwork::wrap_plots()
}

tidy_marginal_predictions(
  mod3,
  variables_list = list(
    list(Class = unique, Sex = "Female"),
    list(Age = unique)
  )
)
#>    variable          term  estimate   std.error statistic       p.value
#> 1 Class:Sex  1st * Female 0.8980680 0.015940375  56.33920  0.000000e+00
#> 2 Class:Sex  2nd * Female 0.7580905 0.030941811  24.50052 1.458306e-132
#> 3 Class:Sex  3rd * Female 0.5904013 0.031062754  19.00673  1.500218e-80
#> 4 Class:Sex Crew * Female 0.7978123 0.026646079  29.94108 5.749545e-197
#> 5       Age         Adult 0.3142777 0.008706739  36.09591 2.629104e-285
#> 6       Age         Child 0.6033792 0.050063951  12.05217  1.889135e-33
#>    s.value  conf.low conf.high
#> 1      Inf 0.8668255 0.9293106
#> 2 437.9502 0.6974457 0.8187354
#> 3 265.1691 0.5295195 0.6512832
#> 4 651.8964 0.7455869 0.8500376
#> 5 945.3549 0.2972128 0.3313426
#> 6 108.7059 0.5052556 0.7015027

# Marginal Predictions at the Mean
tidy_marginal_predictions(mod, newdata = "mean")
#>   variable   term  estimate  std.error statistic       p.value   s.value
#> 1    Class    1st 0.4070382 0.03286740 12.384253  3.180314e-35 114.59831
#> 2    Class    2nd 0.1987193 0.02474435  8.030897  9.676261e-16  49.87640
#> 3    Class    3rd 0.1039594 0.01182240  8.793426  1.450667e-18  59.25799
#> 4    Class   Crew 0.2254997 0.01405834 16.040279  6.685371e-58 189.93082
#> 5      Age  Adult 0.2254997 0.01405834 16.040279  6.685371e-58 189.93082
#> 6      Age  Child 0.4570172 0.06370387  7.174088  7.279073e-13  40.32131
#> 7      Sex Female 0.7660538 0.02841756 26.957057 4.714993e-160 529.27124
#> 8      Sex   Male 0.2254997 0.01405834 16.040279  6.685371e-58 189.93082
#>     conf.low conf.high
#> 1 0.34261928 0.4714571
#> 2 0.15022129 0.2472174
#> 3 0.08078793 0.1271309
#> 4 0.19794588 0.2530536
#> 5 0.19794588 0.2530536
#> 6 0.33215989 0.5818745
#> 7 0.71035641 0.8217512
#> 8 0.19794588 0.2530536
if (require("patchwork")) {
  plot_marginal_predictions(mod, newdata = "mean") |>
    patchwork::wrap_plots()
}

# }