diff --git a/doc/modules/label_propagation.rst b/doc/modules/label_propagation.rst
index eddc34b7a8c7c18c927cb6f500998e6b81e4048d..1aba742723f01754feb1fc035892c86ecebe3082 100644
--- a/doc/modules/label_propagation.rst
+++ b/doc/modules/label_propagation.rst
@@ -52,8 +52,8 @@ differ in modifications to the similarity matrix that graph and the
 clamping effect on the label distributions.
 Clamping allows the algorithm to change the weight of the true ground labeled
 data to some degree. The :class:`LabelPropagation` algorithm performs hard
-clamping of input labels, which means :math:`\alpha=1`. This clamping factor
-can be relaxed, to say :math:`\alpha=0.8`, which means that we will always
+clamping of input labels, which means :math:`\alpha=0`. This clamping factor
+can be relaxed, to say :math:`\alpha=0.2`, which means that we will always
 retain 80 percent of our original label distribution, but the algorithm gets to
 change its confidence of the distribution within 20 percent.
 
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index a9601419c9edd2edda768fae4c732e6e88f09919..73fa6dcee8b060246afe12f973f022f12eacad4f 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -448,7 +448,16 @@ Bug fixes
      in :class:`decomposition.PCA`,
      :class:`decomposition.RandomizedPCA` and
      :class:`decomposition.IncrementalPCA`.
-     :issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_. 
+     :issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.
+
+   - Fix :class:`semi_supervised.BaseLabelPropagation` to correctly implement
+     ``LabelPropagation`` and ``LabelSpreading`` as done in the referenced
+     papers. :class:`semi_supervised.LabelPropagation` now always does hard
+     clamping. Its ``alpha`` parameter has no effect and is
+     deprecated to be removed in 0.21. :issue:`6727` :issue:`3550` issue:`5770`
+     by :user:`Andre Ambrosio Boechat <boechat107>`, :user:`Utkarsh Upadhyay
+     <musically-ut>`, and `Joel Nothman`_.
+
 
 API changes summary
 -------------------
diff --git a/examples/semi_supervised/plot_label_propagation_structure.py b/examples/semi_supervised/plot_label_propagation_structure.py
index 7cc15d73f1b891c207b032e597483e944d32a5cb..95f19ec108e820b9ed4d6292b50616a27388a2fb 100644
--- a/examples/semi_supervised/plot_label_propagation_structure.py
+++ b/examples/semi_supervised/plot_label_propagation_structure.py
@@ -30,7 +30,7 @@ labels[-1] = inner
 
 # #############################################################################
 # Learn with LabelSpreading
-label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=1.0)
+label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=0.2)
 label_spread.fit(X, labels)
 
 # #############################################################################
diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py
index 1759b2c1d7572a3e133c1fc1731a98d6b2e0f3cc..ab0dd64bf81ea7fde43a07ce1bd8f1a74157a727 100644
--- a/sklearn/semi_supervised/label_propagation.py
+++ b/sklearn/semi_supervised/label_propagation.py
@@ -14,11 +14,12 @@ For more information see the references below.
 Model Features
 --------------
 Label clamping:
-  The algorithm tries to learn distributions of labels over the dataset. In the
-  "Hard Clamp" mode, the true ground labels are never allowed to change. They
-  are clamped into position. In the "Soft Clamp" mode, they are allowed some
-  wiggle room, but some alpha of their original value will always be retained.
-  Hard clamp is the same as soft clamping with alpha set to 1.
+  The algorithm tries to learn distributions of labels over the dataset given
+  label assignments over an initial subset. In one variant, the algorithm does
+  not allow for any errors in the initial assignment (hard-clamping) while
+  in another variant, the algorithm allows for some wiggle room for the initial
+  assignments, allowing them to change by a fraction alpha in each iteration
+  (soft-clamping).
 
 Kernel:
   A function which projects a vector into some higher dimensional space. This
@@ -55,6 +56,7 @@ Non-Parametric Function Induction in Semi-Supervised Learning. AISTAT 2005
 # License: BSD
 from abc import ABCMeta, abstractmethod
 
+import warnings
 import numpy as np
 from scipy import sparse
 
@@ -239,10 +241,13 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
 
         n_samples, n_classes = len(y), len(classes)
 
