From d7609b38a7d08bfc317a08bdf5912aeac08e184b Mon Sep 17 00:00:00 2001
From: Olivier Grisel <olivier.grisel@ensta.org>
Date: Fri, 10 Dec 2010 20:04:51 +0100
Subject: [PATCH] consistently rename n_comp to n_components

---
 .../applications/plot_face_recognition.py     |  4 +-
 examples/cluster/kmeans_digits.py             |  2 +-
 examples/plot_pca.py                          |  2 +-
 scikits/learn/fastica.py                      | 48 +++++++++----------
 scikits/learn/pca.py                          |  6 +--
 scikits/learn/tests/test_fastica.py           |  2 +-
 scikits/learn/utils/_csgraph.py               |  6 +--
 7 files changed, 35 insertions(+), 35 deletions(-)

diff --git a/examples/applications/plot_face_recognition.py b/examples/applications/plot_face_recognition.py
index b80cc7fd5e..2d17a9f719 100644
--- a/examples/applications/plot_face_recognition.py
+++ b/examples/applications/plot_face_recognition.py
@@ -34,7 +34,7 @@ import pylab as pl
 from scikits.learn.grid_search import GridSearchCV
 from scikits.learn.metrics import classification_report
 from scikits.learn.metrics import confusion_matrix
-from scikits.learn.pca import PCA
+from scikits.learn.pca import RandomizedPCA
 from scikits.learn.svm import SVC
 
 ################################################################################
@@ -115,7 +115,7 @@ y_train, y_test = y[:split], y[split:]
 n_components = 150
 
 print "Extracting the top %d eigenfaces" % n_components
-pca = PCA(n_comp=n_components, whiten=True, do_fast_svd=True).fit(X_train)
+pca = RandomizedPCA(n_components=n_components, whiten=True).fit(X_train)
 
 eigenfaces = pca.components_.T.reshape((n_components, 64, 64))
 
diff --git a/examples/cluster/kmeans_digits.py b/examples/cluster/kmeans_digits.py
index ba63405f3b..c902413f7f 100644
--- a/examples/cluster/kmeans_digits.py
+++ b/examples/cluster/kmeans_digits.py
@@ -51,7 +51,7 @@ print "Raw k-means with PCA-based centroid init..."
 # in this case the seeding of the centers is deterministic, hence we run the
 # kmeans algorithm only once with n_init=1
 t0 = time()
-pca = PCA(n_comp=n_digits).fit(data)
+pca = PCA(n_components=n_digits).fit(data)
 km = KMeans(init=pca.components_.T, k=n_digits, n_init=1).fit(data)
 print "done in %0.3fs" % (time() - t0)
 print "inertia: %f" % km.inertia_
diff --git a/examples/plot_pca.py b/examples/plot_pca.py
index 5832c0bd7d..7e40b2c736 100644
--- a/examples/plot_pca.py
+++ b/examples/plot_pca.py
@@ -25,7 +25,7 @@ X = iris.data
 y = iris.target
 target_names = iris.target_names
 
-pca = PCA(n_comp=2)
+pca = PCA(n_components=2)
 X_r = pca.fit(X).transform(X)
 
 # Percentage of variance explained for each components
diff --git a/scikits/learn/fastica.py b/scikits/learn/fastica.py
index 6deee471a9..9a0638e50b 100644
--- a/scikits/learn/fastica.py
+++ b/scikits/learn/fastica.py
@@ -55,11 +55,11 @@ def _ica_def(X, tol, g, gprime, fun_args, maxit, w_init):
     Used internally by FastICA.
     """
 
-    n_comp = w_init.shape[0]
-    W = np.zeros((n_comp, n_comp), dtype=float)
+    n_components = w_init.shape[0]
+    W = np.zeros((n_components, n_components), dtype=float)
 
     # j is the index of the extracted component
-    for j in range(n_comp):
+    for j in range(n_components):
         w = w_init[j, :].copy()
         w /= np.sqrt((w**2).sum())
 
@@ -114,7 +114,7 @@ def _ica_par(X, tol, g, gprime, fun_args, maxit, w_init):
     return W
 
 
-def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
+def fastica(X, n_components=None, algorithm="parallel", whiten=True,
             fun="logcosh", fun_prime='', fun_args={}, maxit=200,
             tol=1e-04, w_init=None):
     """Perform Fast Independent Component Analysis.
@@ -124,7 +124,7 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
     X : (n, p) array of shape = [n_samples, n_features]
         Training vector, where n_samples is the number of samples and
         n_features is the number of features.
-    n_comp : int, optional
+    n_components : int, optional
         Number of components to extract. If None no dimension reduction
         is performed.
     algorithm : {'parallel','deflation'}
@@ -151,22 +151,22 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
     tol : float
           A positive scalar giving the tolerance at which the
           un-mixing matrix is considered to have converged
-    w_init : (n_comp,n_comp) array
+    w_init : (n_components,n_components) array
              Initial un-mixing array of dimension (n.comp,n.comp).
              If None (default) then an array of normal r.v.'s is used
     source_only: if True, only the sources matrix is returned
 
     Results
     -------
-    K : (n_comp, p) array
+    K : (n_components, p) array
         pre-whitening matrix that projects data onto th first n.comp
         principal components. Returned only if whiten is True
-    W : (n_comp, n_comp) array
+    W : (n_components, n_components) array
         estimated un-mixing matrix
         The mixing matrix can be obtained by::
             w = np.dot(W, K.T)
             A = w.T * (w * w.T).I
-    S : (n_comp, n) array
+    S : (n_components, n) array
         estimated source matrix
 
 
@@ -227,11 +227,11 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
 
     n, p = X.shape
 
