-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add templates, tests and logic for the exported model data approach
- Loading branch information
Darius Morawiec
committed
Nov 12, 2017
1 parent
fc1e472
commit 79d846f
Showing
21 changed files
with
266 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,4 +31,6 @@ tmp | |
tmp.* | ||
Tmp.* | ||
|
||
.ipynb_checkpoints | ||
.ipynb_checkpoints | ||
|
||
*.jar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
examples/estimator/classifier/AdaBoostClassifier/java/data.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[{"indices": [3, -2, 3, 2, 3, -2, -2, 3, -2, -2, 2, 0, -2, -2, -2], "thresholds": [0.800000011921, -2.0, 1.75, 4.94999980927, 1.65000009537, -2.0, -2.0, 1.54999995232, -2.0, -2.0, 4.85000038147, 5.94999980927, -2.0, -2.0, -2.0], "classes": [[0.333333333333, 0.333333333333, 0.333333333333], [0.333333333333, 0.0, 0.0], [0.0, 0.333333333333, 0.333333333333], [0.0, 0.326666666667, 0.0333333333333], [0.0, 0.313333333333, 0.00666666666667], [0.0, 0.313333333333, 0.0], [0.0, 0.0, 0.00666666666667], [0.0, 0.0133333333333, 0.0266666666667], [0.0, 0.0, 0.02], [0.0, 0.0133333333333, 0.00666666666667], [0.0, 0.00666666666667, 0.3], [0.0, 0.00666666666667, 0.0133333333333], [0.0, 0.00666666666667, 0.0], [0.0, 0.0, 0.0133333333333], [0.0, 0.0, 0.286666666667]], "childrenRight": [2, -1, 10, 7, 6, -1, -1, 9, -1, -1, 14, 13, -1, -1, -1], "childrenLeft": [1, -1, 3, 4, 5, -1, -1, 8, -1, -1, 11, 12, -1, -1, -1]}, {"indices": [2, 2, -2, 3, 0, -2, -2, 1, -2, -2, -2], "thresholds": [5.14999961853, 2.45000004768, -2.0, 1.75, 4.94999980927, -2.0, -2.0, 3.15000009537, -2.0, -2.0, -2.0], "classes": [[8.3290724464e-05, 0.499957521731, 0.499959187545], [8.3290724464e-05, 0.499957521731, 2.66530318285e-05], [8.3290724464e-05, 0.0, 0.0], [0.0, 0.499957521731, 2.66530318285e-05], [0.0, 0.499955855916, 4.99744346784e-06], [0.0, 1.66581448928e-06, 1.66581448928e-06], [0.0, 0.499954190102, 3.33162897856e-06], [0.0, 1.66581448928e-06, 2.16555883606e-05], [0.0, 0.0, 1.99897738714e-05], [0.0, 1.66581448928e-06, 1.66581448928e-06], [0.0, 0.0, 0.499932534513]], "childrenRight": [10, 3, -1, 7, 6, -1, -1, 9, -1, -1, -1], "childrenLeft": [1, 2, -1, 4, 5, -1, -1, 8, -1, -1, -1]}, {"indices": [3, 2, 3, -2, -2, -2, 2, 3, 0, -2, -2, -2, -2], "thresholds": [1.54999995232, 4.94999980927, 0.800000011921, -2.0, -2.0, -2.0, 5.14999961853, 1.84999990463, 5.40000009537, -2.0, -2.0, -2.0, -2.0], "classes": [[2.67881771865e-08, 0.499919588597, 0.500080384615], [2.67881771865e-08, 0.000184731094993, 0.499696643102], [2.67881771865e-08, 0.000184731094993, 0.0], [2.67881771865e-08, 0.0, 0.0], [0.0, 0.000184731094993, 0.0], [0.0, 0.0, 0.499696643102], [0.0, 0.499734857502, 0.000383741512437], [0.0, 0.499734857502, 0.00022295245966], [0.0, 0.499734857502, 0.000111475694067], [0.0, 0.0, 0.000111473015249], [0.0, 0.499734857502, 2.67881771865e-09], [0.0, 0.0, 0.000111476765594], [0.0, 0.0, 0.000160789052777]], "childrenRight": [6, 5, 4, -1, -1, -1, 12, 11, 10, -1, -1, -1, -1], "childrenLeft": [1, 2, 3, -1, -1, -1, 7, 8, 9, -1, -1, -1, -1]}, {"indices": [3, 3, 2, 3, -2, -2, -2, 0, 1, -2, -2, -2, 1, -2, 2, -2, -2], "thresholds": [1.75, 1.54999995232, 4.94999980927, 0.800000011921, -2.0, -2.0, -2.0, 6.94999980927, 2.59999990463, -2.0, -2.0, -2.0, 3.15000009537, -2.0, 4.94999980927, -2.0, -2.0], "classes": [[9.25765397376e-11, 0.499136211999, 0.500863787909], [9.25765397376e-11, 0.499024872662, 0.00172782900859], [9.25765397376e-11, 6.38407213652e-07, 0.00172688816469], [9.25765397376e-11, 6.38407213652e-07, 0.0], [9.25765397376e-11, 0.0, 0.0], [0.0, 6.38407213652e-07, 0.0], [0.0, 0.0, 0.00172688816469], [0.0, 0.499024234255, 9.40843895869e-07], [0.0, 0.499024234255, 3.85236589785e-07], [0.0, 0.0, 3.85236589785e-07], [0.0, 0.499024234255, 0.0], [0.0, 0.0, 5.55607306084e-07], [0.0, 0.000111339336392, 0.4991359589], [0.0, 0.0, 0.499135573641], [0.0, 0.000111339336392, 3.85258808154e-07], [0.0, 0.000111339336392, 0.0], [0.0, 0.0, 3.85258808154e-07]], "childrenRight": [12, 7, 6, 5, -1, -1, -1, 11, 10, -1, -1, -1, 14, -1, 16, -1, -1], "childrenLeft": [1, 2, 3, 4, -1, -1, -1, 8, 9, -1, -1, -1, 13, -1, 15, -1, -1]}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
sklearn_porter/estimator/classifier/AdaBoostClassifier/templates/java/exported.class.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import java.io.File; | ||
import java.io.FileNotFoundException; | ||
import java.io.IOException; | ||
import java.lang.reflect.Type; | ||
import java.util.List; | ||
import java.util.Scanner; | ||
import com.google.gson.Gson; | ||
import com.google.gson.reflect.TypeToken; | ||
|
||
|
||
class {class_name} {{ | ||
|
||
private class Tree {{ | ||
private int[] childrenLeft; | ||
private int[] childrenRight; | ||
private double[] thresholds; | ||
private int[] indices; | ||
private double[][] classes; | ||
|
||
private double[] predict (double[] features, int node) {{ | ||
if (this.thresholds[node] != -2) {{ | ||
if (features[this.indices[node]] <= this.thresholds[node]) {{ | ||
return this.predict(features, this.childrenLeft[node]); | ||
}} else {{ | ||
return this.predict(features, this.childrenRight[node]); | ||
}} | ||
}} | ||
return this.classes[node]; | ||
}} | ||
private double[] predict (double[] features) {{ | ||
return this.predict(features, 0); | ||
}} | ||
}} | ||
|
||
private List<Tree> forest; | ||
private int nClasses; | ||
private int nEstimators; | ||
|
||
public {class_name} (String file) throws FileNotFoundException {{ | ||
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next(); | ||
Gson gson = new Gson(); | ||
Type listType = new TypeToken<List<Tree>>(){{}}.getType(); | ||
this.forest = gson.fromJson(jsonStr, listType); | ||
this.nEstimators = this.forest.size(); | ||
this.nClasses = this.forest.get(0).classes[0].length; | ||
}} | ||
|
||
private int findMax(double[] nums) {{ | ||
int index = 0; | ||
for (int i = 0; i < nums.length; i++) {{ | ||
index = nums[i] > nums[index] ? i : index; | ||
}} | ||
return index; | ||
}} | ||
|
||
public int {method_name}(double[] features) {{ | ||
double[][] preds = new double[this.nEstimators][this.nClasses]; | ||
double normalizer, sum; | ||
int i, j; | ||
|
||
for (i = 0; i < this.nEstimators; i++) {{ | ||
preds[i] = this.forest.get(i).predict(features, 0); | ||
}} | ||
for (i = 0; i < this.nEstimators; i++) {{ | ||
normalizer = 0.; | ||
for (j = 0; j < this.nClasses; j++) {{ | ||
normalizer += preds[i][j]; | ||
}} | ||
if (normalizer == 0.) {{ | ||
normalizer = 1.; | ||
}} | ||
for (j = 0; j < this.nClasses; j++) {{ | ||
preds[i][j] = preds[i][j] / normalizer; | ||
if (preds[i][j] <= 2.2204460492503131e-16) {{ | ||
preds[i][j] = 2.2204460492503131e-16; | ||
}} | ||
preds[i][j] = Math.log(preds[i][j]); | ||
}} | ||
sum = 0.; | ||
for (j = 0; j < this.nClasses; j++) {{ | ||
sum += preds[i][j]; | ||
}} | ||
for (j = 0; j < this.nClasses; j++) {{ | ||
preds[i][j] = (this.nClasses - 1) * (preds[i][j] - (1. / this.nClasses) * sum); | ||
}} | ||
}} | ||
double[] classes = new double[this.nClasses]; | ||
for (i = 0; i < this.nEstimators; i++) {{ | ||
for (j = 0; j < this.nClasses; j++) {{ | ||
classes[j] += preds[i][j]; | ||
}} | ||
}} | ||
|
||
return this.findMax(classes); | ||
}} | ||
|
||
public static void main(String[] args) throws IOException {{ | ||
if (args[0].endsWith(".json")) {{ | ||
|
||
// Features: | ||
double[] features = new double[args.length-1]; | ||
for (int i = 1, l = args.length; i < l; i++) {{ | ||
features[i-1] = Double.parseDouble(args[i]); | ||
}} | ||
|
||
// Estimator with parameters: | ||
String modelData = args[0]; | ||
{class_name} clf = new {class_name}(modelData); | ||
|
||
// Prediction: | ||
int prediction = clf.{method_name}(features); | ||
System.out.println(prediction); | ||
|
||
}} | ||
}} | ||
}} |
87 changes: 0 additions & 87 deletions
87
sklearn_porter/estimator/classifier/AdaBoostClassifier/templates/js/class.txt
This file was deleted.
Oops, something went wrong.
1 change: 1 addition & 0 deletions
1
sklearn_porter/estimator/classifier/AdaBoostClassifier/templates/js/exported.class.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
- |
Oops, something went wrong.