diff --git a/build_tools/travis/flake8_diff.sh b/build_tools/travis/flake8_diff.sh
index 87ffdffd345ce15655e1a4f7cb67bcb374070de3..cf3dcb5577e9c4202d1903649a00fdb9b72a5067 100755
--- a/build_tools/travis/flake8_diff.sh
+++ b/build_tools/travis/flake8_diff.sh
@@ -137,8 +137,8 @@ check_files() {
 if [[ "$MODIFIED_FILES" == "no_match" ]]; then
     echo "No file outside sklearn/externals and doc/sphinxext/sphinx_gallery has been modified"
 else
-    check_files "$(echo "$MODIFIED_FILES" | grep -v ^examples)"
+    check_files "$(echo "$MODIFIED_FILES" | grep -v ^examples)" --ignore=W503
     # Examples are allowed to not have imports at top of file
-    check_files "$(echo "$MODIFIED_FILES" | grep ^examples)" --ignore=E402
+    check_files "$(echo "$MODIFIED_FILES" | grep ^examples)" --ignore=E402,W503
 fi
 echo -e "No problem detected by flake8\n"
diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index ab5a27f832609c0e02190a2fae51384109f12ee0..d03b92d4aaed8661c5a7999dcec70a976180c810 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -1198,6 +1198,7 @@ See the :ref:`metrics` section of the user guide for further details.
    preprocessing.Normalizer
    preprocessing.OneHotEncoder
    preprocessing.PolynomialFeatures
+   preprocessing.QuantileTransformer
    preprocessing.RobustScaler
    preprocessing.StandardScaler
 
@@ -1211,6 +1212,7 @@ See the :ref:`metrics` section of the user guide for further details.
    preprocessing.maxabs_scale
    preprocessing.minmax_scale
    preprocessing.normalize
+   preprocessing.quantile_transform
    preprocessing.robust_scale
    preprocessing.scale
 
diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst
index 709239687158e6c0402c5b1f79e5cbcdf00df67f..3b75eed6a7ff2bb288b09c4535b38682814fcf3c 100644
--- a/doc/modules/preprocessing.rst
+++ b/doc/modules/preprocessing.rst
@@ -10,6 +10,13 @@ The ``sklearn.preprocessing`` package provides several common
 utility functions and transformer classes to change raw feature vectors
 into a representation that is more suitable for the downstream estimators.
 
+In general, learning algorithms benefit from standardization of the data set. If
+some outliers are present in the set, robust scalers or transformers are more
+appropriate. The behaviors of the different scalers, transformers, and
+normalizers on a dataset containing marginal outliers is highlighted in
+:ref:`sphx_glr_auto_examples_preprocessing_plot_all_scaling.py`.
+
+
 .. _preprocessing_scaler:
 
 Standardization, or mean removal and variance scaling
@@ -39,10 +46,10 @@ operation on a single array-like dataset::
 
   >>> from sklearn import preprocessing
   >>> import numpy as np
-  >>> X = np.array([[ 1., -1.,  2.],
-  ...               [ 2.,  0.,  0.],
-  ...               [ 0.,  1., -1.]])
-  >>> X_scaled = preprocessing.scale(X)
+  >>> X_train = np.array([[ 1., -1.,  2.],
+  ...                     [ 2.,  0.,  0.],
+  ...                     [ 0.,  1., -1.]])
+  >>> X_scaled = preprocessing.scale(X_train)
 
   >>> X_scaled                                          # doctest: +ELLIPSIS
   array([[ 0.  ..., -1.22...,  1.33...],
@@ -71,7 +78,7 @@ able to later reapply the same transformation on the testing set.
 This class is hence suitable for use in the early steps of a
 :class:`sklearn.pipeline.Pipeline`::
 
-  >>> scaler = preprocessing.StandardScaler().fit(X)
+  >>> scaler = preprocessing.StandardScaler().fit(X_train)
   >>> scaler
   StandardScaler(copy=True, with_mean=True, with_std=True)
 
@@ -81,7 +88,7 @@ This class is hence suitable for use in the early steps of a
   >>> scaler.scale_                                       # doctest: +ELLIPSIS
   array([ 0.81...,  0.81...,  1.24...])
 
-  >>> scaler.transform(X)                               # doctest: +ELLIPSIS
+  >>> scaler.transform(X_train)                           # doctest: +ELLIPSIS
   array([[ 0.  ..., -1.22...,  1.33...],
          [ 1.22...,  0.  ..., -0.26...],
          [-1.22...,  1.22..., -1.06...]])
@@ -90,7 +97,8 @@ This class is hence suitable for use in the early steps of a
 The scaler instance can then be used on new data to transform it the
 same way it did on the training set::
 
-  >>> scaler.transform([[-1.,  1., 0.]])                # doctest: +ELLIPSIS
+  >>> X_test = [[-1., 1., 0.]]
+  >>> scaler.transform(X_test)                # doctest: +ELLIPSIS
   array([[-2.44...,  1.22..., -0.26...]])
 
 It is possible to disable either centering or scaling by either
@@ -248,6 +256,69 @@ a :class:`KernelCenterer` can transform the kernel matrix
 so that it contains inner products in the feature space
 defined by :math:`phi` followed by removal of the mean in that space.
 
+.. _preprocessing_transformer:
+
+Non-linear transformation
+=========================
+
+Like scalers, :class:`QuantileTransformer` puts each feature into the same
+range or distribution. However, by performing a rank transformation, it smooths
+out unusual distributions and is less influenced by outliers than scaling
+methods. It does, however, distort correlations and distances within and across
+features.
+
+:class:`QuantileTransformer` and :func:`quantile_transform` provide a
+non-parametric transformation based on the quantile function to map the data to
+a uniform distribution with values between 0 and 1::
+
+  >>> from sklearn.datasets import load_iris
+  >>> from sklearn.model_selection import train_test_split
+  >>> iris = load_iris()
+  >>> X, y = iris.data, iris.target
+  >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
+  >>> quantile_transformer = preprocessing.QuantileTransformer(random_state=0)
+  >>> X_train_trans = quantile_transformer.fit_transform(X_train)
+  >>> X_test_trans = quantile_transformer.transform(X_test)
+  >>> np.percentile(X_train[:, 0], [0, 25, 50, 75, 100]) # doctest: +SKIP
+  array([ 4.3,  5.1,  5.8,  6.5,  7.9])
+
+This feature corresponds to the sepal length in cm. Once the quantile
+transformation applied, those landmarks approach closely the percentiles
+previously defined::
+
+  >>> np.percentile(X_train_trans[:, 0], [0, 25, 50, 75, 100])
+  ... # doctest: +ELLIPSIS +SKIP
+  array([ 0.00... ,  0.24...,  0.49...,  0.73...,  0.99... ])
+
+This can be confirmed on a independent testing set with similar remarks::
+
+  >>> np.percentile(X_test[:, 0], [0, 25, 50, 75, 100])
+  ... # doctest: +SKIP
+  array([ 4.4  ,  5.125,  5.75 ,  6.175,  7.3  ])
+  >>> np.percentile(X_test_trans[:, 0], [0, 25, 50, 75, 100])
+  ... # doctest: +ELLIPSIS +SKIP
+  array([ 0.01...,  0.25...,  0.46...,  0.60... ,  0.94...])
+
+It is also possible to map the transformed data to a normal distribution by
+setting ``output_distribution='normal'``::
+
+  >>> quantile_transformer = preprocessing.QuantileTransformer(
+  ...     output_distribution='normal', random_state=0)
+  >>> X_trans = quantile_transformer.fit_transform(X)
+  >>> quantile_transformer.quantiles_ # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+  array([[ 4.3...,   2...,     1...,     0.1...],
+         [ 4.31...,  2.02...,  1.01...,  0.1...],
+         [ 4.32...,  2.05...,  1.02...,  0.1...],
+         ...,
+         [ 7.84...,  4.34...,  6.84...,  2.5...],
+         [ 7.87...,  4.37...,  6.87...,  2.5...],
+         [ 7.9...,   4.4...,   6.9...,   2.5...]])
+
+Thus the median of the input becomes the mean of the output, centered at 0. The
+normal output is clipped so that the input's minimum and maximum ---
+corresponding to the 1e-7 and 1 - 1e-7 quantiles respectively --- do not
+become infinite under the transformation.
+
 .. _preprocessing_normalization:
 
 Normalization
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 75e616d97d53799a52a280e42e68b754cac4f033..5f365444c21fbb7a580f7c757eef3f92bfd32d13 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -62,6 +62,13 @@ New features
      during the first epochs of ridge and logistic regression.
      By `Arthur Mensch`_.
 
+   - Added :class:`preprocessing.QuantileTransformer` class and
+     :func:`preprocessing.quantile_transform` function for features
+     normalization based on quantiles.
+     :issue:`8363` by :user:`Denis Engemann <dengemann>`,
+     :user:`Guillaume Lemaitre <glemaitre>`, `Olivier Grisel`_, `Raghav RV`_,
+     :user:`Thierry Guillemot <tguillemot>`_, and `Gael Varoquaux`_.
+
 Enhancements
 ............
 
@@ -172,7 +179,7 @@ Enhancements
    - Add ``sample_weight`` parameter to :func:`metrics.cohen_kappa_score` by
      Victor Poughon.
 
-   - In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict`` 
+   - In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict``
      is a lot faster with ``return_std=True`` by :user:`Hadrien Bertrand <hbertrand>`.
 
    - Added ability to use sparse matrices in :func:`feature_selection.f_regression`
@@ -331,7 +338,7 @@ Bug fixes
      both ``'binary'`` but the union of ``y_true`` and ``y_pred`` was
      ``'multiclass'``. :issue:`8377` by `Loic Esteve`_.
 
-   - Fix :func:`sklearn.linear_model.BayesianRidge.fit` to return 
+   - Fix :func:`sklearn.linear_model.BayesianRidge.fit` to return
      ridge parameter `alpha_` and `lambda_` consistent with calculated
      coefficients `coef_` and `intercept_`.
      :issue:`8224` by :user:`Peter Gedeck <gedeck>`.
diff --git a/examples/preprocessing/plot_all_scaling.py b/examples/preprocessing/plot_all_scaling.py
new file mode 100755
index 0000000000000000000000000000000000000000..677386a00191c84955ba51f35626374635198c47
--- /dev/null
+++ b/examples/preprocessing/plot_all_scaling.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+=============================================================
+Compare the effect of different scalers on data with outliers
+=============================================================
+
+Feature 0 (median income in a block) and feature 5 (number of households) of
+the `California housing dataset
+<http://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html>`_ have very
+different scales and contain some very large outliers. These two
+characteristics lead to difficulties to visualize the data and, more
+importantly, they can degrade the predictive performance of many machine
+learning algorithms. Unscaled data can also slow down or even prevent the
+convergence of many gradient-based estimators.
+
+Indeed many estimators are designed with the assumption that each feature takes
+values close to zero or more importantly that all features vary on comparable
+scales. In particular, metric-based and gradient-based estimators often assume
+approximately standardized data (centered features with unit variances). A
+notable exception are decision tree-based estimators that are robust to
+arbitrary scaling of the data.
+
+This example uses different scalers, transformers, and normalizers to bring the
+data within a pre-defined range.
+
+Scalers are linear (or more precisely affine) transformers and differ from each
+other in the way to estimate the parameters used to shift and scale each
+feature.
+
+``QuantileTransformer`` provides a non-linear transformation in which distances
+between marginal outliers and inliers are shrunk.
+
+Unlike the previous transformations, normalization refers to a per sample
+transformation instead of a per feature transformation.
+
+The following code is a bit verbose, feel free to jump directly to the analysis
+of the results_.
+
+"""
+
+# Author:  Raghav RV <rvraghav93@gmail.com>
+#          Guillaume Lemaitre <g.lemaitre58@gmail.com>
+#          Thomas Unterthiner
+# License: BSD 3 clause
+
+from __future__ import print_function
+
+import numpy as np
+
+import matplotlib as mpl
+from matplotlib import pyplot as plt
+from matplotlib import cm
+
+from sklearn.preprocessing import MinMaxScaler
+from sklearn.preprocessing import minmax_scale
+from sklearn.preprocessing import MaxAbsScaler
+from sklearn.preprocessing import StandardScaler
+from sklearn.preprocessing import RobustScaler
+from sklearn.preprocessing import Normalizer
+from sklearn.preprocessing.data import QuantileTransformer
+
+from sklearn.datasets import fetch_california_housing
+
+print(__doc__)
+
+dataset = fetch_california_housing()
+X_full, y_full = dataset.data, dataset.target
+
+# Take only 2 features to make visualization easier
+# Feature of 0 has a long tail distribution.
+# Feature 5 has a few but very large outliers.
+
+X = X_full[:, [0, 5]]
+
+distributions = [
+    ('Unscaled data', X),
+    ('Data after standard scaling',
+        StandardScaler().fit_transform(X)),
+    ('Data after min-max scaling',
+        MinMaxScaler().fit_transform(X)),
+    ('Data after max-abs scaling',
+        MaxAbsScaler().fit_transform(X)),
+    ('Data after robust scaling',
+        RobustScaler(quantile_range=(25, 75)).fit_transform(X)),
+    ('Data after quantile transformation (uniform pdf)',
+        QuantileTransformer(output_distribution='uniform')
+        .fit_transform(X)),
+    ('Data after quantile transformation (gaussian pdf)',
+        QuantileTransformer(output_distribution='normal')
+        .fit_transform(X)),
+    ('Data after sample-wise L2 normalizing',
+        Normalizer().fit_transform(X))
+]
+
+# scale the output between 0 and 1 for the colorbar
+y = minmax_scale(y_full)
+
+
+def create_axes(title, figsize=(16, 6)):
+    fig = plt.figure(figsize=figsize)
+    fig.suptitle(title)
+
+    # define the axis for the first plot
+    left, width = 0.1, 0.22
+    bottom, height = 0.1, 0.7
+    bottom_h = height + 0.15
+    left_h = left + width + 0.02
+
+    rect_scatter = [left, bottom, width, height]
+    rect_histx = [left, bottom_h, width, 0.1]
+    rect_histy = [left_h, bottom, 0.05, height]
+
+    ax_scatter = plt.axes(rect_scatter)
+    ax_histx = plt.axes(rect_histx)
+    ax_histy = plt.axes(rect_histy)
+
+    # define the axis for the zoomed-in plot
+    left = width + left + 0.2
+    left_h = left + width + 0.02
+
+    rect_scatter = [left, bottom, width, height]
+    rect_histx = [left, bottom_h, width, 0.1]
+    rect_histy = [left_h, bottom, 0.05, height]
+
+    ax_scatter_zoom = plt.axes(rect_scatter)
+    ax_histx_zoom = plt.axes(rect_histx)
+    ax_histy_zoom = plt.axes(rect_histy)
+
+    # define the axis for the colorbar
+    left, width = width + left + 0.13, 0.01
+
+    rect_colorbar = [left, bottom, width, height]
+    ax_colorbar = plt.axes(rect_colorbar)
+
+    return ((ax_scatter, ax_histy, ax_histx),
+            (ax_scatter_zoom, ax_histy_zoom, ax_histx_zoom),
+            ax_colorbar)
+
+
+def plot_distribution(axes, X, y, hist_nbins=50, title="",
+                      x0_label="", x1_label=""):
+    ax, hist_X1, hist_X0 = axes
+
+    ax.set_title(title)
+    ax.set_xlabel(x0_label)
+    ax.set_ylabel(x1_label)
+
+    # The scatter plot
+    colors = cm.plasma_r(y)
+    ax.scatter(X[:, 0], X[:, 1], alpha=0.5, marker='o', s=5, lw=0, c=colors)
+
+    # Removing the top and the right spine for aesthetics
+    # make nice axis layout
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
+    ax.get_xaxis().tick_bottom()
+    ax.get_yaxis().tick_left()
+    ax.spines['left'].set_position(('outward', 10))
+    ax.spines['bottom'].set_position(('outward', 10))
+
+    # Histogram for axis X1 (feature 5)
+    hist_X1.set_ylim(ax.get_ylim())
+    hist_X1.hist(X[:, 1], bins=hist_nbins, orientation='horizontal',
+                 color='grey', ec='grey')
+    hist_X1.axis('off')
+
+    # Histogram for axis X0 (feature 0)
+    hist_X0.set_xlim(ax.get_xlim())
+    hist_X0.hist(X[:, 0], bins=hist_nbins, orientation='vertical',
+                 color='grey', ec='grey')
+    hist_X0.axis('off')
+
+###############################################################################
+# Two plots will be shown for each scaler/normalizer/transformer. The left
+# figure will show a scatter plot of the full data set while the right figure
+# will exclude the extreme values considering only 99 % of the data set,
+# excluding marginal outliers. In addition, the marginal distributions for each
+# feature will be shown on the side of the scatter plot.
+
+
+def make_plot(item_idx):
+    title, X = distributions[item_idx]
+    ax_zoom_out, ax_zoom_in, ax_colorbar = create_axes(title)
+    axarr = (ax_zoom_out, ax_zoom_in)
+    plot_distribution(axarr[0], X, y, hist_nbins=200,
+                      x0_label="Median Income",
+                      x1_label="Number of households",
+                      title="Full data")
+
+    # zoom-in
+    zoom_in_percentile_range = (0, 99)
+    cutoffs_X0 = np.percentile(X[:, 0], zoom_in_percentile_range)
+    cutoffs_X1 = np.percentile(X[:, 1], zoom_in_percentile_range)
+
+    non_outliers_mask = (
+        np.all(X > [cutoffs_X0[0], cutoffs_X1[0]], axis=1) &
+        np.all(X < [cutoffs_X0[1], cutoffs_X1[1]], axis=1))
+    plot_distribution(axarr[1], X[non_outliers_mask], y[non_outliers_mask],
+                      hist_nbins=50,
+                      x0_label="Median Income",
+                      x1_label="Number of households",
+                      title="Zoom-in")
+
+    norm = mpl.colors.Normalize(y_full.min(), y_full.max())
+    mpl.colorbar.ColorbarBase(ax_colorbar, cmap=cm.plasma_r,
+                              norm=norm, orientation='vertical',
+                              label='Color mapping for values of y')
+
+
+########################################################################
+# .. _results:
+#
+# Original data
+# -------------
+#
+# Each transformation is plotted showing two transformed features, with the
+# left plot showing the entire dataset, and the right zoomed-in to show the
+# dataset without the marginal outliers. A large majority of the samples are
+# compacted to a specific range, [0, 10] for the median income and [0, 6] for
+# the number of households. Note that there are some marginal outliers (some
+# blocks have more than 1200 households). Therefore, a specific pre-processing
+# can be very beneficial depending of the application. In the following, we
+# present some insights and behaviors of those pre-processing methods in the
+# presence of marginal outliers.
+
+make_plot(0)
+
+#######################################################################
+# StandardScaler
+# --------------
+#
+# ``StandardScaler`` removes the mean and scales the data to unit variance.
+# However, the outliers have an influence when computing the empirical mean and
+# standard deviation which shrink the range of the feature values as shown in
+# the left figure below. Note in particular that because the outliers on each
+# feature have different magnitudes, the spread of the transformed data on
+# each feature is very different: most of the data lie in the [-2, 4] range for
+# the transformed median income feature while the same data is squeezed in the
+# smaller [-0.2, 0.2] range for the transformed number of households.
+#
+# ``StandardScaler`` therefore cannot guarantee balanced feature scales in the
+# presence of outliers.
+
+make_plot(1)
+
+##########################################################################
+# MinMaxScaler
+# ------------
+#
+# ``MinMaxScaler`` rescales the data set such that all feature values are in
+# the range [0, 1] as shown in the right panel below. However, this scaling
+# compress all inliers in the narrow range [0, 0.005] for the transformed
+# number of households.
+#
+# As ``StandardScaler``, ``MinMaxScaler`` is very sensitive to the presence of
+# outliers.
+
+make_plot(2)
+
+#############################################################################
+# MaxAbsScaler
+# ------------
+#
+# ``MaxAbsScaler`` differs from the previous scaler such that the absolute
+# values are mapped in the range [0, 1]. On positive only data, this scaler
+# behaves similarly to ``MinMaxScaler`` and therefore also suffers from the
+# presence of large outliers.
+
+make_plot(3)
+
+##############################################################################
+# RobustScaler
+# ------------
+#
+# Unlike the previous scalers, the centering and scaling statistics of this
+# scaler are based on percentiles and are therefore not influenced by a few
+# number of very large marginal outliers. Consequently, the resulting range of
+# the transformed feature values is larger than for the previous scalers and,
+# more importantly, are approximately similar: for both features most of the
+# transformed values lie in a [-2, 3] range as seen in the zoomed-in figure.
+# Note that the outliers themselves are still present in the transformed data.
+# If a separate outlier clipping is desirable, a non-linear transformation is
+# required (see below).
+
+make_plot(4)
+
+###################################################################
+# QuantileTransformer (uniform output)
+# ------------------------------------
+#
+# ``QuantileTransformer`` applies a non-linear transformation such that the
+# probability density function of each feature will be mapped to a uniform
+# distribution. In this case, all the data will be mapped in the range [0, 1],
+# even the outliers which cannot be distinguished anymore from the inliers.
+#
+# As ``RobustScaler``, ``QuantileTransformer`` is robust to outliers in the
+# sense that adding or removing outliers in the training set will yield
+# approximately the same transformation on held out data. But contrary to
+# ``RobustScaler``, ``QuantileTransformer`` will also automatically collapse
+# any outlier by setting them to the a priori defined range boundaries (0 and
+# 1).
+
+make_plot(5)
+
+##############################################################################
+# QuantileTransformer (Gaussian output)
+# -------------------------------------
+#
+# ``QuantileTransformer`` has an additional ``output_distribution`` parameter
+# allowing to match a Gaussian distribution instead of a uniform distribution.
+# Note that this non-parametetric transformer introduces saturation artifacts
+# for extreme values.
+
+make_plot(6)
+
+##############################################################################
+# Normalizer
+# ----------
+#
+# The ``Normalizer`` rescales the vector for each sample to have unit norm,
+# independently of the distribution of the samples. It can be seen on both
+# figures below where all samples are mapped onto the unit circle. In our
+# example the two selected features have only positive values; therefore the
+# transformed data only lie in the positive quadrant. This would not be the
+# case if some original features had a mix of positive and negative values.
+
+make_plot(7)
+plt.show()
diff --git a/examples/preprocessing/plot_robust_scaling.py b/examples/preprocessing/plot_robust_scaling.py
deleted file mode 100644
index e752284147b4d9a9872bf758bdad6aa473f4ff3e..0000000000000000000000000000000000000000
--- a/examples/preprocessing/plot_robust_scaling.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#!/usr/bin/python
-# -*- coding: utf-8 -*-
-
-"""
-=========================================================
-Robust Scaling on Toy Data
-=========================================================
-
-Making sure that each Feature has approximately the same scale can be a
-crucial preprocessing step. However, when data contains outliers,
-:class:`StandardScaler <sklearn.preprocessing.StandardScaler>` can often
-be mislead. In such cases, it is better to use a scaler that is robust
-against outliers.
-
-Here, we demonstrate this on a toy dataset, where one single datapoint
-is a large outlier.
-"""
-from __future__ import print_function
-print(__doc__)
-
-
-# Code source: Thomas Unterthiner
-# License: BSD 3 clause
-
-import matplotlib.pyplot as plt
-import numpy as np
-from sklearn.preprocessing import StandardScaler, RobustScaler
-
-# Create training and test data
-np.random.seed(42)
-n_datapoints = 100
-Cov = [[0.9, 0.0], [0.0, 20.0]]
-mu1 = [100.0, -3.0]
-mu2 = [101.0, -3.0]
-X1 = np.random.multivariate_normal(mean=mu1, cov=Cov, size=n_datapoints)
-X2 = np.random.multivariate_normal(mean=mu2, cov=Cov, size=n_datapoints)
-Y_train = np.hstack([[-1]*n_datapoints, [1]*n_datapoints])
-X_train = np.vstack([X1, X2])
-
-X1 = np.random.multivariate_normal(mean=mu1, cov=Cov, size=n_datapoints)
-X2 = np.random.multivariate_normal(mean=mu2, cov=Cov, size=n_datapoints)
-Y_test = np.hstack([[-1]*n_datapoints, [1]*n_datapoints])
-X_test = np.vstack([X1, X2])
-
-X_train[0, 0] = -1000  # a fairly large outlier
-
-
-# Scale data
-standard_scaler = StandardScaler()
-Xtr_s = standard_scaler.fit_transform(X_train)
-Xte_s = standard_scaler.transform(X_test)
-
-robust_scaler = RobustScaler()
-Xtr_r = robust_scaler.fit_transform(X_train)
-Xte_r = robust_scaler.transform(X_test)
-
-
-# Plot data
-fig, ax = plt.subplots(1, 3, figsize=(12, 4))
-ax[0].scatter(X_train[:, 0], X_train[:, 1],
-              color=np.where(Y_train > 0, 'r', 'b'))
-ax[1].scatter(Xtr_s[:, 0], Xtr_s[:, 1], color=np.where(Y_train > 0, 'r', 'b'))
-ax[2].scatter(Xtr_r[:, 0], Xtr_r[:, 1], color=np.where(Y_train > 0, 'r', 'b'))
-ax[0].set_title("Unscaled data")
-ax[1].set_title("After standard scaling (zoomed in)")
-ax[2].set_title("After robust scaling (zoomed in)")
-# for the scaled data, we zoom in to the data center (outlier can't be seen!)
-for a in ax[1:]:
-    a.set_xlim(-3, 3)
-    a.set_ylim(-3, 3)
-plt.tight_layout()
-plt.show()
-
-
-# Classify using k-NN
-from sklearn.neighbors import KNeighborsClassifier
-
-knn = KNeighborsClassifier()
-knn.fit(Xtr_s, Y_train)
-acc_s = knn.score(Xte_s, Y_test)
-print("Testset accuracy using standard scaler: %.3f" % acc_s)
-knn.fit(Xtr_r, Y_train)
-acc_r = knn.score(Xte_r, Y_test)
-print("Testset accuracy using robust scaler:   %.3f" % acc_r)
diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py
index cabbd469c10d4511def4daf5a727b61d4c96fa6a..2b105709ffe08f2eb7ec4c5f6df6fcce998fbd11 100644
--- a/sklearn/preprocessing/__init__.py
+++ b/sklearn/preprocessing/__init__.py
@@ -12,6 +12,7 @@ from .data import MaxAbsScaler
 from .data import Normalizer
 from .data import RobustScaler
 from .data import StandardScaler
+from .data import QuantileTransformer
 from .data import add_dummy_feature
 from .data import binarize
 from .data import normalize
@@ -19,6 +20,7 @@ from .data import scale
 from .data import robust_scale
 from .data import maxabs_scale
 from .data import minmax_scale
+from .data import quantile_transform
 from .data import OneHotEncoder
 
 from .data import PolynomialFeatures
@@ -41,6 +43,7 @@ __all__ = [
     'MultiLabelBinarizer',
     'MinMaxScaler',
     'MaxAbsScaler',
+    'QuantileTransformer',
     'Normalizer',
     'OneHotEncoder',
     'RobustScaler',
@@ -54,4 +57,5 @@ __all__ = [
     'maxabs_scale',
     'minmax_scale',
     'label_binarize',
+    'quantile_transform',
 ]
diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py
index 46937c77bee46bc743af9d804f8cd82f5ed8434d..107656702bad95690f14c3bd225fd6614cfb17cf 100644
--- a/sklearn/preprocessing/data.py
+++ b/sklearn/preprocessing/data.py
@@ -6,6 +6,8 @@
 #          Giorgio Patrini <giorgio.patrini@anu.edu.au>
 # License: BSD 3 clause
 
+from __future__ import division
+
 from itertools import chain, combinations
 import numbers
 import warnings
@@ -13,6 +15,7 @@ from itertools import combinations_with_replacement as combinations_w_r
 
 import numpy as np
 from scipy import sparse
+from scipy import stats
 
 from ..base import BaseEstimator, TransformerMixin
 from ..externals import six
@@ -24,7 +27,9 @@ from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1,
 from ..utils.sparsefuncs import (inplace_column_scale,
                                  mean_variance_axis, incr_mean_variance_axis,
                                  min_max_axis)
-from ..utils.validation import check_is_fitted, FLOAT_DTYPES
+from ..utils.validation import (check_is_fitted, check_random_state,
+                                FLOAT_DTYPES)
+BOUNDS_THRESHOLD = 1e-7
 
 
 zip = six.moves.zip
@@ -40,6 +45,7 @@ __all__ = [
     'OneHotEncoder',
     'RobustScaler',
     'StandardScaler',
+    'QuantileTransformer',
     'add_dummy_feature',
     'binarize',
     'normalize',
@@ -47,6 +53,7 @@ __all__ = [
     'robust_scale',
     'maxabs_scale',
     'minmax_scale',
+    'quantile_transform',
 ]
 
 
@@ -110,10 +117,14 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
 
     To avoid memory copy the caller should pass a CSC matrix.
 
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
+
     See also
     --------
     StandardScaler: Performs scaling to unit variance using the``Transformer`` API
         (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`).
+
     """  # noqa
     X = check_array(X, accept_sparse='csc', copy=copy, ensure_2d=False,
                     warn_on_dtype=True, estimator='the scale function',
@@ -233,7 +244,12 @@ class MinMaxScaler(BaseEstimator, TransformerMixin):
 
     See also
     --------
-    minmax_scale: Equivalent function without the object oriented API.
+    minmax_scale: Equivalent function without the estimator API.
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """
 
     def __init__(self, feature_range=(0, 1), copy=True):
@@ -390,6 +406,11 @@ def minmax_scale(X, feature_range=(0, 1), axis=0, copy=True):
     --------
     MinMaxScaler: Performs scaling to a given range using the``Transformer`` API
         (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`).
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """  # noqa
     # Unlike the scaler object, this function allows 1d input.
     # If copy is required, it will be done inside the scaler object.
@@ -478,10 +499,15 @@ class StandardScaler(BaseEstimator, TransformerMixin):
 
     See also
     --------
-    scale: Equivalent function without the object oriented API.
+    scale: Equivalent function without the estimator API.
 
     :class:`sklearn.decomposition.PCA`
         Further removes the linear correlation across features with 'whiten=True'.
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """  # noqa
 
     def __init__(self, copy=True, with_mean=True, with_std=True):
@@ -683,7 +709,12 @@ class MaxAbsScaler(BaseEstimator, TransformerMixin):
 
     See also
     --------
-    maxabs_scale: Equivalent function without the object oriented API.
+    maxabs_scale: Equivalent function without the estimator API.
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """
 
     def __init__(self, copy=True):
@@ -811,6 +842,11 @@ def maxabs_scale(X, axis=0, copy=True):
     --------
     MaxAbsScaler: Performs scaling to the [-1, 1] range using the``Transformer`` API
         (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`).
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """  # noqa
     # Unlike the scaler object, this function allows 1d input.
 
@@ -895,7 +931,7 @@ class RobustScaler(BaseEstimator, TransformerMixin):
 
     See also
     --------
-    robust_scale: Equivalent function without the object oriented API.
+    robust_scale: Equivalent function without the estimator API.
 
     :class:`sklearn.decomposition.PCA`
         Further removes the linear correlation across features with
@@ -903,7 +939,7 @@ class RobustScaler(BaseEstimator, TransformerMixin):
 
     Notes
     -----
-    See examples/preprocessing/plot_robust_scaling.py for an example.
+    See examples/preprocessing/plot_all_scaling.py for an example.
 
     https://en.wikipedia.org/wiki/Median_(statistics)
     https://en.wikipedia.org/wiki/Interquartile_range
@@ -1053,6 +1089,9 @@ def robust_scale(X, axis=0, with_centering=True, with_scaling=True,
 
     To avoid memory copy the caller should pass a CSR matrix.
 
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
+
     See also
     --------
     RobustScaler: Performs centering and scaling using the ``Transformer`` API
@@ -1269,6 +1308,11 @@ def normalize(X, norm='l2', axis=1, copy=True, return_norm=False):
     --------
     Normalizer: Performs normalization using the ``Transformer`` API
         (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`).
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
     """
     if norm not in ('l1', 'l2', 'max'):
         raise ValueError("'%s' is not a supported norm" % norm)
@@ -1352,9 +1396,12 @@ class Normalizer(BaseEstimator, TransformerMixin):
     This estimator is stateless (besides constructor parameters), the
     fit method does nothing but is useful when used in a pipeline.
 
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
+
     See also
     --------
-    normalize: Equivalent function without the object oriented API.
+    normalize: Equivalent function without the estimator API.
     """
 
     def __init__(self, norm='l2', copy=True):
@@ -1465,7 +1512,7 @@ class Binarizer(BaseEstimator, TransformerMixin):
 
     See also
     --------
-    binarize: Equivalent function without the object oriented API.
+    binarize: Equivalent function without the estimator API.
     """
 
     def __init__(self, threshold=0.0, copy=True):
@@ -1900,3 +1947,483 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
         """
         return _transform_selected(X, self._transform,
                                    self.categorical_features, copy=True)
+
+
+class QuantileTransformer(BaseEstimator, TransformerMixin):
+    """Transform features using quantiles information.
+
+    This method transforms the features to follow a uniform or a normal
+    distribution. Therefore, for a given feature, this transformation tends
+    to spread out the most frequent values. It also reduces the impact of
+    (marginal) outliers: this is therefore a robust preprocessing scheme.
+
+    The transformation is applied on each feature independently.
+    The cumulative density function of a feature is used to project the
+    original values. Features values of new/unseen data that fall below
+    or above the fitted range will be mapped to the bounds of the output
+    distribution. Note that this transform is non-linear. It may distort linear
+    correlations between variables measured at the same scale but renders
+    variables measured at different scales more directly comparable.
+
+    Read more in the :ref:`User Guide <preprocessing_transformer>`.
+
+    Parameters
+    ----------
+    n_quantiles : int, optional (default=1000)
+        Number of quantiles to be computed. It corresponds to the number
+        of landmarks used to discretize the cumulative density function.
+
+    output_distribution : str, optional (default='uniform')
+        Marginal distribution for the transformed data. The choices are
+        'uniform' (default) or 'normal'.
+
+    ignore_implicit_zeros : bool, optional (default=False)
+        Only applies to sparse matrices. If True, the sparse entries of the
+        matrix are discarded to compute the quantile statistics. If False,
+        these entries are treated as zeros.
+
+    subsample : int, optional (default=1e5)
+        Maximum number of samples used to estimate the quantiles for
+        computational efficiency. Note that the subsampling procedure may
+        differ for value-identical sparse and dense matrices.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by np.random. Note that this is used by subsampling and smoothing
+        noise.
+
+    copy : boolean, optional, (default=True)
+        Set to False to perform inplace transformation and avoid a copy (if the
+        input is already a numpy array).
+
+    Attributes
+    ----------
+    quantiles_ : ndarray, shape (n_quantiles, n_features)
+        The values corresponding the quantiles of reference.
+
+    references_ : ndarray, shape(n_quantiles, )
+        Quantiles of references.
+
+    Examples
+    --------
+    >>> import numpy as np
+    >>> from sklearn.preprocessing import QuantileTransformer
+    >>> rng = np.random.RandomState(0)
+    >>> X = np.sort(rng.normal(loc=0.5, scale=0.25, size=(25, 1)), axis=0)
+    >>> qt = QuantileTransformer(n_quantiles=10, random_state=0)
+    >>> qt.fit_transform(X) # doctest: +ELLIPSIS
+    array([...])
+
+    See also
+    --------
+    quantile_transform : Equivalent function without the estimator API.
+    StandardScaler : perform standardization that is faster, but less robust
+        to outliers.
+    RobustScaler : perform robust standardization that removes the influence
+        of outliers but does not put outliers and inliers on the same scale.
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
+
+    """
+
+    def __init__(self, n_quantiles=1000, output_distribution='uniform',
+                 ignore_implicit_zeros=False, subsample=int(1e5),
+                 random_state=None, copy=True):
+        self.n_quantiles = n_quantiles
+        self.output_distribution = output_distribution
+        self.ignore_implicit_zeros = ignore_implicit_zeros
+        self.subsample = subsample
+        self.random_state = random_state
+        self.copy = copy
+
+    def _dense_fit(self, X, random_state):
+        """Compute percentiles for dense matrices.
+
+        Parameters
+        ----------
+        X : ndarray, shape (n_samples, n_features)
+            The data used to scale along the features axis.
+        """
+        if self.ignore_implicit_zeros:
+            warnings.warn("'ignore_implicit_zeros' takes effect only with"
+                          " sparse matrix. This parameter has no effect.")
+
+        n_samples, n_features = X.shape
+        # for compatibility issue with numpy<=1.8.X, references
+        # need to be a list scaled between 0 and 100
+        references = (self.references_ * 100).tolist()
+        self.quantiles_ = []
+        for col in X.T:
+            if self.subsample < n_samples:
+                subsample_idx = random_state.choice(n_samples,
+                                                    size=self.subsample,
+                                                    replace=False)
+                col = col.take(subsample_idx, mode='clip')
+            self.quantiles_.append(np.percentile(col, references))
+        self.quantiles_ = np.transpose(self.quantiles_)
+
+    def _sparse_fit(self, X, random_state):
+        """Compute percentiles for sparse matrices.
+
+        Parameters
+        ----------
+        X : sparse matrix CSC, shape (n_samples, n_features)
+            The data used to scale along the features axis. The sparse matrix
+            needs to be nonnegative.
+        """
+        n_samples, n_features = X.shape
+
+        # for compatibility issue with numpy<=1.8.X, references
+        # need to be a list scaled between 0 and 100
+        references = list(map(lambda x: x * 100, self.references_))
+        self.quantiles_ = []
+        for feature_idx in range(n_features):
+            column_nnz_data = X.data[X.indptr[feature_idx]:
+                                     X.indptr[feature_idx + 1]]
+            if len(column_nnz_data) > self.subsample:
+                column_subsample = (self.subsample * len(column_nnz_data) //
+                                    n_samples)
+                if self.ignore_implicit_zeros:
+                    column_data = np.zeros(shape=column_subsample,
+                                           dtype=X.dtype)
+                else:
+                    column_data = np.zeros(shape=self.subsample, dtype=X.dtype)
+                column_data[:column_subsample] = random_state.choice(
+                    column_nnz_data, size=column_subsample, replace=False)
+            else:
+                if self.ignore_implicit_zeros:
+                    column_data = np.zeros(shape=len(column_nnz_data),
+                                           dtype=X.dtype)
+                else:
+                    column_data = np.zeros(shape=n_samples, dtype=X.dtype)
+                column_data[:len(column_nnz_data)] = column_nnz_data
+
+            if not column_data.size:
+                # if no nnz, an error will be raised for computing the
+                # quantiles. Force the quantiles to be zeros.
+                self.quantiles_.append([0] * len(references))
+            else:
+                self.quantiles_.append(
+                    np.percentile(column_data, references))
+        self.quantiles_ = np.transpose(self.quantiles_)
+
+    def fit(self, X, y=None):
+        """Compute the quantiles used for transforming.
+
+        Parameters
+        ----------
+        X : ndarray or sparse matrix, shape (n_samples, n_features)
+            The data used to scale along the features axis. If a sparse
+            matrix is provided, it will be converted into a sparse
+            ``csc_matrix``. Additionally, the sparse matrix needs to be
+            nonnegative if `ignore_implicit_zeros` is False.
+
+        Returns
+        -------
+        self : object
+            Returns self
+        """
+        if self.n_quantiles <= 0:
+            raise ValueError("Invalid value for 'n_quantiles': %d. "
+                             "The number of quantiles must be at least one."
+                             % self.n_quantiles)
+
+        if self.subsample <= 0:
+            raise ValueError("Invalid value for 'subsample': %d. "
+                             "The number of subsamples must be at least one."
+                             % self.subsample)
+
+        if self.n_quantiles > self.subsample:
+            raise ValueError("The number of quantiles cannot be greater than"
+                             " the number of samples used. Got {} quantiles"
+                             " and {} samples.".format(self.n_quantiles,
+                                                       self.subsample))
+
+        X = self._check_inputs(X)
+        rng = check_random_state(self.random_state)
+
+        # Create the quantiles of reference
+        self.references_ = np.linspace(0, 1, self.n_quantiles,
+                                       endpoint=True)
+        if sparse.issparse(X):
+            self._sparse_fit(X, rng)
+        else:
+            self._dense_fit(X, rng)
+
+        return self
+
+    def _transform_col(self, X_col, quantiles, inverse):
+        """Private function to transform a single feature"""
+
+        if self.output_distribution == 'normal':
+            output_distribution = 'norm'
+        else:
+            output_distribution = self.output_distribution
+        output_distribution = getattr(stats, output_distribution)
+
+        # older version of scipy do not handle tuple as fill_value
+        # clipping the value before transform solve the issue
+        if not inverse:
+            lower_bound_x = quantiles[0]
+            upper_bound_x = quantiles[-1]
+            lower_bound_y = 0
+            upper_bound_y = 1
+        else:
+            lower_bound_x = 0
+            upper_bound_x = 1
+            lower_bound_y = quantiles[0]
+            upper_bound_y = quantiles[-1]
+            #  for inverse transform, match a uniform PDF
+            X_col = output_distribution.cdf(X_col)
+        # find index for lower and higher bounds
+        lower_bounds_idx = (X_col - BOUNDS_THRESHOLD <
+                            lower_bound_x)
+        upper_bounds_idx = (X_col + BOUNDS_THRESHOLD >
+                            upper_bound_x)
+
+        if not inverse:
+            # Interpolate in one direction and in the other and take the
+            # mean. This is in case of repeated values in the features
+            # and hence repeated quantiles
+            #
+            # If we don't do this, only one extreme of the duplicated is
+            # used (the upper when we do assending, and the
+            # lower for descending). We take the mean of these two
+            X_col = .5 * (np.interp(X_col, quantiles, self.references_)
+                          - np.interp(-X_col, -quantiles[::-1],
+                                      -self.references_[::-1]))
+        else:
+            X_col = np.interp(X_col, self.references_, quantiles)
+
+        X_col[upper_bounds_idx] = upper_bound_y
+        X_col[lower_bounds_idx] = lower_bound_y
+        # for forward transform, match the output PDF
+        if not inverse:
+            X_col = output_distribution.ppf(X_col)
+            # find the value to clip the data to avoid mapping to
+            # infinity. Clip such that the inverse transform will be
+            # consistent
+            clip_min = output_distribution.ppf(BOUNDS_THRESHOLD -
+                                               np.spacing(1))
+            clip_max = output_distribution.ppf(1 - (BOUNDS_THRESHOLD -
+                                                    np.spacing(1)))
+            X_col = np.clip(X_col, clip_min, clip_max)
+
+        return X_col
+
+    def _check_inputs(self, X, accept_sparse_negative=False):
+        """Check inputs before fit and transform"""
+        X = check_array(X, accept_sparse='csc', copy=self.copy,
+                        dtype=[np.float64, np.float32])
+        # we only accept positive sparse matrix when ignore_implicit_zeros is
+        # false and that we call fit or transform.
+        if (not accept_sparse_negative and not self.ignore_implicit_zeros and
+                (sparse.issparse(X) and np.any(X.data < 0))):
+            raise ValueError('QuantileTransformer only accepts non-negative'
+                             ' sparse matrices.')
+
+        # check the output PDF
+        if self.output_distribution not in ('normal', 'uniform'):
+            raise ValueError("'output_distribution' has to be either 'normal'"
+                             " or 'uniform'. Got '{}' instead.".format(
+                                 self.output_distribution))
+
+        return X
+
+    def _check_is_fitted(self, X):
+        """Check the inputs before transforming"""
+        check_is_fitted(self, 'quantiles_')
+        # check that the dimension of X are adequate with the fitted data
+        if X.shape[1] != self.quantiles_.shape[1]:
+            raise ValueError('X does not have the same number of features as'
+                             ' the previously fitted data. Got {} instead of'
+                             ' {}.'.format(X.shape[1],
+                                           self.quantiles_.shape[1]))
+
+    def _transform(self, X, inverse=False):
+        """Forward and inverse transform.
+
+        Parameters
+        ----------
+        X : ndarray, shape (n_samples, n_features)
+            The data used to scale along the features axis.
+
+        inverse : bool, optional (default=False)
+            If False, apply forward transform. If True, apply
+            inverse transform.
+
+        Returns
+        -------
+        X : ndarray, shape (n_samples, n_features)
+            Projected data
+        """
+
+        if sparse.issparse(X):
+            for feature_idx in range(X.shape[1]):
+                column_slice = slice(X.indptr[feature_idx],
+                                     X.indptr[feature_idx + 1])
+                X.data[column_slice] = self._transform_col(
+                    X.data[column_slice], self.quantiles_[:, feature_idx],
+                    inverse)
+        else:
+            for feature_idx in range(X.shape[1]):
+                X[:, feature_idx] = self._transform_col(
+                    X[:, feature_idx], self.quantiles_[:, feature_idx],
+                    inverse)
+
+        return X
+
+    def transform(self, X):
+        """Feature-wise transformation of the data.
+
+        Parameters
+        ----------
+        X : ndarray or sparse matrix, shape (n_samples, n_features)
+            The data used to scale along the features axis. If a sparse
+            matrix is provided, it will be converted into a sparse
+            ``csc_matrix``. Additionally, the sparse matrix needs to be
+            nonnegative if `ignore_implicit_zeros` is False.
+
+        Returns
+        -------
+        Xt : ndarray or sparse matrix, shape (n_samples, n_features)
+            The projected data.
+        """
+        X = self._check_inputs(X)
+        self._check_is_fitted(X)
+
+        return self._transform(X, inverse=False)
+
+    def inverse_transform(self, X):
+        """Back-projection to the original space.
+
+        X : ndarray or sparse matrix, shape (n_samples, n_features)
+            The data used to scale along the features axis. If a sparse
+            matrix is provided, it will be converted into a sparse
+            ``csc_matrix``. Additionally, the sparse matrix needs to be
+            nonnegative if `ignore_implicit_zeros` is False.
+
+        Returns
+        -------
+        Xt : ndarray or sparse matrix, shape (n_samples, n_features)
+            The projected data.
+        """
+        X = self._check_inputs(X, accept_sparse_negative=True)
+        self._check_is_fitted(X)
+
+        return self._transform(X, inverse=True)
+
+
+def quantile_transform(X, axis=0, n_quantiles=1000,
+                       output_distribution='uniform',
+                       ignore_implicit_zeros=False,
+                       subsample=int(1e5),
+                       random_state=None,
+                       copy=False):
+    """Transform features using quantiles information.
+
+    This method transforms the features to follow a uniform or a normal
+    distribution. Therefore, for a given feature, this transformation tends
+    to spread out the most frequent values. It also reduces the impact of
+    (marginal) outliers: this is therefore a robust preprocessing scheme.
+
+    The transformation is applied on each feature independently.
+    The cumulative density function of a feature is used to project the
+    original values. Features values of new/unseen data that fall below
+    or above the fitted range will be mapped to the bounds of the output
+    distribution. Note that this transform is non-linear. It may distort linear
+    correlations between variables measured at the same scale but renders
+    variables measured at different scales more directly comparable.
+
+    Read more in the :ref:`User Guide <preprocessing_transformer>`.
+
+    Parameters
+    ----------
+    X : array-like, sparse matrix
+        The data to transform.
+
+    axis : int, (default=0)
+        Axis used to compute the means and standard deviations along. If 0,
+        transform each feature, otherwise (if 1) transform each sample.
+
+    n_quantiles : int, optional (default=1000)
+        Number of quantiles to be computed. It corresponds to the number
+        of landmarks used to discretize the cumulative density function.
+
+    output_distribution : str, optional (default='uniform')
+        Marginal distribution for the transformed data. The choices are
+        'uniform' (default) or 'normal'.
+
+    ignore_implicit_zeros : bool, optional (default=False)
+        Only applies to sparse matrices. If True, the sparse entries of the
+        matrix are discarded to compute the quantile statistics. If False,
+        these entries are treated as zeros.
+
+    subsample : int, optional (default=1e5)
+        Maximum number of samples used to estimate the quantiles for
+        computational efficiency. Note that the subsampling procedure may
+        differ for value-identical sparse and dense matrices.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by np.random. Note that this is used by subsampling and smoothing
+        noise.
+
+    copy : boolean, optional, (default=True)
+        Set to False to perform inplace transformation and avoid a copy (if the
+        input is already a numpy array).
+
+    Attributes
+    ----------
+    quantiles_ : ndarray, shape (n_quantiles, n_features)
+        The values corresponding the quantiles of reference.
+
+    references_ : ndarray, shape(n_quantiles, )
+        Quantiles of references.
+
+    Examples
+    --------
+    >>> import numpy as np
+    >>> from sklearn.preprocessing import quantile_transform
+    >>> rng = np.random.RandomState(0)
+    >>> X = np.sort(rng.normal(loc=0.5, scale=0.25, size=(25, 1)), axis=0)
+    >>> quantile_transform(X, n_quantiles=10, random_state=0)
+    ... # doctest: +ELLIPSIS
+    array([...])
+
+    See also
+    --------
+    QuantileTransformer : Performs quantile-based scaling using the
+        ``Transformer`` API (e.g. as part of a preprocessing
+        :class:`sklearn.pipeline.Pipeline`).
+    scale : perform standardization that is faster, but less robust
+        to outliers.
+    robust_scale : perform robust standardization that removes the influence
+        of outliers but does not put outliers and inliers on the same scale.
+
+    Notes
+    -----
+    See examples/preprocessing/plot_all_scaling.py for a comparison of the
+    different scalers, transformers, and normalizers.
+
+    """
+    n = QuantileTransformer(n_quantiles=n_quantiles,
+                            output_distribution=output_distribution,
+                            subsample=subsample,
+                            ignore_implicit_zeros=ignore_implicit_zeros,
+                            random_state=random_state,
+                            copy=copy)
+    if axis == 0:
+        return n.fit_transform(X)
+    elif axis == 1:
+        return n.fit_transform(X.T).T
+    else:
+        raise ValueError("axis should be either equal to 0 or 1. Got"
+                         " axis={}".format(axis))
diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py
index 7a51049b6024288f77e40877065d040e4d322c13..af7f28f8162c6121586098831c821546596b6140 100644
--- a/sklearn/preprocessing/tests/test_data.py
+++ b/sklearn/preprocessing/tests/test_data.py
@@ -1,9 +1,9 @@
-
 # Authors:
 #
 #          Giorgio Patrini
 #
 # License: BSD 3 clause
+from __future__ import division
 
 import warnings
 import numpy as np
@@ -42,6 +42,8 @@ from sklearn.preprocessing.data import StandardScaler
 from sklearn.preprocessing.data import scale
 from sklearn.preprocessing.data import MinMaxScaler
 from sklearn.preprocessing.data import minmax_scale
+from sklearn.preprocessing.data import QuantileTransformer
+from sklearn.preprocessing.data import quantile_transform
 from sklearn.preprocessing.data import MaxAbsScaler
 from sklearn.preprocessing.data import maxabs_scale
 from sklearn.preprocessing.data import RobustScaler
@@ -141,7 +143,8 @@ def test_polynomial_feature_names():
                         'b c^2', 'c^3'], feature_names)
     # test some unicode
     poly = PolynomialFeatures(degree=1, include_bias=True).fit(X)
