Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import regreg.api as rr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(500, 1001)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = np.loadtxt('X.csv', delimiter=',')\n", | |
"n, p = X.shape\n", | |
"# scale X\n", | |
"X -= X.mean(0)[None, :]\n", | |
"X /= (X.std(0)[None, :] * np.sqrt(n / (n-1)))\n", | |
"X = np.hstack([np.ones((X.shape[0],1)), X])\n", | |
"X.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(500,)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Y = np.loadtxt('Y.csv', delimiter=',')\n", | |
"Y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", | |
" 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", | |
" 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"groups = np.array([-1] + list(np.multiply.outer(np.ones(50, np.int), np.arange(20)).reshape(-1)))\n", | |
"groups[:50]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{-1: 0,\n", | |
" 0: 4.47213595499958,\n", | |
" 1: 4.47213595499958,\n", | |
" 2: 4.47213595499958,\n", | |
" 3: 4.47213595499958,\n", | |
" 4: 4.47213595499958,\n", | |
" 5: 4.47213595499958,\n", | |
" 6: 4.47213595499958,\n", | |
" 7: 4.47213595499958,\n", | |
" 8: 4.47213595499958,\n", | |
" 9: 4.47213595499958,\n", | |
" 10: 4.47213595499958,\n", | |
" 11: 4.47213595499958,\n", | |
" 12: 4.47213595499958,\n", | |
" 13: 4.47213595499958,\n", | |
" 14: 4.47213595499958,\n", | |
" 15: 4.47213595499958,\n", | |
" 16: 4.47213595499958,\n", | |
" 17: 4.47213595499958,\n", | |
" 18: 4.47213595499958,\n", | |
" 19: 4.47213595499958,\n", | |
" 20: 4.47213595499958,\n", | |
" 21: 4.47213595499958,\n", | |
" 22: 4.47213595499958,\n", | |
" 23: 4.47213595499958,\n", | |
" 24: 4.47213595499958,\n", | |
" 25: 4.47213595499958,\n", | |
" 26: 4.47213595499958,\n", | |
" 27: 4.47213595499958,\n", | |
" 28: 4.47213595499958,\n", | |
" 29: 4.47213595499958,\n", | |
" 30: 4.47213595499958,\n", | |
" 31: 4.47213595499958,\n", | |
" 32: 4.47213595499958,\n", | |
" 33: 4.47213595499958,\n", | |
" 34: 4.47213595499958,\n", | |
" 35: 4.47213595499958,\n", | |
" 36: 4.47213595499958,\n", | |
" 37: 4.47213595499958,\n", | |
" 38: 4.47213595499958,\n", | |
" 39: 4.47213595499958,\n", | |
" 40: 4.47213595499958,\n", | |
" 41: 4.47213595499958,\n", | |
" 42: 4.47213595499958,\n", | |
" 43: 4.47213595499958,\n", | |
" 44: 4.47213595499958,\n", | |
" 45: 4.47213595499958,\n", | |
" 46: 4.47213595499958,\n", | |
" 47: 4.47213595499958,\n", | |
" 48: 4.47213595499958,\n", | |
" 49: 4.47213595499958}" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lasso_weights = np.ones_like(groups)\n", | |
"lasso_weights[0] = 0 # unpenalized intercept\n", | |
"group_weights = dict([(-1,0)] + [(j, np.sqrt(20)) for j in range(50)]) # no group penalty on [-1] group -- intercept\n", | |
"group_weights" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"penalty = rr.sparse_group_lasso(groups, lasso_weights, weights=group_weights, lagrange=1.)\n", | |
"alpha = 0.95\n", | |
"loss = rr.squared_error(X, Y, coef=1./n)\n", | |
"lagrange_val = np.array([0.0051939447, 0.0046011431, 0.0040759999, 0.0036107929, 0.0031986816,\n", | |
" 0.0028336058, 0.0025101972, 0.0022237004, 0.0019699023, 0.0017450710,\n", | |
" 0.0015459005, 0.0013694619, 0.0012131608, 0.0010746989, 0.0009520400,\n", | |
" 0.0008433807, 0.0007471229, 0.0006618514, 0.0005863122, 0.0005193945])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"enet_diag = np.ones(X.shape[1])\n", | |
"enet_diag[0] = 0 # no ridge on intercept\n", | |
"enet_loss = rr.quadratic_loss(X.shape[1], enet_diag, Qdiag=True)\n", | |
"final_loss = rr.smooth_sum([loss, enet_loss])\n", | |
"problem = rr.simple_problem(final_loss, penalty)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.28 s ± 37.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"soln = np.zeros(X.shape[1])\n", | |
"for lagrange in lagrange_val:\n", | |
" enet_loss.coef = (1 - alpha) * lagrange\n", | |
" penalty.lagrange = alpha * lagrange\n", | |
" problem.coefs[:] = soln\n", | |
" soln = problem.solve(tol=1.e-12)\n", | |
" soln[:] = problem.coefs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"jupytext": { | |
"cell_metadata_filter": "all,-slideshow" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment