Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import regreg.api as rr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(500, 1001)" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = np.loadtxt('X.csv', delimiter=',')\n", | |
| "n, p = X.shape\n", | |
| "# scale X\n", | |
| "X -= X.mean(0)[None, :]\n", | |
| "X /= (X.std(0)[None, :] * np.sqrt(n / (n-1)))\n", | |
| "X = np.hstack([np.ones((X.shape[0],1)), X])\n", | |
| "X.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(500,)" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Y = np.loadtxt('Y.csv', delimiter=',')\n", | |
| "Y.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", | |
| " 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", | |
| " 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "groups = np.array([-1] + list(np.multiply.outer(np.ones(50, np.int), np.arange(20)).reshape(-1)))\n", | |
| "groups[:50]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{-1: 0,\n", | |
| " 0: 4.47213595499958,\n", | |
| " 1: 4.47213595499958,\n", | |
| " 2: 4.47213595499958,\n", | |
| " 3: 4.47213595499958,\n", | |
| " 4: 4.47213595499958,\n", | |
| " 5: 4.47213595499958,\n", | |
| " 6: 4.47213595499958,\n", | |
| " 7: 4.47213595499958,\n", | |
| " 8: 4.47213595499958,\n", | |
| " 9: 4.47213595499958,\n", | |
| " 10: 4.47213595499958,\n", | |
| " 11: 4.47213595499958,\n", | |
| " 12: 4.47213595499958,\n", | |
| " 13: 4.47213595499958,\n", | |
| " 14: 4.47213595499958,\n", | |
| " 15: 4.47213595499958,\n", | |
| " 16: 4.47213595499958,\n", | |
| " 17: 4.47213595499958,\n", | |
| " 18: 4.47213595499958,\n", | |
| " 19: 4.47213595499958,\n", | |
| " 20: 4.47213595499958,\n", | |
| " 21: 4.47213595499958,\n", | |
| " 22: 4.47213595499958,\n", | |
| " 23: 4.47213595499958,\n", | |
| " 24: 4.47213595499958,\n", | |
| " 25: 4.47213595499958,\n", | |
| " 26: 4.47213595499958,\n", | |
| " 27: 4.47213595499958,\n", | |
| " 28: 4.47213595499958,\n", | |
| " 29: 4.47213595499958,\n", | |
| " 30: 4.47213595499958,\n", | |
| " 31: 4.47213595499958,\n", | |
| " 32: 4.47213595499958,\n", | |
| " 33: 4.47213595499958,\n", | |
| " 34: 4.47213595499958,\n", | |
| " 35: 4.47213595499958,\n", | |
| " 36: 4.47213595499958,\n", | |
| " 37: 4.47213595499958,\n", | |
| " 38: 4.47213595499958,\n", | |
| " 39: 4.47213595499958,\n", | |
| " 40: 4.47213595499958,\n", | |
| " 41: 4.47213595499958,\n", | |
| " 42: 4.47213595499958,\n", | |
| " 43: 4.47213595499958,\n", | |
| " 44: 4.47213595499958,\n", | |
| " 45: 4.47213595499958,\n", | |
| " 46: 4.47213595499958,\n", | |
| " 47: 4.47213595499958,\n", | |
| " 48: 4.47213595499958,\n", | |
| " 49: 4.47213595499958}" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lasso_weights = np.ones_like(groups)\n", | |
| "lasso_weights[0] = 0 # unpenalized intercept\n", | |
| "group_weights = dict([(-1,0)] + [(j, np.sqrt(20)) for j in range(50)]) # no group penalty on [-1] group -- intercept\n", | |
| "group_weights" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "penalty = rr.sparse_group_lasso(groups, lasso_weights, weights=group_weights, lagrange=1.)\n", | |
| "alpha = 0.95\n", | |
| "loss = rr.squared_error(X, Y, coef=1./n)\n", | |
| "lagrange_val = np.array([0.0051939447, 0.0046011431, 0.0040759999, 0.0036107929, 0.0031986816,\n", | |
| " 0.0028336058, 0.0025101972, 0.0022237004, 0.0019699023, 0.0017450710,\n", | |
| " 0.0015459005, 0.0013694619, 0.0012131608, 0.0010746989, 0.0009520400,\n", | |
| " 0.0008433807, 0.0007471229, 0.0006618514, 0.0005863122, 0.0005193945])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "enet_diag = np.ones(X.shape[1])\n", | |
| "enet_diag[0] = 0 # no ridge on intercept\n", | |
| "enet_loss = rr.quadratic_loss(X.shape[1], enet_diag, Qdiag=True)\n", | |
| "final_loss = rr.smooth_sum([loss, enet_loss])\n", | |
| "problem = rr.simple_problem(final_loss, penalty)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.28 s ± 37.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "soln = np.zeros(X.shape[1])\n", | |
| "for lagrange in lagrange_val:\n", | |
| " enet_loss.coef = (1 - alpha) * lagrange\n", | |
| " penalty.lagrange = alpha * lagrange\n", | |
| " problem.coefs[:] = soln\n", | |
| " soln = problem.solve(tol=1.e-12)\n", | |
| " soln[:] = problem.coefs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "cell_metadata_filter": "all,-slideshow" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment