diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 66e8212947cf015e5b7df4c5f36c683a95807ed5..34f82c30e981e4b4a17ebf3908b2f4e06703a17e 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -110,6 +110,10 @@ Enhancements
      attributes, ``n_skips_*``.
      :issue:`7914` by :user:`Michael Horrell <mthorrell>`.
 
+   - :func:`model_selection.cross_val_predict` now returns output of the
+     correct shape for all values of the argument ``method``.
+     :issue:`7863` by :user:`Aman Dalmia <dalmia>`.
+
    - Fix a bug where :class:`sklearn.feature_selection.SelectFdr` did not
      exactly implement Benjamini-Hochberg procedure. It formerly may have
      selected fewer features than it should.
diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py
index 91f60366f871703ac0dbcf4256a9aee0e4b3caee..ab18d9035b4d207b6b9259c68e2026c24f7629e0 100644
--- a/sklearn/model_selection/_validation.py
+++ b/sklearn/model_selection/_validation.py
@@ -28,6 +28,7 @@ from ..externals.joblib import Parallel, delayed, logger
 from ..metrics.scorer import check_scoring
 from ..exceptions import FitFailedWarning
 from ._split import check_cv
+from ..preprocessing import LabelEncoder
 
 __all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
            'learning_curve', 'validation_curve']
@@ -364,7 +365,9 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
               as in '2*n_jobs'
 
     method : string, optional, default: 'predict'
-        Invokes the passed method name of the passed estimator.
+        Invokes the passed method name of the passed estimator. For
+        method='predict_proba', the columns correspond to the classes
+        in sorted order.
 
     Returns
     -------
@@ -390,6 +393,10 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
         raise AttributeError('{} not implemented in estimator'
                              .format(method))
 
+    if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
+        le = LabelEncoder()
+        y = le.fit_transform(y)
+
     # We clone the estimator to make sure that all the folds are
     # independent, and that it is pickle-able.
     parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
@@ -472,6 +479,14 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
         estimator.fit(X_train, y_train, **fit_params)
     func = getattr(estimator, method)
     predictions = func(X_test)
+    if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
+        n_classes = len(set(y))
+        predictions_ = np.zeros((X_test.shape[0], n_classes))
+        if method == 'decision_function' and len(estimator.classes_) == 2:
+            predictions_[:, estimator.classes_[-1]] = predictions
+        else:
+            predictions_[:, estimator.classes_] = predictions
+        predictions = predictions_
     return predictions, test
 
 
diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py
index d1f83b469d6c8dd65a19a7039a63b22788815e29..c6ae5f3fdd18a553e6de886957c95f607f0cfefc 100644
--- a/sklearn/model_selection/tests/test_validation.py
+++ b/sklearn/model_selection/tests/test_validation.py
@@ -51,6 +51,7 @@ from sklearn.svm import SVC
 from sklearn.cluster import KMeans
 
 from sklearn.preprocessing import Imputer
+from sklearn.preprocessing import LabelEncoder
 from sklearn.pipeline import Pipeline
 
 from sklearn.externals.six.moves import cStringIO as StringIO
@@ -940,6 +941,79 @@ def test_cross_val_predict_with_method():
                                         cv=kfold)
         assert_array_almost_equal(expected_predictions, predictions)
 
+        # Test alternative representations of y
+        predictions_y1 = cross_val_predict(est, X, y + 1, method=method,
+                                           cv=kfold)
+        assert_array_equal(predictions, predictions_y1)
+
+        predictions_y2 = cross_val_predict(est, X, y - 2, method=method,
+                                           cv=kfold)
+        assert_array_equal(predictions, predictions_y2)
+
+        predictions_ystr = cross_val_predict(est, X, y.astype('str'),
+                                             method=method, cv=kfold)
+        assert_array_equal(predictions, predictions_ystr)
+
+
+def get_expected_predictions(X, y, cv, classes, est, method):
+
+    expected_predictions = np.zeros([len(y), classes])
+    func = getattr(est, method)
+
+    for train, test in cv.split(X, y):
+        est.fit(X[train], y[train])
+        expected_predictions_ = func(X[test])
+        # To avoid 2 dimensional indexing
+        exp_pred_test = np.zeros((len(test), classes))
+        if method is 'decision_function' and len(est.classes_) == 2:
+            exp_pred_test[:, est.classes_[-1]] = expected_predictions_
+        else:
+            exp_pred_test[:, est.classes_] = expected_predictions_
+        expected_predictions[test] = exp_pred_test
+
+    return expected_predictions
+
+
+def test_cross_val_predict_class_subset():
+
+    X = np.arange(8).reshape(4, 2)
+    y = np.array([0, 0, 1, 2])
+    classes = 3
+
+    kfold3 = KFold(n_splits=3)
+    kfold4 = KFold(n_splits=4)
+
+    le = LabelEncoder()
+
+    methods = ['decision_function', 'predict_proba', 'predict_log_proba']
+    for method in methods:
+        est = LogisticRegression()
+
+        # Test with n_splits=3
+        predictions = cross_val_predict(est, X, y, method=method,
+                                        cv=kfold3)
+
+        # Runs a naive loop (should be same as cross_val_predict):
+        expected_predictions = get_expected_predictions(X, y, kfold3, classes,
+                                                        est, method)
+        assert_array_almost_equal(expected_predictions, predictions)
+
+        # Test with n_splits=4
+        predictions = cross_val_predict(est, X, y, method=method,
+                                        cv=kfold4)
+        expected_predictions = get_expected_predictions(X, y, kfold4, classes,
+                                                        est, method)
+        assert_array_almost_equal(expected_predictions, predictions)
+
+        # Testing unordered labels
+        y = [1, 1, -4, 6]
+        predictions = cross_val_predict(est, X, y, method=method,
+                                        cv=kfold3)
+        y = le.fit_transform(y)
+        expected_predictions = get_expected_predictions(X, y, kfold3, classes,
+                                                        est, method)
+        assert_array_almost_equal(expected_predictions, predictions)
+
 
 def test_score_memmap():
     # Ensure a scalar score of memmap type is accepted