#================================================================================================
# Load libraries - suppress messages
#
suppressMessages(library(tidyverse))
suppressMessages(library(tidymodels))
suppressMessages(library(rattle))
#================================================================================================
# Load the Titanic training data and transform Embarked to a factor
#
titanic_train <- read_csv("titanic_train.csv", show_col_types = FALSE) %>%
mutate(Sex = factor(Sex),
Embarked = factor(case_when(
Embarked == "C" ~ "Cherbourg",
Embarked == "Q" ~ "Queenstown",
Embarked == "S" ~ "Southampton",
is.na(Embarked) ~ "missing")))
#================================================================================================
# Craft the recipe - recipes package
#
titanic_recipe <- recipe(Survived ~ Sex + Pclass + Embarked, data = titanic_train) %>%
step_num2factor(Survived,
transform = function(x) x + 1,
levels = c("perished", "survived")) %>%
step_num2factor(Pclass,
levels = c("first", "second", "third"))
#================================================================================================
# Specify the algorithm - parsnip package
#
# Specify a single CART decision tree with no pre-pruning and a value of 14 for the min_n hyperparameter
titanic_model <- decision_tree() %>%
set_engine("rpart") %>%
set_mode("classification")
#================================================================================================
# Set up the workflow
#
titanic_workflow <- workflow() %>%
add_recipe(titanic_recipe) %>%
add_model(titanic_model)
#================================================================================================
# Fit the model to all the Titanic training data
#
titanic_fit <- titanic_workflow %>%
fit(titanic_train)
#================================================================================================
# Visualize the tree by extracting the trained model
#
titanic_tree <- extract_fit_parsnip(titanic_fit)
# Write the visualization to a file on disk
png(filename = "output/tree.png", height = 750, width = 750)
fancyRpartPlot(titanic_tree$fit, sub = NULL)
# Close the device opened by the png() function
dev.off()