diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index f71b8cf8e37f..07c4709bafde 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -45,16 +45,46 @@ namespace op { namespace seq_last { enum SequenceLastOpInputs { kData, kSequenceLength }; enum SequenceLastOpOutputs { kOut }; +enum SequenceLastOpResource { kTempSpace }; } struct SequenceLastParam : public dmlc::Parameter { bool use_sequence_length; + int axis; DMLC_DECLARE_PARAMETER(SequenceLastParam) { DMLC_DECLARE_FIELD(use_sequence_length) .set_default(false) .describe( - "If set to true, this layer takes in an extra input parameter `sequence_length` " + "If set to true, this layer takes in an extra input parameter " + "`sequence_length` " "to specify variable length sequence"); + DMLC_DECLARE_FIELD(axis).set_default(0).describe( + "The sequence axis. Only values of 0 and 1 are currently supported."); + } +}; + +template +struct SequenceLastKernel { + template + MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, + const DType *idx, int offset1, int offset2, + mshadow::Shape<2> oshape) { + const auto opos = mxnet_op::unravel(i, oshape); + const int seqpos = static_cast(idx[opos[0]]) - 1; + const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; + KERNEL_ASSIGN(out[i], req, in[ipos]); + } +}; + +struct SequenceLastGradKernel { + template + MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad, + const DType *idx, int offset1, int offset2, + mshadow::Shape<2> oshape) { + const auto opos = mxnet_op::unravel(i, oshape); + const int seqpos = static_cast(idx[opos[0]]) - 1; + const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; + in_grad[ipos] += out_grad[i]; } }; @@ -63,6 +93,47 @@ class SequenceLastOp : public Operator { public: explicit SequenceLastOp(SequenceLastParam p) { this->param_ = p; } + void sequence_last(const mshadow::Tensor &data, + const mshadow::Tensor &out, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s) { + using namespace mshadow; + using namespace mshadow::expr; + + int axis = param_.axis; + int out_size = out.size(0) * out.size(1); + int max_seq_len = data.size(axis); + int offset1 = axis ? out.size(1) : out_size; + int offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1); + + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, out_size, out.dptr_, data.dptr_, indices.dptr_, offset1, offset2, + out.shape_); + }); + } + + void sequence_last_grad(const mshadow::Tensor &in_grad, + const mshadow::Tensor &out_grad, + const mshadow::Tensor &indices, + mshadow::Stream *const s) { + using namespace mshadow; + using namespace mshadow::expr; + + auto axis = param_.axis; + int batch = out_grad.size(0); + int rest = out_grad.size(1); + int out_size = batch * rest; + + int max_seq_len = in_grad.size(axis); + int offset1 = axis ? rest : out_size; + int offset2 = axis ? (max_seq_len * rest) : rest; + + mxnet_op::Kernel::Launch( + s, out_size, in_grad.dptr_, out_grad.dptr_, indices.dptr_, offset1, + offset2, out_grad.shape_); + } + virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, @@ -74,33 +145,32 @@ class SequenceLastOp : public Operator { CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); + // only support axis of 0 or 1 for now + auto axis = param_.axis; + // Get any size input + output into required form - index_t n = in_data[seq_last::kData].size(1); - int max_seq_len = in_data[seq_last::kData].size(0); - int total_size = in_data[seq_last::kData].Size(); - Shape<2> s2 = Shape2(n, static_cast(total_size / n / max_seq_len)); - Shape<3> s3 = - Shape3(max_seq_len, n, static_cast(total_size / n / max_seq_len)); + auto d0 = in_data[seq_last::kData].size(0); + auto d1 = in_data[seq_last::kData].size(1); + auto dsize = in_data[seq_last::kData].Size(); + + auto batch = (axis != 0) ? d0 : d1; + auto max_seq_len = in_data[seq_last::kData].size(axis); + auto rest_size = dsize / (d0 * d1); + Tensor data = - in_data[seq_last::kData].get_with_shape(s3, s); + in_data[seq_last::kData].get_with_shape( + Shape3(d0, d1, rest_size), s); Tensor out = - out_data[seq_last::kOut].get_with_shape(s2, s); - - if (param_.use_sequence_length) { - std::vector indices_vec(n, max_seq_len); - IndexTensorToVector( - in_data[seq_last::kSequenceLength].get(s), - &indices_vec); - if (req[seq_last::kOut] == kWriteTo) out = 0.0f; - index_t seq_ind; - for (index_t i = 0; i < n; ++i) { - seq_ind = indices_vec[i] - 1; // 1-indexing - out[i] += data[seq_ind][i]; - } - } else { - Assign(out, req[seq_last::kOut], - F(data[max_seq_len - 1])); - } + out_data[seq_last::kOut].get_with_shape( + Shape2(batch, rest_size), s); + Tensor indices = + param_.use_sequence_length + ? in_data[seq_last::kSequenceLength].get(s) + : ctx.requested[seq_last::kTempSpace] + .get_space_typed(Shape1(batch), s); + if (!param_.use_sequence_length) indices = max_seq_len; + + sequence_last(data, out, indices, req[seq_last::kOut], s); } virtual void Backward(const OpContext &ctx, @@ -119,33 +189,32 @@ class SequenceLastOp : public Operator { if (req[seq_last::kData] == kNullOp) return; Stream *s = ctx.get_stream(); + // only support axis of 0 or 1 for now + auto axis = param_.axis; // Get any size input + output into required form - index_t n = in_grad[seq_last::kData].size(1); - int max_seq_len = in_grad[seq_last::kData].size(0); - int total_size = in_grad[seq_last::kData].Size(); - Shape<2> s2 = Shape2(n, static_cast(total_size / n / max_seq_len)); - Shape<3> s3 = - Shape3(max_seq_len, n, static_cast(total_size / n / max_seq_len)); + auto d0 = in_data[seq_last::kData].size(0); + auto d1 = in_data[seq_last::kData].size(1); + auto dsize = in_data[seq_last::kData].Size(); + + auto batch = (axis != 0) ? d0 : d1; + auto max_seq_len = in_data[seq_last::kData].size(axis); + auto rest_size = dsize / (d0 * d1); Tensor data_grad = - in_grad[seq_last::kData].get_with_shape(s3, s); + in_grad[seq_last::kData].get_with_shape( + Shape3(d0, d1, rest_size), s); Tensor output_grad = - out_grad[seq_last::kOut].get_with_shape(s2, s); + out_grad[seq_last::kOut].get_with_shape( + Shape2(batch, rest_size), s); + Tensor indices = + param_.use_sequence_length + ? in_data[seq_last::kSequenceLength].get(s) + : ctx.requested[seq_last::kTempSpace] + .get_space_typed(Shape1(batch), s); - // copy indices to vector - std::vector indices_vec(n, max_seq_len); - if (param_.use_sequence_length) - IndexTensorToVector( - in_data[seq_last::kSequenceLength].get(s), - &indices_vec); - - index_t seq_ind; if (req[seq_last::kData] == kWriteTo) data_grad = 0.0f; - for (index_t i = 0; i < n; ++i) { - seq_ind = indices_vec[i] - 1; - data_grad[seq_ind][i] += output_grad[i]; - } + sequence_last_grad(data_grad, output_grad, indices, s); } private: @@ -183,18 +252,21 @@ class SequenceLastProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U) << "Input:[data, sequence_length]"; + CHECK((param_.axis == 0) || (param_.axis == 1)) + << "Current implementation expects axis to be 0 or 1."; const TShape &dshape = (*in_shape)[seq_last::kData]; CHECK_GT(dshape.ndim(), 1U) << "The data array must be of rank 2 or greater."; // seq length vector is same as batch size + int sbatch = param_.axis ? dshape[0] : dshape[1]; if (param_.use_sequence_length) - SHAPE_ASSIGN_CHECK(*in_shape, seq_last::kSequenceLength, - Shape1(dshape[1])); + SHAPE_ASSIGN_CHECK(*in_shape, seq_last::kSequenceLength, Shape1(sbatch)); // calculate output size TShape shape_o(dshape.ndim() - 1); - for (index_t i = 0; i < shape_o.ndim(); ++i) shape_o[i] = dshape[i + 1]; + shape_o[0] = sbatch; + for (index_t i = 1; i < shape_o.ndim(); ++i) shape_o[i] = dshape[i + 1]; const TShape &oshape = shape_o; out_shape->clear(); @@ -227,6 +299,16 @@ class SequenceLastProp : public OperatorProperty { std::string TypeString() const override { return "SequenceLast"; } + std::vector ForwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + std::vector BackwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h index 7f53a0ba82d7..a34cea04965e 100644 --- a/src/operator/sequence_mask-inl.h +++ b/src/operator/sequence_mask-inl.h @@ -32,12 +32,11 @@ #include #include #include -#include #include #include -#include "./operator_common.h" +#include #include "./mshadow_op.h" -#include "./nn/sequence_mask-inl.h" +#include "./operator_common.h" namespace mxnet { namespace op { @@ -45,19 +44,60 @@ namespace op { namespace seq_mask { enum SequenceMaskOpInputs { kData, kSequenceLength }; enum SequenceMaskOpOutputs { kOut }; +enum SequenceMaskOpBackResource { kTempSpace }; } struct SequenceMaskParam : public dmlc::Parameter { bool use_sequence_length; float value; + int axis; DMLC_DECLARE_PARAMETER(SequenceMaskParam) { DMLC_DECLARE_FIELD(use_sequence_length) .set_default(false) .describe( - "If set to true, this layer takes in an extra input parameter `sequence_length` " + "If set to true, this layer takes in an extra input parameter " + "`sequence_length` " "to specify variable length sequence"); DMLC_DECLARE_FIELD(value).set_default(0.).describe( "The value to be used as a mask."); + DMLC_DECLARE_FIELD(axis).set_default(0).describe( + "The sequence axis. Only values of 0 and 1 are currently supported."); + } +}; + +// (seqlen, batch, rest) case +template +struct SequenceMask0Kernel { + template + MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + const index_t seqpos = static_cast(idx[b]); +#pragma unroll + for (index_t s = seqpos; s < max_s_len; ++s) { + index_t incr = (s * batch_size * restsize) + (b * restsize); +#pragma unroll + for (index_t r = 0; r < restsize; ++r) + KERNEL_ASSIGN(in[incr + r], req, value); + } + } +}; + +// (batch, seqlen, rest) case +template +struct SequenceMask1Kernel { + template + MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + const index_t seqpos = static_cast(idx[b]); +#pragma unroll + for (index_t s = seqpos; s < max_s_len; ++s) { + index_t incr = (b * max_s_len * restsize) + (s * restsize); +#pragma unroll + for (index_t r = 0; r < restsize; ++r) + KERNEL_ASSIGN(in[incr + r], req, value); + } } }; @@ -66,6 +106,29 @@ class SequenceMaskOp : public Operator { public: explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; } + void sequence_mask(const mshadow::Tensor &data, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s, + DType val) { + using namespace mshadow; + using namespace mshadow::expr; + + index_t batch = indices.size(0); + index_t max_seq_len = data.size(param_.axis); + index_t restsize = data.size(2); + + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + if (param_.axis == 1) + mxnet_op::Kernel, xpu>::Launch( + s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + else + mxnet_op::Kernel, xpu>::Launch( + s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + }); + } + virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, @@ -77,21 +140,23 @@ class SequenceMaskOp : public Operator { Stream *s = ctx.get_stream(); // Get any size input + output into required form - int max_seq_len = in_data[seq_mask::kData].size(0); - int n = in_data[seq_mask::kData].size(1); - int total_size = in_data[seq_mask::kData].Size(); - int rest_dim = static_cast(total_size / n / max_seq_len); + auto d0 = in_data[seq_mask::kData].size(0); + auto d1 = in_data[seq_mask::kData].size(1); + auto dsize = in_data[seq_mask::kData].Size(); + auto rest_size = dsize / (d0 * d1); - Shape<3> s3 = Shape3(max_seq_len, n, rest_dim); + Shape<3> s3 = Shape3(d0, d1, rest_size); Tensor data = in_data[seq_mask::kData].get_with_shape(s3, s); Tensor out = out_data[seq_mask::kOut].get_with_shape(s3, s); + // Actual implementation of masking Assign(out, req[seq_mask::kOut], F(data)); if (param_.use_sequence_length) { Tensor indices = in_data[seq_mask::kSequenceLength].get(s); - mxnet_op::SequenceMask(out, indices, static_cast(param_.value)); + sequence_mask(out, indices, req[seq_mask::kOut], s, + static_cast(param_.value)); } } @@ -109,25 +174,36 @@ class SequenceMaskOp : public Operator { Stream *s = ctx.get_stream(); // Get any size input + output into required form - int max_seq_len = in_grad[seq_mask::kData].size(0); - int n = in_grad[seq_mask::kData].size(1); - int total_size = in_grad[seq_mask::kData].Size(); - int rest_dim = static_cast(total_size / n / max_seq_len); - - Shape<3> s3 = Shape3(max_seq_len, n, rest_dim); + auto d0 = in_grad[seq_mask::kData].size(0); + auto d1 = in_grad[seq_mask::kData].size(1); + auto dsize = in_grad[seq_mask::kData].Size(); + auto rest_size = dsize / (d0 * d1); - Tensor data_grad = + Shape<3> s3 = Shape3(d0, d1, rest_size); + Tensor data_g = in_grad[seq_mask::kData].get_with_shape(s3, s); - Tensor output_grad = + Tensor out_g = out_grad[seq_mask::kOut].get_with_shape(s3, s); - Assign(data_grad, req[seq_mask::kData], - F(output_grad)); - - if (param_.use_sequence_length) { + // Actual implementation of masking + if (req[seq_mask::kData] == kNullOp) return; + if (!param_.use_sequence_length) { + Assign(data_g, req[seq_mask::kData], F(out_g)); + } else { Tensor indices = in_data[seq_mask::kSequenceLength].get(s); - mxnet_op::SequenceMask(data_grad, indices, DType(0)); + if (req[seq_mask::kData] == kAddTo) { + Tensor out_g_temp = + ctx.requested[seq_mask::kTempSpace].get_space_typed( + s3, s); + out_g_temp = F(out_g); + out_g = out_g_temp; + sequence_mask(out_g, indices, kWriteInplace, s, DType(0.)); + Assign(data_g, kAddTo, F(out_g)); + } else { + Assign(data_g, req[seq_mask::kData], F(out_g)); + sequence_mask(data_g, indices, req[seq_mask::kData], s, DType(0.)); + } } } @@ -172,10 +248,13 @@ class SequenceMaskProp : public OperatorProperty { const TShape &dshape = (*in_shape)[seq_mask::kData]; CHECK_GT(dshape.ndim(), 1U) << "The data array must be of rank 2 or greater."; + CHECK((param_.axis == 0) || (param_.axis == 1)) + << "Current implementation expects axis to be 0 or 1."; + // seq length vector is same as batch size + int sbatch = param_.axis ? dshape[0] : dshape[1]; if (param_.use_sequence_length) - SHAPE_ASSIGN_CHECK(*in_shape, seq_mask::kSequenceLength, - Shape1(dshape[1])); + SHAPE_ASSIGN_CHECK(*in_shape, seq_mask::kSequenceLength, Shape1(sbatch)); const TShape &oshape = dshape; out_shape->clear(); @@ -222,6 +301,19 @@ class SequenceMaskProp : public OperatorProperty { return {ResourceRequest::kTempSpace}; } + std::vector > BackwardInplaceOption( + const std::vector &out_grad, const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + return {{out_grad[seq_mask::kOut], in_grad[seq_mask::kData]}}; + } + + std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const override { + return {{in_data[seq_mask::kData], out_data[seq_mask::kOut]}}; + } + Operator *CreateOperator(Context ctx) const override { LOG(FATAL) << "Not Implemented."; return NULL; diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h index 47154011bcbe..943ca6e933c9 100644 --- a/src/operator/sequence_reverse-inl.h +++ b/src/operator/sequence_reverse-inl.h @@ -51,6 +51,7 @@ enum SequenceReverseOpOutputs { kOut }; struct SequenceReverseParam : public dmlc::Parameter { bool use_sequence_length; + int axis; DMLC_DECLARE_PARAMETER(SequenceReverseParam) { DMLC_DECLARE_FIELD(use_sequence_length) .set_default(false) @@ -58,20 +59,23 @@ struct SequenceReverseParam : public dmlc::Parameter { "If set to true, this layer takes in an extra input parameter " "`sequence_length` " "to specify variable length sequence"); + DMLC_DECLARE_FIELD(axis).set_default(0).describe( + "The sequence axis. Only 0 is currently supported."); } }; struct ReverseKernel { template - MSHADOW_XINLINE static void Map( - const int i, DType *const out_data, const DType *const in_data, - const OpReqType req, const index_t max_seq_len, const index_t batch_size, - const index_t other_dim, const index_t numel, const DType *const indices - ) { + MSHADOW_XINLINE static void Map(const int i, DType *const out_data, + const DType *const in_data, + const OpReqType req, + const index_t max_seq_len, + const index_t batch_size, + const index_t other_dim, const index_t numel, + const DType *const indices) { for (index_t batch = 0; batch < batch_size; ++batch) { - const index_t num_seq = indices - ? static_cast(indices[batch]) - : max_seq_len; + const index_t num_seq = + indices ? static_cast(indices[batch]) : max_seq_len; const index_t padded_periods = max_seq_len - num_seq; // padded part if (padded_periods > 0 && i < static_cast(padded_periods)) { @@ -130,10 +134,10 @@ class SequenceReverseOp : public Operator { Stream *const s = ctx.get_stream(); // Get any size input + output into required form - int max_seq_len = in_data[seq_reverse::kData].size(0); - int n = in_data[seq_reverse::kData].size(1); - int total_size = in_data[seq_reverse::kData].Size(); - int rest_dim = static_cast(total_size / n / max_seq_len); + auto max_seq_len = in_data[seq_reverse::kData].size(0); + auto n = in_data[seq_reverse::kData].size(1); + auto total_size = in_data[seq_reverse::kData].Size(); + auto rest_dim = static_cast(total_size / n / max_seq_len); Shape<3> s3 = Shape3(max_seq_len, n, rest_dim); Tensor data = @@ -163,10 +167,10 @@ class SequenceReverseOp : public Operator { Stream *s = ctx.get_stream(); // Get any size input + output into required form - int max_seq_len = in_grad[seq_reverse::kData].size(0); - int n = in_grad[seq_reverse::kData].size(1); - int total_size = in_grad[seq_reverse::kData].Size(); - int rest_dim = static_cast(total_size / n / max_seq_len); + auto max_seq_len = in_grad[seq_reverse::kData].size(0); + auto n = in_grad[seq_reverse::kData].size(1); + auto total_size = in_grad[seq_reverse::kData].Size(); + auto rest_dim = static_cast(total_size / n / max_seq_len); Shape<3> s3 = Shape3(max_seq_len, n, rest_dim); @@ -180,7 +184,8 @@ class SequenceReverseOp : public Operator { ? in_data[seq_reverse::kSequenceLength].dptr() : nullptr; - sequence_reverse(output_grad, data_grad, req[seq_reverse::kData], indices, s); + sequence_reverse(output_grad, data_grad, req[seq_reverse::kData], indices, + s); } private: @@ -220,6 +225,7 @@ class SequenceReverseProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U) << "Input:[data, sequence_length]"; + CHECK_EQ(param_.axis, 0) << "Current implementation expects axis to be 0."; const TShape &dshape = (*in_shape)[seq_reverse::kData]; CHECK_GT(dshape.ndim(), 1U) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 966a955ad746..d169a5455bb8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2312,39 +2312,91 @@ def test_l2_normalization(): check_l2_normalization((nbatch, nchannel, height, width), mode) -def sequence_mask_numpy(array, lengths, value): +# Numpy Implementation of Sequence Ops +def sequence_last_numpy(array, lengths, axis): + # create new array of dims [batch, seqlen, ...] + array2 = np.moveaxis(array, axis, 1) + dims = array2.shape + if lengths is None: + return array2[:, -1] + lengths = list(lengths) + return np.array([array2[i, int(lengths[i]) - 1] for i in range(dims[0])]) + + +def sequence_mask_numpy(array, lengths, axis, value): + if lengths is None: + return array arrayMask = array.copy() - shape = array.shape - batch = shape[1] - for i in range(batch): - arrayMask[int(lengths[i]):, i] = value - return arrayMask - -def check_sequence_mask(shape, xpu, mask_value): + # conform to [batch, seqlen, ...] + arrayMask = np.moveaxis(arrayMask, axis, 1) + shape = arrayMask.shape + lengths = list(lengths) + for i in range(shape[0]): + arrayMask[i, int(lengths[i]):] = value + return np.moveaxis(arrayMask, 1, axis) + + +def sequence_reverse_numpy(array, lengths, axis): + rarray = array.copy() + # conform to [batch, seqlen, ...] + rarray = np.moveaxis(rarray, axis, 1) + shape = rarray.shape + if lengths is None: + lengths = [shape[1]] * shape[0] + lengths = list(lengths) + for i in range(shape[0]): + j = int(lengths[i]) + rarray[i,:j] = rarray[i,:j][::-1] + return np.moveaxis(rarray, 1, axis) + + +def check_sequence_func(ftype, mask_value=0, axis=0): # bind with label + xpu = default_context() X = mx.symbol.Variable('X') L = mx.symbol.Variable('L') # lengths - Y = mx.symbol.SequenceMask(data=X, use_sequence_length=True, sequence_length=L, value=mask_value) - x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu) - l = mx.nd.array(np.random.randint(1, shape[0] + 1, shape[1]), ctx=mx.cpu()).copyto(xpu) - # numpy result - np_out = sequence_mask_numpy(x.asnumpy(), l.asnumpy(), mask_value) - # mxnet result - exec1 = Y.bind(xpu, args = [x, l], grad_req={'X':'null', 'L':'null'}) - exec1.forward() - out = exec1.outputs[0].asnumpy() - # compare numpy + mxnet - assert_almost_equal(out, np_out, rtol=1e-5) - # grad check - check_numeric_gradient(Y, [x.asnumpy(), l.asnumpy()], grad_nodes={'X':'write'}, - numeric_eps=1e-3, rtol=1e-2) + shapes = [(3, 4), (1, 1), (3, 4, 3, 1, 1)] + for seqlenQ in [True, False]: + for s in shapes: + x = mx.random.uniform(-1, 1, s, ctx=mx.cpu()).copyto(xpu) + batch = s[1] if (axis == 0) else s[0] + seqlen = s[axis] + l_np = np.random.randint(1, seqlen + 1, batch) + l = mx.nd.array(l_np, ctx=mx.cpu()).copyto(xpu) + if not seqlenQ: + l_np = None + args = {'data':X, 'use_sequence_length':seqlenQ, "axis":axis} + if seqlenQ: + args['sequence_length'] = L + if ftype == "last": + Y = mx.symbol.SequenceLast(**args) + np_out = sequence_last_numpy(x.asnumpy(), l_np, axis) + elif ftype == "mask": + args['value'] = mask_value + Y = mx.symbol.SequenceMask(**args) + np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis, mask_value) + elif ftype == "reverse": + Y = mx.symbol.SequenceReverse(**args) + np_out = sequence_reverse_numpy(x.asnumpy(), l_np, axis) + fargs = [x, l] if seqlenQ else [x] + gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()] + check_symbolic_forward(Y, fargs, [np_out]) + check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'}, + numeric_eps=1e-2, rtol=1e-2) + check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'}, + numeric_eps=1e-3, rtol=1e-2, atol=1E-4) + check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'}, + numeric_eps=1e-3, rtol=1e-2, atol=1E-4) + + +def test_sequence_last(): + check_sequence_func("last", axis=0) + check_sequence_func("last", axis=1) + def test_sequence_mask(): - shape1 = (4, 2, 2, 3) - shape2 = (1, 2, 2, 3, 1, 1) - check_sequence_mask(shape1, default_context(), 2.1) - check_sequence_mask(shape2, default_context(), 0.1) - check_sequence_mask((3, 4), default_context(), 0.14) + check_sequence_func("mask", axis = 0, mask_value=-2.3) + check_sequence_func("mask", axis = 1, mask_value=0.3) def check_sequence_reverse(xpu): @@ -2411,7 +2463,9 @@ def test_wrapper(arr, xpu, sequence_length=None, use_sequence_length=False): assert_array_equal(test_wrapper(arr, xpu, sequence_length=[2, 3], use_sequence_length=True), arr3) assert_array_equal(test_wrapper(arr_4, xpu, sequence_length=seq_len_1, use_sequence_length=True), arr_5) + def test_sequence_reverse(): + check_sequence_func("reverse", axis=0) check_sequence_reverse(mx.cpu())