Model fit checks
library(MASS)
library(brms)
library(ggplot2)Let’s assume that our data \(y\) is real valued and our predictor is “Age” and a group number.
y <- c(rnorm(15, mean = 1, sd = 2), rnorm(15, mean = 10, sd = 2.5))
dat1 <- data.frame(y)
dat1$age <- rnorm(30, mean = 40, sd = 10)
dat1$grp <- as.factor(c(rep(1,15),rep(2,15)))
head(dat1)We would like to model this data according to the real-valued predictor age, as well as the categorical predictor grp. Let’s imagine we want to create a simple model of the form \(y \sim N(\mu, \sigma)\) where \(\mu = a_i + b \cdot Age\), where $a_i$ is the intercept for i-th group. That can be easily encoded in the brms formula
model_formula<- bf(y ~ 0 + grp + age)fit0 <- brm(model_formula,
data = dat1, family = gaussian())Posterior summary
examine the posterior samples with posterior_summary
posterior_summary(fit0)we can also plot the posterior samples for each parameter
plot(fit0)Posterior predictive checks
pp_check is a handy function for visualizing the predictions of the model.
pp_check(fit0)By default it ouputs an overlayed density plot. Another handy visualization is the predictive intervals for each datapoint against it’s true value.
pp_check(fit0,
type = 'intervals')Type pp_check(type = xyz) for a full list of options.
Let’s assume we have 10 new datapoints, together with their true values,
age <- rnorm(10, mean = 40, sd = 10)
dat_new <- data.frame(age)
dat_new$grp <- as.factor(c(rep(1,5),rep(2,5)))
dat_new$y <- c(rnorm(5, mean = 1, sd = 2), rnorm(5, mean = 10, sd = 2.5))we can get redictions from our model
pp_check(fit0,
type='intervals',
newdata = dat_new)Predictions
Each data point gets a predictive distribution that is implied by the posterior samples. Let’s say we have one data point
y0 <- expand.grid(age = 30 ,
grp = 2)Our model predicts the following value
predict(fit0, newdata = y0)To see the full values of values type
predict(fit0, newdata = y0, summary=FALSE)
# or equivalently
posterior_predict(fit0, newdata = y0)Recall that the model is \(y \sim N(\mu, \sigma)\) and hence the predictions incorporate the uncertainty that stems from the variance \(\sigma\). There is a special function, called fitted for the prediction of the average \(\mu = a + b * Age\)
fitted(fit0, newdata = y0)Note how the estimate is very close to the value proivded by predict. The difference lies in the predicted percintile intervals, which are much narrower for the fitted function, as they should be.
Finally there is marginal_effects which visualizes the effect of the values of the features on the observed outcome $y$. For example we can see the effect of the group as follows
marginal_effects(fit0, effects = 'grp')there is mode options to visualize interactions or other details
marginal_effects(fit0, effects = 'grp:age')Model comparison
Let’s say we want to examine if ignoring the age effect makes for a better model. We would model it as follows
model_formula2<- bf(y ~ 0 + grp )
fit2 <- brm(model_formula2,
data = dat1, family = gaussian())One method to compare the models is to compute the approximate leave-one-out cross-validation. The smaller the value the better. Hence we see that model 2 wins here.
loo(fit0,fit2)Another way to compare models is to compute their weights if they were to be combined
loo_model_weights(fit0, fit2) Another method is WAIC, again the smaller the better.
waic(fit0,fit2)There is also a handy method for k-fold cross validation.
kfold(fit2, K = 3)Model Averaging
If we wanted to extract posterior samples from a weighted average of the two models we can use
posterior_average(fit0, fit2, weights = "loo")And to retrieve direct predictions
pp_average(fit0, fit2, method = "fitted", newdata = y0)References
- See documentation of pp_check
- For examples of bayesplot, the package behind
pp_check - See documentation of predict
- See documentation of fitted
- See documentation of marginal_effects
- See documentation of loo
- See documentation of kfold
- See documentation of posterior_average
- See documentation of pp_average