Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1202] Change Builder class into a better way (#13159)
Browse files Browse the repository at this point in the history
* applying changes for Builder functions

* simplify the code structure

* update docgen

* follow Naveen's suggestion

* apply comments to Param

* clean up param build

* change on the comments

* add one description line
  • Loading branch information
lanking520 authored and nswamy committed Nov 12, 2018
1 parent 149ea17 commit 3664a7c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,3 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
override def equals(obj: Any): Boolean = nd.equals(obj)
override def hashCode(): Int = nd.hashCode
}

object NDArrayFuncReturn {
implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
: org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn)
: NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
}

private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) {
def head : NDArray = ndFuncReturn.head
def get : NDArray = ndFuncReturn.get
def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
// TODO: Add JavaNDArray operational stuff
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.mxnet.javaapi.NDArrayBase.*;

import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -71,15 +71,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(nd).invoke().get().toArray()[0];
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(nd, nd).setout(dotResult).invoke().get();
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ private[mxnet] object APIDocGenerator{
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateJavaAPISignature(absClassFunction)
s"$scalaDoc\n$defBody"
generateJavaAPISignature(absClassFunction)
})
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
Expand Down Expand Up @@ -203,27 +201,61 @@ private[mxnet] object APIDocGenerator{
}

def generateJavaAPISignature(func : absClassFunction) : String = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
var requiredParam = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
// scalastyle:off
if (absClassArg.isOptional) {
classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
if (absClassArg.isOptional && useParamObject) {
classDef +=
s"""private var $currArgName: ${absClassArg.argType} = null
|/**
| * @param $currArgName\t\t${absClassArg.argDesc}
| */
|def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = {
| this.$currArgName = $currArgName
| this
| }""".stripMargin
}
else {
requiredParam += s" * @param $currArgName\t\t${absClassArg.argDesc}"
argDef += s"$currArgName : ${absClassArg.argType}"
}
classDef += s"def get${currArgName.capitalize}() = this.$currArgName"
// scalastyle:on
})
classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
val experimentalTag = "@Experimental"
// scalastyle:off
var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
// scalastyle:on
finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}"
finalStr
val returnType = "Array[NDArray]"
val scalaDoc = generateAPIDocFromBackend(func)
val scalaDocNoParam = generateAPIDocFromBackend(func, false)
if(useParamObject) {
classDef +=
s"""private var out : org.apache.mxnet.NDArray = null
|def setOut(out : NDArray) : ${func.name}Param = {
| this.out = out
| this
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
} else {
argDef += "out : NDArray"
s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,14 @@ private[mxnet] object JavaNDArrayMacro {

newNDArrayFunctions.foreach { ndarrayfunction =>

val useParamObject = ndarrayfunction.listOfArgs.count(arg => arg.isOptional) >= 2
// Construct argument field with all required args
var argDef = ListBuffer[String]()
// Construct Optional Arg
var OptionArgDef = ListBuffer[String]()
// Construct function Implementation field (e.g norm)
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
// scalastyle:off
impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
// scalastyle:on
// Construct Class Implementation (e.g normBuilder)
var classImpl = ListBuffer[String]()
impl +=
"val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
Expand All @@ -88,55 +84,56 @@ private[mxnet] object JavaNDArrayMacro {
case "type" => "typeOf"
case _ => ndarrayArg.argName
}
if (ndarrayArg.isOptional) {
OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null"
val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})"
val tempImpl = s"this.$currArgName = $currArgName\nthis"
classImpl += s"$tempDef = {$tempImpl}"
} else {
argDef += s"$currArgName : ${ndarrayArg.argType}"
}
if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()"
argDef += s"$currArgName : ${ndarrayArg.argType}"
// NDArray arg implementation
val returnType = "org.apache.mxnet.javaapi.NDArray"
val base =
if (ndarrayArg.argType.equals(returnType)) {
s"args += this.$currArgName"
s"args += $currArgName"
} else if (ndarrayArg.argType.equals(s"Array[$returnType]")){
s"this.$currArgName.foreach(args+=_)"
s"$currArgName.foreach(args+=_)"
} else {
"map(\"" + ndarrayArg.argName + "\") = this." + currArgName
"map(\"" + ndarrayArg.argName + "\") = " + currArgName
}
impl.append(
if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base"
if (ndarrayArg.isOptional) s"if ($currArgName != null) $base"
else base
)
})
// add default out parameter
classImpl +=
"def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}"
impl += "if (this.out != null) map(\"out\") = this.out"
OptionArgDef += "private var out : org.apache.mxnet.NDArray = null"
val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn"
argDef += s"out: org.apache.mxnet.javaapi.NDArray"
if (useParamObject) {
impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
} else {
impl += "if (out != null) map(\"out\") = out"
}
val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
// scalastyle:off
// Combine and build the function string
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase"
val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
val classFinal = s"$classDef {$classBody}"
val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody"
// scalastyle:on
functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
impl += "val finalArr = org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" +
ndarrayfunction.name + "\", args.toSeq, map.toMap).arr"
impl += "finalArr.map(ele => new NDArray(ele))"
if (useParamObject) {
val funcDef =
s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}Param): $returnType = {
| ${impl.mkString("\n")}
| }""".stripMargin
functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
} else {
val funcDef =
s"""def ${ndarrayfunction.name}(${argDef.mkString(",")}): $returnType = {
| ${impl.mkString("\n")}
| }""".stripMargin
functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
}
}

structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*)
structGeneration(c)(functionDefs.toList, annottees : _*)
}

private def structGeneration(c: blackbox.Context)
(funcDef : List[c.universe.DefDef],
classDef : List[c.universe.ClassDef],
annottees: c.Expr[Any]*)
: c.Expr[Any] = {
import c.universe._
Expand All @@ -146,15 +143,15 @@ private[mxnet] object JavaNDArrayMacro {
case ClassDef(mods, name, something, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
ClassDef(mods, name, something, q)
case ModuleDef(mods, name, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
Expand Down

0 comments on commit 3664a7c

Please sign in to comment.