Created
January 17, 2016 23:23
-
-
Save cwharland/66dcb10b7d7605a96413 to your computer and use it in GitHub Desktop.
Example of skflow's inability to recover weights?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.linear_model import LinearRegression | |
import skflow | |
import pytest | |
class TestLinearRegression: | |
def setup_class(self): | |
'''Simple linear model with a tiny amount of noise''' | |
rng = np.random.RandomState(67) | |
N = 1000 | |
n_weights = 10 | |
self.bias = 2 | |
self.X = rng.uniform(-1, 1, (N, n_weights)) | |
self.weights = 10 * rng.randn(n_weights) | |
self.y = np.dot(self.X, self.weights) | |
self.y += rng.randn(len(self.X)) * 0.05 + rng.normal(self.bias, 0.01) | |
def test_skflow(self): | |
'''Check that skflow LinearRegression can recover weights and bias''' | |
print('Fitting skflow model...') | |
regressor = skflow.TensorFlowLinearRegressor() | |
regressor.fit(self.X, self.y) | |
# Have to flatten weights since they come in (X, 1) shape | |
np.testing.assert_allclose(self.weights, | |
regressor.weights_.flatten(), | |
rtol=0.01) | |
assert abs(self.bias - regressor.bias_) < 0.1 | |
def test_sklearn(self): | |
'''Check that sklearn LinearRegression can recover weights and bias''' | |
print('Fitting sklearn model...') | |
regressor = LinearRegression() | |
regressor.fit(self.X, self.y) | |
np.testing.assert_allclose(self.weights, | |
regressor.coef_.flatten(), | |
rtol=0.01) | |
assert abs(self.bias - regressor.intercept_) < 0.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
scikit-learn LinearRegression passes just fine but skflow fails with:
The weights are way off (first are true weights, second are skflow). Doesn't seem to be fixed with standard scaling so I must be missing something.
Package details:
Python 3.5
tensorflow 0.6.0
skflow 0.0.1