diff --git a/benchmarks/bench_isolation_forest.py b/benchmarks/bench_isolation_forest.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cfd484a3ab44a8fa7944fbb384a6ee7afadaf1c
--- /dev/null
+++ b/benchmarks/bench_isolation_forest.py
@@ -0,0 +1,108 @@
+"""
+==========================================
+IsolationForest benchmark
+==========================================
+
+A test of IsolationForest on classical anomaly detection datasets.
+
+"""
+print(__doc__)
+
+from time import time
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.ensemble import IsolationForest
+from sklearn.metrics import roc_curve, auc
+from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata
+from sklearn.preprocessing import LabelBinarizer
+from sklearn.utils import shuffle as sh
+
+np.random.seed(1)
+
+
+datasets = ['http']#, 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
+
+for dat in datasets:
+    # loading and vectorization
+    print('loading data')
+    if dat in ['http', 'smtp', 'SA', 'SF']:
+        dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True)
+        X = dataset.data
+        y = dataset.target
+
+    if dat == 'shuttle':
+        dataset = fetch_mldata('shuttle')
+        X = dataset.data
+        y = dataset.target
+        sh(X, y)
+        # we remove data with label 4
+        # normal data are then those of class 1
+        s = (y != 4)
+        X = X[s, :]
+        y = y[s]
+        y = (y != 1).astype(int)
+
+    if dat == 'forestcover':
+        dataset = fetch_covtype(shuffle=True)
+        X = dataset.data
+        y = dataset.target
+        # normal data are those with attribute 2
+        # abnormal those with attribute 4
+        s = (y == 2) + (y == 4)
+        X = X[s, :]
+        y = y[s]
+        y = (y != 2).astype(int)
+
+    print('vectorizing data')
+
+    if dat == 'SF':
+        lb = LabelBinarizer()
+        lb.fit(X[:, 1])
+        x1 = lb.transform(X[:, 1])
+        X = np.c_[X[:, :1], x1, X[:, 2:]]
+        y = (y != 'normal.').astype(int)
+
+    if dat == 'SA':
+        lb = LabelBinarizer()
+        lb.fit(X[:, 1])
+        x1 = lb.transform(X[:, 1])
+        lb.fit(X[:, 2])
+        x2 = lb.transform(X[:, 2])
+        lb.fit(X[:, 3])
+        x3 = lb.transform(X[:, 3])
+        X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
+        y = (y != 'normal.').astype(int)
+
+    if dat == 'http' or dat == 'smtp':
+        y = (y != 'normal.').astype(int)
+
+    n_samples, n_features = np.shape(X)
+    n_samples_train = n_samples // 2
+    n_samples_test = n_samples - n_samples_train
+
+    X = X.astype(float)
+    X_train = X[:n_samples_train, :]
+    X_test = X[n_samples_train:, :]
+    y_train = y[:n_samples_train]
+    y_test = y[n_samples_train:]
+
+    print('IsolationForest processing...')
+    model = IsolationForest(bootstrap=True, n_jobs=-1)
+    tstart = time()
+    model.fit(X_train)
+    fit_time = time() - tstart
+    tstart = time()
+
+    scoring = model.predict(X_test)  # the lower, the more normal
+    predict_time = time() - tstart
+    fpr, tpr, thresholds = roc_curve(y_test, scoring)
+    AUC = auc(fpr, tpr)
+    plt.plot(fpr, tpr, lw=1, label='ROC for %s (area = %0.3f, train-time: %0.2fs, test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time))
+
+plt.xlim([-0.05, 1.05])
+plt.ylim([-0.05, 1.05])
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('Receiver operating characteristic')
+plt.legend(loc="lower right")
+plt.show()
diff --git a/doc/datasets/kddcup99.rst b/doc/datasets/kddcup99.rst
new file mode 100644
index 0000000000000000000000000000000000000000..fadc41c85c3be5e99cc67653212401d211e9a15a
--- /dev/null
+++ b/doc/datasets/kddcup99.rst
@@ -0,0 +1,36 @@
+
+.. _kddcup99:
+
+Kddcup 99 dataset
+=================
+
+The KDD Cup '99 dataset was created by processing the tcpdump portions
+of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset,
+created by MIT Lincoln Lab. The artificial data (described on the `dataset's
+homepage <http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html>`_) was
+generated using a closed network and hand-injected attacks to produce a
+large number of different types of attack with normal activity in the
+background. As the initial goal was to produce a large training set for
+supervised learning algorithms, there is a large proportion (80.1%) of
+abnormal data which is unrealistic in real world, and inapropriate for
+unsupervised anomaly detection which aims at detecting 'abnormal' data, ie
+1) qualitatively different from normal data
+2) in large minority among the observations.
+We thus transform the KDD Data set into two differents data set: SA and SF.
+
+-SA is obtained by simply selecting all the normal data, and a small
+proportion of abnormal data to gives an anomaly proportion of 1%.
+
+-SF is obtained as in [2]
+by simply picking up the data whose attribute logged_in is positive, thus
+focusing on the intrusion attack, which gives a proportion of 0.3% of
+attack.
+
+-http and smtp are two subsets of SF corresponding with third feature
+equal to 'http' (resp. to 'smtp')
+
+:func:`sklearn.datasets.fetch_kddcup99` will load the kddcup99 dataset;
+it returns a dictionary-like object
+with the feature matrix in the ``data`` member
+and the target values in ``target``.
+The dataset will be downloaded from the web if necessary.
diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 17842e70a3d6867505fbbd8b7ab8b8358c69e739..e3172df56597502de551bcb8d688428b30b9a04c 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -221,6 +221,7 @@ Loaders
    datasets.fetch_olivetti_faces
    datasets.fetch_california_housing
    datasets.fetch_covtype
+   datasets.fetch_kddcup99
    datasets.fetch_rcv1
    datasets.load_mlcomp
    datasets.load_sample_image
@@ -351,6 +352,7 @@ Samples generator
    ensemble.ExtraTreesRegressor
    ensemble.GradientBoostingClassifier
    ensemble.GradientBoostingRegressor
+   ensemble.IsolationForest
    ensemble.RandomForestClassifier
    ensemble.RandomTreesEmbedding
    ensemble.RandomForestRegressor
diff --git a/doc/modules/outlier_detection.rst b/doc/modules/outlier_detection.rst
index a99758989e195cc098db10afc46d96644564a52d..d2a26f779829d8b9642990bd07aeeeaa25ab2df8 100644
--- a/doc/modules/outlier_detection.rst
+++ b/doc/modules/outlier_detection.rst
@@ -192,4 +192,45 @@ multiple modes.
      an outlier detection method) and a covariance-based outlier
      detection with :class:`covariance.MinCovDet`.
 
+Isolation Forest
+----------------------------
+
+One efficient way of performing outlier detection in high-dimensional datasets
+is to use random forests.
+:class:`ensemble.IsolationForest` consists in 'isolating' the observations
+by randomly selecting a feature and then randomly selecting a split value
+between the maximum and minimum values of the selected feature.
+
+Since recursive partitioning can be represented by a tree structure, the
+number of splitting required to isolate a point is equivalent to the path
+length from the root node to a terminating node.
+
+This path length, averaged among a forest of such random trees, is a
+measure of abnormality and our decision function.
+
+Indeed random partitioning produces noticeable shorter paths for anomalies.
+Hence, when a forest of random trees collectively produce shorter path
+lengths for some particular points, then they are highly likely to be
+anomalies.
+
+This strategy is illustrated below.
+
+.. figure:: ../auto_examples/ensemble/images/plot_isolation_forest_001.png
+   :target: ../auto_examples/ensemble/plot_isolation_forest.html
+   :align: center
+   :scale: 75%
+
+.. topic:: Examples:
 
+   * See :ref:`example_ensemble_plot_isolation_forest.py` for
+     an illustration of the use of IsolationForest.
+
+   * See :ref:`example_covariance_plot_outlier_detection.py` for a
+     comparison of :class:`ensemble.IsolationForest` with
+     :class:`svm.OneClassSVM` (tuned to perform like an outlier detection
+     method) and a covariance-based outlier detection with
+     :class:`covariance.MinCovDet`.
+
+.. topic:: References:
+    .. [LTZ2008] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
+           Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
diff --git a/examples/covariance/plot_outlier_detection.py b/examples/covariance/plot_outlier_detection.py
index fefa666fe00f405ec00285714ae18ed32ece54c1..0ac61a5fd15d761e2ea35912f100af096b5961a5 100644
--- a/examples/covariance/plot_outlier_detection.py
+++ b/examples/covariance/plot_outlier_detection.py
@@ -3,7 +3,7 @@
 Outlier detection with several methods.
 ==========================================
 
-When the amount of contamination is known, this example illustrates two
+When the amount of contamination is known, this example illustrates three
 different ways of performing :ref:`outlier_detection`:
 
 - based on a robust estimator of covariance, which is assuming that the
@@ -14,6 +14,10 @@ different ways of performing :ref:`outlier_detection`:
   data set, hence performing better when the data is strongly
   non-Gaussian, i.e. with two well-separated clusters;
 
+- using the Isolation Forest algorithm, which is based on random forests and
+  hence more adapted to large-dimensional settings, even if it performs
+  quite well in the examples below.
+
 The ground truth about inliers and outliers is given by the points colors
 while the orange-filled area indicates which points are reported as inliers
 by each method.
@@ -32,6 +36,9 @@ from scipy import stats
 
 from sklearn import svm
 from sklearn.covariance import EllipticEnvelope
+from sklearn.ensemble import IsolationForest
+
+rng = np.random.RandomState(42)
 
 # Example settings
 n_samples = 200
@@ -42,7 +49,8 @@ clusters_separation = [0, 1, 2]
 classifiers = {
     "One-Class SVM": svm.OneClassSVM(nu=0.95 * outliers_fraction + 0.05,
                                      kernel="rbf", gamma=0.1),
-    "robust covariance estimator": EllipticEnvelope(contamination=.1)}
+    "robust covariance estimator": EllipticEnvelope(contamination=.1),
+    "Isolation Forest": IsolationForest(max_samples=n_samples, random_state=rng)}
 
 # Compare given classifiers under given settings
 xx, yy = np.meshgrid(np.linspace(-7, 7, 500), np.linspace(-7, 7, 500))
@@ -61,7 +69,7 @@ for i, offset in enumerate(clusters_separation):
     # Add outliers
     X = np.r_[X, np.random.uniform(low=-6, high=6, size=(n_outliers, 2))]
 
-    # Fit the model with the One-Class SVM
+    # Fit the model
     plt.figure(figsize=(10, 5))
     for i, (clf_name, clf) in enumerate(classifiers.items()):
         # fit the data and tag outliers
@@ -74,7 +82,7 @@ for i, offset in enumerate(clusters_separation):
         # plot the levels lines and the points
         Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
         Z = Z.reshape(xx.shape)
-        subplot = plt.subplot(1, 2, i + 1)
+        subplot = plt.subplot(1, 3, i + 1)
         subplot.set_title("Outlier detection")
         subplot.contourf(xx, yy, Z, levels=np.linspace(Z.min(), threshold, 7),
                          cmap=plt.cm.Blues_r)
diff --git a/examples/ensemble/plot_isolation_forest.py b/examples/ensemble/plot_isolation_forest.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af38fe40b7d0aea5e86056f34c0e0267074ff99
--- /dev/null
+++ b/examples/ensemble/plot_isolation_forest.py
@@ -0,0 +1,69 @@
+"""
+==========================================
+IsolationForest example
+==========================================
+
+An example using IsolationForest for anomaly detection.
+
+IsolationForest consists in 'isolating' the observations by randomly selecting
+a feature and then randomly selecting a split value between the maximum and
+minimum values of the selected feature.
+
+Since recursive partitioning can be represented by a tree structure, the
+number of splitting required to isolate a sample is equivalent to the path
+length from the root node to a terminating node.
+
+This path length, averaged among a forest of such random trees, is a measure
+of abnormality and our decision function.
+
+Indeed random partitioning produces noticeable shorter paths for anomalies.
+Hence, when a forest of random trees collectively produce shorter path lengths
+for some particular samples, then they are highly likely to be anomalies.
+
+.. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
+    Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
+
+"""
+print(__doc__)
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.ensemble import IsolationForest
+
+rng = np.random.RandomState(42)
+
+# Generate train data
+X = 0.3 * rng.randn(100, 2)
+X_train = np.r_[X + 2, X - 2]
+# Generate some regular novel observations
+X = 0.3 * rng.randn(20, 2)
+X_test = np.r_[X + 2, X - 2]
+# Generate some abnormal novel observations
+X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
+
+# fit the model
+clf = IsolationForest(max_samples=100, random_state=rng)
+clf.fit(X_train)
+y_pred_train = clf.predict(X_train)
+y_pred_test = clf.predict(X_test)
+y_pred_outliers = clf.predict(X_outliers)
+
+# plot the line, the samples, and the nearest vectors to the plane
+xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
+Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
+Z = Z.reshape(xx.shape)
+
+plt.title("IsolationForest")
+plt.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)
+
+b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c='white')
+b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c='green')
+c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='red')
+plt.axis('tight')
+plt.xlim((-5, 5))
+plt.ylim((-5, 5))
+plt.legend([b1, b2, c],
+           ["training observations",
+            "new regular observations", "new abnormal observations"],
+           loc="upper left")
+plt.show()
diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py
index 4997d97e0fd946197b2116caf1f95c244649c691..0a8cfc62df537347bb3a3cb03adc7885111cfecc 100644
--- a/sklearn/datasets/__init__.py
+++ b/sklearn/datasets/__init__.py
@@ -16,6 +16,7 @@ from .base import clear_data_home
 from .base import load_sample_images
 from .base import load_sample_image
 from .covtype import fetch_covtype
+from .kddcup99 import fetch_kddcup99
 from .mlcomp import load_mlcomp
 from .lfw import load_lfw_pairs
 from .lfw import load_lfw_people
