From 0dc22798e72cca233e62c0fa37408fed5d4d2394 Mon Sep 17 00:00:00 2001
From: Utkarsh Upadhyay <mail@musicallyut.in>
Date: Tue, 4 Jul 2017 00:20:52 +0200
Subject: [PATCH] [MRG+1] Fix semi_supervised (#9239)

* Files for my dev environment with Docker

* Fixing label clamping (alpha=0 for hard clamping)

* Deprecating alpha, fixing its value to zero

* Correct way to deprecate alpha for LabelPropagation

The previous way was breaking the test
sklearn.tests.test_common.test_all_estimators

* Detailed info for LabelSpreading's alpha parameter

Based on the original paper.

* Minor changes in the deprecation message

* Improving "deprecated" doc string and raising DeprecationWarning

* Using a local "alpha" in "fit" to deprecate LabelPropagation's alpha

This solution isn't great, but it sets the correct value for alpha
without violating the restrictions imposed by the tests.

* Removal of my development files

* Using sphinx's "deprecated" tag (jnothman's suggestion)

* Deprecation warning: stating that the alpha's value will be ignored

* Use __init__ with alpha=None

* Update what's new

* Try fix RuntimeWarning in test_alpha_deprecation

* DOC Indent deprecation details

* DOC wording

* Update docs

* Change to the one true implementation.

* Add sanity-checked impl. of Label{Propagation,Spreading}

* Raise ValueError if alpha is invalid in LabelSpreading.

* Add a normalizing step before clamping to LabelPropagation.

* Fix flake8 errors.

* Remove duplicate imports.

* DOC Update What's New.

* Specify alpha's value in the error.

* Tidy up tests.

Add a test and add references, where needed.

* Add comment to non-regression test.

* Fix documentation.

* Move check for alpha into fit from __init__.

* Fix corner case of LabelSpreading with alpha=None.

* alpha -> self.variant

* Make Whats_new more explicit.

* Simplify impl. of Label{Propagation,Spreading}.

* variant -> _variant.
---
 doc/modules/label_propagation.rst             |  4 +-
 doc/whats_new.rst                             | 11 ++-
 .../plot_label_propagation_structure.py       |  2 +-
 sklearn/semi_supervised/label_propagation.py  | 76 ++++++++++++----
 .../tests/test_label_propagation.py           | 86 +++++++++++++++++++
 5 files changed, 160 insertions(+), 19 deletions(-)

diff --git a/doc/modules/label_propagation.rst b/doc/modules/label_propagation.rst
index eddc34b7a8..1aba742723 100644
--- a/doc/modules/label_propagation.rst
+++ b/doc/modules/label_propagation.rst
@@ -52,8 +52,8 @@ differ in modifications to the similarity matrix that graph and the
 clamping effect on the label distributions.
 Clamping allows the algorithm to change the weight of the true ground labeled
 data to some degree. The :class:`LabelPropagation` algorithm performs hard
-clamping of input labels, which means :math:`\alpha=1`. This clamping factor
-can be relaxed, to say :math:`\alpha=0.8`, which means that we will always
+clamping of input labels, which means :math:`\alpha=0`. This clamping factor
+can be relaxed, to say :math:`\alpha=0.2`, which means that we will always
 retain 80 percent of our original label distribution, but the algorithm gets to
 change its confidence of the distribution within 20 percent.
 
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index a9601419c9..73fa6dcee8 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -448,7 +448,16 @@ Bug fixes
      in :class:`decomposition.PCA`,
      :class:`decomposition.RandomizedPCA` and
      :class:`decomposition.IncrementalPCA`.
-     :issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_. 
+     :issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.
+
+   - Fix :class:`semi_supervised.BaseLabelPropagation` to correctly implement
+     ``LabelPropagation`` and ``LabelSpreading`` as done in the referenced
+     papers. :class:`semi_supervised.LabelPropagation` now always does hard
+     clamping. Its ``alpha`` parameter has no effect and is
+     deprecated to be removed in 0.21. :issue:`6727` :issue:`3550` issue:`5770`
+     by :user:`Andre Ambrosio Boechat <boechat107>`, :user:`Utkarsh Upadhyay
+     <musically-ut>`, and `Joel Nothman`_.
+
 
 API changes summary
 -------------------
diff --git a/examples/semi_supervised/plot_label_propagation_structure.py b/examples/semi_supervised/plot_label_propagation_structure.py
index 7cc15d73f1..95f19ec108 100644
--- a/examples/semi_supervised/plot_label_propagation_structure.py
+++ b/examples/semi_supervised/plot_label_propagation_structure.py
@@ -30,7 +30,7 @@ labels[-1] = inner
 
 # #############################################################################
 # Learn with LabelSpreading
-label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=1.0)
+label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=0.2)
 label_spread.fit(X, labels)
 
 # #############################################################################
diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py
index 1759b2c1d7..ab0dd64bf8 100644
--- a/sklearn/semi_supervised/label_propagation.py
+++ b/sklearn/semi_supervised/label_propagation.py
@@ -14,11 +14,12 @@ For more information see the references below.
 Model Features
 --------------
 Label clamping:
-  The algorithm tries to learn distributions of labels over the dataset. In the
-  "Hard Clamp" mode, the true ground labels are never allowed to change. They
-  are clamped into position. In the "Soft Clamp" mode, they are allowed some
-  wiggle room, but some alpha of their original value will always be retained.
-  Hard clamp is the same as soft clamping with alpha set to 1.
+  The algorithm tries to learn distributions of labels over the dataset given
+  label assignments over an initial subset. In one variant, the algorithm does
+  not allow for any errors in the initial assignment (hard-clamping) while
+  in another variant, the algorithm allows for some wiggle room for the initial
+  assignments, allowing them to change by a fraction alpha in each iteration
+  (soft-clamping).
 
 Kernel:
   A function which projects a vector into some higher dimensional space. This
@@ -55,6 +56,7 @@ Non-Parametric Function Induction in Semi-Supervised Learning. AISTAT 2005
 # License: BSD
 from abc import ABCMeta, abstractmethod
 
+import warnings
 import numpy as np
 from scipy import sparse
 
@@ -239,10 +241,13 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
 
         n_samples, n_classes = len(y), len(classes)
 
+        alpha = self.alpha
+        if self._variant == 'spreading' and \
+                (alpha is None or alpha <= 0.0 or alpha >= 1.0):
+            raise ValueError('alpha=%s is invalid: it must be inside '
+                             'the open interval (0, 1)' % alpha)
         y = np.asarray(y)
         unlabeled = y == -1
-        clamp_weights = np.ones((n_samples, 1))
-        clamp_weights[unlabeled, 0] = self.alpha
 
         # initialize distributions
         self.label_distributions_ = np.zeros((n_samples, n_classes))
@@ -250,13 +255,17 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
             self.label_distributions_[y == label, classes == label] = 1
 
         y_static = np.copy(self.label_distributions_)
-        if self.alpha > 0.:
-            y_static *= 1 - self.alpha
-        y_static[unlabeled] = 0
+        if self._variant == 'propagation':
+            # LabelPropagation
+            y_static[unlabeled] = 0
+        else:
+            # LabelSpreading
+            y_static *= 1 - alpha
 
         l_previous = np.zeros((self.X_.shape[0], n_classes))
 
         remaining_iter = self.max_iter
