Last active
June 7, 2020 02:41
-
-
Save lynxnathan/9834926bc0e42fac5d840e6f4e9261ee to your computer and use it in GitHub Desktop.
Naive single decision tree implementation - wrote this on the context of fast.ai ml1 course
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const csv = require('csv/lib/sync'); | |
const fs = require('fs'); | |
// Naive single decision tree implementation in javascript | |
// Train.csv can be acquired in kaggle's blue book for bulldozers competition | |
class DecisionTree { | |
constructor(csvData, independentVariables, dependentVariable, options) { | |
this.options = options; | |
this.depth = 0; | |
this.columns = csvData.shift(); | |
this.data = csvData; | |
this.dependentVariableIndex = DecisionTree.getVariablesIndexes(this.columns, [dependentVariable]); | |
this.independentVariablesIndexes = DecisionTree.getVariablesIndexes(this.columns, independentVariables); | |
if (options.sampleSize > -1 && this.data.length > options.sampleSize) { | |
this.data = this.data.slice(0, options.sampleSize); | |
} | |
} | |
fit() { | |
let depth = 0; | |
const getBestSplitForBranch = (subtreeIndexes) => { | |
let bestSplit = {branchDiffusion: Infinity}; | |
depth += 1; | |
this.independentVariablesIndexes.forEach(independentVariableIndex => { | |
const split = this.getBestSplitForVariable(subtreeIndexes, independentVariableIndex); | |
if (split.branchDiffusion < bestSplit.branchDiffusion) { | |
bestSplit = split; | |
bestSplit.splitVariable = this.columns[independentVariableIndex]; | |
if (this.options.depth && depth <= this.options.depth) { | |
bestSplit.leftBranch = getBestSplitForBranch(split.leftBranchIndexes); | |
bestSplit.rightBranch = getBestSplitForBranch(split.rightBranchIndexes); | |
} | |
bestSplit.sampleSize = subtreeIndexes.length; | |
delete bestSplit.leftBranchIndexes; | |
delete bestSplit.rightBranchIndexes; | |
} | |
}); | |
return bestSplit; | |
}; | |
return getBestSplitForBranch(Array(this.data.length).fill().map((_, i) => i)); | |
}; | |
getBestSplitForVariable(subtreeIndexes, independentVariableIndex) { | |
const possibleValues = this.getValuesFromColumn(subtreeIndexes, independentVariableIndex); | |
let split = {branchDiffusion: Infinity, splitValue: null, leftBranchIndexes: [], rightBranchIndexes: []}; | |
possibleValues.forEach(splitValue => { | |
const leftBranchIndexes = []; | |
const rightBranchIndexes = []; | |
subtreeIndexes.forEach(index => { | |
parseFloat(this.data[index][independentVariableIndex]) <= parseFloat(splitValue) ? leftBranchIndexes.push(index) : rightBranchIndexes.push(index); | |
}); | |
let branchDiffusion = leftBranchIndexes.length * DecisionTree.standardDeviation( | |
leftBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex]))); | |
branchDiffusion += rightBranchIndexes.length * DecisionTree.standardDeviation( | |
rightBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex]))); | |
if (branchDiffusion < split.branchDiffusion) { | |
split = {branchDiffusion, leftBranchIndexes, rightBranchIndexes, splitValue}; | |
} | |
}); | |
return split; | |
}; | |
static getVariablesIndexes(columns, variables) { | |
const indexes = []; | |
variables.forEach(column => { | |
const columnIndex = columns.indexOf(column); | |
if (columnIndex !== -1) { | |
indexes.push(columnIndex) | |
} | |
}); | |
return indexes; | |
}; | |
getValuesFromColumn(subtreeIndexes, columnIndex) { | |
return subtreeIndexes.map(index => this.data[index][columnIndex]); | |
}; | |
static standardDeviation(values) { | |
const mean = values.reduce((sum, value) => sum + parseFloat(value), 0) / values.length; | |
const squaredMean = values.reduce((sum, value) => sum + Math.pow(value - mean, 2), 0) / values.length; | |
return Math.sqrt(squaredMean); | |
}; | |
} | |
const csvData = csv.parse(fs.readFileSync('Train.csv')); | |
const INDEPENDENT_VARIABLES = ['MachineHoursCurrentMeter', 'YearMade']; | |
const DEPENDENT_VARIABLE = 'SalePrice'; | |
console.log((new DecisionTree(csvData, INDEPENDENT_VARIABLES, DEPENDENT_VARIABLE, {sampleSize: 1000, depth: 1})).fit()); | |
// Result: | |
// | |
// { branchDiffusion: 672.0239238184622, | |
// splitValue: '2178', | |
// splitVariable: 'MachineHoursCurrentMeter', | |
// leftBranch: | |
// { branchDiffusion: 299.86590963795163, | |
// splitValue: '2003', | |
// splitVariable: 'YearMade', | |
// sampleSize: 470 }, | |
// rightBranch: | |
// { branchDiffusion: 349.137024345939, | |
// splitValue: '1997', | |
// splitVariable: 'YearMade', | |
// sampleSize: 530 }, | |
// sampleSize: 1000 } | |
// Jeremy's (fast.ai) Python implementation's result: | |
// ml1 course – lesson3 notebook | |
// n: 1000; val:10.160352993311724; score:672.0239238184623; split:2178.0; var:MachineHoursCurrentMeter | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment