In this article, we’ll learn how to program a decision tree from scratch in Python using ONLY MATH.
Let’s get started.
Table of Contents
Why is building a decision tree from scratch useful?
When studying a new machine learning model, I always get lost between the formulas and the theory, and I don’t understand how the algorithm works.
And there is no better way to understand an algorithm than to write it from 0, without any help or starting point.
Disclaimer
I have already written an article that discusses decision tree thoroughly, explaining the mathematical concepts and steps of the algorithm with pictures and examples.
I suggest you read it before continuing.
Decision tree in Python
Problem statement
We want to solve a regression problem with only numerical features by fitting a decision tree to the data.
1. Import necessary libraries
1 | import numpy |
In this code I only use Numpy, a library useful for dealing with lists, to save space by computing the mean value of a list without iterating.
2. Define a dataset
1 2 3 4 5 6 7 8 | X = { "LotArea" :[ 50 , 70 , 100 ], "Quality" :[ 8 , 7.5 , 9 ] } y = { "SalePrice" :[ 100 , 105 , 180 ] } |
This is a house price dataset with two input features: property area and quality of the materials; and one target feature: market price.
3. Define nodes
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | def create_node(datapoints): datapoints = datapoints mean_value = numpy.mean([train_y[ "SalePrice" ][y] for y in datapoints]) mean_variance = numpy.mean([(train_y[ "SalePrice" ][y] - mean_value) * * 2 for y in datapoints]) leaf = False if mean_variance ! = 0 else True return { # datapoints index in the node "datapoints" : datapoints, # mean output value of the data points in the node "mean_value" : mean_value, # impurity of the node "mean_variance" : mean_variance, # leaf state "leaf" : leaf, # best feature and treshold to split the node "feature" : None , "treshold" : None , # the left sub-node of the node "left" : None , # the right sub-node of the node "right" : None } |
The function create_node:
- Takes in input the indexes of the samples in the node and the training output dataset.
- Return a dictionary with the indexes and all of the properties of the node.
4. Define the split function
1 2 3 4 5 6 | def split_node(train_X, train_y, node, feature, treshold): node_left = create_node(train_y, [x for x in node[ "datapoints" ] if train_X[feature][x] < = treshold]) node_right = create_node(train_y, [x for x in node[ "datapoints" ] if train_X[feature][x] > treshold]) return node_left, node_right |
The function split_node:
- Takes in input: training X, training y, the node we want to split, and the threshold.
- Create 2 nodes. Each data point in the node goes to the left if it is minor or equal to the threshold, and vice versa.
- Return the left and right nodes.
5. Find the best threshold for a node
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 | def find_treshold(train_X, train_y, node): values = [] impurities = [] for feature in train_X: for treshold_index in node[ "datapoints" ]: node_left, node_right = split_node(train_X, train_y, node, feature, train_X[feature][treshold_index]) if len (node_left[ "datapoints" ]) = = 0 or len (node_right[ "datapoints" ]) = = 0 : continue weighted_impurity = node_left[ "mean_variance" ] * len (node_left[ "datapoints" ]) / len (node[ "datapoints" ]) + node_right[ "mean_variance" ] * len (node_right[ "datapoints" ]) / len (node[ "datapoints" ]) values.append([feature, train_X[feature][treshold_index]]) impurities.append(weighted_impurity) best_split = impurities.index( min (impurities)) node[ "feature" ] = values[best_split][ 0 ] node[ "treshold" ] = values[best_split][ 1 ] |
The algorithm calculates all the possible splits and selects the threshold with the minimal weighted impurity to split the node.
6. Build the tree structure
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 | def build_tree(train_X, train_y, node, max_depth, depth = 1 ): # if the current depth is equal to max_depth, stop the splitting process if depth < max_depth: # find the best treshold and split the node find_treshold(train_X, train_y, node) node[ "left" ], node[ "right" ] = split_node(train_X, train_y, node, node[ "feature" ], node[ "treshold" ]) if node[ "left" ][ "leaf" ] = = False : # re-execute this function with the left node as main node build_tree(train_X, train_y, node[ "left" ], max_depth, depth + 1 ) if node[ "right" ][ "leaf" ] = = False : # re-execute this function with the right node as main node build_tree(train_X, train_y, node[ "right" ], max_depth, depth + 1 ) else : node[ "leaf" ] = True |
This is a recursive function to build a tree based on a root node and a max_depth parameter.
7. Define the predict function
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 | def predict(val_X, tree): y = [] for index in range ( len (val_X[ "LotArea" ])): current_node = tree while not current_node[ "leaf" ]: # choose the path of the input samples if val_X[current_node[ "feature" ]][index] < = current_node[ "treshold" ]: current_node = current_node[ "left" ] else : current_node = current_node[ "right" ] # predicted output is the mean value of the node where the input falls y.append(current_node[ "mean_value" ]) return y |
This function locates the node where the input value falls in and returns the mean value of that node.
8. Fit a tree to the data
1 2 3 4 5 | tree = create_node(y, [ 0 , 1 , 2 ]) build_tree(X, y, tree, 3 ), print (tree) |
1 | > {'datapoints': [0, 1, 2], 'mean_value': 128.33333333333334, 'mean_variance': 1338.888888888889, 'leaf': False, 'feature': 'LotArea', 'treshold': 70, 'left': {'datapoints': [0, 1], 'mean_value': 102.5, 'mean_variance': 6.25, 'leaf': True, 'feature': None, 'treshold': None, 'left': None, 'right': None}, 'right': {'datapoints': [2], 'mean_value': 180.0, 'mean_variance': 0.0, 'leaf': True, 'feature': None, 'treshold': None, 'left': None, 'right': None}} |
This is the written structure of a decision tree, just like the sklearn one. WE’VE DONE IT!
9. Predict unseen values
1 2 3 4 5 6 | val_X = { "LotArea" : [ 50 , 90 ], "Quality" : [ 7.5 , 7.5 ] } print (predict(val_X, tree)) |
1 | [102.5, 180.0] |
Let’s go! These predictions make sense.
Decision tree from scratch full code
001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 032 033 034 035 036 037 038 039 040 041 042 043 044 045 046 047 048 049 050 051 052 053 054 055 056 057 058 059 060 061 062 063 064 065 066 067 068 069 070 071 072 073 074 075 076 077 078 079 080 081 082 083 084 085 086 087 088 089 090 091 092 093 094 095 096 097 098 099 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import numpy X = { "LotArea" :[ 50 , 70 , 100 ], "Quality" :[ 8 , 7.5 , 9 ] } y = { "SalePrice" :[ 100 , 105 , 180 ] } def create_node(train_y, datapoints): datapoints = datapoints mean_value = numpy.mean([train_y[ "SalePrice" ][y] for y in datapoints]) mean_variance = numpy.mean([(train_y[ "SalePrice" ][y] - mean_value) * * 2 for y in datapoints]) leaf = False if mean_variance ! = 0 else True return { # datapoints index in the node "datapoints" : datapoints, # mean output value of the data points in the node "mean_value" : mean_value, # impurity of the node "mean_variance" : mean_variance, # leaf state "leaf" : leaf, # best feature and treshold to split the node "feature" : None , "treshold" : None , # the left sub-node of the node "left" : None , # the right sub-node of the node "right" : None } def split_node(train_X, train_y, node, feature, treshold): node_left = create_node(train_y, [x for x in node[ "datapoints" ] if train_X[feature][x] < = treshold]) node_right = create_node(train_y, [x for x in node[ "datapoints" ] if train_X[feature][x] > treshold]) return node_left, node_right def find_treshold(train_X, train_y, node): values = [] impurities = [] for feature in train_X: for treshold_index in node[ "datapoints" ]: node_left, node_right = split_node(train_X, train_y, node, feature, train_X[feature][treshold_index]) if len (node_left[ "datapoints" ]) = = 0 or len (node_right[ "datapoints" ]) = = 0 : continue weighted_impurity = node_left[ "mean_variance" ] * len (node_left[ "datapoints" ]) / len (node[ "datapoints" ]) + node_right[ "mean_variance" ] * len (node_right[ "datapoints" ]) / len (node[ "datapoints" ]) values.append([feature, train_X[feature][treshold_index]]) impurities.append(weighted_impurity) best_split = impurities.index( min (impurities)) node[ "feature" ] = values[best_split][ 0 ] node[ "treshold" ] = values[best_split][ 1 ] def build_tree(train_X, train_y, node, max_depth, depth = 1 ): # if the current depth is equal to max_depth, stop the splitting process if depth < max_depth: # find the best treshold and split the node find_treshold(train_X, train_y, node) node[ "left" ], node[ "right" ] = split_node(train_X, train_y, node, node[ "feature" ], node[ "treshold" ]) if node[ "left" ][ "leaf" ] = = False : # re-execute this function with the left node as main node build_tree(train_X, train_y, node[ "left" ], max_depth, depth + 1 ) if node[ "right" ][ "leaf" ] = = False : # re-execute this function with the right node as main node build_tree(train_X, train_y, node[ "right" ], max_depth, depth + 1 ) else : node[ "leaf" ] = True def predict(val_X, tree): y = [] for index in range ( len (val_X[ "LotArea" ])): current_node = tree while not current_node[ "leaf" ]: # choose the path of the input samples if val_X[current_node[ "feature" ]][index] < = current_node[ "treshold" ]: current_node = current_node[ "left" ] else : current_node = current_node[ "right" ] # predicted output is the mean value of the node where the input falls y.append(current_node[ "mean_value" ]) return y tree = create_node(y, [ 0 , 1 , 2 ]) build_tree(X, y, tree, 2 ), print (tree) val_X = { "LotArea" : [ 50 , 90 ], "Quality" : [ 7.5 , 7.5 ] } print (predict(val_X, tree)) |