Last active
October 9, 2019 21:43
-
-
Save glemaitre/8fcc24bdfc7dc38ca0c09c56e26b9386 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from sklearn.datasets import make_classification | |
from sklearn.model_selection import StratifiedShuffleSplit | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.pipeline import make_pipeline | |
from sklearn.model_selection import cross_validate | |
RANDOM_SEED = 2 | |
X, y = make_classification( | |
n_samples=400, | |
n_features=45, | |
n_informative=10, | |
n_redundant=0, | |
n_repeated=0, | |
n_classes=2, | |
n_clusters_per_class=1, | |
random_state=RANDOM_SEED, | |
shuffle=False | |
) | |
sss = StratifiedShuffleSplit( | |
n_splits=10, test_size=0.2, random_state=RANDOM_SEED | |
) | |
pipe = make_pipeline( | |
StandardScaler(), | |
LogisticRegression( | |
fit_intercept=True, | |
max_iter=100000, | |
C=1e9, | |
tol=1e-8, | |
verbose=1, | |
penalty='none' | |
) | |
) | |
# %% | |
pipe.set_params(logisticregression__solver='lbfgs') | |
results_lbfgs = cross_validate( | |
pipe, X, y, | |
cv=sss, | |
return_estimator=True, | |
n_jobs=-1 | |
) | |
# %% | |
pipe.set_params(logisticregression__solver='saga') | |
results_saga = cross_validate( | |
pipe, X, y, | |
cv=sss, | |
return_estimator=True, | |
n_jobs=-1 | |
) | |
# %% | |
import matplotlib.pyplot as plt | |
for pipe_lbfgs, pipe_saga in zip(results_lbfgs['estimator'], | |
results_saga['estimator']): | |
coef_lbfgs = pipe_lbfgs[-1].coef_.ravel() | |
coef_saga = pipe_saga[-1].coef_.ravel() | |
plt.figure() | |
plt.plot(coef_lbfgs, 'r.--', label='lbfgs') | |
plt.plot(coef_saga, 'bs-.', label='saga') | |
plt.title('#iterations: {} LBFGS - {} SAGA'.format( | |
pipe_lbfgs[-1].n_iter_, | |
pipe_saga[-1].n_iter_ | |
)) | |
plt.legend() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment