From b68c5df241fec0b654d919066dca0663d909bde1 Mon Sep 17 00:00:00 2001 From: Darius Morawiec Date: Sat, 2 Dec 2017 15:53:51 +0100 Subject: [PATCH] Add new templates, examples and tests for the GaussianNB classifier --- .../classifier/GaussianNB/java/basics.ipynb | 86 ++++- .../GaussianNB/java/basics_imported.ipynb | 339 ++++++++++++++++++ .../GaussianNB/java/basics_imported.py | 87 +++++ readme.md | 2 +- .../classifier/GaussianNB/__init__.py | 44 ++- .../templates/java/exported.class.txt | 67 ++++ .../java/{class.txt => separated.class.txt} | 0 ...edict.txt => separated.method.predict.txt} | 0 .../js/{class.txt => separated.class.txt} | 0 ...edict.txt => separated.method.predict.txt} | 0 .../estimator/classifier/SVC/__init__.py | 4 +- .../GaussianNB/GaussianNBJavaTest.py | 3 +- 12 files changed, 616 insertions(+), 16 deletions(-) create mode 100644 examples/estimator/classifier/GaussianNB/java/basics_imported.ipynb create mode 100644 examples/estimator/classifier/GaussianNB/java/basics_imported.py create mode 100644 sklearn_porter/estimator/classifier/GaussianNB/templates/java/exported.class.txt rename sklearn_porter/estimator/classifier/GaussianNB/templates/java/{class.txt => separated.class.txt} (100%) rename sklearn_porter/estimator/classifier/GaussianNB/templates/java/{method.predict.txt => separated.method.predict.txt} (100%) rename sklearn_porter/estimator/classifier/GaussianNB/templates/js/{class.txt => separated.class.txt} (100%) rename sklearn_porter/estimator/classifier/GaussianNB/templates/js/{method.predict.txt => separated.method.predict.txt} (100%) diff --git a/examples/estimator/classifier/GaussianNB/java/basics.ipynb b/examples/estimator/classifier/GaussianNB/java/basics.ipynb index b17b7e2e..c756f070 100644 --- a/examples/estimator/classifier/GaussianNB/java/basics.ipynb +++ b/examples/estimator/classifier/GaussianNB/java/basics.ipynb @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "scrolled": false }, @@ -137,9 +137,9 @@ " }\n", "\n", " // Parameters:\n", - " double[] priors = {0.33333333333333331, 0.33333333333333331, 0.33333333333333331};\n", - " double[][] sigmas = {{0.12176400309242481, 0.14227600309242491, 0.029504003092424898, 0.011264003092424885}, {0.26110400309242499, 0.096500003092424902, 0.21640000309242502, 0.038324003092424869}, {0.39625600309242481, 0.10192400309242496, 0.29849600309242508, 0.073924003092424875}};\n", - " double[][] thetas = {{5.0059999999999993, 3.4180000000000006, 1.464, 0.24399999999999991}, {5.9359999999999999, 2.7700000000000005, 4.2599999999999998, 1.3259999999999998}, {6.5879999999999983, 2.9739999999999998, 5.5519999999999996, 2.0259999999999998}};\n", + " double[] priors = {0.333333333333, 0.333333333333, 0.333333333333};\n", + " double[][] sigmas = {{0.121764003092, 0.142276003092, 0.0295040030924, 0.0112640030924}, {0.261104003092, 0.0965000030924, 0.216400003092, 0.0383240030924}, {0.396256003092, 0.101924003092, 0.298496003092, 0.0739240030924}};\n", + " double[][] thetas = {{5.006, 3.418, 1.464, 0.244}, {5.936, 2.77, 4.26, 1.326}, {6.588, 2.974, 5.552, 2.026}};\n", "\n", " // Prediction:\n", " GaussianNB clf = new GaussianNB(priors, sigmas, thetas);\n", @@ -148,11 +148,15 @@ "\n", " }\n", " }\n", - "}\n" + "}\n", + "CPU times: user 1.06 ms, sys: 777 µs, total: 1.83 ms\n", + "Wall time: 1.15 ms\n" ] } ], "source": [ + "%%time\n", + "\n", "from sklearn_porter import Porter\n", "\n", "porter = Porter(clf)\n", @@ -160,6 +164,78 @@ "\n", "print(output)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run classification in Java:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save the transpiled estimator:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "with open('GaussianNB.java', 'w') as f:\n", + " f.write(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compiling:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "javac -cp . GaussianNB.java" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prediction:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "java -cp . GaussianNB 1 2 3 4" + ] } ], "metadata": { diff --git a/examples/estimator/classifier/GaussianNB/java/basics_imported.ipynb b/examples/estimator/classifier/GaussianNB/java/basics_imported.ipynb new file mode 100644 index 00000000..134217a6 --- /dev/null +++ b/examples/estimator/classifier/GaussianNB/java/basics_imported.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# sklearn-porter\n", + "\n", + "Repository: https://github.com/nok/sklearn-porter\n", + "\n", + "## GaussianNB\n", + "\n", + "Documentation: [sklearn.naive_bayes.GaussianNB](http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading data:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "((150, 4), (150,))\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_iris\n", + "\n", + "iris_data = load_iris()\n", + "X = iris_data.data\n", + "y = iris_data.target\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train classifier:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GaussianNB(priors=None)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.naive_bayes import GaussianNB\n", + "\n", + "clf = GaussianNB()\n", + "clf.fit(X, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transpile classifier:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "import java.io.File;\n", + "import java.io.FileNotFoundException;\n", + "import java.util.*;\n", + "import com.google.gson.Gson;\n", + "\n", + "\n", + "class GaussianNB {\n", + "\n", + " private class Classifier {\n", + " private double[] priors;\n", + " private double[][] sigmas;\n", + " private double[][] thetas;\n", + " }\n", + "\n", + " private Classifier clf;\n", + "\n", + " public GaussianNB(String file) throws FileNotFoundException {\n", + " String jsonStr = new Scanner(new File(file)).useDelimiter(\"\\\\Z\").next();\n", + " this.clf = new Gson().fromJson(jsonStr, Classifier.class);\n", + " }\n", + "\n", + " public int predict(double[] features) {\n", + " double[] likelihoods = new double[this.clf.sigmas.length];\n", + "\n", + " for (int i = 0, il = this.clf.sigmas.length; i < il; i++) {\n", + " double sum = 0.;\n", + " for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {\n", + " sum += Math.log(2. * Math.PI * this.clf.sigmas[i][j]);\n", + " }\n", + " double nij = -0.5 * sum;\n", + " sum = 0.;\n", + " for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {\n", + " sum += Math.pow(features[j] - this.clf.thetas[i][j], 2.) / this.clf.sigmas[i][j];\n", + " }\n", + " nij -= 0.5 * sum;\n", + " likelihoods[i] = Math.log(this.clf.priors[i]) + nij;\n", + " }\n", + "\n", + " int classIdx = 0;\n", + " for (int i = 0, l = likelihoods.length; i < l; i++) {\n", + " classIdx = likelihoods[i] > likelihoods[classIdx] ? i : classIdx;\n", + " }\n", + " return classIdx;\n", + " }\n", + "\n", + " public static void main(String[] args) throws FileNotFoundException {\n", + " if (args.length > 0 && args[0].endsWith(\".json\")) {\n", + "\n", + " // Features:\n", + " double[] features = new double[args.length-1];\n", + " for (int i = 1, l = args.length; i < l; i++) {\n", + " features[i - 1] = Double.parseDouble(args[i]);\n", + " }\n", + "\n", + " // Parameters:\n", + " String modelData = args[0];\n", + "\n", + " // Estimators:\n", + " GaussianNB clf = new GaussianNB(modelData);\n", + "\n", + " // Prediction:\n", + " int prediction = clf.predict(features);\n", + " System.out.println(prediction);\n", + "\n", + " }\n", + " }\n", + "}\n", + "CPU times: user 1.31 ms, sys: 1.16 ms, total: 2.48 ms\n", + "Wall time: 2.26 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "from sklearn_porter import Porter\n", + "\n", + "porter = Porter(clf)\n", + "output = porter.export(export_data=True)\n", + "\n", + "print(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"priors\": [0.333333333333, 0.333333333333, 0.333333333333], \"sigmas\": [[0.121764003092, 0.142276003092, 0.0295040030924, 0.0112640030924], [0.261104003092, 0.0965000030924, 0.216400003092, 0.0383240030924], [0.396256003092, 0.101924003092, 0.298496003092, 0.0739240030924]], \"thetas\": [[5.006, 3.418, 1.464, 0.244], [5.936, 2.77, 4.26, 1.326], [6.588, 2.974, 5.552, 2.026]]}" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "cat data.json" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "hideOutput": false + }, + "source": [ + "### Run classification in Java:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save the transpiled estimator:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "with open('GaussianNB.java', 'w') as f:\n", + " f.write(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download the dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "--2017-12-02 15:31:18-- http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar\n", + "Resolving central.maven.org... 151.101.36.209\n", + "Connecting to central.maven.org|151.101.36.209|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 232932 (227K) [application/java-archive]\n", + "Saving to: 'gson-2.8.2.jar'\n", + "\n", + " 0K .......... .......... .......... .......... .......... 21% 1.93M 0s\n", + " 50K .......... .......... .......... .......... .......... 43% 3.07M 0s\n", + " 100K .......... .......... .......... .......... .......... 65% 3.27M 0s\n", + " 150K .......... .......... .......... .......... .......... 87% 3.56M 0s\n", + " 200K .......... .......... ....... 100% 2.18M=0.08s\n", + "\n", + "2017-12-02 15:31:18 (2.70 MB/s) - 'gson-2.8.2.jar' saved [232932/232932]\n", + "\n" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compiling:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "javac -cp .:gson-2.8.2.jar GaussianNB.java" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prediction:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "java -cp .:gson-2.8.2.jar GaussianNB data.json 1 2 3 4" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/estimator/classifier/GaussianNB/java/basics_imported.py b/examples/estimator/classifier/GaussianNB/java/basics_imported.py new file mode 100644 index 00000000..33e31cca --- /dev/null +++ b/examples/estimator/classifier/GaussianNB/java/basics_imported.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +from sklearn.datasets import load_iris +from sklearn.naive_bayes import GaussianNB +from sklearn_porter import Porter + + +iris_data = load_iris() +X = iris_data.data +y = iris_data.target + +clf = GaussianNB() +clf.fit(X, y) + +porter = Porter(clf) +output = porter.export(export_data=True) +print(output) + +""" +import java.io.File; +import java.io.FileNotFoundException; +import java.util.*; +import com.google.gson.Gson; + + +class GaussianNB { + + private class Classifier { + private double[] priors; + private double[][] sigmas; + private double[][] thetas; + } + + private Classifier clf; + + public GaussianNB(String file) throws FileNotFoundException { + String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next(); + this.clf = new Gson().fromJson(jsonStr, Classifier.class); + } + + public int predict(double[] features) { + double[] likelihoods = new double[this.clf.sigmas.length]; + + for (int i = 0, il = this.clf.sigmas.length; i < il; i++) { + double sum = 0.; + for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) { + sum += Math.log(2. * Math.PI * this.clf.sigmas[i][j]); + } + double nij = -0.5 * sum; + sum = 0.; + for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) { + sum += Math.pow(features[j] - this.clf.thetas[i][j], 2.) / this.clf.sigmas[i][j]; + } + nij -= 0.5 * sum; + likelihoods[i] = Math.log(this.clf.priors[i]) + nij; + } + + int classIdx = 0; + for (int i = 0, l = likelihoods.length; i < l; i++) { + classIdx = likelihoods[i] > likelihoods[classIdx] ? i : classIdx; + } + return classIdx; + } + + public static void main(String[] args) throws FileNotFoundException { + if (args.length > 0 && args[0].endsWith(".json")) { + + // Features: + double[] features = new double[args.length-1]; + for (int i = 1, l = args.length; i < l; i++) { + features[i - 1] = Double.parseDouble(args[i]); + } + + // Parameters: + String modelData = args[0]; + + // Estimators: + GaussianNB clf = new GaussianNB(modelData); + + // Prediction: + int prediction = clf.predict(features); + System.out.println(prediction); + + } + } +} +""" diff --git a/readme.md b/readme.md index 254ec77c..11951d7c 100644 --- a/readme.md +++ b/readme.md @@ -101,7 +101,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e naive_bayes.GaussianNB - + , ✓ ᴵ diff --git a/sklearn_porter/estimator/classifier/GaussianNB/__init__.py b/sklearn_porter/estimator/classifier/GaussianNB/__init__.py index 28a95ec1..3c2a923a 100644 --- a/sklearn_porter/estimator/classifier/GaussianNB/__init__.py +++ b/sklearn_porter/estimator/classifier/GaussianNB/__init__.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- +import os +import json +from json import encoder + from sklearn_porter.estimator.classifier.Classifier import Classifier @@ -52,7 +56,9 @@ def __init__(self, estimator, target_language='java', target_method=target_method, **kwargs) self.estimator = estimator - def export(self, class_name, method_name, **kwargs): + def export(self, class_name, method_name, + export_data=False, export_dir='.', + **kwargs): """ Port a trained estimator to the syntax of a chosen programming language. @@ -111,9 +117,14 @@ def export(self, class_name, method_name, **kwargs): values=thetas) if self.target_method == 'predict': - return self.predict() - - def predict(self): + # Exported: + if export_data and os.path.isdir(export_dir): + self.export_data(export_dir) + return self.predict('exported') + # Separated: + return self.predict('separated') + + def predict(self, temp_type): """ Transpile the predict method. @@ -122,7 +133,26 @@ def predict(self): :return : string The transpiled predict method as string. """ - return self.create_class(self.create_method()) + # Exported: + if temp_type == 'exported': + temp = self.temp('exported.class') + return temp.format(class_name=self.class_name, + method_name=self.method_name) + + # Separated + method = self.create_method() + return self.create_class(method) + + def export_data(self, export_dir): + model_data = { + 'priors': self.estimator.class_prior_.tolist(), + 'sigmas': self.estimator.sigma_.tolist(), + 'thetas': self.estimator.theta_.tolist() + } + encoder.FLOAT_REPR = lambda o: self.repr(o) + path = os.path.join(export_dir, 'data.json') + with open(path, 'w') as fp: + json.dump(model_data, fp) def create_method(self): """ @@ -133,7 +163,7 @@ def create_method(self): :return out : string The built method as string. """ - temp_method = self.temp('method.predict', n_indents=1, skipping=True) + temp_method = self.temp('separated.method.predict', n_indents=1, skipping=True) out = temp_method.format(**self.__dict__) return out @@ -147,6 +177,6 @@ def create_class(self, method): The built class as string. """ self.__dict__.update(dict(method=method)) - temp_class = self.temp('class') + temp_class = self.temp('separated.class') out = temp_class.format(**self.__dict__) return out diff --git a/sklearn_porter/estimator/classifier/GaussianNB/templates/java/exported.class.txt b/sklearn_porter/estimator/classifier/GaussianNB/templates/java/exported.class.txt new file mode 100644 index 00000000..2112e680 --- /dev/null +++ b/sklearn_porter/estimator/classifier/GaussianNB/templates/java/exported.class.txt @@ -0,0 +1,67 @@ +import java.io.File; +import java.io.FileNotFoundException; +import java.util.*; +import com.google.gson.Gson; + + +class {class_name} {{ + + private class Classifier {{ + private double[] priors; + private double[][] sigmas; + private double[][] thetas; + }} + + private Classifier clf; + + public {class_name}(String file) throws FileNotFoundException {{ + String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next(); + this.clf = new Gson().fromJson(jsonStr, Classifier.class); + }} + + public int {method_name}(double[] features) {{ + double[] likelihoods = new double[this.clf.sigmas.length]; + + for (int i = 0, il = this.clf.sigmas.length; i < il; i++) {{ + double sum = 0.; + for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {{ + sum += Math.log(2. * Math.PI * this.clf.sigmas[i][j]); + }} + double nij = -0.5 * sum; + sum = 0.; + for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {{ + sum += Math.pow(features[j] - this.clf.thetas[i][j], 2.) / this.clf.sigmas[i][j]; + }} + nij -= 0.5 * sum; + likelihoods[i] = Math.log(this.clf.priors[i]) + nij; + }} + + int classIdx = 0; + for (int i = 0, l = likelihoods.length; i < l; i++) {{ + classIdx = likelihoods[i] > likelihoods[classIdx] ? i : classIdx; + }} + return classIdx; + }} + + public static void main(String[] args) throws FileNotFoundException {{ + if (args.length > 0 && args[0].endsWith(".json")) {{ + + // Features: + double[] features = new double[args.length-1]; + for (int i = 1, l = args.length; i < l; i++) {{ + features[i - 1] = Double.parseDouble(args[i]); + }} + + // Parameters: + String modelData = args[0]; + + // Estimators: + {class_name} clf = new {class_name}(modelData); + + // Prediction: + int prediction = clf.{method_name}(features); + System.out.println(prediction); + + }} + }} +}} \ No newline at end of file diff --git a/sklearn_porter/estimator/classifier/GaussianNB/templates/java/class.txt b/sklearn_porter/estimator/classifier/GaussianNB/templates/java/separated.class.txt similarity index 100% rename from sklearn_porter/estimator/classifier/GaussianNB/templates/java/class.txt rename to sklearn_porter/estimator/classifier/GaussianNB/templates/java/separated.class.txt diff --git a/sklearn_porter/estimator/classifier/GaussianNB/templates/java/method.predict.txt b/sklearn_porter/estimator/classifier/GaussianNB/templates/java/separated.method.predict.txt similarity index 100% rename from sklearn_porter/estimator/classifier/GaussianNB/templates/java/method.predict.txt rename to sklearn_porter/estimator/classifier/GaussianNB/templates/java/separated.method.predict.txt diff --git a/sklearn_porter/estimator/classifier/GaussianNB/templates/js/class.txt b/sklearn_porter/estimator/classifier/GaussianNB/templates/js/separated.class.txt similarity index 100% rename from sklearn_porter/estimator/classifier/GaussianNB/templates/js/class.txt rename to sklearn_porter/estimator/classifier/GaussianNB/templates/js/separated.class.txt diff --git a/sklearn_porter/estimator/classifier/GaussianNB/templates/js/method.predict.txt b/sklearn_porter/estimator/classifier/GaussianNB/templates/js/separated.method.predict.txt similarity index 100% rename from sklearn_porter/estimator/classifier/GaussianNB/templates/js/method.predict.txt rename to sklearn_porter/estimator/classifier/GaussianNB/templates/js/separated.method.predict.txt diff --git a/sklearn_porter/estimator/classifier/SVC/__init__.py b/sklearn_porter/estimator/classifier/SVC/__init__.py index 8a0bb2da..39db9ed5 100644 --- a/sklearn_porter/estimator/classifier/SVC/__init__.py +++ b/sklearn_porter/estimator/classifier/SVC/__init__.py @@ -4,7 +4,8 @@ import json from json import encoder import types -from ..Classifier import Classifier + +from sklearn_porter.estimator.classifier.Classifier import Classifier class SVC(Classifier): @@ -177,7 +178,6 @@ def export(self, class_name, method_name, self.coef0 = self.repr(self.params['coef0']) self.degree = self.repr(self.params['degree']) - if self.target_method == 'predict': # Exported: if export_data and os.path.isdir(export_dir): diff --git a/tests/estimator/classifier/GaussianNB/GaussianNBJavaTest.py b/tests/estimator/classifier/GaussianNB/GaussianNBJavaTest.py index 2c6becaf..7b5784f4 100644 --- a/tests/estimator/classifier/GaussianNB/GaussianNBJavaTest.py +++ b/tests/estimator/classifier/GaussianNB/GaussianNBJavaTest.py @@ -5,10 +5,11 @@ from sklearn.naive_bayes import GaussianNB from tests.estimator.classifier.Classifier import Classifier +from tests.estimator.classifier.ExportedData import ExportedData from tests.language.Java import Java -class GaussianNBJavaTest(Java, Classifier, TestCase): +class GaussianNBJavaTest(Java, Classifier, ExportedData, TestCase): def setUp(self): super(GaussianNBJavaTest, self).setUp()