+        unlabeled = unlabeled[:, np.newaxis]
         if sparse.isspmatrix(graph_matrix):
             graph_matrix = graph_matrix.tocsr()
         while (_not_converged(self.label_distributions_, l_previous, self.tol)
@@ -264,13 +273,23 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
             l_previous = self.label_distributions_
             self.label_distributions_ = safe_sparse_dot(
                 graph_matrix, self.label_distributions_)
-            # clamp
-            self.label_distributions_ = np.multiply(
-                clamp_weights, self.label_distributions_) + y_static
+
+            if self._variant == 'propagation':
+                normalizer = np.sum(
+                    self.label_distributions_, axis=1)[:, np.newaxis]
+                self.label_distributions_ /= normalizer
+                self.label_distributions_ = np.where(unlabeled,
+                                                     self.label_distributions_,
+                                                     y_static)
+            else:
+                # clamp
+                self.label_distributions_ = np.multiply(
+                    alpha, self.label_distributions_) + y_static
             remaining_iter -= 1
 
         normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
         self.label_distributions_ /= normalizer
+
         # set the transduction item
         transduction = self.classes_[np.argmax(self.label_distributions_,
                                                axis=1)]
@@ -299,7 +318,11 @@ class LabelPropagation(BaseLabelPropagation):
         Parameter for knn kernel
 
     alpha : float
-        Clamping factor
+        Clamping factor.
+
+        .. deprecated:: 0.19
+            This parameter will be removed in 0.21.
+            'alpha' is fixed to zero in 'LabelPropagation'.
 
     max_iter : float
         Change maximum number of iterations allowed
@@ -350,6 +373,14 @@ class LabelPropagation(BaseLabelPropagation):
     LabelSpreading : Alternate label propagation strategy more robust to noise
     """
 
+    _variant = 'propagation'
+
+    def __init__(self, kernel='rbf', gamma=20, n_neighbors=7,
+                 alpha=None, max_iter=30, tol=1e-3, n_jobs=1):
+        super(LabelPropagation, self).__init__(
+            kernel=kernel, gamma=gamma, n_neighbors=n_neighbors, alpha=alpha,
+            max_iter=max_iter, tol=tol, n_jobs=n_jobs)
+
     def _build_graph(self):
         """Matrix representing a fully connected graph between each sample
 
@@ -366,6 +397,15 @@ class LabelPropagation(BaseLabelPropagation):
             affinity_matrix /= normalizer[:, np.newaxis]
         return affinity_matrix
 
+    def fit(self, X, y):
+        if self.alpha is not None:
+            warnings.warn(
+                "alpha is deprecated since 0.19 and will be removed in 0.21.",
+                DeprecationWarning
+            )
+            self.alpha = None
+        return super(LabelPropagation, self).fit(X, y)
+
 
 class LabelSpreading(BaseLabelPropagation):
     """LabelSpreading model for semi-supervised learning
@@ -391,7 +431,11 @@ class LabelSpreading(BaseLabelPropagation):
       parameter for knn kernel
 
     alpha : float
-      clamping factor
+      Clamping factor. A value in [0, 1] that specifies the relative amount
+      that an instance should adopt the information from its neighbors as
+      opposed to its initial label.
+      alpha=0 means keeping the initial label information; alpha=1 means
+      replacing all initial information.
 
     max_iter : float
       maximum number of iterations allowed
@@ -446,6 +490,8 @@ class LabelSpreading(BaseLabelPropagation):
     LabelPropagation : Unregularized graph based semi-supervised learning
     """
 
+    _variant = 'spreading'
+
     def __init__(self, kernel='rbf', gamma=20, n_neighbors=7, alpha=0.2,
                  max_iter=30, tol=1e-3, n_jobs=1):
 
diff --git a/sklearn/semi_supervised/tests/test_label_propagation.py b/sklearn/semi_supervised/tests/test_label_propagation.py
index 81e7dd028b..3d5bd21a89 100644
--- a/sklearn/semi_supervised/tests/test_label_propagation.py
+++ b/sklearn/semi_supervised/tests/test_label_propagation.py
@@ -3,8 +3,12 @@
 import numpy as np
 
 from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_raises
+from sklearn.utils.testing import assert_no_warnings
 from sklearn.semi_supervised import label_propagation
 from sklearn.metrics.pairwise import rbf_kernel
+from sklearn.datasets import make_classification
 from numpy.testing import assert_array_almost_equal
 from numpy.testing import assert_array_equal
 
@@ -59,3 +63,85 @@ def test_predict_proba():
         clf = estimator(**parameters).fit(samples, labels)
         assert_array_almost_equal(clf.predict_proba([[1., 1.]]),
                                   np.array([[0.5, 0.5]]))
+
+
+def test_alpha_deprecation():
+    X, y = make_classification(n_samples=100)
+    y[::3] = -1
+
+    lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
+    lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
+
+    lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
+    lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
+
+    assert_array_equal(lp_default_y, lp_0_y)
+
+
+def test_label_spreading_closed_form():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    y[::3] = -1
+    clf = label_propagation.LabelSpreading().fit(X, y)
+    # adopting notation from Zhou et al (2004):
+    S = clf._build_graph()
+    Y = np.zeros((len(y), n_classes + 1))
+    Y[np.arange(len(y)), y] = 1
+    Y = Y[:, :-1]
+    for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]:
+        expected = np.dot(np.linalg.inv(np.eye(len(S)) - alpha * S), Y)
+        expected /= expected.sum(axis=1)[:, np.newaxis]
+        clf = label_propagation.LabelSpreading(max_iter=10000, alpha=alpha)
+        clf.fit(X, y)
+        assert_array_almost_equal(expected, clf.label_distributions_, 4)
+
+
+def test_label_propagation_closed_form():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    y[::3] = -1
+    Y = np.zeros((len(y), n_classes + 1))
+    Y[np.arange(len(y)), y] = 1
+    unlabelled_idx = Y[:, (-1,)].nonzero()[0]
+    labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0]
+
+    clf = label_propagation.LabelPropagation(max_iter=10000,
+                                             gamma=0.1).fit(X, y)
+    # adopting notation from Zhu et al 2002
+    T_bar = clf._build_graph()
+    Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')]
+    Tul = T_bar[np.meshgrid(unlabelled_idx, labelled_idx, indexing='ij')]
+    Y = Y[:, :-1]
+    Y_l = Y[labelled_idx, :]
+    Y_u = np.dot(np.dot(np.linalg.inv(np.eye(Tuu.shape[0]) - Tuu), Tul), Y_l)
+
+    expected = Y.copy()
+    expected[unlabelled_idx, :] = Y_u
+    expected /= expected.sum(axis=1)[:, np.newaxis]
+
+    assert_array_almost_equal(expected, clf.label_distributions_, 4)
+
+
+def test_valid_alpha():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    for alpha in [-0.1, 0, 1, 1.1, None]:
+        assert_raises(ValueError,
+                      lambda **kwargs:
+                      label_propagation.LabelSpreading(**kwargs).fit(X, y),
+                      alpha=alpha)
+
+
+def test_convergence_speed():
+    # This is a non-regression test for #5774
+    X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
+    y = np.array([0, 1, -1])
+    mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=5000)
+    mdl.fit(X, y)
+
+    # this should converge quickly:
+    assert mdl.n_iter_ < 10
+    assert_array_equal(mdl.predict(X), [0, 1, 1])
-- 
GitLab