@@ -65,6 +66,7 @@ __all__ = ['clear_data_home',
            'fetch_california_housing',
            'fetch_covtype',
            'fetch_rcv1',
+           'fetch_kddcup99',
            'get_data_home',
            'load_boston',
            'load_diabetes',
diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7696f68c2812406b085e7d752d666d51ed6cbd
--- /dev/null
+++ b/sklearn/datasets/kddcup99.py
@@ -0,0 +1,355 @@
+"""KDDCUP 99 dataset.
+
+A classic dataset for anomaly detection.
+
+The dataset page is available from UCI Machine Learning Repository
+
+https://archive.ics.uci.edu/ml/machine-learning-databases/kddcup99-mld/kddcup.data.gz
+
+"""
+
+import sys
+import errno
+from gzip import GzipFile
+from io import BytesIO
+import logging
+import os
+from os.path import exists, join
+try:
+    from urllib2 import urlopen
+except ImportError:
+    from urllib.request import urlopen
+
+import numpy as np
+
+from .base import get_data_home
+from .base import Bunch
+from ..externals import joblib
+from ..utils import check_random_state
+from ..utils import shuffle as shuffle_method
+
+
+URL10 = ('http://archive.ics.uci.edu/ml/'
+         'machine-learning-databases/kddcup99-mld/kddcup.data_10_percent.gz')
+
+URL = ('http://archive.ics.uci.edu/ml/'
+       'machine-learning-databases/kddcup99-mld/kddcup.data.gz')
+
+
+logger = logging.getLogger()
+
+
+def fetch_kddcup99(subset=None, shuffle=False, random_state=None,
+                   percent10=False):
+    """Load and return the kddcup 99 dataset (regression).
+
+    The KDD Cup '99 dataset was created by processing the tcpdump portions
+    of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset,
+    created by MIT Lincoln Lab [1] . The artificial data was generated using
+    a closed network and hand-injected attacks to produce a large number of
+    different types of attack with normal activity in the background.
+    As the initial goal was to produce a large training set for supervised
+    learning algorithms, there is a large proportion (80.1%) of abnormal
+    data which is unrealistic in real world, and inapropriate for unsupervised
+    anomaly detection which aims at detecting 'abnormal' data, ie
+
+    1) qualitatively different from normal data.
+
+    2) in large minority among the observations.
+
+    We thus transform the KDD Data set into two differents data set: SA and SF.
+
+    - SA is obtained by simply selecting all the normal data, and a small
+      proportion of abnormal data to gives an anomaly proportion of 1%.
+
+    - SF is obtained as in [2]
+      by simply picking up the data whose attribute logged_in is positive, thus
+      focusing on the intrusion attack, which gives a proportion of 0.3% of
+      attack.
+
+    - http and smtp are two subsets of SF corresponding with third feature
+      equal to 'http' (resp. to 'smtp')
+
+
+    General KDD structure :
+
+    ================      ==========================================
+    Samples total         4898431
+    Dimensionality        41
+    Features              discrete (int) or continuous (float)
+    Targets               str, 'normal.' or name of the anomaly type
+    ================      ==========================================
+
+    SA structure :
+    ================      ==========================================
+    Samples total         976158
+    Dimensionality        41
+    Features              discrete (int) or continuous (float)
+    Targets               str, 'normal.' or name of the anomaly type
+    ================      ==========================================
+
+    SF structure :
+    ================      ==========================================
+    Samples total         699691
+    Dimensionality        40
+    Features              discrete (int) or continuous (float)
+    Targets               str, 'normal.' or name of the anomaly type
+    ================      ==========================================
+
+    http structure :
+    ================      ==========================================
+    Samples total         619052
+    Dimensionality        39
+    Features              discrete (int) or continuous (float)
+    Targets               str, 'normal.' or name of the anomaly type
+    ================      ==========================================
+
+    smtp structure :
+    ================      ==========================================
+    Samples total         95373
+    Dimensionality        39
+    Features              discrete (int) or continuous (float)
+    Targets               str, 'normal.' or name of the anomaly type
+    ================      ==========================================
+
+    Parameters
+    ----------
+    subset : None, 'SA', 'SF', 'http', 'smtp'
+        To return the corresponding classical subsets of kddcup 99.
+        If None, return the entire kddcup 99 dataset.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        Random state for shuffling the dataset.
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by `np.random`.
+
+    shuffle : bool, default=False
+        Whether to shuffle dataset.
+
+    percent10 : bool, default=False
+        Whether to load only 10 percent of the data.
+
+    Returns
+    -------
+    data : Bunch
+        Dictionary-like object, the interesting attributes are:
+        'data', the data to learn and 'target', the regression target for each
+        sample.
+
+
+    References
+    ----------
+    .. [1] Analysis and Results of the 1999 DARPA Off-Line Intrusion
+           Detection Evaluation Richard Lippmann, Joshua W. Haines,
+           David J. Fried, Jonathan Korba, Kumar Das
+
+    .. [2] A Geometric Framework for Unsupervised Anomaly Detection: Detecting
+           Intrusions in Unlabeled Data (2002) by Eleazar Eskin, Andrew Arnold,
+           Michael Prerau, Leonid Portnoy, Sal Stolfo
+    """
+    kddcup99 = _fetch_brute_kddcup99(shuffle=shuffle, percent10=percent10)
+
+    data = kddcup99.data
+    target = kddcup99.target
+
+    if subset == 'SA':
+        s = target == 'normal.'
+        t = np.logical_not(s)
+        normal_samples = data[s, :]
+        normal_targets = target[s]
+        abnormal_samples = data[t, :]
+        abnormal_targets = target[t]
+
+        n_samples_abnormal = abnormal_samples.shape[0]
+        # selected abnormal samples:
+        random_state = check_random_state(random_state)
+        r = random_state.randint(0, n_samples_abnormal, 3377)
+        abnormal_samples = abnormal_samples[r]
+        abnormal_targets = abnormal_targets[r]
+
+        data = np.r_[normal_samples, abnormal_samples]
+        target = np.r_[normal_targets, abnormal_targets]
+
+    if subset == 'SF' or subset == 'http' or subset == 'smtp':
+        # select all samples with positive logged_in attribute:
+        s = data[:, 11] == 1
+        data = np.c_[data[s, :11], data[s, 12:]]
+        target = target[s]
+
+        data[:, 0] = np.log((data[:, 0] + 0.1).astype(float))
+        data[:, 4] = np.log((data[:, 4] + 0.1).astype(float))
+        data[:, 5] = np.log((data[:, 5] + 0.1).astype(float))
+
+        if subset == 'http':
+            s = data[:, 2] == 'http'
+            data = data[s]
+            target = target[s]
+            data = np.c_[data[:, 0], data[:, 4], data[:, 5]]
+
+        if subset == 'smtp':
+            s = data[:, 2] == 'smtp'
+            data = data[s]
+            target = target[s]
+            data = np.c_[data[:, 0], data[:, 4], data[:, 5]]
+
+        if subset == 'SF':
+            data = np.c_[data[:, 0], data[:, 2], data[:, 4], data[:, 5]]
+
+    return Bunch(data=data, target=target)
+
+
+def _fetch_brute_kddcup99(subset=None, data_home=None,
+                          download_if_missing=True, random_state=None,
+                          shuffle=False, percent10=False):
+
+    """Load the kddcup99 dataset, downloading it if necessary.
+
+    Parameters
+    ----------
+    subset : None, 'SA', 'SF', 'http', 'smtp'
+        To return the corresponding classical subsets of kddcup 99.
+        If None, return the entire kddcup 99 dataset.
+
+    data_home : string, optional
+        Specify another download and cache folder for the datasets. By default
+        all scikit learn data is stored in '~/scikit_learn_data' subfolders.
+
+    download_if_missing : boolean, default=True
+        If False, raise a IOError if the data is not locally available
+        instead of trying to download the data from the source site.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        Random state for shuffling the dataset.
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by `np.random`.
+
+    shuffle : bool, default=False
+        Whether to shuffle dataset.
+
+    percent10 : bool, default=False
+        Whether to load only 10 percent of the data.
+
+    Returns
+    -------
+    dataset : dict-like object with the following attributes:
+        dataset.data : numpy array of shape (494021, 41)
+            Each row corresponds to the 41 features in the dataset.
+        dataset.target : numpy array of shape (494021,)
+            Each value corresponds to one of the 21 attack types or to the
+            label 'normal.'.
+        dataset.DESCR : string
+            Description of the kddcup99 dataset.
+
+    """
+
+    data_home = get_data_home(data_home=data_home)
+    if sys.version_info[0] == 3:
+        # The zlib compression format use by joblib is not compatible when
+        # switching from Python 2 to Python 3, let us use a separate folder
+        # under Python 3:
+        dir_suffix = "-py3"
+    else:
+        # Backward compat for Python 2 users
+        dir_suffix = ""
+    if percent10:
+        kddcup_dir = join(data_home, "kddcup99_10" + dir_suffix)
+    else:
+        kddcup_dir = join(data_home, "kddcup99" + dir_suffix)
+    samples_path = join(kddcup_dir, "samples")
+    targets_path = join(kddcup_dir, "targets")
+    available = exists(samples_path)
+
+    if download_if_missing and not available:
+        _mkdirp(kddcup_dir)
+        URL_ = URL10 if percent10 else URL
+        logger.warning("Downloading %s" % URL_)
+        f = BytesIO(urlopen(URL_).read())
+
+        dt = [('duration', int),
+              ('protocol_type', 'S4'),
+              ('service', 'S11'),
+              ('flag', 'S6'),
+              ('src_bytes', int),
+              ('dst_bytes', int),
+              ('land', int),
+              ('wrong_fragment', int),
+              ('urgent', int),
+              ('hot', int),
+              ('num_failed_logins', int),
+              ('logged_in', int),
+              ('num_compromised', int),
+              ('root_shell', int),
+              ('su_attempted', int),
+              ('num_root', int),
+              ('num_file_creations', int),
+              ('num_shells', int),
+              ('num_access_files', int),
+              ('num_outbound_cmds', int),
+              ('is_host_login', int),
+              ('is_guest_login', int),
+              ('count', int),
+              ('srv_count', int),
+              ('serror_rate', float),
+              ('srv_serror_rate', float),
+              ('rerror_rate', float),
+              ('srv_rerror_rate', float),
+              ('same_srv_rate', float),
+              ('diff_srv_rate', float),
+              ('srv_diff_host_rate', float),
+              ('dst_host_count', int),
+              ('dst_host_srv_count', int),
+              ('dst_host_same_srv_rate', float),
+              ('dst_host_diff_srv_rate', float),
+              ('dst_host_same_src_port_rate', float),
+              ('dst_host_srv_diff_host_rate', float),
+              ('dst_host_serror_rate', float),
+              ('dst_host_srv_serror_rate', float),
+              ('dst_host_rerror_rate', float),
+              ('dst_host_srv_rerror_rate', float),
+              ('labels', 'S16')]
+        DT = np.dtype(dt)
+
+        file_ = GzipFile(fileobj=f, mode='r')
+        Xy = []
+        for line in file_.readlines():
+            Xy.append(line.replace('\n', '').split(','))
+        file_.close()
+        print('extraction done')
+        Xy = np.asarray(Xy, dtype=object)
+        for j in range(42):
+            Xy[:, j] = Xy[:, j].astype(DT[j])
+
+        X = Xy[:, :-1]
+        y = Xy[:, -1]
+        # XXX bug when compress!=0:
+        # (error: 'Incorrect data length while decompressing[...] the file
+        #  could be corrupted.')
+
+        joblib.dump(X, samples_path, compress=0)
+        joblib.dump(y, targets_path, compress=0)
+
+    try:
+        X, y
+    except NameError:
+        X = joblib.load(samples_path)
+        y = joblib.load(targets_path)
+
+    if shuffle:
+        X, y = shuffle_method(X, y, random_state=random_state)
+
+    return Bunch(data=X, target=y, DESCR=__doc__)
+
+
+def _mkdirp(d):
+    """Ensure directory d exists (like mkdir -p on Unix)
+    No guarantee that the directory is writable.
+    """
+    try:
+        os.makedirs(d)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py
index d2e0a1496f92d68e115bd6710c7c79adff649e44..5586a9e1e1fbaa390895225a86407c3606487a48 100644
--- a/sklearn/ensemble/__init__.py
+++ b/sklearn/ensemble/__init__.py
@@ -1,6 +1,6 @@
 """
 The :mod:`sklearn.ensemble` module includes ensemble-based methods for
-classification and regression.
+classification, regression and anomaly detection.
 """
 
 from .base import BaseEnsemble
@@ -11,6 +11,7 @@ from .forest import ExtraTreesClassifier
 from .forest import ExtraTreesRegressor
 from .bagging import BaggingClassifier
 from .bagging import BaggingRegressor
+from .iforest import IsolationForest
 from .weight_boosting import AdaBoostClassifier
 from .weight_boosting import AdaBoostRegressor
 from .gradient_boosting import GradientBoostingClassifier
@@ -27,7 +28,7 @@ __all__ = ["BaseEnsemble",
            "RandomForestClassifier", "RandomForestRegressor",
            "RandomTreesEmbedding", "ExtraTreesClassifier",
            "ExtraTreesRegressor", "BaggingClassifier",
-           "BaggingRegressor", "GradientBoostingClassifier",
+           "BaggingRegressor", "IsolationForest", "GradientBoostingClassifier",
            "GradientBoostingRegressor", "AdaBoostClassifier",
            "AdaBoostRegressor", "VotingClassifier",
            "bagging", "forest", "gradient_boosting",
diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py
index c69d31ef25c28e988ec331168c4e480926799d9a..f9849d33389fb844f809ad01af46be72e77ad02a 100644
--- a/sklearn/ensemble/bagging.py
+++ b/sklearn/ensemble/bagging.py
@@ -34,11 +34,10 @@ MAX_INT = np.iinfo(np.int32).max
 
 
 def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
-                               seeds, verbose):
+                               max_samples, seeds, verbose):
     """Private function used to build a batch of estimators within a job."""
     # Retrieve settings
     n_samples, n_features = X.shape
