Decision Tree

What is a decision tree?

Well, a decision tree is the most simplistic way of reasoning to find order in chaos or a system. Although very rudimentry in nature, decision trees and its varients(bagging methods and boosting methods) are capable of mapping i.e. finding the order in data to a very high degree of accuracy. Decision trees are mainly binary tress comprising of if-else statements that determine the split lines, planes or hyperplanes which breaks down the data space in $\mathbb{R}^n$ dimensions into rectangles, cuboids or hyper cuboids.

What are we trying to do here?

This post/jupyter notebook is meant for visualising how the decision line for data in $\mathbb{R}^1$ space.

Approximated algorithm of Regression Decision Tree implemented here

  1. Divide the dimension in $\mathbb{R}^1$ space into equally spaced intervals
  2. At each interval compute the mean of all the points to the left and to the right of the interval
  3. Compute the Mean Squared Error(MSE) of all the points to the left of the interval with respect to the mean obtained in step 2 and do the same respectively for the points to the right of the interval.
  4. Add up the MSE's and keep track of them
  5. Repeat steps 2 to 4 at all the intervals
  6. Determine the interval where the MSE was minimum, this interval point is the optimal point which classifies the input data into two groups

Code start


Imports are handled and graph style presets are set here

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

from IPython.display import Video

%matplotlib widget

plt.style.use('seaborn')

Helper functions to calculate the mean and mean squared error of data points

def mean(x, y, split_at):
    ml, mr = np.mean(y[x < split_at]), np.mean(y[x > split_at])
    return ml, mr

def mse(x, y, split_at, ml, mr):
    msel, mser = np.mean(np.square(ml - y[x < split_at])), np.mean(np.square(mr - y[x > split_at]))
    return msel, mser

Initialise random data. The data is generated using two normal distributions with their respective means and variances. Equal data points from each normal distribution are sampled. Here, a method is also declared to perform the repetitive steps 2 to 4 as per the above algorithm.

class Draw:
    
    def __init__(self, samples):
        """
        Get the number of data samples to generate and prepare the figure
        """
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.ax.set_xlabel('x')
        self.ax.set_ylabel(r'$f(x)$')
        self.samples = samples
        
        # Data along the dimension of independent variable
        # 1st data cluster- mean=0, variance=0.1, count=samples/2
        # 2nd data cluster- mean=0.7, variance=0.1, count=samples/2
        self.x = np.hstack((np.random.normal(0, 0.1, (1, samples//2)), np.random.normal(0.7, 0.1, (1, samples//2))))
        self.xmin = np.min(self.x)
        self.xmax = np.max(self.x)
        # Divide the search space along the independent variable x into equal intervals
        self.linspace = np.linspace(np.min(self.x), np.max(self.x), samples)
        
        # Data along the dimension of dependent variable(ex--> Temperature over the day or year)
        # 1st data cluster- mean=0, variance=0.2, count=samples/2
        # 2nd data cluster- mean=0.5, variance=0.2, count=samples/2
        self.f = np.hstack((np.random.normal(0, 0.2, (1, samples//2)), np.random.normal(0.5, 0.2, (1, samples//2))))
        self.fmax = np.max(self.f)
        self.mses = []
        
        self.ax.scatter(self.x, self.f, marker= '.')
        
    
    def step_tree(self, i):
        splitting_at = self.linspace[i+1]
        # self.ax.clear()
        for artist in plt.gca().lines + plt.gca().collections:
            artist.remove()
        
        left_mean, right_mean = mean(self.x, self.f, splitting_at)
        left_mse, right_mse = mse(self.x, self.f, splitting_at, left_mean, right_mean)        
        self.mses.append(left_mse+right_mse)
        
        self.ax.scatter(self.x[self.x < splitting_at], self.f[self.x < splitting_at], c = 'c')
        self.ax.scatter(self.x[self.x > splitting_at], self.f[self.x > splitting_at], c = 'k')
        
        xllim, xrlim = self.ax.get_xlim()
        self.ax.axvline(splitting_at, 0, 1, c='r', linewidth=2)
        self.ax.axhline(left_mean, 0, (splitting_at-xllim)/(xrlim - xllim), c='c', linewidth=2, linestyle='-.', label='Left Mean')
        ymin = self.f.copy()
        ymax = ymin.copy()
        
        ymax[(self.x < splitting_at) & (self.f < left_mean)] = left_mean
        ymin[(self.x < splitting_at) & (self.f > left_mean)] = left_mean
        ymax[(self.x > splitting_at) & (self.f < right_mean)] = right_mean
        ymin[(self.x > splitting_at) & (self.f > right_mean)] = right_mean
        
        self.ax.vlines(self.x[(self.x < splitting_at)], ymin[(self.x < splitting_at)], ymax[(self.x < splitting_at)], color='c', linestyle='--')
        self.ax.vlines(self.x[(self.x > splitting_at)], ymin[(self.x > splitting_at)], ymax[(self.x > splitting_at)], color='k', linestyle='--')
        self.ax.axhline(right_mean, (splitting_at-xllim)/(xrlim - xllim), 1, c='k', linewidth=2, linestyle='-.', label='Right Mean')
        
        #print(len(self.linspace[:i+1]), len(self.mses))
        self.ax.plot(self.linspace[:len(self.mses)], self.mses, c='r', label='MSE error')
        if i == self.samples-3:
            # self.ax.clear()
            for artist in plt.gca().lines + plt.gca().collections:
                artist.remove()
            self.ax.scatter(self.x[self.x < self.linspace[np.argmin(self.mses)]], self.f[self.x < self.linspace[np.argmin(self.mses)]], c = 'c', label='leaf-1 points')
            self.ax.scatter(self.x[self.x > self.linspace[np.argmin(self.mses)]], self.f[self.x > self.linspace[np.argmin(self.mses)]], c = 'k', label='leaf-2 points')
            
            self.ax.axhline(left_mean, 0, (self.linspace[np.argmin(self.mses)]-xllim)/(xrlim - xllim), c='c', linewidth=2, linestyle='-.', label='Left Mean')
            self.ax.axhline(right_mean, (self.linspace[np.argmin(self.mses)]-xllim)/(xrlim - xllim), 1, c='k', linewidth=2, linestyle='-.', label='Right Mean')
            
            self.ax.plot(self.linspace[:len(self.mses)], self.mses, c='r', label='MSE error')
            
            # self.ax.scatter(self.linspace[np.argmin(self.mses)], np.min(self.mses), c='g', marker='*', label='Minimum MSE')
            self.ax.axvline(self.linspace[np.argmin(self.mses)], c='b', label='Decision boundary at minimum MSE')
            self.ax.set_title('Decision boundary is at '+str(self.linspace[np.argmin(self.mses)]))
        self.ax.legend(loc='upper left')
fps = 10 # frames per second
samples = 50 # number of data samples
assert samples % 2 == 0, "Please specify an even number of data samples"
draw_obj = Draw(samples)

Writer = animation.writers['ffmpeg']
writer = Writer(fps=fps, metadata=dict(artist='Sai'), bitrate=1800)

animation.FuncAnimation(draw_obj.fig, draw_obj.step_tree, frames=samples-2, interval=(1/fps)*1000, repeat=False).save('decision_tree.mp4', writer=writer)
Video("../images/decision_tree/decision_tree.mp4", embed=True)

Results

As the algorithm progresses/steps through the equally spaced intervals we get to see how the means of left and right data points with respect to the interval position(red vertical line) varies and how this influence the MSE error. Finally we determine the interval where the mean was the minimum and hence choose this interval i.e. point along the independent variable x which divides the data into two sections.

The results of the entire operation is the final splitting boundary which determines wether it is either day or night, summer or winter and based on this independent variable x the decision tree enables us to take an educated guess about the possible temperature or any other possible value of f(x), which is nothing but the mean values of left and right leafs.