Created
July 2, 2018 23:24
-
-
Save pjbull/8206464b19261fc1e55817246c1fa5ce to your computer and use it in GitHub Desktop.
VIF Calculations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"from sklearn.linear_model import LinearRegression\n", | |
"from statsmodels.stats.outliers_influence import variance_inflation_factor" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Setup data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1. , -0.74693463, 10.58762484, 10.64695346,\n", | |
" 9.59141025],\n", | |
" [ 1. , -0.33770897, 11.00045577, 9.3593188 ,\n", | |
" 9.53698997],\n", | |
" [ 1. , -0.1262439 , 10.60092148, 10.44927083,\n", | |
" 10.29898203],\n", | |
" ...,\n", | |
" [ 1. , 0.95956534, 100.88819035, 100.815621 ,\n", | |
" 102.04864684],\n", | |
" [ 1. , -2.24937759, 99.32643405, 100.84714279,\n", | |
" 99.83283641],\n", | |
" [ 1. , -0.9830458 , 100.02128736, 99.28083099,\n", | |
" 98.71993343]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"size = 5000000\n", | |
"n_corr_col = 3\n", | |
"\n", | |
"data = np.array([\n", | |
" np.ones(size), # intercept\n", | |
" np.random.randn(size), # non-colinear col,\n", | |
" *[np.linspace(10, 100, size) + np.random.randn(size) for _ in range(n_corr_col)] # colinear\n", | |
"])\n", | |
"\n", | |
"data = data.T\n", | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### `statsmodels` VIF\n", | |
"\n", | |
" - requires intercept column in data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[5.47942299070207, 1.0000003863432638, 450.72491870300564, 450.9618899491534, 450.5721576459639]\n", | |
"CPU times: user 12.7 s, sys: 3.92 s, total: 16.6 s\n", | |
"Wall time: 4.42 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"print([variance_inflation_factor(data, i) for i in range(data.shape[1])])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### `scikit-learn` VIF\n", | |
"\n", | |
" - accepts `has_intercept` parameter to indicate if data has intercept" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1.0, 1.0000003863432654, 450.72491870300564, 450.9618899491534, 450.5721576459639]\n", | |
"CPU times: user 6.36 s, sys: 2.66 s, total: 9.03 s\n", | |
"Wall time: 2.3 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"def skvif(data, i, has_intercept=True):\n", | |
" not_i = np.arange(data.shape[1]) != i\n", | |
" X, y = data[:, not_i], data[:, i]\n", | |
" r2 = LinearRegression(fit_intercept=not has_intercept).fit(X, y).score(X, y)\n", | |
" return 1. / (1. - r2)\n", | |
"\n", | |
"print([skvif(data, i) for i in range(data.shape[1])])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### `numpy` VIF\n", | |
"\n", | |
" - Accepts `intercept_ix` and removes intercept if it exists." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 1.00000039 450.7249187 450.96188995 450.57215765]\n", | |
"CPU times: user 393 ms, sys: 289 ms, total: 682 ms\n", | |
"Wall time: 176 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"def npvif(data, intercept_ix=None):\n", | |
" if intercept_ix is not None:\n", | |
" data = data[:, np.arange(data.shape[1]) != intercept_ix]\n", | |
" \n", | |
" return np.diag(np.linalg.inv(np.corrcoef(data, rowvar=0)))\n", | |
"\n", | |
"print(npvif(data, intercept_ix=0))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment