Skip to content

Commit

Permalink
Allow fitting single column data
Browse files Browse the repository at this point in the history
  • Loading branch information
aherbert committed Mar 5, 2024
1 parent f805657 commit 7708824
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public class MultivariateNormalMixtureExpectationMaximization {
* @throws DimensionMismatchException if rows of data have different numbers
* of columns
* @throws NumberIsTooSmallException if the number of columns in the data is
* less than 2
* less than 1
*/
public MultivariateNormalMixtureExpectationMaximization(double[][] data)
throws NotStrictlyPositiveException,
Expand All @@ -99,9 +99,9 @@ public MultivariateNormalMixtureExpectationMaximization(double[][] data)
throw new DimensionMismatchException(data[i].length,
data[0].length);
}
if (data[i].length < 2) {
if (data[i].length < 1) {
throw new NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
data[i].length, 2, true);
data[i].length, 1, true);
}
this.data[i] = Arrays.copyOf(data[i], data[i].length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.commons.math4.legacy.distribution.fitting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
Expand Down Expand Up @@ -241,6 +242,53 @@ public void testFit() {
}
}

@Test
public void testFit1() {
// Test that the fit can be performed on data with a single dimension
// Use only the first column of the test data
final double[][] data = Arrays.stream(getTestSamples())
.map(x -> new double[] {x[0]}).toArray(double[][]::new);

// Fit the first column of test samples using Matlab R2023b (Update 6):
// GMModel = fitgmdist(X,2);

// NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
final double correctLogLikelihood = -2.512197016873482e+02 / data.length;
// ComponentProportion
final double[] correctWeights = new double[] {0.240510201974078, 0.759489798025922};
// Since data has 1 dimension the means and covariances are single values
// mu
final double[] correctMeans = new double[] {-1.736139126623031, 3.899886984922886};
// Sigma
final double[] correctCov = new double[] {1.371327786710623, 5.254286022455004};

MultivariateNormalMixtureExpectationMaximization fitter
= new MultivariateNormalMixtureExpectationMaximization(data);

MixtureMultivariateNormalDistribution initialMix
= MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
fitter.fit(initialMix);
MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();

final double relError = 0.05;
Assert.assertEquals(correctLogLikelihood,
fitter.getLogLikelihood(),
Math.abs(correctLogLikelihood) * relError);

int i = 0;
for (Pair<Double, MultivariateNormalDistribution> component : components) {
final double weight = component.getFirst();
final MultivariateNormalDistribution mvn = component.getSecond();
final double[] mean = mvn.getMeans();
final RealMatrix covMat = mvn.getCovariances();
Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * relError);
Assert.assertEquals(correctMeans[i], mean[0], Math.abs(correctMeans[i]) * relError);
Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), correctCov[i] * relError);
i++;
}
}

private double[][] getTestSamples() {
// generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
// [4, 8.2]
Expand Down

0 comments on commit 7708824

Please sign in to comment.