Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

use ResourceScope in Model/Trainer/FeedForward.scala #12882

Merged
merged 3 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 97 additions & 55 deletions scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.optimizer.SGD
import org.slf4j.{LoggerFactory, Logger}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -55,7 +56,7 @@ class FeedForward private(
argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
private val allowExtraParams: Boolean,
val beginEpoch: Int) {
val beginEpoch: Int) extends NativeResource {

val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
private var argumentChecked = false
Expand Down Expand Up @@ -126,6 +127,8 @@ class FeedForward private(
}

// Initialize weight parameters and auxiliary states
// The NDArrays associated with the _argParms and _auxParams are not disposed instead
// they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
Expand All @@ -137,16 +140,26 @@ class FeedForward private(
val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
paramNames.contains(name)
}
val argParams = paramNameShapes.map { case (name, shape) =>
(name, NDArray.zeros(shape))
val argParams = paramNameShapes.map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap
val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
(name, NDArray.zeros(shape))

val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap

for ((k, v) <- argParams) {
if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
argParams(k).set(_argParams(k))

} else {
initializer(k, v)
}
Expand Down Expand Up @@ -277,13 +290,15 @@ class FeedForward private(
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, kvStoreType: String,
epochEndCallback: EpochEndCallback, batchEndCallback: BatchEndCallback,
logger: Logger, workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
kvStore.foreach(_.dispose())
ResourceScope.using() {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
// kvStore.foreach(_.dispose())
Copy link
Contributor

@andrewfayres andrewfayres Oct 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There a reason to keep this comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you asking about the comment? no i'll remove

}
}

def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
Expand Down Expand Up @@ -313,11 +328,13 @@ class FeedForward private(
batchEndCallback: BatchEndCallback, logger: Logger,
workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
ResourceScope.using() {
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
}
}

def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
Expand Down Expand Up @@ -352,44 +369,49 @@ class FeedForward private(
batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
val (argNames, paramNames, auxNames) = initSymbolParams(trainData)

// init optimizer
val batchSizeMultiplier = kvStore.map { kv =>
if (kv.`type` == "dist_sync") {
kv.numWorkers
} else {
1
}
}
val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
this.optimizer.setArgNames(argNames)
this.optimizer.setRescaleGrad(1f / batchSize)
this.optimizer.setSymbol(this.symbol)
val paramIdx2Name =
if (updateOnKVStore) {
paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
} else {
paramNames.zipWithIndex.flatMap { case (name, idx) =>
(0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
}.toMap
// TODO: https://issues.apache.org/jira/browse/MXNET-1171
// this leaks memory, initSymbolParams->initParams is already called which allocates
// NDArray in argParams, auxParams and here we are overwriting it by calling again.
// PhantomRef should take care of releasing this when GC is called, however we have to
// wait for the GC call to happen.
val (argNames, paramNames, auxNames) = initSymbolParams(trainData)

// init optimizer
val batchSizeMultiplier = kvStore.map { kv =>
if (kv.`type` == "dist_sync") {
kv.numWorkers
} else {
1
}
}
this.optimizer.setIdx2Name(paramIdx2Name)

logger.debug("Start training on multi-device")
Model.trainMultiDevice(
symbol, ctx, argNames, paramNames, auxNames,
_argParams, _auxParams,
this.beginEpoch, this.numEpoch,
this.epochSize, this.optimizer,
kvStore, updateOnKVStore,
trainData = trainData, evalData = Option(evalData),
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
workLoadList = workLoadList,
monitor = monitor,
symGen = symGen)
val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
this.optimizer.setArgNames(argNames)
this.optimizer.setRescaleGrad(1f / batchSize)
this.optimizer.setSymbol(this.symbol)
val paramIdx2Name =
if (updateOnKVStore) {
paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
} else {
paramNames.zipWithIndex.flatMap { case (name, idx) =>
(0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
}.toMap
}
this.optimizer.setIdx2Name(paramIdx2Name)

logger.debug("Start training on multi-device")
Model.trainMultiDevice(
symbol, ctx, argNames, paramNames, auxNames,
_argParams, _auxParams,
this.beginEpoch, this.numEpoch,
this.epochSize, this.optimizer,
kvStore, updateOnKVStore,
trainData = trainData, evalData = Option(evalData),
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
workLoadList = workLoadList,
monitor = monitor,
symGen = symGen)
}

/**
Expand All @@ -416,9 +438,29 @@ class FeedForward private(
def serialize(): Array[Byte] = {
Model.serialize(this.symbol, getArgParams, getAuxParams)
}

// hack to make the FeedForward.scala work with ResourceScope and
// automatically release _argParms and _auxParms
override def nativeAddress: CPtrAddress = hashCode()

override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator

override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

override def dispose(): Unit = {
if (!super.isDisposed) {
_argParams.foreach { case (_, param) => param.dispose() }
_auxParams.foreach { case (_, param) => param.dispose() }
}
}
}

object FeedForward {

private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0

private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
// Check if name is a data argument.
private def isDataArg(name: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ private[mxnet] trait NativeResource
*/
def nativeDeAllocator: (CPtrAddress => Int)

/** Call NativeResource.register to get the reference
/**
* Call NativeResource.register to get the reference
*/
val ref: NativeResourceRef

Expand All @@ -56,6 +57,7 @@ private[mxnet] trait NativeResource
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long

// this is set and unset by [[ResourceScope.add]] and [[ResourceScope.remove]]
private[mxnet] var scope: Option[ResourceScope] = None

@volatile private var disposed = false
Expand All @@ -69,11 +71,11 @@ private[mxnet] trait NativeResource
* using PhantomReference
*/
def register(): NativeResourceRef = {
scope = ResourceScope.getCurrentScope()
val scope = ResourceScope.getCurrentScope()
if (scope.isDefined) scope.get.add(this)

NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
// register with PhantomRef tracking to release incase the objects go
// register with PhantomRef tracking to release in case the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable {
*/
def add(resource: NativeResource): Unit = {
resourceQ.+=(resource)
resource.scope = Some(this)
}

/**
Expand All @@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable {
*/
def remove(resource: NativeResource): Unit = {
resourceQ.-=(resource)
resource.scope = None
}

/**
* Removes from current Scope and moves to outer scope if it exists
* @param resource Resource to be moved to an outer scope
*/
def moveToOuterScope(resource: NativeResource): Unit = {
val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
if (prevScope.isDefined) {
this.remove(resource)
prevScope.get.add(resource)
} else this.remove(resource)
}

}

object ResourceScope {
Expand All @@ -92,32 +107,22 @@ object ResourceScope {

val curScope = if (scope != null) scope else new ResourceScope()

val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()

@inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
g.foreach( n =>
n match {
case nRes: NativeResource => {
removeAndAddToPrevScope(nRes)
curScope.moveToOuterScope(nRes)
}
case kv: scala.Tuple2[_, _] => {
if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._1.asInstanceOf[NativeResource])
if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._2.asInstanceOf[NativeResource])
}
}
)
}

@inline def removeAndAddToPrevScope(r: NativeResource) = {
curScope.remove(r)
if (prevScope.isDefined) {
prevScope.get.add(r)
r.scope = prevScope
}
}

@inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
}
Expand All @@ -129,8 +134,8 @@ object ResourceScope {
ret match {
// don't de-allocate if returning any collection that contains NativeResource.
case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
case nRes: NativeResource => removeAndAddToPrevScope(nRes)
case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
case nRes: NativeResource => curScope.moveToOuterScope(nRes)
case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
case _ => // do nothing
}
ret
Expand Down
Loading