-    if n_comp is None:
-        n_comp = min(n, p)
-    if (n_comp > min(n, p)):
-        n_comp = min(n, p)
-        print("n_comp is too large: it will be set to %s" % n_comp)
+    if n_components is None:
+        n_components = min(n, p)
+    if (n_components > min(n, p)):
+        n_components = min(n, p)
+        print("n_components is too large: it will be set to %s" % n_components)
 
     if whiten:
         # Centering the columns (ie the variables)
@@ -241,7 +241,7 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
         u, d, _ = linalg.svd(X, full_matrices=False)
 
         del _
-        K = (u/d).T[:n_comp]  # see (6.33) p.140
+        K = (u/d).T[:n_components]  # see (6.33) p.140
         del u, d
         X1 = np.dot(K, X)
         # see (13.6) p.267 Here X1 is white and data
@@ -251,12 +251,12 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True,
     X1 *= np.sqrt(p)
 
     if w_init is None:
-        w_init = np.random.normal(size=(n_comp, n_comp))
+        w_init = np.random.normal(size=(n_components, n_components))
     else:
         w_init = np.asarray(w_init)
-        if w_init.shape != (n_comp, n_comp):
+        if w_init.shape != (n_components, n_components):
             raise ValueError("w_init has invalid shape -- should be %(shape)s"
-                             % {'shape': (n_comp, n_comp)})
+                             % {'shape': (n_components, n_components)})
 
     kwargs = {'tol': tol,
               'g': g,
@@ -283,7 +283,7 @@ class FastICA(BaseEstimator):
 
     Parameters
     ----------
-    n_comp : int, optional
+    n_components : int, optional
         Number of components to use. If none is passed, all are used.
     algorithm: {'parallel', 'deflation'}
         Apply parallel or deflational algorithm for FastICA
@@ -300,12 +300,12 @@ class FastICA(BaseEstimator):
         Maximum number of iterations during fit
     tol : float, optional
         Tolerance on update at each iteration
-    w_init: None of an (n_comp, n_comp) ndarray
+    w_init: None of an (n_components, n_components) ndarray
         The mixing matrix to be used to initialize the algorithm.
 
     Attributes
     ----------
-    unmixing_matrix_ : 2D array, [n_comp, n_samples]
+    unmixing_matrix_ : 2D array, [n_components, n_samples]
 
     Methods
     -------
@@ -322,11 +322,11 @@ class FastICA(BaseEstimator):
 
     """
 
-    def __init__(self, n_comp=None, algorithm='parallel', whiten=True,
+    def __init__(self, n_components=None, algorithm='parallel', whiten=True,
                 fun='logcosh', fun_prime='', fun_args={}, maxit=200, tol=1e-4,
                 w_init=None):
         super(FastICA, self).__init__()
-        self.n_comp = n_comp
+        self.n_components = n_components
         self.algorithm = algorithm
         self.whiten = whiten
         self.fun = fun
@@ -338,7 +338,7 @@ class FastICA(BaseEstimator):
 
     def fit(self, X, **params):
         self._set_params(**params)
-        whitening_, unmixing_, sources_ = fastica(X, self.n_comp,
+        whitening_, unmixing_, sources_ = fastica(X, self.n_components,
                         self.algorithm, self.whiten,
                         self.fun, self.fun_prime, self.fun_args, self.maxit,
                         self.tol, self.w_init)
diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py
index 3e475079a8..ce2e623d9a 100644
--- a/scikits/learn/pca.py
+++ b/scikits/learn/pca.py
@@ -129,17 +129,17 @@ class PCA(BaseEstimator):
 
     Attributes
     ----------
-    components_: array, [n_features, n_comp]
+    components_: array, [n_features, n_components]
         Components with maximum variance.
 
-    explained_variance_ratio_: array, [n_comp]
+    explained_variance_ratio_: array, [n_components]
         Percentage of variance explained by each of the selected components.
         k is not set then all components are stored and the sum of
         explained variances is equal to 1.0
 
     Notes
     -----
-    For n_comp='mle', this class uses the method of Thomas P. Minka:
+    For n_components='mle', this class uses the method of Thomas P. Minka:
     Automatic Choice of Dimensionality for PCA. NIPS 2000: 598-604
 
     Examples
diff --git a/scikits/learn/tests/test_fastica.py b/scikits/learn/tests/test_fastica.py
index a3fad0acdb..6bf7871f7e 100644
--- a/scikits/learn/tests/test_fastica.py
+++ b/scikits/learn/tests/test_fastica.py
@@ -121,7 +121,7 @@ def test_non_square_fastica(add_noise=False):
 
     center_and_norm(m)
 
-    k_, mixing_, s_ = fastica.fastica(m, n_comp=2)
+    k_, mixing_, s_ = fastica.fastica(m, n_components=2)
 
     # Check that the mixing model described in the docstring holds:
     np.testing.assert_almost_equal(s_, np.dot(np.dot(mixing_, k_), m))
diff --git a/scikits/learn/utils/_csgraph.py b/scikits/learn/utils/_csgraph.py
index e395c6b1d8..c577fc35a3 100644
--- a/scikits/learn/utils/_csgraph.py
+++ b/scikits/learn/utils/_csgraph.py
@@ -32,7 +32,7 @@ def cs_graph_components(x):
 
     Returns
     --------
-    n_comp: int
+    n_components: int
         The number of connected components.
     label: ndarray (ints, 1 dimension):
         The label array of each connected component (-2 is used to
@@ -74,8 +74,8 @@ def cs_graph_components(x):
 
     label = np.empty((shape[0],), dtype=x.indptr.dtype)
 
-    n_comp = _cs_graph_components(shape[0], x.indptr, x.indices, label)
+    n_components = _cs_graph_components(shape[0], x.indptr, x.indices, label)
 
-    return n_comp, label
+    return n_components, label
 
 
-- 
GitLab