Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "lines_to_next_cell": 2 | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from regreg.smooth.cox import cox_loglike\n", | |
| "import regreg.api as rr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "ncase, nfeature, ndisease = 1000, 5000, 20\n", | |
| "losses = []\n", | |
| "for _ in range(ndisease):\n", | |
| " times = np.random.exponential(size=(ncase,))\n", | |
| " censoring = np.array([0]*int(0.3*ncase) + [1]*int(0.7*ncase))\n", | |
| " np.random.shuffle(censoring)\n", | |
| " losses.append(cox_loglike((ncase,),\n", | |
| " times,\n", | |
| " censoring))\n", | |
| "X = np.random.standard_normal((ncase, nfeature))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class cox_stacked(rr.smooth_atom): \n", | |
| " \n", | |
| " def __init__(self,\n", | |
| " losses,\n", | |
| " X,\n", | |
| " quadratic=None, \n", | |
| " initial=None,\n", | |
| " offset=None):\n", | |
| " \n", | |
| " self.losses = losses\n", | |
| " self.ndisease = len(losses)\n", | |
| " self.ncase, self.nfeature = X.shape\n", | |
| "\n", | |
| " self.X, self.X_T = X, X.T\n", | |
| "\n", | |
| " assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n", | |
| " \n", | |
| " rr.smooth_atom.__init__(self,\n", | |
| " (self.nfeature, self.ndisease),\n", | |
| " offset=offset,\n", | |
| " quadratic=quadratic,\n", | |
| " initial=initial)\n", | |
| " self._gradient = np.zeros((self.ndisease, self.ncase))\n", | |
| "\n", | |
| " def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
| "\n", | |
| " arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
| " linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n", | |
| " if mode == 'grad':\n", | |
| " for d in range(self.ndisease):\n", | |
| " self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n", | |
| " return self.scale(self._gradient.dot(self.X).T)\n", | |
| " elif mode == 'func':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " value += self.losses[d].smooth_objective(linpred[d], 'func')\n", | |
| " return self.scale(value)\n", | |
| " elif mode == 'both':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n", | |
| " self._gradient[d] = g\n", | |
| " value += f\n", | |
| " return self.scale(value), self.scale(self._gradient.dot(self.X).T)\n", | |
| " else:\n", | |
| " raise ValueError(\"mode incorrectly specified\")\n", | |
| "\n", | |
| "loss = cox_stacked(losses, X)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Check the loss can be computed\n", | |
| "\n", | |
| "- We'll use `G` to compute $\\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "F, G = loss.smooth_objective(np.zeros((nfeature, ndisease)), 'both')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "25.649429321289062" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty = rr.sparse_group_block(loss.shape, l1_weight=1., l2_weight=np.sqrt(ndisease), lagrange=1.)\n", | |
| "dual = penalty.conjugate\n", | |
| "lambda_max = dual.seminorm(G, lagrange=1)\n", | |
| "penalty.lagrange = lambda_max\n", | |
| "lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $\\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.67 s ± 392 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.99987" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)\n", | |
| "np.mean(soln == 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.9 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.48 s ± 360 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.9 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(5, 52)" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.9 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)\n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.8 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.14 s ± 234 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.8 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(95, 1032, 20.519546508789062, 20.51954345703125)" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.8 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100) \n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.8 * lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.7 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.04 s ± 148 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(475, 5351, 17.954605102539062, 17.954600524902343)" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100) \n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.7 * lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Comparison with sparse group LASSO (not block)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "groups = np.multiply.outer(np.arange(nfeature), np.ones(ndisease)).reshape(-1)\n", | |
| "group_penalty = rr.sparse_group_lasso(groups, np.ones(nfeature * ndisease), lagrange=0.7 * lambda_max)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class cox_stacked_flat(rr.smooth_atom): \n", | |
| " \n", | |
| " def __init__(self,\n", | |
| " losses,\n", | |
| " X,\n", | |
| " quadratic=None, \n", | |
| " initial=None,\n", | |
| " offset=None):\n", | |
| " \n", | |
| " self.losses = losses\n", | |
| " self.ndisease = len(losses)\n", | |
| " self.ncase, self.nfeature = X.shape\n", | |
| "\n", | |
| " self.X, self.X_T = X, X.T\n", | |
| "\n", | |
| " assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n", | |
| " \n", | |
| " rr.smooth_atom.__init__(self,\n", | |
| " (self.nfeature * self.ndisease,),\n", | |
| " offset=offset,\n", | |
| " quadratic=quadratic,\n", | |
| " initial=initial)\n", | |
| " self._gradient = np.zeros((self.ndisease, self.ncase))\n", | |
| "\n", | |
| " def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
| "\n", | |
| " arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
| " arg = arg.reshape((self.nfeature, self.ndisease))\n", | |
| " linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n", | |
| " if mode == 'grad':\n", | |
| " for d in range(self.ndisease):\n", | |
| " self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n", | |
| " return self.scale(self._gradient.dot(self.X).T.reshape(-1))\n", | |
| " elif mode == 'func':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " value += self.losses[d].smooth_objective(linpred[d], 'func')\n", | |
| " return self.scale(value)\n", | |
| " elif mode == 'both':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n", | |
| " self._gradient[d] = g\n", | |
| " value += f\n", | |
| " return self.scale(value), self.scale(self._gradient.dot(self.X).T.reshape(-1))\n", | |
| " else:\n", | |
| " raise ValueError(\"mode incorrectly specified\")\n", | |
| "\n", | |
| "loss_flat = cox_stacked_flat(losses, X)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "5.27 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Maybe we haven't solved enough -- let's compare starting at a random point\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = 0\n", | |
| "soln_flat = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n", | |
| "soln_flat_r = problem.solve(tol=1.e-9, min_its=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.0022612382977817364" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Let's up the number of iterations a bit" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = 0\n", | |
| "soln_flat = problem.solve(tol=1.e-9, min_its=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n", | |
| "soln_flat_r = problem.solve(tol=1.e-9, min_its=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "6.217488899625311e-14" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat) # now we've essentially found same solution" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Time at 200 iterations\n", | |
| "\n", | |
| "Should be roughly double the time but could be a bit less because Lipschitz constant (inverse stepsize) may have settled down so less need for backtracking." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "10.1 s ± 265 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "group_penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.46 s ± 55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9, min_its=200)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "cell_metadata_filter": "all,-slideshow", | |
| "formats": "ipynb,Rmd" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment