Skip to content

Commit

Permalink
apache#13385 [Clojure] - Turn examples into integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico committed Dec 10, 2018
1 parent f2ca66f commit e6f0888
Show file tree
Hide file tree
Showing 24 changed files with 830 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
;;

(ns cnn-text-classification.classifier
(:require [cnn-text-classification.data-helper :as data-helper]
(:require [clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[cnn-text-classification.data-helper :as data-helper]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
Expand All @@ -26,12 +28,18 @@
[org.apache.clojure-mxnet.context :as context])
(:gen-class))

(def data-dir "data/")
(def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path
(def glove-file-path "data/glove/glove.6B.50d.txt")
(def num-filter 100)
(def num-label 2)
(def dropout 0.5)



(when-not (.exists (io/file (str data-dir)))
(do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh")))

(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}]
(println "Shuffling the data and splitting into training and test sets")
(println {:sentence-count sentence-count
Expand Down Expand Up @@ -103,10 +111,10 @@
;;; omit max-examples if you want to run all the examples in the movie review dataset
;; to limit mem consumption set to something like 1000 and adjust test size to 100
(println "Running with context devices of" devs)
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
;; runs all the examples
#_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10})))

(comment
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns cnn-text-classification.classifier-test
(:require
[clojure.test :refer :all]
[org.apache.clojure-mxnet.module :as module]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.context :as context]
[cnn-text-classification.classifier :as classifier]))

;
; The one and unique classifier test
;
(deftest classifier-test
(let [train
(classifier/train-convnet
{:devs [(context/default-context)]
:embedding-size 50
:batch-size 10
:test-size 100
:num-epoch 1
:max-examples 1000})]
(is (= ["data"] (util/scala-vector->vec (module/data-names train))))
(is (= 20 (count (ndarray/->vec (-> train module/outputs first first)))))))
;(prn (util/scala-vector->vec (data-shapes train)))
;(prn (util/scala-vector->vec (label-shapes train)))
;(prn (output-names train))
;(prn (output-shapes train))
6 changes: 4 additions & 2 deletions contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@

(save-img-diff i n calc-diff))))

(defn train [devs]
(defn train
([devs] (train devs num-epoch))
([devs num-epoch]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
(m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
:label-shapes (mx-io/provide-label-desc mnist-iter)
Expand Down Expand Up @@ -203,7 +205,7 @@
(save-img-gout i n (ndarray/copy (ffirst out-g)))
(save-img-data i n batch)
(calc-diff i n (ffirst diff-d)))
(inc n)))))))
(inc n))))))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
25 changes: 25 additions & 0 deletions contrib/clojure-package/examples/gan/test/gan/gan_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns gan.gan_test
(:require
[gan.gan-mnist :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest check-pdf
(train [(context/cpu)] 1))
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
(def batch-size 10) ;; the batch size
(def optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.0}))
(def eval-metric (eval-metric/accuracy))
(def num-epoch 5) ;; the number of training epochs
(def num-epoch 1) ;; the number of training epochs
(def kvstore "local") ;; the kvstore type
;;; Note to run distributed you might need to complile the engine with an option set
(def role "worker") ;; scheduler/server/worker
Expand Down Expand Up @@ -82,7 +82,9 @@
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))

(defn start [devs]
(defn start
([devs] (start devs num-epoch))
([devs _num-epoch]
(when scheduler-host
(println "Initing PS enviornments with " envs)
(kvstore-server/init envs))
Expand All @@ -94,14 +96,18 @@
(do
(println "Starting Training of MNIST ....")
(println "Running with context devices of" devs)
(let [mod (m/module (get-symbol) {:contexts devs})]
(m/fit mod {:train-data train-data
(let [_mod (m/module (get-symbol) {:contexts devs})]
(m/fit _mod {:train-data train-data
:eval-data test-data
:num-epoch num-epoch
:num-epoch _num-epoch
:fit-params (m/fit-params {:kvstore kvstore
:optimizer optimizer
:eval-metric eval-metric})}))
(println "Finish fit"))))
:eval-metric eval-metric})})
(println "Finish fit")
_mod
)

))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns imclassification.train-mnist-test
(:require
[clojure.test :refer :all]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.module :as module]
[imclassification.train-mnist :as mnist]))

(defn- file-to-filtered-seq [file]
(->>
file
(io/file)
(io/reader)
(line-seq)
(filter #(not (clojure.string/includes? #"mxnet_version" %)))))

(deftest mnist-two-epochs-test
(module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2})
(is (=
(file-to-filtered-seq "test/test-symbol.json.ref")
(file-to-filtered-seq "target/test-symbol.json"))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "null",
"name": "fc1_weight",
"attrs": {"num_hidden": "128"},
"inputs": []
},
{
"op": "null",
"name": "fc1_bias",
"attrs": {"num_hidden": "128"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc1",
"attrs": {"num_hidden": "128"},
"inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]]
},
{
"op": "Activation",
"name": "relu1",
"attrs": {"act_type": "relu"},
"inputs": [[3, 0, 0]]
},
{
"op": "null",
"name": "fc2_weight",
"attrs": {"num_hidden": "64"},
"inputs": []
},
{
"op": "null",
"name": "fc2_bias",
"attrs": {"num_hidden": "64"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc2",
"attrs": {"num_hidden": "64"},
"inputs": [[4, 0, 0], [5, 0, 0], [6, 0, 0]]
},
{
"op": "Activation",
"name": "relu2",
"attrs": {"act_type": "relu"},
"inputs": [[7, 0, 0]]
},
{
"op": "null",
"name": "fc3_weight",
"attrs": {"num_hidden": "10"},
"inputs": []
},
{
"op": "null",
"name": "fc3_bias",
"attrs": {"num_hidden": "10"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc3",
"attrs": {"num_hidden": "10"},
"inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0]]
},
{
"op": "null",
"name": "softmax_label",
"inputs": []
},
{
"op": "SoftmaxOutput",
"name": "softmax",
"inputs": [[11, 0, 0], [12, 0, 0]]
}
],
"arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12],
"node_row_ptr": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14
],
"heads": [[13, 0, 0]],
"attrs": {"mxnet_version": ["int", 10400]}
}
29 changes: 29 additions & 0 deletions contrib/clojure-package/examples/module/test/mnist_mlp_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns mnist-mlp-test
(:require
[mnist-mlp :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest run-those-tests
(let [devs [(context/cpu)]]
(run-intermediate-level-api :devs devs)
(run-intermediate-level-api :devs devs :load-model-epoch (dec num-epoch))
(run-high-level-api devs)
(run-prediction-iterator-api devs)
(run-predication-and-calc-accuracy-manually devs)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns multi_label_test
(:require
[multi-label.core :as label]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest run-multi-label
(label/train [(context/cpu)]))
Loading

0 comments on commit e6f0888

Please sign in to comment.