The green decision boundary uses momentum, while the grey decision boundary does not. Inspired by a recent article on Distill.
Last active
April 14, 2017 23:53
-
-
Save feyderm/6bd8e75420d7aff0b19aa204651eab76 to your computer and use it in GitHub Desktop.
Exploring Gradient Descent with Momentum
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
text { | |
font-family: sans-serif; | |
fill: #000000; | |
} | |
.pts { | |
stroke: #595959; | |
} | |
.group1 { | |
fill: steelblue; | |
} | |
.group2 { | |
fill: red; | |
} | |
line { | |
fill: none; | |
opacity: 0.6; | |
} | |
#dec_boundary { | |
stroke: #000000; | |
stroke-width: 2px; | |
} | |
#dec_boundary_m { | |
stroke: #008000; | |
stroke-width: 4px; | |
} | |
#beta_val { | |
font-family: sans-serif; | |
position: relative; | |
left: 20px; | |
} | |
</style> | |
<body> | |
<!--range slider for beta (i.e. momentum coefficient)--> | |
<form> | |
<input type="range" name="beta" min="0" max="1.0" step="0.01" oninput=displayBeta(this.value) onchange=runGradientDescent(this.value)> | |
<label id="beta_val"></label> | |
</form> | |
<!--viz--> | |
<div id="chart"></div> | |
<script src="https://d3js.org/d3.v4.min.js"></script> | |
<script src="http://feyderm.github.io/math/math.js"></script> | |
<script type="text/javascript"> | |
// dims | |
var margin = {top: 20, right: 0, bottom: 50, left: 85}, | |
svg_dx = 500, | |
svg_dy = 400, | |
plot_dx = svg_dx - margin.right - margin.left, | |
plot_dy = svg_dy - margin.top - margin.bottom; | |
// scales | |
var xPos = d3.scaleLinear() | |
.range([margin.left, plot_dx]), | |
yPos = d3.scaleLinear() | |
.range([plot_dy, margin.top]); | |
var svg = d3.select("#chart") | |
.append("svg") | |
.attr("width", svg_dx) | |
.attr("height", svg_dy); | |
d3.csv("logistic_reg_grad_decent.csv", d => { | |
xPos.domain(d3.extent(d, d => +d.x)); | |
yPos.domain(d3.extent(d, d => +d.y)); | |
plotAxes(d3.axisBottom(xPos), d3.axisLeft(yPos)); | |
plotPts(d); | |
runGradientDescent(0.5); // initial beta = 0.5 | |
}); | |
function runGradientDescent(beta) { | |
removeDecBnds(); | |
displayBeta(beta); | |
var d = d3.selectAll(".pts").data(); | |
var d_extent_x = d3.extent(d, pt => +pt.x); | |
var X = d.map(pt => [1, +pt.x, +pt.y]), | |
y = d.map(pt => +pt.group); | |
X = math.matrix(X); | |
y = math.matrix(y); | |
var iteration = 0, | |
iterationNumber = 400, | |
m = math.subset(math.size(X), math.index(0)), | |
alpha = 0.0004, | |
velocity = math.matrix([0.0, 0.0, 0.0]), | |
theta = math.matrix([-24, 0.5, 0.2]), | |
theta_m = math.matrix([-24, 0.5, 0.2]); | |
// decision boundary w/o momentum | |
var dec_bnd = svg.append("line") | |
.attr("class", "dec_boundary") | |
.attr("id", "dec_boundary"); | |
// decision boundary w/ momentum | |
var dec_bnd_m = svg.append("line") | |
.attr("class", "dec_boundary") | |
.attr("id", "dec_boundary_m"); | |
var iterate = d3.timer(() => { | |
// update theta w/o momentum and plot decision boundary | |
var h = math.multiply(X, theta).map(z => sigmoid(z)), | |
grad = computeGradient(m, y, h, X); | |
theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i)))) | |
updateDecisionBoundary(dec_bnd, theta, d_extent_x); | |
// update theta w/ momemtum and plot decision boundary | |
var h_m = math.multiply(X, theta_m).map(z => sigmoid(z)), | |
grad_m = computeGradient(m, y, h_m, X); | |
// velocity = beta * velocity + grad_m | |
velocity = math.add(math.multiply(beta, velocity), grad_m); | |
theta_m = theta_m.map((t, i) => t - (alpha * math.subset(velocity, math.index(i)))) | |
updateDecisionBoundary(dec_bnd_m, theta_m, d_extent_x); | |
if (iteration++ > iterationNumber) { | |
iterate.stop(); | |
} | |
}, 200) | |
} | |
function updateDecisionBoundary(dec_bnd, theta, d_extent_x) { | |
var theta0 = math.subset(theta, math.index(0)), | |
theta1 = math.subset(theta, math.index(1)), | |
theta2 = math.subset(theta, math.index(2)); | |
dec_bnd.attr("x1",xPos(d_extent_x[0])) | |
.attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0))) | |
.attr("x2",xPos(d_extent_x[1])) | |
.attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0))); | |
} | |
function sigmoid(z) { | |
var s = 1 / (1 + Math.pow(Math.E, -z)); | |
return s; | |
} | |
function computeGradient(m, y, h, X) { | |
// conversion from octave of grad = (1 / m) * (h - y)' * X; | |
var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X) | |
.map(d => (1 / m) * d); | |
return grad; | |
} | |
function removeDecBnds() { | |
d3.selectAll(".dec_boundary").remove(); | |
} | |
function displayBeta(beta) { | |
d3.select("#beta_val") | |
.text("Momentum Coefficient: " + beta); | |
} | |
function plotPts(d) { | |
svg.append("g") | |
.selectAll("path") | |
.data(d) | |
.enter() | |
.append("path") | |
.attr("class", d => d.group == "1" ? "pts group1" : "pts group2") | |
.attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross)) | |
.attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")"); | |
} | |
function plotAxes(x, y) { | |
svg.append("g") | |
.attr("id", "axis_x") | |
.attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")") | |
.call(x); | |
svg.append("g") | |
.attr("id", "axis_y") | |
.attr("transform", "translate(" + (margin.left / 2) + ", 0)") | |
.call(y); | |
} | |
</script> | |
</body> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x | y | group | |
---|---|---|---|
34.62365962451697 | 78.0246928153624 | 0 | |
30.28671076822607 | 43.89499752400101 | 0 | |
35.84740876993872 | 72.90219802708364 | 0 | |
60.18259938620976 | 86.30855209546826 | 1 | |
79.0327360507101 | 75.3443764369103 | 1 | |
45.08327747668339 | 56.3163717815305 | 0 | |
61.10666453684766 | 96.51142588489624 | 1 | |
75.02474556738889 | 46.55401354116538 | 1 | |
76.09878670226257 | 87.42056971926803 | 1 | |
84.43281996120035 | 43.53339331072109 | 1 | |
95.86155507093572 | 38.22527805795094 | 0 | |
75.01365838958247 | 30.60326323428011 | 0 | |
82.30705337399482 | 76.48196330235604 | 1 | |
69.36458875970939 | 97.71869196188608 | 1 | |
39.53833914367223 | 76.03681085115882 | 0 | |
53.9710521485623 | 89.20735013750205 | 1 | |
69.07014406283025 | 52.74046973016765 | 1 | |
67.94685547711617 | 46.67857410673128 | 0 | |
70.66150955499435 | 92.92713789364831 | 1 | |
76.97878372747498 | 47.57596364975532 | 1 | |
67.37202754570876 | 42.83843832029179 | 0 | |
89.67677575072079 | 65.79936592745237 | 1 | |
50.534788289883 | 48.85581152764205 | 0 | |
34.21206097786789 | 44.20952859866288 | 0 | |
77.9240914545704 | 68.9723599933059 | 1 | |
62.27101367004632 | 69.95445795447587 | 1 | |
80.1901807509566 | 44.82162893218353 | 1 | |
93.114388797442 | 38.80067033713209 | 0 | |
61.83020602312595 | 50.25610789244621 | 0 | |
38.78580379679423 | 64.99568095539578 | 0 | |
61.379289447425 | 72.80788731317097 | 1 | |
85.40451939411645 | 57.05198397627122 | 1 | |
52.10797973193984 | 63.12762376881715 | 0 | |
52.04540476831827 | 69.43286012045222 | 1 | |
40.23689373545111 | 71.16774802184875 | 0 | |
54.63510555424817 | 52.21388588061123 | 0 | |
33.91550010906887 | 98.86943574220611 | 0 | |
64.17698887494485 | 80.90806058670817 | 1 | |
74.78925295941542 | 41.57341522824434 | 0 | |
34.1836400264419 | 75.2377203360134 | 0 | |
83.90239366249155 | 56.30804621605327 | 1 | |
51.54772026906181 | 46.85629026349976 | 0 | |
94.44336776917852 | 65.56892160559052 | 1 | |
82.36875375713919 | 40.61825515970618 | 0 | |
51.04775177128865 | 45.82270145776001 | 0 | |
62.22267576120188 | 52.06099194836679 | 0 | |
77.19303492601364 | 70.45820000180959 | 1 | |
97.77159928000232 | 86.7278223300282 | 1 | |
62.07306379667647 | 96.76882412413983 | 1 | |
91.56497449807442 | 88.69629254546599 | 1 | |
79.94481794066932 | 74.16311935043758 | 1 | |
99.2725269292572 | 60.99903099844988 | 1 | |
90.54671411399852 | 43.39060180650027 | 1 | |
34.52451385320009 | 60.39634245837173 | 0 | |
50.2864961189907 | 49.80453881323059 | 0 | |
49.58667721632031 | 59.80895099453265 | 0 | |
97.64563396007767 | 68.86157272420604 | 1 | |
32.57720016809309 | 95.59854761387875 | 0 | |
74.24869136721598 | 69.82457122657193 | 1 | |
71.79646205863379 | 78.45356224515052 | 1 | |
75.3956114656803 | 85.75993667331619 | 1 | |
35.28611281526193 | 47.02051394723416 | 0 | |
56.25381749711624 | 39.26147251058019 | 0 | |
30.05882244669796 | 49.59297386723685 | 0 | |
44.66826172480893 | 66.45008614558913 | 0 | |
66.56089447242954 | 41.09209807936973 | 0 | |
40.45755098375164 | 97.53518548909936 | 1 | |
49.07256321908844 | 51.88321182073966 | 0 | |
80.27957401466998 | 92.11606081344084 | 1 | |
66.74671856944039 | 60.99139402740988 | 1 | |
32.72283304060323 | 43.30717306430063 | 0 | |
64.0393204150601 | 78.03168802018232 | 1 | |
72.34649422579923 | 96.22759296761404 | 1 | |
60.45788573918959 | 73.09499809758037 | 1 | |
58.84095621726802 | 75.85844831279042 | 1 | |
99.82785779692128 | 72.36925193383885 | 1 | |
47.26426910848174 | 88.47586499559782 | 1 | |
50.45815980285988 | 75.80985952982456 | 1 | |
60.45555629271532 | 42.50840943572217 | 0 | |
82.22666157785568 | 42.71987853716458 | 0 | |
88.9138964166533 | 69.80378889835472 | 1 | |
94.83450672430196 | 45.69430680250754 | 1 | |
67.31925746917527 | 66.58935317747915 | 1 | |
57.23870631569862 | 59.51428198012956 | 1 | |
80.36675600171273 | 90.96014789746954 | 1 | |
68.46852178591112 | 85.59430710452014 | 1 | |
42.0754545384731 | 78.84478600148043 | 0 | |
75.47770200533905 | 90.42453899753964 | 1 | |
78.63542434898018 | 96.64742716885644 | 1 | |
52.34800398794107 | 60.76950525602592 | 0 | |
94.09433112516793 | 77.15910509073893 | 1 | |
90.44855097096364 | 87.50879176484702 | 1 | |
55.48216114069585 | 35.57070347228866 | 0 | |
74.49269241843041 | 84.84513684930135 | 1 | |
89.84580670720979 | 45.35828361091658 | 1 | |
83.48916274498238 | 48.38028579728175 | 1 | |
42.2617008099817 | 87.10385094025457 | 1 | |
99.31500880510394 | 68.77540947206617 | 1 | |
55.34001756003703 | 64.9319380069486 | 1 | |
74.77589300092767 | 89.52981289513276 | 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment