Skip to content

Commit

Permalink
Add templates, tests and logic for the exported model data approach
Browse files Browse the repository at this point in the history
  • Loading branch information
Darius Morawiec committed Nov 12, 2017
1 parent fc1e472 commit 79d846f
Show file tree
Hide file tree
Showing 21 changed files with 266 additions and 164 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ tmp
tmp.*
Tmp.*

.ipynb_checkpoints
.ipynb_checkpoints

*.jar
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ install:
- conda env create -q -n sklearn-porter python=$TRAVIS_PYTHON_VERSION -f environment.yml
- pip install -U pip
- source activate sklearn-porter
- wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
- mv gson-2.8.2.jar gson.jar
- SKLEARN_PORTER_HOME=$(pwd)
before_script:
- python --version
- gcc --version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
clf.fit(X, y)

porter = Porter(clf)
output = porter.export(embedded=True)
output = porter.export(export_data=True)
print(output)

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"indices": [3, -2, 3, 2, 3, -2, -2, 3, -2, -2, 2, 0, -2, -2, -2], "thresholds": [0.800000011921, -2.0, 1.75, 4.94999980927, 1.65000009537, -2.0, -2.0, 1.54999995232, -2.0, -2.0, 4.85000038147, 5.94999980927, -2.0, -2.0, -2.0], "classes": [[0.333333333333, 0.333333333333, 0.333333333333], [0.333333333333, 0.0, 0.0], [0.0, 0.333333333333, 0.333333333333], [0.0, 0.326666666667, 0.0333333333333], [0.0, 0.313333333333, 0.00666666666667], [0.0, 0.313333333333, 0.0], [0.0, 0.0, 0.00666666666667], [0.0, 0.0133333333333, 0.0266666666667], [0.0, 0.0, 0.02], [0.0, 0.0133333333333, 0.00666666666667], [0.0, 0.00666666666667, 0.3], [0.0, 0.00666666666667, 0.0133333333333], [0.0, 0.00666666666667, 0.0], [0.0, 0.0, 0.0133333333333], [0.0, 0.0, 0.286666666667]], "childrenRight": [2, -1, 10, 7, 6, -1, -1, 9, -1, -1, 14, 13, -1, -1, -1], "childrenLeft": [1, -1, 3, 4, 5, -1, -1, 8, -1, -1, 11, 12, -1, -1, -1]}, {"indices": [2, 2, -2, 3, 0, -2, -2, 1, -2, -2, -2], "thresholds": [5.14999961853, 2.45000004768, -2.0, 1.75, 4.94999980927, -2.0, -2.0, 3.15000009537, -2.0, -2.0, -2.0], "classes": [[8.3290724464e-05, 0.499957521731, 0.499959187545], [8.3290724464e-05, 0.499957521731, 2.66530318285e-05], [8.3290724464e-05, 0.0, 0.0], [0.0, 0.499957521731, 2.66530318285e-05], [0.0, 0.499955855916, 4.99744346784e-06], [0.0, 1.66581448928e-06, 1.66581448928e-06], [0.0, 0.499954190102, 3.33162897856e-06], [0.0, 1.66581448928e-06, 2.16555883606e-05], [0.0, 0.0, 1.99897738714e-05], [0.0, 1.66581448928e-06, 1.66581448928e-06], [0.0, 0.0, 0.499932534513]], "childrenRight": [10, 3, -1, 7, 6, -1, -1, 9, -1, -1, -1], "childrenLeft": [1, 2, -1, 4, 5, -1, -1, 8, -1, -1, -1]}, {"indices": [3, 2, 3, -2, -2, -2, 2, 3, 0, -2, -2, -2, -2], "thresholds": [1.54999995232, 4.94999980927, 0.800000011921, -2.0, -2.0, -2.0, 5.14999961853, 1.84999990463, 5.40000009537, -2.0, -2.0, -2.0, -2.0], "classes": [[2.67881771865e-08, 0.499919588597, 0.500080384615], [2.67881771865e-08, 0.000184731094993, 0.499696643102], [2.67881771865e-08, 0.000184731094993, 0.0], [2.67881771865e-08, 0.0, 0.0], [0.0, 0.000184731094993, 0.0], [0.0, 0.0, 0.499696643102], [0.0, 0.499734857502, 0.000383741512437], [0.0, 0.499734857502, 0.00022295245966], [0.0, 0.499734857502, 0.000111475694067], [0.0, 0.0, 0.000111473015249], [0.0, 0.499734857502, 2.67881771865e-09], [0.0, 0.0, 0.000111476765594], [0.0, 0.0, 0.000160789052777]], "childrenRight": [6, 5, 4, -1, -1, -1, 12, 11, 10, -1, -1, -1, -1], "childrenLeft": [1, 2, 3, -1, -1, -1, 7, 8, 9, -1, -1, -1, -1]}, {"indices": [3, 3, 2, 3, -2, -2, -2, 0, 1, -2, -2, -2, 1, -2, 2, -2, -2], "thresholds": [1.75, 1.54999995232, 4.94999980927, 0.800000011921, -2.0, -2.0, -2.0, 6.94999980927, 2.59999990463, -2.0, -2.0, -2.0, 3.15000009537, -2.0, 4.94999980927, -2.0, -2.0], "classes": [[9.25765397376e-11, 0.499136211999, 0.500863787909], [9.25765397376e-11, 0.499024872662, 0.00172782900859], [9.25765397376e-11, 6.38407213652e-07, 0.00172688816469], [9.25765397376e-11, 6.38407213652e-07, 0.0], [9.25765397376e-11, 0.0, 0.0], [0.0, 6.38407213652e-07, 0.0], [0.0, 0.0, 0.00172688816469], [0.0, 0.499024234255, 9.40843895869e-07], [0.0, 0.499024234255, 3.85236589785e-07], [0.0, 0.0, 3.85236589785e-07], [0.0, 0.499024234255, 0.0], [0.0, 0.0, 5.55607306084e-07], [0.0, 0.000111339336392, 0.4991359589], [0.0, 0.0, 0.499135573641], [0.0, 0.000111339336392, 3.85258808154e-07], [0.0, 0.000111339336392, 0.0], [0.0, 0.0, 3.85258808154e-07]], "childrenRight": [12, 7, 6, 5, -1, -1, -1, 11, 10, -1, -1, -1, 14, -1, 16, -1, -1], "childrenLeft": [1, 2, 3, 4, -1, -1, -1, 8, 9, -1, -1, -1, 13, -1, 15, -1, -1]}]
8 changes: 4 additions & 4 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
<td align="left" width="40%">Classification</td>
<td align="center" width="10%">C</td>
<td align="center" width="10%">Java *</td>
<td align="center" width="10%">JavaScript</td>
<td align="center" width="10%">JS</td>
<td align="center" width="10%">Go</td>
<td align="center" width="10%">PHP</td>
<td align="center" width="10%">Ruby</td>
Expand Down Expand Up @@ -84,7 +84,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
<tr>
<td><a href="http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html">sklearn.ensemble.AdaBoostClassifier</a></td>
<td align="center"><a href="examples/estimator/classifier/AdaBoostClassifier/c/basics_embedded.ipynb">✓ ᴱ</a></td>
<td align="center"><a href="examples/estimator/classifier/AdaBoostClassifier/java/basics_embedded.ipynb">✓ ᴱ</a></td>
<td align="center"><a href="examples/estimator/classifier/AdaBoostClassifier/java/basics_imported.ipynb">✓ ᴵ</a>, <a href="examples/estimator/classifier/AdaBoostClassifier/java/basics_embedded.ipynb">✓ ᴱ</a></td>
<td align="center"><a href="examples/estimator/classifier/AdaBoostClassifier/js/basics_embedded.ipynb">✓ ᴱ</a></td>
<td align="center"></td>
<td align="center"></td>
Expand Down Expand Up @@ -142,7 +142,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
</tbody>
</table>

✓ = is full-featured, ᴱ = embedded model data, * = default language
✓ = is full-featured, ᴱ = with embedded model data, ᴵ = with imported model data, * = default language

## Installation

Expand Down Expand Up @@ -301,7 +301,7 @@ source activate sklearn-porter
The following compilers or intepreters are required to cover all tests:

- [GCC](https://gcc.gnu.org) (`>=4.2`)
- [Java](https://java.com) (`>=1.7`)
- [Java](https://java.com) (`>=1.6`)
- [PHP](http://www.php.net/) (`>=7`)
- [Ruby](https://www.ruby-lang.org) (`>=2.4.1`)
- [Go](https://golang.org/) (`>=1.7.4`)
Expand Down
1 change: 1 addition & 0 deletions sklearn_porter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def main():
method_name = str(args.get('method_name'))
output = porter.export(class_name=class_name,
method_name=method_name,
output=str(args.get('output')),
details=True)
except Exception as e:
sys.exit('Error: {}'.format(str(e)))
Expand Down
75 changes: 32 additions & 43 deletions sklearn_porter/estimator/classifier/AdaBoostClassifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-

import os
import json
from json import encoder
import sklearn
from sklearn_porter.estimator.classifier.Classifier import Classifier

Expand Down Expand Up @@ -81,7 +84,8 @@ def __init__(self, estimator, target_language='java',

self.estimator = estimator

def export(self, class_name, method_name, embedded=False):
def export(self, class_name, method_name,
export_data=False, export_dir='.'):
"""
Port a trained estimator to the syntax of a chosen programming language.
Expand All @@ -97,52 +101,47 @@ def export(self, class_name, method_name, embedded=False):
:return : string
The transpiled algorithm with the defined placeholders.
"""

if self.target_language in ['c']:
embedded = True

# TODO: Force the embedded mode, remove after the updates.
embedded = True

# Arguments:
self.class_name = class_name
self.method_name = method_name

# Estimator:
est = self.estimator

self.n_classes = est.n_classes_
# Basic parameters:
self.estimators = []
self.weights = []
self.n_estimators = 0
for idx in range(est.n_estimators):
weight = est.estimator_weights_[idx]
if weight > 0:
if est.estimator_weights_[idx] > 0:
self.estimators.append(est.estimators_[idx])
self.weights.append(est.estimator_weights_[idx])
self.n_estimators += 1
self.n_features = est.estimators_[idx].n_features_
self.n_classes = est.n_classes_
self.n_features = est.estimators_[0].n_features_
self.n_estimators = len(self.estimator)

if self.target_method == 'predict':
return self.predict(embedded)

def predict(self, embedded):
"""
Transpile the predict method.

Returns
-------
:return : string
The transpiled predict method as string.
"""
if embedded:
# Exported data:
if export_data and os.path.isdir(export_dir):
model_data = []
for est in self.estimators:
model_data.append({
'childrenLeft': est.tree_.children_left.tolist(),
'childrenRight': est.tree_.children_right.tolist(),
'thresholds': est.tree_.threshold.tolist(),
'classes': [e[0] for e in est.tree_.value.tolist()],
'indices': est.tree_.feature.tolist()
})
encoder.FLOAT_REPR = lambda o: self.repr(o)
path = os.path.join(export_dir, 'data.json')
with open(path, 'w') as fp:
json.dump(model_data, fp)

temp_class = self.temp('exported.class')
return temp_class.format(class_name=self.class_name,
method_name=self.method_name)

# Deep embedded data:
method = self.create_method_embedded()
out = self.create_class_embedded(method)
return out

# return self.create_class(self.create_method())
out = self.create_class()
return out
return self.create_class_embedded(method)

def create_branches(self, left_nodes, right_nodes, threshold,
value, features, node, depth, init=False):
Expand Down Expand Up @@ -291,13 +290,3 @@ def create_class_embedded(self, method):
method_name=self.method_name, method=method,
n_features=self.n_features)
return out

def create_class(self):

# {left_childs}
# {right_childs}
# {thresholds}
# {indices}
# {classes}

return '-'
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Scanner;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;


class {class_name} {{

private class Tree {{
private int[] childrenLeft;
private int[] childrenRight;
private double[] thresholds;
private int[] indices;
private double[][] classes;

private double[] predict (double[] features, int node) {{
if (this.thresholds[node] != -2) {{
if (features[this.indices[node]] <= this.thresholds[node]) {{
return this.predict(features, this.childrenLeft[node]);
}} else {{
return this.predict(features, this.childrenRight[node]);
}}
}}
return this.classes[node];
}}
private double[] predict (double[] features) {{
return this.predict(features, 0);
}}
}}

private List<Tree> forest;
private int nClasses;
private int nEstimators;

public {class_name} (String file) throws FileNotFoundException {{
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
Gson gson = new Gson();
Type listType = new TypeToken<List<Tree>>(){{}}.getType();
this.forest = gson.fromJson(jsonStr, listType);
this.nEstimators = this.forest.size();
this.nClasses = this.forest.get(0).classes[0].length;
}}

private int findMax(double[] nums) {{
int index = 0;
for (int i = 0; i < nums.length; i++) {{
index = nums[i] > nums[index] ? i : index;
}}
return index;
}}

public int {method_name}(double[] features) {{
double[][] preds = new double[this.nEstimators][this.nClasses];
double normalizer, sum;
int i, j;

for (i = 0; i < this.nEstimators; i++) {{
preds[i] = this.forest.get(i).predict(features, 0);
}}
for (i = 0; i < this.nEstimators; i++) {{
normalizer = 0.;
for (j = 0; j < this.nClasses; j++) {{
normalizer += preds[i][j];
}}
if (normalizer == 0.) {{
normalizer = 1.;
}}
for (j = 0; j < this.nClasses; j++) {{
preds[i][j] = preds[i][j] / normalizer;
if (preds[i][j] <= 2.2204460492503131e-16) {{
preds[i][j] = 2.2204460492503131e-16;
}}
preds[i][j] = Math.log(preds[i][j]);
}}
sum = 0.;
for (j = 0; j < this.nClasses; j++) {{
sum += preds[i][j];
}}
for (j = 0; j < this.nClasses; j++) {{
preds[i][j] = (this.nClasses - 1) * (preds[i][j] - (1. / this.nClasses) * sum);
}}
}}
double[] classes = new double[this.nClasses];
for (i = 0; i < this.nEstimators; i++) {{
for (j = 0; j < this.nClasses; j++) {{
classes[j] += preds[i][j];
}}
}}

return this.findMax(classes);
}}

public static void main(String[] args) throws IOException {{
if (args[0].endsWith(".json")) {{

// Features:
double[] features = new double[args.length-1];
for (int i = 1, l = args.length; i < l; i++) {{
features[i-1] = Double.parseDouble(args[i]);
}}

// Estimator with parameters:
String modelData = args[0];
{class_name} clf = new {class_name}(modelData);

// Prediction:
int prediction = clf.{method_name}(features);
System.out.println(prediction);

}}
}}
}}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-
Loading

0 comments on commit 79d846f

Please sign in to comment.