diff --git a/scikits/learn/metrics/metrics.py b/scikits/learn/metrics/metrics.py index e5d2b8bf9909cd745ed95fdfbebbf47242b3a72f..dd653c245eba1212bd768175e39939d70047f629 100644 --- a/scikits/learn/metrics/metrics.py +++ b/scikits/learn/metrics/metrics.py @@ -118,16 +118,39 @@ def roc_curve(y_true, y_score): y_score = y_score.ravel() thresholds = np.sort(np.unique(y_score))[::-1] - n_thresholds = thresholds.size - tpr = np.empty(n_thresholds) # True positive rate - fpr = np.empty(n_thresholds) # False positive rate n_pos = float(np.sum(y_true == classes[1])) # nb of true positive n_neg = float(np.sum(y_true == classes[0])) # nb of true negative - for i, t in enumerate(thresholds): - tpr[i] = np.sum(y_true[y_score >= t] == classes[1]) / n_pos - fpr[i] = np.sum(y_true[y_score >= t] == classes[0]) / n_neg + thresholds = np.unique(y_score) + neg_value, pos_value = classes[0], classes[1] + + tpr = np.empty(thresholds.size) # True positive rate + fpr = np.empty(thresholds.size) # False positive rate + + # Buid tpr/fpr vector + dpos = dneg = sum_pos = sum_neg = idx = 0 + + sorted_signal = sorted(zip(y_score, y_true), reverse=True) + last_input = sorted_signal[0][0] + for each, value in sorted_signal: + if each == last_input: + if value == pos_value: + dpos += 1 + else: + dneg += 1 + else: + tpr[idx] = (sum_pos + dpos) / n_pos + fpr[idx] = (sum_neg + dneg) / n_neg + sum_pos += dpos + sum_neg += dneg + dpos = 1 if value == pos_value else 0 + dneg = 1 if value == neg_value else 0 + idx += 1 + last_input = each + else: + tpr[-1] = (sum_pos + dpos) / n_pos + fpr[-1] = (sum_neg + dneg) / n_neg # hard decisions, add (0,0) if fpr.shape[0] == 2: @@ -137,6 +160,7 @@ def roc_curve(y_true, y_score): elif fpr.shape[0] == 1: fpr = np.array([0.0, fpr[0], 1.0]) tpr = np.array([0.0, tpr[0], 1.0]) + return fpr, tpr, thresholds