Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": true, | |
| "lines_to_next_cell": 2 | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from regreg.smooth.cox import cox_loglike\n", | |
| "import regreg.api as rr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "ncase, nfeature, ndisease = 1000, 5000, 20\n", | |
| "losses = []\n", | |
| "for _ in range(ndisease):\n", | |
| " times = np.random.exponential(size=(ncase,))\n", | |
| " censoring = np.array([0]*int(0.3*ncase) + [1]*int(0.7*ncase))\n", | |
| " np.random.shuffle(censoring)\n", | |
| " losses.append(cox_loglike((ncase,),\n", | |
| " times,\n", | |
| " censoring))\n", | |
| "X = np.random.standard_normal((ncase, nfeature))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class cox_stacked(rr.smooth_atom):\n", | |
| "\n", | |
| " \n", | |
| " \n", | |
| " def __init__(self,\n", | |
| " losses,\n", | |
| " X,\n", | |
| " quadratic=None, \n", | |
| " initial=None,\n", | |
| " offset=None):\n", | |
| " \n", | |
| " self.losses = losses\n", | |
| " self.ndisease = len(losses)\n", | |
| " self.ncase, self.nfeature = X.shape\n", | |
| "\n", | |
| " self.X, self.X_T = X, X.T\n", | |
| "\n", | |
| " assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n", | |
| " \n", | |
| " rr.smooth_atom.__init__(self,\n", | |
| " (self.nfeature, self.ndisease),\n", | |
| " offset=offset,\n", | |
| " quadratic=quadratic,\n", | |
| " initial=initial)\n", | |
| " self._gradient = np.zeros((self.ndisease, self.ncase))\n", | |
| "\n", | |
| " def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
| "\n", | |
| " arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
| " linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n", | |
| " if mode == 'grad':\n", | |
| " for d in range(self.ndisease):\n", | |
| " self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n", | |
| " return self.scale(self._gradient.dot(self.X).T)\n", | |
| " elif mode == 'func':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " value += self.losses[d].smooth_objective(linpred[d], 'func')\n", | |
| " return self.scale(value)\n", | |
| " elif mode == 'both':\n", | |
| " value = 0\n", | |
| " for d in range(self.ndisease):\n", | |
| " f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n", | |
| " self._gradient[d] = g\n", | |
| " value += f\n", | |
| " return self.scale(value), self.scale(self._gradient.dot(self.X).T)\n", | |
| " else:\n", | |
| " raise ValueError(\"mode incorrectly specified\")\n", | |
| "\n", | |
| "loss = cox_stacked(losses, X)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Check the loss can be computed\n", | |
| "\n", | |
| "- We'll use `G` to compute $\\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "F, G = loss.smooth_objective(np.zeros((nfeature, ndisease)), 'both')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2.667806625366211" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty = rr.sparse_group_block(loss.shape, l1_weight=1., l2_weight=np.sqrt(nfeature), lagrange=1.)\n", | |
| "dual = penalty.conjugate\n", | |
| "lambda_max = dual.seminorm(G, lagrange=1)\n", | |
| "penalty.lagrange = lambda_max\n", | |
| "lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $\\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.63 s ± 308 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.99982" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9)\n", | |
| "np.mean(soln == 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.9 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "691 ms ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.9 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(4, 77)" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.9 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9)\n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.8 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "701 ms ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.8 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(56, 1063, 2.134244918823242, 2.1342453002929687)" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.8 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9) \n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.8 * lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing at $0.7 \\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "843 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "problem.coefs[:] = 0 # start at 0\n", | |
| "soln = problem.solve(tol=1.e-9)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(409, 7807, 1.8674650192260742, 1.8674646377563475)" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "penalty.lagrange = 0.7 * lambda_max\n", | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9) \n", | |
| "np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.7 * lambda_max" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "cell_metadata_filter": "all,-slideshow", | |
| "formats": "ipynb,Rmd" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment