Skip to content

Commit

Permalink
Tpetra: Fix complex build error with #1665
Browse files Browse the repository at this point in the history
@trilinos/tpetra If Scalar could be std::complex<T>, it needs to turn
into impl_scalar_type (via reinterpret_cast) before it enters the
Kokkos world.
  • Loading branch information
Mark Hoemmen committed Sep 8, 2017
1 parent d3b6fa2 commit dbd33f5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
3 changes: 2 additions & 1 deletion packages/tpetra/core/src/Tpetra_CrsMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8170,7 +8170,8 @@ namespace Tpetra {
numImportPacketsPerLID, constantNumPackets,
Distor, INSERT, NumSameIDs, PermuteToLIDs,
PermuteFromLIDs, N, mynnz, MyPID,
CSR_rowptr (), CSR_colind_GID (), CSR_vals (),
CSR_rowptr (), CSR_colind_GID (),
Teuchos::av_reinterpret_cast<impl_scalar_type> (CSR_vals ()),
SourcePids (), TargetPids);

/**************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ unpackAndCombineIntoCrsArrays (
const int MyTargetPID,
const Teuchos::ArrayView<size_t>& CRS_rowptr,
const Teuchos::ArrayView<GlobalOrdinal>& CRS_colind,
const Teuchos::ArrayView<Scalar>& CRS_vals,
const Teuchos::ArrayView<typename CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, false>::impl_scalar_type>& CRS_vals,
const Teuchos::ArrayView<const int>& SourcePids,
Teuchos::Array<int>& TargetPids);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ unpackAndCombineIntoCrsArrays (
const int MyTargetPID,
const Teuchos::ArrayView<size_t>& CRS_rowptr,
const Teuchos::ArrayView<GlobalOrdinal>& CRS_colind,
const Teuchos::ArrayView<Scalar>& CRS_vals,
const Teuchos::ArrayView<typename CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, false>::impl_scalar_type>& CRS_vals,
const Teuchos::ArrayView<const int>& SourcePids,
Teuchos::Array<int>& TargetPids)
{
Expand Down Expand Up @@ -1343,10 +1343,30 @@ unpackAndCombineIntoCrsArrays (
create_mirror_view_from_raw_host_array(outputDevice, CRS_colind.getRawPtr(),
CRS_colind.size(), true, "crs_colidx");

#ifdef HAVE_TPETRA_INST_COMPLEX_DOUBLE
static_assert (! std::is_same<
typename std::remove_const<
typename std::decay<
decltype (CRS_vals)
>::type::value_type
>::type,
std::complex<double> >::value,
"CRS_vals::value_type is std::complex<double>; this should never happen"
", since std::complex does not work in Kokkos::View objects.");
#endif // HAVE_TPETRA_INST_COMPLEX_DOUBLE

auto crs_vals_d =
create_mirror_view_from_raw_host_array(outputDevice, CRS_vals.getRawPtr(),
CRS_vals.size(), true, "crs_vals");

#ifdef HAVE_TPETRA_INST_COMPLEX_DOUBLE
static_assert (! std::is_same<
typename decltype (crs_vals_d)::non_const_value_type,
std::complex<double> >::value,
"crs_vals_d::non_const_value_type is std::complex<double>; this should "
"never happen, since std::complex does not work in Kokkos::View objects.");
#endif // HAVE_TPETRA_INST_COMPLEX_DOUBLE

auto src_pids_d =
create_mirror_view_from_raw_host_array(outputDevice, SourcePids.getRawPtr(),
SourcePids.size(), true, "src_pids");
Expand Down Expand Up @@ -1383,6 +1403,14 @@ unpackAndCombineIntoCrsArrays (
outArg(num_bytes_per_value));
}

#ifdef HAVE_TPETRA_INST_COMPLEX_DOUBLE
static_assert (! std::is_same<
typename decltype (crs_vals_d)::non_const_value_type,
std::complex<double> >::value,
"crs_vals_d::non_const_value_type is std::complex<double>; this should "
"never happen, since std::complex does not work in Kokkos::View objects.");
#endif // HAVE_TPETRA_INST_COMPLEX_DOUBLE

unpackAndCombineIntoCrsArraysImpl(local_matrix, local_col_map,
import_lids_d, imports_d, num_packets_per_lid_d, permute_to_lids_d,
permute_from_lids_d, crs_rowptr_d, crs_colind_d, crs_vals_d, src_pids_d,
Expand Down Expand Up @@ -1439,7 +1467,7 @@ unpackAndCombineIntoCrsArrays (
const int, \
const Teuchos::ArrayView<size_t>&, \
const Teuchos::ArrayView<GO>&, \
const Teuchos::ArrayView<ST>&, \
const Teuchos::ArrayView<CrsMatrix<ST, LO, GO, NT, false>::impl_scalar_type>&, \
const Teuchos::ArrayView<const int>&, \
Teuchos::Array<int>&); \
template size_t \
Expand Down
38 changes: 20 additions & 18 deletions packages/tpetra/core/test/ImportExport2/ImportExport2_UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,7 @@ TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL( Import_Util, UnpackAndCombineWithOwningPIDs,
typedef typename Tpetra::CrsMatrix<Scalar, LO, GO>::packet_type PacketType;
typedef typename MapType::device_type device_type;
typedef Tpetra::global_size_t GST;
typedef typename CrsMatrixType::impl_scalar_type IST;

RCP<const Comm<int> > Comm = getDefaultComm();
RCP<CrsMatrixType> A,B;
Expand Down Expand Up @@ -1941,24 +1942,25 @@ TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL( Import_Util, UnpackAndCombineWithOwningPIDs,
Teuchos::Array<int> TargetPids;

using Tpetra::Details::unpackAndCombineIntoCrsArrays;
unpackAndCombineIntoCrsArrays<Scalar, LO, GO, Node> (*A,
Importer->getRemoteLIDs (),
imports (),
numImportPackets (),
constantNumPackets,
distor,
Tpetra::INSERT,
Importer->getNumSameIDs (),
Importer->getPermuteToLIDs (),
Importer->getPermuteFromLIDs (),
MapTarget->getNodeNumElements (),
nnz2,
MyPID,
rowptr (),
colind (),
vals (),
SourcePids (),
TargetPids);
unpackAndCombineIntoCrsArrays<Scalar, LO, GO, Node> (
*A,
Importer->getRemoteLIDs (),
imports (),
numImportPackets (),
constantNumPackets,
distor,
Tpetra::INSERT,
Importer->getNumSameIDs (),
Importer->getPermuteToLIDs (),
Importer->getPermuteFromLIDs (),
MapTarget->getNodeNumElements (),
nnz2,
MyPID,
rowptr (),
colind (),
Teuchos::av_reinterpret_cast<IST> (vals ()),
SourcePids (),
TargetPids);
// Do the comparison
Teuchos::ArrayRCP<const size_t> Browptr;
Teuchos::ArrayRCP<const LO> Bcolind;
Expand Down

0 comments on commit dbd33f5

Please sign in to comment.