Skip to content

Commit

Permalink
apache#13385 [Clojure] - Turn examples into integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico committed Dec 7, 2018
1 parent 2d27160 commit c9894f0
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
;;

(ns cnn-text-classification.classifier
(:require [cnn-text-classification.data-helper :as data-helper]
(:require [clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[cnn-text-classification.data-helper :as data-helper]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
Expand All @@ -26,12 +28,18 @@
[org.apache.clojure-mxnet.context :as context])
(:gen-class))

(def data-dir "data/")
(def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path
(def glove-file-path "data/glove/glove.6B.50d.txt")
(def num-filter 100)
(def num-label 2)
(def dropout 0.5)



(when-not (.exists (io/file (str data-dir)))
(do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh")))

(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}]
(println "Shuffling the data and splitting into training and test sets")
(println {:sentence-count sentence-count
Expand Down Expand Up @@ -103,10 +111,10 @@
;;; omit max-examples if you want to run all the examples in the movie review dataset
;; to limit mem consumption set to something like 1000 and adjust test size to 100
(println "Running with context devices of" devs)
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
;; runs all the examples
#_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10})))

(comment
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns cnn-text-classification.classifier-test
(:require
[clojure.test :refer :all]
[org.apache.clojure-mxnet.module :as module]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.context :as context]
[cnn-text-classification.classifier :as classifier]))

;
; The one and unique classifier test
;
(deftest classifier-test
(let [train
(classifier/train-convnet
{:devs [(context/default-context)]
:embedding-size 50
:batch-size 10
:test-size 100
:num-epoch 1
:max-examples 1000})]
(is (= ["data"] (util/scala-vector->vec (module/data-names train))))
(is (= 20 (count (ndarray/->vec (-> train module/outputs first first)))))))
;(prn (util/scala-vector->vec (data-shapes train)))
;(prn (util/scala-vector->vec (label-shapes train)))
;(prn (output-names train))
;(prn (output-shapes train))
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
(def content-weight 5) ;; the weight for the content image
(def blur-radius 1) ;; the blur filter radius
(def output-dir "output")
(def lr 10) ;; the learning rate
(def lr 10.0) ;; the learning rate
(def tv-weight 0.01) ;; the magnitude on the tv loss
(def num-epochs 1000)
(def num-channels 3)
Expand Down Expand Up @@ -157,9 +157,10 @@
out (ndarray/* out tv-weight)]
(sym/bind out ctx {"img" img "kernel" kernel}))))

(defn train [devs]

(let [dev (first devs)
(defn train
([devs] (train devs 20))
([devs n-epochs]
(let [dev (first devs)
content-np (preprocess-content-image content-image max-long-edge)
content-np-shape (mx-shape/->vec (ndarray/shape content-np))
style-np (preprocess-style-image style-image content-np-shape)
Expand Down Expand Up @@ -212,7 +213,7 @@
tv-grad-executor (get-tv-grad-executor img dev tv-weight)
eps 0.0
e 0]
(doseq [i (range 20)]
(doseq [i (range n-epochs)]
(ndarray/set (:data model-executor) img)
(-> (:executor model-executor)
(executor/forward)
Expand All @@ -237,8 +238,10 @@
(println "Epoch " i "relative change " eps)
(when (zero? (mod i 2))
(save-image (ndarray/copy img) (str output-dir "/out_" i ".png") blur-radius true)))

(ndarray/set old-img img))))
(ndarray/set old-img img))
; (save-image (ndarray/copy img) (str output-dir "/final.png") 0 false)
; (postprocess-image img)
)))

(defn -main [& args]
;;; Note this only works on cpu right now
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns neural-style.vgg-19-test
(:require
[clojure.test :refer :all]
[mikera.image.core :as img]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.context :as context]
[neural-style.core :as neural]))

(defn pic-to-ndarray-vec[path]
(-> path
img/load-image
neural/image->ndarray
ndarray/->vec))

(defn last-modified-check[x]
(let [t (- (System/currentTimeMillis) (.lastModified x)) ]
(if (> 10000 t) ; 10 seconds
x
(throw (Exception. (str "Generated File Too Old: (" t " ms) [" x "]"))))))

(defn latest-pic-to-ndarray-vec[folder]
(->> folder
io/as-file
(.listFiles)
(sort-by #(.lastModified %))
last
(last-modified-check)
(.getPath)
pic-to-ndarray-vec))

;
; The one and unique classifier test
;
(deftest vgg-19-test
(neural/train [(context/cpu)] 3)
(is (=
(latest-pic-to-ndarray-vec "output")
(pic-to-ndarray-vec "test/ref_out_2.png"))))
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c9894f0

Please sign in to comment.