Discover the essentials of decision trees in machine learning, from understanding overfitting and hyperparameter tuning to a step-by-step guide on sepsis survival prediction using Python. Learn how to improve model performance and interpretability with practical examples and code snippets.
Table of Contents:
- Introduction to Decision Trees in Machine Learning
- Overview of Decision Trees and Their Importance
- Understanding Decision Tree Terminology
- Nodes, Splits, and Leaves Explained
- The Problem of Overfitting in Decision Trees
- How Overfitting Affects Model Performance
- Hyperparameter Tuning for Decision Trees
- Techniques to Prevent Overfitting in Decision Trees
- Real-World Example: Sepsis Survival Prediction
- Applying Decision Trees to Predict Sepsis Outcomes
- Using Scikit-Learn for Decision Tree Implementation
- Step-by-Step Guide to Building a Decision Tree in Python
- Balancing Datasets with SMOTE for Better Decision Trees
- How to Handle Imbalanced Data for Accurate Predictions
- Evaluating Decision Tree Performance with Confusion Matrix
- Understanding True Positives, True Negatives, and Model Accuracy
- Optimizing Decision Trees with Precision, Recall, and F1 Score
- How to Choose the Right Metrics for Your Model
- Conclusion and Next Steps in Decision Tree Modeling
- Improving Decision Tree Models for Better Generalizability
Introduction to Decision Trees in Machine Learning
Overview of Decision Trees and Their Importance
So let's look at a concrete example. Let's say we want to predict whether I'm going to drink tea or coffee and to make that prediction, we can use a decision tree like the one shown here. The way this works is, we start at the top of the decision tree here and we just answer the following yes or no questions.
Understanding Decision Tree Terminology
Nodes, Splits, and Leaves Explained
These rectangles that we see throughout the decision trees are called nodes and we have different types of nodes. For example, the node that sits at the top of the decision tree is called the root node.
Over here in green, we have what is called a leaf node and these are nodes that I don't ask any yes or no questions and where we can assign our prediction and then we have splitting nodes which ask yes or no questions but are not the root node. I personally find this a much more intuitive way to think about things. So here, we have the example from before where we have two pieces of information namely time of day, which we're plotting on this x-axis.
The amount of sleep I got last night which we're plotting on the y-axis. So another way we can represent what the decision tree is doing is by partitioning, This predictor space means, the space defined by R2 predictor variables time of day and hours of sleep partitioning. This predictor spaces into different sections and assigns a label to each section. So what this will look like for splitting on 4 pm and 6 hours of sleep is something like this, so here's 4 pm.
We draw a line here and then here's six hours of sleep. We draw another line here and now we just look at the leaf nodes for each of these splits and assign a label to each section. So intuitively, this is all a decision tree is doing it's taking the predictor space splitting it into the different sections, and then assigning a label to each section. So now that we have a basic understanding of what decision trees are and an intuition for how they work.
The Problem of Overfitting in Decision Trees
A natural question is how can we bring this into practice. Namely, how can I use a decision tree in the real world? We can use decision trees in practice by developing them from data. So put another way we can learn decision tree structure from data.
How Overfitting Affects Model Performance
So I'm going to walk through an example. Here just to give you a qualitative sense of how this works. I'll just kind of start with the disclaimer that there are many ways to grow decision trees from data. But what I'm going to describe here is a widely used methodology. All right so before getting into it. I need to introduce the concept of Genie impurity.
I'm just throwing the equation up here for completeness and for those that I think in terms of math and just describing what this is it's saying that the genie impurity of a sample. This sample right here is equal to 1 minus the sum over pi squared and so Pi corresponds to the probability of the I class looking at this example. Here we have two possible classes tea or coffee. So the genie impurity of this sample would simply be 1 minus. The probability of t squared minus the probability of coffee squared and if that doesn't make any sense. No worries we can think of the genie impurity in terms of its extremes.
Namely its minimum and maximum value. So visualizing this we have minimum impurity whenever every class in our sample is identical. So either every class in the sample is T or every class in the sample is coffee. On the other end of the spectrum. We have maximum impurity when each class is equally likely and so for those of you familiar with information. Theory or the concept of entropy, you'll notice that this quantity of Genie impurity is actually proportional with information entropy, okay. So you might be saying, Xiao.
Why are you talking about this Genie impurity? I'm glad you asked that because we can use the genie impurity to learn decision tree structure from data and so the goal when growing decision trees is to use our predictor variables to split our data such that the overall Genie impurity is minimized. So essentially growing a decision tree is an optimization problem and in the slides.
Real-World Example: Sepsis Survival Prediction
Applying Decision Trees to Predict Sepsis Outcomes
For example, we take this first value of 721 am and we can split our data based on time being less than or equal to 721 am and the resulting split would look like this. So here we have a sample with just one record. This first one here and then we have everything else in this other sample here and then we can evaluate this split option by Computing its Genie impurity.
Basically what I mean by that is we calculate the genie and purity of this sample and the genie impurity of this sample and then we take their weighted average. So here we just have one class. So that is actually a minimum impurity of zero and then this one is a bit of a mix. So it's going to be pretty close to the maximum impurity then we will wait for the average by the number of Records in each sample.
So this one will have a very low weight because it only consists of one record and then this one will have a very high weight because there are a lot of Records in this box over here. So this split will give us a number corresponding to its genium Purity and then we can just continue this process so now we split at 8:47 am. Calculate the average geranium Purity, which we split at 9:30 am. Calculate average Genie impurity and so on and so forth for every possible value of time in this data set and then we do the same thing for amount of sleep.
We look at our first option which is 5.5. We get something like this 5.9, 595, and so on and so forth for every single value. We observe in our data set and so let's say after doing this and calculating all these average geranium Purity values. We discovered that the split option of sleeping less than or equal to 6.75 hours is the optimal value. This gives us the smallest Genie and purity of all the different split options that we observe in our data set and so now notice this node over here is pure. It has minimum impurity so it doesn't really make sense to split this sample further. But on the left-hand side, we still have some impurity in this node and we can do additional splits.
Now we have a smaller data set, so instead of starting with all the data in our table. We just have a subset shown by this smaller table over here and then we just repeat the same exact process as before we evaluate every split option. Let's say that after evaluating all the split options. We discovered that splitting on time less than or equal to 145 PM gives us the smallest Genie impurity and now notice again. We have a pure nod here, but we still have some impurities here.
We can just keep splitting the data until every single node is pure and meaningful. Every single node just has a single class in it and it has a Genie impurity equal to zero. So at first while this might sound great. You might think, oh! we can have a perfect classifier. We can have an absolutely perfect decision tree.
However, this is not such a great idea because this brings up a very well-known problem in machine learning known as the overfitting problem. Overfitting is when you learn a machine learning model based on some data set but your model becomes over-optimized on the data set. It was trained on and when you try to apply that model to new data that it's never seen before. You'll find that your model is actually very inaccurate so instead of allowing our decision tree to grow without end and become hyper-optimized to our data set, We can control the growth of our decision tree.
Hyperparameter Tuning for Decision Trees
Techniques to Prevent Overfitting in Decision Trees
On the right first, we have the maximum number of splits. So in the original decision tree, you see that we have two splits happening, but we could have easily constrained the size of this decision Tree by setting the max number of splits equal to one another hyperparameter. We could have used the minimum Leaf size. So in the original decision tree, we have a minimum Leaf size of two. But if we would have set the minimum Leaf size to something like five. This additional split could have never happened and then finally we could have controlled the number of splitting variables.
In the decision tree from the previous slide, we split on both hours of sleep and time of day, but if we set the number of splitting variables to 1 decorative constrained. Our decision tree to the sex and so the key point is hyperparameter tuning can help avoid this overfitting problem and improve your decision tree's generalizability.
Its ability to perform well on new data and then as a final note. Although this is a very widely used way to develop decision trees. This is not the only way to develop decision trees and I talk a little bit more about alternative strategies for developing decision trees in the blog associated with this article. So if you're interested in that, be sure to check that out okay. So with the theoretical Foundation set, let's dive into a concrete example with code, and data from The Real World. So here we're going to do sepsis survival prediction using decision trees.
Using Scikit-Learn for Decision Tree Implementation
Step-by-Step Guide to Building a Decision Tree in Python
The first step is we're going to import some helpful Python libraries. So Pandas is going to help us with formatting. Our data Numpy helps do some math and calculations. We use Matplotlib to make visualizations. We import several things from Sklearn and then finally we're going to import this smote function to help balance our data set which we will talk about here soon. So with our libraries imported, we can read in our data set.
So with pandas, this is just one line of code, and the CSV file used here is available at the GitHub repo as well as two additional CSV files that can be used for validating our decision tree. So with our data read in, we can plot the histograms for every variable in our data set. So here we just have four variables, the age of the patient whether the patient is male or female. The number of sepsis episodes that the patient has experienced and then finally the outcome variable which is an indicator of whether the patient survived or died.
Balancing Datasets with SMOTE for Better Decision Trees
How to Handle Imbalanced Data for Accurate Predictions
One way we can correct this is using smote which stands for synthetic minority class over sampling technique. I think I got that right and it's basically a way to over-sample the minority class to make it more Equitable with the majority class and ultimately reduce bias in our decision tree. This is pretty straightforward. So here we're just grabbing the predictor variable names and the outcome.
Some variable names here. We store the predictor and outcome variables into two pandas data frames namely X and Y and then finally with just one line of code. We can use smote to over-sample the minority class and then we can plot the results using Matplotlib and look at that, we have a more balanced data set.
Now that we have balanced our data set. We can create our training in testing data sets and so basically the point of this is our training data set will be used to grow our decision tree and then the testing data set will be used to evaluate its performance. So here we use the 80-20 split. So 80 of the data is used for training, 20 is used for testing and then with that growing the decision tree is very straightforward. We can do it with just two lines of code as we do here.
The first step is we initialize the decision tree classifier and then the second step is we fit our decision tree to our data and then that's it, so we have our decision tree. We can take a look at it using this built-in functionality inside like it learns and this is what it looks like needless to say this is a very big decision tree and it's hard to think that a doctor or any medical staff will be able to interpret this decision tree to extract anything meaningful. But let's just put that point aside for now and evaluate our decision tree's performance.
Evaluating Decision Tree Performance with Confusion Matrix
Understanding True Positives, True Negatives, and Model Accuracy
On the left, what this is showing is the number of troop negatives, true positives false, positives, and false negatives. So in other words this is just comparing the decision tree predictions to the ground truth and I don't want to get into too many details about interpreting confusion matrices and whatnot for this discussion.
I'll say when it comes to confusion matrices you generally want to maximize the diagonal elements and minimize the off-diagonal element. What that means is we want our predictions and the ground truth to agree as much as possible and we want them to disagree as little as possible.
Optimizing Decision Trees with Precision, Recall, and F1 Score
How to Choose the Right Metrics for Your Model
In this case,, I'd say the Precision is something we care more about than recall because in this context. We probably care more about false positives than false negatives and so the reason is a false positive corresponds to the case where the decision tree predicted that the patient would survive and they did not.
For using this decision tree to quantify patient risk. Then there's a lot more downside to predicting that a patient would survive five that didn't, than predicting that a patient would die who doesn't and so clearly which one of these metrics you want to look at and care about is highly context-dependent sometimes.
Sometimes you care more about false negatives. Sometimes you might care more about false positives like the case here and so which metric you use to evaluate your model depends on the problem and the context you're looking at and then here's a handy function available in the GitHub that generates all this. Coming back to this massive decision tree from a couple slides ago, this brings up once again the overfitting problem.
Conclusion and Next Steps in Decision Tree Modeling
Improving Decision Tree Models for Better Generalizability
This decision tree might work reasonably well on the data set. Here a decision tree that looks like this is prone to overfitting meaning it may not generalize well to new data sets. So to avoid this problem we can use hyperparameters and so here we're just going to use one hyperparameter which is the maximum depth. Here we're just going to set that equal to three.
We ensure that the decision tree doesn't get too many branch settings, this is super easy with Sklearn. We just pass this input argument into our decision tree classifier and then we fit our model just like we did before and then outcomes our tuned decision tree and so plotting out the decision tree it looks like this and already we can see this is much more interpretable. We can actually read the text here and so just kind of looking at all these different splitting nodes. We're seeing the age predictor appearing a lot, so this indicates that age seems to be a very important risk factor.
When it comes to sepsis survival prediction and then also right. Here we're seeing that sex is playing a role which is a little surprising that we're not seeing the number of episodes and surely if we were to increase the max depth or do some other hyperparameter tuning. We would see that variable appear in additional splits. So our hyperparameter tuned decision tree seems more comprehensible but how does it perform so we can once again look at the confusion Matrix and those three performance match tricks and surprisingly.
The Precision is actually a little better for this hyper perimeter tune decision tree than the fully grown decision tree. But notice that the recall in F1 scores are significantly lower than what we saw before but I would say in this case we may not care about that because Precision like I was saying before might be the metric that we're really trying to optimize in this context. Because we likely will want to wait for false positives more than false negatives.
If you enjoyed this content please consider liking subscribing and sharing your thoughts in the comment section below as always thank you for your time and thanks for reading.