Skip to content

Commit

Permalink
Removed nlhs argument left from Matlab.
Browse files Browse the repository at this point in the history
  • Loading branch information
matinraayai committed May 8, 2022
1 parent 61cc994 commit 3b1c320
Showing 1 changed file with 50 additions and 43 deletions.
93 changes: 50 additions & 43 deletions src/pymcx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ namespace py = pybind11;
#define RAND_WORD_LEN 4 /**< number of Words per RNG state */
#endif

float *detps = nullptr; //! buffer to receive data from cfg.detphotons field
int dimdetps[2] = {0, 0}; //! dimensions of the cfg.detphotons array
int seedbyte = 0;
float *det_ps = nullptr; //! buffer to receive data from cfg.detphotons field
int dim_det_ps[2] = {0, 0}; //! dimensions of the cfg.detphotons array
int seed_byte = 0;

#define GET_SCALAR_FIELD(src, dst, prop, type) if (src.contains(#prop)) {dst.prop = py::reinterpret_borrow<type>(src[#prop]); std::cout << #prop << ": " << (float) dst.prop << std::endl;}
/**
* Macro to find and extract a scalar property from a source Python dictionary configuration and assign it in a destination
* MCX Config. The scalar is cast to the python type before assignment.
*/
#define GET_SCALAR_FIELD(src_pydict, dst_mcx_config, property, py_type) if ((src_pydict).contains(#property))\
{(dst_mcx_config).property = py::reinterpret_borrow<py_type>((src_pydict)[#property]);\
std::cout << #property << ": " << (float) (dst_mcx_config).property << std::endl;}

#define GET_VEC3_FIELD(src, dst, prop, type) if (src.contains(#prop)) {auto list = py::reinterpret_borrow<py::list>(src[#prop]);\
dst.prop = {list[0].cast<type>(), list[1].cast<type>(), list[2].cast<type>()};\
Expand All @@ -85,7 +91,7 @@ void parseVolume(const py::dict &userCfg, Config &mcxConfig) {
auto volumeHandle = userCfg["vol"];
// Free the volume
if (mcxConfig.vol) free(mcxConfig.vol);
unsigned int dim_xyz;
unsigned int dim_xyz = 0;
// Data type-specific logic
if (py::array_t<int8_t, py::array::c_style>::check_(volumeHandle)) {
auto fStyleVolume = py::array_t<int8_t, py::array::f_style>::ensure(volumeHandle);
Expand Down Expand Up @@ -123,10 +129,9 @@ void parseVolume(const py::dict &userCfg, Config &mcxConfig) {
dim_xyz = mcxConfig.dim.x * mcxConfig.dim.y * mcxConfig.dim.z;
mcxConfig.vol = static_cast<unsigned int *>(malloc(dim_xyz * sizeof(unsigned int)));
if (i == 1) {
unsigned int dim_xyz = mcxConfig.dim.x * mcxConfig.dim.y * mcxConfig.dim.z;
if (buffer.shape.at(0) == 3) {
mcxConfig.mediabyte = MEDIA_2LABEL_MIX;
unsigned short *val = (unsigned short *) buffer.ptr;
auto *val = (unsigned short *) buffer.ptr;
union {
unsigned short h[2];
unsigned char c[4];
Expand Down Expand Up @@ -345,10 +350,10 @@ void parseVolume(const py::dict &userCfg, Config &mcxConfig) {
mcxConfig.vol[i] = static_cast<double *>(buffer.ptr)[i];
}
else
throw py::value_error("Invalid data type for vol array.");
throw py::type_error("Invalid data type for vol array.");
}

void parseConfig(const py::dict &userCfg, Config &mcxConfig) {
void parse_config(const py::dict &userCfg, Config &mcxConfig) {
mcx_initcfg(&mcxConfig);

mcxConfig.flog = stderr;
Expand Down Expand Up @@ -385,7 +390,6 @@ void parseConfig(const py::dict &userCfg, Config &mcxConfig) {
GET_SCALAR_FIELD(userCfg, mcxConfig, gscatter, py::int_);
GET_SCALAR_FIELD(userCfg, mcxConfig, srcnum, py::int_);
GET_SCALAR_FIELD(userCfg, mcxConfig, omega, py::float_);
GET_SCALAR_FIELD(userCfg, mcxConfig, issave2pt, py::int_);
GET_SCALAR_FIELD(userCfg, mcxConfig, lambda, py::float_);
GET_VEC3_FIELD(userCfg, mcxConfig, srcpos, float);
GET_VEC34_FIELD(userCfg, mcxConfig, srcdir, float);
Expand Down Expand Up @@ -553,7 +557,7 @@ void parseConfig(const py::dict &userCfg, Config &mcxConfig) {
else {
auto fStyleArray = py::array_t<uint8_t, py::array::f_style | py::array::forcecast>::ensure(seedValue);
auto bufferInfo = fStyleArray.request();
seedbyte = bufferInfo.shape.at(0);
seed_byte = bufferInfo.shape.at(0);
if (bufferInfo.shape.at(0) != sizeof(float) * RAND_WORD_LEN)
throw py::value_error("the row number of cfg.seed does not match RNG seed byte-length");
mcxConfig.replay.seed = malloc(bufferInfo.size);
Expand Down Expand Up @@ -592,13 +596,18 @@ void parseConfig(const py::dict &userCfg, Config &mcxConfig) {
for (int i = 0; i < bufferInfo.size; i++)
mcxConfig.workload[i] = static_cast<float *>(bufferInfo.ptr)[i];
}
// Output arguments parsing
GET_SCALAR_FIELD(userCfg, mcxConfig, issave2pt, py::int_);
GET_SCALAR_FIELD(userCfg, mcxConfig, issavedet, py::int_);
GET_SCALAR_FIELD(userCfg, mcxConfig, issaveseed, py::int_);

// Flush the std::cout and std::cerr to avoid
std::cout.flush();
std::cerr.flush();
}


py::dict pyMcxInterface(const py::dict &userCfg) {
py::dict pyMcxInterface(const py::dict &user_cfg) {
unsigned int partial_data, hostdetreclen;
Config mcx_config; /* mcx_config: structure to store all simulation parameters */
GPUInfo *gpu_info = nullptr; /** gpuInfo: structure to store GPU information */
Expand All @@ -611,35 +620,35 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
/*
* To start an MCX simulation, we first create a simulation configuration and set all elements to its default settings.
*/
parseConfig(userCfg, mcx_config);
parse_config(user_cfg, mcx_config);

/** The next step, we identify gpu number and query all GPU info */
if (!(active_dev = mcx_list_gpu(&mcx_config, &gpu_info))) {
mcx_error(-1, "No GPU device found\n", __FILE__, __LINE__);
}
detps = nullptr;
det_ps = nullptr;

mcx_flush(&mcx_config);

/*
* Number of output arguments has to be explicitly specified, unlike Matlab.
*/
if (!userCfg.contains("nlhs"))
throw py::value_error("Number of output arguments must be specified.");
if (!py::int_::check_(userCfg["nlhs"]))
throw py::value_error("Number of output arguments must be int.");
int nlhs = py::int_(userCfg["nlhs"]);
if (nlhs < 0)
throw py::value_error("Number of output arguments must be greater than zero.");
// if (!user_cfg.contains("nlhs"))
// throw py::value_error("Number of output arguments must be specified.");
// if (!py::int_::check_(user_cfg["nlhs"]))
// throw py::value_error("Number of output arguments must be int.");
// int nlhs = py::int_(user_cfg["nlhs"]);
// if (nlhs < 0)
// throw py::value_error("Number of output arguments must be greater than zero.");

/** Overwrite the output flags using the number of output present */
if (nlhs < 1)
mcx_config.issave2pt =
0; /** issave2pt default is 1, but allow users to manually disable, auto disable only if there is no output */
mcx_config.issavedet = nlhs >= 2 ? 1 : 0; /** save detected photon data to the 2nd output if present */
mcx_config.issaveseed = nlhs >= 4 ? 1 : 0; /** save detected photon seeds to the 4th output if present */
// if (nlhs < 1)
// mcx_config.issave2pt =
// 0; /** issave2pt default is 1, but allow users to manually disable, auto disable only if there is no output */
// mcx_config.issavedet = nlhs >= 2 ? 1 : 0; /** save detected photon data to the 2nd output if present */
// mcx_config.issaveseed = nlhs >= 4 ? 1 : 0; /** save detected photon seeds to the 4th output if present */
/** Validate all input fields, and warn incompatible inputs */
validate_config(&mcx_config, detps, dimdetps, seedbyte, [](const char *msg) { throw py::value_error(msg); });
validate_config(&mcx_config, det_ps, dim_det_ps, seed_byte, [](const char *msg) { throw py::value_error(msg); });

partial_data =
(mcx_config.medianum - 1) * (SAVE_NSCAT(mcx_config.savedetflag) + SAVE_PPATH(mcx_config.savedetflag) +
Expand All @@ -653,7 +662,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
}

/** Initialize all buffers necessary to store the output variables */
if (nlhs >= 1) {
if (mcx_config.issave2pt == 1) {
int fieldlen =
static_cast<int>(mcx_config.dim.x) * static_cast<int>(mcx_config.dim.y) * static_cast<int>(mcx_config.dim.z) *
(int) ((mcx_config.tend - mcx_config.tstart) / mcx_config.tstep + 0.5) * mcx_config.srcnum;
Expand All @@ -663,13 +672,13 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
fieldlen *= 2;
mcx_config.exportfield = (float *) calloc(fieldlen, sizeof(float));
}
if (nlhs >= 2) {
if (mcx_config.issavedet == 1) {
mcx_config.exportdetected = (float *) malloc(hostdetreclen * mcx_config.maxdetphoton * sizeof(float));
}
if (nlhs >= 4) {
if (mcx_config.issaveseed == 1) {
mcx_config.seeddata = malloc(mcx_config.maxdetphoton * sizeof(float) * RAND_WORD_LEN);
}
if (nlhs >= 5) {
if (mcx_config.debuglevel != 0) {
mcx_config.exportdebugdata = (float *) malloc(mcx_config.maxjumpdebug * sizeof(float) * MCX_DEBUG_REC_LEN);
mcx_config.debuglevel |= MCX_DEBUG_MOVE;
}
Expand Down Expand Up @@ -706,7 +715,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
fielddim[5] = 1;

/** if 5th output presents, output the photon trajectory data */
if (nlhs >= 5) {
if (mcx_config.debuglevel != 0) {
fielddim[0] = MCX_DEBUG_REC_LEN;
fielddim[1] = mcx_config.debugdatalen; // his.savedphoton is for one repetition, should correct
fielddim[2] = 0;
Expand All @@ -720,7 +729,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
output["photontraj"] = photonTrajData;
}
/** if the 4th output presents, output the detected photon seeds */
if (nlhs >= 4) {
if (mcx_config.issaveseed == 1) {
fielddim[0] = (mcx_config.issaveseed > 0) * RAND_WORD_LEN * sizeof(float);
fielddim[1] = mcx_config.detectedcount; // his.savedphoton is for one repetition, should correct
fielddim[2] = 0;
Expand All @@ -732,7 +741,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
output["detectedseeds"] = detectedSeeds;
}
/** if the 3rd output presents, output the detector-masked medium volume, similar to the --dumpmask flag */
if (nlhs >= 3) {
if (user_cfg.contains("dumpmask") && py::reinterpret_borrow<py::bool_>(user_cfg["dumpmask"]).cast<bool>()) {
fielddim[0] = mcx_config.dim.x;
fielddim[1] = mcx_config.dim.y;
fielddim[2] = mcx_config.dim.z;
Expand All @@ -745,7 +754,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
}
}
/** if the 2nd output presents, output the detected photon partialpath data */
if (nlhs >= 2) {
if (mcx_config.issavedet == 1) {
fielddim[0] = hostdetreclen;
fielddim[1] = mcx_config.detectedcount;
fielddim[2] = 0;
Expand All @@ -760,7 +769,7 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
mcx_config.exportdetected = NULL;
}
/** if the 1st output presents, output the fluence/energy-deposit volume data */
if (nlhs >= 1) {
if (mcx_config.issave2pt) {
int fieldlen;
fielddim[0] = mcx_config.srcnum * mcx_config.dim.x;
fielddim[1] = mcx_config.dim.y;
Expand Down Expand Up @@ -791,11 +800,9 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
}
output["dref"] = drefArray;
}
if (mcx_config.issave2pt) {
auto data = py::array_t<float, py::array::f_style>(arrayDims);
memcpy(data.mutable_data(), mcx_config.exportfield, fieldlen * sizeof(float));
output["data"] = data;
}
auto data = py::array_t<float, py::array::f_style>(arrayDims);
memcpy(data.mutable_data(), mcx_config.exportfield, fieldlen * sizeof(float));
output["data"] = data;
free(mcx_config.exportfield);
mcx_config.exportfield = nullptr;
output["runtime"] = mcx_config.runtime;
Expand Down Expand Up @@ -830,8 +837,8 @@ py::dict pyMcxInterface(const py::dict &userCfg) {
}

/** Clear up simulation data structures by calling the destructors */
if (detps)
free(detps);
if (det_ps)
free(det_ps);
mcx_cleargpuinfo(&gpu_info);
mcx_clearcfg(&mcx_config);
// return a pointer to the MCX output, wrapped in a std::vector
Expand Down

0 comments on commit 3b1c320

Please sign in to comment.