Grant McDermott developed this new R package I wish I had thought of:
# install.packages("remotes") remotes::install_github("grantmcdermott/parttree")
Using the familiar ggplot2 syntax, we can simply add decision tree boundaries to a plot of our data.
In this example from his Github page, Grant trains a decision tree on the famous Titanic data using the
parsnip package. And then visualizes the resulting partition / decision boundaries using the simple function
library(parsnip) library(titanic) ## Just for a different data set set.seed(123) ## For consistent jitter titanic_train$Survived = as.factor(titanic_train$Survived) ## Build our tree using parsnip (but with rpart as the model engine) ti_tree = decision_tree() %>% set_engine("rpart") %>% set_mode("classification") %>% fit(Survived ~ Pclass + Age, data = titanic_train) ## Plot the data and model partitions titanic_train %>% ggplot(aes(x=Pclass, y=Age)) + geom_jitter(aes(col=Survived), alpha=0.7) + geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) + theme_minimal()
This visualization precisely shows where the trained decision tree thinks it should predict that the passengers of the Titanic would have survived (blue regions) or not (red), based on their
passenger class (Pclass).
This will be super helpful if you need to explain to yourself, your team, or your stakeholders how you model works. Currently, only
rpart decision trees are supported, but I am very much hoping that Grant continues building this functionality!