From cb1b6c4734b2989ad0492b56c326858d030f3fa2 Mon Sep 17 00:00:00 2001
From: Thomas Moreau <thomas.moreau.2010@gmail.com>
Date: Wed, 12 Jul 2017 22:56:02 +0200
Subject: [PATCH] FIX t-SNE memory usage and many other optimizer issues
 (#9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
---
 benchmarks/.gitignore                      |   4 +
 benchmarks/bench_tsne_mnist.py             | 169 +++++
 benchmarks/plot_tsne_mnist.py              |  30 +
 doc/whats_new.rst                          |  17 +
 examples/manifold/plot_t_sne_perplexity.py |  44 +-
 sklearn/manifold/_barnes_hut_tsne.pyx      | 805 +++------------------
 sklearn/manifold/_utils.pyx                |  68 +-
 sklearn/manifold/setup.py                  |   1 +
 sklearn/manifold/t_sne.py                  | 494 ++++++-------
 sklearn/manifold/tests/test_t_sne.py       | 358 +++++----
 sklearn/mixture/base.py                    |   2 +-
 sklearn/neighbors/quad_tree.pxd            | 100 +++
 sklearn/neighbors/quad_tree.pyx            | 672 +++++++++++++++++
 sklearn/neighbors/setup.py                 |   4 +
 sklearn/neighbors/tests/test_quad_tree.py  | 108 +++
 sklearn/tree/_utils.pxd                    |   4 +-
 16 files changed, 1748 insertions(+), 1132 deletions(-)
 create mode 100644 benchmarks/.gitignore
 create mode 100644 benchmarks/bench_tsne_mnist.py
 create mode 100644 benchmarks/plot_tsne_mnist.py
 create mode 100644 sklearn/neighbors/quad_tree.pxd
 create mode 100644 sklearn/neighbors/quad_tree.pyx
 create mode 100644 sklearn/neighbors/tests/test_quad_tree.py

diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore
new file mode 100644
index 0000000000..2b6f7ba9c1
--- /dev/null
+++ b/benchmarks/.gitignore
@@ -0,0 +1,4 @@
+/bhtsne
+*.npy
+*.json
+/mnist_tsne_output/
diff --git a/benchmarks/bench_tsne_mnist.py b/benchmarks/bench_tsne_mnist.py
new file mode 100644
index 0000000000..26dde6aac3
--- /dev/null
+++ b/benchmarks/bench_tsne_mnist.py
@@ -0,0 +1,169 @@
+"""
+=============================
+MNIST dataset T-SNE benchmark
+=============================
+
+"""
+from __future__ import division, print_function
+
+# License: BSD 3 clause
+
+import os
+import os.path as op
+from time import time
+import numpy as np
+import json
+import argparse
+
+from sklearn.externals.joblib import Memory
+from sklearn.datasets import fetch_mldata
+from sklearn.manifold import TSNE
+from sklearn.neighbors import NearestNeighbors
+from sklearn.decomposition import PCA
+from sklearn.utils import check_array
+from sklearn.utils import shuffle as _shuffle
+
+
+LOG_DIR = "mnist_tsne_output"
+if not os.path.exists(LOG_DIR):
+    os.mkdir(LOG_DIR)
+
+
+memory = Memory(os.path.join(LOG_DIR, 'mnist_tsne_benchmark_data'),
+                mmap_mode='r')
+
+
+@memory.cache
+def load_data(dtype=np.float32, order='C', shuffle=True, seed=0):
+    """Load the data, then cache and memmap the train/test split"""
+    print("Loading dataset...")
+    data = fetch_mldata('MNIST original')
+
+    X = check_array(data['data'], dtype=dtype, order=order)
+    y = data["target"]
+
+    if shuffle:
+        X, y = _shuffle(X, y, random_state=seed)
+
+    # Normalize features
+    X /= 255
+    return X, y
+
+
+def nn_accuracy(X, X_embedded, k=1):
+    """Accuracy of the first nearest neighbor"""
+    knn = NearestNeighbors(n_neighbors=1, n_jobs=-1)
+    _, neighbors_X = knn.fit(X).kneighbors()
+    _, neighbors_X_embedded = knn.fit(X_embedded).kneighbors()
+    return np.mean(neighbors_X == neighbors_X_embedded)
+
+
+def tsne_fit_transform(model, data):
+    transformed = model.fit_transform(data)
+    return transformed, model.n_iter_
+
+
+def sanitize(filename):
+    return filename.replace("/", '-').replace(" ", "_")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser('Benchmark for t-SNE')
+    parser.add_argument('--order', type=str, default='C',
+                        help='Order of the input data')
+    parser.add_argument('--perplexity', type=float, default=30)
+    parser.add_argument('--bhtsne', action='store_true',
+                        help="if set and the reference bhtsne code is "
+                        "correctly installed, run it in the benchmark.")
+    parser.add_argument('--all', action='store_true',
+                        help="if set, run the benchmark with the whole MNIST."
+                             "dataset. Note that it will take up to 1 hour.")
+    parser.add_argument('--profile', action='store_true',
+                        help="if set, run the benchmark with a memory "
+                             "profiler.")
+    parser.add_argument('--verbose', type=int, default=0)
+    parser.add_argument('--pca-components', type=int, default=50,
+                        help="Number of principal components for "
+                             "preprocessing.")
+    args = parser.parse_args()
+
+    X, y = load_data(order=args.order)
+
+    if args.pca_components > 0:
+        t0 = time()
+        X = PCA(n_components=args.pca_components).fit_transform(X)
+        print("PCA preprocessing down to {} dimensions took {:0.3f}s"
+              .format(args.pca_components, time() - t0))
+
+    methods = []
+
+    # Put TSNE in methods
+    tsne = TSNE(n_components=2, init='pca', perplexity=args.perplexity,
+                verbose=args.verbose, n_iter=1000)
+    methods.append(("sklearn TSNE",
+                    lambda data: tsne_fit_transform(tsne, data)))
+
+    if args.bhtsne:
+        try:
+            from bhtsne.bhtsne import run_bh_tsne
+        except ImportError:
+            raise ImportError("""\
+If you want comparison with the reference implementation, build the
+binary from source (https://github.com/lvdmaaten/bhtsne) in the folder
+benchmarks/bhtsne and add an empty `__init__.py` file in the folder:
+
+$ git clone git@github.com:lvdmaaten/bhtsne.git
+$ cd bhtsne
+$ g++ sptree.cpp tsne.cpp tsne_main.cpp -o bh_tsne -O2
+$ touch __init__.py
+$ cd ..
+""")
+
+        def bhtsne(X):
+            """Wrapper for the reference lvdmaaten/bhtsne implementation."""
+            # PCA preprocessing is done elsewhere in the benchmark script
+            n_iter = -1  # TODO find a way to report the number of iterations
+            return run_bh_tsne(X, use_pca=False, perplexity=args.perplexity,
+                               verbose=args.verbose > 0), n_iter
+        methods.append(("lvdmaaten/bhtsne", bhtsne))
+
+    if args.profile:
+
+        try:
+            from memory_profiler import profile
+        except ImportError:
+            raise ImportError("To run the benchmark with `--profile`, you "
+                              "need to install `memory_profiler`. Please "
+                              "run `pip install memory_profiler`.")
+        methods = [(n, profile(m)) for n, m in methods]
+
+    data_size = [100, 500, 1000, 5000, 10000]
+    if args.all:
+        data_size.append(70000)
+
+    results = []
+    basename, _ = os.path.splitext(__file__)
+    log_filename = os.path.join(LOG_DIR, basename + '.json')
+    for n in data_size:
+        X_train = X[:n]
+        y_train = y[:n]
+        n = X_train.shape[0]
+        for name, method in methods:
+            print("Fitting {} on {} samples...".format(name, n))
+            t0 = time()
+            np.save(os.path.join(LOG_DIR, 'mnist_{}_{}.npy'
+                                 .format('original', n)), X_train)
+            np.save(os.path.join(LOG_DIR, 'mnist_{}_{}.npy'
+                                 .format('original_labels', n)), y_train)
+            X_embedded, n_iter = method(X_train)
+            duration = time() - t0
+            precision_5 = nn_accuracy(X_train, X_embedded)
+            print("Fitting {} on {} samples took {:.3f}s in {:d} iterations, "
+                  "nn accuracy: {:0.3f}".format(
+                      name, n, duration, n_iter, precision_5))
+            results.append(dict(method=name, duration=duration, n_samples=n))
+            with open(log_filename, 'w', encoding='utf-8') as f:
+                json.dump(results, f)
+            method_name = sanitize(name)
+            np.save(op.join(LOG_DIR, 'mnist_{}_{}.npy'.format(method_name, n)),
+                    X_embedded)
diff --git a/benchmarks/plot_tsne_mnist.py b/benchmarks/plot_tsne_mnist.py
new file mode 100644
index 0000000000..0ffd32b3de
--- /dev/null
+++ b/benchmarks/plot_tsne_mnist.py
@@ -0,0 +1,30 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import os.path as op
+
+import argparse
+
+
+LOG_DIR = "mnist_tsne_output"
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser('Plot benchmark results for t-SNE')
+    parser.add_argument(
+        '--labels', type=str,
+        default=op.join(LOG_DIR, 'mnist_original_labels_10000.npy'),
+        help='1D integer numpy array for labels')
+    parser.add_argument(
+        '--embedding', type=str,
+        default=op.join(LOG_DIR, 'mnist_sklearn_TSNE_10000.npy'),
+        help='2D float numpy array for embedded data')
+    args = parser.parse_args()
+
+    X = np.load(args.embedding)
+    y = np.load(args.labels)
+
+    for i in np.unique(y):
+        mask = y == i
+        plt.scatter(X[mask, 0], X[mask, 1], alpha=0.2, label=int(i))
+    plt.legend(loc='best')
+    plt.show()
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 3c87d4174c..1244c4596b 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -19,6 +19,7 @@ occurs due to changes in the modelling logic (bug fixes or enhancements), or in
 random sampling procedures.
 
    * :class:`sklearn.ensemble.IsolationForest` (bug fix)
+   * :class:`sklearn.manifold.TSNE` (bug fix)
 
 Details are listed in the changelog below.
 
@@ -245,6 +246,14 @@ Enhancements
    - Speed improvements to :class:`model_selection.StratifiedShuffleSplit`.
      :issue:`5991` by :user:`Arthur Mensch <arthurmensch>` and `Joel Nothman`_.
 
+   - Memory improvements for method barnes_hut in :class:`manifold.TSNE`
+     :issue:`7089` by :user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.
+
+   - Optimization schedule improvements for so the results are closer to the
+     one from the reference implementation
+     `lvdmaaten/bhtsne <https://github.com/lvdmaaten/bhtsne>`_ by
+     :user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.
+
 Bug fixes
 .........
 
@@ -478,6 +487,14 @@ Bug fixes
      and :class:`linear_model.Ridge` when using ``normalize=True``
      by `Alexandre Gramfort`_.
 
+   - Fixed the implementation of :class:`manifold.TSNE`:
+      - ``early_exageration`` parameter had no effect and is now used for the
+        first 250 optimization iterations.
+      - Fixed the ``InsersionError`` reported in :issue:`8992`.
+      - Improve the learning schedule to match the one from the reference
+        implementation `lvdmaaten/bhtsne <https://github.com/lvdmaaten/bhtsne>`_.
+     by :user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.
+
 API changes summary
 -------------------
 
diff --git a/examples/manifold/plot_t_sne_perplexity.py b/examples/manifold/plot_t_sne_perplexity.py
index 4165dac141..cc3dafc12a 100644
--- a/examples/manifold/plot_t_sne_perplexity.py
+++ b/examples/manifold/plot_t_sne_perplexity.py
@@ -14,7 +14,7 @@ perplexity values and does not always convey a meaning.
 As shown below, t-SNE for higher perplexities finds meaningful topology of
 two concentric circles, however the size and the distance of the circles varies
 slightly from the original. Contrary to the two circles dataset, the shapes
-visually diverge from S-curve topology on the S-curve dateset even for
+visually diverge from S-curve topology on the S-curve dataset even for
 larger perplexity values.
 
 For further details, "How to Use t-SNE Effectively"
@@ -28,16 +28,17 @@ those effects.
 
 print(__doc__)
 
+import numpy as np
 import matplotlib.pyplot as plt
 
 from matplotlib.ticker import NullFormatter
 from sklearn import manifold, datasets
 from time import time
 
-n_samples = 500
+n_samples = 300
 n_components = 2
-(fig, subplots) = plt.subplots(2, 5, figsize=(15, 8))
-perplexities = [5, 50, 100, 150]
+(fig, subplots) = plt.subplots(3, 5, figsize=(15, 8))
+perplexities = [5, 30, 50, 100]
 
 X, y = datasets.make_circles(n_samples=n_samples, factor=.5, noise=.05)
 
@@ -71,7 +72,7 @@ for i, perplexity in enumerate(perplexities):
 X, color = datasets.samples_generator.make_s_curve(n_samples, random_state=0)
 
 ax = subplots[1][0]
-ax.scatter(X[:, 0], X[:, 2], c=color, cmap=plt.cm.Spectral)
+ax.scatter(X[:, 0], X[:, 2], c=color, cmap=plt.cm.viridis)
 ax.xaxis.set_major_formatter(NullFormatter())
 ax.yaxis.set_major_formatter(NullFormatter())
 
@@ -86,9 +87,40 @@ for i, perplexity in enumerate(perplexities):
     print("S-curve, perplexity=%d in %.2g sec" % (perplexity, t1 - t0))
 
     ax.set_title("Perplexity=%d" % perplexity)
-    ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
+    ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.viridis)
     ax.xaxis.set_major_formatter(NullFormatter())
     ax.yaxis.set_major_formatter(NullFormatter())
     ax.axis('tight')
 
+
+# Another example using a 2D uniform grid
+x = np.linspace(0, 1, int(np.sqrt(n_samples)))
+xx, yy = np.meshgrid(x, x)
+X = np.hstack([
+    xx.ravel().reshape(-1, 1),
+    yy.ravel().reshape(-1, 1),
+])
+color = xx.ravel()
+ax = subplots[2][0]
+ax.scatter(X[:, 0], X[:, 1], c=color, cmap=plt.cm.viridis)
+ax.xaxis.set_major_formatter(NullFormatter())
+ax.yaxis.set_major_formatter(NullFormatter())
+
+for i, perplexity in enumerate(perplexities):
+    ax = subplots[2][i + 1]
+
+    t0 = time()
+    tsne = manifold.TSNE(n_components=n_components, init='random',
+                         random_state=0, perplexity=perplexity)
+    Y = tsne.fit_transform(X)
+    t1 = time()
+    print("uniform grid, perplexity=%d in %.2g sec" % (perplexity, t1 - t0))
+
+    ax.set_title("Perplexity=%d" % perplexity)
+    ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.viridis)
+    ax.xaxis.set_major_formatter(NullFormatter())
+    ax.yaxis.set_major_formatter(NullFormatter())
+    ax.axis('tight')
+
+
 plt.show()
diff --git a/sklearn/manifold/_barnes_hut_tsne.pyx b/sklearn/manifold/_barnes_hut_tsne.pyx
index 62cb036f7a..f08a2ced26 100644
--- a/sklearn/manifold/_barnes_hut_tsne.pyx
+++ b/sklearn/manifold/_barnes_hut_tsne.pyx
@@ -11,18 +11,24 @@
 from libc.stdlib cimport malloc, free
 from libc.stdio cimport printf
 from libc.math cimport sqrt, log
-cimport numpy as np
 import numpy as np
+cimport numpy as np
+
+from sklearn.neighbors import quad_tree
+from sklearn.neighbors cimport quad_tree
 
 cdef char* EMPTY_STRING = ""
 
 cdef extern from "math.h":
     float fabsf(float x) nogil
 
-# Round points differing by less than this amount
-# effectively ignoring differences near the 32bit 
-# floating point precision
-cdef float EPSILON = 1e-6
+# Smallest strictly positive value that can be represented by floating
+# point numbers for different precision levels. This is useful to avoid
+# taking the log of zero when computing the KL divergence.
+cdef float FLOAT32_TINY = np.finfo(np.float32).tiny
+
+# Useful to void division by zero or divergence to +inf.
+cdef float FLOAT64_EPS = np.finfo(np.float64).eps
 
 # This is effectively an ifdef statement in Cython
 # It allows us to write printf debugging lines
@@ -37,466 +43,66 @@ cdef extern from "time.h":
     double CLOCKS_PER_SEC
 
 
-cdef extern from "cblas.h":
-    float snrm2 "cblas_snrm2"(int N, float *X, int incX) nogil
-
-
-cdef struct Node:
-    # Keep track of the center of mass
-    float* barycenter
-    # If this is a leaf, the position of the point within this leaf 
-    float* leaf_point_position
-    # The number of points including all 
-    # nodes below this one
-    long cumulative_size
-    # Number of points at this node
-    long size
-    # Index of the point at this node
-    # Only defined for non-empty leaf nodes
-    long point_index
-    # level = 0 is the root node
-    # And each subdivision adds 1 to the level
-    long level
-    # Left edge of this node
-    float* left_edge
-    # The center of this node, equal to le + w/2.0
-    float* center
-    # The width of this node -- used to calculate the opening
-    # angle. Equal to width = re - le
-    float* width
-    # The value of the maximum width w
-    float max_width
-
-    # Does this node have children?
-    # Default to leaf until we add points
-    int is_leaf
-    # Array of pointers to pointers of children
-    Node **children
-    # Keep a pointer to the parent
-    Node *parent
-    # Pointer to the tree this node belongs too
-    Tree* tree
-
-cdef struct Tree:
-    # Holds a pointer to the root node
-    Node* root_node 
-    # Number of dimensions in the output
-    int n_dimensions
-    # Total number of cells
-    long n_cells
-    # Total number of points
-    long n_points
-    # Spit out diagnostic information?
-    int verbose
-    # How many cells per node? Should go as 2 ** n_dimensionss
-    int n_cell_per_node
-
-cdef Tree* init_tree(float[:] left_edge, float[:] width, int n_dimensions, 
-                     int verbose) nogil:
-    # tree is freed by free_tree
-    cdef Tree* tree = <Tree*> malloc(sizeof(Tree))
-    tree.n_dimensions = n_dimensions
-    tree.n_cells = 0
-    tree.n_points = 0
-    tree.verbose = verbose
-    tree.root_node = create_root(left_edge, width, n_dimensions)
-    tree.root_node.tree = tree
-    tree.n_cells += 1
-    tree.n_cell_per_node = 2 ** n_dimensions
-    if DEBUGFLAG:
-        printf("[t-SNE] Tree initialised. Left_edge = (%1.9e, %1.9e, %1.9e)\n",
-               left_edge[0], left_edge[1], left_edge[2])
-        printf("[t-SNE] Tree initialised. Width = (%1.9e, %1.9e, %1.9e)\n",
-                width[0], width[1], width[2])
-    return tree
-
-cdef Node* create_root(float[:] left_edge, float[:] width, int n_dimensions) nogil:
-    # Create a default root node
-    cdef int ax
-    cdef int n_cell_per_node = 2 ** n_dimensions
-    # root is freed by free_tree
-    root = <Node*> malloc(sizeof(Node))
-    root.is_leaf = 1
-    root.parent = NULL
-    root.level = 0
-    root.cumulative_size = 0
-    root.size = 0
-    root.point_index = -1
-    root.max_width = 0.0
-    root.width = <float*> malloc(sizeof(float) * n_dimensions)
-    root.left_edge = <float*> malloc(sizeof(float) * n_dimensions)
-    root.center = <float*> malloc(sizeof(float) * n_dimensions)
-    root.barycenter = <float*> malloc(sizeof(float) * n_dimensions)
-    root.leaf_point_position= <float*> malloc(sizeof(float) * n_dimensions)
-    root.children = NULL
-    for ax in range(n_dimensions):
-        root.width[ax] = width[ax]
-        root.left_edge[ax] = left_edge[ax]
-        root.center[ax] = 0.0
-        root.barycenter[ax] = 0.
-        root.leaf_point_position[ax] = -1
-    for ax in range(n_dimensions):
-        root.max_width = max(root.max_width, root.width[ax])
-    if DEBUGFLAG:
-        printf("[t-SNE] Created root node %p\n", root)
-    return root
-
-cdef Node* create_child(Node *parent, int[3] offset) nogil:
-    # Create a new child node with default parameters
-    cdef int ax
-    # these children are freed by free_recursive
-    child = <Node *> malloc(sizeof(Node))
-    child.is_leaf = 1
-    child.parent = parent
-    child.level = parent.level + 1
-    child.size = 0
-    child.cumulative_size = 0
-    child.point_index = -1
-    child.tree = parent.tree
-    child.max_width = 0.0
-    child.width = <float*> malloc(sizeof(float) * parent.tree.n_dimensions)
-    child.left_edge = <float*> malloc(sizeof(float) * parent.tree.n_dimensions)
-    child.center = <float*> malloc(sizeof(float) * parent.tree.n_dimensions)
-    child.barycenter = <float*> malloc(sizeof(float) * parent.tree.n_dimensions)
-    child.leaf_point_position = <float*> malloc(sizeof(float) * parent.tree.n_dimensions)
-    child.children = NULL
-    for ax in range(parent.tree.n_dimensions):
-        child.width[ax] = parent.width[ax] / 2.0
-        child.left_edge[ax] = parent.left_edge[ax] + offset[ax] * parent.width[ax] / 2.0
-        child.center[ax] = child.left_edge[ax] + child.width[ax] / 2.0
-        child.barycenter[ax] = 0.
-        child.leaf_point_position[ax] = -1.
-    for ax in range(parent.tree.n_dimensions):
-        child.max_width = max(child.max_width, child.width[ax])
-    child.tree.n_cells += 1
-    return child
-
-cdef Node* select_child(Node *node, float[3] pos, long index) nogil:
-    # Find which sub-node a position should go into
-    # And return the appropriate node
-    cdef int* offset = <int*> malloc(sizeof(int) * node.tree.n_dimensions)
-    cdef int ax, idx
-    cdef Node* child
-    cdef int error
-    for ax in range(node.tree.n_dimensions):
-        offset[ax] = (pos[ax] - (node.left_edge[ax] + node.width[ax] / 2.0)) > 0.
-    idx = offset2index(offset, node.tree.n_dimensions)
-    child = node.children[idx]
-    if DEBUGFLAG:
-        printf("[t-SNE] Offset [%i, %i] with LE [%f, %f]\n",
-               offset[0], offset[1], child.left_edge[0], child.left_edge[1])
-    free(offset)
-    return child
-
-
-cdef inline void index2offset(int* offset, int index, int n_dimensions) nogil:
-    # Convert a 1D index into N-D index; useful for indexing
-    # children of a quadtree, octree, N-tree
-    # Quite likely there's a fancy bitshift way of doing this
-    # since the offset is equivalent to the binary representation
-    # of the integer index
-    # We read the offset array left-to-right
-    # such that the least significat bit is on the right
-    cdef int rem, k, shift
-    for k in range(n_dimensions):
-        shift = n_dimensions -k -1
-        rem = ((index >> shift) << shift)
-        offset[k] = rem > 0
-        if DEBUGFLAG:
-            printf("i2o index %i k %i rem %i offset", index, k, rem)
-            for j in range(n_dimensions):
-                printf(" %i", offset[j])
-            printf(" n_dimensions %i\n", n_dimensions)
-        index -= rem
-
-
-cdef inline int offset2index(int* offset, int n_dimensions) nogil:
-    # Calculate the 1:1 index for a given offset array
-    # We read the offset array right-to-left
-    # such that the least significat bit is on the right
-    cdef int dim
-    cdef int index = 0
-    for dim in range(n_dimensions):
-        index += (2 ** dim) * offset[n_dimensions - dim - 1]
-        if DEBUGFLAG:
-            printf("o2i index %i dim %i            offset", index, dim)
-            for j in range(n_dimensions):
-                printf(" %i", offset[j])
-            printf(" n_dimensions %i\n", n_dimensions)
-    return index
-
-
-cdef void subdivide(Node* node) nogil:
-    # This instantiates 2**n_dimensions = n_cell_per_node nodes for the current node
-    cdef int idx = 0
-    cdef int* offset = <int*> malloc(sizeof(int) * node.tree.n_dimensions)
-    node.is_leaf = False
-    node.children = <Node**> malloc(sizeof(Node*) * node.tree.n_cell_per_node)
-    for idx in range(node.tree.n_cell_per_node):
-        index2offset(offset, idx, node.tree.n_dimensions)
-        node.children[idx] = create_child(node, offset)
-    free(offset)
-
-
-cdef int insert(Node *root, float pos[3], long point_index, long depth, long
-        duplicate_count) nogil:
-    # Introduce a new point into the tree
-    # by recursively inserting it and subdividng as necessary
-    # Carefully treat the case of identical points at the same node
-    # by increasing the root.size and tracking duplicate_count
-    cdef Node *child
-    cdef long i
-    cdef int ax
-    cdef int not_identical = 1
-    cdef int n_dimensions = root.tree.n_dimensions
-    if DEBUGFLAG:
-        printf("[t-SNE] [d=%i] Inserting pos %i [%f, %f] duplicate_count=%i "
-                "into child %p\n", depth, point_index, pos[0], pos[1],
-                duplicate_count, root)    
-    # Increment the total number points including this
-    # node and below it
-    root.cumulative_size += duplicate_count
-    # Evaluate the new center of mass, weighting the previous
-    # center of mass against the new point data
-    cdef double frac_seen = <double>(root.cumulative_size - 1) / (<double>
-            root.cumulative_size)
-    cdef double frac_new  = 1.0 / <double> root.cumulative_size
-    # Assert that duplicate_count > 0
-    if duplicate_count < 1:
-        return -1
-    # Assert that the point is inside the left & right edges
-    for ax in range(n_dimensions):
-        root.barycenter[ax] *= frac_seen
-        if (pos[ax] > (root.left_edge[ax] + root.width[ax] + EPSILON)):
-            printf("[t-SNE] Error: point (%1.9e) is above right edge of node "
-                    "(%1.9e)\n", pos[ax], root.left_edge[ax] + root.width[ax])
-            return -1
-        if (pos[ax] < root.left_edge[ax] - EPSILON):
-            printf("[t-SNE] Error: point (%1.9e) is below left edge of node "
-                   "(%1.9e)\n", pos[ax], root.left_edge[ax])
-            return -1
-    for ax in range(n_dimensions):
-        root.barycenter[ax] += pos[ax] * frac_new
-
-    # If this node is unoccupied, fill it.
-    # Otherwise, we need to insert recursively.
-    # Two insertion scenarios: 
-    # 1) Insert into this node if it is a leaf and empty
-    # 2) Subdivide this node if it is currently occupied
-    if (root.size == 0) & root.is_leaf:
-        # Root node is empty and a leaf
-        if DEBUGFLAG:
-            printf("[t-SNE] [d=%i] Inserting [%f, %f] into blank cell\n", depth,
-                   pos[0], pos[1])
-        for ax in range(n_dimensions):
-            root.leaf_point_position[ax] = pos[ax]
-        root.point_index = point_index
-        root.size = duplicate_count
-        return 0
-    else:
-        # Root node is occupied or not a leaf
-        if DEBUGFLAG:
-            printf("[t-SNE] [d=%i] Node %p is occupied or is a leaf.\n", depth,
-                    root)
-            printf("[t-SNE] [d=%i] Node %p leaf = %i. Size %i\n", depth, root,
-                    root.is_leaf, root.size)
-        if root.is_leaf & (root.size > 0):
-            # is a leaf node and is occupied
-            for ax in range(n_dimensions):
-                not_identical &= (fabsf(pos[ax] - root.leaf_point_position[ax]) < EPSILON)
-                not_identical &= (root.point_index != point_index)
-            if not_identical == 1:
-                root.size += duplicate_count
-                if DEBUGFLAG:
-                    printf("[t-SNE] Warning: [d=%i] Detected identical "
-                            "points. Returning. Leaf now has size %i\n",
-                            depth, root.size)
-                return 0
-        # If necessary, subdivide this node before
-        # descending
-        if root.is_leaf:
-            if DEBUGFLAG:
-                printf("[t-SNE] [d=%i] Subdividing this leaf node %p\n", depth,
-                        root)
-            subdivide(root)
-        # We have two points to relocate: the one previously
-        # at this node, and the new one we're attempting
-        # to insert
-        if root.size > 0:
-            child = select_child(root, root.leaf_point_position, root.point_index)
-            if DEBUGFLAG:
-                printf("[t-SNE] [d=%i] Relocating old point to node %p\n",
-                        depth, child)
-            insert(child, root.leaf_point_position, root.point_index, depth + 1, root.size)
-        # Insert the new point
-        if DEBUGFLAG:
-            printf("[t-SNE] [d=%i] Selecting node for new point\n", depth)
-        child = select_child(root, pos, point_index)
-        if root.size > 0:
-            # Remove the point from this node
-            for ax in range(n_dimensions):
-                root.leaf_point_position[ax] = -1            
-            root.size = 0
-            root.point_index = -1            
-        return insert(child, pos, point_index, depth + 1, 1)
-
-cdef int insert_many(Tree* tree, float[:,:] pos_array) nogil:
-    # Insert each data point into the tree one at a time
-    cdef long nrows = pos_array.shape[0]
-    cdef long i
-    cdef int ax
-    cdef float row[3]
-    cdef long err = 0
-    for i in range(nrows):
-        for ax in range(tree.n_dimensions):
-            row[ax] = pos_array[i, ax]
-        if DEBUGFLAG:
-            printf("[t-SNE] inserting point %i: [%f, %f]\n", i, row[0], row[1])
-        err = insert(tree.root_node, row, i, 0, 1)
-        if err != 0:
-            printf("[t-SNE] ERROR\n%s", EMPTY_STRING)
-            return err
-        tree.n_points += 1
-    return err
-
-cdef int free_tree(Tree* tree) nogil:
-    cdef int check
-    cdef long* cnt = <long*> malloc(sizeof(long) * 3)
-    for i in range(3):
-        cnt[i] = 0
-    free_recursive(tree, tree.root_node, cnt)
-    check = cnt[0] == tree.n_cells
-    check &= cnt[2] == tree.n_points
-    free(tree)
-    free(cnt)
-    return check
-
-cdef void free_post_children(Node *node) nogil:
-    free(node.width)
-    free(node.left_edge)
-    free(node.center)
-    free(node.barycenter)
-    free(node.leaf_point_position)
-    free(node)
-
-cdef void free_recursive(Tree* tree, Node *root, long* counts) nogil:
-    # Free up all of the tree nodes recursively
-    # while counting the number of nodes visited
-    # and total number of data points removed
-    cdef int idx
-    cdef Node* child
-    if not root.is_leaf:
-        for idx in range(tree.n_cell_per_node):
-            child = root.children[idx]
-            free_recursive(tree, child, counts)
-            counts[0] += 1
-            if child.is_leaf:
-                counts[1] += 1
-                if child.size > 0:
-                    counts[2] +=1
-            else:
-                free(child.children)
-
-            free_post_children(child)
-
-    if root == tree.root_node:
-        if not root.is_leaf:
-            free(root.children)
-
-        free_post_children(root)
-
-cdef long count_points(Node* root, long count) nogil:
-    # Walk through the whole tree and count the number 
-    # of points at the leaf nodes
-    if DEBUGFLAG:
-        printf("[t-SNE] Counting nodes at root node %p\n", root)
-    cdef Node* child
-    cdef int idx
-    if root.is_leaf:
-        count += root.size
-        if DEBUGFLAG : 
-            printf("[t-SNE] %p is a leaf node, no children\n", root)
-            printf("[t-SNE] %i points in node %p\n", count, root)
-        return count
-    # Otherwise, get the children
-    for idx in range(root.tree.n_cell_per_node):
-        child = root.children[idx]
-        if DEBUGFLAG:
-            printf("[t-SNE] Counting points for child %p\n", child)
-        if child.is_leaf and child.size > 0:
-            if DEBUGFLAG:
-                printf("[t-SNE] Child has size %d\n", child.size)
-            count += child.size
-        elif not child.is_leaf:
-            if DEBUGFLAG:
-                printf("[t-SNE] Child is not a leaf. Descending\n%s", EMPTY_STRING)
-            count = count_points(child, count)
-        # else case is we have an empty leaf node
-        # which happens when we create a quadtree for
-        # one point, and then the other neighboring cells
-        # don't get filled in
-    if DEBUGFLAG:
-        printf("[t-SNE] %i points in this node\n", count)
-    return count
-
-
-cdef float compute_gradient(float[:,:] val_P,
-                            float[:,:] pos_reference,
-                            np.int64_t[:,:] neighbors,
-                            float[:,:] tot_force,
-                            Node* root_node,
+cdef float compute_gradient(float[:] val_P,
+                            float[:, :] pos_reference,
+                            np.int64_t[:] neighbors,
+                            np.int64_t[:] indptr,
+                            float[:, :] tot_force,
+                            quad_tree._QuadTree qt,
                             float theta,
                             float dof,
                             long start,
                             long stop) nogil:
     # Having created the tree, calculate the gradient
     # in two components, the positive and negative forces
-    cdef long i, coord
-    cdef int ax
-    cdef long n = pos_reference.shape[0]
-    cdef int n_dimensions = root_node.tree.n_dimensions
-    if root_node.tree.verbose > 11:
-        printf("[t-SNE] Allocating %i elements in force arrays\n",
-                n * n_dimensions * 2)
-    cdef float* sum_Q = <float*> malloc(sizeof(float))
-    cdef float* neg_f = <float*> malloc(sizeof(float) * n * n_dimensions)
-    cdef float* neg_f_fast = <float*> malloc(sizeof(float) * n * n_dimensions)
-    cdef float* pos_f = <float*> malloc(sizeof(float) * n * n_dimensions)
-    cdef clock_t t1, t2
-    cdef float sQ, error
+    cdef:
+        long i, coord
+        int ax
+        long n_samples = pos_reference.shape[0]
+        int n_dimensions = qt.n_dimensions
+        double[1] sum_Q
+        clock_t t1, t2
+        float sQ, error
+
+    if qt.verbose > 11:
+        printf("[t-SNE] Allocating %li elements in force arrays\n",
+                n_samples * n_dimensions * 2)
+    cdef float* neg_f = <float*> malloc(sizeof(float) * n_samples * n_dimensions)
+    cdef float* pos_f = <float*> malloc(sizeof(float) * n_samples * n_dimensions)
 
     sum_Q[0] = 0.0
     t1 = clock()
-    compute_gradient_negative(val_P, pos_reference, neg_f, root_node, sum_Q,
+    compute_gradient_negative(pos_reference, neg_f, qt, sum_Q,
                               dof, theta, start, stop)
     t2 = clock()
-    if root_node.tree.verbose > 15:
+    if qt.verbose > 15:
         printf("[t-SNE] Computing negative gradient: %e ticks\n", ((float) (t2 - t1)))
     sQ = sum_Q[0]
     t1 = clock()
-    error = compute_gradient_positive(val_P, pos_reference, neighbors, pos_f,
-                              n_dimensions, dof, sQ, start, root_node.tree.verbose)
+    error = compute_gradient_positive(val_P, pos_reference, neighbors, indptr,
+                                      pos_f, n_dimensions, dof, sQ, start,
+                                      qt.verbose)
     t2 = clock()
-    if root_node.tree.verbose > 15:
+    if qt.verbose > 15:
         printf("[t-SNE] Computing positive gradient: %e ticks\n", ((float) (t2 - t1)))
-    for i in range(start, n):
+    for i in range(start, n_samples):
         for ax in range(n_dimensions):
             coord = i * n_dimensions + ax
-            tot_force[i, ax] = pos_f[coord] - (neg_f[coord] / sum_Q[0])
-    free(sum_Q)
+            tot_force[i, ax] = pos_f[coord] - (neg_f[coord] / sQ)
+
     free(neg_f)
-    free(neg_f_fast)
     free(pos_f)
-    return sQ
+    return error
 
 
-cdef float compute_gradient_positive(float[:,:] val_P,
-                                     float[:,:] pos_reference,
-                                     np.int64_t[:,:] neighbors,
+cdef float compute_gradient_positive(float[:] val_P,
+                                     float[:, :] pos_reference,
+                                     np.int64_t[:] neighbors,
+                                     np.int64_t[:] indptr,
                                      float* pos_f,
                                      int n_dimensions,
                                      float dof,
-                                     float sum_Q,
+                                     double sum_Q,
                                      np.int64_t start,
                                      int verbose) nogil:
     # Sum over the following expression for i not equal to j
@@ -507,33 +113,33 @@ cdef float compute_gradient_positive(float[:,:] val_P,
     cdef:
         int ax
         long i, j, k
-        long K = neighbors.shape[1]
-        long n = val_P.shape[0]
-        float[3] buff
-        float D, Q, pij
+        long n_samples = indptr.shape[0] - 1
+        float dij, qij, pij
         float C = 0.0
         float exponent = (dof + 1.0) / -2.0
-    cdef clock_t t1, t2
+        float[3] buff
+        clock_t t1, t2
+
     t1 = clock()
-    for i in range(start, n):
+    for i in range(start, n_samples):
+        # Init the gradient vector
         for ax in range(n_dimensions):
             pos_f[i * n_dimensions + ax] = 0.0
-        for k in range(K):
-            j = neighbors[i, k]
-            # we don't need to exclude the i==j case since we've 
-            # already thrown it out from the list of neighbors
-            D = 0.0
-            Q = 0.0
-            pij = val_P[i, j]
+        # Compute the positive interaction for the nearest neighbors
+        for k in range(indptr[i], indptr[i+1]):
+            j = neighbors[k]
+            dij = 0.0
+            pij = val_P[k]
             for ax in range(n_dimensions):
                 buff[ax] = pos_reference[i, ax] - pos_reference[j, ax]
-                D += buff[ax] ** 2.0  
-            Q = (((1.0 + D) / dof) ** exponent)
-            D = pij * Q
-            Q /= sum_Q
-            C += pij * log((pij + EPSILON) / (Q + EPSILON))
+                dij += buff[ax] * buff[ax]
+            qij = (((1.0 + dij) / dof) ** exponent)
+            dij = pij * qij
+            qij /= sum_Q
+            C += pij * log(max(pij, FLOAT32_TINY)
+                           / max(qij, FLOAT32_TINY))
             for ax in range(n_dimensions):
-                pos_f[i * n_dimensions + ax] += D * buff[ax]
+                pos_f[i * n_dimensions + ax] += dij * buff[ax]
     t2 = clock()
     dt = ((float) (t2 - t1))
     if verbose > 10:
@@ -541,45 +147,32 @@ cdef float compute_gradient_positive(float[:,:] val_P,
     return C
 
 
-
-cdef void compute_gradient_negative(float[:,:] val_P, 
-                                    float[:,:] pos_reference,
+cdef void compute_gradient_negative(float[:, :] pos_reference,
                                     float* neg_f,
-                                    Node *root_node,
-                                    float* sum_Q,
+                                    quad_tree._QuadTree qt,
+                                    double* sum_Q,
                                     float dof,
-                                    float theta, 
-                                    long start, 
+                                    float theta,
+                                    long start,
                                     long stop) nogil:
     if stop == -1:
-        stop = pos_reference.shape[0] 
+        stop = pos_reference.shape[0]
     cdef:
         int ax
-        long i, j
+        int n_dimensions = qt.n_dimensions
+        long i, j, idx
         long n = stop - start
-        float* force
-        float* iQ 
-        float* pos
-        float* dist2s
-        long* sizes
-        float* deltas
-        long* l
-        int n_dimensions = root_node.tree.n_dimensions
-        float qijZ, mult
-        long idx, 
         long dta = 0
         long dtb = 0
+        long offset = n_dimensions + 2
+        long* l
+        float size, dist2s, mult
+        double qijZ
+        float[1] iQ
+        float[3] force, neg_force, pos
         clock_t t1, t2, t3
-        float* neg_force
 
-    iQ = <float*> malloc(sizeof(float))
-    force = <float*> malloc(sizeof(float) * n_dimensions)
-    pos = <float*> malloc(sizeof(float) * n_dimensions)
-    dist2s = <float*> malloc(sizeof(float) * n)
-    sizes = <long*> malloc(sizeof(long) * n)
-    deltas = <float*> malloc(sizeof(float) * n * n_dimensions)
-    l = <long*> malloc(sizeof(long))
-    neg_force= <float*> malloc(sizeof(float) * n_dimensions)
+    summary = <float*> malloc(sizeof(float) * n * offset)
 
     for i in range(start, stop):
         # Clear the arrays
@@ -588,146 +181,44 @@ cdef void compute_gradient_negative(float[:,:] val_P,
             neg_force[ax] = 0.0
             pos[ax] = pos_reference[i, ax]
         iQ[0] = 0.0
-        l[0] = 0
         # Find which nodes are summarizing and collect their centers of mass
         # deltas, and sizes, into vectorized arrays
         t1 = clock()
-        compute_non_edge_forces(root_node, theta, i, pos, force, dist2s,
-                                     sizes, deltas, l)
+        idx = qt.summarize(pos, summary, theta*theta)
         t2 = clock()
         # Compute the t-SNE negative force
         # for the digits dataset, walking the tree
-        # is about 10-15x more expensive than the 
+        # is about 10-15x more expensive than the
         # following for loop
         exponent = (dof + 1.0) / -2.0
-        for j in range(l[0]):
-            qijZ = ((1.0 + dist2s[j]) / dof) ** exponent
-            sum_Q[0] += sizes[j] * qijZ
-            mult = sizes[j] * qijZ * qijZ
+        for j in range(idx // offset):
+
+            dist2s = summary[j * offset + n_dimensions]
+            size = summary[j * offset + n_dimensions + 1]
+            qijZ = ((1.0 + dist2s) / dof) ** exponent  # 1/(1+dist)
+            sum_Q[0] += size * qijZ   # size of the node * q
+            mult = size * qijZ * qijZ
             for ax in range(n_dimensions):
-                idx = j * n_dimensions + ax
-                neg_force[ax] += mult * deltas[idx]
+                neg_force[ax] += mult * summary[j * offset + ax]
         t3 = clock()
         for ax in range(n_dimensions):
             neg_f[i * n_dimensions + ax] = neg_force[ax]
         dta += t2 - t1
         dtb += t3 - t2
-    if root_node.tree.verbose > 20:
-        printf("[t-SNE] Tree: %i clock ticks | ", dta)
-        printf("Force computation: %i clock ticks\n", dtb)
-    free(iQ)
-    free(force)
-    free(pos)
-    free(dist2s)
-    free(sizes)
-    free(deltas)
-    free(l)
-    free(neg_force)
+    if qt.verbose > 20:
+        printf("[t-SNE] Tree: %li clock ticks | ", dta)
+        printf("Force computation: %li clock ticks\n", dtb)
 
-
-cdef void compute_non_edge_forces(Node* node, 
-                                  float theta,
-                                  long point_index,
-                                  float* pos,
-                                  float* force,
-                                  float* dist2s,
-                                  long* sizes,
-                                  float* deltas,
-                                  long* l) nogil:
-    # Compute the t-SNE force on the point in pos given by point_index
-    cdef:
-        Node* child
-        int i, j
-        int n_dimensions = node.tree.n_dimensions
-        long idx, idx1
-        float dist_check
-    
-    # There are no points below this node if cumulative_size == 0
-    # so do not bother to calculate any force contributions
-    # Also do not compute self-interactions
-    if node.cumulative_size > 0 and not (node.is_leaf and (node.point_index ==
-        point_index)):
-        # Compute distance between node center of mass and the reference point
-        # I've tried rewriting this in terms of BLAS functions, but it's about
-        # 1.5x worse when we do so, probbaly because the vectors are small
-        idx1 = l[0] * n_dimensions
-        deltas[idx1] = pos[0] - node.barycenter[0]
-        idx = idx1
-        for i in range(1, n_dimensions):
-            idx += 1
-            deltas[idx] = pos[i] - node.barycenter[i] 
-        # do np.sqrt(np.sum(deltas**2.0))
-        dist2s[l[0]] = snrm2(n_dimensions, &deltas[idx1], 1)
-        # Check whether we can use this node as a summary
-        # It's a summary node if the angular size as measured from the point
-        # is relatively small (w.r.t. to theta) or if it is a leaf node.
-        # If it can be summarized, we use the cell center of mass 
-        # Otherwise, we go a higher level of resolution and into the leaves.
-        if node.is_leaf or ((node.max_width / dist2s[l[0]]) < theta):
-            # Compute the t-SNE force between the reference point and the
-            # current node
-            sizes[l[0]] = node.cumulative_size
-            dist2s[l[0]] = dist2s[l[0]] * dist2s[l[0]]
-            l[0] += 1
-        else:
-            # Recursively apply Barnes-Hut to child nodes
-            for idx in range(node.tree.n_cell_per_node):
-                child = node.children[idx]
-                if child.cumulative_size == 0: 
-                    continue
-                compute_non_edge_forces(child, theta,
-                        point_index, pos, force, dist2s, sizes, deltas,
-                        l)
+    # Put sum_Q to machine EPSILON to avoid divisions by 0
+    sum_Q[0] = max(sum_Q[0], FLOAT64_EPS)
+    free(summary)
 
 
-cdef float compute_error(float[:, :] val_P,
-                        float[:, :] pos_reference,
-                        np.int64_t[:,:] neighbors,
-                        float sum_Q,
-                        int n_dimensions,
-                        int verbose) nogil:
-    cdef int i, j, ax
-    cdef int I = neighbors.shape[0]
-    cdef int K = neighbors.shape[1]
-    cdef float pij, Q
-    cdef float C = 0.0
-    cdef clock_t t1, t2
-    cdef float dt, delta
-    t1 = clock()
-    for i in range(I):
-        for k in range(K):
-            j = neighbors[i, k]
-            pij = val_P[i, j]
-            Q = 0.0
-            for ax in range(n_dimensions):
-                delta = (pos_reference[i, ax] - pos_reference[j, ax])
-                Q += delta * delta
-            Q = (1.0 / (sum_Q + Q * sum_Q))
-            C += pij * log((pij + EPSILON) / (Q + EPSILON))
-    t2 = clock()
-    dt = ((float) (t2 - t1))
-    if verbose > 10:
-        printf("[t-SNE] Computed error=%1.4f in %1.1e ticks\n", C, dt)
-    return C
-
-
-def calculate_edge(pos_output):
-    # Make the boundaries slightly outside of the data
-    # to avoid floating point error near the edge
-    left_edge = np.min(pos_output, axis=0)
-    right_edge = np.max(pos_output, axis=0) 
-    center = (right_edge + left_edge) * 0.5
-    width = np.maximum(np.subtract(right_edge, left_edge), EPSILON)
-    # Exagerate width to avoid boundary edge
-    width = width.astype(np.float32) * 1.001
-    left_edge = center - width / 2.0
-    right_edge = center + width / 2.0
-    return left_edge, right_edge, width
-
-def gradient(float[:,:] pij_input, 
-             float[:,:] pos_output, 
-             np.int64_t[:,:] neighbors, 
-             float[:,:] forces, 
+def gradient(float[:] val_P,
+             float[:, :] pos_output,
+             np.int64_t[:] neighbors,
+             np.int64_t[:] indptr,
+             float[:, :] forces,
              float theta,
              int n_dimensions,
              int verbose,
@@ -738,108 +229,32 @@ def gradient(float[:,:] pij_input,
     # up in-place
     cdef float C
     n = pos_output.shape[0]
-    left_edge, right_edge, width = calculate_edge(pos_output)
-    assert width.itemsize == 4
-    assert pij_input.itemsize == 4
+    assert val_P.itemsize == 4
     assert pos_output.itemsize == 4
     assert forces.itemsize == 4
-    m = "Number of neighbors must be < # of points - 1"
-    assert n - 1 >= neighbors.shape[1], m
-    m = "neighbors array and pos_output shapes are incompatible"
-    assert n == neighbors.shape[0], m
     m = "Forces array and pos_output shapes are incompatible"
     assert n == forces.shape[0], m
     m = "Pij and pos_output shapes are incompatible"
-    assert n == pij_input.shape[0], m
-    m = "Pij and pos_output shapes are incompatible"
-    assert n == pij_input.shape[1], m
+    assert n == indptr.shape[0] - 1, m
     if verbose > 10:
         printf("[t-SNE] Initializing tree of n_dimensions %i\n", n_dimensions)
-    cdef Tree* qt = init_tree(left_edge, width, n_dimensions, verbose)
+    cdef quad_tree._QuadTree qt = quad_tree._QuadTree(pos_output.shape[1],
+                                                      verbose)
     if verbose > 10:
-        printf("[t-SNE] Inserting %i points\n", pos_output.shape[0])
-    err = insert_many(qt, pos_output)
-    assert err == 0, "[t-SNE] Insertion failed"
+        printf("[t-SNE] Inserting %li points\n", pos_output.shape[0])
+    qt.build_tree(pos_output)
     if verbose > 10:
         # XXX: format hack to workaround lack of `const char *` type
         # in the generated C code that triggers error with gcc 4.9
         # and -Werror=format-security
         printf("[t-SNE] Computing gradient\n%s", EMPTY_STRING)
-    sum_Q = compute_gradient(pij_input, pos_output, neighbors, forces,
-                             qt.root_node, theta, dof, skip_num_points, -1)
-    C = compute_error(pij_input, pos_output, neighbors, sum_Q, n_dimensions,
-                      verbose)
+    C = compute_gradient(val_P, pos_output, neighbors, indptr, forces,
+                         qt, theta, dof, skip_num_points, -1)
     if verbose > 10:
         # XXX: format hack to workaround lack of `const char *` type
         # in the generated C code
         # and -Werror=format-security
         printf("[t-SNE] Checking tree consistency\n%s", EMPTY_STRING)
-    cdef long count = count_points(qt.root_node, 0)
-    m = ("Tree consistency failed: unexpected number of points=%i "
-         "at root node=%i" % (count, qt.root_node.cumulative_size))
-    assert count == qt.root_node.cumulative_size, m 
     m = "Tree consistency failed: unexpected number of points on the tree"
-    assert count == qt.n_points, m
-    free_tree(qt)
+    assert qt.cells[0].cumulative_size == qt.n_points, m
     return C
-
-
-# Helper functions
-def check_quadtree(X, np.int64_t[:] counts):
-    """
-    Helper function to access quadtree functions for testing
-    """
-    
-    X = X.astype(np.float32)
-    left_edge, right_edge, width = calculate_edge(X)
-    # Initialise a tree
-    qt = init_tree(left_edge, width, 2, 2)
-    # Insert data into the tree
-    insert_many(qt, X)
-
-    cdef long count = count_points(qt.root_node, 0)
-    counts[0] = count
-    counts[1] = qt.root_node.cumulative_size
-    counts[2] = qt.n_points
-    free_tree(qt)
-    return counts
-
-
-cdef int helper_test_index2offset(int* check, int index, int n_dimensions):
-    cdef int* offset = <int*> malloc(sizeof(int) * n_dimensions)
-    cdef int error_check = 1
-    for i in range(n_dimensions):
-        offset[i] = 0
-    index2offset(offset, index, n_dimensions)
-    for i in range(n_dimensions):
-        error_check &= offset[i] == check[i]
-    free(offset)
-    return error_check
-
-
-def test_index2offset():
-    ret = 1
-    ret &= helper_test_index2offset([1, 0, 1], 5, 3) == 1
-    ret &= helper_test_index2offset([0, 0, 0], 0, 3) == 1
-    ret &= helper_test_index2offset([0, 0, 1], 1, 3) == 1
-    ret &= helper_test_index2offset([0, 1, 0], 2, 3) == 1
-    ret &= helper_test_index2offset([0, 1, 1], 3, 3) == 1
-    ret &= helper_test_index2offset([1, 0, 0], 4, 3) == 1
-    return ret
-
-
-def test_index_offset():
-    cdef int n_dimensions, idx, tidx, k
-    cdef int error_check = 1
-    cdef int* offset 
-    for n_dimensions in range(2, 10):
-        offset = <int*> malloc(sizeof(int) * n_dimensions)
-        for k in range(n_dimensions):
-            offset[k] = 0
-        for idx in range(2 ** n_dimensions):
-            index2offset(offset, idx, n_dimensions)
-            tidx = offset2index(offset, n_dimensions)
-            error_check &= tidx == idx
-            assert error_check == 1
-        free(offset)
-    return error_check
diff --git a/sklearn/manifold/_utils.pyx b/sklearn/manifold/_utils.pyx
index b85da09dba..1f51889a4f 100644
--- a/sklearn/manifold/_utils.pyx
+++ b/sklearn/manifold/_utils.pyx
@@ -12,22 +12,22 @@ cdef float PERPLEXITY_TOLERANCE = 1e-5
 
 @cython.boundscheck(False)
 cpdef np.ndarray[np.float32_t, ndim=2] _binary_search_perplexity(
-        np.ndarray[np.float32_t, ndim=2] affinities, 
-        np.ndarray[np.int64_t, ndim=2] neighbors, 
+        np.ndarray[np.float32_t, ndim=2] affinities,
+        np.ndarray[np.int64_t, ndim=2] neighbors,
         float desired_perplexity,
         int verbose):
-    """Binary search for sigmas of conditional Gaussians. 
-    
+    """Binary search for sigmas of conditional Gaussians.
+
     This approximation reduces the computational complexity from O(N^2) to
     O(uN). See the exact method '_binary_search_perplexity' for more details.
 
     Parameters
     ----------
-    affinities : array-like, shape (n_samples, n_samples)
-        Distances between training samples.
+    affinities : array-like, shape (n_samples, k)
+        Distances between training samples and its k nearest neighbors.
 
-    neighbors : array-like, shape (n_samples, K) or None
-        Each row contains the indices to the K nearest neigbors. If this
+    neighbors : array-like, shape (n_samples, k) or None
+        Each row contains the indices to the k nearest neigbors. If this
         array is None, then the perplexity is estimated over all data
         not just the nearest neighbors.
 
@@ -46,28 +46,30 @@ cpdef np.ndarray[np.float32_t, ndim=2] _binary_search_perplexity(
     cdef long n_steps = 100
 
     cdef long n_samples = affinities.shape[0]
-    # This array is later used as a 32bit array. It has multiple intermediate
-    # floating point additions that benefit from the extra precision
-    cdef np.ndarray[np.float64_t, ndim=2] P = np.zeros((n_samples, n_samples),
-                                                       dtype=np.float64)
-    # Precisions of conditional Gaussian distrubutions
+    # Precisions of conditional Gaussian distributions
     cdef float beta
     cdef float beta_min
     cdef float beta_max
     cdef float beta_sum = 0.0
-    # Now we go to log scale
+
+    # Use log scale
     cdef float desired_entropy = math.log(desired_perplexity)
     cdef float entropy_diff
 
     cdef float entropy
     cdef float sum_Pi
     cdef float sum_disti_Pi
-    cdef long i, j, k, l = 0
-    cdef long K = n_samples
+    cdef long i, j, k, l
+    cdef long n_neighbors = n_samples
     cdef int using_neighbors = neighbors is not None
 
     if using_neighbors:
-        K = neighbors.shape[1]
+        n_neighbors = neighbors.shape[1]
+
+    # This array is later used as a 32bit array. It has multiple intermediate
+    # floating point additions that benefit from the extra precision
+    cdef np.ndarray[np.float64_t, ndim=2] P = np.zeros(
+        (n_samples, n_neighbors), dtype=np.float64)
 
     for i in range(n_samples):
         beta_min = -NPY_INFINITY
@@ -79,34 +81,20 @@ cpdef np.ndarray[np.float32_t, ndim=2] _binary_search_perplexity(
             # Compute current entropy and corresponding probabilities
             # computed just over the nearest neighbors or over all data
             # if we're not using neighbors
-            if using_neighbors:
-                for k in range(K):
-                    j = neighbors[i, k]
-                    P[i, j] = math.exp(-affinities[i, j] * beta)
-            else:
-                for j in range(K):
-                    P[i, j] = math.exp(-affinities[i, j] * beta)
-            P[i, i] = 0.0
             sum_Pi = 0.0
-            if using_neighbors:
-                for k in range(K):
-                    j = neighbors[i, k]
-                    sum_Pi += P[i, j]
-            else:
-                for j in range(K):
+            for j in range(n_neighbors):
+                if j != i or using_neighbors:
+                    P[i, j] = math.exp(-affinities[i, j] * beta)
                     sum_Pi += P[i, j]
+
             if sum_Pi == 0.0:
                 sum_Pi = EPSILON_DBL
             sum_disti_Pi = 0.0
-            if using_neighbors:
-                for k in range(K):
-                    j = neighbors[i, k]
-                    P[i, j] /= sum_Pi
-                    sum_disti_Pi += affinities[i, j] * P[i, j]
-            else:
-                for j in range(K):
-                    P[i, j] /= sum_Pi
-                    sum_disti_Pi += affinities[i, j] * P[i, j]
+
+            for j in range(n_neighbors):
+                P[i, j] /= sum_Pi
+                sum_disti_Pi += affinities[i, j] * P[i, j]
+
             entropy = math.log(sum_Pi) + beta * sum_disti_Pi
             entropy_diff = entropy - desired_entropy
 
diff --git a/sklearn/manifold/setup.py b/sklearn/manifold/setup.py
index a2562cd3c0..bec1e25eee 100644
--- a/sklearn/manifold/setup.py
+++ b/sklearn/manifold/setup.py
@@ -31,6 +31,7 @@ def configuration(parent_package="", top_path=None):
 
     return config
 
+
 if __name__ == "__main__":
     from numpy.distutils.core import setup
     setup(**configuration().todict())
diff --git a/sklearn/manifold/t_sne.py b/sklearn/manifold/t_sne.py
index 8d4056627c..163e8340f7 100644
--- a/sklearn/manifold/t_sne.py
+++ b/sklearn/manifold/t_sne.py
@@ -8,12 +8,14 @@
 # * Fast Optimization for t-SNE:
 #   http://cseweb.ucsd.edu/~lvdmaaten/workshops/nips2010/papers/vandermaaten.pdf
 
+from time import time
 import numpy as np
 from scipy import linalg
 import scipy.sparse as sp
 from scipy.spatial.distance import pdist
 from scipy.spatial.distance import squareform
-from ..neighbors import BallTree
+from scipy.sparse import csr_matrix
+from ..neighbors import NearestNeighbors
 from ..base import BaseEstimator
 from ..utils import check_array
 from ..utils import check_random_state
@@ -70,10 +72,11 @@ def _joint_probabilities_nn(distances, neighbors, desired_perplexity, verbose):
 
     Parameters
     ----------
-    distances : array, shape (n_samples * (n_samples-1) / 2,)
-        Distances of samples are stored as condensed matrices, i.e.
-        we omit the diagonal and duplicate entries and store everything
-        in a one-dimensional array.
+    distances : array, shape (n_samples, k)
+        Distances of samples to its k nearest neighbors.
+
+    neighbors : array, shape (n_samples, k)
+        Indices of the k nearest-neighbors for each samples.
 
     desired_perplexity : float
         Desired perplexity of the joint probability distributions.
@@ -83,21 +86,35 @@ def _joint_probabilities_nn(distances, neighbors, desired_perplexity, verbose):
 
     Returns
     -------
-    P : array, shape (n_samples * (n_samples-1) / 2,)
-        Condensed joint probability matrix.
+    P : csr sparse matrix, shape (n_samples, n_samples)
+        Condensed joint probability matrix with only nearest neighbors.
     """
+    t0 = time()
     # Compute conditional probabilities such that they approximately match
     # the desired perplexity
+    n_samples, k = neighbors.shape
     distances = distances.astype(np.float32, copy=False)
     neighbors = neighbors.astype(np.int64, copy=False)
     conditional_P = _utils._binary_search_perplexity(
         distances, neighbors, desired_perplexity, verbose)
-    m = "All probabilities should be finite"
-    assert np.all(np.isfinite(conditional_P)), m
-    P = conditional_P + conditional_P.T
-    sum_P = np.maximum(np.sum(P), MACHINE_EPSILON)
-    P = np.maximum(squareform(P) / sum_P, MACHINE_EPSILON)
-    assert np.all(np.abs(P) <= 1.0)
+    assert np.all(np.isfinite(conditional_P)), \
+        "All probabilities should be finite"
+
+    # Symmetrize the joint probability distribution using sparse operations
+    P = csr_matrix((conditional_P.ravel(), neighbors.ravel(),
+                    range(0, n_samples * k + 1, k)),
+                   shape=(n_samples, n_samples))
+    P = P + P.T
+
+    # Normalize the joint probability distribution
+    sum_P = np.maximum(P.sum(), MACHINE_EPSILON)
+    P /= sum_P
+
+    assert np.all(np.abs(P.data) <= 1.0)
+    if verbose >= 2:
+        duration = time() - t0
+        print("[t-SNE] Computed conditional probabilities in {:.3f}s"
+              .format(duration))
     return P
 
 
@@ -140,24 +157,25 @@ def _kl_divergence(params, P, degrees_of_freedom, n_samples, n_components,
     X_embedded = params.reshape(n_samples, n_components)
 
     # Q is a heavy-tailed distribution: Student's t-distribution
-    n = pdist(X_embedded, "sqeuclidean")
-    n += 1.
-    n /= degrees_of_freedom
-    n **= (degrees_of_freedom + 1.0) / -2.0
-    Q = np.maximum(n / (2.0 * np.sum(n)), MACHINE_EPSILON)
+    dist = pdist(X_embedded, "sqeuclidean")
+    dist += 1.
+    dist /= degrees_of_freedom
+    dist **= (degrees_of_freedom + 1.0) / -2.0
+    Q = np.maximum(dist / (2.0 * np.sum(dist)), MACHINE_EPSILON)
 
     # Optimization trick below: np.dot(x, y) is faster than
     # np.sum(x * y) because it calls BLAS
 
     # Objective: C (Kullback-Leibler divergence of P and Q)
-    kl_divergence = 2.0 * np.dot(P, np.log(P / Q))
+    kl_divergence = 2.0 * np.dot(P, np.log(np.maximum(P, MACHINE_EPSILON) / Q))
 
     # Gradient: dC/dY
-    grad = np.ndarray((n_samples, n_components))
-    PQd = squareform((P - Q) * n)
+    # pdist always returns double precision distances. Thus we need to take
+    grad = np.ndarray((n_samples, n_components), dtype=params.dtype)
+    PQd = squareform((P - Q) * dist)
     for i in range(skip_num_points, n_samples):
-        np.dot(np.ravel(PQd[i], order='K'), X_embedded[i] - X_embedded,
-               out=grad[i])
+        grad[i] = np.dot(np.ravel(PQd[i], order='K'),
+                         X_embedded[i] - X_embedded)
     grad = grad.ravel()
     c = 2.0 * (degrees_of_freedom + 1.0) / degrees_of_freedom
     grad *= c
@@ -165,65 +183,8 @@ def _kl_divergence(params, P, degrees_of_freedom, n_samples, n_components,
     return kl_divergence, grad
 
 
-def _kl_divergence_error(params, P, neighbors, degrees_of_freedom, n_samples,
-                         n_components):
-    """t-SNE objective function: the absolute error of the
-    KL divergence of p_ijs and q_ijs.
-
-    Parameters
-    ----------
-    params : array, shape (n_params,)
-        Unraveled embedding.
-
-    P : array, shape (n_samples * (n_samples-1) / 2,)
-        Condensed joint probability matrix.
-
-    neighbors : array (n_samples, K)
-        The neighbors is not actually required to calculate the
-        divergence, but is here to match the signature of the
-        gradient function
-
-    degrees_of_freedom : float
-        Degrees of freedom of the Student's-t distribution.
-
-    n_samples : int
-        Number of samples.
-
-    n_components : int
-        Dimension of the embedded space.
-
-    Returns
-    -------
-    kl_divergence : float
-        Kullback-Leibler divergence of p_ij and q_ij.
-
-    grad : array, shape (n_params,)
-        Unraveled gradient of the Kullback-Leibler divergence with respect to
-        the embedding.
-    """
-    X_embedded = params.reshape(n_samples, n_components)
-
-    # Q is a heavy-tailed distribution: Student's t-distribution
-    n = pdist(X_embedded, "sqeuclidean")
-    n += 1.
-    n /= degrees_of_freedom
-    n **= (degrees_of_freedom + 1.0) / -2.0
-    Q = np.maximum(n / (2.0 * np.sum(n)), MACHINE_EPSILON)
-
-    # Optimization trick below: np.dot(x, y) is faster than
-    # np.sum(x * y) because it calls BLAS
-
-    # Objective: C (Kullback-Leibler divergence of P and Q)
-    if len(P.shape) == 2:
-        P = squareform(P)
-    kl_divergence = 2.0 * np.dot(P, np.log(P / Q))
-
-    return kl_divergence
-
-
-def _kl_divergence_bh(params, P, neighbors, degrees_of_freedom, n_samples,
-                      n_components, angle=0.5, skip_num_points=0,
-                      verbose=False):
+def _kl_divergence_bh(params, P, degrees_of_freedom, n_samples, n_components,
+                      angle=0.5, skip_num_points=0, verbose=False):
     """t-SNE objective function: KL divergence of p_ijs and q_ijs.
 
     Uses Barnes-Hut tree methods to calculate the gradient that
@@ -234,12 +195,9 @@ def _kl_divergence_bh(params, P, neighbors, degrees_of_freedom, n_samples,
     params : array, shape (n_params,)
         Unraveled embedding.
 
-    P : array, shape (n_samples * (n_samples-1) / 2,)
-        Condensed joint probability matrix.
-
-    neighbors : int64 array, shape (n_samples, K)
-        Array with element [i, j] giving the index for the jth
-        closest neighbor to point i.
+    P : csr sparse matrix, shape (n_samples, n_sample)
+        Sparse approximate joint probability matrix, computed only for the
+        k nearest-neighbors and symmetrized.
 
     degrees_of_freedom : float
         Degrees of freedom of the Student's-t distribution.
@@ -278,14 +236,13 @@ def _kl_divergence_bh(params, P, neighbors, degrees_of_freedom, n_samples,
     """
     params = params.astype(np.float32, copy=False)
     X_embedded = params.reshape(n_samples, n_components)
-    neighbors = neighbors.astype(np.int64, copy=False)
-    if len(P.shape) == 1:
-        sP = squareform(P).astype(np.float32)
-    else:
-        sP = P.astype(np.float32)
+
+    val_P = P.data.astype(np.float32, copy=False)
+    neighbors = P.indices.astype(np.int64, copy=False)
+    indptr = P.indptr.astype(np.int64, copy=False)
 
     grad = np.zeros(X_embedded.shape, dtype=np.float32)
-    error = _barnes_hut_tsne.gradient(sP, X_embedded, neighbors,
+    error = _barnes_hut_tsne.gradient(val_P, X_embedded, neighbors, indptr,
                                       grad, angle, n_components, verbose,
                                       dof=degrees_of_freedom)
     c = 2.0 * (degrees_of_freedom + 1.0) / degrees_of_freedom
@@ -295,11 +252,10 @@ def _kl_divergence_bh(params, P, neighbors, degrees_of_freedom, n_samples,
     return error, grad
 
 
-def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
-                      n_iter_check=1, n_iter_without_progress=50,
-                      momentum=0.5, learning_rate=1000.0, min_gain=0.01,
-                      min_grad_norm=1e-7, min_error_diff=1e-7, verbose=0,
-                      args=None, kwargs=None):
+def _gradient_descent(objective, p0, it, n_iter,
+                      n_iter_check=1, n_iter_without_progress=300,
+                      momentum=0.8, learning_rate=200.0, min_gain=0.01,
+                      min_grad_norm=1e-7, verbose=0, args=None, kwargs=None):
     """Batch gradient descent with momentum and individual gains.
 
     Parameters
@@ -324,21 +280,20 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
         Number of iterations before evaluating the global error. If the error
         is sufficiently low, we abort the optimization.
 
-    objective_error : function or callable
-        Should return a tuple of cost and gradient for a given parameter
-        vector.
-
-    n_iter_without_progress : int, optional (default: 30)
+    n_iter_without_progress : int, optional (default: 300)
         Maximum number of iterations without progress before we abort the
         optimization.
 
-    momentum : float, within (0.0, 1.0), optional (default: 0.5)
+    momentum : float, within (0.0, 1.0), optional (default: 0.8)
         The momentum generates a weight for previous gradients that decays
         exponentially.
 
-    learning_rate : float, optional (default: 1000.0)
-        The learning rate should be extremely high for t-SNE! Values in the
-        range [100.0, 1000.0] are common.
+    learning_rate : float, optional (default: 200.0)
+        The learning rate for t-SNE is usually in the range [10.0, 1000.0]. If
+        the learning rate is too high, the data may look like a 'ball' with any
+        point approximately equidistant from its nearest neighbours. If the
+        learning rate is too low, most points may look compressed in a dense
+        cloud with few outliers.
 
     min_gain : float, optional (default: 0.01)
         Minimum individual gain for each parameter.
@@ -347,10 +302,6 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
         If the gradient norm is below this threshold, the optimization will
         be aborted.
 
-    min_error_diff : float, optional (default: 1e-7)
-        If the absolute difference of two successive cost function values
-        is below this threshold, the optimization will be aborted.
-
     verbose : int, optional (default: 0)
         Verbosity level.
 
@@ -381,10 +332,11 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
     gains = np.ones_like(p)
     error = np.finfo(np.float).max
     best_error = np.finfo(np.float).max
-    best_iter = 0
+    best_iter = i = it
 
+    tic = time()
     for i in range(it, n_iter):
-        new_error, grad = objective(p, *args, **kwargs)
+        error, grad = objective(p, *args, **kwargs)
         grad_norm = linalg.norm(grad)
 
         inc = update * grad < 0.0
@@ -397,14 +349,15 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
         p += update
 
         if (i + 1) % n_iter_check == 0:
-            if new_error is None:
-                new_error = objective_error(p, *args)
-            error_diff = np.abs(new_error - error)
-            error = new_error
+            toc = time()
+            duration = toc - tic
+            tic = toc
 
             if verbose >= 2:
-                m = "[t-SNE] Iteration %d: error = %.7f, gradient norm = %.7f"
-                print(m % (i + 1, error, grad_norm))
+                print("[t-SNE] Iteration %d: error = %.7f,"
+                      " gradient norm = %.7f"
+                      " (%s iterations in %0.3fs)"
+                      % (i + 1, error, grad_norm, n_iter_check, duration))
 
             if error < best_error:
                 best_error = error
@@ -420,14 +373,6 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
                     print("[t-SNE] Iteration %d: gradient norm %f. Finished."
                           % (i + 1, grad_norm))
                 break
-            if error_diff <= min_error_diff:
-                if verbose >= 2:
-                    m = "[t-SNE] Iteration %d: error difference %f. Finished."
-                    print(m % (i + 1, error_diff))
-                break
-
-        if new_error is not None:
-            error = new_error
 
     return p, error, i
 
@@ -525,7 +470,7 @@ class TSNE(BaseEstimator):
         between 5 and 50. The choice is not extremely critical since t-SNE
         is quite insensitive to this parameter.
 
-    early_exaggeration : float, optional (default: 4.0)
+    early_exaggeration : float, optional (default: 12.0)
         Controls how tight natural clusters in the original space are in
         the embedded space and how much space will be between them. For
         larger values, the space between natural clusters will be larger
@@ -534,31 +479,30 @@ class TSNE(BaseEstimator):
         optimization, the early exaggeration factor or the learning rate
         might be too high.
 
-    learning_rate : float, optional (default: 1000)
-        The learning rate can be a critical parameter. It should be
-        between 100 and 1000. If the cost function increases during initial
-        optimization, the early exaggeration factor or the learning rate
-        might be too high. If the cost function gets stuck in a bad local
-        minimum increasing the learning rate helps sometimes.
+    learning_rate : float, optional (default: 200.0)
+        The learning rate for t-SNE is usually in the range [10.0, 1000.0]. If
+        the learning rate is too high, the data may look like a 'ball' with any
+        point approximately equidistant from its nearest neighbours. If the
+        learning rate is too low, most points may look compressed in a dense
+        cloud with few outliers. If the cost function gets stuck in a bad local
+        minimum increasing the learning rate may help.
 
     n_iter : int, optional (default: 1000)
         Maximum number of iterations for the optimization. Should be at
-        least 200.
+        least 250.
 
-    n_iter_without_progress : int, optional (default: 30)
-        Only used if method='exact'
+    n_iter_without_progress : int, optional (default: 300)
         Maximum number of iterations without progress before we abort the
-        optimization. If method='barnes_hut' this parameter is fixed to
-        a value of 30 and cannot be changed.
+        optimization, used after 250 initial iterations with early
+        exaggeration. Note that progress is only checked every 50 iterations so
+        this value is rounded to the next multiple of 50.
 
         .. versionadded:: 0.17
            parameter *n_iter_without_progress* to control stopping criteria.
 
     min_grad_norm : float, optional (default: 1e-7)
-        Only used if method='exact'
         If the gradient norm is below this threshold, the optimization will
-        be aborted. If method='barnes_hut' this parameter is fixed to a value
-        of 1e-3 and cannot be changed.
+        be stopped.
 
     metric : string or callable, optional
         The metric to use when calculating distance between instances in a
@@ -609,7 +553,6 @@ class TSNE(BaseEstimator):
         in the range of 0.2 - 0.8. Angle less than 0.2 has quickly increasing
         computation time and angle greater 0.8 has quickly increasing error.
 
-
     Attributes
     ----------
     embedding_ : array-like, shape (n_samples, n_components)
@@ -627,13 +570,9 @@ class TSNE(BaseEstimator):
     >>> import numpy as np
     >>> from sklearn.manifold import TSNE
     >>> X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
-    >>> model = TSNE(n_components=2, random_state=0)
-    >>> np.set_printoptions(suppress=True)
-    >>> model.fit_transform(X) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
-    array([[ 0.00017619,  0.00004014],
-           [ 0.00010268,  0.00020546],
-           [ 0.00018298, -0.00008335],
-           [ 0.00009501, -0.00001388]])
+    >>> X_embedded = TSNE(n_components=2).fit_transform(X)
+    >>> X_embedded.shape
+    (4, 2)
 
     References
     ----------
@@ -648,17 +587,17 @@ class TSNE(BaseEstimator):
         Journal of Machine Learning Research 15(Oct):3221-3245, 2014.
         http://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf
     """
+    # Control the number of exploration iterations with early_exaggeration on
+    _EXPLORATION_N_ITER = 250
+
+    # Control the number of iterations between progress checks
+    _N_ITER_CHECK = 50
 
     def __init__(self, n_components=2, perplexity=30.0,
-                 early_exaggeration=4.0, learning_rate=1000.0, n_iter=1000,
-                 n_iter_without_progress=30, min_grad_norm=1e-7,
+                 early_exaggeration=12.0, learning_rate=200.0, n_iter=1000,
+                 n_iter_without_progress=300, min_grad_norm=1e-7,
                  metric="euclidean", init="random", verbose=0,
                  random_state=None, method='barnes_hut', angle=0.5):
-        if not ((isinstance(init, string_types) and
-                init in ["pca", "random"]) or
-                isinstance(init, np.ndarray)):
-            msg = "'init' must be 'pca', 'random', or a numpy array"
-            raise ValueError(msg)
         self.n_components = n_components
         self.perplexity = perplexity
         self.early_exaggeration = early_exaggeration
@@ -699,6 +638,16 @@ class TSNE(BaseEstimator):
             raise ValueError("'method' must be 'barnes_hut' or 'exact'")
         if self.angle < 0.0 or self.angle > 1.0:
             raise ValueError("'angle' must be between 0.0 - 1.0")
+        if self.metric == "precomputed":
+            if isinstance(self.init, string_types) and self.init == 'pca':
+                raise ValueError("The parameter init=\"pca\" cannot be "
+                                 "used with metric=\"precomputed\".")
+            if X.shape[0] != X.shape[1]:
+                raise ValueError("X should be a square distance matrix")
+            if np.any(X < 0):
+                raise ValueError("All distances should be positive, the "
+                                 "precomputed distances given as X is not "
+                                 "correct")
         if self.method == 'barnes_hut' and sp.issparse(X):
             raise TypeError('A sparse matrix was passed, but dense '
                             'data is required for method="barnes_hut". Use '
@@ -708,84 +657,115 @@ class TSNE(BaseEstimator):
                             'reduction techniques (e.g. TruncatedSVD)')
         else:
             X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
-                            dtype=np.float64)
+                            dtype=[np.float32, np.float64])
+        if self.method == 'barnes_hut' and self.n_components > 3:
+            raise ValueError("'n_components' should be inferior to 4 for the "
+                             "barnes_hut algorithm as it relies on "
+                             "quad-tree or oct-tree.")
         random_state = check_random_state(self.random_state)
 
         if self.early_exaggeration < 1.0:
-            raise ValueError("early_exaggeration must be at least 1, but is "
-                             "%f" % self.early_exaggeration)
+            raise ValueError("early_exaggeration must be at least 1, but is {}"
+                             .format(self.early_exaggeration))
 
-        if self.n_iter < 200:
-            raise ValueError("n_iter should be at least 200")
+        if self.n_iter < 250:
+            raise ValueError("n_iter should be at least 250")
 
-        if self.metric == "precomputed":
-            if isinstance(self.init, string_types) and self.init == 'pca':
-                raise ValueError("The parameter init=\"pca\" cannot be used "
-                                 "with metric=\"precomputed\".")
-            if X.shape[0] != X.shape[1]:
-                raise ValueError("X should be a square distance matrix")
-            distances = X
-        else:
-            if self.verbose:
-                print("[t-SNE] Computing pairwise distances...")
+        n_samples = X.shape[0]
 
-            if self.metric == "euclidean":
-                distances = pairwise_distances(X, metric=self.metric,
-                                               squared=True)
+        neighbors_nn = None
+        if self.method == "exact":
+            # Retrieve the distance matrix, either using the precomputed one or
+            # computing it.
+            if self.metric == "precomputed":
+                distances = X
             else:
-                distances = pairwise_distances(X, metric=self.metric)
+                if self.verbose:
+                    print("[t-SNE] Computing pairwise distances...")
 
-        if not np.all(distances >= 0):
-            raise ValueError("All distances should be positive, either "
-                             "the metric or precomputed distances given "
-                             "as X are not correct")
+                if self.metric == "euclidean":
+                    distances = pairwise_distances(X, metric=self.metric,
+                                                   squared=True)
+                else:
+                    distances = pairwise_distances(X, metric=self.metric)
 
-        # Degrees of freedom of the Student's t-distribution. The suggestion
-        # degrees_of_freedom = n_components - 1 comes from
-        # "Learning a Parametric Embedding by Preserving Local Structure"
-        # Laurens van der Maaten, 2009.
-        degrees_of_freedom = max(self.n_components - 1.0, 1)
-        n_samples = X.shape[0]
-        # the number of nearest neighbors to find
-        k = min(n_samples - 1, int(3. * self.perplexity + 1))
+                if np.any(distances < 0):
+                    raise ValueError("All distances should be positive, the "
+                                     "metric given is not correct")
+
+            # compute the joint probability distribution for the input space
+            P = _joint_probabilities(distances, self.perplexity, self.verbose)
+            assert np.all(np.isfinite(P)), "All probabilities should be finite"
+            assert np.all(P >= 0), "All probabilities should be non-negative"
+            assert np.all(P <= 1), ("All probabilities should be less "
+                                    "or then equal to one")
+
+        else:
+            # Cpmpute the number of nearest neighbors to find.
+            # LvdM uses 3 * perplexity as the number of neighbors.
+            # In the event that we have very small # of points
+            # set the neighbors to n - 1.
+            k = min(n_samples - 1, int(3. * self.perplexity + 1))
 
-        neighbors_nn = None
-        if self.method == 'barnes_hut':
             if self.verbose:
-                print("[t-SNE] Computing %i nearest neighbors..." % k)
-            if self.metric == 'precomputed':
-                # Use the precomputed distances to find
-                # the k nearest neighbors and their distances
-                neighbors_nn = np.argsort(distances, axis=1)[:, :k]
-            else:
-                # Find the nearest neighbors for every point
-                bt = BallTree(X)
-                # LvdM uses 3 * perplexity as the number of neighbors
-                # And we add one to not count the data point itself
-                # In the event that we have very small # of points
-                # set the neighbors to n - 1
-                distances_nn, neighbors_nn = bt.query(X, k=k + 1)
-                neighbors_nn = neighbors_nn[:, 1:]
-            P = _joint_probabilities_nn(distances, neighbors_nn,
+                print("[t-SNE] Computing {} nearest neighbors...".format(k))
+
+            # Find the nearest neighbors for every point
+            neighbors_method = 'ball_tree'
+            if (self.metric == 'precomputed'):
+                neighbors_method = 'brute'
+            knn = NearestNeighbors(algorithm=neighbors_method, n_neighbors=k,
+                                   metric=self.metric)
+            t0 = time()
+            knn.fit(X)
+            duration = time() - t0
+            if self.verbose:
+                print("[t-SNE] Indexed {} samples in {:.3f}s...".format(
+                    n_samples, duration))
+
+            t0 = time()
+            distances_nn, neighbors_nn = knn.kneighbors(
+                None, n_neighbors=k)
+            duration = time() - t0
+            if self.verbose:
+                print("[t-SNE] Computed neighbors for {} samples in {:.3f}s..."
+                      .format(n_samples, duration))
+
+            # Free the memory used by the ball_tree
+            del knn
+
+            if self.metric == "euclidean":
+                # knn return the euclidean distance but we need it squared
+                # to be consistent with the 'exact' method. Note that the
+                # the method was derived using the euclidean method as in the
+                # input space. Not sure of the implication of using a different
+                # metric.
+                distances_nn **= 2
+
+            # compute the joint probability distribution for the input space
+            P = _joint_probabilities_nn(distances_nn, neighbors_nn,
                                         self.perplexity, self.verbose)
-        else:
-            P = _joint_probabilities(distances, self.perplexity, self.verbose)
-        assert np.all(np.isfinite(P)), "All probabilities should be finite"
-        assert np.all(P >= 0), "All probabilities should be zero or positive"
-        assert np.all(P <= 1), ("All probabilities should be less "
-                                "or then equal to one")
 
         if isinstance(self.init, np.ndarray):
             X_embedded = self.init
         elif self.init == 'pca':
             pca = PCA(n_components=self.n_components, svd_solver='randomized',
                       random_state=random_state)
-            X_embedded = pca.fit_transform(X)
+            X_embedded = pca.fit_transform(X).astype(np.float32, copy=False)
         elif self.init == 'random':
-            X_embedded = None
+            # The embedding is initialized with iid samples from Gaussians with
+            # standard deviation 1e-4.
+            X_embedded = 1e-4 * random_state.randn(
+                n_samples, self.n_components).astype(np.float32)
         else:
-            raise ValueError("Unsupported initialization scheme: %s"
-                             % self.init)
+            raise ValueError("'init' must be 'pca', 'random', or "
+                             "a numpy array")
+
+        # Degrees of freedom of the Student's t-distribution. The suggestion
+        # degrees_of_freedom = n_components - 1 comes from
+        # "Learning a Parametric Embedding by Preserving Local Structure"
+        # Laurens van der Maaten, 2009.
+        degrees_of_freedom = max(self.n_components - 1.0, 1)
 
         return self._tsne(P, degrees_of_freedom, n_samples, random_state,
                           X_embedded=X_embedded,
@@ -798,75 +778,59 @@ class TSNE(BaseEstimator):
     def n_iter_final(self):
         return self.n_iter_
 
-    def _tsne(self, P, degrees_of_freedom, n_samples, random_state,
-              X_embedded=None, neighbors=None, skip_num_points=0):
+    def _tsne(self, P, degrees_of_freedom, n_samples, random_state, X_embedded,
+              neighbors=None, skip_num_points=0):
         """Runs t-SNE."""
         # t-SNE minimizes the Kullback-Leiber divergence of the Gaussians P
         # and the Student's t-distributions Q. The optimization algorithm that
-        # we use is batch gradient descent with three stages:
-        # * early exaggeration with momentum 0.5
-        # * early exaggeration with momentum 0.8
-        # * final optimization with momentum 0.8
-        # The embedding is initialized with iid samples from Gaussians with
-        # standard deviation 1e-4.
-
-        if X_embedded is None:
-            # Initialize embedding randomly
-            X_embedded = 1e-4 * random_state.randn(n_samples,
-                                                   self.n_components)
+        # we use is batch gradient descent with two stages:
+        # * initial optimization with early exaggeration and momentum at 0.5
+        # * final optimization with momentum at 0.8
         params = X_embedded.ravel()
 
-        opt_args = {"n_iter": 50, "momentum": 0.5, "it": 0,
-                    "learning_rate": self.learning_rate,
-                    "n_iter_without_progress": self.n_iter_without_progress,
-                    "verbose": self.verbose, "n_iter_check": 25,
-                    "kwargs": dict(skip_num_points=skip_num_points)}
+        opt_args = {
+            "it": 0,
+            "n_iter_check": self._N_ITER_CHECK,
+            "min_grad_norm": self.min_grad_norm,
+            "learning_rate": self.learning_rate,
+            "verbose": self.verbose,
+            "kwargs": dict(skip_num_points=skip_num_points),
+            "args": [P, degrees_of_freedom, n_samples, self.n_components],
+            "n_iter_without_progress": self._EXPLORATION_N_ITER,
+            "n_iter": self._EXPLORATION_N_ITER,
+            "momentum": 0.5,
+        }
         if self.method == 'barnes_hut':
-            m = "Must provide an array of neighbors to use Barnes-Hut"
-            assert neighbors is not None, m
             obj_func = _kl_divergence_bh
-            objective_error = _kl_divergence_error
-            sP = squareform(P).astype(np.float32)
-            neighbors = neighbors.astype(np.int64)
-            args = [sP, neighbors, degrees_of_freedom, n_samples,
-                    self.n_components]
-            opt_args['args'] = args
-            opt_args['min_grad_norm'] = 1e-3
-            opt_args['n_iter_without_progress'] = 30
-            # Don't always calculate the cost since that calculation
-            # can be nearly as expensive as the gradient
-            opt_args['objective_error'] = objective_error
             opt_args['kwargs']['angle'] = self.angle
+            # Repeat verbose argument for _kl_divergence_bh
             opt_args['kwargs']['verbose'] = self.verbose
         else:
             obj_func = _kl_divergence
-            opt_args['args'] = [P, degrees_of_freedom, n_samples,
-                                self.n_components]
-            opt_args['min_error_diff'] = 0.0
-            opt_args['min_grad_norm'] = self.min_grad_norm
 
-        # Early exaggeration
+        # Learning schedule (part 1): do 250 iteration with lower momentum but
+        # higher learning rate controlled via the early exageration parameter
         P *= self.early_exaggeration
-
-        params, kl_divergence, it = _gradient_descent(obj_func, params,
-                                                      **opt_args)
-        opt_args['n_iter'] = 100
-        opt_args['momentum'] = 0.8
-        opt_args['it'] = it + 1
         params, kl_divergence, it = _gradient_descent(obj_func, params,
                                                       **opt_args)
         if self.verbose:
             print("[t-SNE] KL divergence after %d iterations with early "
                   "exaggeration: %f" % (it + 1, kl_divergence))
-        # Save the final number of iterations
-        self.n_iter_ = it
 
-        # Final optimization
+        # Learning schedule (part 2): disable early exaggeration and finish
+        # optimization with a higher momentum at 0.8
         P /= self.early_exaggeration
-        opt_args['n_iter'] = self.n_iter
-        opt_args['it'] = it + 1
-        params, kl_divergence, it = _gradient_descent(obj_func, params,
-                                                      **opt_args)
+        remaining = self.n_iter - self._EXPLORATION_N_ITER
+        if it < self._EXPLORATION_N_ITER or remaining > 0:
+            opt_args['n_iter'] = self.n_iter
+            opt_args['it'] = it + 1
+            opt_args['momentum'] = 0.8
+            opt_args['n_iter_without_progress'] = self.n_iter_without_progress
+            params, kl_divergence, it = _gradient_descent(obj_func, params,
+                                                          **opt_args)
+
+        # Save the final number of iterations
+        self.n_iter_ = it
 
         if self.verbose:
             print("[t-SNE] Error after %d iterations: %f"
diff --git a/sklearn/manifold/tests/test_t_sne.py b/sklearn/manifold/tests/test_t_sne.py
index 52c056a5ad..2311b48ee2 100644
--- a/sklearn/manifold/tests/test_t_sne.py
+++ b/sklearn/manifold/tests/test_t_sne.py
@@ -4,12 +4,14 @@ import numpy as np
 import scipy.sparse as sp
 
 from sklearn.neighbors import BallTree
+from sklearn.neighbors import NearestNeighbors
 from sklearn.utils.testing import assert_less_equal
 from sklearn.utils.testing import assert_equal
 from sklearn.utils.testing import assert_almost_equal
 from sklearn.utils.testing import assert_array_equal
 from sklearn.utils.testing import assert_array_almost_equal
 from sklearn.utils.testing import assert_less
+from sklearn.utils.testing import assert_greater
 from sklearn.utils.testing import assert_raises_regexp
 from sklearn.utils.testing import assert_in
 from sklearn.utils.testing import skip_if_32bit
@@ -30,6 +32,14 @@ from scipy.spatial.distance import squareform
 from sklearn.metrics.pairwise import pairwise_distances
 
 
+x = np.linspace(0, 1, 10)
+xx, yy = np.meshgrid(x, x)
+X_2d_grid = np.hstack([
+    xx.ravel().reshape(-1, 1),
+    yy.ravel().reshape(-1, 1),
+])
+
+
 def test_gradient_descent_stops():
     # Test stopping conditions of gradient descent.
     class ObjectiveSmallGradient:
@@ -50,7 +60,7 @@ def test_gradient_descent_stops():
         _, error, it = _gradient_descent(
             ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=100,
             n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
-            min_gain=0.0, min_grad_norm=1e-5, min_error_diff=0.0, verbose=2)
+            min_gain=0.0, min_grad_norm=1e-5, verbose=2)
     finally:
         out = sys.stdout.getvalue()
         sys.stdout.close()
@@ -59,22 +69,6 @@ def test_gradient_descent_stops():
     assert_equal(it, 0)
     assert("gradient norm" in out)
 
-    # Error difference
-    old_stdout = sys.stdout
-    sys.stdout = StringIO()
-    try:
-        _, error, it = _gradient_descent(
-            ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=100,
-            n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
-            min_gain=0.0, min_grad_norm=0.0, min_error_diff=0.2, verbose=2)
-    finally:
-        out = sys.stdout.getvalue()
-        sys.stdout.close()
-        sys.stdout = old_stdout
-    assert_equal(error, 0.9)
-    assert_equal(it, 1)
-    assert("error difference" in out)
-
     # Maximum number of iterations without improvement
     old_stdout = sys.stdout
     sys.stdout = StringIO()
@@ -82,7 +76,7 @@ def test_gradient_descent_stops():
         _, error, it = _gradient_descent(
             flat_function, np.zeros(1), 0, n_iter=100,
             n_iter_without_progress=10, momentum=0.0, learning_rate=0.0,
-            min_gain=0.0, min_grad_norm=0.0, min_error_diff=-1.0, verbose=2)
+            min_gain=0.0, min_grad_norm=0.0, verbose=2)
     finally:
         out = sys.stdout.getvalue()
         sys.stdout.close()
@@ -98,7 +92,7 @@ def test_gradient_descent_stops():
         _, error, it = _gradient_descent(
             ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=11,
             n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
-            min_gain=0.0, min_grad_norm=0.0, min_error_diff=0.0, verbose=2)
+            min_gain=0.0, min_grad_norm=0.0, verbose=2)
     finally:
         out = sys.stdout.getvalue()
         sys.stdout.close()
@@ -140,20 +134,26 @@ def test_binary_search_neighbors():
 
     # Test that when we use all the neighbors the results are identical
     k = n_samples
-    neighbors_nn = np.argsort(distances, axis=1)[:, :k].astype(np.int64)
-    P2 = _binary_search_perplexity(distances, neighbors_nn,
+    neighbors_nn = np.argsort(distances, axis=1)[:, 1:k].astype(np.int64)
+    distances_nn = np.array([distances[k, neighbors_nn[k]]
+                            for k in range(n_samples)])
+    P2 = _binary_search_perplexity(distances_nn, neighbors_nn,
                                    desired_perplexity, verbose=0)
-    assert_array_almost_equal(P1, P2, decimal=4)
+    P_nn = np.array([P1[k, neighbors_nn[k]] for k in range(n_samples)])
+    assert_array_almost_equal(P_nn, P2, decimal=4)
 
     # Test that the highest P_ij are the same when few neighbors are used
-    for k in np.linspace(80, n_samples, 10):
+    for k in np.linspace(80, n_samples, 5):
         k = int(k)
         topn = k * 10  # check the top 10 *k entries out of k * k entries
         neighbors_nn = np.argsort(distances, axis=1)[:, :k].astype(np.int64)
-        P2k = _binary_search_perplexity(distances, neighbors_nn,
+        distances_nn = np.array([distances[k, neighbors_nn[k]]
+                                for k in range(n_samples)])
+        P2k = _binary_search_perplexity(distances_nn, neighbors_nn,
                                         desired_perplexity, verbose=0)
         idx = np.argsort(P1.ravel())[::-1]
         P1top = P1.ravel()[idx][:topn]
+        idx = np.argsort(P2k.ravel())[::-1]
         P2top = P2k.ravel()[idx][:topn]
         assert_array_almost_equal(P1top, P2top, decimal=2)
 
@@ -175,6 +175,8 @@ def test_binary_perplexity_stability():
         P = _binary_search_perplexity(distances.copy(), neighbors_nn.copy(),
                                       3, verbose=0)
         P1 = _joint_probabilities_nn(distances, neighbors_nn, 3, verbose=0)
+        # Convert the sparse matrix to a dense one for testing
+        P1 = P1.toarray()
         if last_P is None:
             last_P = P
             last_P1 = P1
@@ -193,9 +195,9 @@ def test_gradient():
     alpha = 1.0
 
     distances = random_state.randn(n_samples, n_features).astype(np.float32)
-    distances = distances.dot(distances.T)
+    distances = np.abs(distances.dot(distances.T))
     np.fill_diagonal(distances, 0.0)
-    X_embedded = random_state.randn(n_samples, n_components)
+    X_embedded = random_state.randn(n_samples, n_components).astype(np.float32)
 
     P = _joint_probabilities(distances, desired_perplexity=25.0,
                              verbose=0)
@@ -233,21 +235,16 @@ def test_trustworthiness():
 def test_preserve_trustworthiness_approximately():
     # Nearest neighbors should be preserved approximately.
     random_state = check_random_state(0)
-    # The Barnes-Hut approximation uses a different method to estimate
-    # P_ij using only a number of nearest neighbors instead of all
-    # points (so that k = 3 * perplexity). As a result we set the
-    # perplexity=5, so that the number of neighbors is 5%.
     n_components = 2
     methods = ['exact', 'barnes_hut']
-    X = random_state.randn(100, n_components).astype(np.float32)
+    X = random_state.randn(50, n_components).astype(np.float32)
     for init in ('random', 'pca'):
         for method in methods:
-            tsne = TSNE(n_components=n_components, perplexity=50,
-                        learning_rate=100.0, init=init, random_state=0,
+            tsne = TSNE(n_components=n_components, init=init, random_state=0,
                         method=method)
             X_embedded = tsne.fit_transform(X)
-            T = trustworthiness(X, X_embedded, n_neighbors=1)
-            assert_almost_equal(T, 1.0, decimal=1)
+            t = trustworthiness(X, X_embedded, n_neighbors=1)
+            assert_greater(t, 0.9)
 
 
 def test_optimization_minimizes_kl_divergence():
@@ -255,7 +252,7 @@ def test_optimization_minimizes_kl_divergence():
     random_state = check_random_state(0)
     X, _ = make_blobs(n_features=3, random_state=random_state)
     kl_divergences = []
-    for n_iter in [200, 250, 300]:
+    for n_iter in [250, 300, 350]:
         tsne = TSNE(n_components=2, perplexity=10, learning_rate=100.0,
                     n_iter=n_iter, random_state=0)
         tsne.fit_transform(X)
@@ -280,13 +277,16 @@ def test_fit_csr_matrix():
 def test_preserve_trustworthiness_approximately_with_precomputed_distances():
     # Nearest neighbors should be preserved approximately.
     random_state = check_random_state(0)
-    X = random_state.randn(100, 2)
-    D = squareform(pdist(X), "sqeuclidean")
-    tsne = TSNE(n_components=2, perplexity=2, learning_rate=100.0,
-                metric="precomputed", random_state=0, verbose=0)
-    X_embedded = tsne.fit_transform(D)
-    assert_almost_equal(trustworthiness(D, X_embedded, n_neighbors=1,
-                                        precomputed=True), 1.0, decimal=1)
+    for i in range(3):
+        X = random_state.randn(100, 2)
+        D = squareform(pdist(X), "sqeuclidean")
+        tsne = TSNE(n_components=2, perplexity=2, learning_rate=100.0,
+                    early_exaggeration=2.0, metric="precomputed",
+                    random_state=i, verbose=0)
+        X_embedded = tsne.fit_transform(D)
+        t = trustworthiness(D, X_embedded, n_neighbors=1,
+                            precomputed=True)
+        assert t > .95
 
 
 def test_early_exaggeration_too_small():
@@ -310,10 +310,32 @@ def test_non_square_precomputed_distances():
                          tsne.fit_transform, np.array([[0.0], [1.0]]))
 
 
+def test_non_positive_precomputed_distances():
+    # Precomputed distance matrices must be positive.
+    bad_dist = np.array([[0., -1.], [1., 0.]])
+    for method in ['barnes_hut', 'exact']:
+        tsne = TSNE(metric="precomputed", method=method)
+        assert_raises_regexp(ValueError, "All distances .*precomputed.*",
+                             tsne.fit_transform, bad_dist)
+
+
+def test_non_positive_computed_distances():
+    # Computed distance matrices must be positive.
+    def metric(x, y):
+        return -1
+
+    tsne = TSNE(metric=metric, method='exact')
+    X = np.array([[0.0, 0.0], [1.0, 1.0]])
+    assert_raises_regexp(ValueError, "All distances .*metric given.*",
+                         tsne.fit_transform, X)
+
+
 def test_init_not_available():
     # 'init' must be 'pca', 'random', or numpy array.
+    tsne = TSNE(init="not available")
     m = "'init' must be 'pca', 'random', or a numpy array"
-    assert_raises_regexp(ValueError, m, TSNE, init="not available")
+    assert_raises_regexp(ValueError, m, tsne.fit_transform,
+                         np.array([[0.0], [1.0]]))
 
 
 def test_init_ndarray():
@@ -332,10 +354,29 @@ def test_init_ndarray_precomputed():
 
 def test_distance_not_available():
     # 'metric' must be valid.
-    tsne = TSNE(metric="not available")
+    tsne = TSNE(metric="not available", method='exact')
     assert_raises_regexp(ValueError, "Unknown metric not available.*",
                          tsne.fit_transform, np.array([[0.0], [1.0]]))
 
+    tsne = TSNE(metric="not available", method='barnes_hut')
+    assert_raises_regexp(ValueError, "Metric 'not available' not valid.*",
+                         tsne.fit_transform, np.array([[0.0], [1.0]]))
+
+
+def test_method_not_available():
+    # 'nethod' must be 'barnes_hut' or 'exact'
+    tsne = TSNE(method='not available')
+    assert_raises_regexp(ValueError, "'method' must be 'barnes_hut' or ",
+                         tsne.fit_transform, np.array([[0.0], [1.0]]))
+
+
+def test_angle_out_of_range_checks():
+    # check the angle parameter range
+    for angle in [-1, -1e-6, 1 + 1e-6, 2]:
+        tsne = TSNE(angle=angle)
+        assert_raises_regexp(ValueError, "'angle' must be between 0.0 - 1.0",
+                             tsne.fit_transform, np.array([[0.0], [1.0]]))
+
 
 def test_pca_initialization_not_compatible_with_precomputed_kernel():
     # Precomputed distance matrices must be square matrices.
@@ -345,6 +386,48 @@ def test_pca_initialization_not_compatible_with_precomputed_kernel():
                          tsne.fit_transform, np.array([[0.0], [1.0]]))
 
 
+def test_n_components_range():
+    # barnes_hut method should only be used with n_components <= 3
+    tsne = TSNE(n_components=4, method="barnes_hut")
+    assert_raises_regexp(ValueError, "'n_components' should be .*",
+                         tsne.fit_transform, np.array([[0.0], [1.0]]))
+
+
+def test_early_exaggeration_used():
+    # check that the ``early_exaggeration`` parameter has an effect
+    random_state = check_random_state(0)
+    n_components = 2
+    methods = ['exact', 'barnes_hut']
+    X = random_state.randn(25, n_components).astype(np.float32)
+    for method in methods:
+        tsne = TSNE(n_components=n_components, perplexity=1,
+                    learning_rate=100.0, init="pca", random_state=0,
+                    method=method, early_exaggeration=1.0)
+        X_embedded1 = tsne.fit_transform(X)
+        tsne = TSNE(n_components=n_components, perplexity=1,
+                    learning_rate=100.0, init="pca", random_state=0,
+                    method=method, early_exaggeration=10.0)
+        X_embedded2 = tsne.fit_transform(X)
+
+        assert not np.allclose(X_embedded1, X_embedded2)
+
+
+def test_n_iter_used():
+    # check that the ``n_iter`` parameter has an effect
+    random_state = check_random_state(0)
+    n_components = 2
+    methods = ['exact', 'barnes_hut']
+    X = random_state.randn(25, n_components).astype(np.float32)
+    for method in methods:
+        for n_iter in [251, 500]:
+            tsne = TSNE(n_components=n_components, perplexity=1,
+                        learning_rate=0.5, init="random", random_state=0,
+                        method=method, early_exaggeration=1.0, n_iter=n_iter)
+            tsne.fit_transform(X)
+
+            assert tsne.n_iter_ == n_iter - 1
+
+
 def test_answer_gradient_two_points():
     # Test the tree with only a single set of children.
     #
@@ -418,7 +501,13 @@ def _run_answer_test(pos_input, pos_output, neighbors, grad_output,
     pij_input = squareform(pij_input).astype(np.float32)
     grad_bh = np.zeros(pos_output.shape, dtype=np.float32)
 
-    _barnes_hut_tsne.gradient(pij_input, pos_output, neighbors,
+    from scipy.sparse import csr_matrix
+    P = csr_matrix(pij_input)
+
+    neighbors = P.indices.astype(np.int64)
+    indptr = P.indptr.astype(np.int64)
+
+    _barnes_hut_tsne.gradient(P.data, pos_output, neighbors, indptr,
                               grad_bh, 0.5, 2, 1, skip_num_points=0)
     assert_array_almost_equal(grad_bh, grad_output, decimal=4)
 
@@ -439,12 +528,10 @@ def test_verbose():
         sys.stdout = old_stdout
 
     assert("[t-SNE]" in out)
-    assert("Computing pairwise distances" in out)
+    assert("nearest neighbors..." in out)
     assert("Computed conditional probabilities" in out)
     assert("Mean sigma" in out)
-    assert("Finished" in out)
     assert("early exaggeration" in out)
-    assert("Finished" in out)
 
 
 def test_chebyshev_metric():
@@ -481,10 +568,15 @@ def test_64bit():
     methods = ['barnes_hut', 'exact']
     for method in methods:
         for dt in [np.float32, np.float64]:
-            X = random_state.randn(100, 2).astype(dt)
+            X = random_state.randn(50, 2).astype(dt)
             tsne = TSNE(n_components=2, perplexity=2, learning_rate=100.0,
-                        random_state=0, method=method)
-            tsne.fit_transform(X)
+                        random_state=0, method=method, verbose=0)
+            X_embedded = tsne.fit_transform(X)
+            effective_type = X_embedded.dtype
+
+            # tsne cython code is only single precision, so the output will
+            # always be single precision, irrespectively of the input dtype
+            assert effective_type == np.float32
 
 
 def test_barnes_hut_angle():
@@ -499,91 +591,57 @@ def test_barnes_hut_angle():
         random_state = check_random_state(0)
         distances = random_state.randn(n_samples, n_features)
         distances = distances.astype(np.float32)
-        distances = distances.dot(distances.T)
+        distances = abs(distances.dot(distances.T))
         np.fill_diagonal(distances, 0.0)
         params = random_state.randn(n_samples, n_components)
-        P = _joint_probabilities(distances, perplexity, False)
-        kl, gradex = _kl_divergence(params, P, degrees_of_freedom, n_samples,
-                                    n_components)
+        P = _joint_probabilities(distances, perplexity, verbose=0)
+        kl_exact, grad_exact = _kl_divergence(params, P, degrees_of_freedom,
+                                              n_samples, n_components)
 
         k = n_samples - 1
         bt = BallTree(distances)
         distances_nn, neighbors_nn = bt.query(distances, k=k + 1)
         neighbors_nn = neighbors_nn[:, 1:]
-        Pbh = _joint_probabilities_nn(distances, neighbors_nn,
-                                      perplexity, False)
-        kl, gradbh = _kl_divergence_bh(params, Pbh, neighbors_nn,
-                                       degrees_of_freedom, n_samples,
-                                       n_components, angle=angle,
-                                       skip_num_points=0, verbose=False)
-        assert_array_almost_equal(Pbh, P, decimal=5)
-        assert_array_almost_equal(gradex, gradbh, decimal=5)
-
-
-def test_quadtree_similar_point():
-    # Introduce a point into a quad tree where a similar point already exists.
-    # Test will hang if it doesn't complete.
-    Xs = []
-
-    # check the case where points are actually different
-    Xs.append(np.array([[1, 2], [3, 4]], dtype=np.float32))
-    # check the case where points are the same on X axis
-    Xs.append(np.array([[1.0, 2.0], [1.0, 3.0]], dtype=np.float32))
-    # check the case where points are arbitrarily close on X axis
-    Xs.append(np.array([[1.00001, 2.0], [1.00002, 3.0]], dtype=np.float32))
-    # check the case where points are the same on Y axis
-    Xs.append(np.array([[1.0, 2.0], [3.0, 2.0]], dtype=np.float32))
-    # check the case where points are arbitrarily close on Y axis
-    Xs.append(np.array([[1.0, 2.00001], [3.0, 2.00002]], dtype=np.float32))
-    # check the case where points are arbitrarily close on both axes
-    Xs.append(np.array([[1.00001, 2.00001], [1.00002, 2.00002]],
-              dtype=np.float32))
-
-    # check the case where points are arbitrarily close on both axes
-    # close to machine epsilon - x axis
-    Xs.append(np.array([[1, 0.0003817754041], [2, 0.0003817753750]],
-              dtype=np.float32))
-
-    # check the case where points are arbitrarily close on both axes
-    # close to machine epsilon - y axis
-    Xs.append(np.array([[0.0003817754041, 1.0], [0.0003817753750, 2.0]],
-              dtype=np.float32))
-
-    for X in Xs:
-        counts = np.zeros(3, dtype='int64')
-        _barnes_hut_tsne.check_quadtree(X, counts)
-        m = "Tree consistency failed: unexpected number of points at root node"
-        assert_equal(counts[0], counts[1], m)
-        m = "Tree consistency failed: unexpected number of points on the tree"
-        assert_equal(counts[0], counts[2], m)
-
-
-def test_index_offset():
-    # Make sure translating between 1D and N-D indices are preserved
-    assert_equal(_barnes_hut_tsne.test_index2offset(), 1)
-    assert_equal(_barnes_hut_tsne.test_index_offset(), 1)
+        distances_nn = np.array([distances[i, neighbors_nn[i]]
+                                 for i in range(n_samples)])
+        assert np.all(distances[0, neighbors_nn[0]] == distances_nn[0]),\
+            abs(distances[0, neighbors_nn[0]] - distances_nn[0])
+        P_bh = _joint_probabilities_nn(distances_nn, neighbors_nn,
+                                       perplexity, verbose=0)
+        kl_bh, grad_bh = _kl_divergence_bh(params, P_bh, degrees_of_freedom,
+                                           n_samples, n_components,
+                                           angle=angle, skip_num_points=0,
+                                           verbose=0)
+
+        P = squareform(P)
+        P_bh = P_bh.toarray()
+        assert_array_almost_equal(P_bh, P, decimal=5)
+        assert_almost_equal(kl_exact, kl_bh, decimal=3)
 
 
 @skip_if_32bit
 def test_n_iter_without_progress():
     # Use a dummy negative n_iter_without_progress and check output on stdout
     random_state = check_random_state(0)
-    X = random_state.randn(100, 2)
-    tsne = TSNE(n_iter_without_progress=-1, verbose=2,
-                random_state=1, method='exact')
-
-    old_stdout = sys.stdout
-    sys.stdout = StringIO()
-    try:
-        tsne.fit_transform(X)
-    finally:
-        out = sys.stdout.getvalue()
-        sys.stdout.close()
-        sys.stdout = old_stdout
+    X = random_state.randn(100, 10)
+    for method in ["barnes_hut", "exact"]:
+        tsne = TSNE(n_iter_without_progress=-1, verbose=2, learning_rate=1e8,
+                    random_state=0, method=method, n_iter=351, init="random")
+        tsne._N_ITER_CHECK = 1
+        tsne._EXPLORATION_N_ITER = 0
+
+        old_stdout = sys.stdout
+        sys.stdout = StringIO()
+        try:
+            tsne.fit_transform(X)
+        finally:
+            out = sys.stdout.getvalue()
+            sys.stdout.close()
+            sys.stdout = old_stdout
 
-    # The output needs to contain the value of n_iter_without_progress
-    assert_in("did not make any progress during the "
-              "last -1 episodes. Finished.", out)
+        # The output needs to contain the value of n_iter_without_progress
+        assert_in("did not make any progress during the "
+                  "last -1 episodes. Finished.", out)
 
 
 def test_min_grad_norm():
@@ -616,7 +674,7 @@ def test_min_grad_norm():
         start_grad_norm = line.find('gradient norm')
         if start_grad_norm >= 0:
             line = line[start_grad_norm:]
-            line = line.replace('gradient norm = ', '')
+            line = line.replace('gradient norm = ', '').split(' ')[0]
             gradient_norm_values.append(float(line))
 
     # Compute how often the gradient norm is smaller than min_grad_norm
@@ -654,3 +712,55 @@ def test_accessible_kl_divergence():
                 error, _, _ = error.partition(',')
                 break
     assert_almost_equal(tsne.kl_divergence_, float(error), decimal=5)
+
+
+def check_uniform_grid(method, seeds=[0, 1, 2], n_iter=1000):
+    """Make sure that TSNE can approximately recover a uniform 2D grid"""
+    for seed in seeds:
+        tsne = TSNE(n_components=2, init='random', random_state=seed,
+                    perplexity=10, n_iter=n_iter, method=method)
+        Y = tsne.fit_transform(X_2d_grid)
+
+        # Ensure that the convergence criterion has been triggered
+        assert tsne.n_iter_ < n_iter
+
+        # Ensure that the resulting embedding leads to approximately
+        # uniformly spaced points: the distance to the closest neighbors
+        # should be non-zero and approximately constant.
+        nn = NearestNeighbors(n_neighbors=1).fit(Y)
+        dist_to_nn = nn.kneighbors(return_distance=True)[0].ravel()
+        assert dist_to_nn.min() > 0.1
+
+        smallest_to_mean = dist_to_nn.min() / np.mean(dist_to_nn)
+        largest_to_mean = dist_to_nn.max() / np.mean(dist_to_nn)
+
+        try_name = "{}_{}".format(method, seed)
+        assert_greater(smallest_to_mean, .5, msg=try_name)
+        assert_less(largest_to_mean, 2, msg=try_name)
+
+
+def test_uniform_grid():
+    for method in ['barnes_hut', 'exact']:
+        yield check_uniform_grid, method
+
+
+def test_bh_match_exact():
+    # check that the ``barnes_hut`` method match the exact one when
+    # ``angle = 0`` and ``perplexity > n_samples / 3``
+    random_state = check_random_state(0)
+    n_features = 10
+    X = random_state.randn(30, n_features).astype(np.float32)
+    X_embeddeds = {}
+    n_iter = {}
+    for method in ['exact', 'barnes_hut']:
+        tsne = TSNE(n_components=2, method=method, learning_rate=1.0,
+                    init="random", random_state=0, n_iter=251,
+                    perplexity=30.0, angle=0)
+        # Kill the early_exaggeration
+        tsne._EXPLORATION_N_ITER = 0
+        X_embeddeds[method] = tsne.fit_transform(X)
+        n_iter[method] = tsne.n_iter_
+
+    assert n_iter['exact'] == n_iter['barnes_hut']
+    assert_array_almost_equal(X_embeddeds['exact'], X_embeddeds['barnes_hut'],
+                              decimal=3)
diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py
index e88b00cd32..88cb62623e 100644
--- a/sklearn/mixture/base.py
+++ b/sklearn/mixture/base.py
@@ -351,7 +351,7 @@ class BaseMixture(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)):
         Returns
         -------
         resp : array, shape (n_samples, n_components)
-            Returns the probability of each Gaussian (state) in
+            Returns the probability each Gaussian (state) in
             the model given each sample.
         """
         self._check_is_fitted()
diff --git a/sklearn/neighbors/quad_tree.pxd b/sklearn/neighbors/quad_tree.pxd
new file mode 100644
index 0000000000..0dc4bd3fe5
--- /dev/null
+++ b/sklearn/neighbors/quad_tree.pxd
@@ -0,0 +1,100 @@
+# cython: boundscheck=False
+# cython: wraparound=False
+# cython: cdivision=True
+# Author: Thomas Moreau <thomas.moreau.2010@gmail.com>
+# Author: Olivier Grisel <olivier.grisel@ensta.fr>
+
+# See quad_tree.pyx for details.
+
+import numpy as np
+cimport numpy as np
+
+ctypedef np.npy_float32 DTYPE_t          # Type of X
+ctypedef np.npy_intp SIZE_t              # Type for indices and counters
+ctypedef np.npy_int32 INT32_t            # Signed 32 bit integer
+ctypedef np.npy_uint32 UINT32_t          # Unsigned 32 bit integer
+
+# This is effectively an ifdef statement in Cython
+# It allows us to write printf debugging lines
+# and remove them at compile time
+cdef enum:
+    DEBUGFLAG = 0
+
+cdef float EPSILON = 1e-6
+
+# XXX: Careful to not change the order of the arguments. It is important to
+# have is_leaf and max_width consecutive as it permits to avoid padding by
+# the compiler and keep the size coherent for both C and numpy data structures.
+cdef struct Cell:
+    # Base storage stucture for cells in a QuadTree object
+
+    # Tree structure
+    SIZE_t parent              # Parent cell of this cell
+    SIZE_t[8] children         # Array pointing to childrens of this cell
+    
+    # Cell description
+    SIZE_t cell_id             # Id of the cell in the cells array in the Tree
+    SIZE_t point_index         # Index of the point at this cell (only defined
+                               # in non empty leaf)
+    bint is_leaf               # Does this cell have children?
+    DTYPE_t squared_max_width  # Squared value of the maximum width w
+    SIZE_t depth               # Depth of the cell in the tree
+    SIZE_t cumulative_size     # Number of points included in the subtree with
+                               # this cell as a root.
+
+    # Internal constants
+    DTYPE_t[3] center          # Store the center for quick split of cells
+    DTYPE_t[3] barycenter      # Keep track of the center of mass of the cell
+
+    # Cell boundaries
+    DTYPE_t[3] min_bounds      # Inferior boundaries of this cell (inclusive)
+    DTYPE_t[3] max_bounds      # Superior boundaries of this cell (exclusive)
+
+
+cdef class _QuadTree:
+    # The QuadTree object is a quad tree structure constructed by inserting
+    # recursively points in the tree and splitting cells in 4 so that each
+    # leaf cell contains at most one point.
+    # This structure also handle 3D data, inserted in trees with 8 children
+    # for each node.
+
+    # Parameters of the tree
+    cdef public int n_dimensions         # Number of dimensions in X
+    cdef public int verbose              # Verbosity of the output
+    cdef SIZE_t n_cells_per_cell         # Number of children per node. (2 ** n_dimension)
+
+    # Tree inner structure
+    cdef public SIZE_t max_depth         # Max depth of the tree
+    cdef public SIZE_t cell_count        # Counter for node IDs
+    cdef public SIZE_t capacity          # Capacity of tree, in terms of nodes
+    cdef public SIZE_t n_points          # Total number of points
+    cdef Cell* cells                     # Array of nodes
+
+    # Point insertion methods
+    cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
+                          SIZE_t cell_id=*) nogil except -1
+    cdef SIZE_t _insert_point_in_new_child(self, DTYPE_t[3] point, Cell* cell,
+                                           SIZE_t point_index, SIZE_t size=*
+                                           ) nogil
+    cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) nogil
+    cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) nogil
+
+    # Create a summary of the Tree compare to a query point
+    cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
+                        float squared_theta=*, int cell_id=*, long idx=*
+                        ) nogil
+
+    # Internal cell initialization methods
+    cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) nogil
+    cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
+                         ) nogil
+
+    # Private methods
+    cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
+                                  ) nogil except -1
+
+    # Private array manipulation to manage the ``cells`` array
+    cdef int _resize(self, SIZE_t capacity) nogil except -1
+    cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
+    cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=*) nogil except -1
+    cdef np.ndarray _get_cell_ndarray(self)
diff --git a/sklearn/neighbors/quad_tree.pyx b/sklearn/neighbors/quad_tree.pyx
new file mode 100644
index 0000000000..b2cdaac84c
--- /dev/null
+++ b/sklearn/neighbors/quad_tree.pyx
@@ -0,0 +1,672 @@
+# cython: boundscheck=False
+# cython: wraparound=False
+# cython: cdivision=True
+# Author: Thomas Moreau <thomas.moreau.2010@gmail.com>
+# Author: Olivier Grisel <olivier.grisel@ensta.fr>
+
+
+from cpython cimport Py_INCREF, PyObject
+
+from libc.stdlib cimport malloc, free
+from libc.string cimport memcpy
+from libc.stdio cimport printf
+
+from sklearn.tree._utils cimport safe_realloc, sizet_ptr_to_ndarray
+from ..utils import check_array
+
+import numpy as np
+cimport numpy as np
+np.import_array()
+
+cdef extern from "math.h":
+    float fabsf(float x) nogil
+
+cdef extern from "numpy/arrayobject.h":
+    object PyArray_NewFromDescr(object subtype, np.dtype descr,
+                                int nd, np.npy_intp* dims,
+                                np.npy_intp* strides,
+                                void* data, int flags, object obj)
+
+
+# XXX using (size_t)(-1) is ugly, but SIZE_MAX is not available in C89
+# (i.e., older MSVC).
+cdef SIZE_t DEFAULT = <SIZE_t>(-1)
+
+
+# Repeat struct definition for numpy
+CELL_DTYPE = np.dtype({
+    'names': ['parent', 'children', 'cell_id', 'point_index', 'is_leaf',
+              'max_width', 'depth', 'cumulative_size', 'center', 'barycenter',
+              'min_bounds', 'max_bounds'],
+    'formats': [np.intp, (np.intp, 8), np.intp, np.intp, np.int32, np.float32,
+                np.intp, np.intp, (np.float32, 3), (np.float32, 3),
+                (np.float32, 3), (np.float32, 3)],
+    'offsets': [
+        <Py_ssize_t> &(<Cell*> NULL).parent,
+        <Py_ssize_t> &(<Cell*> NULL).children,
+        <Py_ssize_t> &(<Cell*> NULL).cell_id,
+        <Py_ssize_t> &(<Cell*> NULL).point_index,
+        <Py_ssize_t> &(<Cell*> NULL).is_leaf,
+        <Py_ssize_t> &(<Cell*> NULL).squared_max_width,
+        <Py_ssize_t> &(<Cell*> NULL).depth,
+        <Py_ssize_t> &(<Cell*> NULL).cumulative_size,
+        <Py_ssize_t> &(<Cell*> NULL).center,
+        <Py_ssize_t> &(<Cell*> NULL).barycenter,
+        <Py_ssize_t> &(<Cell*> NULL).min_bounds,
+        <Py_ssize_t> &(<Cell*> NULL).max_bounds,
+    ]
+})
+
+assert CELL_DTYPE.itemsize == sizeof(Cell)
+
+
+cdef class _QuadTree:
+    """Array-based representation of a QuadTree.
+
+    This class is currently working for indexing 2D data (regular QuadTree) and
+    for indexing 3D data (OcTree). It is planned to split the 2 implementations
+    using `Cython.Tempita` to save some memory for QuadTree.
+
+    Note that this code is currently internally used only by the Barnes-Hut
+    method in `sklearn.manifold.TSNE`. It is planned to be refactored and
+    generalized in the future to be compatible with nearest neighbors API of
+    `sklearn.neighbors` with 2D and 3D data.
+    """
+    def __cinit__(self, int n_dimensions, int verbose):
+        """Constructor."""
+        # Parameters of the tree
+        self.n_dimensions = n_dimensions
+        self.verbose = verbose
+        self.n_cells_per_cell = 2 ** self.n_dimensions
+
+        # Inner structures
+        self.max_depth = 0
+        self.cell_count = 0
+        self.capacity = 0
+        self.n_points = 0
+        self.cells = NULL
+
+    def __dealloc__(self):
+        """Destructor."""
+        # Free all inner structures
+        free(self.cells)
+
+    property cumulative_size:
+        def __get__(self):
+            return self._get_cell_ndarray()['cumulative_size'][:self.cell_count]
+
+    property leafs:
+        def __get__(self):
+            return self._get_cell_ndarray()['is_leaf'][:self.cell_count]
+
+    def build_tree(self, X):
+        """Build a tree from an arary of points X."""
+        cdef:
+            int i
+            DTYPE_t[3] pt
+            DTYPE_t[3] min_bounds, max_bounds
+
+        # validate X and prepare for query
+        # X = check_array(X, dtype=DTYPE_t, order='C')
+        n_samples = X.shape[0]
+
+        capacity = 100
+        self._resize(capacity)
+        m = np.min(X, axis=0)
+        M = np.max(X, axis=0)
+        # Scale the maximum to get all points strictly in the tree bounding box
+        # The 3 bounds are for positive, negative and small values
+        M = np.maximum(M * (1. + 1e-3 * np.sign(M)), M + 1e-3)
+        for i in range(self.n_dimensions):
+            min_bounds[i] = m[i]
+            max_bounds[i] = M[i]
+
+            if self.verbose > 10:
+                printf("[QuadTree] bounding box axis %i : [%f, %f]\n",
+                       i, min_bounds[i], max_bounds[i])
+
+        # Create the initial node with boundaries from the dataset
+        self._init_root(min_bounds, max_bounds)
+
+        for i in range(n_samples):
+            for j in range(self.n_dimensions):
+                pt[j] = X[i, j]
+            self.insert_point(pt, i)
+
+        # Shrink the cells array to reduce memory usage
+        self._resize(capacity=self.cell_count)
+
+    cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
+                          SIZE_t cell_id=0) nogil except -1:
+        """Insert a point in the QuadTree."""
+        cdef int ax
+        cdef DTYPE_t n_frac
+        cdef SIZE_t selected_child
+        cdef Cell* cell = &self.cells[cell_id]
+        cdef SIZE_t n_point = cell.cumulative_size
+
+        if self.verbose > 10:
+            printf("[QuadTree] Inserting depth %li\n", cell.depth)
+
+        # Assert that the point is in the right range
+        if DEBUGFLAG:
+            self._check_point_in_cell(point, cell)
+
+        # If the cell is an empty leaf, insert the point in it
+        if cell.cumulative_size == 0:
+            cell.cumulative_size = 1
+            self.n_points += 1
+            for i in range(self.n_dimensions):
+                cell.barycenter[i] = point[i]
+            cell.point_index = point_index
+            if self.verbose > 10:
+                printf("[QuadTree] inserted point %li in cell %li\n",
+                       point_index, cell_id)
+            return cell_id
+
+        # If the cell is not a leaf, update cell internals and
+        # recurse in selected child
+        if not cell.is_leaf:
+            for ax in range(self.n_dimensions):
+                # barycenter update using a weighted mean
+                cell.barycenter[ax] = (
+                    n_point * cell.barycenter[ax] + point[ax]) / (n_point + 1)
+
+            # Increase the size of the subtree starting from this cell
+            cell.cumulative_size += 1
+
+            # Insert child in the correct subtree
+            selected_child = self._select_child(point, cell)
+            if self.verbose > 49:
+                printf("[QuadTree] selected child %li\n", selected_child)
+            if selected_child == -1:
+                self.n_points += 1
+                return self._insert_point_in_new_child(point, cell, point_index)
+            return self.insert_point(point, point_index, selected_child)
+
+        # Finally, if the cell is a leaf with a point already inserted,
+        # split the cell in n_cells_per_cell if the point is not a duplicate.
+        # If it is a duplicate, increase the size of the leaf and return.
+        if self._is_duplicate(point, cell.barycenter):
+            if self.verbose > 10:
+                printf("[QuadTree] found a duplicate!\n")
+            cell.cumulative_size += 1
+            self.n_points += 1
+            return cell_id
+
+        # In a leaf, the barycenter correspond to the only point included
+        # in it.
+        self._insert_point_in_new_child(cell.barycenter, cell, cell.point_index,
+                                        cell.cumulative_size)
+        return self.insert_point(point, point_index, cell_id)
+
+    # XXX: This operation is not Thread safe
+    cdef SIZE_t _insert_point_in_new_child(self, DTYPE_t[3] point, Cell* cell,
+                                          SIZE_t point_index, SIZE_t size=1
+                                          ) nogil:
+        """Create a child of cell which will contain point."""
+
+        # Local variable definition
+        cdef:
+            SIZE_t cell_id, cell_child_id, parent_id
+            DTYPE_t[3] save_point
+            DTYPE_t width
+            Cell* child
+            int i
+
+        # If the maximal capacity of the Tree have been reach, double the capacity
+        # We need to save the current cell id and the current point to retrieve them
+        # in case the reallocation
+        if self.cell_count + 1 > self.capacity:
+            parent_id = cell.cell_id
+            for i in range(self.n_dimensions):
+                save_point[i] = point[i]
+            self._resize(DEFAULT)
+            cell = &self.cells[parent_id]
+            point = save_point
+
+        # Get an empty cell and initialize it
+        cell_id = self.cell_count
+        self.cell_count += 1
+        child  = &self.cells[cell_id]
+
+        self._init_cell(child, cell.cell_id, cell.depth + 1)
+        child.cell_id = cell_id
+
+        # Set the cell as an inner cell of the Tree
+        cell.is_leaf = False
+        cell.point_index = -1
+
+        # Set the correct boundary for the cell, store the point in the cell
+        # and compute its index in the children array.
+        cell_child_id = 0
+        for i in range(self.n_dimensions):
+            cell_child_id *= 2
+            if point[i] >= cell.center[i]:
+                cell_child_id += 1
+                child.min_bounds[i] = cell.center[i]
+                child.max_bounds[i] = cell.max_bounds[i]
+            else:
+                child.min_bounds[i] = cell.min_bounds[i]
+                child.max_bounds[i] = cell.center[i]
+            child.center[i] = (child.min_bounds[i] + child.max_bounds[i]) / 2.
+            width = child.max_bounds[i] - child.min_bounds[i]
+
+            child.barycenter[i] = point[i]
+            child.squared_max_width = max(child.squared_max_width, width*width)
+
+        # Store the point info and the size to account for duplicated points
+        child.point_index = point_index
+        child.cumulative_size = size
+
+        # Store the child cell in the correct place in children
+        cell.children[cell_child_id] = child.cell_id
+
+        if DEBUGFLAG:
+            # Assert that the point is in the right range
+            self._check_point_in_cell(point, child)
+        if self.verbose > 10:
+            printf("[QuadTree] inserted point %li in new child %li\n",
+                   point_index, cell_id)
+
+        return cell_id
+
+
+    cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) nogil:
+        """Check if the two given points are equals."""
+        cdef int i
+        cdef bint res = True
+        for i in range(self.n_dimensions):
+            # Use EPSILON to avoid numerical error that would overgrow the tree
+            res &= fabsf(point1[i] - point2[i]) <= EPSILON
+        return res
+
+
+    cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) nogil:
+        """Select the child of cell which contains the given query point."""
+        cdef:
+            int i
+            SIZE_t selected_child = 0
+
+        for i in range(self.n_dimensions):
+            # Select the correct child cell to insert the point by comparing
+            # it to the borders of the cells using precomputed center.
+            selected_child *= 2
+            if point[i] >= cell.center[i]:
+                selected_child += 1
+        return cell.children[selected_child]
+
+    cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) nogil:
+        """Initialize a cell structure with some constants."""
+        cell.parent = parent
+        cell.is_leaf = True
+        cell.depth = depth
+        cell.squared_max_width = 0
+        cell.cumulative_size = 0
+        for i in range(self.n_cells_per_cell):
+            cell.children[i] = DEFAULT
+
+    cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
+                         ) nogil:
+        """Initialize the root node with the given space boundaries"""
+        cdef:
+            int i
+            DTYPE_t width
+            Cell* root = &self.cells[0]
+
+        self._init_cell(root, -1, 0)
+        for i in range(self.n_dimensions):
+            root.min_bounds[i] = min_bounds[i]
+            root.max_bounds[i] = max_bounds[i]
+            root.center[i] = (max_bounds[i] + min_bounds[i]) / 2.
+            width = max_bounds[i] - min_bounds[i]
+            root.squared_max_width = max(root.squared_max_width, width*width)
+        root.cell_id = 0
+
+        self.cell_count += 1
+
+    cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
+                                  ) nogil except -1:
+        """Check that the given point is in the cell boundaries."""
+
+        if self.verbose >= 50:
+            if self.n_dimensions == 3:
+                printf("[QuadTree] Checking point (%f, %f, %f) in cell %li "
+                        "([%f/%f, %f/%f, %f/%f], size %li)\n",
+                        point[0], point[1], point[2], cell.cell_id,
+                        cell.min_bounds[0], cell.max_bounds[0], cell.min_bounds[1],
+                        cell.max_bounds[1], cell.min_bounds[2], cell.max_bounds[2],
+                        cell.cumulative_size)
+            else:
+                printf("[QuadTree] Checking point (%f, %f) in cell %li "
+                        "([%f/%f, %f/%f], size %li)\n",
+                        point[0], point[1],cell.cell_id, cell.min_bounds[0],
+                        cell.max_bounds[0], cell.min_bounds[1],
+                        cell.max_bounds[1], cell.cumulative_size)
+
+        for i in range(self.n_dimensions):
+            if (cell.min_bounds[i] > point[i] or
+                    cell.max_bounds[i] <= point[i]):
+                with gil:
+                    msg = "[QuadTree] InsertionError: point out of cell "
+                    msg += "boundary.\nAxis %li: cell [%f, %f]; point %f\n"
+
+                    msg %= i, cell.min_bounds[i],  cell.max_bounds[i], point[i]
+                    raise ValueError(msg)
+
+    def _check_coherence(self):
+        """Check the coherence of the cells of the tree.
+
+        Check that the info stored in each cell is compatible with the info
+        stored in descendent and sibling cells. Raise a ValueError if this
+        fails.
+        """
+        for cell in self.cells[:self.cell_count]:
+            # Check that the barycenter of inserted point is within the cell
+            # boundaries
+            self._check_point_in_cell(cell.barycenter, &cell)
+
+            if not cell.is_leaf:
+                # Compute the number of point in children and compare with
+                # its cummulative_size.
+                n_points = 0
+                for idx in range(self.n_cells_per_cell):
+                    child_id = cell.children[idx]
+                    if child_id != -1:
+                        child = self.cells[child_id]
+                        n_points += child.cumulative_size
+                        assert child.cell_id == child_id, (
+                            "Cell id not correctly initiliazed.")
+                if n_points != cell.cumulative_size:
+                    raise ValueError(
+                        "Cell {} is incoherent. Size={} but found {} points "
+                        "in children. ({})"
+                        .format(cell.cell_id, cell.cumulative_size,
+                                n_points, cell.children))
+
+        # Make sure that the number of point in the tree correspond to the
+        # cummulative size in root cell.
+        if self.n_points != self.cells[0].cumulative_size:
+            raise ValueError(
+                "QuadTree is incoherent. Size={} but found {} points "
+                "in children."
+                .format(self.n_points, self.cells[0].cumulative_size))
+
+    cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
+                        float squared_theta=.5, SIZE_t cell_id=0, long idx=0
+                        ) nogil:
+        """Summarize the tree compared to a query point.
+
+        Input arguments
+        ---------------
+        point : array (n_dimensions)
+             query point to construct the summary.
+        cell_id : integer, optional (default: 0)
+            current cell of the tree summarized. This should be set to 0 for
+            external calls.
+        idx : integer, optional (default: 0)
+            current index in the result array. This should be set to 0 for
+            external calls
+        squared_theta: float, optional (default: .5)
+            threshold to decide whether the node is sufficiently far
+            from the query point to be a good summary. The formula is such that
+            the node is a summary if
+                node_width^2 / dist_node_point^2 < squared_theta.
+            Note that the argument should be passed as theta^2 to avoid
+            computing square roots of the distances.
+
+        Output arguments
+        ----------------
+        results : array (n_samples * (n_dimensions+2))
+            result will contain a summary of the tree information compared to
+            the query point:
+            - results[idx:idx+n_dimensions] contains the coordinate-wise
+                difference between the query point and the summary cell idx.
+                This is usefull in t-SNE to compute the negative forces.
+            - result[idx+n_dimensions+1] contains the squared euclidean
+                distance to the summary cell idx.
+            - result[idx+n_dimensions+2] contains the number of point of the
+                tree contained in the summary cell idx.
+
+        Return
+        ------
+        idx : integer
+            number of elements in the results array.
+        """
+        cdef:
+            int i, idx_d = idx + self.n_dimensions
+            bint duplicate = True
+            Cell* cell = &self.cells[cell_id]
+
+        results[idx_d] = 0.
+        for i in range(self.n_dimensions):
+            results[idx + i] = point[i] - cell.barycenter[i]
+            results[idx_d] += results[idx + i] * results[idx + i]
+            duplicate &= fabsf(results[idx + i]) <= EPSILON
+
+        # Do not compute self interactions
+        if duplicate and cell.is_leaf:
+            return idx
+
+        # Check whether we can use this node as a summary
+        # It's a summary node if the angular size as measured from the point
+        # is relatively small (w.r.t. to theta) or if it is a leaf node.
+        # If it can be summarized, we use the cell center of mass
+        # Otherwise, we go a higher level of resolution and into the leaves.
+        if cell.is_leaf or (
+                (cell.squared_max_width / results[idx_d]) < squared_theta):
+            results[idx_d + 1] = <DTYPE_t> cell.cumulative_size
+            return idx + self.n_dimensions + 2
+
+        else:
+            # Recursively compute the summary in nodes
+            for c in range(self.n_cells_per_cell):
+                child_id = cell.children[c]
+                if child_id != -1:
+                    idx = self.summarize(point, results, squared_theta,
+                                         child_id, idx)
+
+        return idx
+
+    def get_cell(self, point):
+        """return the id of the cell containing the query point or raise 
+        ValueError if the point is not in the tree
+        """
+        cdef DTYPE_t[3] query_pt
+        cdef int i
+
+        assert len(point) == self.n_dimensions, (
+            "Query point should be a point in dimension {}."
+            .format(self.n_dimensions))
+
+        for i in range(self.n_dimensions):
+            query_pt[i] = point[i]
+
+        return self._get_cell(query_pt, 0)
+
+    cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=0
+                       ) nogil except -1:
+        """guts of get_cell.
+        
+        Return the id of the cell containing the query point or raise ValueError
+        if the point is not in the tree"""
+        cdef:
+            SIZE_t selected_child
+            Cell* cell = &self.cells[cell_id]
+
+        if cell.is_leaf:
+            if self._is_duplicate(cell.barycenter, point):
+                if self.verbose > 99:
+                    printf("[QuadTree] Found point in cell: %li\n",
+                           cell.cell_id)
+                return cell_id
+            with gil:
+                raise ValueError("Query point not in the Tree.")
+
+        selected_child = self._select_child(point, cell)
+        if selected_child > 0:
+            if self.verbose > 99:
+                printf("[QuadTree] Selected_child: %li\n", selected_child)
+            return self._get_cell(point, selected_child)
+        with gil:
+            raise ValueError("Query point not in the Tree.")
+
+    # Pickling primitives
+
+    def __reduce__(self):
+        """Reduce re-implementation, for pickling."""
+        return (_QuadTree, (self.n_dimensions, self.verbose),
+                           self.__getstate__())
+
+    def __getstate__(self):
+        """Getstate re-implementation, for pickling."""
+        d = {}
+        # capacity is infered during the __setstate__ using nodes
+        d["max_depth"] = self.max_depth
+        d["cell_count"] = self.cell_count
+        d["capacity"] = self.capacity
+        d["n_points"] = self.n_points
+        d["cells"] = self._get_cell_ndarray()
+        return d
+
+    def __setstate__(self, d):
+        """Setstate re-implementation, for unpickling."""
+        self.max_depth = d["max_depth"]
+        self.cell_count = d["cell_count"]
+        self.capacity = d["capacity"]
+        self.n_points = d["n_points"]
+
+        if 'cells' not in d:
+            raise ValueError('You have loaded Tree version which '
+                             'cannot be imported')
+
+        cell_ndarray = d['cells']
+
+        if (cell_ndarray.ndim != 1 or
+                cell_ndarray.dtype != CELL_DTYPE or
+                not cell_ndarray.flags.c_contiguous):
+            raise ValueError('Did not recognise loaded array layout')
+
+        self.capacity = cell_ndarray.shape[0]
+        if self._resize_c(self.capacity) != 0:
+            raise MemoryError("resizing tree to %d" % self.capacity)
+
+        cells = memcpy(self.cells, (<np.ndarray> cell_ndarray).data,
+                       self.capacity * sizeof(Cell))
+
+
+    # Array manipulation methods, to convert it to numpy or to resize
+    # self.cells array
+
+    cdef np.ndarray _get_cell_ndarray(self):
+        """Wraps nodes as a NumPy struct array.
+
+        The array keeps a reference to this Tree, which manages the underlying
+        memory. Individual fields are publicly accessible as properties of the
+        Tree.
+        """
+        cdef np.npy_intp shape[1]
+        shape[0] = <np.npy_intp> self.cell_count
+        cdef np.npy_intp strides[1]
+        strides[0] = sizeof(Cell)
+        cdef np.ndarray arr
+        Py_INCREF(CELL_DTYPE)
+        arr = PyArray_NewFromDescr(np.ndarray, CELL_DTYPE, 1, shape,
+                                   strides, <void*> self.cells,
+                                   np.NPY_DEFAULT, None)
+        Py_INCREF(self)
+        arr.base = <PyObject*> self
+        return arr
+
+    cdef int _resize(self, SIZE_t capacity) nogil except -1:
+        """Resize all inner arrays to `capacity`, if `capacity` == -1, then
+           double the size of the inner arrays.
+
+        Returns -1 in case of failure to allocate memory (and raise MemoryError)
+        or 0 otherwise.
+        """
+        if self._resize_c(capacity) != 0:
+            # Acquire gil only if we need to raise
+            with gil:
+                raise MemoryError()
+
+    cdef int _resize_c(self, SIZE_t capacity=DEFAULT) nogil except -1:
+        """Guts of _resize
+
+        Returns -1 in case of failure to allocate memory (and raise MemoryError)
+        or 0 otherwise.
+        """
+        if capacity == self.capacity and self.cells != NULL:
+            return 0
+
+        if capacity == DEFAULT:
+            if self.capacity == 0:
+                capacity = 9  # default initial value to min
+            else:
+                capacity = 2 * self.capacity
+
+        safe_realloc(&self.cells, capacity)
+
+        # if capacity smaller than cell_count, adjust the counter
+        if capacity < self.cell_count:
+            self.cell_count = capacity
+
+        self.capacity = capacity
+        return 0
+
+    @staticmethod
+    def test_summarize():
+
+        cdef:
+            DTYPE_t[3] query_pt
+            float* summary
+            int i, n_samples, n_dimensions
+
+        n_dimensions = 2
+        n_samples = 4
+        angle = 0.9
+        offset = n_dimensions + 2
+        X = np.array([[-10., -10.], [9., 10.], [10., 9.], [10., 10.]])
+
+        n_dimensions = X.shape[1]
+        qt = _QuadTree(n_dimensions, verbose=0)
+        qt.build_tree(X)
+
+        summary = <float*> malloc(sizeof(float) * n_samples * 4)
+
+        for i in range(n_dimensions):
+            query_pt[i] = X[0, i]
+
+        # Summary should contain only 1 node with size 3 and distance to
+        # X[1:] barycenter
+        idx = qt.summarize(query_pt, summary, angle * angle)
+
+        node_dist = summary[n_dimensions]
+        node_size = summary[n_dimensions + 1]
+
+        barycenter = X[1:].mean(axis=0)
+        ds2c = ((X[0] - barycenter) ** 2).sum()
+
+        assert idx == offset
+        assert node_size == 3, "summary size = {}".format(node_size)
+        assert np.isclose(node_dist, ds2c)
+
+        # Summary should contain all 3 node with size 1 and distance to
+        # each point in X[1:] for ``angle=0``
+        idx = qt.summarize(query_pt, summary, 0)
+
+        node_dist = summary[n_dimensions]
+        node_size = summary[n_dimensions + 1]
+
+        barycenter = X[1:].mean(axis=0)
+        ds2c = ((X[0] - barycenter) ** 2).sum()
+
+        assert idx == 3 * (offset)
+        for i in range(3):
+            node_dist = summary[i * offset + n_dimensions]
+            node_size = summary[i * offset + n_dimensions + 1]
+
+            ds2c = ((X[0] - X[i + 1]) ** 2).sum()
+
+            assert node_size == 1, "summary size = {}".format(node_size)
+            assert np.isclose(node_dist, ds2c)
diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py
index 1180b8c365..8b1ad7bac9 100644
--- a/sklearn/neighbors/setup.py
+++ b/sklearn/neighbors/setup.py
@@ -31,6 +31,10 @@ def configuration(parent_package='', top_path=None):
                          sources=['typedefs.pyx'],
                          include_dirs=[numpy.get_include()],
                          libraries=libraries)
+    config.add_extension("quad_tree",
+                         sources=["quad_tree.pyx"],
+                         include_dirs=[numpy.get_include()],
+                         libraries=libraries)
 
     config.add_subpackage('tests')
 
diff --git a/sklearn/neighbors/tests/test_quad_tree.py b/sklearn/neighbors/tests/test_quad_tree.py
new file mode 100644
index 0000000000..6cfa4bcc56
--- /dev/null
+++ b/sklearn/neighbors/tests/test_quad_tree.py
@@ -0,0 +1,108 @@
+import pickle
+import numpy as np
+from sklearn.neighbors.quad_tree import _QuadTree
+from sklearn.utils import check_random_state
+
+
+def test_quadtree_boundary_computation():
+    # Introduce a point into a quad tree with boundaries not easy to compute.
+    Xs = []
+
+    # check a random case
+    Xs.append(np.array([[-1, 1], [-4, -1]], dtype=np.float32))
+    # check the case where only 0 are inserted
+    Xs.append(np.array([[0, 0], [0, 0]], dtype=np.float32))
+    # check the case where only negative are inserted
+    Xs.append(np.array([[-1, -2], [-4, 0]], dtype=np.float32))
+    # check the case where only small numbers are inserted
+    Xs.append(np.array([[-1e-6, 1e-6], [-4e-6, -1e-6]], dtype=np.float32))
+
+    for X in Xs:
+        tree = _QuadTree(n_dimensions=2, verbose=0)
+        tree.build_tree(X)
+        tree._check_coherence()
+
+
+def test_quadtree_similar_point():
+    # Introduce a point into a quad tree where a similar point already exists.
+    # Test will hang if it doesn't complete.
+    Xs = []
+
+    # check the case where points are actually different
+    Xs.append(np.array([[1, 2], [3, 4]], dtype=np.float32))
+    # check the case where points are the same on X axis
+    Xs.append(np.array([[1.0, 2.0], [1.0, 3.0]], dtype=np.float32))
+    # check the case where points are arbitrarily close on X axis
+    Xs.append(np.array([[1.00001, 2.0], [1.00002, 3.0]], dtype=np.float32))
+    # check the case where points are the same on Y axis
+    Xs.append(np.array([[1.0, 2.0], [3.0, 2.0]], dtype=np.float32))
+    # check the case where points are arbitrarily close on Y axis
+    Xs.append(np.array([[1.0, 2.00001], [3.0, 2.00002]], dtype=np.float32))
+    # check the case where points are arbitrarily close on both axes
+    Xs.append(np.array([[1.00001, 2.00001], [1.00002, 2.00002]],
+              dtype=np.float32))
+
+    # check the case where points are arbitrarily close on both axes
+    # close to machine epsilon - x axis
+    Xs.append(np.array([[1, 0.0003817754041], [2, 0.0003817753750]],
+              dtype=np.float32))
+
+    # check the case where points are arbitrarily close on both axes
+    # close to machine epsilon - y axis
+    Xs.append(np.array([[0.0003817754041, 1.0], [0.0003817753750, 2.0]],
+              dtype=np.float32))
+
+    for X in Xs:
+        tree = _QuadTree(n_dimensions=2, verbose=0)
+        tree.build_tree(X)
+        tree._check_coherence()
+
+
+def test_quad_tree_pickle():
+    rng = check_random_state(0)
+
+    for n_dimensions in (2, 3):
+        X = rng.random_sample((10, n_dimensions))
+
+        tree = _QuadTree(n_dimensions=n_dimensions, verbose=0)
+        tree.build_tree(X)
+
+        def check_pickle_protocol(protocol):
+            s = pickle.dumps(tree, protocol=protocol)
+            bt2 = pickle.loads(s)
+
+            for x in X:
+                cell_x_tree = tree.get_cell(x)
+                cell_x_bt2 = bt2.get_cell(x)
+                assert cell_x_tree == cell_x_bt2
+
+        for protocol in (0, 1, 2):
+            yield check_pickle_protocol, protocol
+
+
+def test_qt_insert_duplicate():
+    rng = check_random_state(0)
+
+    def check_insert_duplicate(n_dimensions=2):
+
+        X = rng.random_sample((10, n_dimensions))
+        Xd = np.r_[X, X[:5]]
+        tree = _QuadTree(n_dimensions=n_dimensions, verbose=0)
+        tree.build_tree(Xd)
+
+        cumulative_size = tree.cumulative_size
+        leafs = tree.leafs
+
+        # Assert that the first 5 are indeed duplicated and that the next
+        # ones are single point leaf
+        for i, x in enumerate(X):
+            cell_id = tree.get_cell(x)
+            assert leafs[cell_id]
+            assert cumulative_size[cell_id] == 1 + (i < 5)
+
+    for n_dimensions in (2, 3):
+        yield check_insert_duplicate, n_dimensions
+
+
+def test_summarize():
+    _QuadTree.test_summarize()
diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd
index 017888ab41..04806ade18 100644
--- a/sklearn/tree/_utils.pxd
+++ b/sklearn/tree/_utils.pxd
@@ -10,7 +10,8 @@
 
 import numpy as np
 cimport numpy as np
-from _tree cimport Node 
+from _tree cimport Node
+from sklearn.neighbors.quad_tree cimport Cell
 
 ctypedef np.npy_float32 DTYPE_t          # Type of X
 ctypedef np.npy_float64 DOUBLE_t         # Type of y, sample_weight
@@ -39,6 +40,7 @@ ctypedef fused realloc_ptr:
     (DOUBLE_t*)
     (DOUBLE_t**)
     (Node*)
+    (Cell*)
     (Node**)
     (StackRecord*)
     (PriorityHeapRecord*)
-- 
GitLab