From 10322957ab6d4b86397b2bbae3ed869c22ef366e Mon Sep 17 00:00:00 2001 From: Mathieu Blondel <mathieu@mblondel.org> Date: Mon, 29 Aug 2011 22:51:55 +0900 Subject: [PATCH] Fix doctest errors (hopefully!). --- doc/modules/multiclass.rst | 15 ++++++++------- scikits/learn/multiclass.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/doc/modules/multiclass.rst b/doc/modules/multiclass.rst index 07f33d634c..05a5c18813 100644 --- a/doc/modules/multiclass.rst +++ b/doc/modules/multiclass.rst @@ -47,12 +47,12 @@ fair default choice. Below is an example:: >>> X, y = iris.data, iris.target >>> OneVsRestClassifier(LinearSVC()).fit(X, y).predict(X) array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) One-Vs-One @@ -125,7 +125,7 @@ Example:: >>> from scikits.learn.svm import LinearSVC >>> iris = datasets.load_iris() >>> X, y = iris.data, iris.target - >>> OutputCodeClassifier(LinearSVC(), code_size=2).fit(X, y).predict(X) + >>> OutputCodeClassifier(LinearSVC(), code_size=2, random_state=0).fit(X, y).predict(X) array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, @@ -134,6 +134,7 @@ Example:: 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) + .. topic:: References: * [1] "Solving multiclass learning problems via error-correcting ouput codes", diff --git a/scikits/learn/multiclass.py b/scikits/learn/multiclass.py index 8918c1a297..977cc9b02c 100644 --- a/scikits/learn/multiclass.py +++ b/scikits/learn/multiclass.py @@ -22,9 +22,10 @@ improves. import numpy as np -from scikits.learn.base import BaseEstimator, ClassifierMixin, clone -from scikits.learn.preprocessing import LabelBinarizer -from scikits.learn.metrics.pairwise import euclidean_distances +from .base import BaseEstimator, ClassifierMixin, clone +from .preprocessing import LabelBinarizer +from .metrics.pairwise import euclidean_distances +from .utils import check_random_state def fit_binary(estimator, X, y): @@ -240,9 +241,10 @@ class OneVsOneClassifier(BaseEstimator, ClassifierMixin): return predict_ovo(self.estimators_, self.classes_, X) -def fit_ecoc(estimator, X, y, code_size): +def fit_ecoc(estimator, X, y, code_size=1.5, random_state=None): """Fit an error-correcting output-code strategy.""" check_estimator(estimator) + random_state = check_random_state(random_state) classes = np.unique(y) n_classes = classes.shape[0] @@ -250,7 +252,7 @@ def fit_ecoc(estimator, X, y, code_size): # FIXME: there are more elaborate methods than generating the codebook # randomly. - code_book = np.random.random((n_classes, code_size)) + code_book = random_state.random_sample((n_classes, code_size)) code_book[code_book > 0.5] = 1 if hasattr(estimator, "decision_function"): @@ -328,12 +330,13 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin): 2008. """ - def __init__(self, estimator, code_size=1.5): + def __init__(self, estimator, code_size=1.5, random_state=None): if (code_size <= 0): raise ValueError("code_size should be greater than 0!") self.estimator = estimator self.code_size = code_size + self.random_state = random_state def fit(self, X, y): """Fit underlying estimators. @@ -351,7 +354,7 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin): self """ self.estimators_, self.classes_, self.code_book_ = \ - fit_ecoc(self.estimator, X, y, self.code_size) + fit_ecoc(self.estimator, X, y, self.code_size, self.random_state) return self def predict(self, X): -- GitLab