-    max_samples = ensemble.max_samples
     max_features = ensemble.max_features
 
     if (not isinstance(max_samples, (numbers.Integral, np.integer)) and
@@ -244,6 +243,35 @@ class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)):
             Note that this is supported only if the base estimator supports
             sample weighting.
 
+        Returns
+        -------
+        self : object
+            Returns self.
+        """
+        return self._fit(X, y, self.max_samples, sample_weight)
+
+    def _fit(self, X, y, max_samples, sample_weight=None):
+        """Build a Bagging ensemble of estimators from the training
+           set (X, y).
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of shape = [n_samples, n_features]
+            The training input samples. Sparse matrices are accepted only if
+            they are supported by the base estimator.
+
+        y : array-like, shape = [n_samples]
+            The target values (class labels in classification, real numbers in
+            regression).
+
+        max_samples : int or float, optional (default=None)
+            Argument to use instead of self.max_samples.
+
+        sample_weight : array-like, shape = [n_samples] or None
+            Sample weights. If None, then samples are equally weighted.
+            Note that this is supported only if the base estimator supports
+            sample weighting.
+
         Returns
         -------
         self : object
@@ -261,9 +289,8 @@ class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)):
         # Check parameters
         self._validate_estimator()
 
-        if isinstance(self.max_samples, (numbers.Integral, np.integer)):
-            max_samples = self.max_samples
-        else:  # float
+        # if max_samples is float:
+        if not isinstance(max_samples, (numbers.Integral, np.integer)):
             max_samples = int(self.max_samples * X.shape[0])
 
         if not (0 < max_samples <= X.shape[0]):