+        alpha = self.alpha
+        if self._variant == 'spreading' and \
+                (alpha is None or alpha <= 0.0 or alpha >= 1.0):
+            raise ValueError('alpha=%s is invalid: it must be inside '
+                             'the open interval (0, 1)' % alpha)
         y = np.asarray(y)
         unlabeled = y == -1
-        clamp_weights = np.ones((n_samples, 1))
-        clamp_weights[unlabeled, 0] = self.alpha
 
         # initialize distributions
         self.label_distributions_ = np.zeros((n_samples, n_classes))
@@ -250,13 +255,17 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
             self.label_distributions_[y == label, classes == label] = 1
 
         y_static = np.copy(self.label_distributions_)
-        if self.alpha > 0.:
-            y_static *= 1 - self.alpha
-        y_static[unlabeled] = 0
+        if self._variant == 'propagation':
+            # LabelPropagation
+            y_static[unlabeled] = 0
+        else:
+            # LabelSpreading
+            y_static *= 1 - alpha
 
         l_previous = np.zeros((self.X_.shape[0], n_classes))
 
         remaining_iter = self.max_iter
+        unlabeled = unlabeled[:, np.newaxis]
         if sparse.isspmatrix(graph_matrix):
             graph_matrix = graph_matrix.tocsr()
         while (_not_converged(self.label_distributions_, l_previous, self.tol)
@@ -264,13 +273,23 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
             l_previous = self.label_distributions_
             self.label_distributions_ = safe_sparse_dot(
                 graph_matrix, self.label_distributions_)
-            # clamp
-            self.label_distributions_ = np.multiply(
-                clamp_weights, self.label_distributions_) + y_static
+
+            if self._variant == 'propagation':
+                normalizer = np.sum(
+                    self.label_distributions_, axis=1)[:, np.newaxis]
+                self.label_distributions_ /= normalizer
+                self.label_distributions_ = np.where(unlabeled,
+                                                     self.label_distributions_,
+                                                     y_static)
+            else:
+                # clamp
+                self.label_distributions_ = np.multiply(
+                    alpha, self.label_distributions_) + y_static
             remaining_iter -= 1
 
         normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
         self.label_distributions_ /= normalizer
+
         # set the transduction item
         transduction = self.classes_[np.argmax(self.label_distributions_,
                                                axis=1)]
@@ -299,7 +318,11 @@ class LabelPropagation(BaseLabelPropagation):
         Parameter for knn kernel
 
     alpha : float
-        Clamping factor
+        Clamping factor.
+
+        .. deprecated:: 0.19
+            This parameter will be removed in 0.21.
+            'alpha' is fixed to zero in 'LabelPropagation'.
 
     max_iter : float
         Change maximum number of iterations allowed
@@ -350,6 +373,14 @@ class LabelPropagation(BaseLabelPropagation):
     LabelSpreading : Alternate label propagation strategy more robust to noise
     """
 
+    _variant = 'propagation'
+
+    def __init__(self, kernel='rbf', gamma=20, n_neighbors=7,
+                 alpha=None, max_iter=30, tol=1e-3, n_jobs=1):
+        super(LabelPropagation, self).__init__(
+            kernel=kernel, gamma=gamma, n_neighbors=n_neighbors, alpha=alpha,
+            max_iter=max_iter, tol=tol, n_jobs=n_jobs)
+
     def _build_graph(self):
         """Matrix representing a fully connected graph between each sample
 
@@ -366,6 +397,15 @@ class LabelPropagation(BaseLabelPropagation):
             affinity_matrix /= normalizer[:, np.newaxis]
         return affinity_matrix
 
+    def fit(self, X, y):
+        if self.alpha is not None:
+            warnings.warn(
+                "alpha is deprecated since 0.19 and will be removed in 0.21.",
+                DeprecationWarning
+            )
+            self.alpha = None
+        return super(LabelPropagation, self).fit(X, y)
+
 
 class LabelSpreading(BaseLabelPropagation):
     """LabelSpreading model for semi-supervised learning
@@ -391,7 +431,11 @@ class LabelSpreading(BaseLabelPropagation):
       parameter for knn kernel
 
     alpha : float
-      clamping factor
+      Clamping factor. A value in [0, 1] that specifies the relative amount
+      that an instance should adopt the information from its neighbors as
+      opposed to its initial label.
+      alpha=0 means keeping the initial label information; alpha=1 means
+      replacing all initial information.
 
     max_iter : float
       maximum number of iterations allowed
@@ -446,6 +490,8 @@ class LabelSpreading(BaseLabelPropagation):
     LabelPropagation : Unregularized graph based semi-supervised learning
     """
 
+    _variant = 'spreading'
+
     def __init__(self, kernel='rbf', gamma=20, n_neighbors=7, alpha=0.2,
                  max_iter=30, tol=1e-3, n_jobs=1):
 
diff --git a/sklearn/semi_supervised/tests/test_label_propagation.py b/sklearn/semi_supervised/tests/test_label_propagation.py
index 81e7dd028bf5d56c8c5eed42e5410150ab22ebf8..3d5bd21a89110c5bf5f81684d96e814a56badd33 100644
--- a/sklearn/semi_supervised/tests/test_label_propagation.py
+++ b/sklearn/semi_supervised/tests/test_label_propagation.py
@@ -3,8 +3,12 @@
 import numpy as np
 
 from sklearn.utils.testing import assert_equal
+from sklearn.utils.testing import assert_warns
+from sklearn.utils.testing import assert_raises
+from sklearn.utils.testing import assert_no_warnings
 from sklearn.semi_supervised import label_propagation
 from sklearn.metrics.pairwise import rbf_kernel
+from sklearn.datasets import make_classification
 from numpy.testing import assert_array_almost_equal
 from numpy.testing import assert_array_equal
 
@@ -59,3 +63,85 @@ def test_predict_proba():
         clf = estimator(**parameters).fit(samples, labels)
         assert_array_almost_equal(clf.predict_proba([[1., 1.]]),
                                   np.array([[0.5, 0.5]]))
+
+
+def test_alpha_deprecation():
+    X, y = make_classification(n_samples=100)
+    y[::3] = -1
+
+    lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
+    lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
+
+    lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
+    lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
+
+    assert_array_equal(lp_default_y, lp_0_y)
+
+
+def test_label_spreading_closed_form():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    y[::3] = -1
+    clf = label_propagation.LabelSpreading().fit(X, y)
+    # adopting notation from Zhou et al (2004):
+    S = clf._build_graph()
+    Y = np.zeros((len(y), n_classes + 1))
+    Y[np.arange(len(y)), y] = 1
+    Y = Y[:, :-1]
+    for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]:
+        expected = np.dot(np.linalg.inv(np.eye(len(S)) - alpha * S), Y)
+        expected /= expected.sum(axis=1)[:, np.newaxis]
+        clf = label_propagation.LabelSpreading(max_iter=10000, alpha=alpha)
+        clf.fit(X, y)
+        assert_array_almost_equal(expected, clf.label_distributions_, 4)
+
+
+def test_label_propagation_closed_form():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    y[::3] = -1
+    Y = np.zeros((len(y), n_classes + 1))
+    Y[np.arange(len(y)), y] = 1
+    unlabelled_idx = Y[:, (-1,)].nonzero()[0]
+    labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0]
+
+    clf = label_propagation.LabelPropagation(max_iter=10000,
+                                             gamma=0.1).fit(X, y)
+    # adopting notation from Zhu et al 2002
+    T_bar = clf._build_graph()
+    Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')]
+    Tul = T_bar[np.meshgrid(unlabelled_idx, labelled_idx, indexing='ij')]
+    Y = Y[:, :-1]
+    Y_l = Y[labelled_idx, :]
+    Y_u = np.dot(np.dot(np.linalg.inv(np.eye(Tuu.shape[0]) - Tuu), Tul), Y_l)
+
+    expected = Y.copy()
+    expected[unlabelled_idx, :] = Y_u
+    expected /= expected.sum(axis=1)[:, np.newaxis]
+
+    assert_array_almost_equal(expected, clf.label_distributions_, 4)
+
+
+def test_valid_alpha():
+    n_classes = 2
+    X, y = make_classification(n_classes=n_classes, n_samples=200,
+                               random_state=0)
+    for alpha in [-0.1, 0, 1, 1.1, None]:
+        assert_raises(ValueError,
+                      lambda **kwargs:
+                      label_propagation.LabelSpreading(**kwargs).fit(X, y),
+                      alpha=alpha)
+
+
+def test_convergence_speed():
+    # This is a non-regression test for #5774
+    X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
+    y = np.array([0, 1, -1])
+    mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=5000)
+    mdl.fit(X, y)
+
+    # this should converge quickly:
+    assert mdl.n_iter_ < 10
+    assert_array_equal(mdl.predict(X), [0, 1, 1])