Skip to content

Commit

Permalink
CU-8695pvhfe fix usage monitoring for multiprocessing (#488)
Browse files Browse the repository at this point in the history
* CU-8695pvhfe: Rename a test class

* CU-8695pvhfe: Add tests for multiprocessig usage monitoring

* CU-8695pvhfe: Fix usage monitor for multiprocessig.

When using CAT.multiprocessing_batch_char_size (CAT._multiprocessing_batch and CAT._mp_cons internally), flush the usage monitor at the end of multiprocessing method.
When using CAT.get_entities_multi_texts or CAT.multiprocessing_batch_docs_size (uses the former internally), add logging of usage to output

* CU-8695pvhfe: Fix remaining issues with usage monitor for multiprocessig.

Avoid checking length of (potentially) non-existent strings. Avoid early iteration of generator.
  • Loading branch information
mart-r authored Sep 17, 2024
1 parent 394e17b commit eb912d6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
23 changes: 22 additions & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,11 +1127,29 @@ def get_entities_multi_texts(self,
self.pipe.set_error_handler(self._pipe_error_handler)
try:
texts_ = self._get_trimmed_texts(texts)
if self.config.general.usage_monitor.enabled:
input_lengths: List[Tuple[int, int]] = []
for orig_text, trimmed_text in zip(texts, texts_):
if orig_text is None or trimmed_text is None:
l1, l2 = 0, 0
else:
l1 = len(orig_text)
l2 = len(trimmed_text)
input_lengths.append((l1, l2))
docs = self.pipe.batch_multi_process(texts_, n_process, batch_size)

for doc in tqdm(docs, total=len(texts_)):
for doc_nr, doc in tqdm(enumerate(docs), total=len(texts_)):
doc = None if doc.text.strip() == '' else doc
out.append(self._doc_to_out(doc, only_cui, addl_info, out_with_text=True))
if self.config.general.usage_monitor.enabled:
l1, l2 = input_lengths[doc_nr]
if doc is None:
nents = 0
elif self.config.general.show_nested_entities:
nents = len(doc._.ents) # type: ignore
else:
nents = len(doc.ents) # type: ignore
self.usage_monitor.log_inference(l1, l2, nents)

# Currently spaCy cannot mark which pieces of texts failed within the pipe so be this workaround,
# which also assumes texts are different from each others.
Expand Down Expand Up @@ -1637,6 +1655,9 @@ def _mp_cons(self, in_q: Queue, out_list: List, min_free_memory: float,
logger.warning("PID: %s failed one document in _mp_cons, running will continue normally. \n" +
"Document length in chars: %s, and ID: %s", pid, len(str(text)), i_text)
logger.warning(str(e))
if self.config.general.usage_monitor.enabled:
# NOTE: This is in another process, so need to explicitly flush
self.usage_monitor._flush_logs()
sleep(2)

def _add_nested_ent(self, doc: Doc, _ents: List[Span], _ent: Union[Dict, Span]) -> None:
Expand Down
47 changes: 43 additions & 4 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import sys
import time
from typing import Callable
from functools import partial
import unittest
from unittest.mock import mock_open, patch
import tempfile
Expand Down Expand Up @@ -595,18 +597,55 @@ def test_get_entities_gets_monitored(self,
contents = f.readline()
self.assertTrue(contents)

def assert_gets_usage_monitored(self, data_processor: Callable[[None], None], exp_logs: int = 1):
# clear usage monitor buffer
self.undertest.usage_monitor.log_buffer.clear()
data_processor()
file = self.undertest.usage_monitor.log_file
if os.path.exists(file):
with open(file) as f:
content = f.readlines()
content += self.undertest.usage_monitor.log_buffer
else:
content = self.undertest.usage_monitor.log_buffer
self.assertTrue(content)
self.assertEqual(len(content), exp_logs)

def test_get_entities_logs_usage(self,
text="The dog is sitting outside the house."):
# clear usage monitor buffer
self.undertest.usage_monitor.log_buffer.clear()
self.undertest.get_entities(text)
self.assertTrue(self.undertest.usage_monitor.log_buffer)
self.assertEqual(len(self.undertest.usage_monitor.log_buffer), 1)
self.assert_gets_usage_monitored(partial(self.undertest.get_entities, text), 1)
line = self.undertest.usage_monitor.log_buffer[0]
# the 1st element is the input text length
input_text_length = line.split(",")[1]
self.assertEqual(str(len(text)), input_text_length)

TEXT4MP_USAGE = [
("ID1", "Text with house and dog one"),
("ID2", "Text with house and dog two"),
("ID3", "Text with house and dog three"),
("ID4", "Text with house and dog four"),
("ID5", "Text with house and dog five"),
("ID6", "Text with house and dog siz"),
("ID7", "Text with house and dog seven"),
("ID8", "Text with house and dog eight"),
]

def test_mp_batch_char_size_logs_usage(self):
all_text = self.TEXT4MP_USAGE
proc = partial(self.undertest.multiprocessing_batch_char_size, all_text, nproc=2)
self.assert_gets_usage_monitored(proc, len(all_text))

def test_mp_get_multi_texts_logs_usage(self):
all_text = self.TEXT4MP_USAGE
proc = partial(self.undertest.get_entities_multi_texts, all_text, n_process=2)
self.assert_gets_usage_monitored(proc, len(all_text))

def test_mp_batch_docs_size_logs_usage(self):
all_text = self.TEXT4MP_USAGE
proc = partial(self.undertest.multiprocessing_batch_docs_size, all_text, nproc=2)
self.assert_gets_usage_monitored(proc, len(all_text))

def test_simple_hashing_is_faster(self):
self.undertest.config.general.simple_hash = False
st = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_usage_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_some_in_file(self):
self.assertEqual(len(lines), self.expected_in_file)


class UMT(UsageMonitorBaseTests):
class UsageMonitoringAutoTests(UsageMonitorBaseTests):
ENABLED_DICT = {
"MEDCAT_USAGE_LOGS": "True",
"MEDCAT_USAGE_LOGS_LOCATION": "."
Expand Down

0 comments on commit eb912d6

Please sign in to comment.