diff --git a/doc/modules/manifold.rst b/doc/modules/manifold.rst
index a0a65cd82729e0be051300f35786fd361ddf2565..0a3d18b31b9f06047113fe960e84d9a890a3cea7 100644
--- a/doc/modules/manifold.rst
+++ b/doc/modules/manifold.rst
@@ -212,7 +212,7 @@ vectors in each neighborhood.  This is the essence of *modified locally
 linear embedding* (MLLE).  MLLE can be  performed with function
 :func:`locally_linear_embedding` or its object-oriented counterpart
 :class:`LocallyLinearEmbedding`, with the keyword ``method = 'modified'``.
-It requires ``n_neighbors > out_dim``.
+It requires ``n_neighbors > n_components``.
 
 .. figure:: ../auto_examples/manifold/images/plot_lle_digits_7.png
    :target: ../auto_examples/manifold/plot_lle_digits.html
@@ -262,7 +262,7 @@ improvements which make its cost comparable to that of other LLE variants
 for small output dimension.  HLLE can be  performed with function
 :func:`locally_linear_embedding` or its object-oriented counterpart
 :class:`LocallyLinearEmbedding`, with the keyword ``method = 'hessian'``.
-It requires ``n_neighbors > out_dim * (out_dim + 3) / 2``.
+It requires ``n_neighbors > n_components * (n_components + 3) / 2``.
 
 .. figure:: ../auto_examples/manifold/images/plot_lle_digits_8.png
    :target: ../auto_examples/manifold/plot_lle_digits.html
@@ -355,7 +355,7 @@ Tips on practical use
 * The reconstruction error computed by each routine can be used to choose
   the optimal output dimension.  For a :math:`d`-dimensional manifold embedded
   in a :math:`D`-dimensional parameter space, the reconstruction error will
-  decrease as ``out_dim`` is increased until ``out_dim == d``.
+  decrease as ``n_components`` is increased until ``n_components == d``.
 
 * Note that noisy data can "short-circuit" the manifold, in essence acting
   as a bridge between parts of the manifold that would otherwise be
diff --git a/examples/applications/plot_stock_market.py b/examples/applications/plot_stock_market.py
index d59cfab23b7c5d13dc5a529681ffd1ecb77cfe3a..d68ea340334b73882c6870ff68d7964146e0753b 100644
--- a/examples/applications/plot_stock_market.py
+++ b/examples/applications/plot_stock_market.py
@@ -185,7 +185,7 @@ for i in range(n_labels + 1):
 # initiated with random vectors that we don't control). In addition, we
 # use a large number of neighbors to capture the large-scale structure.
 node_position_model = manifold.LocallyLinearEmbedding(
-    out_dim=2, eigen_solver='dense', n_neighbors=6)
+    n_components=2, eigen_solver='dense', n_neighbors=6)
 
 embedding = node_position_model.fit_transform(X.T).T
 
diff --git a/examples/manifold/plot_compare_methods.py b/examples/manifold/plot_compare_methods.py
index 1f9a8ddea0784446f45cc493e0e6f4d852e41a0f..78e756553771bfb77d89f306d27afff7f7252816 100644
--- a/examples/manifold/plot_compare_methods.py
+++ b/examples/manifold/plot_compare_methods.py
@@ -28,7 +28,7 @@ Axes3D
 n_points = 1000
 X, color = datasets.samples_generator.make_s_curve(n_points)
 n_neighbors = 10
-out_dim = 2
+n_components = 2
 
 fig = pl.figure(figsize=(12, 8))
 pl.suptitle("Manifold Learning with %i points, %i neighbors"
@@ -48,7 +48,7 @@ labels = ['LLE', 'LTSA', 'Hessian LLE', 'Modified LLE']
 
 for i, method in enumerate(methods):
     t0 = time()
-    Y = manifold.LocallyLinearEmbedding(n_neighbors, out_dim,
+    Y = manifold.LocallyLinearEmbedding(n_neighbors, n_components,
                                         eigen_solver='auto',
                                         method=method).fit_transform(X)
     t1 = time()
@@ -62,7 +62,7 @@ for i, method in enumerate(methods):
     pl.axis('tight')
 
 t0 = time()
-Y = manifold.Isomap(n_neighbors, out_dim).fit_transform(X)
+Y = manifold.Isomap(n_neighbors, n_components).fit_transform(X)
 t1 = time()
 print "Isomap: %.2g sec" % (t1 - t0)
 ax = fig.add_subplot(236)
diff --git a/examples/manifold/plot_lle_digits.py b/examples/manifold/plot_lle_digits.py
index 2a0ec53015b6a3faf9d11e54708f107246574d95..b290ace26609222b84419c443fb3d7a867c4635c 100644
--- a/examples/manifold/plot_lle_digits.py
+++ b/examples/manifold/plot_lle_digits.py
@@ -109,7 +109,7 @@ plot_embedding(X_lda,
 # Isomap projection of the digits dataset
 print "Computing Isomap embedding"
 t0 = time()
-X_iso = manifold.Isomap(n_neighbors, out_dim=2).fit_transform(X)
+X_iso = manifold.Isomap(n_neighbors, n_components=2).fit_transform(X)
 print "Done."
 plot_embedding(X_iso,
     "Isomap projection of the digits (time %.2fs)" %
@@ -119,7 +119,7 @@ plot_embedding(X_iso,
 #----------------------------------------------------------------------
 # Locally linear embedding of the digits dataset
 print "Computing LLE embedding"
-clf = manifold.LocallyLinearEmbedding(n_neighbors, out_dim=2,
+clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                       method='standard')
 t0 = time()
 X_lle = clf.fit_transform(X)
@@ -132,7 +132,7 @@ plot_embedding(X_lle,
 #----------------------------------------------------------------------
 # Modified Locally linear embedding of the digits dataset
 print "Computing modified LLE embedding"
-clf = manifold.LocallyLinearEmbedding(n_neighbors, out_dim=2,
+clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                       method='modified')
 t0 = time()
 X_mlle = clf.fit_transform(X)
@@ -145,7 +145,7 @@ plot_embedding(X_mlle,
 #----------------------------------------------------------------------
 # HLLE embedding of the digits dataset
 print "Computing Hessian LLE embedding"
-clf = manifold.LocallyLinearEmbedding(n_neighbors, out_dim=2,
+clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                       method='hessian')
 t0 = time()
 X_hlle = clf.fit_transform(X)
@@ -158,7 +158,7 @@ plot_embedding(X_hlle,
 #----------------------------------------------------------------------
 # LTSA embedding of the digits dataset
 print "Computing LTSA embedding"
-clf = manifold.LocallyLinearEmbedding(n_neighbors, out_dim=2,
+clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                       method='ltsa')
 t0 = time()
 X_ltsa = clf.fit_transform(X)
diff --git a/sklearn/manifold/isomap.py b/sklearn/manifold/isomap.py
index 6873d1c2ad3fe79be5281bfa84fff53fc25b495d..dff954215f4a4d4925166690e09e22dfc17c3bd7 100644
--- a/sklearn/manifold/isomap.py
+++ b/sklearn/manifold/isomap.py
@@ -4,6 +4,7 @@
 # License: BSD, (C) 2011
 
 import numpy as np
+import warnings
 from ..base import BaseEstimator
 from ..neighbors import NearestNeighbors, kneighbors_graph
 from ..utils.graph import graph_shortest_path
@@ -21,7 +22,7 @@ class Isomap(BaseEstimator):
     n_neighbors : integer
         number of neighbors to consider for each point.
 
-    out_dim : integer
+    n_components : integer
         number of coordinates for the manifold
 
     eigen_solver : ['auto'|'arpack'|'dense']
@@ -53,7 +54,7 @@ class Isomap(BaseEstimator):
 
     Attributes
     ----------
-    `embedding_` : array-like, shape (n_samples, out_dim)
+    `embedding_` : array-like, shape (n_samples, n_components)
         Stores the embedding vectors
 
     `kernel_pca_` : `KernelPCA` object used to implement the embedding
@@ -75,12 +76,16 @@ class Isomap(BaseEstimator):
         framework for nonlinear dimensionality reduction. Science 290 (5500)
     """
 
-    def __init__(self, n_neighbors=5, out_dim=2,
-                 eigen_solver='auto', tol=0,
-                 max_iter=None, path_method='auto',
-                 neighbors_algorithm='auto'):
+    def __init__(self, n_neighbors=5, n_components=2, eigen_solver='auto',
+            tol=0, max_iter=None, path_method='auto',
+            neighbors_algorithm='auto', out_dim=None):
+
+        if not out_dim is None:
+            warnings.warn("Parameter ``out_dim`` was renamed to "
+                "``n_components`` and is now deprecated.", DeprecationWarning)
+            n_components = n_components
         self.n_neighbors = n_neighbors
-        self.out_dim = out_dim
+        self.n_components = n_components
         self.eigen_solver = eigen_solver
         self.tol = tol
         self.max_iter = max_iter
@@ -92,7 +97,7 @@ class Isomap(BaseEstimator):
     def _fit_transform(self, X):
         self.nbrs_.fit(X)
         self.training_data_ = self.nbrs_._fit_X
-        self.kernel_pca_ = KernelPCA(n_components=self.out_dim,
+        self.kernel_pca_ = KernelPCA(n_components=self.n_components,
                                      kernel="precomputed",
                                      eigen_solver=self.eigen_solver,
                                      tol=self.tol, max_iter=self.max_iter)
@@ -160,7 +165,7 @@ class Isomap(BaseEstimator):
 
         Returns
         -------
-        X_new: array-like, shape (n_samples, out_dim)
+        X_new: array-like, shape (n_samples, n_components)
         """
         self._fit_transform(X)
         return self.embedding_
@@ -182,7 +187,7 @@ class Isomap(BaseEstimator):
 
         Returns
         -------
-        X_new: array-like, shape (n_samples, out_dim)
+        X_new: array-like, shape (n_samples, n_components)
         """
         distances, indices = self.nbrs_.kneighbors(X, return_distance=True)
 
diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py
index 28c51005504a4be91967cf4dd5bd096ac1f2399e..b9529670f437c7a4b3c7e19174f466218a02315d 100644
--- a/sklearn/manifold/locally_linear.py
+++ b/sklearn/manifold/locally_linear.py
@@ -5,6 +5,7 @@
 # License: BSD, (C) INRIA 2011
 
 import numpy as np
+import warnings
 from scipy.linalg import eigh, svd, qr, solve
 from scipy.sparse import eye, csr_matrix
 from ..base import BaseEstimator
@@ -177,10 +178,10 @@ def null_space(M, k, k_skip=1, eigen_solver='arpack', tol=1E-6, max_iter=100,
 
 
 def locally_linear_embedding(
-    X, n_neighbors, out_dim, reg=1e-3, eigen_solver='auto',
+    X, n_neighbors, n_components, reg=1e-3, eigen_solver='auto',
     tol=1e-6, max_iter=100, method='standard',
     hessian_tol=1E-4, modified_tol=1E-12,
-    random_state=None):
+    random_state=None, out_dim=None):
     """Perform a Locally Linear Embedding analysis on the data.
 
     Parameters
@@ -193,7 +194,7 @@ def locally_linear_embedding(
     n_neighbors : integer
         number of neighbors to consider for each point.
 
-    out_dim : integer
+    n_components : integer
         number of coordinates for the manifold.
 
     reg : float
@@ -223,7 +224,7 @@ def locally_linear_embedding(
         standard : use the standard locally linear embedding algorithm.
                    see reference [1]_
         hessian  : use the Hessian eigenmap method.  This method requires
-                   n_neighbors > out_dim * (1 + (out_dim + 1) / 2.
+                   n_neighbors > n_components * (1 + (n_components + 1) / 2.
                    see reference [2]_
         modified : use the modified locally linear embedding algorithm.
                    see reference [3]_
@@ -243,7 +244,7 @@ def locally_linear_embedding(
 
     Returns
     -------
-    Y : array-like, shape [n_samples, out_dim]
+    Y : array-like, shape [n_samples, n_components]
         Embedding vectors.
 
     squared_error : float
@@ -271,13 +272,18 @@ def locally_linear_embedding(
     if method not in ('standard', 'hessian', 'modified', 'ltsa'):
         raise ValueError("unrecognized method '%s'" % method)
 
+    if not out_dim is None:
+        warnings.warn("Parameter ``out_dim`` was renamed to ``n_components`` "
+                "and is now deprecated.", DeprecationWarning)
+        n_components = n_components
+
     nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1)
     nbrs.fit(X)
     X = nbrs._fit_X
 
     N, d_in = X.shape
 
-    if out_dim > d_in:
+    if n_components > d_in:
         raise ValueError("output dimension must be less than or equal "
                          "to input dimension")
     if n_neighbors >= N:
@@ -302,17 +308,17 @@ def locally_linear_embedding(
             M.flat[::M.shape[0] + 1] += 1  # W = W - I = W - I
 
     elif method == 'hessian':
-        dp = out_dim * (out_dim + 1) / 2
+        dp = n_components * (n_components + 1) / 2
 
-        if n_neighbors <= out_dim + dp:
+        if n_neighbors <= n_components + dp:
             raise ValueError("for method='hessian', n_neighbors must be "
-                             "greater than [out_dim * (out_dim + 3) / 2]")
+                    "greater than [n_components * (n_components + 3) / 2]")
 
         neighbors = nbrs.kneighbors(X, n_neighbors=n_neighbors + 1,
                                     return_distance=False)
         neighbors = neighbors[:, 1:]
 
-        Yi = np.empty((n_neighbors, 1 + out_dim + dp), dtype=np.float)
+        Yi = np.empty((n_neighbors, 1 + n_components + dp), dtype=np.float)
         Yi[:, 0] = 1
 
         M = np.zeros((N, N), dtype=np.float)
@@ -330,16 +336,17 @@ def locally_linear_embedding(
                 Ci = np.dot(Gi, Gi.T)
                 U = eigh(Ci)[1][:, ::-1]
 
-            Yi[:, 1:1 + out_dim] = U[:, :out_dim]
+            Yi[:, 1:1 + n_components] = U[:, :n_components]
 
-            j = 1 + out_dim
-            for k in range(out_dim):
-                Yi[:, j:j + out_dim - k] = U[:, k:k + 1] * U[:, k:out_dim]
-                j += out_dim - k
+            j = 1 + n_components
+            for k in range(n_components):
+                Yi[:, j:j + n_components - k] = \
+                        U[:, k:k + 1] * U[:, k:n_components]
+                j += n_components - k
 
             Q, R = qr(Yi)
 
-            w = Q[:, out_dim + 1:]
+            w = Q[:, n_components + 1:]
             S = w.sum(0)
 
             S[np.where(abs(S) < hessian_tol)] = 1
@@ -352,8 +359,9 @@ def locally_linear_embedding(
             M = csr_matrix(M)
 
     elif method == 'modified':
-        if n_neighbors < out_dim:
-            raise ValueError("modified LLE requires n_neighbors >= out_dim")
+        if n_neighbors < n_components:
+            raise ValueError("modified LLE requires "
+                "n_neighbors >= n_components")
 
         neighbors = nbrs.kneighbors(X, n_neighbors=n_neighbors + 1,
                                     return_distance=False)
@@ -399,7 +407,7 @@ def locally_linear_embedding(
 
         #calculate eta: the median of the ratio of small to large eigenvalues
         # across the points.  This is used to determine s_i, below
-        rho = evals[:, out_dim:].sum(1) / evals[:, :out_dim].sum(1)
+        rho = evals[:, n_components:].sum(1) / evals[:, :n_components].sum(1)
         eta = np.median(rho)
 
         #find s_i, the size of the "almost null space" for each point:
@@ -470,15 +478,15 @@ def locally_linear_embedding(
             Xi = X[neighbors[i]]
             Xi -= Xi.mean(0)
 
-            # compute out_dim largest eigenvalues of Xi * Xi^T
+            # compute n_components largest eigenvalues of Xi * Xi^T
             if use_svd:
                 v = svd(Xi, full_matrices=True)[0]
             else:
                 Ci = np.dot(Xi, Xi.T)
                 v = eigh(Ci)[1][:, ::-1]
 
-            Gi = np.zeros((n_neighbors, out_dim + 1))
-            Gi[:, 1:] = v[:, :out_dim]
+            Gi = np.zeros((n_neighbors, n_components + 1))
+            Gi[:, 1:] = v[:, :n_components]
             Gi[:, 0] = 1. / np.sqrt(n_neighbors)
 
             GiGiT = np.dot(Gi, Gi.T)
@@ -487,7 +495,7 @@ def locally_linear_embedding(
             M[nbrs_x, nbrs_y] -= GiGiT
             M[neighbors[i], neighbors[i]] += 1
 
-    return null_space(M, out_dim, k_skip=1, eigen_solver=eigen_solver,
+    return null_space(M, n_components, k_skip=1, eigen_solver=eigen_solver,
                       tol=tol, max_iter=max_iter, random_state=random_state)
 
 
@@ -499,7 +507,7 @@ class LocallyLinearEmbedding(BaseEstimator):
     n_neighbors : integer
         number of neighbors to consider for each point.
 
-    out_dim : integer
+    n_components : integer
         number of coordinates for the manifold
 
     reg : float
@@ -530,7 +538,7 @@ class LocallyLinearEmbedding(BaseEstimator):
         standard : use the standard locally linear embedding algorithm.
                    see reference [1]
         hessian  : use the Hessian eigenmap method.  This method requires
-                   n_neighbors > out_dim * (1 + (out_dim + 1) / 2.
+                   n_neighbors > n_components * (1 + (n_components + 1) / 2.
                    see reference [2]
         modified : use the modified locally linear embedding algorithm.
                    see reference [3]
@@ -555,7 +563,7 @@ class LocallyLinearEmbedding(BaseEstimator):
 
     Attributes
     ----------
-    `embedding_vectors_` : array-like, shape [out_dim, n_samples]
+    `embedding_vectors_` : array-like, shape [n_components, n_samples]
         Stores the embedding vectors
 
     `reconstruction_error_` : float
@@ -581,12 +589,18 @@ class LocallyLinearEmbedding(BaseEstimator):
         Journal of Shanghai Univ.  8:406 (2004)`
     """
 
-    def __init__(self, n_neighbors=5, out_dim=2, reg=1E-3,
-                 eigen_solver='auto', tol=1E-6, max_iter=100,
-                 method='standard', hessian_tol=1E-4, modified_tol=1E-12,
-                 neighbors_algorithm='auto', random_state=None):
+    def __init__(self, n_neighbors=5, n_components=2, reg=1E-3,
+            eigen_solver='auto', tol=1E-6, max_iter=100, method='standard',
+            hessian_tol=1E-4, modified_tol=1E-12, neighbors_algorithm='auto',
+            random_state=None, out_dim=None):
+
+        if not out_dim is None:
+            warnings.warn("Parameter ``out_dim`` was renamed to "
+                "``n_components`` and is now deprecated.", DeprecationWarning)
+            n_components = n_components
+
         self.n_neighbors = n_neighbors
-        self.out_dim = out_dim
+        self.n_components = n_components
         self.reg = reg
         self.eigen_solver = eigen_solver
         self.tol = tol
@@ -603,7 +617,7 @@ class LocallyLinearEmbedding(BaseEstimator):
         self.nbrs_.fit(X)
         self.embedding_, self.reconstruction_error_ = \
             locally_linear_embedding(
-                self.nbrs_, self.n_neighbors, self.out_dim,
+                self.nbrs_, self.n_neighbors, self.n_components,
                 eigen_solver=self.eigen_solver, tol=self.tol,
                 max_iter=self.max_iter, method=self.method,
                 hessian_tol=self.hessian_tol, modified_tol=self.modified_tol,
@@ -634,7 +648,7 @@ class LocallyLinearEmbedding(BaseEstimator):
 
         Returns
         -------
-        X_new: array-like, shape (n_samples, out_dim)
+        X_new: array-like, shape (n_samples, n_components)
         """
         self._fit_transform(X)
         return self.embedding_
@@ -649,7 +663,7 @@ class LocallyLinearEmbedding(BaseEstimator):
 
         Returns
         -------
-        X_new : array, shape = [n_samples, out_dim]
+        X_new : array, shape = [n_samples, n_components]
 
         Notes
         -----
@@ -661,7 +675,7 @@ class LocallyLinearEmbedding(BaseEstimator):
                                     return_distance=False)
         weights = barycenter_weights(X, self.nbrs_._fit_X[ind],
                                      reg=self.reg)
-        X_new = np.empty((X.shape[0], self.out_dim))
+        X_new = np.empty((X.shape[0], self.n_components))
         for i in range(X.shape[0]):
             X_new[i] = np.dot(self.embedding_[ind[i]].T, weights[i])
         return X_new
diff --git a/sklearn/manifold/tests/test_isomap.py b/sklearn/manifold/tests/test_isomap.py
index 17155473294929fb50a85c951b67013129afd0f0..0ba3ee59cbb4b76a47ce0e263922c8ff8daa4b9d 100644
--- a/sklearn/manifold/tests/test_isomap.py
+++ b/sklearn/manifold/tests/test_isomap.py
@@ -21,7 +21,7 @@ def test_isomap_simple_grid():
     Npts = N_per_side ** 2
     n_neighbors = Npts - 1
 
-    # grid of equidistant points in 2D, out_dim = n_dim
+    # grid of equidistant points in 2D, n_components = n_dim
     X = np.array(list(product(range(N_per_side), repeat=2)))
 
     # distances from each point to all others
@@ -30,7 +30,7 @@ def test_isomap_simple_grid():
 
     for eigen_solver in eigen_solvers:
         for path_method in path_methods:
-            clf = manifold.Isomap(n_neighbors=n_neighbors, out_dim=2,
+            clf = manifold.Isomap(n_neighbors=n_neighbors, n_components=2,
                                   eigen_solver=eigen_solver,
                                   path_method=path_method)
             clf.fit(X)
@@ -47,7 +47,7 @@ def test_isomap_reconstruction_error():
     Npts = N_per_side ** 2
     n_neighbors = Npts - 1
 
-    # grid of equidistant points in 2D, out_dim = n_dim
+    # grid of equidistant points in 2D, n_components = n_dim
     X = np.array(list(product(range(N_per_side), repeat=2)))
 
     # add noise in a third dimension
@@ -64,7 +64,7 @@ def test_isomap_reconstruction_error():
 
     for eigen_solver in eigen_solvers:
         for path_method in path_methods:
-            clf = manifold.Isomap(n_neighbors=n_neighbors, out_dim=2,
+            clf = manifold.Isomap(n_neighbors=n_neighbors, n_components=2,
                                   eigen_solver=eigen_solver,
                                   path_method=path_method)
             clf.fit(X)
diff --git a/sklearn/manifold/tests/test_locally_linear.py b/sklearn/manifold/tests/test_locally_linear.py
index a21ccbdd55538196b771c90549135cd297dc4817..a45c8addb873cd10d4864ceb3151a78182a7529a 100644
--- a/sklearn/manifold/tests/test_locally_linear.py
+++ b/sklearn/manifold/tests/test_locally_linear.py
@@ -34,11 +34,12 @@ def test_barycenter_kneighbors_graph():
 
 def test_lle_simple_grid():
     rng = np.random.RandomState(0)
-    # grid of equidistant points in 2D, out_dim = n_dim
+    # grid of equidistant points in 2D, n_components = n_dim
     X = np.array(list(product(range(5), repeat=2)))
     X = X + 1e-10 * rng.uniform(size=X.shape)
-    out_dim = 2
-    clf = manifold.LocallyLinearEmbedding(n_neighbors=5, out_dim=out_dim)
+    n_components = 2
+    clf = manifold.LocallyLinearEmbedding(n_neighbors=5,
+            n_components=n_components)
     tol = .1
 
     N = barycenter_kneighbors_graph(X, clf.n_neighbors).todense()
@@ -48,7 +49,7 @@ def test_lle_simple_grid():
     for solver in eigen_solvers:
         clf.set_params(eigen_solver=solver)
         clf.fit(X)
-        assert_true(clf.embedding_.shape[1] == out_dim)
+        assert_true(clf.embedding_.shape[1] == n_components)
         reconstruction_error = np.linalg.norm(
             np.dot(N, clf.embedding_) - clf.embedding_, 'fro') ** 2
         # FIXME: ARPACK fails this test ...
@@ -68,8 +69,9 @@ def test_lle_manifold():
     X = np.array(list(product(range(20), repeat=2)))
     X = np.c_[X, X[:, 0] ** 2 / 20]
     X = X + 1e-10 * np.random.uniform(size=X.shape)
-    out_dim = 2
-    clf = manifold.LocallyLinearEmbedding(n_neighbors=5, out_dim=out_dim,
+    n_components = 2
+    clf = manifold.LocallyLinearEmbedding(n_neighbors=5,
+            n_components=n_components,
                                           random_state=0)
     tol = 1.5
 
@@ -80,7 +82,7 @@ def test_lle_manifold():
     for solver in eigen_solvers:
         clf.set_params(eigen_solver=solver)
         clf.fit(X)
-        assert_true(clf.embedding_.shape[1] == out_dim)
+        assert_true(clf.embedding_.shape[1] == n_components)
         reconstruction_error = np.linalg.norm(
             np.dot(N, clf.embedding_) - clf.embedding_, 'fro') ** 2
         details = "solver: " + solver