Skip to content

Commit

Permalink
modify broadcast_like_op.cpp and add test (#8720)
Browse files Browse the repository at this point in the history
* modify broadcast_like_op.cpp and add test

* modify broadcast_like_op.cpp

* Update broadcast_like_op.cpp

Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 22, 2022
1 parent 737878e commit 7664464
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
24 changes: 20 additions & 4 deletions oneflow/user/ops/broadcast_like_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,25 @@ Maybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {
}

bool IsAxesLegal(const AxisVector& axis_vec, const Shape& like_shape, const Shape& in_shape) {
Shape reduced_shape = CreateReducedShape(like_shape, axis_vec);
Shape reduced_like_shape = CreateReducedShape(like_shape, axis_vec);
if (like_shape.NumAxes() > in_shape.NumAxes()) {
reduced_shape = reduced_shape.RemoveOnes(axis_vec);
std::vector<int64_t> in_shape_vec;
in_shape_vec.reserve(in_shape.NumAxes());
std::vector<int64_t> like_shape_vec;
like_shape_vec.reserve(reduced_like_shape.NumAxes());
for (const int64_t& dim : in_shape.dim_vec()) {
if (dim != 1) { in_shape_vec.emplace_back(dim); }
}
for (const int64_t& dim : reduced_like_shape.dim_vec()) {
if (dim != 1) { like_shape_vec.emplace_back(dim); }
}
if (in_shape_vec.size() > like_shape_vec.size()) {
return false;
} else {
return std::equal(in_shape_vec.begin(), in_shape_vec.end(), like_shape_vec.begin());
}
}
return reduced_shape.dim_vec() == in_shape.dim_vec();
return reduced_like_shape.dim_vec() == in_shape.dim_vec();
}

Maybe<void> InferTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -81,7 +95,9 @@ Maybe<void> InferTensorDesc(user_op::InferContext* ctx) {
Shape* out_shape = ctx->MutOutputShape("y", 0);
Stride* out_stride = ctx->MutOutputStride("y", 0);
const AxisVector axis_vec = {broadcast_axes.begin(), broadcast_axes.end()};
CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape));
CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape))
<< Error::RuntimeError() << "Invalid input parameter: like shape:" << like_shape.ToString()
<< ", in shape:" << in_shape.ToString() << ", axis_vec size:" << axis_vec.size();
*out_shape = like_shape;
*out_stride = Stride(like_shape);
return Maybe<void>::Ok();
Expand Down
68 changes: 68 additions & 0 deletions python/oneflow/test/modules/test_broadcast_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,54 @@ def _test_broadcast_like(test_case, device):
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_one(test_case, device):
input = flow.tensor(
np.ones(shape=(1, 1), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
like_tensor = flow.tensor(
np.ones(shape=(1, 2, 3), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
of_out = flow.broadcast_like(input, like_tensor)
np_out = np.ones(shape=(1, 2, 3))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_different_dim(test_case, device):
input = flow.tensor(
np.ones(shape=(3, 1), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
like_tensor = flow.tensor(
np.ones(shape=(2, 3, 4), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
of_out = flow.broadcast_like(input, like_tensor)
np_out = np.ones(shape=(2, 3, 4))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_different_dim_with_input_axisvec(test_case, device):
input = flow.tensor(
np.ones(shape=(1, 5, 6), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
like_tensor = flow.tensor(
np.ones(shape=(1, 5, 6, 1, 6), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(3, 4))
np_out = np.ones(shape=(1, 5, 6, 1, 6))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_3dim(test_case, device):
input = flow.tensor(
np.ones(shape=(1, 3, 2), dtype=np.float32),
Expand Down Expand Up @@ -72,6 +120,22 @@ def _test_broadcast_like_4dim(test_case, device):
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_empty_axisvec(test_case, device):
input = flow.tensor(
np.ones(shape=(1), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
like_tensor = flow.tensor(
np.ones(shape=(2, 3, 4), dtype=np.float32),
dtype=flow.float32,
device=flow.device(device),
)
of_out = flow.broadcast_like(input, like_tensor)
np_out = np.ones(shape=(2, 3, 4))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_broadcast_like_backward(test_case, device):
input = flow.tensor(
np.ones(shape=(3, 1, 1), dtype=np.float32),
Expand All @@ -98,8 +162,12 @@ def test_broadcast_like(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_broadcast_like,
_test_broadcast_like_one,
_test_broadcast_like_different_dim,
_test_broadcast_like_different_dim_with_input_axisvec,
_test_broadcast_like_3dim,
_test_broadcast_like_4dim,
_test_broadcast_like_empty_axisvec,
_test_broadcast_like_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
Expand Down

0 comments on commit 7664464

Please sign in to comment.