@@ -324,6 +351,7 @@ class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)):
                 X,
                 y,
                 sample_weight,
+                max_samples,
                 seeds[starts[i]:starts[i + 1]],
                 verbose=self.verbose)
             for i in range(n_jobs))
diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py
index d4eea7b371069c035c1f267924f9cc9b7617e23e..db4d259892f48c79678e05febe655a11b191ffa2 100644
--- a/sklearn/ensemble/forest.py
+++ b/sklearn/ensemble/forest.py
@@ -45,7 +45,6 @@ import warnings
 from warnings import warn
 
 from abc import ABCMeta, abstractmethod
-
 import numpy as np
 from scipy.sparse import issparse
 from scipy.sparse import hstack as sparse_hstack
diff --git a/sklearn/ensemble/iforest.py b/sklearn/ensemble/iforest.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b0e2dda52a6770dc3d5535f743d6d5fe8ecadfb
--- /dev/null
+++ b/sklearn/ensemble/iforest.py
@@ -0,0 +1,274 @@
+# Authors: Nicolas Goix <nicolas.goix@telecom-paristech.fr>
+#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
+# License: BSD 3 clause
+
+from __future__ import division
+
+import numbers
+import numpy as np
+from warnings import warn
+
+from scipy.sparse import issparse
+
+from ..externals.joblib import Parallel, delayed
+from ..tree import ExtraTreeRegressor
+from ..utils import check_random_state, check_array
+
+from .bagging import BaseBagging
+from .forest import _parallel_helper
+from .base import _partition_estimators
+
+__all__ = ["IsolationForest"]
+
+
+class IsolationForest(BaseBagging):
+    """Isolation Forest Algorithm
+
+    Return the anomaly score of each sample with the IsolationForest algorithm
+
+    IsolationForest consists in 'isolate' the observations by randomly
+    selecting a feature and then randomly selecting a split value
+    between the maximum and minimum values of the selected feature.
+
+    Since recursive partitioning can be represented by a tree structure, the
+    number of splitting required to isolate a point is equivalent to the path
+    length from the root node to a terminating node.
+
+    This path length, averaged among a forest of such random trees, is a
+    measure of abnormality and our decision function.
+
+    Indeed random partitioning produces noticeable shorter paths for anomalies.
+    Hence, when a forest of random trees collectively produce shorter path
+    lengths for some particular points, then they are highly likely to be
+    anomalies.
+
+
+    Parameters
+    ----------
+    n_estimators : int, optional (default=100)
+        The number of base estimators in the ensemble.
+
+    max_samples : int or float, optional (default=256)
+        The number of samples to draw from X to train each base estimator.
+            - If int, then draw `max_samples` samples.
+            - If float, then draw `max_samples * X.shape[0]` samples.
+        If max_samples is larger than number of samples provided,
+        all samples with be used for all trees (no sampling).
+
+    max_features : int or float, optional (default=1.0)
+        The number of features to draw from X to train each base estimator.
+            - If int, then draw `max_features` features.
+            - If float, then draw `max_features * X.shape[1]` features.
+
+    bootstrap : boolean, optional (default=False)
+        Whether samples are drawn with replacement.
+
+    n_jobs : integer, optional (default=1)
+        The number of jobs to run in parallel for both `fit` and `predict`.
+        If -1, then the number of jobs is set to the number of cores.
+
+    random_state : int, RandomState instance or None, optional (default=None)
+        If int, random_state is the seed used by the random number generator;
+        If RandomState instance, random_state is the random number generator;
+        If None, the random number generator is the RandomState instance used
+        by `np.random`.
+
+    verbose : int, optional (default=0)
+        Controls the verbosity of the tree building process.
+
+
+    Attributes
+    ----------
+    estimators_ : list of DecisionTreeClassifier
+        The collection of fitted sub-estimators.
+
+    estimators_samples_ : list of arrays
+        The subset of drawn samples (i.e., the in-bag samples) for each base
+        estimator.
+
+    max_samples_ : integer
+        The actual number of samples
+
+    References
+    ----------
+    .. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
+           Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
+    .. [2] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation-based
+           anomaly detection." ACM Transactions on Knowledge Discovery from
+           Data (TKDD) 6.1 (2012): 3.
+    """
+
+    def __init__(self,
+                 n_estimators=100,
+                 max_samples=256,
+                 max_features=1.,
+                 bootstrap=False,
+                 n_jobs=1,
+                 random_state=None,
+                 verbose=0):
+        super(IsolationForest, self).__init__(
+            base_estimator=ExtraTreeRegressor(
+                max_depth=int(np.ceil(np.log2(max(max_samples, 2)))),
+                max_features=1,
+                splitter='random',
+                random_state=random_state),
+            # here above max_features has no links with self.max_features
+            bootstrap=bootstrap,
+            bootstrap_features=False,
+            n_estimators=n_estimators,
+            max_samples=max_samples,
+            max_features=max_features,
+            n_jobs=n_jobs,
+            random_state=random_state,
+            verbose=verbose)
+
+    def _set_oob_score(self, X, y):
+        raise NotImplementedError("OOB score not supported by iforest")
+
+    def fit(self, X, y=None, sample_weight=None):
+        """Fit estimator.
+
+        Parameters
+        ----------
+        X : array-like or sparse matrix, shape (n_samples, n_features)
+            The input samples. Use ``dtype=np.float32`` for maximum
+            efficiency. Sparse matrices are also supported, use sparse
+            ``csc_matrix`` for maximum efficieny.
+
+        Returns
+        -------
+        self : object
+            Returns self.
+        """
+        # ensure_2d=False because there are actually unit test checking we fail
+        # for 1d.
+        X = check_array(X, accept_sparse=['csc'], ensure_2d=False)
+        if issparse(X):
+            # Pre-sort indices to avoid that each individual tree of the
+            # ensemble sorts the indices.
+            X.sort_indices()
+
+        rnd = check_random_state(self.random_state)
+        y = rnd.uniform(size=X.shape[0])
+
+        # ensure that max_sample is in [1, n_samples]:
+        max_samples = self.max_samples
+        n_samples = X.shape[0]
+        if max_samples > n_samples:
+            warn("max_samples (%s) is greater than the "
+                 "total number of samples (%s). max_samples "
+                 "will be set to n_samples for estimation."
+                 % (self.max_samples, n_samples))
+            max_samples = n_samples
+
+        super(IsolationForest, self)._fit(X, y, max_samples,
+                                          sample_weight=sample_weight)
+        return self
+
+    def predict(self, X):
+        """Predict anomaly score of X with the IsolationForest algorithm.
+
+        The anomaly score of an input sample is computed as
+        the mean anomaly scores of the trees in the forest.
+
+        The measure of normality of an observation given a tree is the depth
+        of the leaf containing this observation, which is equivalent to
+        the number of splitting required to isolate this point. In case of
+        several observations n_left in the leaf, the average length path of
+        a n_left samples isolation tree is added.
+
+        Parameters
+        ----------
+        X : array-like or sparse matrix of shape (n_samples, n_features)
+            The input samples. Internally, it will be converted to
+            ``dtype=np.float32`` and if a sparse matrix is provided
+            to a sparse ``csr_matrix``.
+
+        Returns
+        -------
+        scores : array of shape (n_samples,)
+            The anomaly score of the input samples.
+            The lower, the more normal.
+        """
+        # code structure from ForestClassifier/predict_proba
+        # Check data
+        X = self.estimators_[0]._validate_X_predict(X, check_input=True)
+        n_samples = X.shape[0]
+
+
+        n_samples_leaf = np.zeros((n_samples, self.n_estimators), order="f")
+        depths = np.zeros((n_samples, self.n_estimators), order="f")
+
+        for i, tree in enumerate(self.estimators_):
+            leaves_index = tree.apply(X)
+            node_indicator = tree.decision_path(X)
+            n_samples_leaf[:, i] = tree.tree_.n_node_samples[leaves_index]
+            depths[:, i] = np.asarray(node_indicator.sum(axis=1)).reshape(-1) - 1
+
+        depths += _average_path_length(n_samples_leaf)
+
+        if not isinstance(self.max_samples, (numbers.Integral, np.integer)):
+            max_samples = int(self.max_samples * X.shape[0])
+        else:
+            max_samples = self.max_samples
+
+        scores = 2 ** (-depths.mean(axis=1) / _average_path_length(max_samples))
+
+        return scores
+
+    def decision_function(self, X):
+        """Average of the decision functions of the base classifiers.
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix}, shape (n_samples, n_features)
+            The training input samples. Sparse matrices are accepted only if
+            they are supported by the base estimator.
+
+        Returns
+        -------
+        score : array, shape (n_samples,)
+            The decision function of the input samples.
+
+        """
+        # minus as bigger is better (here less abnormal):
+        return - self.predict(X)
+
+
+def _average_path_length(n_samples_leaf):
+    """ The average path length in a n_samples iTree, which is equal to
+    the average path length of an unsuccessful BST search since the
+    latter has the same structure as an isolation tree.
+    Parameters
+    ----------
+    n_samples_leaf : array-like of shape (n_samples, n_estimators), or int.
+        The number of training samples in each test sample leaf, for
+        each estimators.
+    
+    Returns
+    -------
+    average_path_length : array, same shape as n_samples_leaf
+
+    """
+    if isinstance(n_samples_leaf, int):
+        if n_samples_leaf <= 1:
+            return 1.
+        else:
+            return 2. * (np.log(n_samples_leaf) + 0.5772156649) - 2. * (
+                n_samples_leaf - 1.) / n_samples_leaf
+
+    else:
+
+        n_samples_leaf_shape = n_samples_leaf.shape
+        n_samples_leaf = n_samples_leaf.reshape((1, -1))
+        average_path_length = np.zeros(n_samples_leaf.shape)
+
+        mask = (n_samples_leaf <= 1)
+        not_mask = np.logical_not(mask)
+
+        average_path_length[mask] = 1.
+        average_path_length[not_mask] = 2. * (
+            np.log(n_samples_leaf[not_mask]) + 0.5772156649) - 2. * (
+                n_samples_leaf[not_mask] - 1.) / n_samples_leaf[not_mask]
+
+        return average_path_length.reshape(n_samples_leaf_shape)
diff --git a/sklearn/ensemble/tests/test_iforest.py b/sklearn/ensemble/tests/test_iforest.py
new file mode 100644
index 0000000000000000000000000000000000000000..694f3af7842d508f6de6eac8536228462dcce13f
--- /dev/null
+++ b/sklearn/ensemble/tests/test_iforest.py
@@ -0,0 +1,161 @@
+
+"""
+Testing for Isolation Forest algorithm (sklearn.ensemble.iforest).
+"""
+
+# Authors: Nicolas Goix <nicolas.goix@telecom-paristech.fr>
+#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
+# License: BSD 3 clause
+
+import numpy as np
+
+from sklearn.utils.testing import assert_array_equal
+from sklearn.utils.testing import assert_array_almost_equal
+from sklearn.utils.testing import assert_raises
+from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_greater
+from sklearn.utils.testing import ignore_warnings
+
+from sklearn.grid_search import ParameterGrid
+from sklearn.ensemble import IsolationForest
+from sklearn.cross_validation import train_test_split
+from sklearn.datasets import load_boston, load_iris
+from sklearn.utils import check_random_state
+from sklearn.metrics import roc_auc_score
+
+from scipy.sparse import csc_matrix, csr_matrix
+
+rng = check_random_state(0)
+
+# load the iris dataset
+# and randomly permute it
+iris = load_iris()
+perm = rng.permutation(iris.target.size)
+iris.data = iris.data[perm]
+iris.target = iris.target[perm]
+
+# also load the boston dataset
+# and randomly permute it
+boston = load_boston()
+perm = rng.permutation(boston.target.size)
+boston.data = boston.data[perm]
+boston.target = boston.target[perm]
+
+
+def test_iforest():
+    """Check Isolation Forest for various parameter settings."""
+    X_train = np.array([[0, 1], [1, 2]])
+    X_test = np.array([[2, 1], [1, 1]])
+
+    grid = ParameterGrid({"n_estimators": [3],
+                          "max_samples": [0.5, 1.0, 3],
+                          "bootstrap": [True, False]})
+
+    with ignore_warnings():
+        for params in grid:
+            IsolationForest(random_state=rng,
+                            **params).fit(X_train).predict(X_test)
+
+
+def test_iforest_sparse():
+    """Check IForest for various parameter settings on sparse input."""
+    rng = check_random_state(0)
+    X_train, X_test, y_train, y_test = train_test_split(boston.data[:50],
+                                                        boston.target[:50],
+                                                        random_state=rng)
+    grid = ParameterGrid({"max_samples": [0.5, 1.0],
+                          "bootstrap": [True, False]})
+
+    for sparse_format in [csc_matrix, csr_matrix]:
+        X_train_sparse = sparse_format(X_train)
+        X_test_sparse = sparse_format(X_test)
+
+        for params in grid:
+            # Trained on sparse format
+            sparse_classifier = IsolationForest(
+                random_state=1, **params).fit(X_train_sparse)
+            sparse_results = sparse_classifier.predict(X_test_sparse)
+
+            # Trained on dense format
+            dense_results = IsolationForest(
+                random_state=1, **params).fit(X_train).predict(X_test)
+
+            assert_array_equal(sparse_results, dense_results)
+            assert_array_equal(sparse_results, dense_results)
+
+
+def test_iforest_error():
+    """Test that it gives proper exception on deficient input."""
+    X = iris.data
+
+    # Test max_samples
+    assert_raises(ValueError,
+                  IsolationForest(max_samples=-1).fit, X)
+    assert_raises(ValueError,
+                  IsolationForest(max_samples=0.0).fit, X)
+    assert_raises(ValueError,
+                  IsolationForest(max_samples=2.0).fit, X)
+    assert_warns(UserWarning,
+                 IsolationForest(max_samples=1000).fit, X)
+    # cannot check for string values
+
+
+def test_iforest_parallel_regression():
+    """Check parallel regression."""
+    rng = check_random_state(0)
+
+    X_train, X_test, y_train, y_test = train_test_split(boston.data,
+                                                        boston.target,
+                                                        random_state=rng)
+
+    ensemble = IsolationForest(n_jobs=3,
+                               random_state=0).fit(X_train)
+
+    ensemble.set_params(n_jobs=1)
+    y1 = ensemble.predict(X_test)
+    ensemble.set_params(n_jobs=2)
+    y2 = ensemble.predict(X_test)
+    assert_array_almost_equal(y1, y2)
+
+    ensemble = IsolationForest(n_jobs=1,
+                               random_state=0).fit(X_train)
+
+    y3 = ensemble.predict(X_test)
+    assert_array_almost_equal(y1, y3)
+
+
+def test_iforest_performance():
+    """Test Isolation Forest performs well"""
+
+    # Generate train/test data
+    rng = check_random_state(2)
+    X = 0.3 * rng.randn(120, 2)
+    X_train = np.r_[X + 2, X - 2]
+    X_train = X[:100]
+
+    # Generate some abnormal novel observations
+    X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
+    X_test = np.r_[X[100:], X_outliers]
+    y_test = np.array([0] * 20 + [1] * 20)
+
+    # fit the model
+    clf = IsolationForest(max_samples=100, random_state=rng).fit(X_train)
+
+    # predict scores (the lower, the more normal)
+    y_pred = clf.predict(X_test)
+
+    # check that there is at most 6 errors (false positive or false negative)
+    assert_greater(roc_auc_score(y_test, y_pred), 0.98)
+
+
+def test_iforest_works():
+    # toy sample (the last two samples are outliers)
+    X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [-4, 7]]
+
+    # Test LOF
+    clf = IsolationForest(random_state=rng)
+    clf.fit(X)
+    pred = clf.predict(X)
+
+    # assert detect outliers:
+    assert_greater(np.min(pred[-2:]), np.max(pred[:-2]))