Created
May 27, 2020 12:33
-
-
Save glemaitre/4f56118e42018c1906fff7744a2d0fac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.mark.parametrize("name, Tree", REG_TREES.items()) | |
@pytest.mark.parametrize("criterion", REG_CRITERIONS) | |
def test_diabetes_overfit(name, Tree, criterion): | |
# check consistency of overfitted trees on the diabetes dataset | |
# since the trees will overfit, we expect an MSE of 0 | |
reg = Tree(criterion=criterion, random_state=0) | |
reg.fit(diabetes.data, diabetes.target) | |
score = mean_squared_error(diabetes.target, reg.predict(diabetes.data)) | |
assert score == pytest.approx(0), ( | |
f"Failed with {name}, criterion = {criterion} and score = {score}" | |
) | |
@pytest.mark.parametrize("name, Tree", REG_TREES.items()) | |
@pytest.mark.parametrize( | |
"criterion, max_depth", | |
[("mse", 15), ("mae", 20), ("friedman_mse", 15)] | |
) | |
def test_diabetes_underfit(name, Tree, criterion, max_depth): | |
# check consistency of trees when the depth and the number of features are | |
# limited | |
reg = Tree( | |
criterion=criterion, max_depth=max_depth, | |
max_features=6, random_state=0 | |
) | |
reg.fit(diabetes.data, diabetes.target) | |
score = mean_squared_error(diabetes.target, reg.predict(diabetes.data)) | |
assert score < 60 and score > 0, ( | |
f"Failed with {name}, criterion = {criterion} and score = {score}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment