diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h new file mode 100644 index 000000000000..b97138a88f97 --- /dev/null +++ b/src/operator/contrib/index_copy-inl.h @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_copy-inl.h + * \brief implementation of index_copy tensor operation + */ + +#ifndef MXNET_OPERATOR_CONTRIB_INDEX_COPY_INL_H_ +#define MXNET_OPERATOR_CONTRIB_INDEX_COPY_INL_H_ + +#include +#include +#include +#include +#include "../elemwise_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +template +struct index_copy_forward { + template + MSHADOW_XINLINE static void Map(int i, + int dim, + IType* index, + DType* new_tensor, + DType* out_tensor) { + DType* out_ptr = out_tensor + static_cast(index[i]) * dim; + DType* new_ptr = new_tensor + i * dim; + for (int idx = 0; idx < dim; ++idx) { + KERNEL_ASSIGN(out_ptr[idx], req, new_ptr[idx]); + } + } +}; + +template +void IndexCopyForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& out = outputs[0]; + const TBlob& original_tensor = inputs[0]; + const TBlob& idx_vector = inputs[1]; + const TBlob& copied_tensor = inputs[2]; + int dim = inputs[2].Size() / inputs[1].Size(); + // copy original tensor to output + mxnet_op::copy(s, out, original_tensor); + // index copy + MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(idx_vector.type_flag_, IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch(s, + idx_vector.Size(), dim, + idx_vector.dptr(), + copied_tensor.dptr(), + out.dptr()); + }); + }); + }); +} + +template +struct index_copy_backward { + template + MSHADOW_XINLINE static void Map(int i, + int dim, + int index_size, + DType* out_grad, + IType* index, + DType* in_grad_1, + DType* in_grad_2) { + // Copy to in_grad_2 + for (int p = 0; p < index_size; ++p) { + int idx = static_cast(index[p]); + if (i >= idx*dim && i < (idx+1)*dim) { + int offset = i - idx*dim; + KERNEL_ASSIGN(in_grad_2[p*dim+offset], req, out_grad[i]); + return; + } + } + // Copy to in_grad_1 + KERNEL_ASSIGN(in_grad_1[i], req, out_grad[i]); + } +}; + +template +void IndexCopyBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 4U); + CHECK_EQ(outputs.size(), 3U); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& index = inputs[2]; + const TBlob& in_grad_1 = outputs[0]; + const TBlob& in_grad_2 = outputs[2]; + int dim = inputs[3].Size() / inputs[2].Size(); + int index_size = inputs[2].Size(); + // index_copy_backward + MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(index.type_flag_, IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch(s, + out_grad.Size(), + dim, index_size, + out_grad.dptr(), + index.dptr(), + in_grad_1.dptr(), + in_grad_2.dptr()); + }); + }); + }); +} + +inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + // inputs[0]: original tensor + // inputs[1]: index vector + // inputs[2]: copied tensor + CHECK_EQ(in_attrs->size(), 3U); + // outputs[0]: a new tensor + CHECK_EQ(out_attrs->size(), 1U); + // inputs[1] must be a vector + CHECK_EQ(in_attrs->at(1).ndim(), 1); + // Shape matching + CHECK_EQ(in_attrs->at(0).ndim(), in_attrs->at(2).ndim()); + for (size_t i = 0; i < in_attrs->at(0).ndim(); ++i) { + if (i == 0) { + CHECK_GE(in_attrs->at(0)[i], in_attrs->at(2)[i]); + } else { + CHECK_EQ(in_attrs->at(0)[i], in_attrs->at(2)[i]); + } + } + // The the length of the fitrst dim of copied tensor + // must equal to the size of index vector + CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0).ndim() != 0U && + out_attrs->at(0).Size() != 0U; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_INDEX_COPY_INL_H_ diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc new file mode 100644 index 000000000000..07067a3f993b --- /dev/null +++ b/src/operator/contrib/index_copy.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_copy.cc + * \brief + */ +#include "./index_copy-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_index_copy) +.describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor` by +selecting the indices in the order given in `index`. The output will be a new tensor +contains the rest elements of old tensor and the copied elements of new tensor. +For example, if `index[i] == j`, then the `i`th row of `new_tensor` is copied to the +`j`th row of output. + +The `index` must be a vector and it must have the same size with the `0`th dimimention of +`new_tensor`. Also, the `0`th dimimention of old_tensor must `>=` the `0`th dimimention of +`new_tensor`, or an error will be raised. + +Examples:: + +x = mx.nd.zeros((5,3)) +t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]]) +index = mx.nd.array([0,4,2]) + +mx.nd.contrib.index_copy(x, index, t) + +[[1. 2. 3.] + [0. 0. 0.] + [7. 8. 9.] + [0. 0. 0.] + [4. 5. 6.]] + + +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FInferShape", IndexCopyShape) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"}) +.set_attr("FCompute", IndexCopyForward) +.add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor") +.add_argument("index_vector", "NDArray-or-Symbol", "Index vector") +.add_argument("new_tensor", "NDArray-or-Symbol", "New tensor to be copied"); + +NNVM_REGISTER_OP(_contrib_backward_index_copy) +.set_num_inputs(4) +.set_num_outputs(3) +.set_attr("TIsBackward", true) +.set_attr("FCompute", IndexCopyBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/index_copy.cu b/src/operator/contrib/index_copy.cu new file mode 100644 index 000000000000..dc416114b04d --- /dev/null +++ b/src/operator/contrib/index_copy.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_copy.cc + * \brief + */ +#include "./index_copy-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_index_copy) +.set_attr("FCompute", IndexCopyForward); + +NNVM_REGISTER_OP(_contrib_backward_index_copy) +.set_attr("FCompute", IndexCopyBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index a7f484e81b38..a1eccf761f0e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4639,6 +4639,29 @@ def test_quantization_op(): assert same(qa.asnumpy(), qa_real.asnumpy()) assert same(a_.asnumpy(), a_real.asnumpy()) +@with_seed() +def test_index_copy(): + x = mx.nd.zeros((5,3)) + t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]]) + index = mx.nd.array([0,4,2]) + + x.attach_grad() + t.attach_grad() + index.attach_grad() + + with mx.autograd.record(): + out = mx.nd.contrib.index_copy(x, index, t) + out.backward() + + tensor = mx.nd.array([[1,2,3],[0,0,0],[7,8,9],[0,0,0],[4,5,6]]) + x_grad = mx.nd.array([[0,0,0],[1,1,1],[0,0,0],[1,1,1],[0,0,0]]) + t_grad = mx.nd.array([[1,1,1],[1,1,1],[1,1,1]]) + index_grad = mx.nd.array([0,0,0]) + + assert same(out.asnumpy(), tensor.asnumpy()) + assert same(x.grad.asnumpy(), x_grad.asnumpy()) + assert same(t.grad.asnumpy(), t_grad.asnumpy()) + assert same(index.grad.asnumpy(), index_grad.asnumpy()) @with_seed() def test_div_sqrt_dim():