-    feature_names = poly.get_feature_names([u"\u0001F40D", u"\u262E", u"\u05D0"])
+    feature_names = poly.get_feature_names(
+        [u"\u0001F40D", u"\u262E", u"\u05D0"])
     assert_array_equal([u"1", u"\u0001F40D", u"\u262E", u"\u05D0"],
                        feature_names)
 
@@ -851,6 +854,328 @@ def test_robust_scaler_iris_quantiles():
     assert_array_almost_equal(q_range, 1)
 
 
+def test_quantile_transform_iris():
+    X = iris.data
+    # uniform output distribution
+    transformer = QuantileTransformer(n_quantiles=30)
+    X_trans = transformer.fit_transform(X)
+    X_trans_inv = transformer.inverse_transform(X_trans)
+    assert_array_almost_equal(X, X_trans_inv)
+    # normal output distribution
+    transformer = QuantileTransformer(n_quantiles=30,
+                                      output_distribution='normal')
+    X_trans = transformer.fit_transform(X)
+    X_trans_inv = transformer.inverse_transform(X_trans)
+    assert_array_almost_equal(X, X_trans_inv)
+    # make sure it is possible to take the inverse of a sparse matrix
+    # which contain negative value; this is the case in the iris dataset
+    X_sparse = sparse.csc_matrix(X)
+    X_sparse_tran = transformer.fit_transform(X_sparse)
+    X_sparse_tran_inv = transformer.inverse_transform(X_sparse_tran)
+    assert_array_almost_equal(X_sparse.A, X_sparse_tran_inv.A)
+
+
+def test_quantile_transform_check_error():
+    X = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100],
+                      [2, 4, 0, 0, 6, 8, 0, 10, 0, 0],
+                      [0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
+    X = sparse.csc_matrix(X)
+    X_neg = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100],
+                          [-2, 4, 0, 0, 6, 8, 0, 10, 0, 0],
+                          [0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
+    X_neg = sparse.csc_matrix(X_neg)
+
+    assert_raises_regex(ValueError, "Invalid value for 'n_quantiles': 0.",
+                        QuantileTransformer(n_quantiles=0).fit, X)
+    assert_raises_regex(ValueError, "Invalid value for 'subsample': 0.",
+                        QuantileTransformer(subsample=0).fit, X)
+    assert_raises_regex(ValueError, "The number of quantiles cannot be"
+                        " greater than the number of samples used. Got"
+                        " 1000 quantiles and 10 samples.",
+                        QuantileTransformer(subsample=10).fit, X)
+
+    transformer = QuantileTransformer(n_quantiles=10)
+    assert_raises_regex(ValueError, "QuantileTransformer only accepts "
+                        "non-negative sparse matrices.",
+                        transformer.fit, X_neg)
+    transformer.fit(X)
+    assert_raises_regex(ValueError, "QuantileTransformer only accepts "
+                        "non-negative sparse matrices.",
+                        transformer.transform, X_neg)
+
+    X_bad_feat = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100],
+                               [0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
+    assert_raises_regex(ValueError, "X does not have the same number of "
+                        "features as the previously fitted data. Got 2"
+                        " instead of 3.",
+                        transformer.transform, X_bad_feat)
+    assert_raises_regex(ValueError, "X does not have the same number of "
+                        "features as the previously fitted data. Got 2"
+                        " instead of 3.",
+                        transformer.inverse_transform, X_bad_feat)
+
+    transformer = QuantileTransformer(n_quantiles=10,
+                                      output_distribution='rnd')
+    # check that an error is raised at fit time
+    assert_raises_regex(ValueError, "'output_distribution' has to be either"
+                        " 'normal' or 'uniform'. Got 'rnd' instead.",
+                        transformer.fit, X)
+    # check that an error is raised at transform time
+    transformer.output_distribution = 'uniform'
+    transformer.fit(X)
+    X_tran = transformer.transform(X)
+    transformer.output_distribution = 'rnd'
+    assert_raises_regex(ValueError, "'output_distribution' has to be either"
+                        " 'normal' or 'uniform'. Got 'rnd' instead.",
+                        transformer.transform, X)
+    # check that an error is raised at inverse_transform time
+    assert_raises_regex(ValueError, "'output_distribution' has to be either"
+                        " 'normal' or 'uniform'. Got 'rnd' instead.",
+                        transformer.inverse_transform, X_tran)
+
+
+def test_quantile_transform_sparse_ignore_zeros():
+    X = np.array([[0, 1],
+                  [0, 0],
+                  [0, 2],
+                  [0, 2],
+                  [0, 1]])
+    X_sparse = sparse.csc_matrix(X)
+    transformer = QuantileTransformer(ignore_implicit_zeros=True,
+                                      n_quantiles=5)
+
+    # dense case -> warning raise
+    assert_warns_message(UserWarning, "'ignore_implicit_zeros' takes effect"
+                         " only with sparse matrix. This parameter has no"
+                         " effect.", transformer.fit, X)
+
+    X_expected = np.array([[0, 0],
+                           [0, 0],
+                           [0, 1],
+                           [0, 1],
+                           [0, 0]])
+    X_trans = transformer.fit_transform(X_sparse)
+    assert_almost_equal(X_expected, X_trans.A)
+
+    # consider the case where sparse entries are missing values and user-given
+    # zeros are to be considered
+    X_data = np.array([0, 0, 1, 0, 2, 2, 1, 0, 1, 2, 0])
+    X_col = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
+    X_row = np.array([0, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8])
+    X_sparse = sparse.csc_matrix((X_data, (X_row, X_col)))
+    X_trans = transformer.fit_transform(X_sparse)
+    X_expected = np.array([[0., 0.5],
+                           [0., 0.],
+                           [0., 1.],
+                           [0., 1.],
+                           [0., 0.5],
+                           [0., 0.],
+                           [0., 0.5],
+                           [0., 1.],
+                           [0., 0.]])
+    assert_almost_equal(X_expected, X_trans.A)
+
+    transformer = QuantileTransformer(ignore_implicit_zeros=True,
+                                      n_quantiles=5)
+    X_data = np.array([-1, -1, 1, 0, 0, 0, 1, -1, 1])
+    X_col = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1])
+    X_row = np.array([0, 4, 0, 1, 2, 3, 4, 5, 6])
+    X_sparse = sparse.csc_matrix((X_data, (X_row, X_col)))
+    X_trans = transformer.fit_transform(X_sparse)
+    X_expected = np.array([[0, 1],
+                           [0, 0.375],
+                           [0, 0.375],
+                           [0, 0.375],
+                           [0, 1],
+                           [0, 0],
+                           [0, 1]])
+    assert_almost_equal(X_expected, X_trans.A)
+    assert_almost_equal(X_sparse.A, transformer.inverse_transform(X_trans).A)
+
+    # check in conjunction with subsampling
+    transformer = QuantileTransformer(ignore_implicit_zeros=True,
+                                      n_quantiles=5,
+                                      subsample=8,
+                                      random_state=0)
+    X_trans = transformer.fit_transform(X_sparse)
+    assert_almost_equal(X_expected, X_trans.A)
+    assert_almost_equal(X_sparse.A, transformer.inverse_transform(X_trans).A)
+
+
+def test_quantile_transform_dense_toy():
+    X = np.array([[0, 2, 2.6],
+                  [25, 4, 4.1],
+                  [50, 6, 2.3],
+                  [75, 8, 9.5],
+                  [100, 10, 0.1]])
+
+    transformer = QuantileTransformer(n_quantiles=5)
+    transformer.fit(X)
+
+    # using the a uniform output, each entry of X should be map between 0 and 1
+    # and equally spaced
+    X_trans = transformer.fit_transform(X)
+    X_expected = np.tile(np.linspace(0, 1, num=5), (3, 1)).T
+    assert_almost_equal(np.sort(X_trans, axis=0), X_expected)
+
+    X_test = np.array([
+        [-1, 1, 0],
+        [101, 11, 10],
+    ])
+    X_expected = np.array([
+        [0, 0, 0],
+        [1, 1, 1],
+    ])
+    assert_array_almost_equal(transformer.transform(X_test), X_expected)
+
+    X_trans_inv = transformer.inverse_transform(X_trans)
+    assert_array_almost_equal(X, X_trans_inv)
+
+
+def test_quantile_transform_subsampling():
+    # Test that subsampling the input yield to a consistent results We check
+    # that the computed quantiles are almost mapped to a [0, 1] vector where
+    # values are equally spaced. The infinite norm is checked to be smaller
+    # than a given threshold. This is repeated 5 times.
+
+    # dense support
+    n_samples = 1000000
+    n_quantiles = 1000
+    X = np.sort(np.random.sample((n_samples, 1)), axis=0)
+    ROUND = 5
+    inf_norm_arr = []
+    for random_state in range(ROUND):
+        transformer = QuantileTransformer(random_state=random_state,
+                                          n_quantiles=n_quantiles,
+                                          subsample=n_samples // 10)
+        transformer.fit(X)
+        diff = (np.linspace(0, 1, n_quantiles) -
+                np.ravel(transformer.quantiles_))
+        inf_norm = np.max(np.abs(diff))
+        assert_true(inf_norm < 1e-2)
+        inf_norm_arr.append(inf_norm)
+    # each random subsampling yield a unique approximation to the expected
+    # linspace CDF
+    assert_equal(len(np.unique(inf_norm_arr)), len(inf_norm_arr))
+
+    # sparse support
+
+    # TODO: rng should be seeded once we drop support for older versions of
+    # scipy (< 0.13) that don't support seeding.
+    X = sparse.rand(n_samples, 1, density=.99, format='csc')
+    inf_norm_arr = []
+    for random_state in range(ROUND):
+        transformer = QuantileTransformer(random_state=random_state,
+                                          n_quantiles=n_quantiles,
+                                          subsample=n_samples // 10)
+        transformer.fit(X)
+        diff = (np.linspace(0, 1, n_quantiles) -
+                np.ravel(transformer.quantiles_))
+        inf_norm = np.max(np.abs(diff))
+        assert_true(inf_norm < 1e-1)
+        inf_norm_arr.append(inf_norm)
+    # each random subsampling yield a unique approximation to the expected
+    # linspace CDF
+    assert_equal(len(np.unique(inf_norm_arr)), len(inf_norm_arr))
+
+
+def test_quantile_transform_sparse_toy():
+    X = np.array([[0., 2., 0.],
+                  [25., 4., 0.],
+                  [50., 0., 2.6],
+                  [0., 0., 4.1],
+                  [0., 6., 0.],
+                  [0., 8., 0.],
+                  [75., 0., 2.3],
+                  [0., 10., 0.],
+                  [0., 0., 9.5],
+                  [100., 0., 0.1]])
+
+    X = sparse.csc_matrix(X)
+
+    transformer = QuantileTransformer(n_quantiles=10)
+    transformer.fit(X)
+
+    X_trans = transformer.fit_transform(X)
+    assert_array_almost_equal(np.min(X_trans.toarray(), axis=0), 0.)
+    assert_array_almost_equal(np.max(X_trans.toarray(), axis=0), 1.)
+
+    X_trans_inv = transformer.inverse_transform(X_trans)
+    assert_array_almost_equal(X.toarray(), X_trans_inv.toarray())
+
+    transformer_dense = QuantileTransformer(n_quantiles=10).fit(
+        X.toarray())
+
+    X_trans = transformer_dense.transform(X)
+    assert_array_almost_equal(np.min(X_trans.toarray(), axis=0), 0.)
+    assert_array_almost_equal(np.max(X_trans.toarray(), axis=0), 1.)
+
+    X_trans_inv = transformer_dense.inverse_transform(X_trans)
+    assert_array_almost_equal(X.toarray(), X_trans_inv.toarray())
+
+
+def test_quantile_transform_axis1():
+    X = np.array([[0, 25, 50, 75, 100],
+                  [2, 4, 6, 8, 10],
+                  [2.6, 4.1, 2.3, 9.5, 0.1]])
+
+    X_trans_a0 = quantile_transform(X.T, axis=0, n_quantiles=5)
+    X_trans_a1 = quantile_transform(X, axis=1, n_quantiles=5)
+    assert_array_almost_equal(X_trans_a0, X_trans_a1.T)
+
+
+def test_quantile_transform_bounds():
+    # Lower and upper bounds are manually mapped. We checked that in the case
+    # of a constant feature and binary feature, the bounds are properly mapped.
+    X_dense = np.array([[0, 0],
+                        [0, 0],
+                        [1, 0]])
+    X_sparse = sparse.csc_matrix(X_dense)
+
+    # check sparse and dense are consistent
+    X_trans = QuantileTransformer(n_quantiles=3,
+                                  random_state=0).fit_transform(X_dense)
+    assert_array_almost_equal(X_trans, X_dense)
+    X_trans_sp = QuantileTransformer(n_quantiles=3,
+                                     random_state=0).fit_transform(X_sparse)
+    assert_array_almost_equal(X_trans_sp.A, X_dense)
+    assert_array_almost_equal(X_trans, X_trans_sp.A)
+
+    # check the consistency of the bounds by learning on 1 matrix
+    # and transforming another
+    X = np.array([[0, 1],
+                  [0, 0.5],
+                  [1, 0]])
+    X1 = np.array([[0, 0.1],
+                   [0, 0.5],
+                   [1, 0.1]])
+    transformer = QuantileTransformer(n_quantiles=3).fit(X)
+    X_trans = transformer.transform(X1)
+    assert_array_almost_equal(X_trans, X1)
+
+    # check that values outside of the range learned will be mapped properly.
+    X = np.random.random((1000, 1))
+    transformer = QuantileTransformer()
+    transformer.fit(X)
+    assert_equal(transformer.transform(-10), transformer.transform(np.min(X)))
+    assert_equal(transformer.transform(10), transformer.transform(np.max(X)))
+    assert_equal(transformer.inverse_transform(-10),
+                 transformer.inverse_transform(
+                     np.min(transformer.references_)))
+    assert_equal(transformer.inverse_transform(10),
+                 transformer.inverse_transform(
+                     np.max(transformer.references_)))
+
+
+def test_quantile_transform_and_inverse():
+    # iris dataset
+    X = iris.data
+    transformer = QuantileTransformer(n_quantiles=1000, random_state=0)
+    X_trans = transformer.fit_transform(X)
+    X_trans_inv = transformer.inverse_transform(X_trans)
+    assert_array_almost_equal(X, X_trans_inv)
+
+
 def test_robust_scaler_invalid_range():
     for range_ in [
         (-1, 90),
@@ -1641,3 +1966,12 @@ def test_fit_cold_start():
         # with a different shape, this may break the scaler unless the internal
         # state is reset
         scaler.fit_transform(X_2d)
+
+
+def test_quantile_transform_valid_axis():
+    X = np.array([[0, 25, 50, 75, 100],
+                  [2, 4, 6, 8, 10],
+                  [2.6, 4.1, 2.3, 9.5, 0.1]])
+
+    assert_raises_regex(ValueError, "axis should be either equal to 0 or 1"
+                        ". Got axis=2", quantile_transform, X.T, axis=2)