Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "lines_to_next_cell": 2 | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(['X_list', 'y_list', 'censor_list'], dtype='<U11')" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from regreg.smooth.cox import cox_loglike\n", | |
| "import regreg.api as rr\n", | |
| "import regreg.affine as ra\n", | |
| "%load_ext rpy2.ipython\n", | |
| "%R load('instance.RData')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def load_data(idx):\n", | |
| " %R -i idx -o X X = X_list[[idx]]\n", | |
| " %R -o Y Y = y_list[[idx]]\n", | |
| " %R -o C C = censor_list[[idx]]\n", | |
| " return X, Y, C\n", | |
| "datasets = [load_data(idx) for idx in range(1, 21)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "losses = [rr.cox_loglike(Y.shape[0], Y.reshape(-1), C.reshape(-1), coef=1./Y.shape[0]) for _, Y, C in datasets]\n", | |
| "Xblock = ra.block_transform([X for X, _, _ in datasets])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "((5000, 20), (22100,))" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Xblock.input_shape, Xblock.output_shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(22100,)" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "class cox_stacked(rr.smooth_atom):\n", | |
| "\n", | |
| " def __init__(self,\n", | |
| " losses,\n", | |
| " X,\n", | |
| " quadratic=None, \n", | |
| " initial=None,\n", | |
| " offset=None):\n", | |
| " \n", | |
| " self.losses = losses\n", | |
| " self.ndisease = len(losses)\n", | |
| " self.nfeature = X.shape[0]\n", | |
| "\n", | |
| " self.X, self.X_T = X, X.T\n", | |
| " \n", | |
| " rr.smooth_atom.__init__(self,\n", | |
| " self.X.output_shape,\n", | |
| " offset=offset,\n", | |
| " quadratic=quadratic,\n", | |
| " initial=initial)\n", | |
| " self._gradient = np.zeros(X.output_shape)\n", | |
| "\n", | |
| " def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
| "\n", | |
| " arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
| " linpred = self.X.dot(arg) # (ndisease, ncase)\n", | |
| " if mode == 'grad':\n", | |
| " for d, slice in enumerate(self.X._slices):\n", | |
| " self._gradient[slice] = self.losses[d].smooth_objective(linpred[slice], 'grad')\n", | |
| " return self.scale(self.X_T.dot(self._gradient))\n", | |
| " elif mode == 'func':\n", | |
| " value = 0\n", | |
| " for d, slice in enumerate(self.X._slices):\n", | |
| " value += self.losses[d].smooth_objective(linpred[slice], 'func')\n", | |
| " return self.scale(value)\n", | |
| " elif mode == 'both':\n", | |
| " value = 0\n", | |
| " for d, slice in enumerate(self.X._slices):\n", | |
| " f, g = self.losses[d].smooth_objective(linpred[slice], 'both')\n", | |
| " self._gradient[slice] = g\n", | |
| " value += f\n", | |
| " return self.scale(value), self.scale(self.X_T.dot(self._gradient))\n", | |
| " else:\n", | |
| " raise ValueError(\"mode incorrectly specified\")\n", | |
| "\n", | |
| "loss = cox_stacked(losses, Xblock)\n", | |
| "loss.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Check the loss can be computed\n", | |
| "\n", | |
| "- We'll use `G` to compute $\\lambda_{\\max}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(5000, 20)" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "F, G = loss.smooth_objective(np.zeros(Xblock.input_shape), 'both')\n", | |
| "G.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.009502477943897247" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nfeature = Xblock.input_shape[0]\n", | |
| "alpha = 0.95\n", | |
| "penalty = rr.sparse_group_block(Xblock.input_shape, l1_weight=alpha, \n", | |
| " l2_weight=(1-alpha)*np.sqrt(nfeature), lagrange=1.)\n", | |
| "dual = penalty.conjugate\n", | |
| "lambda_max = dual.seminorm(G, lagrange=1)\n", | |
| "penalty.lagrange = lambda_max\n", | |
| "lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9)\n", | |
| "np.mean(soln == 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## First 10 values on logscale of length 100 down to 0.01" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.00950248, 0.00907058, 0.0086583 , 0.00826477, 0.00788912,\n", | |
| " 0.00753055, 0.00718828, 0.00686156, 0.00654969, 0.006252 ])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lagrange_vals = np.exp(np.linspace(0, np.log(0.01), 100))[:10] * lambda_max\n", | |
| "lagrange_vals" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Timing" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 0 0.009502477943897247 0.009502477943897247\n", | |
| "5 33 0.00907132774591446 0.009070575655810235\n", | |
| "17 119 0.008658461272716522 0.008658303993288064\n", | |
| "34 248 0.008264981210231781 0.008264770714102115\n", | |
| "76 587 0.007889382541179657 0.0078891241298101\n", | |
| "149 1230 0.007531158626079559 0.0075305512625238645\n", | |
| "235 2018 0.0071884579956531525 0.007188276085454979\n", | |
| "362 3209 0.0068618617951869965 0.0068615578434302205\n", | |
| "504 4537 0.006549973040819168 0.00654968944974222\n", | |
| "710 6535 0.006252247840166092 0.006251995955865733\n", | |
| "time: 43.948302\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from time import time\n", | |
| "toc = time()\n", | |
| "solns = []\n", | |
| "problem.coefs[:] = 0\n", | |
| "for lagrange in lagrange_vals:\n", | |
| " penalty.lagrange = lagrange\n", | |
| " soln = problem.solve(tol=1.e-12)\n", | |
| " solns.append(soln.copy())\n", | |
| " print(np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), lagrange)\n", | |
| "tic = time()\n", | |
| "print('time: %f' % (tic-toc))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## What if $\\alpha=1$?" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.040050357580184937" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nfeature = Xblock.input_shape[0]\n", | |
| "alpha = 1\n", | |
| "penalty = rr.sparse_group_block(Xblock.input_shape, l1_weight=alpha, \n", | |
| " l2_weight=(1-alpha)*np.sqrt(nfeature), lagrange=1.)\n", | |
| "dual = penalty.conjugate\n", | |
| "lambda_max = dual.seminorm(G, lagrange=1)\n", | |
| "penalty.lagrange = lambda_max\n", | |
| "lambda_max" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "problem = rr.simple_problem(loss, penalty)\n", | |
| "soln = problem.solve(tol=1.e-9)\n", | |
| "np.mean(soln == 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.04005036, 0.03823001, 0.03649239, 0.03483376, 0.03325051,\n", | |
| " 0.03173922, 0.03029663, 0.0289196 , 0.02760516, 0.02635046])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lagrange_vals = np.exp(np.linspace(0, np.log(0.01), 100))[:10] * lambda_max\n", | |
| "lagrange_vals" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 0 0.040050357580184937 0.040050357580184937\n", | |
| "3 3 0.03822973370552063 0.03823000701692012\n", | |
| "5 5 0.03649678826332092 0.03649239419617218\n", | |
| "11 11 0.034835249185562134 0.03483375855985143\n", | |
| "15 15 0.033251434564590454 0.033250510473037134\n", | |
| "25 26 0.03173968195915222 0.03173922345525575\n", | |
| "57 58 0.03029760718345642 0.03029662676485945\n", | |
| "101 103 0.028920933604240417 0.028919598320456204\n", | |
| "163 165 0.02760659158229828 0.02760515794407164\n", | |
| "237 241 0.02635185420513153 0.026350460911419748\n", | |
| "time: 30.266288\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from time import time\n", | |
| "toc = time()\n", | |
| "solns = []\n", | |
| "problem.coefs[:] = 0\n", | |
| "for lagrange in lagrange_vals:\n", | |
| " penalty.lagrange = lagrange\n", | |
| " soln = problem.solve(tol=1.e-12)\n", | |
| " solns.append(soln.copy())\n", | |
| " print(np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), lagrange)\n", | |
| "tic = time()\n", | |
| "print('time: %f' % (tic-toc))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "cell_metadata_filter": "all,-slideshow", | |
| "formats": "ipynb,Rmd" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment