Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Test/mkldnn batch norm op #13084

Closed
wants to merge 183 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
f41d6b1
add conv test
azai91 Jul 16, 2018
82859ad
remove pool type
azai91 Jul 16, 2018
43195f8
gettestinput can receive scale input
azai91 Jul 16, 2018
888fb95
create kernels and bias arrays
azai91 Jul 16, 2018
e95fef4
fix shape of kernel / bias
azai91 Jul 17, 2018
efd41aa
fix format
azai91 Jul 17, 2018
5eb5231
bias is 1dim
azai91 Jul 17, 2018
736698b
fix output shape
azai91 Jul 17, 2018
64c4ba6
fix backwards input
azai91 Jul 17, 2018
825c9ef
fix var name to backwards_ex_outputs
azai91 Jul 17, 2018
7bde209
filter inputs with diff memory dims
azai91 Jul 17, 2018
b256246
fix lint
azai91 Jul 17, 2018
7457c8e
remove extra spaces
azai91 Jul 17, 2018
78d5046
fix lint
azai91 Jul 18, 2018
96d055d
add deconv test
azai91 Jul 30, 2018
a352578
add calc devconv size
azai91 Jul 30, 2018
bacb5ee
remove bias from deconv input
azai91 Jul 31, 2018
44b7a7f
create deconv kernel
azai91 Jul 31, 2018
b7f4ce1
fix num outputs for deconv
azai91 Jul 31, 2018
4beceb4
fix lint
azai91 Jul 31, 2018
41ca71b
use random inputs for deconv
azai91 Jul 31, 2018
2207dde
can init random mldnn array
azai91 Jul 31, 2018
b71bc6a
round scale
azai91 Jul 31, 2018
5eb999d
update for loops with size_t instead of ints
azai91 Aug 2, 2018
8a9e7b3
remove comment
azai91 Aug 3, 2018
96a099a
fix merge
azai91 Aug 3, 2018
ca4d60b
use bounded random inputs
azai91 Aug 6, 2018
72f8b19
merge from master
azai91 Aug 21, 2018
f4d02af
fix merge issue
azai91 Aug 21, 2018
7c73d7f
conv op uses filter
azai91 Aug 22, 2018
17c90dd
reorder if view
azai91 Aug 22, 2018
d08a300
reorder backwards
azai91 Aug 22, 2018
c4ac8e8
rename to out_grad
azai91 Aug 22, 2018
4b24d0c
fix lint
azai91 Aug 22, 2018
367acae
filter pooling tpyes
azai91 Aug 23, 2018
d7d3134
merge from master
azai91 Oct 3, 2018
13e4a2d
reorder
azai91 Oct 3, 2018
95f6f43
add bias
azai91 Oct 3, 2018
0db919b
fix typo
azai91 Oct 3, 2018
bd7c6d5
fix ref
azai91 Oct 3, 2018
b4c0125
filter arrays
azai91 Oct 3, 2018
bab06ac
reorder devcon inputs
azai91 Oct 3, 2018
2ad8167
reorder devonc forward inputs
azai91 Oct 4, 2018
23984df
merge master
azai91 Oct 11, 2018
973916d
fix missing var
azai91 Oct 11, 2018
2b24f8e
remove unused var
azai91 Oct 11, 2018
2ac82d7
reorder inputs for deconv forward
azai91 Oct 11, 2018
06abf6c
remove const
azai91 Oct 11, 2018
db0d60e
avoid reorder
azai91 Oct 11, 2018
99a6efa
set bias
azai91 Oct 11, 2018
bf4b82c
fix typo
azai91 Oct 11, 2018
0b990fb
set bias with string
azai91 Oct 11, 2018
701ca34
set bias with string
azai91 Oct 11, 2018
a13727e
remove use bias
azai91 Oct 11, 2018
18f718e
add bias
azai91 Oct 11, 2018
fd97540
add bias shape
azai91 Oct 24, 2018
0917b82
cannot use reshaped non n*** format
azai91 Oct 24, 2018
5b5c502
add spatial filter
azai91 Oct 25, 2018
e20e585
fix conv
azai91 Oct 25, 2018
1553ffd
fix missing conv
azai91 Oct 25, 2018
c481872
fix input
azai91 Oct 26, 2018
4ebc459
merge from master
azai91 Oct 27, 2018
b00bd4a
fix merge
azai91 Oct 29, 2018
fdde3a9
add missing header
azai91 Oct 29, 2018
fa7f9dd
add inline
azai91 Oct 29, 2018
88d4e90
fix input
azai91 Oct 30, 2018
9173a52
add spatial filter in test
azai91 Oct 30, 2018
f0bb845
fix get test input params
azai91 Oct 30, 2018
7330875
fix get test input params
azai91 Oct 30, 2018
83c0d26
Merge branch 'master' into test/mkldnn-conv-op
azai91 Oct 30, 2018
d747aa0
fix num inputs
azai91 Oct 30, 2018
7e13aa2
fix input num of backwards
azai91 Oct 30, 2018
7258fa0
fix bias
azai91 Oct 30, 2018
7daef88
add missing bias
azai91 Oct 30, 2018
01a3971
fix output num for backwards
azai91 Oct 30, 2018
976bf24
fix num outputs for deconv
azai91 Oct 30, 2018
97f2f11
fix test input
azai91 Oct 30, 2018
13cb25e
remove comments
azai91 Oct 30, 2018
89875de
use deconv param
azai91 Oct 30, 2018
8f94e9f
use template
azai91 Oct 30, 2018
d91301e
filter out incompatible widths
azai91 Oct 30, 2018
0e0fca8
fix lint
azai91 Oct 31, 2018
a1ee96d
add bn get command
azai91 Nov 1, 2018
94f7e7c
update comment
azai91 Nov 1, 2018
8d39197
add kwrite in place
azai91 Nov 1, 2018
0fe672b
fix lint
azai91 Nov 1, 2018
e1cd7e3
Merge branch 'master' into test/mkldnn-conv-op
azai91 Nov 6, 2018
2fcea68
reorder weights in deconv
azai91 Nov 8, 2018
5d3f32b
remove const from set_data handle
azai91 Nov 8, 2018
e0a36c9
merge from master
azai91 Nov 8, 2018
2658953
fix lint
azai91 Nov 8, 2018
68713e6
retrigger
azai91 Nov 9, 2018
3995a16
Merge branch 'master' into test/mkldnn-conv-op
azai91 Nov 12, 2018
24323f6
remove data format
azai91 Nov 12, 2018
d2428aa
retrigger
azai91 Nov 12, 2018
272fed4
retrigger
azai91 Nov 13, 2018
bfdc3d5
use check_eq
azai91 Nov 14, 2018
d9ca2a1
Merge branch 'master' into test/mkldnn-conv-op
azai91 Nov 14, 2018
5cec3ce
retrriger
azai91 Nov 14, 2018
bebe81f
merge from master
azai91 Nov 15, 2018
ce97f14
refactor deconv if/else block
azai91 Nov 19, 2018
292fff9
Merge branch 'test/mkldnn-conv-op' into test/mkldnn-batch-norm-op
azai91 Nov 20, 2018
867700b
merge master
azai91 Nov 20, 2018
7ee88c3
fix typo in fetching operator
azai91 Nov 20, 2018
9e1ac40
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Nov 21, 2018
e33aa06
fix from merge
azai91 Nov 21, 2018
01823f8
fix op name
azai91 Nov 21, 2018
ba13532
fix op name
azai91 Nov 21, 2018
28b308a
add accept dims
azai91 Nov 21, 2018
51cf139
fix num inputs
azai91 Nov 21, 2018
7f106ba
parse params
azai91 Nov 21, 2018
2e26f72
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Nov 21, 2018
6a3b1f9
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Nov 26, 2018
dbfea8b
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Nov 27, 2018
52deb18
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Dec 3, 2018
11b0d47
use second inputs
azai91 Dec 3, 2018
ad41fa6
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Dec 4, 2018
15320c9
create copies for inputs
azai91 Dec 4, 2018
fb6d436
fix
azai91 Dec 4, 2018
9f4c096
move copy into inner loop
azai91 Dec 4, 2018
21b8ff9
move copy into inner loop
azai91 Dec 4, 2018
f9e38ca
fix reference
azai91 Dec 4, 2018
4c2ba36
use send inputs array
azai91 Dec 4, 2018
5aae134
create second copy of fixutres
azai91 Dec 4, 2018
4d4d9b4
remove copy (temp)
azai91 Dec 4, 2018
90d27b3
do not init randomly
azai91 Dec 4, 2018
a952347
reorder gamma/beta
azai91 Dec 4, 2018
4f7bbee
fallback if arrays are mlkdnn
azai91 Dec 4, 2018
de15d0d
add missing semi colon
azai91 Dec 4, 2018
27bd60d
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Dec 5, 2018
b63ff97
make more copies of fixture
azai91 Dec 5, 2018
46e662b
refactor
azai91 Dec 5, 2018
ffdc57d
remove unused code
azai91 Dec 5, 2018
062bb25
move copy in scope
azai91 Dec 5, 2018
177dc41
fix get output
azai91 Dec 6, 2018
40fb075
add back reshape
azai91 Dec 6, 2018
4948bd5
wait to read from copies
azai91 Dec 6, 2018
4d577d0
clear array after use
azai91 Dec 6, 2018
d9b9965
fix typo
azai91 Dec 6, 2018
010551b
fix typo
azai91 Dec 6, 2018
41769b2
add back wait to read
azai91 Dec 6, 2018
94cb2f2
fix backwards inputs
azai91 Dec 6, 2018
5a4c954
copy backwards mkldnn memory
azai91 Dec 6, 2018
8c776d4
fix equality
azai91 Dec 6, 2018
3499323
fix num size
azai91 Dec 6, 2018
6095ff5
add wait for all after copy
azai91 Dec 6, 2018
6db733f
store memory in external array
azai91 Dec 6, 2018
df2e34c
use const
azai91 Dec 6, 2018
3c595e5
check if mkldnn
azai91 Dec 6, 2018
5da8b4a
write to copy
azai91 Dec 7, 2018
8dd1f10
add reshape back
azai91 Dec 7, 2018
636f581
add more inputs
azai91 Dec 7, 2018
97c7e1a
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Dec 7, 2018
99682fb
do not use reshaped input
azai91 Dec 7, 2018
06a49c2
limit inputs / outputs
azai91 Dec 7, 2018
e6999dd
remove unused methods
azai91 Dec 7, 2018
05eeea0
remove white space
azai91 Dec 7, 2018
b587ae0
fix lint
azai91 Dec 7, 2018
2e6b2ea
fix lint
azai91 Dec 7, 2018
8f65751
fix int
azai91 Dec 7, 2018
24b1221
Merge branch 'master' into test/mkldnn-batch-norm-op
azai91 Dec 10, 2018
1a78c12
retrigger
azai91 Dec 10, 2018
9b40243
retrigger
azai91 Dec 11, 2018
11fb970
copy memory for non mkldnn
azai91 Dec 11, 2018
0e2e50e
fix
azai91 Dec 11, 2018
9099c34
cannot copy views
azai91 Dec 11, 2018
3e0c964
dont wait for all
azai91 Dec 11, 2018
2abe031
reorder invalid inputs
azai91 Dec 11, 2018
3badbca
add wait for all back in
azai91 Dec 11, 2018
8342e94
move array out of scope
azai91 Dec 11, 2018
d760c0d
copy memory from tmp
azai91 Dec 11, 2018
06d6a52
dont store mem
azai91 Dec 11, 2018
401bcc3
create separate text fixture for bn
azai91 Dec 11, 2018
dbfc67b
fix spacing
azai91 Dec 11, 2018
1a216df
add missing vars
azai91 Dec 11, 2018
cd6c9cc
revert
azai91 Dec 11, 2018
8717513
revert
azai91 Dec 11, 2018
22d2d7b
comment out mem
azai91 Dec 11, 2018
b313629
use bn helper back
azai91 Dec 11, 2018
3faeddb
remove whitespace
azai91 Dec 11, 2018
60723de
fix lint
azai91 Dec 11, 2018
b4070b3
remove copy
azai91 Dec 11, 2018
c75022b
retrigger
azai91 Dec 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,20 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
}

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
static inline bool SupportMKLDNNBN(const std::vector<NDArray> &inputs,
const BatchNormParam &param) {
TShape shape = inputs[0].shape();
bool params_valid = shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& shape[param.axis] % 8 == 0
&& !mxnet::op::batchnorm::disable_mkl;
bool inputs_valid = SupportMKLDNN(inputs[0]);
for (size_t i = 1; i < inputs.size(); i++) {
if (inputs[i].IsMKLDNNData()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once inputs_valid is set to false, we need not check remaining inputs, we can add a break here.
if (!inputs_valid) break;

inputs_valid = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs valid is set to false if it is in MKLDNN format ?

}
}
return params_valid && inputs_valid;
}

void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
Expand All @@ -396,7 +404,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {
if (SupportMKLDNNBN(inputs, param) && inputs[0].IsMKLDNNData()) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());

Expand All @@ -420,7 +428,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,

TShape shape = inputs[0].shape();
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param)
if (SupportMKLDNNBN(inputs, param)
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
Expand Down
14 changes: 12 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,24 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = out_data[batchnorm::kOut];

auto gamma_buffer = in_data[batchnorm::kGamma];
if (gamma_buffer.IsMKLDNNData()) {
gamma_buffer = gamma_buffer.Reorder2Default();
}

auto beta_buffer = in_data[batchnorm::kBeta];
if (beta_buffer.IsMKLDNNData()) {
beta_buffer = beta_buffer.Reorder2Default();
}

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());

// mxnet will always use scale shift.
// But if fix_gamma is true, then all scale elements will be set to 1.0f
if (flags & use_scale_shift) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
const NDArray &gamma = gamma_buffer;
const NDArray &beta = beta_buffer;
CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);

Expand Down
155 changes: 154 additions & 1 deletion tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in
return attrs;
}

OpAttrs GetBNOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("BatchNorm");
attrs.num_inputs = 5;
attrs.num_outputs = 3;
attrs.accept_dims.insert(4);
attrs.requests.insert(OpReqType::kWriteTo);
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.input_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
attrs.output_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
return attrs;
}

OpAttrs GetBNBackwardOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("_backward_BatchNorm");
attrs.num_inputs = 8;
attrs.num_outputs = 3;
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.requests.insert(OpReqType::kWriteTo);
return attrs;
}

void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
Expand Down Expand Up @@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {

// If the array is a view, we shouldn't write data to it.
if (in_arr.arr.IsView())
continue;
continue;

NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < forward_attrs.num_inputs; i++)
Expand All @@ -735,6 +760,128 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}


void TestOpExBNBackward(const OpAttrs &forward_attrs,
const OpAttrs &backwards_attrs,
const OpReqType &req,
const std::vector<NDArray*> &inputs,
const std::vector<NDArray*> &outputs,
const NDArrayAttrs &in_arr,
const NDArrayAttrs &out_arr) {
std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);

std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);

std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
std::vector<OpReqType> back_req(backwards_attrs.num_outputs);

if (req == kWriteTo) {
backwards_input[0] = outputs[0]; // output grad
backwards_input[1] = outputs[1]; // mean
backwards_input[2] = outputs[2]; // var
backwards_input[3] = inputs[0]; // data
backwards_input[4] = inputs[1]; // gamma
backwards_input[5] = inputs[2]; // beta
backwards_input[6] = inputs[3]; // moving mean
backwards_input[7] = inputs[4]; // moving var


for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
auto tmp_output = in_arr.arr;
backwards_buffer.emplace_back(tmp_output.Copy(Context()));
backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
backwards_outputs[i] = &backwards_buffer.back();
backwards_ex_outputs[i] = &backwards_buffer2.back();
Engine::Get()->WaitForAll();
}


for (int i = 0; i < backwards_attrs.num_outputs; i++)
back_req[i] = kWriteTo;

std::cout << "Backwards: ";
PrintVerifyMsg(out_arr, in_arr);
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
}
}

// compares output of fcompute with fcomputex
void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
std::vector<OpReqType> req(forward_attrs.num_outputs);

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, false);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
for (int i1 = 0; i1 < in_arrs.size(); i1++) {
auto in_arr = in_arrs[i1];

CHECK_NE(forward_attrs.accept_dims.size(), 0);
if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
forward_attrs.accept_dims.end())
continue;
for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
}
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
inputs_buffer.clear();
inputs2_buffer.clear();

for (int i = 0; i < forward_attrs.num_inputs; i++) {
inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
Engine::Get()->WaitForAll();
inputs[i] = &inputs_buffer.back();
inputs2[i] = &inputs2_buffer.back();
}
for (int i = 0; i < forward_attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
}
Imperative::Get()->set_is_training(true);

PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs, outputs, req,
DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(outputs, ex_outputs);

if (!backwards_attrs.requests.empty()) {
TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
inputs, outputs, in_arr, out_arrs[0][output_i]);
}
}
}
}
}

// Computes second dimension of FC weight matrix based on input shape
uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
uint32_t dim = 1;
Expand Down Expand Up @@ -1204,4 +1351,10 @@ TEST(IMPERATIVE, DeconvOp) {
}
}

TEST(IMPERATIVE, BNOp) {
OpAttrs forward_attrs = GetBNOp();
OpAttrs backwards_attrs = GetBNBackwardOp();
TestOpExBN(forward_attrs, backwards_attrs);
}

#endif