From d9d71da6442fd874e313245a9a9815eacb1926c8 Mon Sep 17 00:00:00 2001 From: Daniel Baston Date: Tue, 6 Dec 2022 22:10:51 -0500 Subject: [PATCH] Add functions to interrupt processing in a specific thread/context - Add Interrupt::requestForCurrentThread (C API: GEOS_interruptThread) to request interruption of the current thread only. - Add Interrupt::registerThreadCallback to register an interruption callback for the current thread only. - Add GEOSContext_setInterruptCallback_r to associate an interruption callback with a context handle. The callback will be registered for the current thread each time an _r function is called with the specified handle. --- capi/geos_c.cpp | 6 ++ capi/geos_c.h.in | 47 ++++++++- capi/geos_ts_c.cpp | 94 +++++++++++++---- include/geos/util/Interrupt.h | 30 +++++- src/util/Interrupt.cpp | 33 +++++- tests/unit/CMakeLists.txt | 1 + tests/unit/capi/GEOSInterruptTest.cpp | 62 +++++++++++- tests/unit/util/InterruptTest.cpp | 140 ++++++++++++++++++++++++++ 8 files changed, 381 insertions(+), 32 deletions(-) create mode 100644 tests/unit/util/InterruptTest.cpp diff --git a/capi/geos_c.cpp b/capi/geos_c.cpp index 6ef22fcb31..97841e6757 100644 --- a/capi/geos_c.cpp +++ b/capi/geos_c.cpp @@ -122,6 +122,12 @@ extern "C" { geos::util::Interrupt::request(); } + void + GEOS_interruptThread() + { + geos::util::Interrupt::requestForCurrentThread(); + } + void GEOS_interruptCancel() { diff --git a/capi/geos_c.h.in b/capi/geos_c.h.in index 8748af15fc..e0b570751c 100644 --- a/capi/geos_c.h.in +++ b/capi/geos_c.h.in @@ -296,8 +296,8 @@ typedef int (*GEOSTransformXYCallback)( /* ========== Interruption ========== */ /** -* Callback function for use in interruption. The callback will be invoked _before_ checking for -* interruption, so can be used to request it. +* Callback function for use in interruption. The callback will be invoked at each +* possible interruption point and can be used to request interruption. * * \see GEOS_interruptRegisterCallback * \see GEOS_interruptRequest @@ -306,19 +306,56 @@ typedef int (*GEOSTransformXYCallback)( typedef void (GEOSInterruptCallback)(void); /** -* Register a function to be called when processing is interrupted. +* Callback function for use in interruption. The callback will be invoked at each +* possible interruption point and can be used to request interruption. +* +* \see GEOS_interruptRegisterThreadCallback +* \see GEOS_interruptThread +*/ +typedef void (GEOSInterruptThreadCallback)(void*); + +/** +* Register a function to be called when a possible interruption point is reached +* on any thread. The function may be used to request interruption. +* * \param cb Callback function to invoke -* \return the previously configured callback +* \return the previously registered callback, or NULL * \see GEOSInterruptCallback +* \see GEOSContext_setInterruptCallback_r */ extern GEOSInterruptCallback GEOS_DLL *GEOS_interruptRegisterCallback( GEOSInterruptCallback* cb); /** -* Request safe interruption of operations +* Register a function to be called when a possible interruption point is reached +* in code executed in the specified context. The function can interrupt the +* thread if desired by calling GEOS_interruptThread. +* +* \param extHandle the context returned by \ref GEOS_init_r. +* \param cb Callback function to invoke +* \param userData optional data to be pe provided as argument to callback +* \return the previously registered callback, or NULL +* \see GEOSInterruptThreadCallback +*/ +extern GEOSInterruptThreadCallback GEOS_DLL *GEOSContext_setInterruptCallback_r( + GEOSContextHandle_t extHandle, + GEOSInterruptThreadCallback* cb, + void* userData); + +/** +* Request safe interruption of operations. The next thread to check for an +* interrupt will be interrupted. To request interruption of a specific thread, +* instead call GEOS_interruptThread() from a callback executed by that thread. */ extern void GEOS_DLL GEOS_interruptRequest(void); +/** +* Request safe interruption of operations in the current thread. This function +* should be called from a callback registered by GEOS_interruptRegisterThreadCallback() +* or GEOS_interruptRegisterCallback(). +*/ +extern void GEOS_DLL GEOS_interruptThread(void); + /** * Cancel a pending interruption request */ diff --git a/capi/geos_ts_c.cpp b/capi/geos_ts_c.cpp index 3e3ce7780a..0dadaa6bb6 100644 --- a/capi/geos_ts_c.cpp +++ b/capi/geos_ts_c.cpp @@ -204,6 +204,8 @@ typedef struct GEOSContextHandle_HS { void* noticeData; GEOSMessageHandler errorMessageOld; GEOSMessageHandler_r errorMessageNew; + GEOSInterruptThreadCallback* interrupt_cb; + void* interrupt_cb_data; void* errorData; uint8_t WKBOutputDims; int WKBByteOrder; @@ -218,6 +220,8 @@ typedef struct GEOSContextHandle_HS { noticeData(nullptr), errorMessageOld(nullptr), errorMessageNew(nullptr), + interrupt_cb(nullptr), + interrupt_cb_data(nullptr), errorData(nullptr), point2d(nullptr) { @@ -275,6 +279,15 @@ typedef struct GEOSContextHandle_HS { return f; } + GEOSInterruptThreadCallback* + setInterruptHandler(GEOSInterruptThreadCallback* cb, void* userData) + { + auto old = interrupt_cb; + interrupt_cb = cb; + interrupt_cb_data = userData; + return old; + } + void NOTICE_MESSAGE(const char *fmt, ...) { @@ -375,12 +388,37 @@ gstrdup(std::string const& str) return gstrdup_s(str.c_str(), str.size()); } +struct InterruptManager { + InterruptManager(GEOSContextHandle_t handle) : + cb(handle->interrupt_cb), + cb_data(handle->interrupt_cb_data) { + if (cb) { + geos::util::Interrupt::registerThreadCallback(cb, cb_data); + } + } + + ~InterruptManager() { + if (cb != nullptr) { + geos::util::Interrupt::registerThreadCallback(nullptr, nullptr); + } + } + + GEOSInterruptThreadCallback* cb; + void* cb_data; +}; + +struct NotInterruptible { + NotInterruptible(GEOSContextHandle_t handle) { + (void) handle; + } +}; + } // namespace anonymous // Execute a lambda, using the given context handle to process errors. // Return errval on error. // Errval should be of the type returned by f, unless f returns a bool in which case we promote to char. -template +template inline auto execute( GEOSContextHandle_t extHandle, typename std::conditional()()),bool>::value, @@ -396,6 +434,8 @@ inline auto execute( return errval; } + InterruptManagerType ic(handle); + try { return f(); } catch (const std::exception& e) { @@ -409,7 +449,7 @@ inline auto execute( // Execute a lambda, using the given context handle to process errors. // Return nullptr on error. -template()())>::value, std::nullptr_t>::type = nullptr> +template()())>::value, std::nullptr_t>::type = nullptr> inline auto execute(GEOSContextHandle_t extHandle, F&& f) -> decltype(f()) { if (extHandle == nullptr) { return nullptr; @@ -420,6 +460,8 @@ inline auto execute(GEOSContextHandle_t extHandle, F&& f) -> decltype(f()) { return nullptr; } + InterruptManagerType ic(handle); + try { return f(); } catch (const std::exception& e) { @@ -433,9 +475,14 @@ inline auto execute(GEOSContextHandle_t extHandle, F&& f) -> decltype(f()) { // Execute a lambda, using the given context handle to process errors. // No return value. -template()())>::value, std::nullptr_t>::type = nullptr> +template()())>::value, std::nullptr_t>::type = nullptr> inline void execute(GEOSContextHandle_t extHandle, F&& f) { GEOSContextHandleInternal_t* handle = reinterpret_cast(extHandle); + + if (handle != nullptr) { + InterruptManagerType ic(handle); + } + try { f(); } catch (const std::exception& e) { @@ -514,6 +561,17 @@ extern "C" { return handle->setErrorHandler(ef, userData); } + GEOSInterruptThreadCallback* + GEOSContext_setInterruptCallback_r(GEOSContextHandle_t extHandle, GEOSInterruptThreadCallback* cb, void* userData) + { + GEOSContextHandleInternal_t* handle = reinterpret_cast(extHandle); + if(0 == handle->initialized) { + return nullptr; + } + + return handle->setInterruptHandler(cb, userData); + } + void finishGEOS_r(GEOSContextHandle_t extHandle) { @@ -879,7 +937,7 @@ extern "C" { int GEOSArea_r(GEOSContextHandle_t extHandle, const Geometry* g, double* area) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { *area = g->getArea(); return 1; }); @@ -888,7 +946,7 @@ extern "C" { int GEOSLength_r(GEOSContextHandle_t extHandle, const Geometry* g, double* length) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { *length = g->getLength(); return 1; }); @@ -1640,7 +1698,7 @@ extern "C" { int GEOSGetNumInteriorRings_r(GEOSContextHandle_t extHandle, const Geometry* g1) { - return execute(extHandle, -1, [&]() { + return execute(extHandle, -1, [&]() { const Polygon* p = dynamic_cast(g1); if(!p) { throw IllegalArgumentException("Argument is not a Polygon"); @@ -1654,7 +1712,7 @@ extern "C" { int GEOSGetNumGeometries_r(GEOSContextHandle_t extHandle, const Geometry* g1) { - return execute(extHandle, -1, [&]() { + return execute(extHandle, -1, [&]() { return static_cast(g1->getNumGeometries()); }); } @@ -1667,7 +1725,7 @@ extern "C" { const Geometry* GEOSGetGeometryN_r(GEOSContextHandle_t extHandle, const Geometry* g1, int n) { - return execute(extHandle, [&]() { + return execute(extHandle, [&]() { if(n < 0) { throw IllegalArgumentException("Index must be non-negative."); } @@ -1856,7 +1914,7 @@ extern "C" { const Geometry* GEOSGetExteriorRing_r(GEOSContextHandle_t extHandle, const Geometry* g1) { - return execute(extHandle, [&]() { + return execute(extHandle, [&]() { const Polygon* p = dynamic_cast(g1); if(!p) { throw IllegalArgumentException("Invalid argument (must be a Polygon)"); @@ -1872,7 +1930,7 @@ extern "C" { const Geometry* GEOSGetInteriorRingN_r(GEOSContextHandle_t extHandle, const Geometry* g1, int n) { - return execute(extHandle, [&]() { + return execute(extHandle, [&]() { const Polygon* p = dynamic_cast(g1); if(!p) { throw IllegalArgumentException("Invalid argument (must be a Polygon)"); @@ -2573,7 +2631,7 @@ extern "C" { GEOSCoordSeq_setOrdinate_r(GEOSContextHandle_t extHandle, CoordinateSequence* cs, unsigned int idx, unsigned int dim, double val) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { cs->setOrdinate(idx, dim, val); return 1; }); @@ -2600,7 +2658,7 @@ extern "C" { int GEOSCoordSeq_setXY_r(GEOSContextHandle_t extHandle, CoordinateSequence* cs, unsigned int idx, double x, double y) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { cs->setAt(CoordinateXY{x, y}, idx); return 1; }); @@ -2609,7 +2667,7 @@ extern "C" { int GEOSCoordSeq_setXYZ_r(GEOSContextHandle_t extHandle, CoordinateSequence* cs, unsigned int idx, double x, double y, double z) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { cs->setAt(Coordinate{x, y, z}, idx); return 1; }); @@ -2627,7 +2685,7 @@ extern "C" { GEOSCoordSeq_getOrdinate_r(GEOSContextHandle_t extHandle, const CoordinateSequence* cs, unsigned int idx, unsigned int dim, double* val) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { *val = cs->getOrdinate(idx, dim); return 1; }); @@ -2654,7 +2712,7 @@ extern "C" { int GEOSCoordSeq_getXY_r(GEOSContextHandle_t extHandle, const CoordinateSequence* cs, unsigned int idx, double* x, double* y) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { auto& c = cs->getAt(idx); *x = c.x; *y = c.y; @@ -2665,7 +2723,7 @@ extern "C" { int GEOSCoordSeq_getXYZ_r(GEOSContextHandle_t extHandle, const CoordinateSequence* cs, unsigned int idx, double* x, double* y, double* z) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { auto& c = cs->getAt(idx); *x = c.x; *y = c.y; @@ -2677,7 +2735,7 @@ extern "C" { int GEOSCoordSeq_getSize_r(GEOSContextHandle_t extHandle, const CoordinateSequence* cs, unsigned int* size) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { const std::size_t sz = cs->getSize(); *size = static_cast(sz); return 1; @@ -2687,7 +2745,7 @@ extern "C" { int GEOSCoordSeq_getDimensions_r(GEOSContextHandle_t extHandle, const CoordinateSequence* cs, unsigned int* dims) { - return execute(extHandle, 0, [&]() { + return execute(extHandle, 0, [&]() { const std::size_t dim = cs->getDimension(); *dims = static_cast(dim); diff --git a/include/geos/util/Interrupt.h b/include/geos/util/Interrupt.h index e52386d410..d2d20cd54e 100644 --- a/include/geos/util/Interrupt.h +++ b/include/geos/util/Interrupt.h @@ -27,15 +27,25 @@ class GEOS_DLL Interrupt { public: typedef void (Callback)(void); + typedef void (ThreadCallback)(void*); /** * Request interruption of operations * * Operations will be terminated by a GEOSInterrupt - * exception at first occasion. + * exception at first occasion, by the first thread + * to check for an interrupt request. */ static void request(); + /** + * Request interruption of operations in the current thread + * + * Operations in the current thread will be terminated by + * a GEOSInterrupt at first occasion. + */ + static void requestForCurrentThread(); + /** Cancel a pending interruption request */ static void cancel(); @@ -43,17 +53,29 @@ class GEOS_DLL Interrupt { static bool check(); /** \brief - * Register a callback that will be invoked + * Register a callback that will be invoked by all threads * before checking for interruption requests. * * NOTE that interruption request checking may happen - * frequently so any callback would better be quick. + * frequently so the callback should execute quickly. * * The callback can be used to call Interrupt::request() - * + * or Interrupt::requestForCurrentThread(). */ static Callback* registerCallback(Callback* cb); + /** \brief + * Register a callback that will be invoked the current thread + * before checking for interruption requests. + * + * NOTE that interruption request checking may happen + * frequently so the callback should execute quickly. + * + * The callback can be used to call Interrupt::request() + * or Interrupt::requestForCurrentThread(). + */ + static ThreadCallback* registerThreadCallback(ThreadCallback* cb, void* data); + /** * Invoke the callback, if any. Process pending interruption, if any. * diff --git a/src/util/Interrupt.cpp b/src/util/Interrupt.cpp index 0bc988221b..ea3d1691de 100644 --- a/src/util/Interrupt.cpp +++ b/src/util/Interrupt.cpp @@ -16,10 +16,16 @@ #include // for inheritance namespace { -/* Could these be portably stored in thread-specific space ? */ + +// Callback and request status for interruption of any single thread bool requested = false; +thread_local bool requested_for_thread = false; +// Callback and request status for interruption of a the current thread geos::util::Interrupt::Callback* callback = nullptr; +thread_local geos::util::Interrupt::ThreadCallback* callback_thread = nullptr; +thread_local void* callback_thread_data = nullptr; + } namespace geos { @@ -37,16 +43,23 @@ Interrupt::request() requested = true; } +void +Interrupt::requestForCurrentThread() +{ + requested_for_thread = true; +} + void Interrupt::cancel() { requested = false; + requested_for_thread = false; } bool Interrupt::check() { - return requested; + return requested || requested_for_thread; } Interrupt::Callback* @@ -57,14 +70,25 @@ Interrupt::registerCallback(Interrupt::Callback* cb) return prev; } +Interrupt::ThreadCallback* +Interrupt::registerThreadCallback(ThreadCallback* cb, void* data) +{ + ThreadCallback* prev = callback_thread; + callback_thread = cb; + callback_thread_data = data; + return prev; +} + void Interrupt::process() { if(callback) { (*callback)(); } - if(requested) { - requested = false; + if(callback_thread) { + (*callback_thread)(callback_thread_data); + } + if(check()) { interrupt(); } } @@ -74,6 +98,7 @@ void Interrupt::interrupt() { requested = false; + requested_for_thread = false; throw InterruptedException(); } diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index f29a52cf77..697c5fa9bb 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -44,6 +44,7 @@ foreach(_testfile ${_testfiles}) string(CONCAT _testname "geos::" ${_testname}) endif() add_test(NAME unit-${_cmake_testname} COMMAND test_geos_unit ${_testname}) + set_tests_properties(unit-${_cmake_testname} PROPERTIES TIMEOUT 30) endforeach() # Run all the unit tests in one go, for faster memory checking diff --git a/tests/unit/capi/GEOSInterruptTest.cpp b/tests/unit/capi/GEOSInterruptTest.cpp index 0bd0301143..0fb6720820 100644 --- a/tests/unit/capi/GEOSInterruptTest.cpp +++ b/tests/unit/capi/GEOSInterruptTest.cpp @@ -4,11 +4,13 @@ #include // geos #include +#include // std #include #include #include #include +#include namespace tut { // @@ -18,6 +20,7 @@ namespace tut { // Common data used in test cases. struct test_capiinterrupt_data { static int numcalls; + static int maxcalls; static GEOSInterruptCallback* nextcb; static void @@ -56,9 +59,18 @@ struct test_capiinterrupt_data { } } + static void + interruptAfterMaxCalls(void* data) + { + if (++*static_cast(data) >= maxcalls) { + GEOS_interruptThread(); + } + } + }; int test_capiinterrupt_data::numcalls = 0; +int test_capiinterrupt_data::maxcalls = 0; GEOSInterruptCallback* test_capiinterrupt_data::nextcb = nullptr; typedef test_group group; @@ -103,7 +115,7 @@ void object::test<1> finishGEOS(); } -/// Test interrupt callback being called XXX +/// Test interrupt callback being called template<> template<> void object::test<2> @@ -221,5 +233,53 @@ void object::test<5> } +// Test callback is thread-local +template<> +template<> +void object::test<6> +() +{ + using geos::util::Interrupt; + + maxcalls = 3; + int calls_1 = 0; + int calls_2 = 0; + + GEOSContextHandle_t h1 = initGEOS_r(notice, notice); + GEOSContextHandle_t h2 = initGEOS_r(notice, notice); + + GEOSContext_setInterruptCallback_r(h1, interruptAfterMaxCalls, &calls_1); + GEOSContext_setInterruptCallback_r(h2, interruptAfterMaxCalls, &calls_2); + + // get previously registered callback and verify there was none + ensure(Interrupt::registerThreadCallback(nullptr, nullptr) == nullptr); + + auto buffer = [](GEOSContextHandle_t handle) { + GEOSWKTReader* reader = GEOSWKTReader_create_r(handle); + GEOSGeometry* geom1 = GEOSWKTReader_read_r(handle, reader, "LINESTRING (0 0, 1 0)"); + GEOSGeometry* geom2 = GEOSBuffer_r(handle, geom1, 1, 8); + + GEOSGeom_destroy_r(handle, geom2); + GEOSGeom_destroy_r(handle, geom1); + GEOSWKTReader_destroy_r(handle, reader); + }; + + std::thread t1(buffer, h1); + std::thread t2(buffer, h2); + + t1.join(); + t2.join(); + + ensure_equals(calls_1, maxcalls); + ensure_equals(calls_2, maxcalls); + + // get previously registered callback and verify there was none + ensure(Interrupt::registerThreadCallback(nullptr, nullptr) == nullptr); + + finishGEOS_r(h1); + finishGEOS_r(h2); +} + + } // namespace tut diff --git a/tests/unit/util/InterruptTest.cpp b/tests/unit/util/InterruptTest.cpp new file mode 100644 index 0000000000..5b5586a4af --- /dev/null +++ b/tests/unit/util/InterruptTest.cpp @@ -0,0 +1,140 @@ +// tut +#include +// geos +#include +// std +#include +#include +#include + +using geos::util::Interrupt; + +namespace tut { +// +// Test Group +// + +// Common data used in test cases. +struct test_interrupt_data { + static void workForever() { + try { + std::cerr << "Started " << std::this_thread::get_id() << "." << std::endl; + while (true) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + GEOS_CHECK_FOR_INTERRUPTS(); + } + } catch (const std::exception&) { + std::cerr << "Interrupted " << std::this_thread::get_id() << "." << std::endl; + return; + } + } + + static void interruptNow() { + Interrupt::request(); + } + + static std::map* toInterrupt; + + static void interruptIfRequested() { + if (toInterrupt == nullptr) { + return; + } + + auto it = toInterrupt->find(std::this_thread::get_id()); + if (it != toInterrupt->end() && it->second) { + it->second = false; + Interrupt::requestForCurrentThread(); + } + } +}; + +std::map* test_interrupt_data::toInterrupt = nullptr; + +typedef test_group group; +typedef group::object object; + +group test_interrupt_group("geos::util::Interrupt"); + +// +// Test Cases +// + + +// Interrupt worker thread via global request from from main thead +template<> +template<> +void object::test<1> +() +{ + std::thread t(workForever); + Interrupt::request(); + + t.join(); +} + +// Interrupt worker thread via thread-specific request from worker thread using a callback +template<> +template<> +void object::test<2> +() +{ + Interrupt::registerCallback(interruptIfRequested); + + std::thread t1(workForever); + std::thread t2(workForever); + + // Create map and add entries before exposing it to the interrupt + // callback that will be acessed from multiple threads. It's OK + // for multiple threads to modify entries in the map but not for + // multiple threads to create entries. + std::map shouldInterrupt; + shouldInterrupt[t1.get_id()] = false; + shouldInterrupt[t2.get_id()] = false; + toInterrupt = &shouldInterrupt; + + shouldInterrupt[t2.get_id()] = true; + t2.join(); + + shouldInterrupt[t1.get_id()] = true; + t1.join(); +} + +// Register separate callbacks for each thread. Each callback will +// request interruption of itself only. +template<> +template<> +void object::test<3> +() +{ + bool interrupt1 = false; + int numCalls2 = 0; + + auto cb1 = ([](void* data) { + if (*static_cast(data)) { + Interrupt::requestForCurrentThread(); + } + }); + + auto cb2 = ([](void* data) { + if (++*static_cast(data) > 5) { + Interrupt::requestForCurrentThread(); + } + }); + + + std::thread t1([&cb1, &interrupt1]() { + Interrupt::registerThreadCallback(cb1, &interrupt1); + }); + + std::thread t2([&cb2, &numCalls2]() { + Interrupt::registerThreadCallback(cb2, &numCalls2); + }); + + t2.join(); + + interrupt1 = true; + t1.join(); +} + +} // namespace tut +