Skip to content
Snippets Groups Projects
Commit b54ecb6e authored by Andreas Mueller's avatar Andreas Mueller
Browse files

TST small improvement of test for sample weight in svm

parent c2575435
Branches
Tags
No related merge requests found
......@@ -9,6 +9,8 @@ is proportional to its weight.
The sample weighting rescales the C parameter, which means that the classifier
puts more emphasis on getting these points right. The effect might often be
subtle.
To emphasis the effect here, we particularly weight outliers, making the
deformation of the decision boundary very visible.
"""
print(__doc__)
......@@ -39,8 +41,9 @@ X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
Y = [1] * 10 + [-1] * 10
sample_weight_last_ten = abs(np.random.randn(len(X)))
sample_weight_constant = np.ones(len(X))
# and assign a bigger weight to the last 5 samples
# and bigger weights to some outliers
sample_weight_last_ten[15:] *= 5
sample_weight_last_ten[9] *= 15
# for reference, first fit without class weights
......
......@@ -104,7 +104,7 @@ class BaseLibSVM(six.with_metaclass(ABCMeta, BaseEstimator)):
Training vectors, where n_samples is the number of samples
and n_features is the number of features.
y : array-like, shape (n_samples)
y : array-like, shape (n_samples,)
Target values (class labels in classification, real numbers in
regression)
......@@ -268,7 +268,7 @@ class BaseLibSVM(six.with_metaclass(ABCMeta, BaseEstimator)):
Returns
-------
y_pred : array, shape (n_samples)
y_pred : array, shape (n_samples,)
"""
X = self._validate_for_predict(X)
predict = self._sparse_predict if self._sparse else self._dense_predict
......
......@@ -340,6 +340,14 @@ def test_sample_weights():
clf.fit(X, Y, sample_weight=sample_weight)
assert_array_equal(clf.predict(X[2]), [2.])
# test that rescaling all samples is the same as changing C
clf = svm.SVC()
clf.fit(X, Y)
dual_coef_no_weight = clf.dual_coef_
clf.set_params(C=100)
clf.fit(X, Y, sample_weight=np.repeat(0.01, len(X)))
assert_array_almost_equal(dual_coef_no_weight, clf.dual_coef_)
def test_auto_weight():
"""Test class weights for imbalanced data"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment