Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GridSearchColorPlot tests + minor gridsearch change, part of #308 #1097

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
131 changes: 131 additions & 0 deletions tests/test_gridsearch/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# # tests.test_gridsearch.test_base.py
# # Test the GridSearchColorPlot (standard and quick visualizers).
# #
# # Author: Tan Tran
# # Created: Sat Aug 29 12:00:00 2020 -0400
# #
# # Copyright (C) 2020 The scikit-yb developers
# # For license information, see LICENSE.txt
# #

"""
Test the GridSearchColorPlot visualizer.
"""

# ##########################################################################
# ## Imports
# ##########################################################################

import pytest

from tests.base import VisualTestCase
from tests.fixtures import Dataset

from yellowbrick.datasets import load_occupancy
from yellowbrick.gridsearch import GridSearchColorPlot, gridsearch_color_plot

from sklearn.datasets import make_classification
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

import pandas as pd

# ##########################################################################
# ## Test fixtures
# ##########################################################################

@pytest.fixture(scope="class")
def binary(request):
"""
Creates a random binary classification dataset fixture
"""
X, y = make_classification(
n_samples=1000,
n_features=4,
n_informative=2,
n_redundant=2,
n_classes=2,
n_clusters_per_class=2,
random_state=1234,
)

request.cls.binary = Dataset(X, y)

@pytest.fixture(scope="class")
def gridsearchcv(request):
"""
Creates an sklearn SVC, a GridSearchCV for testing through the SVC's kernel,
gamma, and C parameters, and returns the GridSearchCV.
"""

svc = SVC()
grid = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [0.01, 0.1, 1, 10]},
{'kernel': ['linear'], 'C': [0.01, 0.1, 1, 10]}]
gridsearchcv = GridSearchCV(svc, grid, n_jobs=4)

request.cls.gridsearchcv = gridsearchcv

@pytest.mark.usefixtures("binary", "gridsearchcv")
class TestGridSearchColorPlot(VisualTestCase):
"""
Tests of basic GridSearchColorPlot functionality
"""

# ##########################################################################
# ## GridSearchColorPlot Base Test Cases
# ##########################################################################

def test_gridsearchcolorplot(self):
"""
Test GridSearchColorPlot drawing
"""

gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel')
gs_viz.fit(self.binary.X, self.binary.y)
self.assert_images_similar(gs_viz, tol=0.95)

def test_quick_method(self):
"""
Test gridsearch_color_plot quick method
"""

gs = self.gridsearchcv

# If no X data is passed to quick method, model is assumed to be fit
# already
gs.fit(self.binary.X, self.binary.y)

gs_viz = gridsearch_color_plot(gs, 'gamma', 'C')
assert isinstance(gs_viz, GridSearchColorPlot)
self.assert_images_similar(gs_viz, tol=0.95)

# ##########################################################################
# ## Integration Tests
# ##########################################################################

@pytest.mark.skipif(pd is None, reason="test requires pandas")
def test_pandas_integration(self):
"""
Test GridSearchColorPlot on sklearn occupancy data set (as pandas df)
"""

X, y = load_occupancy(return_dataset=True).to_pandas()
X, y = X.head(1000), y.head(1000)

gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel')
gs_viz.fit(X, y)

self.assert_images_similar(gs_viz, tol=0.95)

def test_numpy_integration(self):
"""
Test GridSearchColorPlot on sklearn occupancy data set (as numpy df)
"""

X, y = load_occupancy(return_dataset=True).to_numpy()
X, y = X[:1000], y[:1000]

gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel')
gs_viz.fit(X, y)

self.assert_images_similar(gs_viz, tol=0.95)
8 changes: 4 additions & 4 deletions yellowbrick/gridsearch/pcolor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def gridsearch_color_plot(model, x_param, y_param, X=None, y=None, ax=None, **kw

Returns
-------
ax : matplotlib axes
Returns the axes that the classification report was drawn on.
visualizer : GridSearchColorPlot
Returns visualizer
"""
# Instantiate the visualizer
visualizer = GridSearchColorPlot(model, x_param, y_param, ax=ax, **kwargs)
Expand All @@ -80,8 +80,8 @@ def gridsearch_color_plot(model, x_param, y_param, X=None, y=None, ax=None, **kw
else:
visualizer.draw()

# Return the axes object on the visualizer
return visualizer.ax
# Return the visualizer
return visualizer


class GridSearchColorPlot(GridSearchVisualizer):
Expand Down