Skip to content
Snippets Groups Projects
Commit a3eb84d1 authored by Jaques Grobler's avatar Jaques Grobler Committed by Gael Varoquaux
Browse files

modified test_gmm to match API changes in gmm.py

parent da56baab
No related branches found
No related tags found
No related merge requests found
......@@ -191,17 +191,23 @@ class GMMTester():
X = g.sample(n_samples=100)
g = self.model(n_components=self.n_components,
covariance_type=self.covariance_type,
random_state=rng, min_covar=1e-1)
g.fit(X, n_iter=1, init_params=params)
random_state=rng, min_covar=1e-1,
n_iter=1, init_params=params)
g.fit(X)
# Do one training iteration at a time so we can keep track of
# the log likelihood to make sure that it increases after each
# iteration.
trainll = []
for iter in xrange(5):
g.fit(X, n_iter=1, params=params, init_params='')
g.params = params
g.init_params = ''
g.fit(X)
trainll.append(self.score(g, X))
g.fit(X, n_iter=10, params=params, init_params='') # finish fitting
g.n_iter = 10
g.init_params = ''
g.params = params
g.fit(X) # finish fitting
# Note that the log likelihood will sometimes decrease by a
# very small amount after it has more or less converged due to
......@@ -222,8 +228,9 @@ class GMMTester():
X = rng.randn(100, self.n_features)
X.T[1:] = 0
g = self.model(n_components=2, covariance_type=self.covariance_type,
random_state=rng, min_covar=1e-3)
g.fit(X, n_iter=5, init_params=params)
random_state=rng, min_covar=1e-3, n_iter=5,
init_params=params)
g.fit(X)
trainll = g.score(X)
self.assertTrue(np.sum(np.abs(trainll / 100 / X.shape[1])) < 5)
......@@ -234,8 +241,9 @@ class GMMTester():
X = rng.randn(100, 1)
#X.T[1:] = 0
g = self.model(n_components=2, covariance_type=self.covariance_type,
random_state=rng, min_covar=1e-7)
g.fit(X, n_iter=5, init_params=params)
random_state=rng, min_covar=1e-7, n_iter=5,
init_params=params)
g.fit(X)
trainll = g.score(X)
if isinstance(g, mixture.DPGMM):
self.assertTrue(np.sum(np.abs(trainll / 100)) < 5)
......@@ -271,9 +279,10 @@ def test_multiple_init():
X = rng.randn(30, 5)
X[:10] += 2
g = mixture.GMM(n_components=2, covariance_type='spherical',
random_state=rng, min_covar=1e-7)
train2 = g.fit(X, n_iter=5, n_init=5).score(X).sum()
train1 = g.fit(X, n_iter=5).score(X).sum()
random_state=rng, min_covar=1e-7, n_iter=5)
train1 = g.fit(X).score(X).sum()
g.n_init = 5
train2 = g.fit(X).score(X).sum()
assert train2 >= train1 - 1.e-2
......@@ -284,8 +293,8 @@ def test_n_parameters():
n_params = {'spherical': 13, 'diag': 21, 'tied': 26, 'full': 41}
for cv_type in ['full', 'tied', 'diag', 'spherical']:
g = mixture.GMM(n_components=n_components, covariance_type=cv_type,
random_state=rng, min_covar=1e-7)
g.fit(X, n_iter=1)
random_state=rng, min_covar=1e-7, n_iter=1)
g.fit(X)
assert g._n_parameters() == n_params[cv_type]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment