From a8306d46a672344bd590943269dd889ccd142291 Mon Sep 17 00:00:00 2001
From: hongkahjun <khong008@e.ntu.edu.sg>
Date: Wed, 5 Jul 2017 14:48:30 +0200
Subject: [PATCH] [MRG+1] NMF speed-up for beta_loss = 0 (#9277)

---
 sklearn/decomposition/nmf.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/sklearn/decomposition/nmf.py b/sklearn/decomposition/nmf.py
index 72a52f802a..47eb42496f 100644
--- a/sklearn/decomposition/nmf.py
+++ b/sklearn/decomposition/nmf.py
@@ -545,6 +545,13 @@ def _multiplicative_update_w(X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma,
 
         if beta_loss == 1:
             np.divide(X_data, WH_safe_X_data, out=WH_safe_X_data)
+        elif beta_loss == 0:
+            # speeds up computation time
+            # refer to /numpy/numpy/issues/9363
+            WH_safe_X_data **= -1
+            WH_safe_X_data **= 2
+            # element-wise multiplication
+            WH_safe_X_data *= X_data
         else:
             WH_safe_X_data **= beta_loss - 2
             # element-wise multiplication
@@ -619,6 +626,13 @@ def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma):
 
         if beta_loss == 1:
             np.divide(X_data, WH_safe_X_data, out=WH_safe_X_data)
+        elif beta_loss == 0:
+            # speeds up computation time
+            # refer to /numpy/numpy/issues/9363
+            WH_safe_X_data **= -1
+            WH_safe_X_data **= 2
+            # element-wise multiplication
+            WH_safe_X_data *= X_data
         else:
             WH_safe_X_data **= beta_loss - 2
             # element-wise multiplication
@@ -1167,6 +1181,7 @@ class NMF(BaseEstimator, TransformerMixin):
     Fevotte, C., & Idier, J. (2011). Algorithms for nonnegative matrix
     factorization with the beta-divergence. Neural Computation, 23(9).
     """
+
     def __init__(self, n_components=None, init=None, solver='cd',
                  beta_loss='frobenius', tol=1e-4, max_iter=200,
                  random_state=None, alpha=0., l1_ratio=0., verbose=0,
-- 
GitLab