diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d446160c9..100df35ea 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,7 +40,8 @@ jobs: second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1))) timeout 25m python -m unittest ${first_half_nl[@]} timeout 25m python -m unittest ${second_half_nl[@]} - + - name: Regression + run: source tests/resources/regression/run_regression.sh - name: Get the latest release version id: get_latest_release uses: actions/github-script@v6 diff --git a/configs/default_regression_tests.yml b/configs/default_regression_tests.yml index 4608b8fba..899e8c4bf 100644 --- a/configs/default_regression_tests.yml +++ b/configs/default_regression_tests.yml @@ -1,79 +1,142 @@ -# # Example of some test cases -# # They will try to cover as many possible use cases as possible -# # The idea is that the CUI corresponding to the name is expected to be -# # obtained by MedCAT -# # Only the 'filters' under 'targeting' and the 'phrases' under -# # the test case are the two required sections, the rest is optional -# -# test-case-name-1: # name of this test case -# targeting: # info regarding targets of this test case -# strategy: "ALL" # the strategy for dealing with the filters below -# # so "ALL" means the targets need to match all the below filters -# # and "ANY" means that the targets need to match at least one of the filters -# # if only one type of target it specified, this is irrelevant -# # the default value is "ALL" if not specified -# prefname-only: False # set to True if only prefered names should be checked (defaults to False) -# targfiltersets: # the filters for this specific test case -# # there has to be one type of target, but multiple can be specified -# # if multiple types are target, the strategy defined above is taken into affect -# # each type can specify one or multiple values -# # this example shows has one values -# # the next example (below) will have multiple values -# type_id: "0123" # type_id or type_ids -# cui: "01230" # the target CUI (or list of CUIS) -# name: "name0" # the target names -# # all specified names need to exist within the CDB -# phrases: "The quick brown %s jumped over the lazy cat" # the phrases to go through -# # for each phrases, '%s' is replaced -# # by each name that is to be tested -# test-case-name-2: # name of this test case -# targeting: -# filters: -# type_id: # multiple target type IDs -# - "123" -# - "223" -# cui: # multiple target CUI -# - "1234" -# - "2234" -# name: # multiple names -# - "name1" -# - "name2" -# cui_and_children: # an example with CUI and children -# cui: '111' # the CUI (or CUIs) -# depth: 2 # and the depth of children -# phrases: -# - "The %s was measured" -# - "The %s was not measured" -# -# # The following example was (rather arbitrarily) created and should work for -# # the included SNOMED models -test-case-1: - targeting: - strategy: "ALL" - filters: - type_id: "2680757" - phrases: - - "The %s was measured" +# this is an example test case +# it is based on SNOMED-CT +test-case-1: # The (somewhat) arbitrary name of the test case + targeting: # the description of the replacement targets in the phrase(s) + placeholders: # the placeholders to replace in the phrase(s) + # Note that only 1 concept will be tested for at one time. + # So if the prhase(s) has/have more than 1 placeholder, the + # rest of them will be substitued in without care for whether + # or how accurately the model is able to recognise them. + # For the concepts that are not under test at a given time + # the "first" name is used (because the implementation has + # names in a set, there is possibility for run-to-run variance + # because of different names being used). + # + # There are 2 modes for the placeholders: + # 1. any-combination: false + # In this mode, only the concepts in the same position + # in the various lists are used in conjunction to oneanother. + # Though this also means that it is expected that all of the + # placeholders have the same number of CUIs to use. + # Assuming each of the N placeholders defines M replacement + # cuis, this approach produces M*N cases. + # 2. any-combination: true + # In this mode, any combination of the replacement CUIs is + # allowed. This means that quite a few different combinations + # will be generated and used. It also means that different + # placeholders can have different number of concepts suitbale + # for them. + # Assuming eacho of the N placeholders defines M repalcement + # cuis, this approach produces N * N^M (where `^` is power) + # cases. But for a more complicated set up (i.e where different + # placeholders have a different number of swappable CUIs) + # this calculation is not as straight forward. + # + # NOTE: The above description does not take into account different + # number of names associated with different concepts. For each + # of the "primary" concepts, each possible name is attempted. + - placeholder: '[DISORDER]' # the palceholder that will be substituted in the phrase(s) + cuis: ['4473006', # Intracerebral hemorrhage + '85189001', # Acute appendicitis + '186738001', # vestibular neuritis + '186738001', # vestibular neuritis + ] + - placeholder: '[FINDING1]' + cuis: ['162300006', # unilateral headache + '21522001', # abdominal pain + '103298005', # severe vertigo + '103298005', # severe vertigo + ] + prefname-only: false # this is an optional keyword for wach placeholder + # if set to true, only the preferred name will be used for + # this concept. Otherwise, all names will be used as + # different sub-cases + - placeholder: '[FINDING2]' + cuis: ['409668002', # photophobia + '422587007', # nausea + '422587007', # nausea + '422587007', # nausea + ] + - placeholder: '[FINDING3]' + cuis: ['2228002', # scintillating scotoma + '386661006', # fever + '81756001', # horizontal nystagmus + '81756001', # horizontal nystagmus + ] + - placeholder: '[NEGFINDING]' + cuis: ['386661006', # fever + '62315008', # diarrhea + '15188001', # hearing loss + '60862001', # tinnitus + ] + any-combination: false # if set to false, same length of CUIs is expected + # for each placeholder and only a combination is used + phrases: # The list of phrases + - > + Description: [DISORDER] + + CC: [FINDING1] on presentation; then developed [FINDING3] + + HX: On the day of presentation, this 32 y/o RHM suddenly developed [FINDING1] and [FINDING2]. + Four hours later he experienced sudden [FINDING3] lasting two hours. + There were no other associated symptoms except for the [FINDING1] and [FINDING2]. + He denied [NEGFINDING]. test-case-2: targeting: - filters: - type_id: "9090192" - phrases: - - "Patient presented with %s" - - "No %s was present" -test-case-3: - targeting: - filters: - type_id: "67667581" - phrases: - - "The patient has been diagnosed with %s" - - "There are no signs of %s" -test-case-4: - targeting: - strategy: "ALL" - filters: - cui_and_children: - cui: "364075005" # 'heart rate' - depth: 4 # and children 4 deep + placeholders: + - placeholder: '[FINDING1]' + cuis: ['49727002', # cough + '29857009', # chest pain + '21522001', # abdominal pain + '57676002', # joint pain + '25064002', # headache + '271807003', # fever + '162397003', # hematuria (blood in urine) + '271757001', # fatigue + '386661006', # weight loss + '62315008', # dysuria (painful urination) + ] + - placeholder: '[FINDING2]' + cuis: ['267036007', # shortness of breath + '68962001', # palpatations + '422587007', # nausea + '182888003', # swelling + '404640003', # dizziness + '422400008', # sore throat + '267036007', # shortness of breath + '267064002', # night sweats + '162607003', # back pain + '267102003', # urinary frequency + ] + - placeholder: '[DISORDER]' + cuis: ['195967001', # asthma + '194828000', # angina pectoris + '25374005', # gastroenteritis + '69896004', # rheumatoid arthritis + '37796009', # migraine + '186747009', # influenza + '106063007', # urinary tract infection + '444814009', # chronic fatigue syndrome + '95281007', # tuberculosis + '431855005', # cystitis + ] + any-combination: false phrases: - - "The patient's %s was 82 bps" + - > + The patient presents with [FINDING1] and [FINDING2]. These findings are suggestive of [DISORDER]. + Further diagnostic evaluation and investigations are required to confirm the diagnosis. + - > + The patient reports [FINDING1] and has also been experiencing [FINDING2]. These symptoms are consistent with a clinical presentation of [DISORDER]. + Further assessment and diagnostic tests are required to establish the underlying cause. + - > + Upon evaluation, the patient exhibits [FINDING1] along with [FINDING2]. This combination of findings raises suspicion for [DISORDER]. + Comprehensive diagnostic workup is advised to confirm the diagnosis and plan appropriate management. + - > + During the consultation, the patient described [FINDING1] and noted a recent history of [FINDING2]. These clinical features are suggestive of [DISORDER]. + Further investigation is necessary to verify the diagnosis and rule out other potential causes. + - > + The patient's symptoms include [FINDING1] and [FINDING2], which are commonly associated with [DISORDER]. + It is recommended that additional diagnostic procedures be performed to confirm this working diagnosis. + - > + The clinical presentation of [FINDING1] and [FINDING2] is indicative of [DISORDER]. + To ensure accurate diagnosis, further clinical evaluation and diagnostic tests are required. diff --git a/medcat/utils/regression/README.md b/medcat/utils/regression/README.md new file mode 100644 index 000000000..e422210c0 --- /dev/null +++ b/medcat/utils/regression/README.md @@ -0,0 +1,111 @@ +# Regression with MedCAT + +We often end up creating new models when a new version of an ontology (e.g SNOMED-CT) comes out. +However, it is not always clear whether the new model is comparable to the old one. +To solve this, we've developed a regression suite system. + +The idea is that we can define a small set of patient records with different placeholders for different findings or disorders, or anything in the ontology, really. +And we can then specify the concepts we think should fit in this patient record. + +An example patient record with placeholders (the simple one from the default regression suite): +``` +The patient presents with [FINDING1] and [FINDING2]. These findings are suggestive of [DISORDER]. +Further diagnostic evaluation and investigations are required to confirm the diagnosis. +``` +As we can see, there are three different placeholders in here: `[FINDING1]`, `[FINDING2]`, and `[DISORDER]`. +Each can be replaced with a specific name of a specific concept. +For instance, we've specified the following: + - `[FINDING1]` -> '49727002' (cough) + - `[FINDING2]` -> '267036007' (shortness of breath) + - `[DISORDER]` -> '195967001' (asthma) + +So with these swapped into the original patient record we get: +``` +The patient presents with cough and shortness of bre. These findings are suggestive of asthma. +Further diagnostic evaluation and investigations are required to confirm the diagnosis. +``` + +# Using regression suite + +The easiest way to use the regression suite is to use the built in endpoint: +``` +python -m medcat.utils.regression.regression_checker [regression suite YAML] +``` +While you need to specify a model pack, you do not need to specify a regression suite since the default one can be used instead. + +This will first read the regression suite from the YAML, then load the model pack, and finally run the regression suite. + +
The output can look like this +Output on the 2024-06 SNOMED-CT model on the first case in the default regression suite. + +``` +$ python -m medcat.utils.regression.regression_checker models/Snomed2024-06-gstt-trained_ae5b08e0fb5310b2.zip +Loading RegressionChecker from yaml: configs/default_regression_tests.yml +Loading model pack from file: models/Snomed2024-06-gstt-trained_ae5b08e0fb5310b2.zip +Checking the current status +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.96it/s] +A total of 1 parts were kept track of within the group "ALL". +And a total of 756 (sub)cases were checked. +At the strictness level of Strictness.NORMAL (allowing ['FOUND_ANY_CHILD', 'BIGGER_SPAN_LEFT', 'SMALLER_SPAN', 'PARTIAL_OVERLAP', 'BIGGER_SPAN_BOTH', 'BIGGER_SPAN_RIGHT', 'FOUND_CHILD_PARTIAL', 'IDENTICAL']): +The number of total successful (sub) cases: 737 (97.49%) +The number of total failing (sub) cases : 19 ( 2.51%) +IDENTICAL : 730 (96.56%) +SMALLER_SPAN : 2 ( 0.26%) +FOUND_ANY_CHILD : 5 ( 0.66%) +FAIL : 19 ( 2.51%) + Tested 'test-case-1' for a total of 756 cases: + IDENTICAL : 730 (96.56%) + SMALLER_SPAN : 2 ( 0.26%) + FOUND_ANY_CHILD : 5 ( 0.66%) + FAIL : 19 ( 2.51%) + Examples at Strictness.STRICTEST strictness + With phrase: 'Description: Acute appendicitis\nCC: abdo [277 chars] d Nausea. He denied Diarrhea.\n' + FOUND_ANY_CHILD for placeholder [FINDING1] with CUI '21522001' and name 'abdominal colic' + With phrase: 'Description: Acute appendicitis\nCC: [FIN [273 chars] d Nausea. He denied Diarrhea.\n' + SMALLER_SPAN for placeholder [FINDING1] with CUI '21522001' and name 'abdomen colic' + With phrase: 'Description: Acute appendicitis\nCC: abdo [273 chars] d Nausea. He denied Diarrhea.\n' + SMALLER_SPAN for placeholder [FINDING1] with CUI '21522001' and name 'abdomen colic' + With phrase: 'Description: Acute appendicitis\nCC: abdo [293 chars] d Nausea. He denied Diarrhea.\n' + FOUND_ANY_CHILD for placeholder [FINDING1] with CUI '21522001' and name 'abdominal colic finding' + With phrase: 'Description: Acute appendicitis\nCC: [FIN [271 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING1] with CUI '21522001' and name 'abdomen pain' + With phrase: 'Description: Acute appendicitis\nCC: [FIN [271 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING1] with CUI '21522001' and name 'colicky pain' + With phrase: 'Description: Acute appendicitis\nCC: coli [271 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING1] with CUI '21522001' and name 'colicky pain' + With phrase: 'Description: Acute appendicitis\nCC: coli [271 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING1] with CUI '21522001' and name 'colicky pain' + With phrase: 'Description: Acute appendicitis\nCC: Abdo [291 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING3] with CUI '386661006' and name 'hyperthermia' + With phrase: 'Description: Acute appendicitis\nCC: Abdo [295 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING3] with CUI '386661006' and name 'high temperature' + With phrase: 'Description: Acute appendicitis\nCC: Abdo [295 chars] d Nausea. He denied Diarrhea.\n' + FAIL for placeholder [FINDING3] with CUI '386661006' and name 'high temperature' + With phrase: 'Description: Migraine with aura\nCC: Unil [340 chars] obia. He denied [NEGFINDING].\n' + FAIL for placeholder [NEGFINDING] with CUI '386661006' and name 'hyperthermia' + FAIL for placeholder [NEGFINDING] with CUI '386661006' and name 'high temperature' + With phrase: 'Description: Acute appendicitis\nCC: Abdo [283 chars] usea. He denied [NEGFINDING].\n' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'loose stools' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'watery stool' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'loose bowel movement' + FOUND_ANY_CHILD for placeholder [NEGFINDING] with CUI '62315008' and name 'diarrhea symptom' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'loose bowel motion' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'loose bowel motions' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'loose stool' + FOUND_ANY_CHILD for placeholder [NEGFINDING] with CUI '62315008' and name 'diarrhea symptoms' + FOUND_ANY_CHILD for placeholder [NEGFINDING] with CUI '62315008' and name 'diarrhea symptom finding' + FAIL for placeholder [NEGFINDING] with CUI '62315008' and name 'watery stools' + With phrase: 'Description: Epidemic vertigo\nCC: Severe [311 chars] usea. He denied [NEGFINDING].\n' + FAIL for placeholder [NEGFINDING] with CUI '15188001' and name 'decreased hearing' + FAIL for placeholder [NEGFINDING] with CUI '15188001' and name 'decreased hearing finding' + FAIL for placeholder [NEGFINDING] with CUI '60862001' and name 'ringing in ear' +``` + +
+ +## The regression suite format + +The format has some documentation in the default (`config/default_regression_tests.yml`). +One should refer to those for now. + + diff --git a/medcat/utils/regression/category_separation.py b/medcat/utils/regression/category_separation.py deleted file mode 100644 index 883879390..000000000 --- a/medcat/utils/regression/category_separation.py +++ /dev/null @@ -1,533 +0,0 @@ -from abc import ABC, abstractmethod -from enum import auto, Enum -import os -from typing import Any, List, Dict, Optional, Set -import yaml -import string -import random -import logging - -import pydantic - -from medcat.utils.regression.checking import RegressionChecker, RegressionCase, FilterType, TypedFilter, MetaData - - -logger = logging.getLogger(__name__) - - -class CategoryDescription(pydantic.BaseModel): - """A descriptor for a category. - - Args: - target_cuis (Set[str]): The set of target CUIs - target_names (Set[str]): The set of target names - target_tuis (Set[str]): The set of target type IDs - anything_goes (bool): Matches any CUI/NAME/TUI. Defaults to False - """ - target_cuis: Set[str] - target_names: Set[str] - target_tuis: Set[str] - allow_everything: bool = False - - def _get_required_filter(self, case: RegressionCase, target_filter: FilterType) -> Optional[TypedFilter]: - for filter in case.filters: - if filter.type == target_filter: - return filter - return None - - def _has_specific_from(self, case: RegressionCase, targets: Set[str], target_filter: FilterType): - if self.allow_everything: - return True - filter = self._get_required_filter(case, target_filter) - if filter is None: - return False # No such filter - for val in filter.values: - if val in targets: - return True - return False - - def has_cui_from(self, case: RegressionCase) -> bool: - """Check if the description has a CUI from the specified regression case. - - Args: - case (RegressionCase): The regression case to check - - Returns: - bool: True if the description has a CUI from the regression case - """ - return (self._has_specific_from(case, self.target_cuis, FilterType.CUI) or - self._has_specific_from(case, self.target_cuis, FilterType.CUI_AND_CHILDREN)) - - def has_name_from(self, case: RegressionCase) -> bool: - """Check if the description has a name from the specified regression case. - - Args: - case (RegressionCase): The regression case to check - - Returns: - bool: True if the description has a name from the regression case - """ - return self._has_specific_from(case, self.target_names, FilterType.NAME) - - def has_tui_from(self, case: RegressionCase) -> bool: - """Check if the description has a target ID/TUI from the specified regression case. - - Args: - case (RegressionCase): The regression case to check - - Returns: - bool: True if the description has a target ID/TUI from the regression case - """ - return self._has_specific_from(case, self.target_tuis, FilterType.TYPE_ID) - - def __hash__(self) -> int: - return hash((tuple(self.target_cuis), tuple(self.target_names), tuple(self.target_tuis))) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, CategoryDescription): - return False - return (self.target_cuis == other.target_cuis - and self.target_names == other.target_names - and self.target_tuis == other.target_tuis) - - @classmethod - def anything_goes(cls) -> 'CategoryDescription': - s: Set[str] = set() - return CategoryDescription(target_cuis=s, target_tuis=s, target_names=s, allow_everything=True) - - -class Category(ABC): - """The category base class. - - A category defines which regression cases fit in it. - - Args: - name (str): The name of the category - """ - - def __init__(self, name: str) -> None: - self.name = name - - @abstractmethod - def fits(self, case: RegressionCase) -> bool: - """Check if a particular regression case fits in this category. - - Args: - case (RegressionCase): The regression case. - - Returns: - bool: Whether the case is in this category. - """ - - -class AllPartsCategory(Category): - """Represents a category which only fits a regression case if it matches all parts of category description. - - That is, in order for a regression case to match, it would need to match a CUI, a name and a TUI - specified in the category description. - - Args: - name (str): The name of the category - descr (CategoryDescription): The description of the category - """ - - def __init__(self, name: str, descr: CategoryDescription) -> None: - super().__init__(name) - self.description = descr - - def fits(self, case: RegressionCase) -> bool: - return (self.description.has_cui_from(case) and self.description.has_name_from(case) - and self.description.has_tui_from(case)) - - def __eq__(self, __o: object) -> bool: - if not isinstance(__o, AllPartsCategory): - return False - return __o.description == self.description - - def __hash__(self) -> int: - return hash((self.__class__.__name__, self.description)) - - def __str__(self) -> str: - return f"AllPartsCategory with: {self.description}" - - def __repr__(self) -> str: - return f"<{str(self)}>" - - -class AnyPartOfCategory(Category): - """Represents a category which fits a regression case that matches any part of its category desription. - - That is, any case that matches either a CUI, a name or a TUI within the category description, will fit. - - Args: - name (str): The name of the category - descr (CategoryDescription): The description of the category - """ - - def __init__(self, name: str, descr: CategoryDescription) -> None: - super().__init__(name) - self.description = descr - - def fits(self, case: RegressionCase) -> bool: - return (self.description.has_cui_from(case) or self.description.has_name_from(case) - or self.description.has_tui_from(case)) - - def __eq__(self, __o: object) -> bool: - if not isinstance(__o, AnyPartOfCategory): - return False - return __o.description == self.description - - def __hash__(self) -> int: - return hash((self.__class__.__name__, self.description)) - - def __str__(self) -> str: - return f"AnyPartOfCategory with: {self.description}" - - def __repr__(self) -> str: - return f"<{str(self)}>" - - -class SeparationObserver: - """Keeps track of which case is separate into which category/categories. - - It also keeps track of which cases have been observed as separated and - into which category. - """ - - def __init__(self) -> None: - self.reset() - - def observe(self, case: RegressionCase, category: Category) -> None: - """Observe the specified regression case in the specified category. - - Args: - case (RegressionCase): The regression case to observe - category (Category): The category to link the case tos - """ - if category not in self.separated: - self.separated[category] = set() - self.separated[category].add(case) - if case not in self.cases: - self.cases[case] = set() - self.cases[case].add(category) - - def has_observed(self, case: RegressionCase) -> bool: - """Check if the case has already been observed. - - Args: - case (RegressionCase): The case to check - - Returns: - bool: True if the case had been observed, False otherwise - """ - return case in self.cases - - def reset(self) -> None: - """Allows resetting the state of the observer.""" - self.separated: Dict[Category, Set[RegressionCase]] = {} - self.cases: Dict[RegressionCase, Set[Category]] = {} - - -class StrategyType(Enum): - """Describes the types of strategies one can can employ for strategy.""" - FIRST = auto - ALL = auto - - -class SeparatorStrategy(ABC): - """The strategy according to which the separation takes place. - - The separation strategy relies on the mutable separation observer instance. - """ - - def __init__(self, observer: SeparationObserver) -> None: - self.observer = observer - - @abstractmethod - def can_separate(self, case: RegressionCase) -> bool: - """Check if the separator strategy can separate the specified regression case - - Args: - case (RegressionCase): The regression case to check - - Returns: - bool: True if the strategy allows separation, False otherwise - """ - - @abstractmethod - def separate(self, case: RegressionCase, category: Category) -> None: - """Separate the regression case - - Args: - case (RegressionCase): The regression case to separate - category (Category): The category to separate to - """ - - def reset(self) -> None: - """Allows resetting the state of the separator strategy.""" - self.observer.reset() - - -class SeparateToFirst(SeparatorStrategy): - """Separator strategy that separates each case to its first match. - - That is to say, any subsequently matching categories are ignored. - This means that no regression case gets duplicated. - It also means that the number of cases in all categories will be the - same as the initial number of cases. - """ - - def can_separate(self, case: RegressionCase) -> bool: - return not self.observer.has_observed(case) - - def separate(self, case: RegressionCase, category: Category) -> None: - if self.observer.has_observed(case): - raise ValueError(f"Case {case} has already been observed") - self.observer.observe(case, category) - - -class SeparateToAll(SeparatorStrategy): - """A separator strateg that allows separation to all matching categories. - - This means that when one regression case fits into multiple categories, - it will be saved in each such category. I.e the some cases may be - duplicated. - """ - - def can_separate(self, case: RegressionCase) -> bool: - return True - - def separate(self, case: RegressionCase, category: Category) -> None: - self.observer.observe(case, category) - - -def get_random_str(length=8): - return ''.join(random.choices(string.ascii_letters, k=length)) - - -class RegressionCheckerSeparator(pydantic.BaseModel): - """Regression checker separtor. - - It is able to separate cases in a regression checker - into multiple different sets of regression cases - based on the given list of categories and the specified - strategy. - - Args: - categories(List[Category]): The categories to separate into - strategy(SeparatorStrategy): The strategy for separation - overflow_category(bool): Whether to use an overflow category for cases that don't fit in other categoreis. Defaults to False. - """ - - categories: List[Category] - strategy: SeparatorStrategy - overflow_category: bool = False - - class Config: - arbitrary_types_allowed = True - - def _attempt_category_for(self, cat: Category, case: RegressionCase): - if cat.fits(case) and self.strategy.can_separate(case): - self.strategy.separate(case, cat) - - def find_categories_for(self, case: RegressionCase): - """Find the categories for a specific regression case - - Args: - case (RegressionCase): The regression case to check - - Raises: - ValueError: If no category found. - """ - for cat in self.categories: - self._attempt_category_for(cat, case) - if not self.strategy.observer.has_observed(case) and self.overflow_category: - anything_goes = AnyPartOfCategory( - f'overflow-{get_random_str()}', descr=CategoryDescription.anything_goes()) - self.categories.append(anything_goes) - self._attempt_category_for(anything_goes, case) - logger.info( - "Created overflow category since not all cases fit in specified categories") - logger.info("The overflow category is named: %s", - anything_goes.name) - if not self.strategy.observer.has_observed(case): - raise ValueError("Anything-goes category should be sufficient") - - def separate(self, checker: RegressionChecker) -> None: - """Separate the specified regression checker into multiple sets of cases. - - Each case may be associated with either no, one, or multiple categories. - The specifics depends on `allow_overflow` and `strategy`. - - Args: - checker(RegressionChecker): The input regression checker - """ - for case in checker.cases: - self.find_categories_for(case) - - def save(self, prefix: str, metadata: MetaData, overwrite: bool = False) -> None: - """Save the results of the separation in different files. - - This needs to be called after the `separate` method has been called. - - Each separated category (that has any cases registered to it) will - be saved in a separate file with the specified predix and the category name. - - Args: - prefix (str): The prefix for the saved file(s) - metadata (MetaData): The metadata for the regression suite - overwrite (bool): Whether to overwrite file(s) if/when needed. Defaults to False. - - Raises: - ValueError: If the method is called before separation or no separtion was done - ValueError: If a file already exists and is not allowed to be overwritten - """ - if not self.strategy.observer.separated: # empty - raise ValueError("Need to do separation before saving!") - for category, cases in self.strategy.observer.separated.items(): - rc = RegressionChecker(list(cases), metadata=metadata) - yaml_str = rc.to_yaml() - yaml_file_name = f"{prefix}_{category.name}.yml" - if not overwrite and os.path.exists(yaml_file_name): - raise ValueError(f"File already exists: {yaml_file_name}. " - "Pass overwrite=True to overwrite") - logger.info("Writing %d cases to %s", len(cases), yaml_file_name) - with open(yaml_file_name, 'w') as f: - f.write(yaml_str) - - -def get_strategy(strategy_type: StrategyType) -> SeparatorStrategy: - """Get the separator strategy from the strategy type. - - Args: - strategy_type (StrategyType): The type of strategy - - Raises: - ValueError: If an unknown strategy is provided - - Returns: - SeparatorStrategy: The resulting separator strategys - """ - observer = SeparationObserver() - if strategy_type == StrategyType.FIRST: - return SeparateToFirst(observer) - elif strategy_type == StrategyType.ALL: - return SeparateToAll(observer) - else: - raise ValueError(f"Unknown strategy type {strategy_type}") - - -def get_separator(categories: List[Category], strategy_type: StrategyType, - overflow_category: bool = False) -> RegressionCheckerSeparator: - """Get the regression checker separator for the list of categories and the specified strategy. - - Args: - categories (List[Category]): The list of categories to include - strategy_type (StrategyType): The strategy for separation - overflow_category (bool): Whether to use an overflow category for items that don't go in other categories. Defaults to False. - - Returns: - RegressionCheckerSeparator: The resulting separator - """ - strategy = get_strategy(strategy_type) - return RegressionCheckerSeparator(categories=categories, strategy=strategy, overflow_category=overflow_category) - - -def get_description(cat_description: dict) -> CategoryDescription: - """Get the description from its dict representation. - - The dict is expected to have the following keys: - 'cuis', 'tuis', and 'names' - Each one should have a list of strings as their values. - - Args: - cat_description (dict): The dict representation - - Returns: - CategoryDescription: The resulting category description - """ - cuis = set(cat_description['cuis']) - names = set(cat_description['names']) - tuis = set(cat_description['tuis']) - return CategoryDescription(target_cuis=cuis, target_names=names, target_tuis=tuis) - - -def get_category(cat_name: str, cat_description: dict) -> Category: - """Get the category of the specified name from the dict. - - The dict is expected to be in the form: - type: # either any or all - cuis: [] # list of CUIs in category - names: [] # list of names in category - tuis: [] # list of type IDs in category - - Args: - cat_name (str): The name of the category - cat_description (dict): The dict describing the category - - Raises: - ValueError: If an unknown type is specified. - - Returns: - Category: The resulting category - """ - description = get_description(cat_description) - cat_type = cat_description['type'] - if cat_type.lower() in ('any', 'anyparts', 'anypartsof'): - return AnyPartOfCategory(cat_name, description) - elif cat_type.lower() in ('all', 'allparts'): - return AllPartsCategory(cat_name, description) - else: - raise ValueError( - f"Unknown category type: {cat_type} for category '{cat_name}'") - - -def read_categories(yaml_file: str) -> List[Category]: - """Read categories from a YAML file. - - The yaml is assumed to be in the format: - categories: - category-name: - type: - cuis: [, , ...] - names: [, , ...] - tuis: [, , ...] - other-category-name: - ... # and so on - - Args: - yaml_file (str): The yaml file location - - Returns: - List[Category]: The resulting categories - """ - with open(yaml_file) as f: - d = yaml.safe_load(f) - cat_part = d['categories'] - return [get_category(cat_name, cat_part[cat_name]) for cat_name in cat_part] - - -def separate_categories(category_yaml: str, strategy_type: StrategyType, - regression_suite_yaml: str, target_file_prefix: str, overwrite: bool = False, - overflow_category: bool = False) -> None: - """Separate categories based on simple input. - - The categories are read from the provided file and - the regression suite from its corresponding yaml. - The separated regression suites are saved in accordance - to the defined prefix. - - Args: - category_yaml (str): The name of the YAML file describing the categories - strategy_type (StrategyType): The strategy for separation - regression_suite_yaml (str): The regression suite YAML - target_file_prefix (str): The target file prefix - overwrite (bool): Whether to overwrite file(s) if/when needed. Defaults to False. - overflow_category (bool): Whether to use an overflow category for items that don't go in other categories. Defaults to False. - """ - separator = get_separator(read_categories( - category_yaml), strategy_type, overflow_category) - checker = RegressionChecker.from_yaml(regression_suite_yaml) - separator.separate(checker) - metadata = checker.metadata # TODO - allow using different metadata? - separator.save(target_file_prefix, metadata, overwrite=overwrite) diff --git a/medcat/utils/regression/category_separator.py b/medcat/utils/regression/category_separator.py deleted file mode 100644 index e5e1495e2..000000000 --- a/medcat/utils/regression/category_separator.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import logging -from pathlib import Path - -from medcat.utils.regression.category_separation import separate_categories, StrategyType - - -logger = logging.getLogger(__name__) - - -def _prepare_args() -> argparse.Namespace: - """Prepares command line arguments to be used in main(). - - Returns: - argparse.Namespace: The argument namespace. - """ - parser = argparse.ArgumentParser() - # category_yaml: str, strategy_type: StrategyType, - # regression_suite_yaml: str, target_file_prefix: str, overwrite: bool = False - parser.add_argument( - 'categories', help='The categories YAML file', type=Path) - parser.add_argument('regressionsuite', - help='The regression suite YAML file', type=Path) - parser.add_argument( - 'targetprefix', help='The target YAML file prefix', type=Path) - parser.add_argument( - '--strategy', help='The strategy to be used for separation (FIRST or ALL)', - default='ALL', type=str) - parser.add_argument('--silent', '-s', help='Make the operation silent (i.e ignore console output)', - action='store_true') - parser.add_argument('--verbose', '-debug', help='Enable debug/verbose mode', - action='store_true') - parser.add_argument( - '--overwrite', help='Overwrite the target file if it exists', action='store_true') - parser.add_argument( - '--overflow', help='Allow using overflow category', action='store_true') - return parser.parse_args() - - -def main(): - """Runs the category separation according to command line arguments.""" - args = _prepare_args() - if not args.silent: - logger.addHandler(logging.StreamHandler()) - logger.setLevel('INFO') - if args.verbose: - from category_separation import logger as checking_logger - checking_logger.addHandler(logging.StreamHandler()) - checking_logger.setLevel('DEBUG') - strategy = StrategyType[args.strategy.upper()] - separate_categories(args.categories, strategy, args.regressionsuite, - args.targetprefix, overwrite=args.overwrite, overflow_category=args.overflow) - - -if __name__ == "__main__": - main() diff --git a/medcat/utils/regression/checking.py b/medcat/utils/regression/checking.py index d219ce9cb..d3c425583 100644 --- a/medcat/utils/regression/checking.py +++ b/medcat/utils/regression/checking.py @@ -1,159 +1,107 @@ -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, cast +from typing import Any, Dict, Iterator, List, Tuple, Optional import yaml +import json import logging import tqdm import datetime +import os from pydantic import BaseModel, Field from medcat.cat import CAT -from medcat.utils.regression.targeting import CUIWithChildFilter, FilterOptions, FilterType, TypedFilter, TranslationLayer, FilterStrategy - -from medcat.utils.regression.results import FailDescriptor, MultiDescriptor, ResultDescriptor +from medcat.utils.regression.targeting import TranslationLayer, OptionSet +from medcat.utils.regression.targeting import FinalTarget, TargetedPhraseChanger +from medcat.utils.regression.utils import partial_substitute, MedCATTrainerExportConverter +from medcat.utils.regression.results import MultiDescriptor, ResultDescriptor, Finding logger = logging.getLogger(__name__) class RegressionCase(BaseModel): - """A regression case that has a name, defines options, filters and phrases.s + """A regression case that has a name, defines options, filters and phrases. """ name: str - options: FilterOptions - filters: List[TypedFilter] + options: OptionSet phrases: List[str] report: ResultDescriptor - def get_all_targets(self, in_set: Iterator[Tuple[str, str]], translation: TranslationLayer) -> Iterator[Tuple[str, str]]: - """Get all applicable targets for this regression case - - Args: - in_set (Iterator[Tuple[str, str]]): The input generator / iterator - translation (TranslationLayer): The translation layer - - Yields: - Iterator[Tuple[str, str]]: The output generator - """ - if len(self.filters) == 1: - yield from self.filters[0].get_applicable_targets(translation, in_set) - return - if self.options.strategy == FilterStrategy.ANY: - for filter in self.filters: - yield from filter.get_applicable_targets(translation, in_set) - elif self.options.strategy == FilterStrategy.ALL: - cur_gen = in_set - for filter in self.filters: - cur_gen = filter.get_applicable_targets(translation, cur_gen) - yield from cur_gen - - def check_specific_for_phrase(self, cat: CAT, cui: str, name: str, phrase: str, - translation: TranslationLayer) -> bool: + def check_specific_for_phrase(self, cat: CAT, target: FinalTarget, + translation: TranslationLayer) -> Tuple[Finding, Optional[str]]: """Checks whether the specific target along with the specified phrase is able to be identified using the specified model. Args: cat (CAT): The model - cui (str): The target CUI - name (str): The target name - phrase (str): The phrase to check + target (FinalTarget): The final target configuration translation (TranslationLayer): The translation layer + Raises: + MalformedRegressionCaseException: If there are too many placeholders in phrase. + Returns: - bool: Whether or not the target was correctly identified + Tuple[Finding, Optional[str]]: The nature to which the target was (or wasn't) identified """ - res = cat.get_entities(phrase % name, only_cui=False) + phrase, cui, name, placeholder = target.final_phrase, target.cui, target.name, target.placeholder + nr_of_placeholders = phrase.count(placeholder) + if nr_of_placeholders != 1: + raise MalformedRegressionCaseException(f"Got {nr_of_placeholders} placeholders " + f"({placeholder}) (expected 1) for phrase: " + + phrase) + ph_start = phrase.find(placeholder) + res = cat.get_entities(phrase.replace(placeholder, name), only_cui=False) ents = res['entities'] - found_cuis = [ents[nr]['cui'] for nr in ents] - success = cui in found_cuis - fail_reason: Optional[FailDescriptor] - if success: + finding = Finding.determine(cui, ph_start, ph_start + len(name), + translation, ents) + if finding is Finding.IDENTICAL: logger.debug( 'Matched test case %s in phrase "%s"', (cui, name), phrase) - fail_reason = None else: - fail_reason = FailDescriptor.get_reason_for(cui, name, res, - translation) + found_cuis = [ents[nr]['cui'] for nr in ents] found_names = [ents[nr]['source_value'] for nr in ents] cuis_names = ', '.join([f'{fcui}|{fname}' for fcui, fname in zip(found_cuis, found_names)]) logger.debug( - 'FAILED to match (%s) test case %s in phrase "%s", ' - 'found the following CUIS/names: %s', fail_reason, (cui, name), phrase, cuis_names) - self.report.report(cui, name, phrase, - success, fail_reason) - return success - - def _get_all_cuis_names_types(self) -> Tuple[Set[str], Set[str], Set[str]]: - cuis = set() - names = set() - types = set() - for filt in self.filters: - if filt.type == FilterType.CUI: - cuis.update(filt.values) - elif filt.type == FilterType.CUI_AND_CHILDREN: - cuis.update(cast(CUIWithChildFilter, filt).delegate.values) - if filt.type == FilterType.NAME: - names.update(filt.values) - if filt.type == FilterType.TYPE_ID: - types.update(filt.values) - return cuis, names, types - - def get_all_subcases(self, translation: TranslationLayer) -> Iterator[Tuple[str, str, str]]: - """Get all subcases for this case. - That is, all combinations of targets with their appropriate phrases. + 'FAILED to (fully) match (%s) test case %s in phrase "%s", ' + 'found the following CUIS/names: %s', finding, (cui, name), phrase, cuis_names) + self.report.report(target, finding) + return finding - Args: - translation (TranslationLayer): The translation layer + def estimate_num_of_diff_subcases(self) -> int: + return len(self.phrases) * self.options.estimate_num_of_subcases() - Yields: - Iterator[Tuple[str, str, str]]: The generator for the target info and the phrase - """ - cntr = 0 - for cui, name in self.get_all_targets(translation.all_targets(*self._get_all_cuis_names_types()), translation): - for phrase in self.phrases: - cntr += 1 - yield cui, name, phrase - if not cntr: - for cui, name in self._get_specific_cui_and_name(): - for phrase in self.phrases: - yield cui, name, phrase - - def _get_specific_cui_and_name(self) -> Iterator[Tuple[str, str]]: - if len(self.filters) != 2: - return - if self.options.strategy != FilterStrategy.ALL: - return - f1, f2 = self.filters - if f1.type == FilterType.NAME and f2.type == FilterType.CUI: - name_filter, cui_filter = f1, f2 - elif f2.type == FilterType.NAME and f1.type == FilterType.CUI: - name_filter, cui_filter = f2, f1 - else: - return - # There should only ever be one for the ALL strategty - # because otherwise a match is impossible - for name in name_filter.values: - for cui in cui_filter.values: - yield cui, name + def get_distinct_cases(self, translation: TranslationLayer) -> Iterator[Iterator[FinalTarget]]: + """Gets the various distinct sub-case iterators. - def check_case(self, cat: CAT, translation: TranslationLayer) -> Tuple[int, int]: - """Check the regression case against a model. - I.e check all its applicable targets. + The sub-cases are those that can be determine without the translation layer. + However, the translation layer is included here since it streamlines the operation. Args: - cat (CAT): The CAT instance - translation (TranslationLayer): The translation layer + translation (TranslationLayer): The translation layer. - Returns: - Tuple[int, int]: Number of successes and number of failures + Yields: + Iterator[Iterator[FinalTarget]]: The iterator of iterators of different sub cases. """ - success = 0 - fail = 0 - for cui, name, phrase in self.get_all_subcases(translation): - if self.check_specific_for_phrase(cat, cui, name, phrase, translation): - success += 1 - else: - fail += 1 - return success, fail + # for each phrase and for each placeholder based option + for changer in self.options.get_preprocessors_and_targets(translation): + for phrase in self.phrases: + yield self._get_subcases(phrase, changer, translation) + + def _get_subcases(self, phrase: str, changer: TargetedPhraseChanger, + translation: TranslationLayer) -> Iterator[FinalTarget]: + cui, placeholder = changer.cui, changer.placeholder + changed_phrase = changer.changer(phrase) + for name in translation.get_names_of(cui, changer.onlyprefnames): + num_of_phs = changed_phrase.count(placeholder) + if num_of_phs == 1: + yield FinalTarget(placeholder=placeholder, + cui=cui, name=name, + final_phrase=changed_phrase) + continue + for cntr in range(num_of_phs): + final_phrase = partial_substitute(changed_phrase, placeholder, name, cntr) + yield FinalTarget(placeholder=placeholder, + cui=cui, name=name, + final_phrase=final_phrase) def to_dict(self) -> dict: """Converts the RegressionCase to a dict for serialisation. @@ -163,9 +111,6 @@ def to_dict(self) -> dict: """ d: Dict[str, Any] = {'phrases': list(self.phrases)} targeting = self.options.to_dict() - targeting['filters'] = {} - for filt in self.filters: - targeting['filters'].update(filt.to_dict()) d['targeting'] = targeting return d @@ -173,47 +118,38 @@ def to_dict(self) -> dict: def from_dict(cls, name: str, in_dict: dict) -> 'RegressionCase': """Construct the regression case from a dict. - The expected stucture: + The expected structure: { 'targeting': { - 'strategy': 'ALL', # optional - 'prefname-only': 'false', # optional - 'filters': { - : , # possibly multiple - } + [ + 'placeholder': '[DIAGNOSIS]' # the placeholder to be replaced + 'cuis': ['cui1', 'cui2'] + 'prefname-only': 'false', # optional + ] }, 'phrases': ['phrase %s'] # possible multiple } - Parsing the different parts of are delegated to - other methods within the relevant classes. - Delegators include: FilterOptions, TypedFilter - Args: name (str): The name of the case in_dict (dict): The dict describing the case Raises: ValueError: If the input dict does not have the 'targeting' section - ValueError: If the 'targeting' section does not have a 'filters' section ValueError: If there are no phrases defined Returns: - RegressionCase: The constructed regression case + RegressionCase: The constructed regression cases. """ # set up targeting if 'targeting' not in in_dict: raise ValueError('Input dict should define targeting') targeting_section = in_dict['targeting'] # set up options - options = FilterOptions.from_dict(targeting_section) - if 'filters' not in targeting_section: - raise ValueError( - 'Input dict should have define targets section under targeting') - # set up targets - parsed_filters: List[TypedFilter] = TypedFilter.from_dict( - targeting_section['filters']) - # set up test phrases + options = OptionSet.from_dict(targeting_section) + # all_cases: List['RegressionCase'] = [] + # for option in options: + # # set up test phrases if 'phrases' not in in_dict: raise ValueError('Input dict should defined phrases') phrases = in_dict['phrases'] @@ -221,7 +157,7 @@ def from_dict(cls, name: str, in_dict: dict) -> 'RegressionCase': phrases = [phrases] # just one defined if not phrases: raise ValueError('Need at least one target phrase') - return RegressionCase(name=name, options=options, filters=parsed_filters, + return RegressionCase(name=name, options=options, phrases=phrases, report=ResultDescriptor(name=name)) def __hash__(self) -> int: @@ -244,7 +180,7 @@ def get_ontology_and_version(model_card: dict) -> Tuple[str, str]: That is, unless the specified location does not exist in the model card, in which case 'Unknown' is returned. - The ontology is assumed to be descibed at: + The ontology is assumed to be described at: model_card['Source Ontology'][0] (or model_card['Source Ontology'] if it's a string instead of a list) The ontology version is read from: @@ -282,7 +218,7 @@ def get_ontology_and_version(model_card: dict) -> Tuple[str, str]: class MetaData(BaseModel): - """The metadat for the regression suite. + """The metadata for the regression suite. This should define which ontology (e.g UMLS or SNOMED) as well as which version was used when generating the regression suite. @@ -321,10 +257,10 @@ def unknown(self) -> 'MetaData': def fix_np_float64(d: dict) -> None: - """Fix numpy.float64 in dictrionary for yaml saving purposes. + """Fix numpy.float64 in dictionary for yaml saving purposes. These types of objects are unable to be cleanly serialized using yaml. - So we need to conver them to the corresponding floats. + So we need to convert them to the corresponding floats. The changes will be made within the dictionary itself as well as dictionaries within, recursively. @@ -340,7 +276,7 @@ def fix_np_float64(d: dict) -> None: fix_np_float64(v) -class RegressionChecker: +class RegressionSuite: """The regression checker. This is used to check a bunch of regression cases at once against a model. @@ -350,52 +286,69 @@ class RegressionChecker: use_report (bool): Whether or not to use the report functionality (defaults to False) """ - def __init__(self, cases: List[RegressionCase], metadata: MetaData) -> None: + def __init__(self, cases: List[RegressionCase], metadata: MetaData, name: str) -> None: self.cases: List[RegressionCase] = cases - self.report = MultiDescriptor(name='ALL') # TODO - allow setting names + self.report = MultiDescriptor(name=name) self.metadata = metadata for case in self.cases: self.report.parts.append(case.report) - def get_all_subcases(self, translation: TranslationLayer) -> Iterator[Tuple[RegressionCase, str, str, str]]: - """Get all subcases (i.e regssion case, target info and phrase) for this checker. + def get_all_distinct_cases(self, translation: TranslationLayer + ) -> Iterator[Tuple[RegressionCase, Iterator[FinalTarget]]]: + """Gets all the distinct cases for this regression suite. + + While distinct cases can be determined without the translation layer, + including it here simplifies the process. Args: - translation (TranslationLayer): The translation layer + translation (TranslationLayer): The translation layer. Yields: - Iterator[Tuple[RegressionCase, str, str, str]]: The generator for all the cases + Iterator[Tuple[RegressionCase, Iterator[FinalTarget]]]: The generator of the + regression case along with its corresponding sub-cases. """ - for case in self.cases: - for cui, name, phrase in case.get_all_subcases(translation): - yield case, cui, name, phrase + for regr_case in self.cases: + for subcase in regr_case.get_distinct_cases(translation): + yield regr_case, subcase + + def estimate_total_distinct_cases(self) -> int: + return sum(rc.estimate_num_of_diff_subcases() for rc in self.cases) - def check_model(self, cat: CAT, translation: TranslationLayer, - total: Optional[int] = None) -> MultiDescriptor: + def iter_subcases(self, translation: TranslationLayer, + show_progress: bool = True, + ) -> Iterator[Tuple[RegressionCase, FinalTarget]]: + """Iterate over all the sub-cases. + + Each sub-case present a unique target (phrase, concept, name) on + the corresponding regression case. + + Args: + translation (TranslationLayer): The translation layer. + show_progress (bool): Whether to show progress. Defaults to True. + + Yields: + Iterator[Tuple[RegressionCase, FinalTarget]]: The generator of the + regression case along with each of the final target sub-cases. + """ + total = self.estimate_total_distinct_cases() + for (regr_case, subcase) in tqdm.tqdm(self.get_all_distinct_cases(translation), + total=total, disable=not show_progress): + for target in subcase: + yield regr_case, target + + def check_model(self, cat: CAT, translation: TranslationLayer) -> MultiDescriptor: """Checks model and generates a report Args: cat (CAT): The model to check against translation (TranslationLayer): The translation layer - total (Optional[int]): The total number of (sub)cases expected (for a progress bar) Returns: MultiDescriptor: A report description """ - successes, fails = 0, 0 - if total is not None: - for case, ti, phrase in tqdm.tqdm(self.get_all_subcases(translation), total=total): - if case.check_specific_for_phrase(cat, ti, phrase, translation): - successes += 1 - else: - fails += 1 - else: - for case in tqdm.tqdm(self.cases): - for cui, name, phrase in case.get_all_subcases(translation): - if case.check_specific_for_phrase(cat, cui, name, phrase, translation): - successes += 1 - else: - fails += 1 + for regr_case, target in self.iter_subcases(translation, True): + # NOTE: the finding is reported in the per-case report + regr_case.check_specific_for_phrase(cat, target, translation) return self.report def __str__(self) -> str: @@ -428,12 +381,12 @@ def to_yaml(self) -> str: def __eq__(self, other: object) -> bool: # only checks cases - if not isinstance(other, RegressionChecker): + if not isinstance(other, RegressionSuite): return False return self.cases == other.cases @classmethod - def from_dict(cls, in_dict: dict) -> 'RegressionChecker': + def from_dict(cls, in_dict: dict, name: str) -> 'RegressionSuite': """Construct a RegressionChecker from a dict. Most of the parsing is handled in RegressionChecker.from_dict. @@ -441,26 +394,27 @@ def from_dict(cls, in_dict: dict) -> 'RegressionChecker': and each value describes a RegressionCase. Args: - in_dict (dict): The input dict + in_dict (dict): The input dict. + name (str): The name of the regression suite. Returns: RegressionChecker: The built regression checker """ - cases = [] + cases: List[RegressionCase] = [] for case_name, details in in_dict.items(): if case_name == 'meta': continue # ignore metadata - case = RegressionCase.from_dict(case_name, details) - cases.append(case) + add_case = RegressionCase.from_dict(case_name, details) + cases.append(add_case) if 'meta' not in in_dict: logger.warn("Loading regression suite without any meta data") metadata = MetaData.unknown() else: metadata = MetaData.parse_obj(in_dict['meta']) - return RegressionChecker(cases=cases, metadata=metadata) + return RegressionSuite(cases=cases, metadata=metadata, name=name) @classmethod - def from_yaml(cls, file_name: str) -> 'RegressionChecker': + def from_yaml(cls, file_name: str) -> 'RegressionSuite': """Constructs a RegressionChcker from a YAML file. The from_dict method is used for the construction from the dict. @@ -473,4 +427,17 @@ def from_yaml(cls, file_name: str) -> 'RegressionChecker': """ with open(file_name) as f: data = yaml.safe_load(f) - return RegressionChecker.from_dict(data) + return RegressionSuite.from_dict(data, name=os.path.basename(file_name)) + + @classmethod + def from_mct_export(cls, file_name: str) -> 'RegressionSuite': + with open(file_name) as f: + data = json.load(f) + converted = MedCATTrainerExportConverter(data).convert() + return RegressionSuite.from_dict(converted, name=os.path.basename(file_name)) + + +class MalformedRegressionCaseException(ValueError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/medcat/utils/regression/converting.py b/medcat/utils/regression/converting.py deleted file mode 100644 index 41d214703..000000000 --- a/medcat/utils/regression/converting.py +++ /dev/null @@ -1,226 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -import re -from typing import List, Optional, Set -import tqdm - -from medcat.utils.regression.checking import RegressionCase, RegressionChecker, MetaData -from medcat.utils.regression.results import ResultDescriptor -from medcat.utils.regression.targeting import FilterOptions, FilterStrategy, FilterType, TypedFilter - - -logger = logging.getLogger(__name__) - - -class ContextSelector(ABC): - """Describes how the context of a concept is found. - A sub-class should be used as this one has no implementation. - """ - - def _splitter(self, text: str) -> List[str]: - text = re.sub(' +', ' ', text) # remove duplicate spaces - # remove 1-letter words that are not a valid character - return [word for word in text.split() if ( - len(word) > 1 or re.match(r'\w', word))] - - def make_replace_safe(self, text: str) -> str: - """Make the text replace-safe. - That is, wrap all '%' as '%%' so that the `text % replacement` syntax - can be used for an inserted part (and that part only). - - Args: - text (str): The text to use - - Returns: - str: The replace-safe text - """ - return text.replace(r'%', r'%%') - - @abstractmethod - def get_context(self, text: str, start: int, end: int, leave_concept: bool = False) -> str: - """Get the context of a concept within a larger body of text. - The concept is specifiedb by its start and end indices. - - Args: - text (str): The larger text - start (int): The starting index - end (int): The ending index - leave_concept (bool): Whether to leave the concept or replace it by '%s'. Defaults to False - - Returns: - str: The select contexts - """ - pass # should be overwritten by subclass - - -class PerWordContextSelector(ContextSelector): - """Context selector that selects a number of words - from either side of the concept, regardless of punctuation. - - Args: - words_before (int): Number of words to select from before concept - words_after (int): Number of words to select from after concepts - """ - - def __init__(self, words_before: int, words_after: int) -> None: - self.words_before = words_before - self.words_after = words_after - - def get_context(self, text: str, start: int, end: int, leave_concept: bool = False) -> str: - words_before = self._splitter(text[:start]) - words_after = self._splitter(text[end:]) - if leave_concept: - concept = text[start:end] - else: - concept = '%s' - before = ' '.join(words_before[-self.words_before:]) - before = self.make_replace_safe(before) - after = ' '.join(words_after[:self.words_after]) - after = self.make_replace_safe(after) - return f'{before} {concept} {after}' - - -class PerSentenceSelector(ContextSelector): - """Context selector that selects a sentence as context. - Sentences are said to end with either ".", "?" or "!". - """ - stoppers = r'\.+|\?+|!+' - - def get_context(self, text: str, start: int, end: int, leave_concept: bool = False) -> str: - text_before = text[:start] - r_last_stopper = re.search(self.stoppers, text_before[::-1]) - if r_last_stopper: - last_stopper = len(text_before) - r_last_stopper.start() - context_before = text_before[last_stopper:] - else: # concept in first sentence - context_before = text_before - text_after = text[end:] - first_stopper = re.search(self.stoppers, text_after) - if first_stopper: - context_after = text_after[:first_stopper.start()] - else: # concept in last sentence - context_after = text_after - if leave_concept: - concept = text[start: end] - else: - concept = '%s' - context_before = self.make_replace_safe(context_before) - context_after = self.make_replace_safe(context_after) - return (context_before + concept + context_after).strip() - - -class UniqueNamePreserver: - """Used to preserver unique names in a set - """ - - def __init__(self) -> None: - self.unique_names: Set[str] = set() - - def name2nrgen(self, name: str, nr: int) -> str: - """The method to generate name and copy-number combinations. - - Args: - name (str): The base name - nr (int): The number of the copy - - Returns: - str: The combined name - """ - return f'{name}-{nr}' - - def get_unique_name(self, orig_name: str, dupe_nr: int = 0) -> str: - """Get the unique name of dupe number (at least) as high as specified. - - Args: - orig_name (str): The original / base name - dupe_nr (int): The number of the copy to start from. Defaults to 0. - - Returns: - str: The unique name - """ - if dupe_nr == 0: - cur_name = orig_name - else: - cur_name = self.name2nrgen(orig_name, dupe_nr) - if cur_name not in self.unique_names: - self.unique_names.add(cur_name) - return cur_name - return self.get_unique_name(orig_name, dupe_nr + 1) - - -def get_matching_case(cases: List[RegressionCase], filters: List[TypedFilter]) -> Optional[RegressionCase]: - """Get a case that matches a set of filters (if one exists) from within a list. - - Args: - cases (List[RegressionCase]): The list to look in - filters (List[TypedFilter]): The filters to compare to - - Returns: - Optional[RegressionCase]: The regression case (if found) or None - """ - for case in cases: - if case.filters == filters: - return case - return None - - -def medcat_export_json_to_regression_yml(mct_export_file: str, - cont_sel: ContextSelector = PerSentenceSelector(), - model_card: Optional[dict] = None) -> str: - """Extract regression test cases from a MedCATtrainer export yaml. - This is done based on the context selector specified. - - Args: - mct_export_file (str): The MCT export file path - cont_sel (ContextSelector): The context selector. Defaults to PerSentenceSelector(). - model_card (Optional[dict]): The optional model card for generating metadata - - Returns: - str: Extracted regression cases in YAML form - """ - with open(mct_export_file) as f: - data = json.load(f) - fo = FilterOptions(strategy=FilterStrategy.ALL, onlyprefnames=False) - test_cases: List[RegressionCase] = [] - unique_names = UniqueNamePreserver() - for project in tqdm.tqdm(data['projects']): - proj_name = project['name'] - docs = project['documents'] - for doc in tqdm.tqdm(docs): - text = doc['text'] - for ann in tqdm.tqdm(doc['annotations']): - target_name = ann['value'] - target_cui = ann['cui'] - start, end = ann['start'], ann['end'] - in_text_name = text[start: end] - if target_name != in_text_name: - logging.warn('Could not convert annotation since the text was not ' - f' equal to the name, ignoring:\n{ann}') - break - name_filt = TypedFilter(type=FilterType.NAME, - values=[target_name, ]) - cui_filt = TypedFilter(type=FilterType.CUI, - values=[target_cui, ]) - context = cont_sel.get_context(text, start, end) - phrase = context - case_name = unique_names.get_unique_name(f'{proj_name.replace(" ", "-")}-' - f'{target_name.replace(" ", "~")}') - cur_filters = [name_filt, cui_filt] - added_to_existing = False - for prev_rc in test_cases: - if prev_rc.filters == cur_filters: - prev_rc.phrases.append(phrase) - added_to_existing = True - if not added_to_existing: - rc = RegressionCase(name=case_name, options=fo, - filters=cur_filters, phrases=[ - phrase, ], - report=ResultDescriptor(name=case_name)) - test_cases.append(rc) - if model_card: - metadata = MetaData.from_modelcard(model_card) - else: - metadata = MetaData.unknown() - checker = RegressionChecker(cases=test_cases, metadata=metadata) - return checker.to_yaml() diff --git a/medcat/utils/regression/editing.py b/medcat/utils/regression/editing.py deleted file mode 100644 index 563c75a0d..000000000 --- a/medcat/utils/regression/editing.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -from pathlib import Path -from typing import Optional -import yaml - -from medcat.utils.regression.converting import UniqueNamePreserver, get_matching_case -from medcat.utils.regression.checking import RegressionChecker - -logger = logging.getLogger(__name__) - - -def combine_dicts(base_dict: dict, add_dict: dict, in_place: bool = False, ignore_identicals: bool = True) -> dict: - """Combine two dictionaries that define RegressionCheckers. - - The idea is to combine them into one that defines cases from both. - - If two cases have identical filters, their phrases are combined. - - If an additional case has the same name as one in the base dict, - its name is changed before adding it. - - Args: - base_dict (dict): The base dict to which we shall add - add_dict (dict): The additional dict - in_place (bool): Whether or not to modify the existing (base) dict. Defaults to False. - ignore_identicals (bool): Whether to ignore identical cases (otherwise they get duplicated). Defaults to True. - - Returns: - dict: The combined dict - """ - base = RegressionChecker.from_dict(base_dict) - add = RegressionChecker.from_dict(add_dict) - name_preserver = UniqueNamePreserver() - name_preserver.unique_names = { - base_case.name for base_case in base.cases} - for case in add.cases: - existing = get_matching_case(base.cases, case.filters) - if existing: - if ignore_identicals and existing == case: - logger.warning( - 'Found two identical case: %s and %s in base and addon', existing, case) - continue - logging.info( - 'Found existing case (%s), adding phrases: %s', existing, case.phrases) - existing.phrases.extend(case.phrases) - continue - new_name = name_preserver.get_unique_name(case.name) - if new_name != case.name: - logging.info('Renaming case from "%s" to "%s"', - case.name, new_name) - case.name = new_name - logging.info('Adding new case %s', case) - base.cases.append(case) - new_dict = base.to_dict() - if in_place: - base_dict.clear() - base_dict.update(new_dict) - return base_dict - else: - return new_dict - - -def combine_contents(base_yaml: str, add_yaml: str, ignore_identicals: bool = True) -> str: - """Combined the contents of two yaml strings that describe RegressionCheckers. - - This method simply loads in teh yamls and uses the `combine_dicts` method. - - Args: - base_yaml (str): The yaml of the base checker - add_yaml (str): The yaml of the additional checker - ignore_identicals (bool): Whether or not to ignore identical cases. Defaults to True. - - Returns: - str: The combined yaml contents - """ - base_dict = yaml.safe_load(base_yaml) - add_dict = yaml.safe_load(add_yaml) - combined_dict = combine_dicts( - base_dict, add_dict, in_place=True, ignore_identicals=ignore_identicals) - return yaml.safe_dump(combined_dict) - - -def combine_yamls(base_file: str, add_file: str, new_file: Optional[str] = None, ignore_identicals: bool = True) -> str: - """Combined the contents of two yaml files that describe RegressionCheckers. - - This method simply reads the data and uses the `combined_contents` method. - - The results are saved into the new_file (if specified) or to the base_file otherwise. - - Args: - base_file (str): The base file - add_file (str): The additional file - new_file (Optional[str]): The new file name. Defaults to None. - ignore_identicals (bool): Whether or not to ignore identical cases. Defaults to True. - - Returns: - str: The new file name - """ - base_yaml = Path(base_file).read_text() - add_yaml = Path(add_file).read_text() - combined_yaml = combine_contents( - base_yaml, add_yaml, ignore_identicals=ignore_identicals) - if new_file is None: - new_file = base_file # overwrite base - Path(new_file).write_text(combined_yaml) - return new_file diff --git a/medcat/utils/regression/mct_converter.py b/medcat/utils/regression/mct_converter.py deleted file mode 100644 index b161180e1..000000000 --- a/medcat/utils/regression/mct_converter.py +++ /dev/null @@ -1,84 +0,0 @@ -import argparse -import logging -import os -from pathlib import Path -from typing import List, Optional -import json - -from medcat.cat import CAT -from medcat.utils.regression.converting import ContextSelector, PerSentenceSelector, PerWordContextSelector, medcat_export_json_to_regression_yml - - -logger = logging.getLogger(__name__) - - -def get_model_card_from_file(model_card_file: str) -> dict: - with open(model_card_file) as f: - return json.load(f) - - -def get_model_card_from_model(model_zip: str) -> dict: - logger.info(f"Loading model from {model_zip} to find the model card - this may take a while") - cat = CAT.load_model_pack(model_zip) - return cat.get_model_card(as_dict=True) - - -def main(mct_export: str, target: str, overwrite: bool = False, - words: Optional[List[int]] = None, model_card_file: Optional[str] = None, - model_file: Optional[str] = None) -> None: - if not overwrite and os.path.isfile(target): - raise ValueError("Not able to overwrite an existingfile, " - "pass '--overwrite' to force an overwrite") - logger.info( - "Starting to convert export JSON to YAML from file %s", mct_export) - cont_sel: ContextSelector - if not words: - cont_sel = PerSentenceSelector() - else: - cont_sel = PerWordContextSelector(*words) - if model_card_file: - model_card = get_model_card_from_file(model_card_file) - elif model_file: - model_card = get_model_card_from_model(model_file) - else: - logger.warn("Creating regression suite with no model-card / metadata") - logger.warn("Please consider passing --modelcard or") - logger.warn("--model to find the model card associated with the regression suite") - logger.warn("This will help better understand where and how the regression suite was generated") - model_card = None - yaml = medcat_export_json_to_regression_yml(mct_export, cont_sel=cont_sel, model_card=model_card) - logger.debug("Conversion successful") - logger.info("Saving writing data to %s", target) - with open(target, 'w') as f: - f.write(yaml) - logger.debug("Done saving") - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - 'file', help='The MedCATtrainer export file', type=Path) - parser.add_argument('target', help='The Target YAML file', type=Path) - parser.add_argument( - '--modelcard', help='The ModelCard json file', type=Path) - parser.add_argument( - '--model', help='The Model to read model card from', type=Path) - parser.add_argument('--silent', '-s', help='Make the operation silent (i.e ignore console output)', - action='store_true') - parser.add_argument('--verbose', '-debug', help='Enable debug/verbose mode', - action='store_true') - parser.add_argument( - '--overwrite', help='Overwrite the target file if it exists', action='store_true') - parser.add_argument( - '--words', help='Select the number of words to select from before and after the concept', - nargs=2, type=int) - args = parser.parse_args() - if not args.silent: - logger.addHandler(logging.StreamHandler()) - logger.setLevel('INFO') - if args.verbose: - from checking import logger as checking_logger - checking_logger.addHandler(logging.StreamHandler()) - checking_logger.setLevel('DEBUG') - main(args.file, args.target, overwrite=args.overwrite, words=args.words, - model_card_file=args.modelcard, model_file=args.model) diff --git a/medcat/utils/regression/regression_checker.py b/medcat/utils/regression/regression_checker.py index 98361858b..5cb743734 100644 --- a/medcat/utils/regression/regression_checker.py +++ b/medcat/utils/regression/regression_checker.py @@ -6,33 +6,79 @@ from typing import Optional from medcat.cat import CAT -from medcat.utils.regression.checking import RegressionChecker, TranslationLayer +from medcat.utils.regression.checking import RegressionSuite, TranslationLayer +from medcat.utils.regression.results import Strictness, Finding, STRICTNESS_MATRIX logger = logging.getLogger(__name__) +DEFAULT_TEST_SUITE_PATH = Path('configs', 'default_regression_tests.yml') + + +def show_description(): + logger.info('The various findings and their descriptions:') + logger.info('') + logger.info('Class description:') + logger.info('') + logger.info(Finding.__doc__.replace("\n ", "\n")) + logger.info('') + for f in Finding: + logger.info('%s :', f.name) + logger.info(f.__doc__.replace("\n ", "\n")) + logger.info('') + logger.info('The strictnesses we have available:') + logger.info('') + for strictness in Strictness: + allows = [s.name for s in STRICTNESS_MATRIX[strictness]] + logger.info('%s: allows %s', strictness.name, allows) + logger.info('') + logger.info('NOTE: When using --example-strictness, anything described above ' + 'will be omitted from examples (since the are considered correct)') + + def main(model_pack_dir: Path, test_suite_file: Path, - total: Optional[int] = None, phrases: bool = False, hide_empty: bool = False, - hide_failures: bool = False, + examples_strictness_str: str = 'STRICTEST', jsonpath: Optional[Path] = None, overwrite: bool = False, - jsonindent: Optional[int] = None) -> None: + jsonindent: Optional[int] = None, + strictness_str: str = 'NORMAL', + max_phrase_length: int = 80, + use_mct_export: bool = False, + mct_export_yaml_path: Optional[str] = None, + only_mct_export_conversion: bool = False, + only_describe: bool = False, + require_fully_correct: bool = False) -> None: """Check test suite against the specifeid model pack. Args: model_pack_dir (Path): The path to the model pack test_suite_file (Path): The path to the test suite YAML - total (Optional[int]): The total number of (sub)cases to be tested (for progress bar) phrases (bool): Whether to show per-phrase information in a report hide_empty (bool): Whether to hide empty cases in a report - hide_failures (bool): Whether to hide failures in a report + examples_strictness_str (str): The example strictness string. Defaults to STRICTEST. + NOTE: If you set this to 'None', examples will be omitted. jsonpath (Optional[Path]): The json path to save the report to (if specified) overwrite (bool): Whether to overwrite the file if it exists. Defaults to False jsonindent (int): The indentation for json objects. Defaults to 0 + strictness_str (str): The strictness name. Defaults to NORMAL. + max_phrase_length (int): The maximum phrase length in examples. Defaults to 80. + use_mct_export (bool): Whether to use a MedCATtrainer export as input. Defaults to False. + mct_export_yaml_path (str): The (optional) path the converted MCT export should be saved as YAML at. + If not set (or None), the MCT export is not saved in YAML format. Defaults to None. + only_mct_export_conversion (bool): Whether to only deal with the MCT export conversion. + I.e exit when MCT export conversion is done. Defaults to False. + only_describe (bool): Whether to only describe the finding options and exit. + Defaults to False. + require_fully_correct (bool): Whether all cases are required to be correct. + If set to True, an exit-status of 1 is returned unless all (sub)cases are correct. + Defaults to False. Raises: ValueError: If unable to overwrite file or folder does not exist. """ + if only_describe: + show_description() + return if jsonpath and jsonpath.exists() and not overwrite: # check before doing anything so as to not waste time on the tests raise ValueError( @@ -41,17 +87,40 @@ def main(model_pack_dir: Path, test_suite_file: Path, raise ValueError( f'Need to specify a file in an existing directory, folder not found: {str(jsonpath)}') logger.info('Loading RegressionChecker from yaml: %s', test_suite_file) - rc = RegressionChecker.from_yaml(str(test_suite_file)) + if not use_mct_export: + rc = RegressionSuite.from_yaml(str(test_suite_file)) + else: + rc = RegressionSuite.from_mct_export(str(test_suite_file)) + if mct_export_yaml_path: + logger.info('Writing MCT export in YAML to %s', str(mct_export_yaml_path)) + with open(mct_export_yaml_path, 'w') as f: + f.write(rc.to_yaml()) + if only_mct_export_conversion: + logger.info("Done with conversion - exiting") + return logger.info('Loading model pack from file: %s', model_pack_dir) cat: CAT = CAT.load_model_pack(str(model_pack_dir)) logger.info('Checking the current status') - res = rc.check_model(cat, TranslationLayer.from_CDB(cat.cdb), total=total) + res = rc.check_model(cat, TranslationLayer.from_CDB(cat.cdb)) + strictness = Strictness[strictness_str] + if examples_strictness_str in ("None", "N/A"): + examples_strictness = None + else: + examples_strictness = Strictness[examples_strictness_str] if jsonpath: logger.info('Writing to %s', str(jsonpath)) - jsonpath.write_text(json.dumps(res.dict(), indent=jsonindent)) + jsonpath.write_text(json.dumps(res.dict(strictness=examples_strictness), + indent=jsonindent)) else: logger.info(res.get_report(phrases_separately=phrases, - hide_empty=hide_empty, show_failures=not hide_failures)) + hide_empty=hide_empty, examples_strictness=examples_strictness, + strictness=strictness, phrase_max_len=max_phrase_length)) + if require_fully_correct: + total, success = res.calculate_report(phrases_separately=phrases, + hide_empty=hide_empty, examples_strictness=examples_strictness, + strictness=strictness, phrase_max_len=max_phrase_length)[:2] + if total != success: + exit(1) if __name__ == '__main__': @@ -59,30 +128,47 @@ def main(model_pack_dir: Path, test_suite_file: Path, parser.add_argument('modelpack', help='The model pack against which to check', type=Path) parser.add_argument('test_suite', help='YAML formatted file containing the regression test suite' - 'The default value (and exampe) is at `configs/default_regression_tests.yml`', - default=Path( - 'configs', 'default_regression_tests.yml'), - nargs='?', - type=Path) + f'The default value (and example) is at `{DEFAULT_TEST_SUITE_PATH}`', + default=DEFAULT_TEST_SUITE_PATH, nargs='?', type=Path) parser.add_argument('--silent', '-s', help='Make the operation silent (i.e ignore console output)', action='store_true') parser.add_argument('--verbose', '-debug', help='Enable debug/verbose mode', action='store_true') - parser.add_argument('--total', '-t', help='Set the total number of (sub)cases that will be tested. ' - 'This will enable using a progress bar. ' - 'If unknown, a large-ish number might still be beneficial to show progress.', type=int, default=None) parser.add_argument('--phrases', '-p', help='Include per-phrase information in report', action='store_true') parser.add_argument('--noempty', help='Hide empty cases in report', action='store_true') - parser.add_argument('--hidefailures', help='Hide failed cases in report', - action='store_true') + parser.add_argument('--example-strictness', help='The strictness of examples. Set to None to disable. ' + 'This defaults to STRICTEST to show all non-identical examples. ', + choices=[strictness.name for strictness in Strictness] + ["None"], + default=Strictness.STRICTEST.name) parser.add_argument('--jsonfile', help='Save report to a json file', type=Path) parser.add_argument('--overwrite', help='Whether to overwrite save file', action='store_true') parser.add_argument('--jsonindent', help='The json indent', type=int, default=None) + parser.add_argument('--strictness', help='The strictness to consider success.', + choices=[strictness.name for strictness in Strictness], + default=Strictness.NORMAL.name) + parser.add_argument('--max-phrase-length', help='The maximum phrase length in examples.', + type=int, default=80) + parser.add_argument('--from-mct-export', help='Whether to load the regression suite from ' + 'a MedCATtrainer export (.json) instead of a YAML format (default).', + action='store_true') + parser.add_argument('--mct-export-yaml', help='The YAML file path to safe a convert MCT ' + 'export as. Only useful alongside `--from-mct-export` option and an ' + 'MCT export passed as the test suite.', + type=str, default=None) + parser.add_argument('--only-conversion', help='Whether to load only deal with the MCT export ' + 'conversion. Only useful alongside `--from-mct-export` and `--mct-export-yaml`', + action='store_true') + parser.add_argument('--only-describe', help='Only describe the various findings and exit.', + action='store_true') + parser.add_argument('--require-fully-correct', help='Require the regression test to be fully correct. ' + 'If set, a non-zero exit status is returned unless all cases are successful (100%). ' + 'This can be useful for (e.g) CI workflow integration.', + action='store_true') args = parser.parse_args() if not args.silent: logger.addHandler(logging.StreamHandler()) @@ -91,6 +177,10 @@ def main(model_pack_dir: Path, test_suite_file: Path, from medcat.utils.regression import logger as regr_logger regr_logger.setLevel('DEBUG') regr_logger.addHandler(logging.StreamHandler()) - main(args.modelpack, args.test_suite, total=args.total, - phrases=args.phrases, hide_empty=args.noempty, hide_failures=args.hidefailures, - jsonpath=args.jsonfile, overwrite=args.overwrite, jsonindent=args.jsonindent) + main(args.modelpack, args.test_suite, + phrases=args.phrases, hide_empty=args.noempty, examples_strictness_str=args.example_strictness, + jsonpath=args.jsonfile, overwrite=args.overwrite, jsonindent=args.jsonindent, + strictness_str=args.strictness, max_phrase_length=args.max_phrase_length, + use_mct_export=args.from_mct_export, mct_export_yaml_path=args.mct_export_yaml, + only_mct_export_conversion=args.only_conversion, only_describe=args.only_describe, + require_fully_correct=args.require_fully_correct) diff --git a/medcat/utils/regression/results.py b/medcat/utils/regression/results.py index 639d889a7..421ec217a 100644 --- a/medcat/utils/regression/results.py +++ b/medcat/utils/regression/results.py @@ -1,128 +1,361 @@ -from enum import Enum -from typing import Callable, Dict, List, Optional, Tuple, cast +from enum import Enum, auto +from typing import Dict, List, Optional, Any, Set, Iterable, Tuple +import json import pydantic -from medcat.utils.regression.targeting import TranslationLayer - - -class FailReason(str, Enum): - CONCEPT_NOT_ANNOTATED = 'CONCEPT_NOT_ANNOTATED' - """The concept was not annotated by the model""" - INCORRECT_CUI_FOUND = 'INCORRECT_CUI_FOUND' - """A different CUI with the same name was found""" - INCORRECT_SPAN_BIG = 'INCORRECT_SPAN_BIG' - """The concept was a part of an annotation made by the model""" - INCORRECT_SPAN_SMALL = 'INCORRECT_SPAN_SMALL' - """Only a part of the concept was annotated""" - CUI_NOT_FOUND = 'CUI_NOT_FOUND' - """The CUI was not found in the context database""" - CUI_PARENT_FOUND = 'CUI_PARENT_FOUND' - """The CUI annotated was the parent of the concept""" - CUI_CHILD_FOUND = 'CUI_CHILD_FOUND' - """The CUI annotated was a child of the concept""" - NAME_NOT_FOUND = 'NAME_NOT_FOUND' - """The name specified was not found in the context database""" - UNKNOWN = 'UNKNOWN' - """Unknown reason for failure""" - - -class FailDescriptor(pydantic.BaseModel): - cui: str - name: str - reason: FailReason - extra: str = '' +from medcat.utils.regression.targeting import TranslationLayer, FinalTarget +from medcat.utils.regression.utils import limit_str_len, add_doc_strings_to_enum + + +class Finding(Enum): + """Describes whether or how the finding verified. + + The idea is that we know where we expect the entity to be recognised + and the enum constants describe how the recognition compared to the + expectation. + + In essence, we want to know the relative positions of the two pairs of + numbers (character numbers): + - Expected Start, Expected End + - Recognised Start, Recognised End + + We can model this as 4 numbers on the number line. And we want to know + their position relative to each other. + For example, if the expected positions are marked with * and recognised + positions with #, we may have something like: + ___*__#_______#*______________ + Which would indicate that there is a partial, but smaller span recognised. + """ + # same CUIs + IDENTICAL = auto() + """The CUI and the span recognised are identical to what was expected.""" + BIGGER_SPAN_RIGHT = auto() + """The CUI is the same, but the recognised span is longer on the right. + + If we use the notation from the class doc string, e.g: + _*#__*__#""" + BIGGER_SPAN_LEFT = auto() + """The CUI is the same, but the recognised span is longer on the left. + + If we use the notation from the class doc string, e.g: + _#_*__*#_""" + BIGGER_SPAN_BOTH = auto() + """The CUI is the same, but the recognised span is longer on both sides. + + If we use the notation from the class doc string, e.g: + _#__*__*__#_""" + SMALLER_SPAN = auto() + """The CUI is the same, but the recognised span is smaller. + + If we use the notation from the class doc string, e.g: + _*_#_#_*_ (neither start nor end match) + _*#_#_*__ (start matches, but end is before expected) + _*__#_#*_ (end matches, but start is after expected)""" + PARTIAL_OVERLAP = auto() + """The CUI is the same, but the span overlaps partially. + + If we use the notation from the class doc string, e.g: + _*_#__*_#_ (starts between expected start and end, but ends beyond) + _#_*_#_*__ (start before expected start, but ends between expected start and end)""" + # slightly different CUIs + FOUND_DIR_PARENT = auto() + """The recognised CUI is a parent of the expected CUI but the span is an exact match.""" + FOUND_DIR_GRANDPARENT = auto() + """The recognised CUI is a grandparent of the expected CUI but the span is an exact match.""" + FOUND_ANY_CHILD = auto() + """The recognised CUI is a child of the expected CUI but the span is an exact match.""" + FOUND_CHILD_PARTIAL = auto() + """The recognised CUI is a child yet the match is only partial (smaller/bigger/partial).""" + FOUND_OTHER = auto() + """Found another CUI in the same span.""" + FAIL = auto() + """The concept was not recognised in any meaningful way.""" + + def has_correct_cui(self) -> bool: + """Whether the finding found the correct concept. + + Returns: + bool: Whether the correct concept was found. + """ + return self in ( + Finding.IDENTICAL, Finding.BIGGER_SPAN_RIGHT, Finding.BIGGER_SPAN_LEFT, + Finding.BIGGER_SPAN_BOTH, Finding.SMALLER_SPAN, Finding.PARTIAL_OVERLAP + ) @classmethod - def get_reason_for(cls, cui: str, name: str, res: dict, translation: TranslationLayer) -> 'FailDescriptor': - """Get the fail reason for the failure of finding the specifeid CUI and name - where the resulting entities are presented. + def determine(cls, exp_cui: str, exp_start: int, exp_end: int, + tl: TranslationLayer, found_entities: Dict[str, Dict[str, Any]], + strict_only: bool = False, + check_children: bool = True, check_parent: bool = True, + check_grandparent: bool = True + ) -> Tuple['Finding', Optional[str]]: + """Determine the finding type based on the input Args: - cui (str): The cui that was expected - name (str): The name that was expected - res (dict): The entities that were annotated - translation (TranslationLayer): The translation layer + exp_cui (str): Expected CUI. + exp_start (int): Expected span start. + exp_end (int): Expected span end. + tl (TranslationLayer): The translation layer. + found_entities (Dict[str, Dict[str, Any]]): The entities found by the model. + strict_only (bool): Whether to use a strict-only mode (either identical or fail). Defaults to False. + check_children (bool): Whether to check the children. Defaults to True. + check_parent (bool): Whether to check for parent(s). Defaults to True. + check_grandparent (bool): Whether to check for grandparent(s). Defaults to True. Returns: - FailDescriptor: The corresponding fail descriptor + Tuple['Finding', Optional[str]]: The type of finding determined, and the alternative. """ - def format_matching(matches: List[Tuple[str, str]]) -> str: - return 'Found: ' + ', '.join(f'{mcui}|{mname}' for mcui, mname in matches) - fail_reason: FailReason = FailReason.UNKNOWN # should never remain unknown - extra: str = '' - if cui not in translation.cui2names: - fail_reason = FailReason.CUI_NOT_FOUND - elif name not in translation.name2cuis: - fail_reason = FailReason.NAME_NOT_FOUND - extra = f'Names for concept: {translation.cui2names[cui]}' - else: - ents = res['entities'] - found_cuis = [ents[nr]['cui'] for nr in ents] - found_names = [ents[nr]['source_value'] for nr in ents] - found_children = translation.get_children_of(found_cuis, cui) - found_parents = translation.get_parents_of(found_cuis, cui) - if found_children: - fail_reason = FailReason.CUI_CHILD_FOUND - w_name = [(ccui, found_names[found_cuis.index(ccui)]) - for ccui in found_children] - extra = format_matching(w_name) - elif found_parents: - fail_reason = FailReason.CUI_PARENT_FOUND - w_name = [(ccui, found_names[found_cuis.index(ccui)]) - for ccui in found_parents] - extra = format_matching(w_name) - else: - found_cuis_names = list(zip(found_cuis, found_names)) - - def get_matching(condition: Callable[[str, str], bool]): - return [(found_cui, found_name) - for found_cui, found_name in found_cuis_names - if condition(found_cui, found_name)] - name = name.lower() - same_names = get_matching( - lambda _, fname: fname.lower() == name) - bigger_span = get_matching( - lambda _, fname: name in fname.lower()) - smaller_span = get_matching( - lambda _, fname: fname.lower() in name) - if same_names: - extra = format_matching(same_names) - fail_reason = FailReason.INCORRECT_CUI_FOUND - elif bigger_span: - extra = format_matching(bigger_span) - fail_reason = FailReason.INCORRECT_SPAN_BIG - elif smaller_span: - extra = format_matching(smaller_span) - fail_reason = FailReason.INCORRECT_SPAN_SMALL - else: - fail_reason = FailReason.CONCEPT_NOT_ANNOTATED - return FailDescriptor(cui=cui, name=name, reason=fail_reason, extra=extra) + return FindingDeterminer(exp_cui, exp_start, exp_end, + tl, found_entities, strict_only, + check_children, check_parent, check_grandparent).determine() + + +# NOTE: add doc strings to enum constants +add_doc_strings_to_enum(Finding) + + +class FindingDeterminer: + """A helper class to determine the type of finding. + + This is mostly useful to split the responsibilities of + looking at children/parents as well as to keep track of + the already-checked children to avoid infinite recursion + (which could happen in - e.g - a SNOMED model). + + Args: + exp_cui (str): The expected CUI. + exp_start (int): The expected span start. + exp_end (int): The expected span end. + tl (TranslationLayer): The translation layer. + found_entities (Dict[str, Dict[str, Any]]): The entities found by the model. + strict_only (bool): Whether to use strict-only mode (either identical or fail). Defaults to False. + check_children (bool): Whether or not to check the children. Defaults to True. + check_parent (bool): Whether to check for parent(s). Defaults to True. + check_grandparent (bool): Whether to check for granparent(s). Defaults to True. + """ + + def __init__(self, exp_cui: str, exp_start: int, exp_end: int, + tl: TranslationLayer, found_entities: Dict[str, Dict[str, Any]], + strict_only: bool = False, + check_children: bool = True, check_parent: bool = True, + check_grandparent: bool = True,) -> None: + self.exp_cui = exp_cui + self.exp_start = exp_start + self.exp_end = exp_end + self.tl = tl + self.found_entities = found_entities + self.strict_only = strict_only + self.check_children = check_children + self.check_parent = check_parent + self.check_grandparent = check_grandparent + # helper for children to avoid infinite recursion + self._checked_children: Set[str] = set() + + def _determine_raw(self, start: int, end: int) -> Optional[Finding]: + """Determines the raw SPAN-ONLY finding. + + I.e this assumes the concept is appropriate. + It will return None if there is no overlapping span. + + Args: + start (int): The start of the span. + end (int): The end of the span. + + Raises: + MalformedFinding: If the start is greater than the end. + MalformedFinding: If the expected start is greater than the expected end. + + Returns: + Optional[Finding]: The finding, if a match is found. + """ + if end < start: + raise MalformedFinding(f"The end ({end}) is smaller than the start ({start})") + elif self.exp_end < self.exp_start: + raise MalformedFinding(f"The expected end ({self.exp_end}) is " + f"smaller than the expected start ({self.exp_start})") + if self.strict_only: + if start == self.exp_start and end == self.exp_end: + return Finding.IDENTICAL + return None + if start < self.exp_start: + if end < self.exp_start: + return None + elif end < self.exp_end: + return Finding.PARTIAL_OVERLAP # TODO - distinguish[overlap]? + elif end == self.exp_end: + return Finding.BIGGER_SPAN_LEFT + return Finding.BIGGER_SPAN_BOTH + elif start == self.exp_start: + if end < self.exp_end: + return Finding.SMALLER_SPAN # TODO - distinguish[smaller]? + elif end == self.exp_end: + return Finding.IDENTICAL + return Finding.BIGGER_SPAN_RIGHT + elif start > self.exp_start and start <= self.exp_end: + if end < self.exp_end: + return Finding.SMALLER_SPAN # TODO - distinguish[smaller]? + elif end == self.exp_end: + return Finding.SMALLER_SPAN # TODO - distinguish[smaller]? + return Finding.PARTIAL_OVERLAP # TODO - distinguish[overlap]? + # if start > exp_end -> no match + return None + + def _get_strict(self) -> Optional[Finding]: + if not self.found_entities: + return Finding.FAIL + for entity in self.found_entities.values(): + start, end, cui = entity['start'], entity['end'], entity['cui'] + if cui == self.exp_cui: + raw_find = self._determine_raw(start, end) + if raw_find: + return raw_find + if self.strict_only: + return Finding.FAIL + return None + + def _check_parents(self) -> Optional[Tuple[Finding, Optional[str]]]: + parents = self.tl.get_direct_parents(self.exp_cui) + for parent in parents: + finding, wcui = Finding.determine(parent, self.exp_start, self.exp_end, + self.tl, + self.found_entities, + check_children=False, + check_parent=self.check_grandparent, + check_grandparent=False) + if finding is Finding.IDENTICAL: + return Finding.FOUND_DIR_PARENT, parent + if finding is Finding.FOUND_DIR_PARENT: + return Finding.FOUND_DIR_GRANDPARENT, wcui + return None + + def _check_children(self) -> Optional[Tuple[Finding, Optional[str]]]: + children = self.tl.get_direct_children(self.exp_cui) + for child in children: + finding, wcui = Finding.determine(child, self.exp_start, self.exp_end, + self.tl, + self.found_entities, + check_children=True, + check_parent=False, + check_grandparent=False) + if finding in (Finding.IDENTICAL, Finding.FOUND_ANY_CHILD): + alt_cui = child if finding == Finding.IDENTICAL else wcui + return Finding.FOUND_ANY_CHILD, alt_cui + elif finding.has_correct_cui(): + # i.e a partial match with same CUI + return Finding.FOUND_CHILD_PARTIAL, child + elif finding is Finding.FOUND_CHILD_PARTIAL: + return finding, wcui + self._checked_children.add(child) + return None + + def _descr_cui(self, cui: Optional[str]) -> Optional[str]: + if cui is None: + return None + return f"{cui} ({self.tl.get_preferred_name(cui)})" + + def _find_diff_cui(self) -> Optional[Tuple[Finding, str]]: + for entity in self.found_entities.values(): + start, end, cui = entity['start'], entity['end'], entity['cui'] + if start == self.exp_start and end == self.exp_end: + return Finding.FOUND_OTHER, cui + return None + + def determine(self) -> Tuple[Finding, Optional[str]]: + """Determine the finding based on the given information. + + First, the strict check is done (either identical or not). + Then, parents are checked (if required). + After that, children are checked (if required). + + Returns: + Tuple[Finding, Optional[str]]: The appropriate finding, and the alternative (if applicable). + """ + finding, cui = self._determine() + # NOTE: the point of this wrapper method is to add the preferred name + # to the CUI in one place and one place only + return finding, self._descr_cui(cui) + + def _determine(self) -> Tuple[Finding, Optional[str]]: + finding = self._get_strict() + if finding is not None: + return finding, None + if self.check_parent: + fpar = self._check_parents() + if fpar is not None: + return fpar + if self.check_children: + self._checked_children.add(self.exp_cui) + fch = self._check_children() + if fch is not None: + return fch + fdcui = self._find_diff_cui() + return fdcui or (Finding.FAIL, None) + + +class Strictness(Enum): + """The total strictness on which to judge the results.""" + STRICTEST = auto() + """The strictest option which only allows identical findings.""" + STRICT = auto() + """A strict option which allows identical or children.""" + NORMAL = auto() + """Normal strictness also allows partial overlaps on target concept and children.""" + LENIENT = auto() + """Lenient stictness also allows parents and grandparents.""" + ANYTHING = auto() + """Anything stricness allows ANY finding. + + This would generally only be relevant when disabling examples + for results descriptors.""" + + +STRICTNESS_MATRIX: Dict[Strictness, Set[Finding]] = { + Strictness.STRICTEST: {Finding.IDENTICAL}, + Strictness.STRICT: {Finding.IDENTICAL, Finding.FOUND_ANY_CHILD}, + Strictness.NORMAL: { + Finding.IDENTICAL, Finding.FOUND_ANY_CHILD, Finding.FOUND_CHILD_PARTIAL, + Finding.BIGGER_SPAN_RIGHT, Finding.BIGGER_SPAN_LEFT, + Finding.BIGGER_SPAN_BOTH, + Finding.SMALLER_SPAN, Finding.PARTIAL_OVERLAP + }, + Strictness.LENIENT: { + Finding.IDENTICAL, Finding.FOUND_ANY_CHILD, + Finding.BIGGER_SPAN_RIGHT, Finding.BIGGER_SPAN_LEFT, + Finding.BIGGER_SPAN_BOTH, + Finding.SMALLER_SPAN, Finding.PARTIAL_OVERLAP, + Finding.FOUND_DIR_PARENT, Finding.FOUND_DIR_GRANDPARENT, + }, + Strictness.ANYTHING: set(Finding), +} class SingleResultDescriptor(pydantic.BaseModel): + """The result descriptor. + + This class is responsible for keeping track of all the + findings (i.e how many were found to be identical) as + well as the examples of the finding on a per-target + basis for further analysis. + """ name: str """The name of the part that was checked""" - success: int = 0 - """Number of successes""" - fail: int = 0 - """Number of failures""" - failures: List[FailDescriptor] = [] + findings: Dict[Finding, int] = {} """The description of failures""" + examples: List[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]] = [] + """The examples of non-perfect alignment.""" - def report_success(self, cui: str, name: str, success: bool, fail_reason: Optional[FailDescriptor]) -> None: - """Report a test case and its successfulness + def report_success(self, target: FinalTarget, found: Tuple[Finding, Optional[str]]) -> None: + """Report a test case and its successfulness. Args: - cui (str): The CUI being checked - name (str): The name being checked - success (bool): Whether or not the check was successful - fail_reason (Optional[FailDescriptor]): The reason for the failure (if applicable) + target (FinalTarget): The target configuration + found (Tuple[Finding, Optional[str]]): Whether or not the check was successful """ - if success: - self.success += 1 - else: - self.fail += 1 - self.failures.append(cast(FailDescriptor, fail_reason)) + finding, _ = found + if finding not in self.findings: + self.findings[finding] = 0 + self.findings[finding] += 1 + self.examples.append((target, found)) def get_report(self) -> str: """Get the report associated with this descriptor @@ -130,31 +363,99 @@ def get_report(self) -> str: Returns: str: The report string """ - total = self.success + self.fail - return f"""Tested "{self.name}" for a total of {total} cases: - Success: {self.success:10d} ({100 * self.success / total if total > 0 else 0}%) - Failure: {self.fail:10d} ({100 * self.fail / total if total > 0 else 0}%)""" + total = sum(self.findings.values()) + ret_vals = [f"Tested '{self.name}' for a total of {total} cases:"] + ret_vals.extend([ + f"{f.name:24s}:{self.findings[f]:10d} ({100 * self.findings[f] / total if total > 0 else 0:5.2f}%)" + # NOTE iterating over Finding so the order is the same as in the enum + for f in Finding if f in self.findings + ]) + return "\n".join(ret_vals) + + def dict(self, **kwargs) -> dict: + if 'strictness' in kwargs: + kwargs = kwargs.copy() # so if used elsewhere, keeps the kwarg + strict_raw = kwargs.pop('strictness') + if isinstance(strict_raw, Strictness): + strictness = strict_raw + elif isinstance(strict_raw, str): + strictness = Strictness[strict_raw] + else: + raise ValueError(f"Unknown stircntess specified: {strict_raw}") + else: + strictness = Strictness.NORMAL + # avoid serialising multiple times + if 'exclude' in kwargs and kwargs['exclude'] is not None: + exclude: set = kwargs['exclude'] + else: + exclude = set() + kwargs['exclude'] = exclude + exclude.update(('findings', 'examples')) + serialized_dict = { + key.name: value for key, value in self.findings.items() + } + serialized_examples = [ + (ft.dict(**kwargs), (f[0].name, f[1])) for ft, f in self.examples + # only count if NOT in strictness matrix (i.e 'failures') + if f[0] not in STRICTNESS_MATRIX[strictness] + ] + model_dict = super().dict(**kwargs) + model_dict['findings'] = serialized_dict + model_dict['examples'] = serialized_examples + return model_dict + + def json(self, **kwargs) -> str: + d = self.dict(**kwargs) + return json.dumps(d) class ResultDescriptor(SingleResultDescriptor): + """The overarching result descriptor that handles multiple phrases. + + This class keeps track of the results on a per-phrase basis and + can be used to get the overall report and/or iterate over examples. + """ per_phrase_results: Dict[str, SingleResultDescriptor] = {} - def report(self, cui: str, name: str, phrase: str, success: bool, fail_reason: Optional[FailDescriptor]) -> None: + def report(self, target: FinalTarget, finding: Tuple[Finding, Optional[str]]) -> None: """Report a test case and its successfulness Args: - cui (str): The CUI being checked - name (str): The name being checked - phrase (str): The phrase being checked - success (bool): Whether or not the check was successful - fail_reason (Optional[FailDescriptor]): The reason for the failure (if applicable) + target (FinalTarget): The final targe configuration + finding (Tuple[Finding, Optional[str]]): To what extent the concept was recognised """ - super().report_success(cui, name, success, fail_reason) + phrase = target.final_phrase + super().report_success(target, finding) if phrase not in self.per_phrase_results: self.per_phrase_results[phrase] = SingleResultDescriptor( name=phrase) - self.per_phrase_results[phrase].report_success( - cui, name, success, fail_reason) + self.per_phrase_results[phrase].report_success(target, finding) + + def iter_examples(self, strictness_threshold: Strictness + ) -> Iterable[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]]: + """Iterate suitable examples. + + The strictness threshold at which to include examples. + + Any finding that is assumed to be "correct enough" according to + the strictness matrix for this threshold will be withheld from + examples. + + In simpler terms, if the finding is NOT in the strictness matrix + for this strictness, the example is recorded. + + NOTE: To disable example keeping, set the threshold to Strictness.ANYTHING. + + Args: + strictness_threshold (Strictness): The strictness threshold. + + Yields: + Iterable[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]]: The placeholder, phrase, finding, CUI, and name. + """ + for srd in self.per_phrase_results.values(): + for target, finding in srd.examples: + if finding[0] not in STRICTNESS_MATRIX[strictness_threshold]: + yield target, finding def get_report(self, phrases_separately: bool = False) -> str: """Get the report associated with this descriptor @@ -172,79 +473,221 @@ def get_report(self, phrases_separately: bool = False) -> str: for srd in self.per_phrase_results.values()]) return sr + '\n\t\t' + children.replace('\n', '\n\t\t') + def dict(self, **kwargs) -> dict: + if 'exclude' in kwargs and kwargs['exclude'] is not None: + exclude: set = kwargs['exclude'] + else: + exclude = set() + kwargs['exclude'] = exclude + # NOTE: ignoring here so that examples are only present in the per phrase part + exclude.update(('examples', 'per_phrase_results')) + d = super().dict(**kwargs) + if 'examples' in d: + # NOTE: I don't really know why, but the examples still + # seem to be a part of the resulting dict, so I need + # to explicitly remove them + del d['examples'] + # NOTE: need to propagate here manually so the strictness keyword + # makes sense and doesn't cause issues due being to unexpected keyword + per_phrase_results = { + phrase: res.dict(**kwargs) for phrase, res in self.per_phrase_results.items() + } + d['per_phrase_results'] = per_phrase_results + return d + class MultiDescriptor(pydantic.BaseModel): + """The descriptor of results over multiple different results (parts). + + The idea is that this would likely be used with a regression suite + and it would incorporate all the different regression cases it describes. + """ name: str """The name of the collection being checked""" parts: List[ResultDescriptor] = [] """The parts kept track of""" @property - def success(self) -> int: - """The total number of successes. + def findings(self) -> Dict[Finding, int]: + """The total findings. Returns: - int: The total number of sucesses. + Dict[Finding, int]: The total number of successes. """ - return sum(part.success for part in self.parts) + totals: Dict[Finding, int] = {} + for part in self.parts: + for f, val in part.findings.items(): + if f not in totals: + totals[f] = val + else: + totals[f] += val + return totals - @property - def fail(self) -> int: - """The total number of failures. + def iter_examples(self, strictness_threshold: Strictness + ) -> Iterable[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]]: + """Iterate over all relevant examples. - Returns: - int: The total number of failures. + Only examples that are not in the strictness matrix for the specified + threshold will be used. + + Args: + strictness_threshold (Strictness): The threshold of avoidance. + + Yields: + Iterable[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]]: The examples """ - return sum(part.fail for part in self.parts) + for descr in self.parts: + yield from descr.iter_examples(strictness_threshold=strictness_threshold) - def get_report(self, phrases_separately: bool, - hide_empty: bool = False, show_failures: bool = True) -> str: - """Get the report associated with this descriptor + def _get_part_report(self, part: ResultDescriptor, allowed_findings: Set[Finding], + total_findings: Dict[Finding, int], + hide_empty: bool, + examples_strictness: Optional[Strictness], + phrases_separately: bool, + phrase_max_len: int, + ) -> Tuple[str, int, int, int]: + if hide_empty and len(part.findings) == 0: + return '', 0, 0, 0 + total_total, total_s, total_f = 0, 0, 0 + for f, val in part.findings.items(): + if f not in total_findings: + total_findings[f] = val + else: + total_findings[f] += val + total_total += val + if f in allowed_findings: + total_s += val + else: + total_f += val + cur_add = '\t' + \ + part.get_report(phrases_separately=phrases_separately).replace( + '\n', '\n\t\t') + if examples_strictness is not None: + latest_phrase = '' + for target, found in part.iter_examples(strictness_threshold=examples_strictness): + finding, ocui = found + if latest_phrase == '': + # add header only if there's failures to include + cur_add += f"\n\t\tExamples at {examples_strictness} strictness" + if latest_phrase != target.final_phrase: + short_phrase = limit_str_len(target.final_phrase, max_length=phrase_max_len, + keep_front=phrase_max_len // 2, + keep_rear=phrase_max_len // 2 - 10) + cur_add += f"\n\t\tWith phrase: {repr(short_phrase)}" + latest_phrase = target.final_phrase + found_cui_descr = f' [{ocui}]' if ocui else '' + cur_add += (f'\n\t\t\t{finding.name}{found_cui_descr} for ' + f'placeholder {target.placeholder} ' + f'with CUI {repr(target.cui)} and name {repr(target.name)}') + return cur_add, total_total, total_s, total_f + + def calculate_report(self, phrases_separately: bool = False, + hide_empty: bool = False, + examples_strictness: Optional[Strictness] = Strictness.STRICTEST, + strictness: Strictness = Strictness.NORMAL, + phrase_max_len: int = 80) -> Tuple[int, int, int, str, int]: + """Calculate some of the major parts of the report. Args: phrases_separately (bool): Whether to include per-phrase information hide_empty (bool): Whether to hide empty cases - show_failures (bool): Whether to show failures + examples_strictness (Optional[Strictness.STRICTEST]): What level of strictness to show for examples. + Set to None to disable examples. Defaults to Strictness.STRICTEST. + strictness (Strictness): The strictness of the success / fail overview. + Defaults to Strictness.NORMAL. + phrase_max_len (int): The maximum length of the phrase in examples. Defaults to 80. Returns: - str: The report string + Tuple[int, int, int, int, str]: The total number of examples, the total successes, the total failures, + the delegated part, and the number of empty """ del_out = [] # delegation - all_failures: List[FailDescriptor] = [] + total_findings: Dict[Finding, int] = {} total_s, total_f = 0, 0 + allowed_findings = STRICTNESS_MATRIX[strictness] + total_total = 0 nr_of_empty = 0 for part in self.parts: - total_s += part.success - total_f += part.fail - if hide_empty and part.success == part.fail == 0: + (cur_add, total_total_add, + total_s_add, total_f_add) = self._get_part_report( + part, allowed_findings, total_findings, hide_empty, + # NOTE: using STRICTEST strictness for examples means + # that all but IDENTICAL examples will be shown + examples_strictness, phrases_separately, phrase_max_len) + if hide_empty and total_total_add == 0: nr_of_empty += 1 - continue - cur_add = '\t' + \ - part.get_report(phrases_separately=phrases_separately).replace( - '\n', '\n\t\t') - del_out.append(cur_add) - all_failures.extend(part.failures) - total_total = total_s + total_f - delegated = '\n\t'.join(del_out) + else: + total_total += total_total_add + total_s += total_s_add + total_f += total_f_add + del_out.append(cur_add) + delegated = '\n'.join(del_out) + return total_total, total_s, total_f, delegated, nr_of_empty + + def get_report(self, phrases_separately: bool, + hide_empty: bool = False, + examples_strictness: Optional[Strictness] = Strictness.STRICTEST, + strictness: Strictness = Strictness.NORMAL, + phrase_max_len: int = 80) -> str: + """Get the report associated with this descriptor + + Args: + phrases_separately (bool): Whether to include per-phrase information + hide_empty (bool): Whether to hide empty cases + examples_strictness (Optional[Strictness.STRICTEST]): What level of strictness to show for examples. + Set to None to disable examples. Defaults to Strictness.STRICTEST. + strictness (Strictness): The strictness of the success / fail overview. + Defaults to Strictness.NORMAL. + phrase_max_len (int): The maximum length of the phrase in examples. Defaults to 80. + + Returns: + str: The report string + """ + (total_total, total_s, total_f, + delegated, nr_of_empty) = self.calculate_report(phrases_separately=phrases_separately, + hide_empty=hide_empty, + examples_strictness=examples_strictness, + strictness=strictness, + phrase_max_len=phrase_max_len) empty_text = '' + allowed_findings = STRICTNESS_MATRIX[strictness] if hide_empty: empty_text = f' A total of {nr_of_empty} cases did not match any CUIs and/or names.' - failures = '' - if show_failures and all_failures: - failure_types = {} - for fd in all_failures: - if fd.reason not in failure_types: - failure_types[fd.reason] = 0 - failure_types[fd.reason] += 1 - failures = '\nFailures:\n' + \ - '\n'.join( - [f'{ft}: {occurances}' for ft, occurances in failure_types.items()]) - failures += '\nDetailed:\n' + '\n'.join( - [f'CUI: {repr(descriptor.cui)}, name: {repr(descriptor.name)}, ' - f'reason: {descriptor.reason}{" (%s)"%descriptor.extra if descriptor.extra else ""}' - for descriptor in all_failures]) - return f"""A total of {len(self.parts)} parts were kept track of within the group "{self.name}". -And a total of {total_total} (sub)cases were checked.{empty_text} - Total success: {total_s:10d} ({100 * total_s / total_total if total_total > 0 else 0}%) - Total failure: {total_f:10d} ({100 * total_f / total_total if total_total > 0 else 0}%) - {delegated}{failures}""" + ret_vals = [f"""A total of {len(self.parts)} parts were kept track of within the group "{self.name}". +And a total of {total_total} (sub)cases were checked.{empty_text}"""] + allowed_fingings_str = [f.name for f in allowed_findings] + ret_vals.extend([ + f"At the strictness level of {strictness} (allowing {allowed_fingings_str}):", + f"The number of total successful (sub) cases: {total_s} " + f"({100 * total_s/total_total if total_total > 0 else 0:5.2f}%)", + f"The number of total failing (sub) cases : {total_f} " + f"({100 * total_f/total_total if total_total > 0 else 0:5.2f}%)" + ]) + ret_vals.extend([ + f"{f.name:24s}:{self.findings[f]:10d} " + f"({100 * self.findings[f] / total_total if total_total > 0 else 0:5.2f}%)" + # NOTE iterating over Finding so the order is the same as in the enum + for f in Finding if f in self.findings + ]) + return "\n".join(ret_vals) + f"\n{delegated}" + + def dict(self, **kwargs) -> dict: + if 'strictness' in kwargs: + strict_raw = kwargs.pop('strictness') + if isinstance(strict_raw, Strictness): + strictness = strict_raw + elif isinstance(strict_raw, str): + strictness = Strictness[strict_raw] + else: + raise ValueError(f"Unknown stircntess specified: {strict_raw}") + else: + strictness = Strictness.NORMAL + out_dict = super().dict(exclude={'parts'}, **kwargs) + out_dict['parts'] = [part.dict(strictness=strictness) for part in self.parts] + return out_dict + + +class MalformedFinding(ValueError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/medcat/utils/regression/suite_editor.py b/medcat/utils/regression/suite_editor.py deleted file mode 100644 index 228af4827..000000000 --- a/medcat/utils/regression/suite_editor.py +++ /dev/null @@ -1,47 +0,0 @@ -import argparse -import logging -from pathlib import Path -from typing import Optional - -from medcat.utils.regression.editing import combine_yamls - - -logger = logging.getLogger(__name__) - - -def main(base_file: str, add_file: str, new_file: Optional[str] = None, - ignore_identicals: bool = True) -> None: - logger.info( - "Starting to add to %s from %s", base_file, add_file) - res_file = combine_yamls(base_file, add_file, new_file=new_file, - ignore_identicals=ignore_identicals) - logger.debug("Combination successful") - logger.info("Saved combined data to %s", res_file) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - 'file', help='The base regression YAML file', type=Path) - parser.add_argument( - 'add_file', help='The additional regression YAML file', type=Path) - parser.add_argument('--newfile', help='The target file for the combination ' - '(otherwise, the base file is used)', type=Path, required=False) - parser.add_argument('--silent', '-s', help='Make the operation silent (i.e ignore console output)', - action='store_true') - parser.add_argument('--verbose', '-debug', help='Enable debug/verbose mode', - action='store_true') - parser.add_argument( - '--include-identicals', - help='Write down identical cases (they are only written down once by default)', - action='store_true') - args = parser.parse_args() - if not args.silent: - logger.addHandler(logging.StreamHandler()) - logger.setLevel('INFO') - if args.verbose: - from checking import logger as checking_logger - checking_logger.addHandler(logging.StreamHandler()) - checking_logger.setLevel('DEBUG') - main(args.file, args.add_file, new_file=args.newfile, - ignore_identicals=not args.include_identicals) diff --git a/medcat/utils/regression/targeting.py b/medcat/utils/regression/targeting.py index 516c1ec63..8acd12f3a 100644 --- a/medcat/utils/regression/targeting.py +++ b/medcat/utils/regression/targeting.py @@ -1,13 +1,12 @@ -from enum import Enum import logging -from typing import Dict, Iterable, Iterator, List, Set, Any, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Set, Tuple, Any +from functools import lru_cache +from itertools import product from pydantic import BaseModel from medcat.cdb import CDB -from medcat.utils.regression.utils import loosely_match_enum - logger = logging.getLogger(__name__) @@ -31,10 +30,15 @@ class TranslationLayer: """ def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str]], - cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]]) -> None: + cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]], + cui2preferred_names: Dict[str, str], + separator: str, whitespace: str = ' ') -> None: self.cui2names = cui2names self.name2cuis = name2cuis self.cui2type_ids = cui2type_ids + self.cui2preferred_names = cui2preferred_names + self.separator = separator + self.whitespace = whitespace self.type_id2cuis: Dict[str, Set[str]] = {} for cui, type_ids in self.cui2type_ids.items(): for type_id in type_ids: @@ -46,52 +50,96 @@ def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str if cui not in cui2children: self.cui2children[cui] = set() - def targets_for(self, cui: str) -> Iterator[Tuple[str, str]]: - for name in self.cui2names[cui]: - yield cui, name + def get_names_of(self, cui: str, only_prefnames: bool) -> List[str]: + """Get the preprocessed names of a CUI. + + This method preporcesses the names by replacing the separator (generally `~`) + with the appropriate whitespace (` `). - def all_targets(self, all_cuis: Set[str], all_names: Set[str], all_types: Set[str]) -> Iterator[Tuple[str, str]]: - """Get a generator of all target information objects. - This is the starting point for checking cases. + If the concept is not in the underlying CDB, an empty list is returned. Args: - all_cuis (Set[str]): The set of all CUIs to be queried - all_names (Set[str]): The set of all names to be queried - all_types (Set[str]): The set of all type IDs to be queried + cui (str): The concept in question. + only_prefnames (bool): Whether to only return a preferred name. - Yields: - Iterator[Tuple[str, str]]: The iterator of the target info + Returns: + List[str]: The list of names. + """ + if only_prefnames: + return [self.get_preferred_name(cui).replace(self.separator, self.whitespace)] + return [name.replace(self.separator, self.whitespace) + for name in self.cui2names.get(cui, [])] + + def get_preferred_name(self, cui: str) -> str: + """Get the preferred name of a concept. + + If no preferred name is found, the random 'first' name is selected. + + Args: + cui (str): The concept ID. + + Returns: + str: The preferred name. + """ + pref_name = self.cui2preferred_names.get(cui, None) + if pref_name is None: + logger.warning("CUI %s does not have a preferred name. " + "Using a random 'first' name of all the names", cui) + return self.get_first_name(cui) + return pref_name + + def get_first_name(self, cui: str) -> str: + """Get the preprocessed (potentially) arbitrarily first name of the given concept. + + If the concept does not exist, the CUI itself is returned. + + PS: The "first" name may not be consistent across runs since it relies on set order. + + Args: + cui (str): The concept ID. + + Returns: + str: The first name. + """ + for name in self.cui2names.get(cui, [cui]): + return name.replace(self.separator, self.whitespace) + return cui + + def get_direct_children(self, cui: str) -> List[str]: + """Get the direct children of a concept. + + This means only the children, but not grandchildren. + + If the underlying CDB doesn't list children for this CUI, an empty list is returned. + + Args: + cui (str): The concept in question. + + Returns: + List[str]: The (potentially empty) list of direct children. + """ + return list(self.cui2children.get(cui, [])) + + @lru_cache(maxsize=10_000) + def get_direct_parents(self, cui: str) -> List[str]: + """Get the direct parent(s) of a concept. + + PS: This method can be quite a CPU heavy one since it relies + on running through all the parent-children relationships + since the child->parent(s) relationship isn't normally + kept track of. + + Args: + cui (str): _description_ + + Returns: + List[str]: _description_ """ - for cui in all_cuis: - if cui not in self.cui2names: - logger.warning('CUI not found in translation layer: %s', cui) - continue - for name in self.cui2names[cui]: - yield cui, name - for name in all_names: - if name not in self.name2cuis: - logger.warning('Name not found in translation layer: %s', name) - continue - for cui in self.name2cuis[name]: - if cui in all_cuis: - continue # this cui-name pair should already have been yielded above - yield cui, name - for type_id in all_types: - if type_id not in self.type_id2cuis: - logger.warning( - 'Type ID not found in translation layer: %s', type_id) - continue - for cui in self.type_id2cuis[type_id]: - if cui in all_cuis: - continue # should have been yielded above - if cui not in self.cui2names: - logger.warning( - 'CUI not found in translation layer: %s', cui) - continue - for name in self.cui2names[cui]: - if name in all_names: - continue # should have been yielded above - yield cui, name + parents = [] + for pot_parent, children in self.cui2children.items(): + if cui in children: + parents.append(pot_parent) + return parents def get_children_of(self, found_cuis: Iterable[str], cui: str, depth: int = 1) -> List[str]: """Get the children of the specifeid CUI in the listed CUIs (if they exist). @@ -117,32 +165,6 @@ def get_children_of(self, found_cuis: Iterable[str], cui: str, depth: int = 1) - found_cuis, child, depth - 1)) return found_children - def get_parents_of(self, found_cuis: Iterable[str], cui: str, depth: int = 1) -> List[str]: - """Get the parents of the specifeid CUI in the listed CUIs (if they exist). - - If needed, higher order parents (i.e grandparents) can be queries for. - - This uses the `get_children_of` method intenrnally. - That is, if any of the found CUIs have the specified CUI as a child of - the specified depth, the found CUIs have a parent of the specified depth. - - Args: - found_cuis (Iterable[str]): The list of CUIs to look in - cui (str): The target child CUI - depth (int): The depth to carry out the search for - - Returns: - List[str]: The list of parents found - """ - found_parents = [] - for found_cui in found_cuis: - if self.get_children_of({cui}, found_cui, depth=depth): - # TODO - the intermediate results may get lost here - # i.e if found_cui is grandparent of the specified one, - # the direct parent is not listed - found_parents.append(found_cui) - return found_parents - @classmethod def from_CDB(cls, cdb: CDB) -> 'TranslationLayer': """Construct a TranslationLayer object from a context database (CDB). @@ -162,256 +184,252 @@ def from_CDB(cls, cdb: CDB) -> 'TranslationLayer': parent2child = {} else: parent2child = cdb.addl_info['pt2ch'] - return TranslationLayer(cdb.cui2names, cdb.name2cuis, cdb.cui2type_ids, parent2child) - - -class FilterStrategy(Enum): - """Describes the filter strategy. - I.e whether to match all or any - of the filters specified. - """ - ALL = 1 - """Specified that all filters must be satisfied""" - ANY = 2 - """Specified that any of the filters must be satisfied""" - - @classmethod - def match_str(cls, name: str) -> 'FilterStrategy': - """Find a loose string match. + return TranslationLayer(cdb.cui2names, cdb.name2cuis, cdb.cui2type_ids, parent2child, + cui2preferred_names=cdb.cui2preferred_name, + separator=cdb.config.general.separator) - Args: - name (str): The name of the enum - - Returns: - FilterStrategy: The matched FilterStrategy - """ - return loosely_match_enum(FilterStrategy, name) - -class FilterType(Enum): - """The types of targets that can be specified +class TargetPlaceholder(BaseModel): + """A class describing the options for a specific placeholder. """ - TYPE_ID = 1 - """Filters by specified type_ids""" - CUI = 2 - """Filters by specified CUIs""" - NAME = 3 - """Filters by specified names""" - CUI_AND_CHILDREN = 4 - """Filter by CUI but also allow children, up to a specified distance""" - - @classmethod - def match_str(cls, name: str) -> 'FilterType': - """Case insensitive matching for FilterType + placeholder: str + target_cuis: List[str] + onlyprefnames: bool = False - Args: - name (str): The naeme to be matched - Returns: - FilterType: The matched FilterType - """ - return loosely_match_enum(FilterType, name) +class PhraseChanger(BaseModel): + """The phrase changer. - -class TypedFilter(BaseModel): - """A filter with multiple values to filter against. + This is class used as a preprocessor for phrases with multiple placeholders. + It allows swapping in the rest of the placeholders while leaving in the one + that's being tested for. """ - type: FilterType - values: List[str] - - def get_applicable_targets(self, translation: TranslationLayer, in_gen: Iterator[Tuple[str, str]]) -> Iterator[Tuple[str, str]]: - """Get all applicable targets for this filter - - Args: - translation (TranslationLayer): The translation layer - in_gen (Iterator[Tuple[str, str]]): The input generator / iterator + preprocess_placeholders: List[Tuple[str, str]] - Yields: - Iterator[Tuple[str, str]]: The output generator - """ - if self.type == FilterType.CUI or self.type == FilterType.CUI_AND_CHILDREN: - for cui, name in in_gen: - if cui in self.values: - yield cui, name - if self.type == FilterType.NAME: - for cui, name in in_gen: - if name in self.values: - yield cui, name - if self.type == FilterType.TYPE_ID: - for cui, name in in_gen: - if cui in translation.cui2type_ids: - tids = translation.cui2type_ids[cui] - else: - tids = set() - for tid in tids: - if tid in self.values: - yield cui, name - break + def __call__(self, phrase: str) -> str: + for placeholder, replacement in self.preprocess_placeholders: + phrase = phrase.replace(placeholder, replacement) + return phrase @classmethod - def one_from_input(cls, target_type: str, vals: Union[str, list, dict]) -> 'TypedFilter': - """Get one typed filter from the input target type and values. - The values can either a be a string for a single target, - a list of strings for multiple targets, or - a dict in some more complicated cases (i.e CUI_AND_CHILDREN). + def empty(cls) -> 'PhraseChanger': + """Gets the empty phrase changer. - Args: - target_type (str): The target type as string - vals (Union[str, list, dict]): The values - - Raises: - ValueError: If the values are malformed + That is a phrase changer that makes no changes to the phrase. Returns: - TypedFilter: The parsed filter + PhraseChanger: The empty phrase changer. """ - t_type: FilterType = FilterType.match_str(target_type) - filt: TypedFilter - if isinstance(vals, dict): - if t_type != FilterType.CUI_AND_CHILDREN: - # currently only applicable for CUI_AND_CHILDREN case - raise ValueError(f'Misconfigured config for {target_type}, ' - 'expected either a value or a list of values ' - 'for this type of filter') - depth = vals['depth'] - delegate = cls.one_from_input(target_type, vals['cui']) - if t_type is FilterType.CUI_AND_CHILDREN: - filt = CUIWithChildFilter( - type=t_type, delegate=delegate, depth=depth) - else: - if isinstance(vals, str): - vals = [vals, ] - filt = TypedFilter(type=t_type, values=vals) - return filt + return cls(preprocess_placeholders=[]) - def to_dict(self) -> dict: - """Convert the TypedFilter to a dict to be serialised. - Returns: - dict: The dict representation - """ - return {self.type.name: self.values} +class TargetedPhraseChanger(BaseModel): + """The target phrase changer. - @staticmethod - def list_to_dicts(filters: List['TypedFilter']) -> List[dict]: - """Create a list of dicts from list of TypedFilters. + It includes the phrase changer (for preprocessing) along with + the relevant concept and the placeholder it will replace. + """ + changer: PhraseChanger + placeholder: str + cui: str + onlyprefnames: bool - Args: - filters (List[TypedFilter]): The list of typed filters - Returns: - List[dict]: The list of dicts - """ - return [filt.to_dict() for filt in filters] +class FinalTarget(BaseModel): + """The final target. - @staticmethod - def list_to_dict(filters: List['TypedFilter']) -> dict: - """Create a single dict from the list of TypedFilters. + This involves the final phrase (which (potentially) has other + placeholder replaced in it), the placeholder to be replaced, + and the CUI and specific name being used. + """ + placeholder: str + cui: str + name: str + final_phrase: str - Args: - filters (List[TypedFilter]): The list of typed filters - Returns: - dict: The dict - """ - d = {} - for filt_dict in TypedFilter.list_to_dicts(filters): - d.update(filt_dict) - return d +class OptionSet(BaseModel): + """The targeting option set. + + This describes all the target placeholders and concepts needed. + """ + options: List[TargetPlaceholder] + allow_any_combinations: bool = False @classmethod - def from_dict(cls, input: Dict[str, Any]) -> List['TypedFilter']: - """Construct a list of TypedFilter from a dict. + def from_dict(cls, section: Dict[str, Any]) -> 'OptionSet': + """Construct a OptionSet instance from a dict. The assumed structure is: - {: } - or - {: [, ]} - There can be multiple filter types defined. + { + 'placeholders': [ + { + 'placeholder': , + 'cuis': , + 'prefname-only': 'true' + }, ], + 'any-combination': + } + + The prefname-only key is optional. Args: - input (Dict[str, Any]): The input dict. + section (Dict[str, Any]): The dict to parse + + Raises: + ProblematicOptionSetException: If incorrect number of CUIs when not allowing any combination + ProblematicOptionSetException: If placeholders not a list + ProblematicOptionSetException: If multiple placehodlers with same place holder Returns: - List[TypedFilter]: The list of constructed TypedFilter + OptionSet: The resulting OptionSet """ - parsed_targets: List[TypedFilter] = [] - for target_type, vals in input.items(): - filt = cls.one_from_input(target_type, vals) - parsed_targets.append(filt) - return parsed_targets - - -class FilterOptions(BaseModel): - """A class describing the options for the filters - """ - strategy: FilterStrategy - onlyprefnames: bool = False + options: List['TargetPlaceholder'] = [] + allow_any_in = section.get('any-combination', 'false') + if isinstance(allow_any_in, str): + allow_any_combinations = allow_any_in.lower() == 'true' + elif isinstance(allow_any_in, bool): + allow_any_combinations = allow_any_in + else: + raise ProblematicOptionSetException(f"Unknown 'any-combination' value: {allow_any_in}") + if 'placeholders' not in section: + raise ProblematicOptionSetException("Misconfigured - no placeholders") + section_placeholders = section['placeholders'] + if not isinstance(section_placeholders, list): + raise ProblematicOptionSetException("Misconfigured - placehodlers not a list " + f"({section_placeholders})") + used_ph = set() + for part in section_placeholders: + placeholder = part['placeholder'] + if not isinstance(placeholder, str): + raise ProblematicOptionSetException(f"Unknown placeholder of type {type(placeholder)}. " + "Expected a string. Perhaps you need to surrong the " + "placeholder with single quotes (') in the yaml? " + f"Received: {placeholder}") + if placeholder in used_ph: + raise ProblematicOptionSetException("Misconfigured - multiple identical placeholders") + used_ph.add(placeholder) + target_cuis: List[str] = part['cuis'] + if not isinstance(target_cuis, list): + raise ProblematicOptionSetException( + f"Target CUIs not a list ({type(target_cuis)}): {repr(target_cuis)}") + if 'prefname-only' in part: + opn = part['prefname-only'] + if isinstance(opn, bool): + onlyprefnames = opn + else: + onlyprefnames = str(opn).lower() == 'true' + else: + onlyprefnames = False + option = TargetPlaceholder(placeholder=placeholder, target_cuis=target_cuis, + onlyprefnames=onlyprefnames) + options.append(option) + if not options: + raise ProblematicOptionSetException("Misconfigured - 0 placeholders found (empty list)") + if not allow_any_combinations: + # NOTE: need to have same number of target_cuis for each placeholder + # NOTE: there needs to be at least on option / placeholder anyway + nr_of_cuis = [len(opt.target_cuis) for opt in options] + if not all(nr == nr_of_cuis[0] for nr in nr_of_cuis): + raise ProblematicOptionSetException( + f"Unequal number of cuis when any-combination: false: {nr_of_cuis}. " + "When any-combination: false the number of CUIs for each placeholder " + "should be equal.") + return OptionSet(options=options, allow_any_combinations=allow_any_combinations) def to_dict(self) -> dict: - """Convert the FilterOptions to a dict. + """Convert the OptionSet to a dict. Returns: dict: The dict representation """ - return {'strategy': self.strategy.name, 'prefname-only': str(self.onlyprefnames)} - - @classmethod - def from_dict(cls, section: Dict[str, str]) -> 'FilterOptions': - """Construct a FilterOptions instance from a dict. - - The assumed structure is: - {'strategy': <'all' or 'any'>, - 'prefname-only': 'true'} - - Both strategy and prefname-only are optional. - - Args: - section (Dict[str, str]): The dict to parse + placeholders = [ + { + 'placeholder': opt.placeholder, + 'cuis': opt.target_cuis, + 'prefname-only': str(opt.onlyprefnames), + } + for opt in self.options + ] + return {'placeholders': placeholders, 'any-combination': str(self.allow_any_combinations)} + + def _get_all_combinations(self, cur_opts: TargetPlaceholder, other_opts: List[TargetPlaceholder], + translation: TranslationLayer) -> Iterator[Tuple[PhraseChanger, str]]: + per_ph_nr_of_opts = [len(opt.target_cuis) for opt in other_opts] + if self.allow_any_combinations: + # for each option with N target CUIs use 0, ..., N-1 + for choosers in product(*[range(n) for n in per_ph_nr_of_opts]): + # NOTE: using the 0th name for target CUI + placeholders = [(opt.placeholder, translation.get_preferred_name(opt.target_cuis[cui_nr])) + for opt, cui_nr in zip(other_opts, choosers)] + for target_cui in cur_opts.target_cuis: + yield PhraseChanger(preprocess_placeholders=placeholders), target_cui + else: + nr_of_opts = len(cur_opts.target_cuis) + for cui_nr in range(nr_of_opts): + placeholders = [ + # NOTE: using the 0th name for the target CUI + (opt.placeholder, translation.get_preferred_name(opt.target_cuis[cui_nr])) + for opt in other_opts + ] + yield PhraseChanger(preprocess_placeholders=placeholders), cur_opts.target_cuis[cui_nr] + + def estimate_num_of_subcases(self) -> int: + """Get the number of distinct subcases. + + This includes ones that can be calculated without the knowledge of the + underlying CDB. I.e it doesn't care for the number of names involved per CUI + but only takes into account what is described in the option set itself. + + If any combination is allowed, then the answer is the combination of + the number of target concepts per option. + If any combination is not allowed, then the answer is simply the number + of target concepts for an option (they should all have the same number). Returns: - FilterOptions: The resulting FilterOptions + int: _description_ """ - if 'strategy' in section: - strategy = FilterStrategy.match_str(section['strategy']) - else: - strategy = FilterStrategy.ALL # default - if 'prefname-only' in section: - onlyprefnames = section['prefname-only'].lower() == 'true' + num_of_opts = len(self.options) + if self.allow_any_combinations: + total_cases = 1 + for cur_opt in self.options: + total_cases *= len(cur_opt.target_cuis) else: - onlyprefnames = False - return FilterOptions(strategy=strategy, onlyprefnames=onlyprefnames) + total_cases = len(self.options[0].target_cuis) + return num_of_opts * total_cases - -class CUIWithChildFilter(TypedFilter): - delegate: TypedFilter - depth: int - values: List[str] = [] # overwrite TypedFilter - - def get_applicable_targets(self, translation: TranslationLayer, in_gen: Iterator[Tuple[str, str]]) -> Iterator[Tuple[str, str]]: - """Get all applicable targets for this filter + def get_preprocessors_and_targets(self, translation: TranslationLayer + ) -> Iterator[TargetedPhraseChanger]: + """Get the targeted phrase changers. Args: - translation (TranslationLayer): The translation layer - in_gen (Iterator[Tuple[str, str]]): The input generator / iterator + translation (TranslationLayer): The translaton layer. Yields: - Iterator[Tuple[str, str]]: The output generator - """ - for cui, name in self.delegate.get_applicable_targets(translation, in_gen): - yield cui, name - yield from self.get_children_of(translation, cui, cur_depth=1) - - def get_children_of(self, translation: TranslationLayer, cui: str, cur_depth: int) -> Iterator[Tuple[str, str]]: - for child in translation.cui2children[cui]: - yield from translation.targets_for(child) - if cur_depth < self.depth: - yield from self.get_children_of(translation, child, cur_depth=cur_depth + 1) - - def to_dict(self) -> dict: - """Convert this CUIWithChildFilter to a dict. - - Returns: - dict: The dict representation + Iterator[TargetedPhraseChanger]: Thetarget phrase changers. """ - return {self.type.name: {'depth': self.depth, 'cui': self.delegate.values}} + num_of_opts = len(self.options) + if num_of_opts == 1: + # NOTE: when there's only 1 option, the other option doesn't work + # since it has nothing to iterate over regarding 'other' options + opt = self.options[0] + for target_cui in opt.target_cuis: + yield TargetedPhraseChanger(changer=PhraseChanger.empty(), + placeholder=opt.placeholder, + cui=target_cui, + onlyprefnames=opt.onlyprefnames) + return + for opt_nr in range(num_of_opts): + other_opts = list(self.options) + cur_opt = other_opts.pop(opt_nr) + for changer, target_cui in self._get_all_combinations(cur_opt, other_opts, translation): + yield TargetedPhraseChanger(changer=changer, + placeholder=cur_opt.placeholder, + cui=target_cui, + onlyprefnames=cur_opt.onlyprefnames) + + +class ProblematicOptionSetException(ValueError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/medcat/utils/regression/utils.py b/medcat/utils/regression/utils.py index 90a343783..3d630bec3 100644 --- a/medcat/utils/regression/utils.py +++ b/medcat/utils/regression/utils.py @@ -1,36 +1,225 @@ +from typing import Iterator, Tuple, List, Dict, Any, Type + +import ast +import inspect from enum import Enum -from typing import Type, TypeVar, cast +from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportDocument + + +# this placheolder will be temporarily put in the +# phrases when dealing with one that has multiple +# of the same placeholder in it +_TEMP_MULTI_PLACEHOLDER = '###===PlaceHolder===###' -ENUM = TypeVar('ENUM', bound=Enum) +def partial_substitute(phrase: str, placeholder: str, name: str, nr: int) -> str: + """Substitute all but 1 of the many placeholders present in the given phrase. -def loosely_match_enum(e_type: Type[ENUM], name: str) -> ENUM: - """Loosely (i.e case-insensitively) match enum names. + First, the first `nr` placeholders are replaced. + Then the next (1) placeholder is replaced with a temporary one + After that, the rest of the placeholders are replaced. + And finally, the temporary placeholder is returned back to its original form. + + Example: + If we've got `phrase = "some [PH] and [PH] we [PH]"` + `placeholder = "[PH]"`, and `name = 'NAME'`, + we'd get the following based on the number `nr`: + 0: "some [PH] and NAME we NAME" + 1: "some NAME and [PH] we NAME" + 2: "some NAME and NAME we [PH]" Args: - e_type (Type[Enum]): The type of enum to use - name (str): The case-insensitive name + phrase (str): The phrase in question. + placeholder (str): The placeholder to replace. + name (str): The name to replace the placeholder for. + nr (int): The number of the target to keep. Raises: - _key_err: KeyError if the key is unable to be loosely matched + IncompatiblePhraseException: If the number of placeholders in the phrase + is 1 or the number to be kept is too high; or the phrase has the + temporary placeholder. Returns: - ENUM: The enum constant that was found + str: The partially substituted phrase. + """ + num_of_placeholder = phrase.count(placeholder) + if nr >= num_of_placeholder or num_of_placeholder == 1: + # NOTE: in cae of 1, this makes no sense + raise IncompatiblePhraseException( + f"The phrase ({repr(phrase)}) has {num_of_placeholder} " + f"placeholders, but the {nr}th placeholder was requested to be " + "swapped!") + # replace stuff before the specific one + phrase = phrase.replace(placeholder, name, nr) + if _TEMP_MULTI_PLACEHOLDER in phrase: + # if the temporary placeholder is already in text, the following would fail + # unexpectedly + raise IncompatiblePhraseException( + f"Regression phrase with multiple placeholders ({placeholder}) " + f"has the temporary placeholder: {repr(_TEMP_MULTI_PLACEHOLDER)}. " + f"This means that the partial substitution of all but the {nr}th " + "placeholder failed") + # replace the target with temporary placeholder + phrase = phrase.replace(placeholder, _TEMP_MULTI_PLACEHOLDER, 1) + # replace the rest of the placeholder + phrase = phrase.replace(placeholder, name) + # set back the one needed placeholder + phrase = phrase.replace(_TEMP_MULTI_PLACEHOLDER, placeholder) + return phrase + + +class IncompatiblePhraseException(ValueError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +def limit_str_len(input_str: str, + max_length: int = 40, + keep_front: int = 20, + keep_rear: int = 10) -> str: + """Limits the length of a string. + + If the length of the string is less than or equal to `max_length`, the same + string is returned. + If it's longer, the first `keep_front` are kept, then the number of chars + is included in brackets (e.g `" [123 chars] "`), and finally the last + `keep_rear` characters are included. + + Args: + input_str (str): The input (potentially) long string. + max_length (int): The maximum number of characters at which + the string will remain unchanged. Defaults to 40. + keep_front (int): How many starting characters to keep. Defaults to 20. + keep_rear (int): How many ending characters to keep. Defaults to 10. + + Returns: + str: _description_ + """ + if len(input_str) <= max_length: + return input_str + part1 = input_str[:keep_front] + part2 = input_str[-keep_rear:] + hidden_chars = len(input_str) - len(part1) - len(part2) + return f"{part1} [{hidden_chars} chars] {part2}" + + +class MedCATTrainerExportConverter: + """Used to convert an MCT export to the format required for regression.""" + # NOTE: the first placeholder will use the CUI, the 2nd the order of + # the annotation. This is required so that placeholders with the + # samme concept don't have the same name + TEMP_PLACEHOLDER = "##[SWAPME-{}-{}]##" + + def __init__(self, mct_export: MedCATTrainerExport, + use_only_existing_name: bool = False) -> None: + self.mct_export = mct_export + self.use_only_existing_name = use_only_existing_name + + def _get_placeholder(self, cui: str, nr: int) -> str: + return self.TEMP_PLACEHOLDER.format(cui, nr) + + def convert(self) -> dict: + """Converts the MedCATtrainer export into regression suite dict. + + I.e this should producce a dict in the same format as one read + from a regression suite YAML. + + Returns: + dict: The Regression-suite compatible dict. + """ + converted: Dict[str, dict] = {} + for phrase, case_name, anns in self._iter_docs(): + regr_case: Dict[str, Any] = { + 'targeting': { + 'placeholders': [ + { + # NOTE: this is just and example. + # it will be wiped/overwritten later + 'placeholders': "TODO", + 'cuis': ['CUI1'] + } + ], + 'any-combination': False, + }, + 'phrases': [] # will be filled later + } + placeholders: List[Dict[str, Any]] = [] + # NOTE: the iteration is done from later annotations + # so I can replace using the locations + for ann_nr, (start, end, cui, _) in enumerate(anns): + ph = self._get_placeholder(cui, ann_nr) + phrase = phrase[:start] + ph + phrase[end:] + placeholders.append({ + 'placeholder': ph, 'cuis': [cui, ] + }) + # update at the very end, when changed + regr_case['phrases'] = [phrase] + regr_case['targeting']['placeholders'] = placeholders + converted[case_name] = regr_case + return converted + + def _iter_docs(self) -> Iterator[Tuple[str, str, Iterator[Tuple[int, int, str, str]]]]: + for project in self.mct_export['projects']: + project_id = project['id'] + project_name = project['name'] + for doc in project['documents']: + doc_id = doc['id'] + text = doc['text'] + yield text, f"{project_id}_{project_name}_{doc_id}", self._iter_anns_backwards(doc) + + def _iter_anns_backwards(self, doc: MedCATTrainerExportDocument) -> Iterator[Tuple[int, int, str, str]]: + # NOTE: doing so backwards so that I can replace them one by one using the start/end, + # starting from the end of the phrase + for ann in doc['annotations'][::-1]: + yield ann['start'], ann['end'], ann['cui'], ann['value'] + + +def get_class_level_docstrings(cls: Type) -> List[str]: + """This is a helper method to get all the class level doc strings. + + This is designed to be used alongside and by the `add_doc_strings_to_enum` method. + + Args: + cls (Type): The class in question. + + Returns: + List[str]: All class-level docstrings (including the class docstring if it exists). + """ + source_code = inspect.getsource(cls) + tree = ast.parse(source_code) + docstrings: List[str] = [] + # walk the tree + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for elem in node.body: + if isinstance(elem, ast.Expr) and isinstance(elem.value, ast.Constant): + # If it's an expression node containing a constant, extract the string + docstrings.append(elem.value.s) + return docstrings + + +def add_doc_strings_to_enum(cls: Type[Enum]) -> None: + """Add doc strings to Enum as they are described in code right below each constant. + + The way python works means that the doc strings defined after an Enum constant do not + get stored with the constant. When accessing the doc string of an Enum constant, the + doc string of the class is returned instead. + + So what this method does is gets the doc strings by traversing the abstract syntax tree. + + While there would be easier ways to accomplish this, they would require the doc strings + for the Enum constant to be further from the constants themselves. + + If the class itself has a doc string, it is omitted. Otherwise the Enum constants are + given the doc strings in the order in which they appear. + + Args: + cls (Type[Enum]): The Enum class to do this for. """ - _key_err = None - try: - return cast(ENUM, e_type[name]) - except KeyError as key_err: - _key_err = key_err - name = name.lower() - try: - return cast(ENUM, e_type[name]) - except KeyError: - pass - name = name.upper() - try: - return cast(ENUM, e_type[name]) - except KeyError: - pass - raise _key_err + docstrings = get_class_level_docstrings(cls) + if cls.__doc__ == docstrings[0]: + docstrings = docstrings[1:] + for ev, ds in zip(cls, docstrings): + ev.__doc__ = ds diff --git a/tests/resources/regression/creation/cat_creation.py b/tests/resources/regression/creation/cat_creation.py new file mode 100644 index 000000000..c37af9430 --- /dev/null +++ b/tests/resources/regression/creation/cat_creation.py @@ -0,0 +1,65 @@ +import os +import sys +import pandas as pd + +from medcat.vocab import Vocab +from medcat.config import Config +from medcat.cdb_maker import CDBMaker +from medcat.cdb import CDB +from medcat.cat import CAT + + +vi = sys.version_info +PY_VER = f"{vi.major}.{vi.minor}" + + +# paths +VOCAB_DATA_PATH = os.path.join( + os.path.dirname(__file__), 'vocab_data.txt' + # os.path.dirname(__file__), 'vocab_data_auto.txt' +) +CDB_PREPROCESSED_PATH = os.path.join( + os.path.dirname(__file__), 'preprocessed4cdb.txt' +) +SELF_SUPERVISED_DATA_PATH = os.path.join( + os.path.dirname(__file__), 'selfsupervised_data.txt' +) +SUPERVISED_DATA_PATH = os.path.join( + os.path.dirname(__file__), 'supervised_mct_export.json' +) +SAVE_PATH = os.path.dirname(__file__) +SAVE_NAME = f"simple_model4test-{PY_VER}" + +# vocab + +vocab = Vocab() +vocab.add_words(VOCAB_DATA_PATH) + +# CDB +config = Config() + +maker = CDBMaker(config) + +cdb: CDB = maker.prepare_csvs([CDB_PREPROCESSED_PATH]) + +# CAT +cat = CAT(cdb, vocab) + +# training +# self-supervised +data = pd.read_csv(SELF_SUPERVISED_DATA_PATH) +cat.train(data.text.values) + +print("[sst] cui2count_train", cat.cdb.cui2count_train) + +# supervised + +cat.train_supervised_from_json(SUPERVISED_DATA_PATH) + +print("[sup] cui2count_train", cat.cdb.cui2count_train) + +# save +mpn = cat.create_model_pack(SAVE_PATH, model_pack_name=SAVE_NAME) +full_path = os.path.join(SAVE_PATH, mpn) +print("Saved to") +print(full_path) diff --git a/tests/resources/regression/creation/preprocessed4cdb.txt b/tests/resources/regression/creation/preprocessed4cdb.txt new file mode 100644 index 000000000..113805b2c --- /dev/null +++ b/tests/resources/regression/creation/preprocessed4cdb.txt @@ -0,0 +1,11 @@ +cui,name +C01,kidney failure +C01,loss of kidney function +C02,diabetes +C02,diabetes mellitus +C03,fever +C03,high temperature +C04,seizure +C04,fittest +C05,healthy +C05,fittest \ No newline at end of file diff --git a/tests/resources/regression/creation/selfsupervised_data.txt b/tests/resources/regression/creation/selfsupervised_data.txt new file mode 100644 index 000000000..382a38040 --- /dev/null +++ b/tests/resources/regression/creation/selfsupervised_data.txt @@ -0,0 +1,10 @@ +id,text +FD0,"Patient presented with severe diabetes and had also been diagnosed with acute kidney failure. +Prior to visit the patient had also complained about a light fever" +FD1,"50yo RHM with light fever admitted to hospital. +Tests conducted and acute kidney failure discovered. +Tests also show signes of severe diabetes, though there are no other symptoms." +FD2,"102yo LHF presented with acute seizure after long day of work. +No further complications" +FD3,"Patient is a healthy male in their 20s. +No health complications were noted." \ No newline at end of file diff --git a/tests/resources/regression/creation/supervised_mct_export.json b/tests/resources/regression/creation/supervised_mct_export.json new file mode 100644 index 000000000..d6cdcd828 --- /dev/null +++ b/tests/resources/regression/creation/supervised_mct_export.json @@ -0,0 +1,136 @@ +{ + "projects": [ + { + "cuis": "", + "documents": [ + { + "annotations": [ + { + "cui": "C01", + "start": 38, + "end": 52, + "value": "kidney failure" + }, + { + "cui": "C01", + "start": 122, + "end": 145, + "value": "loss of kidney function" + }, + { + "cui": "C02", + "start": 192, + "end": 200, + "value": "diabetes" + }, + { + "cui": "C02", + "start": 279, + "end": 296, + "value": "diabetes mellitus" + }, + { + "cui": "C03", + "start": 390, + "end": 395, + "value": "fever" + }, + { + "cui": "C03", + "start": 454, + "end": 470, + "value": "high temperature" + } + ], + "id": "ID-0", + "last_modified": "2024-08-21", + "name": "Doc#0", + "text": "Patient had been diagnosed with acute kidney failure the week before. The current complaint was related to the same acute loss of kidney function as the diagnosis. The patient also has severe diabetes even though they have never consumed any sugar. The prior diagnosis of severe diabetes mellitus was confirmed by doctor. Due to the previous issues, patient had been suffering from a light fever all day. They took some paracetamol but still had a light high temperature afterwards." + }, + { + "annotations": [ + { + "cui": "C04", + "start": 20, + "end": 27, + "value": "seizure" + }, + { + "cui": "C04", + "start": 81, + "end": 87, + "value": "fittest" + } + ], + "id": "ID-1", + "last_modified": "2024-08-21", + "name": "Doc#1", + "text": "Patient had a acute seizure during visit with GP. This is the first time a minor fittest was observed for this patient. " + }, + { + "annotations": [ + { + "cui": "C05", + "start": 26, + "end": 33, + "value": "healthy" + }, + { + "cui": "C05", + "start": 84, + "end": 91, + "value": "fittest" + } + ], + "id": "ID-2", + "last_modified": "2024-08-21", + "name": "Doc#2", + "text": "The patient is considered healthy as per tests run. The patient would be considered fittest according to any standard known." + }, + { + "annotations": [ + { + "cui": "C04", + "start": 24, + "end": 31, + "value": "seizure" + }, + { + "cui": "C04", + "start": 65, + "end": 72, + "value": "fittest" + } + ], + "id": "ID-3", + "last_modified": "2024-08-21", + "name": "Doc#3", + "text": "The patient has a minor seizure every day. The presence of daily fittest is extremely problematic." + }, + { + "annotations": [ + { + "cui": "C05", + "start": 16, + "end": 23, + "value": "healthy" + }, + { + "cui": "C05", + "start": 111, + "end": 118, + "value": "fittest" + } + ], + "id": "ID-3", + "last_modified": "2024-08-21", + "name": "Doc#4", + "text": "The RHS male is healthy as considered by all available tests. There are no indications that the patient is not fittest." + } + ], + "id": "Project#0", + "name": "Project-0", + "tuis": "" + } + ] +} diff --git a/tests/resources/regression/creation/vocab_data.txt b/tests/resources/regression/creation/vocab_data.txt new file mode 100644 index 000000000..036e651a8 --- /dev/null +++ b/tests/resources/regression/creation/vocab_data.txt @@ -0,0 +1,18 @@ +severe 10000 1.0 0 0 1 0 0 0 +minor 10000 -1.0 0 0 1 0 0 0 +acute 6500 0 1.0 0 1 0 0 0 +chronic 6500 0 -1.0 0 0 1 0 0 0 +heavy 4000 0 0 1.0 1 0 0 0 +light 4000 0 0 -1.0 1 0 0 0 +considered 1000 0.1 -0.2 0 0 0.9 0 0 +with 20000 0 0 0 0 0 0.8 0 +of 22000 0 0 0 0 0 1 0 +to 19000 0 0 0 0 0 0.9 0 +were 12000 0 0 0 0 0.95 0 0 +was 11000 0 0 0 0 0.94 0 0 +is 12000 0 0 0 0 1 0 0 +are 12000 0 0 0 0 1.1 0 0 +has 11000 0 0 0 0 0.98 0 0 +presence 1000 0 0 0 0 0 0 0.4 +indication 500 0 0 0 0 0 0 0.3 +time 450 0 0 0 0 0 0 0.1 \ No newline at end of file diff --git a/tests/resources/regression/run_regression.sh b/tests/resources/regression/run_regression.sh new file mode 100644 index 000000000..0e9911434 --- /dev/null +++ b/tests/resources/regression/run_regression.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# exit immediately upon non-zero exit status +set -e + +# create and train model and capture output +# this will create a model pack based on some data included within the tests/resources/regression/creation/ folder, +# it will then train on some self-supervised as well as supervised training data and save the model. +output=$(python tests/resources/regression/creation/cat_creation.py) +# make sure the user sees the output +echo "$output" + +# extract the last line of the output which contains the full model path +model_path=$(echo "$output" | tail -n 1) +# NOTE: this file should be tagged with the python version we're using + +# run the regression_checker with the captured file path +# if any of the regression cases fail, this will return a non-zero exit status +python -m medcat.utils.regression.regression_checker \ + "$model_path" \ + tests/resources/regression/testing/test_model_regresssion.yml \ + --strictness STRICTEST \ + --require-fully-correct + +# Step 4: Clean up the generated file +rm -rf "$model_path"* diff --git a/tests/resources/regression/testing/test_model_regresssion.yml b/tests/resources/regression/testing/test_model_regresssion.yml new file mode 100644 index 000000000..ad1da300d --- /dev/null +++ b/tests/resources/regression/testing/test_model_regresssion.yml @@ -0,0 +1,63 @@ +# this is only mean for the test "model pack" in the examples folder +unambiguous-works: # this uses the exact same context that was used during training + targeting: + placeholders: + - placeholder: '[CONCEPT1]' + cuis: [ + 'CO1', # kidney failure + ] + - placeholder: '[CONCEPT2]' + cuis: [ + 'C02', # diabetes + ] + - placeholder: '[CONCEPT3]' + cuis: [ + 'C03', # fever + ] + phrases: # The list of phrases + - Man was diagnosed with severe [CONCEPT1] and acute [CONCEPT2] and presented with a light [CONCEPT3] +unambiguous-works-rnd: # these use the random word that one of the concepts WAS trained for + targeting: + placeholders: + - placeholder: '[CONCEPT]' + cuis: [ + 'CO1', # kidney failure + 'C02', # diabetes + 'C03', # fever + ] + phrases: # The list of phrases + - Patient was diagnosed with severe [CONCEPT]. + - Patient was diagnosed with acute [CONCEPT]. + - Patient presented with light [CONCEPT]. +unambiguous-works-rnd-reverse: # these use the OPPOSITE random word that one of the concepts WAS trained for + targeting: + placeholders: + - placeholder: '[CONCEPT]' + cuis: [ + 'CO1', # kidney failure + 'C02', # diabetes + 'C03', # fever + ] + phrases: # The list of phrases + - Patient was diagnosed with minor [CONCEPT]. + - Patient was diagnosed with chronic [CONCEPT]. + - Patient presented with heavy [CONCEPT]. +ambiguous-works-trained-1: # Uses AMBIGUOUS concepts in the trained context + targeting: + placeholders: + - placeholder: '[CONCEPT]' + cuis: [ + 'C04', # seizure/fit + ] + phrases: # The list of phrases + - Patient presented with acute [CONCEPT]. + - Patient had a minor [CONCEPT] during visit. +ambiguous-works-trained-2: # Uses AMBIGUOUS concepts in the trained context + targeting: + placeholders: + - placeholder: '[CONCEPT]' + cuis: [ + 'C05', # healthy/fit + ] + phrases: # The list of phrases + - Patient is a 50yo RHM considered [CONCEPT]. diff --git a/tests/utils/regression/test_checking.py b/tests/utils/regression/test_checking.py index 16a349d48..06ca933de 100644 --- a/tests/utils/regression/test_checking.py +++ b/tests/utils/regression/test_checking.py @@ -1,22 +1,22 @@ - +import os +import json import unittest -from medcat.utils.regression.targeting import FilterType, FilterStrategy, FilterOptions -from medcat.utils.regression.targeting import TypedFilter, TranslationLayer -from medcat.utils.regression.checking import RegressionChecker, RegressionCase - -DICT_WITH_CUI = {'cui': '123'} -DICT_WITH_MULTI_CUI = {'cui': ['111', '101']} -DICT_WITH_NAME = {'name': 'a name'} -DICT_WITH_MULTI_NAME = {'name': ['one name', 'two name']} -DICT_WITH_TYPE_ID = {'type_id': '443'} -DICT_WITH_MULTI_TYPE_ID = {'type_id': ['987', '789']} -# from python 3.6 the following _should_ remember the order of the dict items -# which should mean that the orders in the tests are correct -DICT_WITH_MIX_1 = dict(DICT_WITH_CUI, **DICT_WITH_NAME) -DICT_WITH_MIX_2 = dict(DICT_WITH_NAME, **DICT_WITH_MULTI_TYPE_ID) -DICT_WITH_MIX_3 = dict(DICT_WITH_MULTI_NAME, **DICT_WITH_MULTI_TYPE_ID) -DICT_WITH_MIX_4 = dict(DICT_WITH_MIX_3, **DICT_WITH_MULTI_CUI) +from medcat.config import Config +from medcat.utils.regression.targeting import OptionSet, FinalTarget +from medcat.utils.regression.targeting import TranslationLayer +from medcat.utils.regression.checking import RegressionSuite, RegressionCase, MetaData +from medcat.utils.regression.results import Finding, ResultDescriptor, Strictness + +EXAMPLE_CUI = '123' +COMPLEX_PLACEHOLDERS = [ + {'placeholder': "[PH1]", + 'cuis': ['cui1', 'cui2']}, + {'placeholder': "[PH2]", + 'cuis': ['cui3', 'cui4']}, + {'placeholder': "[PH3]", + 'cuis': ['cui1', 'cui3']}, +] EXAMPLE_INFOS = [ @@ -60,6 +60,8 @@ def __init__(self, *infos) -> None: else: self.name2cuis[name] = set([cui]) pt2ch.update(dict((cui, set()) for cui in self.cui2names)) + self.cui2preferred_name = {c_cui: list(names)[0] for c_cui, names in self.cui2names.items()} + self.config = Config() class FakeCat: @@ -72,7 +74,8 @@ def get_entities(self, text, only_cui=True) -> dict: cuis = list(self.tl.name2cuis[text]) if only_cui: return {'entities': dict((i, cui) for i, cui in enumerate(cuis))} - return {'entities': dict((i, {'cui': cui, 'source_value': text}) for i, cui in enumerate(cuis))} + return {'entities': dict((i, {'cui': cui, 'source_value': text, 'start': 0, 'end': 4}) + for i, cui in enumerate(cuis))} return {} @@ -91,147 +94,17 @@ def test_TranslationLayer_works_from_non_empty_fake_CDB(self): def test_gets_all_targets(self): fakeCDB = FakeCDB(*EXAMPLE_INFOS) tl = TranslationLayer.from_CDB(fakeCDB) - targets = list(tl.all_targets([ei[0] for ei in EXAMPLE_INFOS], [ - ei[1] for ei in EXAMPLE_INFOS], [ei[2] for ei in EXAMPLE_INFOS])) + targets = [name for ei in EXAMPLE_INFOS for name in tl.get_names_of(ei[0], False)] self.assertEqual(len(targets), len(EXAMPLE_INFOS)) -_CUI = 'C123' -_NAME = 'NAMEof123' -_TYPE_ID = '-1' -_D = {'cui': _CUI} -_tts = TypedFilter.from_dict(_D) -_cui2names = {_CUI: [_NAME, ]} -_name2cuis = {_NAME: [_CUI, ]} -_cui2type_ids = {_CUI: [_TYPE_ID, ]} -_cui2children = {} -_tl = TranslationLayer(cui2names=_cui2names, name2cuis=_name2cuis, - cui2type_ids=_cui2type_ids, cui2children=_cui2children) - - -class TestTypedFilter(unittest.TestCase): - - def test_has_correct_target_type(self): - target_types = [FilterType.CUI, FilterType.NAME, FilterType.TYPE_ID] - for target_type in target_types: - with self.subTest(f'With target type {target_type}'): - tt = TypedFilter(type=target_type, values=[]) - self.assertEqual(tt.type, target_type) - - def check_is_correct_target(self, in_dict: dict, *types, test_with_upper_case=True): - tts = TypedFilter.from_dict(in_dict) - # should have the correct number of elements - self.assertEqual(len(tts), len(types)) - for (the_type, single_multi), tt in zip(types, tts): - with self.subTest(f'With type {the_type} and {single_multi}'): - self.assertIsInstance(tt, single_multi) - self.assertEqual(tt.type, the_type) - if test_with_upper_case: # also test upper case - upper_case_dict = dict((key.upper(), val) - for key, val in in_dict.items()) - self.check_is_correct_target( - upper_case_dict, *types, test_with_upper_case=False) - - def test_constructs_SingleTarget_from_dict_with_single_cui(self): - self.check_is_correct_target( - DICT_WITH_CUI, (FilterType.CUI, TypedFilter)) - - def test_constructs_MultiTarget_from_dict_with_multiple_cuis(self): - self.check_is_correct_target( - DICT_WITH_MULTI_CUI, (FilterType.CUI, TypedFilter)) - - def test_constructs_SingleTarget_from_dict_with_single_name(self): - self.check_is_correct_target( - DICT_WITH_NAME, (FilterType.NAME, TypedFilter)) - - def test_constructs_MultiTarget_from_dict_with_multiple_names(self): - self.check_is_correct_target( - DICT_WITH_MULTI_NAME, (FilterType.NAME, TypedFilter)) - - def test_constructs_SingleTarget_from_dict_with_single_type_id(self): - self.check_is_correct_target( - DICT_WITH_TYPE_ID, (FilterType.TYPE_ID, TypedFilter)) - - def test_constructs_MultiTarget_from_dict_with_multiple_type_ids(self): - self.check_is_correct_target( - DICT_WITH_MULTI_TYPE_ID, (FilterType.TYPE_ID, TypedFilter)) - - def test_constructs_correct_list_of_types_1(self): - self.check_is_correct_target(DICT_WITH_MIX_1, ( - FilterType.CUI, TypedFilter), (FilterType.NAME, TypedFilter)) - - def test_constructs_correct_list_of_types_2(self): - self.check_is_correct_target(DICT_WITH_MIX_2, ( - FilterType.NAME, TypedFilter), (FilterType.TYPE_ID, TypedFilter)) - - def test_constructs_correct_list_of_types_3(self): - self.check_is_correct_target(DICT_WITH_MIX_3, ( - FilterType.NAME, TypedFilter), (FilterType.TYPE_ID, TypedFilter)) - - def test_constructs_correct_list_of_types_4(self): - self.check_is_correct_target(DICT_WITH_MIX_4, ( - FilterType.NAME, TypedFilter), (FilterType.TYPE_ID, TypedFilter), (FilterType.CUI, TypedFilter)) - - def test_get_applicable_targets_gets_target(self): - self.assertEqual(len(_tts), 1) - tt = _tts[0] - targets = list(tt.get_applicable_targets(_tl, _tl.all_targets([ei[0] for ei in EXAMPLE_INFOS], [ - ei[1] for ei in EXAMPLE_INFOS], [ei[2] for ei in EXAMPLE_INFOS]))) - self.assertEqual(len(targets), 1) - cui, name = targets[0] - self.assertEqual(name, _NAME) - self.assertEqual(cui, _CUI) - - def test_get_applicable_targets_gets_target_from_many(self): - # add noise to existing translations - cui2names = dict( - _cui2names, **dict((f'{cui}rnd', f'{name}sss') for cui, name in _cui2names.items())) - name2cuis = dict( - _name2cuis, **dict((f'{name}sss', f'{cui}123') for cui, name in _name2cuis.items())) - cui2type_ids = dict( - _cui2type_ids, **dict((f'{cui}123', 'typeid') for cui in _cui2type_ids)) - cui2children = {} - tl = TranslationLayer(cui2names=cui2names, name2cuis=name2cuis, - cui2type_ids=cui2type_ids, cui2children=cui2children) - self.assertEqual(len(_tts), 1) - tt = _tts[0] - targets = list(tt.get_applicable_targets(tl, tl.all_targets([ei[0] for ei in EXAMPLE_INFOS], [ - ei[1] for ei in EXAMPLE_INFOS], [ei[2] for ei in EXAMPLE_INFOS]))) - self.assertEqual(len(targets), 1) - cui, name = targets[0] - self.assertEqual(name, _NAME) - self.assertEqual(cui, _CUI) - - -class TestFilterOptions(unittest.TestCase): - - def test_loads_from_dict(self): - D = {'strategy': 'all'} - opts = FilterOptions.from_dict(D) - self.assertIsInstance(opts, FilterOptions) - self.assertEqual(opts.strategy, FilterStrategy.ALL) - - def test_loads_from_dict_defaults_not_pref_only(self): - D = dict() - opts = FilterOptions.from_dict(D) - self.assertIsInstance(opts, FilterOptions) - self.assertFalse(opts.onlyprefnames) - - def test_loads_from_empty_dict_w_default(self): - D = dict() - opts = FilterOptions.from_dict(D) - self.assertIsInstance(opts, FilterOptions) - self.assertEqual(opts.strategy, FilterStrategy.ALL) - - def test_loads_from_dict_with_onlypref(self): - D = {'prefname-only': 'True'} - opts = FilterOptions.from_dict(D) - self.assertIsInstance(opts, FilterOptions) - self.assertTrue(opts.onlyprefnames) - - class TestRegressionCase(unittest.TestCase): - D_MIN = {'targeting': {'filters': DICT_WITH_CUI}, + D_MIN = {'targeting': { + 'placeholders': [ + { + 'placeholder': '%s', + 'cuis': [EXAMPLE_CUI], + }]}, 'phrases': ['The phrase %s works']} def _create_copy(self, d): @@ -251,8 +124,8 @@ def test_loads_from_min_dict(self): D = self.min_d rc: RegressionCase = RegressionCase.from_dict(NAME, D) self.assertIsInstance(rc, RegressionCase) - self.assertEqual(len(rc.filters), 1) - self.assertIsInstance(rc.options, FilterOptions) + self.assertEqual(len(rc.options.options), 1) + self.assertIsInstance(rc.options, OptionSet) self.assertEqual(len(rc.phrases), 1) def test_fails_dict_no_targets_1(self): @@ -265,7 +138,7 @@ def test_fails_dict_no_targets_1(self): def test_fails_dict_no_targets_2(self): NAME = 'NAME2' D = self.min_d - D['targeting'].pop('filters') + D['targeting'].pop('placeholders') with self.assertRaises(ValueError): RegressionCase.from_dict(NAME, D) @@ -283,151 +156,168 @@ def test_fails_with_no_phrases_2(self): with self.assertRaises(ValueError): RegressionCase.from_dict(NAME, D) - D_COMPLEX = {'targeting': dict({'filters': DICT_WITH_MIX_4}, **{'strategy': 'any', - 'prefname-only': 'true'}), 'phrases': ['The phrase %s works', 'ALL %s phrases']} + D_COMPLEX = {'targeting': {'placeholders': COMPLEX_PLACEHOLDERS}, + 'phrases': ['The phrase %s works', 'ALL %s phrases']} def test_loads_from_complex_dict(self): NAME = 'NAMEC' D = self.complex_d rc: RegressionCase = RegressionCase.from_dict(NAME, D) self.assertIsInstance(rc, RegressionCase) - self.assertEqual(len(rc.filters), 3) - self.assertIsInstance(rc.options, FilterOptions) + self.assertEqual(len(rc.options.options), 3) + self.assertIsInstance(rc.options, OptionSet) self.assertEqual(len(rc.phrases), 2) TARGET_CUI = 'C123' - D_SPECIFIC_CASE = {'targeting': {'filters': { - 'cui': [TARGET_CUI, ]}}, 'phrases': ['%s']} # should just find the name itself + D_SPECIFIC_CASE = {'targeting': {'placeholders': [{ + 'placeholder': '%s', + 'cuis': [TARGET_CUI, ]} + ]}, 'phrases': ['%s']} # should just find the name itself - def test_specific_case_CUI(self): + +class TestRegressionCaseCheckModel(unittest.TestCase): + EXPECT_MANUAL_SUCCESS = 0 + EXPECT_FAIL = 0 + FAIL_FINDINGS = (Finding.FAIL, Finding.FOUND_OTHER) + + @classmethod + def setUpClass(cls) -> None: NAME = 'NAMESC' - tl = TranslationLayer.from_CDB(FakeCDB(*EXAMPLE_INFOS)) + cls.tl = TranslationLayer.from_CDB(FakeCDB(*EXAMPLE_INFOS)) D = TestRegressionCase.D_SPECIFIC_CASE rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - self.assertEqual(success, len( - tl.cui2names[TestRegressionCase.TARGET_CUI])) - - TARGET_NAME = 'N223' - D_SPECIFIC_CASE_NAME = {'targeting': {'filters': { - 'name': TARGET_NAME}}, 'phrases': ['%s']} + regr_checker = RegressionSuite([rc], MetaData.unknown(), name="TEST SUITE 2") + cls.res = regr_checker.check_model(FakeCat(cls.tl), cls.tl) - def test_specific_case_NAME(self): - NAME = 'NAMESC2' - tl = TranslationLayer.from_CDB(FakeCDB(*EXAMPLE_INFOS)) - D = TestRegressionCase.D_SPECIFIC_CASE_NAME - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) + def test_specific_case_CUI(self): + fail = self.get_manual_fail() + success = self.get_manual_success() + self.assertEqual(fail, self.EXPECT_FAIL) self.assertEqual(success, len( - tl.name2cuis[TestRegressionCase.TARGET_NAME])) - - TARGET_TYPE = 'T1' - D_SPECIFIC_CASE_TYPE_ID = {'targeting': {'filters': { - 'type_id': TARGET_TYPE}}, 'phrases': ['%s']} - - def test_specific_case_TYPE_ID(self): - NAME = 'NAMESC3' - tl = TranslationLayer.from_CDB(FakeCDB(*EXAMPLE_INFOS)) - D = TestRegressionCase.D_SPECIFIC_CASE_TYPE_ID - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - self.assertEqual(success, len(EXAMPLE_TYPE_T1_CUI)) - - PARENT_CUI = 'C123' - CHILD_CUI = 'C124' - D_PARENT_W_CHILDREN = {'targeting': {'filters': { - 'cui_and_children': {'cui': PARENT_CUI, 'depth': 1}}}, - 'phrases': ['%s']} - PT2CHILD = {PARENT_CUI: set([CHILD_CUI])} - - def test_cui_and_children_finds_child(self): - NAME = 'NAMEpt2ch' - cdb = FakeCDB(*EXAMPLE_INFOS) - cdb.addl_info['pt2ch'].update(self.PT2CHILD) - tl = TranslationLayer.from_CDB(cdb) - D = self.D_PARENT_W_CHILDREN - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - expected = len(cdb.cui2names[self.PARENT_CUI]) + \ - len(cdb.cui2names[self.CHILD_CUI]) - self.assertEqual(success, expected) - - P_CUI = 'C123' - C_CUI1 = 'C124' - C_CUI2 = 'C223' - C_CUI1_C1 = 'C224' - C_CUI1_C1_C1 = 'C323' - C_CUI1_C1_C1_C1 = 'C324' - D_MULIT_CHILD_1 = {'targeting': {'filters': { - 'cui_and_children': {'cui': P_CUI, 'depth': 2}}}, - 'phrases': ['%s']} - PT2CHILD_M1 = {P_CUI: set([C_CUI1, C_CUI2]), - C_CUI1: set([C_CUI1_C1]), - C_CUI1_C1: set([C_CUI1_C1_C1]), - C_CUI1_C1_C1: set([C_CUI1_C1_C1_C1])} - - def test_cui_and_children_finds_children_depth_2(self): - NAME = 'NAMEpt2ch' - cdb = FakeCDB(*EXAMPLE_INFOS) - cdb.addl_info['pt2ch'].update(self.PT2CHILD_M1) - tl = TranslationLayer.from_CDB(cdb) - D = self.D_MULIT_CHILD_1 - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - expected = len(cdb.cui2names[self.P_CUI]) - # children - for child in tl.cui2children[self.P_CUI]: - expected += len(cdb.cui2names[child]) - # children of children - for child2 in tl.cui2children[child]: - expected += len(cdb.cui2names[child2]) - self.assertEqual(success, expected) - - D_MULIT_CHILD_2 = {'targeting': {'filters': { - 'cui_and_children': {'cui': P_CUI, 'depth': 3}}}, - 'phrases': ['%s']} - - def test_cui_and_children_finds_children_depth_3(self): - NAME = 'NAMEpt2ch' - cdb = FakeCDB(*EXAMPLE_INFOS) - cdb.addl_info['pt2ch'].update(self.PT2CHILD_M1) - tl = TranslationLayer.from_CDB(cdb) - D = self.D_MULIT_CHILD_2 - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - expected = len(cdb.cui2names[self.P_CUI]) - # children - for child in tl.cui2children[self.P_CUI]: - expected += len(cdb.cui2names[child]) - # children of children - for child2 in tl.cui2children[child]: - expected += len(cdb.cui2names[child2]) - # children of children of children - for child3 in tl.cui2children[child2]: - expected += len(cdb.cui2names[child3]) - self.assertEqual(success, expected) - - def test_gets_with_ANY_strategy(self): - NAME = 'ANYNAME' - tl = TranslationLayer.from_CDB(FakeCDB(*EXAMPLE_INFOS)) - D = {'targeting': {'strategy': 'any', 'filters': { - 'cui': ['C123', 'C124'], 'name': ['N223', 'N224']}}, 'phrases': ['%s']} - rc: RegressionCase = RegressionCase.from_dict(NAME, D) - success, fail = rc.check_case(FakeCat(tl), tl) - self.assertEqual(fail, 0) - expected = sum([len(tl.cui2children[cui]) for cui in D['targeting'] - ['filters']['cui']]) + len(D['targeting']['filters']['name']) - self.assertEqual(success, expected) + self.tl.cui2names[TestRegressionCase.TARGET_CUI]) + + self.EXPECT_MANUAL_SUCCESS # NOTE: manually added parts / success + ) + + def test_success_correct(self): + manual = self.get_manual_success() + report = self.res.calculate_report(strictness=Strictness.LENIENT) + self.assertEqual(report[1], manual) + + def test_fail_correct(self): + manual = self.get_manual_fail() + report = self.res.calculate_report(strictness=Strictness.LENIENT) + self.assertEqual(report[2], manual) + + def get_manual_success(self) -> int: + return sum(v for f, v in self.res.findings.items() if f not in self.FAIL_FINDINGS) + + def get_manual_fail(self) -> int: + return sum(v for f, v in self.res.findings.items() if f in self.FAIL_FINDINGS) + + +class TestRegressionCaseCheckModelJson(TestRegressionCaseCheckModel): + # that is, anything but fail or FIND_OTHER + EXPECT_MANUAL_SUCCESS = 3 + EXPECT_FAIL = 1 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + # add a non-perfect example to show in the below + cls.res.parts[0].report(FinalTarget(placeholder='PH', cui='CUI_PARENT', + name='NAME_PARENT', + final_phrase="FINAL PHRASE"), + (Finding.FOUND_ANY_CHILD, 'CHILD')) + # add another part + added_part = ResultDescriptor(name="NAME#2") + cls.res.parts.append(added_part) + added_part.report(target=FinalTarget(placeholder='PH1', cui='CUI-CORRECT', name='NAME-correct', + final_phrase='FINAL PHRASE'), finding=(Finding.IDENTICAL, None)) + added_part.report(target=FinalTarget(placeholder='PH2', cui='CUI-PARENT', name='CHILD NAME', + final_phrase='FINAL PHRASE'), finding=(Finding.FOUND_ANY_CHILD, 'CUI=child')) + added_part.report(target=FinalTarget(placeholder='PH5', cui='CUI-PARENT', name='OTHER NAME', + final_phrase='FINAL PHRASE'), finding=(Finding.FOUND_OTHER, 'CUI=OTHER')) + + def test_result_is_json_serialisable(self): + rd = self.res.dict() + s = json.dumps(rd) + self.assertIsInstance(s, str) + + def test_result_is_json_serialisable_pydantic(self): + s = self.res.json() + self.assertIsInstance(s, str) + + def test_can_use_strictness(self): + e1 = [ + example for part in self.res.dict(strictness=Strictness.STRICTEST)['parts'] + for per_phrase in part['per_phrase_results'].values() + for example in per_phrase['examples'] + ] + e2 = [ + example for part in self.res.dict(strictness=Strictness.LENIENT)['parts'] + for per_phrase in part['per_phrase_results'].values() + for example in per_phrase['examples'] + ] + self.assertGreater(len(e1), len(e2)) + + def test_dict_includes_all_parts(self): + d_parts = self.res.dict()['parts'] + self.assertEqual(len(self.res.parts), len(d_parts)) class TestRegressionChecker(unittest.TestCase): - - def test_reads_default(self, yaml_file='configs/default_regression_tests.yml'): - rc = RegressionChecker.from_yaml(yaml_file) - self.assertIsInstance(rc, RegressionChecker) + YAML_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", + "configs", "default_regression_tests.yml") + MCT_EXPORT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', + 'resources', 'medcat_trainer_export.json') + + @classmethod + def setUpClass(cls) -> None: + cls.rc = RegressionSuite.from_yaml(cls.YAML_PATH) + + def test_reads_correctly(self): + self.assertIsInstance(self.rc, RegressionSuite) + + def test_has_cases(self): + self.assertGreater(len(self.rc.cases), 0) + + +class TestRegressionCheckerFromMCTExport(TestRegressionChecker): + + @classmethod + def setUpClass(cls) -> None: + cls.rc = RegressionSuite.from_mct_export(cls.MCT_EXPORT_PATH) + + +class MultiPlaceholderTests(unittest.TestCase): + THE_DICT = { + "mulit-placeholder-case": { + 'targeting': { + 'placeholders': [ + { + 'placeholder': '[CONCEPT]', + 'cuis': ['C123', 'C124'] + # either has 1 name + } + ] + }, + 'phrases': [ + "This [CONCEPT] has mulitple [CONCEPT] instances of [CONCEPT]" + # 3 instances + ] + } + } + EXPECTED_CASES = 2 * 1 * 3 # 2 CUIs, 1 name each, 3 placeholders + FAKE_CDB = FakeCDB(*EXAMPLE_INFOS) + TL = TranslationLayer.from_CDB(FAKE_CDB) + + @classmethod + def setUpClass(cls) -> None: + cls.rc = RegressionSuite.from_dict(cls.THE_DICT, name="TEST SUITE 1") + + def test_reads_successfully(self): + self.assertIsInstance(self.rc, RegressionSuite) + + def test_gets_cases(self): + cases = list(self.rc.iter_subcases(self.TL)) + self.assertEqual(len(cases), self.EXPECTED_CASES) diff --git a/tests/utils/regression/test_editing.py b/tests/utils/regression/test_editing.py deleted file mode 100644 index 813b396b3..000000000 --- a/tests/utils/regression/test_editing.py +++ /dev/null @@ -1,92 +0,0 @@ - -import unittest -import yaml - -from medcat.utils.regression.checking import RegressionChecker - -from medcat.utils.regression.editing import combine_contents - - -class TestCombining(unittest.TestCase): - tests1 = """ -test-case-1: - targeting: - filters: - NAME: tcn1 - phrases: - - Some %s phrase - """.strip() - tests1_cp = """ -test-case-1: - targeting: - filters: - NAME: tcn1-cp - phrases: - - Some %s phrase - """.strip() - tests2 = """ -test-case-2: - targeting: - filters: - NAME: tcn2 - phrases: - - Some %s phrase - """.strip() - tests2_cp = """ -test-case-2: - targeting: - filters: - NAME: tcn2-cp - phrases: - - Some %s phrase - """.strip() - - def assert_simple_combination(self, one: str, two: str, combined: str = None, - expect_addition: bool = True, - check_str_len: bool = True, - ignore_identicals: bool = True) -> str: - if not combined: - combined = combine_contents( - one, two, ignore_identicals=ignore_identicals) - self.assertIsInstance(combined, str) - c1 = RegressionChecker.from_dict(yaml.safe_load(one)) - c2 = RegressionChecker.from_dict(yaml.safe_load(two)) - cc = RegressionChecker.from_dict(yaml.safe_load(combined)) - nc1, nc2, ncc = len(c1.cases), len(c2.cases), len(cc.cases) - if expect_addition: - self.assertEqual(ncc, nc1 + nc2) - else: - # total must be greater or equal than the max - self.assertGreaterEqual(ncc, max(nc1, nc2)) - if check_str_len: - self.assertGreater(len(combined), len(one)) - self.assertGreater(len(combined), len(two)) - # print(f'From\n{one}\nand\n{two}\nto\n{combined}') - # account for a newline in the middle - if expect_addition: - self.assertGreaterEqual(len(combined), len(one) + len(two)) - return combined - - def test_combining_makes_longer_yaml(self): - self.assert_simple_combination(self.tests1, self.tests2) - - def test_combinig_renames_similar_case(self): - self.assert_simple_combination(self.tests1, self.tests1_cp) - - def test_combining_combines(self): - # print('\n\nin adding new case\n\n') - combined = self.assert_simple_combination( - self.tests1, self.tests1, expect_addition=False, ignore_identicals=False) - cc = RegressionChecker.from_dict(yaml.safe_load(combined)) - # print('\n\nEND test_combining_combines') - self.assertEqual(len(cc.cases), 1) - self.assertEqual(len(cc.cases[0].phrases), 2) - - def test_combining_no_combine_when_ignoring_identicals(self): - # print('\n\nin adding new case\n\n') - combined = self.assert_simple_combination( - self.tests1, self.tests1, expect_addition=False, ignore_identicals=True) - cc = RegressionChecker.from_dict(yaml.safe_load(combined)) - # print('\n\nEND test_combining_combines') - self.assertEqual(len(cc.cases), 1) - self.assertEqual(len(cc.cases[0].phrases), 1) diff --git a/tests/utils/regression/test_mct_2_yml.py b/tests/utils/regression/test_mct_2_yml.py deleted file mode 100644 index 85c8f4aa1..000000000 --- a/tests/utils/regression/test_mct_2_yml.py +++ /dev/null @@ -1,218 +0,0 @@ -import json -import re -import unittest -import yaml - -from medcat.utils.regression.checking import RegressionChecker - -from medcat.utils.regression.converting import PerSentenceSelector, PerWordContextSelector, UniqueNamePreserver, medcat_export_json_to_regression_yml -from medcat.utils.regression.targeting import FilterType - - -class FakeTranslationLayer: - - def __init__(self, mct_export: dict) -> None: - self.mct_export = json.loads(mct_export) - - def all_targets(self, *args, **kwargs): # -> Iterator[str, str]: - for project in self.mct_export['projects']: - for doc in project['documents']: - for ann in doc['annotations']: - yield ann['cui'], ann['value'] - - -class TestUniqueNames(unittest.TestCase): - - def test_UniqueNamePreserver_first_same(self, name='some name'): - unp = UniqueNamePreserver() - uname = unp.get_unique_name(name) - self.assertEqual(name, uname) - - def test_UniqueNamePreserver_second_different(self, name='some name'): - unp = UniqueNamePreserver() - _ = unp.get_unique_name(name) - uname2 = unp.get_unique_name(name) - self.assertNotEqual(name, uname2) - - def test_UniqueNamePreserver_second_starts_with_name(self, name='some name'): - unp = UniqueNamePreserver() - _ = unp.get_unique_name(name) - uname2 = unp.get_unique_name(name) - self.assertIn(name, uname2) - self.assertTrue(uname2.startswith(name)) - - -class TestConversion(unittest.TestCase): - def_file_name = 'tests/resources/medcat_trainer_export.json' - _converted_yaml = None - _mct_export = None - - @property - def converted_yaml(self): - if not self._converted_yaml: - self._converted_yaml = medcat_export_json_to_regression_yml( - self.def_file_name) - return self._converted_yaml - - @property - def mct_export(self): - if not self._mct_export: - with open(self.def_file_name, 'r') as f: - self._mct_export = f.read() - return self._mct_export - - def test_conversion_default_gets_str(self): - self.assertIsInstance(self.converted_yaml, str) - self.assertGreater(len(self.converted_yaml), 0) - - def test_conversion_default_gets_yml(self): - d = yaml.safe_load(self.converted_yaml) - self.assertIsInstance(d, dict) - self.assertGreater(len(d), 0) - - def test_conversion_valid_regression_checker(self): - d = yaml.safe_load(self.converted_yaml) - checker = RegressionChecker.from_dict(d) - self.assertIsInstance(checker, RegressionChecker) - - def test_conversion_filters_for_names(self): - d = yaml.safe_load(self.converted_yaml) - checker = RegressionChecker.from_dict(d) - for case in checker.cases: - with self.subTest(f'Case {case}'): - self.assertTrue( - any(filt.type == FilterType.NAME for filt in case.filters)) - - def test_conversion_filters_for_cuis(self): - d = yaml.safe_load(self.converted_yaml) - checker = RegressionChecker.from_dict(d) - for case in checker.cases: - with self.subTest(f'Case {case}'): - self.assertTrue( - any(filt.type == FilterType.CUI for filt in case.filters)) - - def test_correct_number_of_cases(self): - checker = RegressionChecker.from_dict( - yaml.safe_load(self.converted_yaml)) - expected = self.mct_export.count('"cui":') - total_cases = 0 - for case in checker.cases: - total_cases += len(case.phrases) - self.assertEqual(total_cases, expected) - - def test_cases_have_1_replacement_part(self): - checker = RegressionChecker.from_dict( - yaml.safe_load(self.converted_yaml)) - for case, cui, name, phrase in checker.get_all_subcases(FakeTranslationLayer(self.mct_export)): - with self.subTest(f'With phrase {phrase} and {case} and {(cui, name)}'): - replaced = phrase % 'something' - self.assertIsInstance(replaced, str) - - -class TestSelectors(unittest.TestCase): - words_before = 2 - words_after = 3 - - def test_ContextSelector_able_to_remove_extra_percent(self, text='Some 1% and #TEST# ' - 'then 2%-3%, or 5%', find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - sel = PerSentenceSelector() - context = sel.get_context(text, start, end) - replaced = context % find - self.assertIsInstance(replaced, str) - self.assertIn(find, replaced) - - def test_ContextSelector_removes_precentage_example(self, text=',HISTORY OF PRESENT ILLNESS:, ' - 'A 48-year-old African-American male with a history of ' - 'coronary artery disease, COPD, congestive heart failure ' - 'with EF of 20%-25%, hypertension, renal insufficiency, ' - 'and recurrent episodes of hypertensive emergency, ' - 'admitted secondary to shortness of breath and ' - 'productive cough', find='episodes'): - self.test_ContextSelector_able_to_remove_extra_percent( - text=text, find=find) - - def test_PerWordContext_contains_concept(self, text='some random text with #TEST# stuff and' - ' then some more text', find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - pwcs = PerWordContextSelector(self.words_before, self.words_after) - context = pwcs.get_context(text, start, end, leave_concept=True) - self.assertIn(find, context) - - def test_PerWordContextSelector_selects_words_both_sides_plenty(self, - text='with some text here #TEST# and some text after', - find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - pwcs = PerWordContextSelector(self.words_before, self.words_after) - context = pwcs.get_context(text, start, end, leave_concept=True) - expected_words = self.words_before + \ - self.words_after + 1 # 1 for the word to be found - nr_of_original_words = len(text.split()) - nr_of_words_in_context = len(context.split()) - self.assertLessEqual(nr_of_words_in_context, nr_of_original_words) - self.assertEqual(nr_of_words_in_context, expected_words) - return context - - def test_PerWordContextSelector_selects_words_both_sides_short(self, - text='one #TEST# each', - find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - pwcs = PerWordContextSelector(self.words_before, self.words_after) - context = pwcs.get_context(text, start, end, leave_concept=True) - nr_of_original_words = len(text.split()) - expected_words = nr_of_original_words # all - nr_of_words_in_context = len(context.split()) - self.assertEqual(nr_of_words_in_context, expected_words) - - def test_PerWordContextSelector_no_care_sentences(self, - text='sentence ends. #TEST# here. ' - 'And more stuff', - find='#TEST#'): - context = self.test_PerWordContextSelector_selects_words_both_sides_plenty( - text, find) - self.assertIn('.', context) - - def test_PerSentenceSelector_contains_concept(self, text='other sentence ends.' - ' some random text with #TEST# stuff and' - ' then sentence ends.' - ' some more text', find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - psc = PerSentenceSelector() - context = psc.get_context(text, start, end, leave_concept=True) - self.assertIn(find, context) - - def test_PerSentenceSelector_selects_sentence_ends_long(self, text='Prev sent. Now #TEST# sentence that ends with a lot of words.' - 'And then there is more sentences. And more.', find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - psc = PerSentenceSelector() - context = psc.get_context(text, start, end, leave_concept=True) - self.assertIsNone(re.search(psc.stoppers, context)) - self.assertLessEqual(len(context), len(text)) - man_found = text[text.rfind( - '.', 0, start) + 1: text.find('.', end)].strip() - self.assertEqual(context, man_found) - - def test_PerSentenceSelector_selects_first_sent(self, text='First #TEST# sentence. That ends early.' - 'And then there is more sentences. And more.', find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - psc = PerSentenceSelector() - context = psc.get_context(text, start, end, leave_concept=True) - self.assertIn(context, text) - self.assertTrue(text.startswith(context)) - - def test_PerSentenceSelector_selects_last_sent(self, text='Firs there are sentences.' - 'And then there are more. Finally, we have #TEST# word', - find='#TEST#'): - found = re.search(find, text) - start, end = found.start(), found.end() - psc = PerSentenceSelector() - context = psc.get_context(text, start, end, leave_concept=True) - self.assertIn(context, text) - self.assertTrue(text.endswith(context)) diff --git a/tests/utils/regression/test_results.py b/tests/utils/regression/test_results.py index 115463e69..d8e27054a 100644 --- a/tests/utils/regression/test_results.py +++ b/tests/utils/regression/test_results.py @@ -1,119 +1,303 @@ +from typing import Optional import unittest +from copy import deepcopy +import json from medcat.utils.regression.targeting import TranslationLayer -from medcat.utils.regression.results import FailDescriptor, FailReason +from medcat.utils.regression.results import Finding, MalformedFinding +from medcat.utils.regression.results import FindingDeterminer +from medcat.utils.regression.results import SingleResultDescriptor +from medcat.utils.regression.targeting import FinalTarget +from .test_checking import FakeCDB -class TestFailReason(unittest.TestCase): - cui2names = { - 'cui1': set(['name-cui1-1', 'name-cui1-2']), - 'cui2': set(['name-cui2-1', 'name-cui2-2']), - 'cui3': set(['name-cui3-1', 'name-cui3-2', 'name-cui3-3']), - 'cui4': set(['name-cui4-1', ]), - } - # only works if one name corresponds to one CUI - name2cuis = dict([(name, set([cui])) - for cui, names in cui2names.items() for name in names]) - cui2type_ids = { - 'cui1': set(['T1', ]), - 'cui2': set(['T1', ]), - 'cui3': set(['T2', ]), - 'cui4': set(['T4', ]) + +def _determine_raw_helper(exp_start: int, exp_end: int, + start: int, end: int, + strict_only: bool = False) -> Optional[Finding]: + return FindingDeterminer("NO_MATTER", exp_start, exp_end, None, + None, strict_only=strict_only)._determine_raw(start, end) + + +class FindingRawTests(unittest.TestCase): + EXAMPLES = [ + # (exp start, exp end, start, end), expected finding + # start < exp_start + ((10, 15, 0, 1), None), + ((10, 15, 0, 11), Finding.PARTIAL_OVERLAP), + ((10, 15, 0, 15), Finding.BIGGER_SPAN_LEFT), + ((10, 15, 0, 25), Finding.BIGGER_SPAN_BOTH), + # start == exp_start + ((10, 15, 10, 12), Finding.SMALLER_SPAN), + ((10, 15, 10, 15), Finding.IDENTICAL), + ((10, 15, 10, 25), Finding.BIGGER_SPAN_RIGHT), + # exp_start < start < exp_end + ((10, 15, 12, 13), Finding.SMALLER_SPAN), + ((10, 15, 12, 15), Finding.SMALLER_SPAN), + ((10, 15, 12, 25), Finding.PARTIAL_OVERLAP), + # exp_start >= end_end + ((10, 15, 20, 25), None), + ] + + def test_finds_correctly(self): + for args, expected in self.EXAMPLES: + with self.subTest(f"With args {args}"): + found = _determine_raw_helper(*args) + self.assertEqual(found, expected) + + def test_exception_when_improper_start_end(self): + with self.assertRaises(MalformedFinding): + _determine_raw_helper(0, 1, 10, 0) + + def test_exception_when_improper_expected_start_end(self): + with self.assertRaises(MalformedFinding): + _determine_raw_helper(10, 1, 0, 1) + + +def _get_example_ent(cui: str = "CUI1", start: int = 10, end: int = 15): + return {"cui": cui, + "start": start, + "end": end} + + +def _get_example_kwargs(cui: str = "CUI1", + exp_start: int = 10, exp_end: int = 15): + return { + "exp_cui": cui, + "exp_start": exp_start, + "exp_end": exp_end, + "check_children": True, + "check_parent": True, + "check_grandparent": True + } + + +class FindingFromEntsTests(unittest.TestCase): + EXAMPLES = [ + # identical + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent()}}, Finding.IDENTICAL), + # bigger span + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=5)}, + }, Finding.BIGGER_SPAN_LEFT), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(end=25)}, + }, Finding.BIGGER_SPAN_RIGHT), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=5, end=25)}, + }, Finding.BIGGER_SPAN_BOTH), + # smaller span + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(end=13)}, + }, Finding.SMALLER_SPAN), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(end=13)}, + }, Finding.SMALLER_SPAN), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=11, end=13)}, + }, Finding.SMALLER_SPAN), + # overlapping span + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=5, end=12)}, + }, Finding.PARTIAL_OVERLAP), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=12, end=122)}, + }, Finding.PARTIAL_OVERLAP), + # identical with some noise start + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=0, end=5), + 1: _get_example_ent()}, + }, Finding.IDENTICAL), + # identical with some noise end + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(), + 1: _get_example_ent(start=20, end=25)}, + }, Finding.IDENTICAL), + # identical with some noise both sides + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=0, end=5), + 1: _get_example_ent(), + 2: _get_example_ent(start=20, end=25)}, + }, Finding.IDENTICAL), + # start from example 12 + # FAILURES + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(cui="CUI2")}, + }, Finding.FOUND_OTHER), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=0, end=5)}, + }, Finding.FAIL), + ({**_get_example_kwargs(), + "found_entities": {0: _get_example_ent(start=20, end=25)}, + }, Finding.FAIL), + ({**_get_example_kwargs(), + "found_entities": {}, + }, Finding.FAIL), + ] + NR_OF_EXAMPLES = len(EXAMPLES) + TL = TranslationLayer.from_CDB(FakeCDB()) + + def test_finds_examples(self): + self.assertEqual(len(self.EXAMPLES), self.NR_OF_EXAMPLES) + for nr, (ekwargs, expected) in enumerate(self.EXAMPLES): + with self.subTest(f"With [{nr}] kwargs {ekwargs}"): + found, _ = Finding.determine(tl=self.TL, **ekwargs) + self.assertEqual(found, expected) + + +EXAMPLE_INFOS = [ + ['CGP', 'NGP', 'T1'], # the grandparent + # CUI, NAME, TYPE_ID + ['C123', 'N123', 'T1'], + ['C124', 'N124', 'T1'], + ['C223', 'N223', 'T2'], + ['C224', 'N224', 'T2'], + # non-unique name + ['C323', 'N123', 'T3'], + ['C324', 'N124', 'T3'], +] + + +class FindingFromEntsWithChildrenTests(unittest.TestCase): + FAKE_CDB = FakeCDB(*EXAMPLE_INFOS) + TL = TranslationLayer.from_CDB(FAKE_CDB) + THE_GRANPARENT = 'CGP' + THE_PARENT = "C123" + THE_CHILD = "C124" + PT2CHILD = { + THE_GRANPARENT: {THE_PARENT}, + THE_PARENT: {THE_CHILD} } - cui2children = {} # none for now - tl = TranslationLayer(cui2names, name2cuis, cui2type_ids, cui2children) - - def test_cui_not_found(self, cui='cui-100', name='random n4m3'): - fr = FailDescriptor.get_reason_for(cui, name, {}, self.tl) - self.assertIs(fr.reason, FailReason.CUI_NOT_FOUND) - - def test_cui_name_found(self, cui='cui1', name='random n4m3-not-there'): - fr = FailDescriptor.get_reason_for(cui, name, {}, self.tl) - self.assertIs(fr.reason, FailReason.NAME_NOT_FOUND) - - -class TestFailReasonWithResultAndChildren(TestFailReason): - res_w_cui1 = {'entities': { - # cui1 - 1: {'source_value': list(TestFailReason.cui2names['cui1'])[0], 'cui': 'cui1'}, - }} - res_w_cui2 = {'entities': { - # cui2 - 1: {'source_value': list(TestFailReason.cui2names['cui2'])[0], 'cui': 'cui2'}, - }} - res_w_both = {'entities': { - # cui1 - 1: {'source_value': list(TestFailReason.cui2names['cui1'])[0], 'cui': 'cui1'}, - # cui2 - 2: {'source_value': list(TestFailReason.cui2names['cui2'])[0], 'cui': 'cui2'}, - }} - cui2children = {'cui1': set(['cui2'])} - tl = TranslationLayer(TestFailReason.cui2names, TestFailReason.name2cuis, - TestFailReason.cui2type_ids, cui2children) - - def test_found_child(self, cui='cui1', name='name-cui1-2'): - fr = FailDescriptor.get_reason_for(cui, name, self.res_w_cui2, self.tl) - self.assertIs(fr.reason, FailReason.CUI_CHILD_FOUND) - - def test_found_parent(self, cui='cui2', name='name-cui2-1'): - fr = FailDescriptor.get_reason_for(cui, name, self.res_w_cui1, self.tl) - self.assertIs(fr.reason, FailReason.CUI_PARENT_FOUND) - - -class TestFailReasonWithSpanningConcepts(unittest.TestCase): - cui2names = { - 'cui1': ('shallow', 'shallow2'), - 'cui1.1': ('broader shallow', 'broader shallow2'), - 'cui1.1.1': ('even broader shallow', 'even broader shallow2'), - 'cui2': ('name-2', ), + CHILD_MAPPED_EXACT_SPAN = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD)}} + GRANDCHILD_MAPPED_EXACT_SPAN = { + **_get_example_kwargs(cui=THE_GRANPARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD)}} + CHILD_MAPPED_PARTIAL_SAPN1 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=5, end=14)}} + CHILD_MAPPED_PARTIAL_SAPN2 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=5, end=15)}} + CHILD_MAPPED_PARTIAL_SAPN3 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=5, end=20)}} + CHILD_MAPPED_PARTIAL_SAPN4 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=10, end=14)}} + CHILD_MAPPED_PARTIAL_SAPN5 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=10, end=20)}} + CHILD_MAPPED_PARTIAL_SAPN6 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=11, end=14)}} + CHILD_MAPPED_PARTIAL_SAPN7 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=11, end=15)}} + CHILD_MAPPED_PARTIAL_SAPN8 = {**_get_example_kwargs(cui=THE_PARENT), + "found_entities": {0: _get_example_ent(cui=THE_CHILD, start=11, end=20)}} + PARTIAL_CHILDREN = [ + CHILD_MAPPED_PARTIAL_SAPN1, CHILD_MAPPED_PARTIAL_SAPN2, CHILD_MAPPED_PARTIAL_SAPN3, + CHILD_MAPPED_PARTIAL_SAPN4, CHILD_MAPPED_PARTIAL_SAPN5, CHILD_MAPPED_PARTIAL_SAPN6, + CHILD_MAPPED_PARTIAL_SAPN7, CHILD_MAPPED_PARTIAL_SAPN8 + ] + PARTIAL_GRANDCHILDREN = [ + {**d, "exp_cui": 'CGP'} for d in deepcopy(PARTIAL_CHILDREN)] + PARENT_MAPPED_EXACT_SPAN = { + **_get_example_kwargs(cui=THE_CHILD), + "found_entities": {0: _get_example_ent(cui=THE_PARENT)} } - # only works if one name corresponds to one CUI - name2cuis = dict([(name, set([cui])) - for cui, names in cui2names.items() for name in names]) - cui2type_ids = { - 'cui1': set(['T1', ]), - 'cui1.1': set(['T1', ]), - 'cui1.1.1': set(['T1', ]) + GRANDPARENT_MAPPED_EXACT_SPAN = { + **_get_example_kwargs(cui=THE_CHILD), + "found_entities": {0: _get_example_ent(cui=THE_GRANPARENT)} } - cui2children = {} # none for now - tl = TranslationLayer(cui2names, name2cuis, cui2type_ids, cui2children) - - res_w_cui1 = {'entities': { - # cui1 - 1: {'source_value': list(cui2names['cui1'])[0], 'cui': 'cui1'}, - }} - - res_w_cui11 = {'entities': { - # cui1.1 - 1: {'source_value': list(cui2names['cui1.1'])[0], 'cui': 'cui1.1'}, - }} - - res_w_cui111 = {'entities': { - # cui1.1.1 - 1: {'source_value': list(cui2names['cui1.1.1'])[0], 'cui': 'cui1.1.1'}, - }} - res_w_all = {'entities': dict([(nr, d['entities'][1]) for nr, d in enumerate([ - res_w_cui1, res_w_cui11, res_w_cui111])])} - - def test_gets_incorrect_span_big(self, cui='cui1', name='shallow'): - fr = FailDescriptor.get_reason_for( - cui, name, self.res_w_cui11, self.tl) - self.assertIs(fr.reason, FailReason.INCORRECT_SPAN_BIG) - - def test_gets_incorrect_span_bigger(self, cui='cui1', name='shallow'): - fr = FailDescriptor.get_reason_for( - cui, name, self.res_w_cui111, self.tl) - self.assertIs(fr.reason, FailReason.INCORRECT_SPAN_BIG) - - def test_gets_incorrect_span_small(self, cui='cui1.1', name='broader shallow'): - fr = FailDescriptor.get_reason_for(cui, name, self.res_w_cui1, self.tl) - self.assertIs(fr.reason, FailReason.INCORRECT_SPAN_SMALL) # HERE - - def test_gets_incorrect_span_smaller(self, cui='cui1.1.1', name='even broader shallow'): - fr = FailDescriptor.get_reason_for(cui, name, self.res_w_cui1, self.tl) - self.assertIs(fr.reason, FailReason.INCORRECT_SPAN_SMALL) # and HERE - - def test_gets_not_annotated(self, cui='cui2', name='name-2'): - fr = FailDescriptor.get_reason_for(cui, name, self.res_w_all, self.tl) - self.assertIs(fr.reason, FailReason.CONCEPT_NOT_ANNOTATED) + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.FAKE_CDB.addl_info['pt2ch'].update(cls.PT2CHILD) + + def test_finds_child_exact_span(self): + finding, optcui = Finding.determine(tl=self.TL, **self.CHILD_MAPPED_EXACT_SPAN) + self.assertIs(finding, Finding.FOUND_ANY_CHILD) + self.assertIsNotNone(optcui) + self.assertTrue(optcui.startswith(self.THE_CHILD)) + + def test_finds_grandchild_exact_span(self): + finding, optcui = Finding.determine(tl=self.TL, **self.GRANDCHILD_MAPPED_EXACT_SPAN) + self.assertIs(finding, Finding.FOUND_ANY_CHILD) + self.assertIsNotNone(optcui) + self.assertTrue(optcui.startswith(self.THE_CHILD)) + + def test_finds_child_partial_span(self): + for nr, ekwargs in enumerate(self.PARTIAL_CHILDREN): + with self.subTest(f"{nr}: {ekwargs}"): + finding, optcui = Finding.determine(tl=self.TL, **ekwargs) + self.assertIs(finding, Finding.FOUND_CHILD_PARTIAL) + self.assertIsNotNone(optcui) + self.assertTrue(optcui.startswith(self.THE_CHILD)) + + def test_finds_grandchild_partial_span(self): + for nr, ekwargs in enumerate(self.PARTIAL_GRANDCHILDREN): + with self.subTest(f"{nr}: {ekwargs}"): + finding, optcui = Finding.determine(tl=self.TL, **ekwargs) + self.assertIs(finding, Finding.FOUND_CHILD_PARTIAL) + self.assertIsNotNone(optcui) + self.assertTrue(optcui.startswith(self.THE_CHILD)) + + def test_finds_parent_exact_span(self): + finding, parcui = Finding.determine(tl=self.TL, **self.PARENT_MAPPED_EXACT_SPAN) + self.assertIs(finding, Finding.FOUND_DIR_PARENT) + self.assertTrue(parcui.startswith(self.THE_PARENT)) # NOTE: also has the preferred name + + def test_finds_grandparent_exact_span(self): + finding, parcui = Finding.determine(tl=self.TL, **self.GRANDPARENT_MAPPED_EXACT_SPAN) + self.assertIs(finding, Finding.FOUND_DIR_GRANDPARENT) + self.assertTrue(parcui.startswith(self.THE_GRANPARENT)) # NOTE: also has the preferred name + + +class FindingFromEntsStrictTests(FindingFromEntsTests): + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.EXAMPLES = [ + ({**e_kwargs, 'strict_only': True}, e_exp) for e_kwargs, e_exp in cls.EXAMPLES.copy() + if e_exp in (Finding.IDENTICAL, Finding.FAIL) + ] + cls.NR_OF_EXAMPLES = len(cls.EXAMPLES) + cls.FAIL_EXAMPLES = [ + ({**e_kwargs, 'strict_only': True}, e_exp) for e_kwargs, e_exp in cls.EXAMPLES.copy() + if e_exp not in (Finding.IDENTICAL, Finding.FAIL) + ] + + def test_fails_on_non_identical_or_fail_in_strict_mode(self): + for nr, (ekwargs, _) in enumerate(self.FAIL_EXAMPLES): + with self.subTest(f"With [{nr}] kwargs {ekwargs}"): + found, optcui = Finding.determine(tl=self.TL, **ekwargs) + self.assertEqual(found, Finding.FAIL) + self.assertIsNotNone(optcui) + + +class SingleResultDescriptorSerialisationTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + e1 = (FinalTarget(placeholder='$', cui='CUI1', name='NAME1', final_phrase='FINAL PHRASE'), + (Finding.FOUND_OTHER, 'OTHER CUI')) + e2 = (FinalTarget(placeholder='$', cui='CUIP', name='PARENT', final_phrase='FINAL PHRASE'), + (Finding.FOUND_ANY_CHILD, 'CUI_C (CHILD)')) + findings = {Finding.FOUND_OTHER: 1, Finding.FOUND_ANY_CHILD: 1} + cls.rd = SingleResultDescriptor(name="RANDOM_NAME", findings=findings, + examples=[e1, e2]) + + def test_can_json_dump_pydantic(self): + s = self.rd.json() + self.assertIsInstance(s, str) + + def test_can_json_dump_json(self): + s = json.dumps(self.rd.dict()) + self.assertIsInstance(s, str) + + def test_can_use_strictness_for_dump(self): + d_strictest = self.rd.dict(strictness='STRICTEST') + e_strictest = d_strictest['examples'] + # this should have more examples + d_lenient = self.rd.dict(strictness='NORMAL') + e_normal = d_lenient['examples'] + self.assertGreater(len(e_strictest), len(e_normal)) diff --git a/tests/utils/regression/test_separation.py b/tests/utils/regression/test_separation.py deleted file mode 100644 index d3c16f0b0..000000000 --- a/tests/utils/regression/test_separation.py +++ /dev/null @@ -1,533 +0,0 @@ -import os -from typing import Iterator, cast -import yaml -from functools import lru_cache -import tempfile - -from medcat.utils.regression.checking import RegressionCase, ResultDescriptor, FilterOptions, FilterStrategy, TypedFilter, FilterType -from medcat.utils.regression.checking import RegressionChecker -from medcat.utils.regression.converting import medcat_export_json_to_regression_yml -from medcat.utils.regression.category_separation import CategoryDescription, Category, AllPartsCategory, AnyPartOfCategory -from medcat.utils.regression.category_separation import SeparationObserver, SeparateToFirst, SeparateToAll, read_categories -from medcat.utils.regression.category_separation import RegressionCheckerSeparator, separate_categories, StrategyType -from medcat.utils.regression.editing import combine_yamls - -import unittest - - -class CategoryDescriptionTests(unittest.TestCase): - CUIS = ['c123', 'c111'] - NAMES = ['NAME1', 'NAME9'] - TUIS = ['T-1', 'T-10'] - - def setUp(self) -> None: - self.cd = CategoryDescription( - target_cuis=set(self.CUIS), target_names=set(self.NAMES), target_tuis=set(self.TUIS)) - self.anything = CategoryDescription.anything_goes() - - def test_initiates(self): - self.assertIsNotNone(self.cd) - - def get_case_for(self, cui=None, name=None, tui=None) -> RegressionCase: - cname = f'TEMPNAME={cui}-{name}-{tui}' - cphrase = 'does not matter %s' - fo = FilterOptions(strategy=FilterStrategy.ANY) - if cui: - ft = FilterType.CUI - value = cui - elif name: - ft = FilterType.NAME - value = name - elif tui: - ft = FilterType.TYPE_ID - value = tui - else: - raise ValueError( - f"Unknown filter for CUI: {cui} NAME: {name} and TUI: {tui}") - cfilter = TypedFilter(type=ft, values=[value]) - return RegressionCase(name=cname, options=fo, filters=[cfilter], phrases=[cphrase], report=ResultDescriptor(name=cname)) - - def helper_recognizes(self, items: list, case_kw: str, method: callable): - for item in items: - with self.subTest(f'With item {item}, testing {case_kw} and {method} for RECOGNIZES'): - self.assertTrue(method(self.get_case_for(**{case_kw: item}))) - - def helper_does_not_recognize(self, items: list, case_kw: str, method: callable): - for item in items: - with self.subTest(f'With item {item}, testing {case_kw} and {method} for NOT RECOGNIZES'): - self.assertFalse(method(self.get_case_for(**{case_kw: item}))) - - def test_recognizes_CUIS(self): - self.helper_recognizes(self.CUIS, 'cui', self.cd.has_cui_from) - - def test_does_NOT_recognize_wrong_CUIS(self): - self.helper_does_not_recognize(self.NAMES, 'cui', self.cd.has_cui_from) - - def test_recognizes_NAMES(self): - self.helper_recognizes(self.NAMES, 'name', self.cd.has_name_from) - - def test_does_NOT_recognize_wrong_NAMES(self): - self.helper_does_not_recognize( - self.CUIS, 'name', self.cd.has_name_from) - - def test_recognizes_TUIS(self): - self.helper_recognizes(self.TUIS, 'tui', self.cd.has_tui_from) - - def test_does_NOT_recognize_wrong_TUIS(self): - self.helper_does_not_recognize(self.NAMES, 'tui', self.cd.has_tui_from) - - def test_anythong_goes_recognizes_anything_cui4cui(self): - self.helper_recognizes(self.CUIS, 'cui', self.anything.has_cui_from) - - def test_anythong_goes_recognizes_anything_tui4cui(self): - self.helper_recognizes(self.TUIS, 'cui', self.anything.has_cui_from) - - def test_anythong_goes_recognizes_anything_name4cui(self): - self.helper_recognizes(self.NAMES, 'cui', self.anything.has_cui_from) - - def test_anythong_goes_recognizes_anything_tui4tui(self): - self.helper_recognizes(self.TUIS, 'tui', self.anything.has_tui_from) - - def test_anythong_goes_recognizes_anything_cui4tui(self): - self.helper_recognizes(self.CUIS, 'tui', self.anything.has_tui_from) - - def test_anythong_goes_recognizes_anything_name4tui(self): - self.helper_recognizes(self.NAMES, 'tui', self.anything.has_tui_from) - - def test_anythong_goes_recognizes_anything_name4name(self): - self.helper_recognizes(self.NAMES, 'name', self.anything.has_name_from) - - def test_anythong_goes_recognizes_anything_cui4name(self): - self.helper_recognizes(self.CUIS, 'name', self.anything.has_name_from) - - def test_anythong_goes_recognizes_anything_tui4name(self): - self.helper_recognizes(self.TUIS, 'name', self.anything.has_name_from) - - -def get_case(cui, tui, name): - if cui: - cui_filter = TypedFilter(type=FilterType.CUI, values=[cui]) - else: - cui_filter = None - if tui: - tui_filter = TypedFilter(type=FilterType.TYPE_ID, values=[tui]) - else: - tui_filter = None - if name: - name_filter = TypedFilter(type=FilterType.NAME, values=[name]) - else: - name_filter = None - fo = FilterOptions(strategy=FilterStrategy.ALL) - cphrase = 'Phrase does not matter %s' - filters = [cui_filter, tui_filter, name_filter] - filters = [f for f in filters if f is not None] - return RegressionCase(name=f'rc w/ cui: {cui}, tui: {tui}, name: {name}', options=fo, - filters=filters, phrases=[cphrase], report=ResultDescriptor(name='TestRD')) - - -class AllPartsCategoryTests(unittest.TestCase): - - def setUp(self) -> None: - cdt = CategoryDescriptionTests() - cdt.setUp() - self.cat = AllPartsCategory('ALL=parts', cdt.cd) - - def test_initializes(self): - self.assertIsNotNone(self.cat) - - def test_recognizes_correct(self): - for cui in CategoryDescriptionTests.CUIS: - for tui in CategoryDescriptionTests.TUIS: - for name in CategoryDescriptionTests.NAMES: - with self.subTest(f'cui: {cui}, tui: {tui}, name: {name}'): - case = get_case(cui, tui, name) - self.assertTrue(self.cat.fits(case)) - - def helper_does_NOT_recognize_one_at_time_3in1(self, items: list): - for item in items: - with self.subTest(f'ITEM: {item} (as CUI, TUI, and name)'): - case = get_case(item, item, item) - self.assertFalse(self.cat.fits(case)) - - def test_does_NOT_recognize_one_at_time_CUI_3in1(self): - self.helper_does_NOT_recognize_one_at_time_3in1( - CategoryDescriptionTests.CUIS) - - def test_does_NOT_recognize_one_at_time_NAME_3in1(self): - self.helper_does_NOT_recognize_one_at_time_3in1( - CategoryDescriptionTests.NAMES) - - def test_does_NOT_recognize_one_at_time_TUI_3in1(self): - self.helper_does_NOT_recognize_one_at_time_3in1( - CategoryDescriptionTests.TUIS) - - def helper_does_NOT_recognize_one_at_time_just1(self, items: list, order: int): - args = [None, None, None] - for item in items: - with self.subTest(f'ITEM: {item} (as CUI, TUI, OR name)'): - args[order] = item - case = get_case(*args) - self.assertFalse(self.cat.fits(case)) - - def test_does_NOT_recognize_one_at_time_CUI_just1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.CUIS, 0) - - def test_does_NOT_recognize_one_at_time_CUI_just1_wrong_type1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.CUIS, 1) - - def test_does_NOT_recognize_one_at_time_CUI_just1_wrong_type2(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.CUIS, 2) - - def test_does_NOT_recognize_one_at_time_NAME_just1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.NAMES, 2) - - def test_does_NOT_recognize_one_at_time_NAME_just1_wrong_type1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.NAMES, 0) - - def test_does_NOT_recognize_one_at_time_NAME_just1_wrong_type2(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.NAMES, 1) - - def test_does_NOT_recognize_one_at_time_TUI_just1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.TUIS, 1) - - def test_does_NOT_recognize_one_at_time_TUI_just1_wrong_type1(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.TUIS, 0) - - def test_does_NOT_recognize_one_at_time_TUI_just1_wrong_type2(self): - self.helper_does_NOT_recognize_one_at_time_just1( - CategoryDescriptionTests.TUIS, 2) - - -class AnyPartOfCategoryTests(unittest.TestCase): - - def setUp(self) -> None: - cdt = CategoryDescriptionTests() - cdt.setUp() - self.cat = AnyPartOfCategory('ANY=parts', cdt.cd) - - def test_init(self): - self.assertIsNotNone(self.cat) - - def helper_recognizes_any(self, items: list, order: int): - args = [None, None, None] - for item in items: - with self.subTest(f'Testing {item} as {["CUI", "TUI", "NAME"][order]}'): - args[order] = item - case = get_case(*args) - self.assertTrue(self.cat.fits(case)) - - def helper_recognizes_any_2(self, items1: list, order1: int, items2, order2: int): - args = [None, None, None] - for item1 in items1: - for item2 in items2: - with self.subTest(f'Testing {item1} and {item2} as {["CUI", "TUI", "NAME"][order1]} and ' - '{["CUI", "TUI", "NAME"][order2]}, respectively'): - args[order1] = item1 - args[order2] = item2 - case = get_case(*args) - self.assertTrue(self.cat.fits(case)) - - def test_recognizes_any_cui(self): - self.helper_recognizes_any(CategoryDescriptionTests.CUIS, 0) - - def test_recognizes_any_tui(self): - self.helper_recognizes_any(CategoryDescriptionTests.TUIS, 1) - - def test_recognizes_any_name(self): - self.helper_recognizes_any(CategoryDescriptionTests.NAMES, 2) - - def test_recognizes_combinations_of_2(self): - self.helper_recognizes_any_2( - CategoryDescriptionTests.CUIS, 0, CategoryDescriptionTests.TUIS, 1) - self.helper_recognizes_any_2( - CategoryDescriptionTests.CUIS, 0, CategoryDescriptionTests.NAMES, 2) - self.helper_recognizes_any_2( - CategoryDescriptionTests.TUIS, 1, CategoryDescriptionTests.NAMES, 2) - - def test_recognizes_combinations_of_3(self): - for cui in CategoryDescriptionTests.CUIS: - for tui in CategoryDescriptionTests.TUIS: - for name in CategoryDescriptionTests.NAMES: - with self.subTest(f'cui: {cui}, tui: {tui}, name: {name}'): - case = get_case(cui, tui, name) - self.assertTrue(self.cat.fits(case)) - - -def get_all_cases() -> Iterator[RegressionCase]: - for cui in CategoryDescriptionTests.CUIS: - for tui in CategoryDescriptionTests.TUIS: - for name in CategoryDescriptionTests.NAMES: - all_args = [cui, tui, name] - # unique combinations of 1 or 2 None's - for nr in range(1, 2**3): # ignore 0 - cur_args = [(arg if (nr >> i) & 1 else None) - for i, arg in enumerate(all_args)] - yield get_case(*cur_args) - - -class SeparationObserverTests(unittest.TestCase): - - def setUp(self) -> None: - self.observer = SeparationObserver() - apct = AnyPartOfCategoryTests() - apct.setUp() - self.cat = apct.cat - - def test_init(self): - self.assertIsNotNone(self.observer) - - def test_new_has_not_observed(self): - for case in get_all_cases(): - with self.subTest(f'CASE: {case}'): - self.assertFalse(self.observer.has_observed(case)) - - def test_observes(self): - for case in get_all_cases(): - with self.subTest(f'CASE: {case}'): - self.observer.observe(case, category=self.cat) - self.assertTrue(self.observer.has_observed(case)) - - -TEST_CATEGORIES_FILE = os.path.join( - 'tests', 'resources', 'test_categories.yml') - - -def get_all_categories() -> Iterator[Category]: - return read_categories(TEST_CATEGORIES_FILE) - - -class SeparateToFirstTests(unittest.TestCase): - - def setUp(self) -> None: - sot = SeparationObserverTests() - sot.setUp() - self.strat = SeparateToFirst(observer=sot.observer) - - def test_init(self): - self.assertIsNotNone(self.strat) - - def test_new_can_separates(self): - for case in get_all_cases(): - with self.subTest(f'CASE: {case}'): - self.assertTrue(self.strat.can_separate(case)) - - def test_separates_cases_with_cat_cui(self): - for cat in get_all_categories(): - cat = cast(AnyPartOfCategory, cat) - for cui in cat.description.target_cuis: - case = get_case(cui, None, None) - with self.subTest(f'CASE: {case} and CATEGORY {cat}'): - self.assertTrue(self.strat.can_separate(case)) - self.strat.separate(case, cat) - self.assertTrue(self.strat.observer.has_observed(case)) - - def test_can_not_separate_cases_with_after_initial_separation(self): - self.test_separates_cases_with_cat_cui() # do initial separation - for cat in get_all_categories(): - cat = cast(AnyPartOfCategory, cat) - for cui in cat.description.target_cuis: - case = get_case(cui, None, None) - with self.subTest(f'CASE: {case} and CATEGORY {cat}'): - self.assertFalse(self.strat.can_separate(case)) - - -class SeparateToAllTests(unittest.TestCase): - - def setUp(self) -> None: - sot = SeparationObserverTests() - sot.setUp() - self.strat = SeparateToAll(observer=sot.observer) - - def test_init(self): - self.assertIsNotNone(self.strat) - - def test_new_can_separates(self): - for case in get_all_cases(): - with self.subTest(f'CASE: {case}'): - self.assertTrue(self.strat.can_separate(case)) - - def test_separates_cases_with_cat_cui(self): - for cat in get_all_categories(): - cat = cast(AnyPartOfCategory, cat) - for cui in cat.description.target_cuis: - case = get_case(cui, None, None) - with self.subTest(f'CASE: {case} and CATEGORY {cat}'): - self.assertTrue(self.strat.can_separate(case)) - self.strat.separate(case, cat) - self.assertTrue(self.strat.observer.has_observed(case)) - - def test_can_separate_cases_with_after_initial_separation(self): - self.test_separates_cases_with_cat_cui() # do initial separation - for cat in get_all_categories(): - cat = cast(AnyPartOfCategory, cat) - for cui in cat.description.target_cuis: - case = get_case(cui, None, None) - with self.subTest(f'CASE: {case} and CATEGORY {cat}'): - self.assertTrue(self.strat.can_separate(case)) - - -TEST_MCT_EXPORT_JSON_FILE = os.path.join("tests", "resources", - "medcat_trainer_export.json") - - -@lru_cache -def get_real_checker() -> RegressionChecker: - yaml_str = medcat_export_json_to_regression_yml(TEST_MCT_EXPORT_JSON_FILE) - d = yaml.safe_load(yaml_str) - return RegressionChecker.from_dict(d) - - -@lru_cache -def get_all_real_cases() -> Iterator[RegressionCase]: - rc = get_real_checker() - for case in rc.cases: - yield case - - -class RegressionCheckerSeparator_toFirst_Tests(unittest.TestCase): - - def setUp(self) -> None: - observer = SeparationObserver() - strat = SeparateToFirst(observer) - self.separator = RegressionCheckerSeparator( - categories=list(get_all_categories()), strategy=strat) - - def test_init(self): - self.assertIsNotNone(self.separator) - - def test_finds_categories(self): - for case in get_all_real_cases(): - with self.subTest(f'CASE: {case} and {self.separator}'): - self.separator.find_categories_for(case) - self.assertTrue( - self.separator.strategy.observer.has_observed(case)) - - def test_nr_of_cases_remains_same(self): - nr_of_total_cases = len(list(get_all_real_cases())) - separated_cases = 0 - self.test_finds_categories() - for cases in self.separator.strategy.observer.separated.values(): - separated_cases += len(cases) - self.assertEqual(nr_of_total_cases, separated_cases) - - -class RegressionCheckerSeparator_toAll_Tests(unittest.TestCase): - - def setUp(self) -> None: - stat = SeparateToAllTests() - stat.setUp() - self.separator = RegressionCheckerSeparator( - categories=list(get_all_categories()), strategy=stat.strat) - - def test_init(self): - self.assertIsNotNone(self.separator) - - def test_finds_categories(self): - for case in get_all_real_cases(): - with self.subTest(f'CASE: {case}'): - self.separator.find_categories_for(case) - self.assertTrue( - self.separator.strategy.observer.has_observed(case)) - - def test_nr_of_cases_remains_same_or_greater(self): - nr_of_total_cases = len(list(get_all_real_cases())) - separated_cases = 0 - self.test_finds_categories() - for cases in self.separator.strategy.observer.separated.values(): - separated_cases += len(cases) - self.assertGreaterEqual(nr_of_total_cases, separated_cases) - - -def get_applicable_files_in(folder: str, avoid_basename_start: str = 'converted') -> list: - orig_list = os.listdir(folder) - return [os.path.join(folder, fn) for fn in orig_list - if fn.endswith(".yml") and not fn.startswith(avoid_basename_start)] - - -class FullSeparationTests(unittest.TestCase): - - def save_copy_with_one_fewer_category(self): - self.one_fewer_categories_file = os.path.join( - self.other_temp_folder.name, 'one_fewer_categories.yml') - with open(TEST_CATEGORIES_FILE) as f: - d = yaml.safe_load(f) - categories = d['categories'] - to_remove = list(categories.keys())[0] - del categories[to_remove] - yaml_str = yaml.safe_dump(d) - with open(self.one_fewer_categories_file, 'w') as f: - f.write(yaml_str) - - def setUp(self) -> None: - # new temporary folders for new tests, just in case - self.target_prefix_file = tempfile.TemporaryDirectory() - self.other_temp_folder = tempfile.TemporaryDirectory() - self.rc = get_real_checker() - self.regr_yaml_file = os.path.join( - self.target_prefix_file.name, "converted_regr.yml") - yaml_str = self.rc.to_yaml() - with open(self.regr_yaml_file, 'w') as f: - f.write(yaml_str) - self.save_copy_with_one_fewer_category() - - def tearDown(self) -> None: - self.target_prefix_file.cleanup() - self.other_temp_folder.cleanup() - - def join_back_up(self) -> RegressionChecker: - files = get_applicable_files_in(self.target_prefix_file.name) - f0 = files[0] - f_new = os.path.join(self.target_prefix_file.name, "join-back-1.yml") - for f1 in files[1:]: - combine_yamls(f0, f1, f_new) - f0 = f_new - return RegressionChecker.from_yaml(f_new) - - def test_separations_work_alone(self): - prefix = os.path.join(self.target_prefix_file.name, 'split-') - separate_categories(TEST_CATEGORIES_FILE, - StrategyType.FIRST, self.regr_yaml_file, prefix) - files = get_applicable_files_in(self.target_prefix_file.name) - for f in files: - with self.subTest(f'With {f}'): - rc = RegressionChecker.from_yaml(f) - self.assertIsNotNone(rc) - - def test_separations_combined_same(self): - prefix = os.path.join(self.target_prefix_file.name, 'split-') - separate_categories(TEST_CATEGORIES_FILE, - StrategyType.FIRST, self.regr_yaml_file, prefix) - rc = self.join_back_up() - self.assertEqual(self.rc, rc) - - def test_something_lost_if_not_fit(self): - prefix = os.path.join(self.target_prefix_file.name, 'split-') - separate_categories(self.one_fewer_categories_file, - StrategyType.FIRST, self.regr_yaml_file, prefix, overflow_category=False) - files = get_applicable_files_in(self.target_prefix_file.name) - self.assertFalse(any('overflow-' in f for f in files)) - rc = self.join_back_up() - self.assertLess(len(rc.cases), len(self.rc.cases)) - self.assertNotEqual(rc, self.rc) - - def test_something_written_in_overflow(self): - prefix = os.path.join(self.target_prefix_file.name, 'split-') - separate_categories(self.one_fewer_categories_file, - StrategyType.FIRST, self.regr_yaml_file, prefix, overflow_category=True) - files = get_applicable_files_in(self.target_prefix_file.name) - self.assertTrue(any('overflow-' in f for f in files)) - - def test_something_NOT_lost_if_use_overflow(self): - prefix = os.path.join(self.target_prefix_file.name, 'split-') - separate_categories(self.one_fewer_categories_file, - StrategyType.FIRST, self.regr_yaml_file, prefix, overflow_category=True) - rc = self.join_back_up() - self.assertEqual(self.rc, rc) diff --git a/tests/utils/regression/test_serialisation.py b/tests/utils/regression/test_serialisation.py deleted file mode 100644 index f20bb89fc..000000000 --- a/tests/utils/regression/test_serialisation.py +++ /dev/null @@ -1,109 +0,0 @@ - -import unittest -from medcat.utils.regression.results import ResultDescriptor - -from medcat.utils.regression.targeting import CUIWithChildFilter, FilterOptions, FilterStrategy, FilterType -from medcat.utils.regression.targeting import TypedFilter -from medcat.utils.regression.checking import RegressionChecker, RegressionCase, MetaData - - -class TestSerialisation(unittest.TestCase): - - def test_TypedFilter_serialises(self, ft=FilterType.NAME, vals=['FNAME1', 'FNAME2']): - tf = TypedFilter(type=ft, values=vals) - self.assertIsInstance(tf.to_dict(), dict) - - def test_TypedFilter_deserialises(self, ft=FilterType.NAME, vals=['FNAME-1', 'FNAME=2']): - tf = TypedFilter(type=ft, values=vals) - tf2 = TypedFilter.from_dict(tf.to_dict())[0] - self.assertIsInstance(tf2, TypedFilter) - - def test_TypedFilter_deserialises_to_one(self, ft=FilterType.NAME, vals=['FNAME-1', 'FNAME=2']): - tf = TypedFilter(type=ft, values=vals) - l = TypedFilter.from_dict(tf.to_dict()) - self.assertEqual(len(l), 1) - - def test_TypedFilter_deserialises_to_same(self, ft=FilterType.NAME, vals=['FNAME-1', 'FNAME=2']): - tf = TypedFilter(type=ft, values=vals) - tf2 = TypedFilter.from_dict(tf.to_dict())[0] - self.assertEqual(tf, tf2) - - def test_CUIWithChildFilter_deserialises_to_same(self, cui='the-cui', depth=5): - delegate = TypedFilter(type=FilterType.CUI_AND_CHILDREN, values=[cui]) - tf = CUIWithChildFilter( - type=FilterType.CUI_AND_CHILDREN, depth=depth, delegate=delegate) - tf2 = TypedFilter.from_dict(tf.to_dict())[0] - self.assertIsInstance(tf2, CUIWithChildFilter) - self.assertEqual(tf, tf2) - - def test_multiple_TypedFilter_serialise(self, ft1=FilterType.NAME, ft2=FilterType.CUI, vals1=['NAMEFILTER1'], vals2=['CUI1']): - tf1 = TypedFilter(type=ft1, values=vals1) - tf2 = TypedFilter(type=ft2, values=vals2) - dicts = TypedFilter.list_to_dicts([tf1, tf2]) - self.assertIsInstance(dicts, list) - self.assertEqual(len(dicts), 2) - for d in dicts: - with self.subTest(f'Assuming dict: {d}'): - self.assertIsInstance(d, dict) - - def test_multiple_TypedFilter_serialise_into(self, ft1=FilterType.NAME, ft2=FilterType.CUI, vals1=['NAMEFILTER1'], vals2=['CUI1']): - tf1 = TypedFilter(type=ft1, values=vals1) - tf2 = TypedFilter(type=ft2, values=vals2) - dicts = TypedFilter.list_to_dicts([tf1, tf2]) - self.assertIsInstance(dicts, list) - - def test_multiple_TypedFilter_deserialise(self, ft1=FilterType.NAME, ft2=FilterType.CUI, vals1=['NAMEFILTER1'], vals2=['CUI1']): - tf1 = TypedFilter(type=ft1, values=vals1) - tf2 = TypedFilter(type=ft2, values=vals2) - tf1_cp, tf2_cp = TypedFilter.from_dict( - TypedFilter.list_to_dict([tf1, tf2])) - self.assertIsInstance(tf1_cp, TypedFilter) - self.assertIsInstance(tf2_cp, TypedFilter) - - def test_multiple_TypedFilter_deserialise_to_same(self, ft1=FilterType.NAME, ft2=FilterType.CUI, vals1=['NAMEFILTER1'], vals2=['CUI1']): - tf1 = TypedFilter(type=ft1, values=vals1) - tf2 = TypedFilter(type=ft2, values=vals2) - the_dict = TypedFilter.list_to_dict([tf1, tf2]) - self.assertIsInstance(the_dict, dict) - tf1_cp, tf2_cp = TypedFilter.from_dict(the_dict) - self.assertEqual(tf1, tf1_cp) - self.assertEqual(tf2, tf2_cp) - - def test_RegressionCase_serialises(self, name='the-name', options=FilterOptions(strategy=FilterStrategy.ALL), - filters=[TypedFilter( - type=FilterType.NAME, values=['nom1', 'nom2'])], - phrases=['the %s phrase']): - rc = RegressionCase(name=name, options=options, - filters=filters, phrases=phrases, report=ResultDescriptor(name=name)) - self.assertIsInstance(rc.to_dict(), dict) - - def test_RegressionCase_deserialises_to_same(self, name='the-name', options=FilterOptions(strategy=FilterStrategy.ANY), - filters=[TypedFilter( - type=FilterType.NAME, values=['nom1', 'nom2'])], - phrases=['the %s phrase']): - rc = RegressionCase(name=name, options=options, - filters=filters, phrases=phrases, report=ResultDescriptor(name=name)) - rc2 = RegressionCase.from_dict(name, rc.to_dict()) - self.assertIsInstance(rc2, RegressionCase) - self.assertEqual(rc, rc2) - - def test_RegressionChecker_serialises(self, name='the-name', options=FilterOptions(strategy=FilterStrategy.ALL), - filters=[TypedFilter( - type=FilterType.NAME, values=['nom1', 'nom2'])], - phrases=['the %s phrase']): - rc = RegressionCase(name=name, options=options, - filters=filters, phrases=phrases, report=ResultDescriptor(name=name)) - checker = RegressionChecker(cases=[rc], metadata=MetaData.unknown()) - self.assertIsInstance(checker.to_dict(), dict) - - def test_RegressionChecker_deserialises_to_same(self, name='the-name', options=FilterOptions(strategy=FilterStrategy.ANY), - filters=[TypedFilter( - type=FilterType.NAME, values=['nom1', 'nom2'])], - phrases=['the %s phrase']): - rc = RegressionCase(name=name, options=options, - filters=filters, phrases=phrases, report=ResultDescriptor(name=name)) - checker = RegressionChecker(cases=[rc], metadata=MetaData.unknown()) - checker2 = RegressionChecker.from_dict(checker.to_dict()) - self.assertIsInstance(checker2, RegressionChecker) - rc.__eq__ - self.assertEqual(checker, checker2) diff --git a/tests/utils/regression/test_targeting.py b/tests/utils/regression/test_targeting.py new file mode 100644 index 000000000..8a6885f37 --- /dev/null +++ b/tests/utils/regression/test_targeting.py @@ -0,0 +1,215 @@ +from typing import Optional, List +from unittest import TestCase + +from medcat.config import Config +from medcat.utils.regression import targeting + +from collections import defaultdict +from copy import deepcopy + + +class FakeCDB: + + def __init__(self, def_name: str, def_cui: str, pt2ch: Optional[dict] = None) -> None: + self.cui2names = defaultdict(lambda: {def_name}) + self.name2cuis = defaultdict(lambda: {def_cui}) + self.cui2type_ids = {} # NOTE: shouldn't be needed + if pt2ch is None: + pt2ch = {} + self.addl_info = {'pt2ch': pt2ch} + self.config = Config() + + def copy(self) -> 'FakeCDB': + cui2names = deepcopy(self.cui2names) + name2cuis = deepcopy(self.name2cuis) + addl_info = deepcopy(self.addl_info) + copy = FakeCDB(cui2names[None], name2cuis[None]) + copy.cui2names = cui2names + copy.name2cuis = name2cuis + copy.addl_info = addl_info + return copy + + @property + def cui2preferred_name(self) -> dict: + return {cui: list(names)[0] for cui, names in self.cui2names.items()} + + +class OptionSetTests(TestCase): + OPTIONSET_SIMPLE = { + 'placeholders': [ + { + 'placeholder': '%s', + 'cuis': ['CUI1'] + } + ] + } + OPTIONSET_MULTI = { + 'placeholders': [ + { + 'placeholder': '%s', + 'cuis': ['CUI1'] + }, + { + 'placeholder': '[PH1]', + 'cuis': ['CUI2'] + }, + ] + } + ALL_WORKING = [OPTIONSET_SIMPLE, OPTIONSET_MULTI] + OPTIONSET_MULTI_SAMES = { + 'placeholders': OPTIONSET_SIMPLE['placeholders'] * 2 + } + OPTIONSET_0_PH = {'placeholders': []} + OPTIONSET_NO_PH = {'SomeJunk': [{'KEYS': 'VALUES'}]} + EXPECTED_TARGETS = [ + (OPTIONSET_SIMPLE, 1), + (OPTIONSET_MULTI, 2) + ] + ALL_ALL = ALL_WORKING + [OPTIONSET_MULTI_SAMES, OPTIONSET_0_PH, OPTIONSET_NO_PH] + cdb = FakeCDB('NAME', 'CUI1') + + @classmethod + def discover_cuis_for(cls, d: dict) -> list: + all_cuis = [] + phs = d.get('placeholders', []) + for ph in phs: + all_cuis.extend(ph.get('cuis', [])) + return all_cuis + + + @classmethod + def discover_all_used_cuis(cls) -> list: + all_cuis = [] + for d in cls.ALL_ALL: + all_cuis.extend(cls.discover_cuis_for(d)) + return all_cuis + + @classmethod + def setUpClass(cls) -> None: + # add name per CUI + for cui in cls.discover_all_used_cuis(): + cls.cdb.cui2names[cui] = {f'cui-{cui}-name'} + cls.tl = targeting.TranslationLayer.from_CDB(cls.cdb) + + def test_create_from_dict_simple(self): + os = targeting.OptionSet.from_dict(self.OPTIONSET_SIMPLE) + self.assertIsInstance(os, targeting.OptionSet) + + def test_create_from_dict_multi(self): + os = targeting.OptionSet.from_dict(self.OPTIONSET_MULTI) + self.assertIsInstance(os, targeting.OptionSet) + + def test_creation_fails_with_same_placeholders(self): + with self.assertRaises(targeting.ProblematicOptionSetException): + targeting.OptionSet.from_dict(self.OPTIONSET_MULTI_SAMES) + + def test_creation_fails_no_placeholders(self): + with self.assertRaises(targeting.ProblematicOptionSetException): + targeting.OptionSet.from_dict(self.OPTIONSET_NO_PH) + + def test_creation_fails_0_placeholders(self): + with self.assertRaises(targeting.ProblematicOptionSetException): + targeting.OptionSet.from_dict(self.OPTIONSET_0_PH) + + def test_get_placeholders(self): + for nr, target in enumerate(self.ALL_WORKING): + with self.subTest(f'Target nr {nr}'): + os = targeting.OptionSet.from_dict(target) + self.assertEqual(len(os.options), len(target['placeholders'])) + + def test_uses_default_allow_any(self): + _def_value = targeting.OptionSet(options=[]).allow_any_combinations + for nr, target in enumerate(self.ALL_WORKING): + with self.subTest(f'Target nr {nr}'): + os = targeting.OptionSet.from_dict(target) + self.assertEqual(os.allow_any_combinations, _def_value) + + def test_gets_correct(self): + for nr, (d, num_of_targets) in enumerate(self.EXPECTED_TARGETS): + with self.subTest(f"Part: {nr}"): + os = targeting.OptionSet.from_dict(d) + targets = list(os.get_preprocessors_and_targets(self.tl)) + self.assertEqual(len(targets), num_of_targets) + + +class OnePerNameOptionSetTests(TestCase): + SIMPLE = OptionSetTests.OPTIONSET_SIMPLE + MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED = { + 'placeholders': [ + { + 'placeholder': '[PH1]', + 'cuis': ['CUI_11', 'CUI_12'] + }, + { + 'placeholder': '[PH2]', + 'cuis': ['CUI_21', 'CUI_22'] + }, + { + 'placeholder': '[PH3]', + 'cuis': ['CUI_31', 'CUI_32'] + } + ], + 'any-combination': False + } + MULTI_PLACEHOLDER_MULTI_CUI_ANY_COMB = {**MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED, + 'any-combination': True} + MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED_BROKEN = deepcopy(MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED) + + @classmethod + def setUpClass(cls) -> None: + # remove a CUI so it's breokn + cls.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED_BROKEN['placeholders'][0]['cuis'] = ['CUI11'] + cuis = OptionSetTests.discover_cuis_for(cls.SIMPLE) + cdb = FakeCDB('NAME', 'CUI1') + total_names_simple = 0 + for cui in cuis: + cdb.cui2names[cui].add(f"CUi-name-2-={cui}") + total_names_simple += len(cdb.cui2names[cui]) + cls.cdb = cdb + cls.tl = targeting.TranslationLayer.from_CDB(cls.cdb) + cls.total_names_simple = total_names_simple + for cui in OptionSetTests.discover_cuis_for(cls.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED): + cdb.cui2names[cui] = {f'CUI=name-4-{cui}'} + + def test_uneven_multi_fails(self): + with self.assertRaises(targeting.ProblematicOptionSetException): + targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED_BROKEN) + + def test_even_builds(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED) + self.assertIsInstance(os, targeting.OptionSet) + self.assertFalse(os.allow_any_combinations) + + def test_any_order_builds(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ANY_COMB) + self.assertIsInstance(os, targeting.OptionSet) + self.assertTrue(os.allow_any_combinations) + + def test_even_has_a_few_targets(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED) + targets = list(os.get_preprocessors_and_targets(self.tl)) + # 2 for each of the 3 PRIMARY options + self.assertEqual(len(targets), 2 * 3) + + def assert_all_unique(self, targets: List[tuple]): + for nr1, ctarget in enumerate(targets): + for nr2, other in enumerate(targets[nr1 + 1:]): + with self.subTest(f"{nr1}x{nr2}"): + self.assertNotEqual(ctarget, other) + self.assertTrue(any(cpart != opart for cpart, opart in zip(ctarget, other))) + + def test_even_has_unique_targets(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ONLY_ORDERED) + targets = list(os.get_preprocessors_and_targets(self.tl)) + self.assert_all_unique(targets) + + def test_any_order_has_many_targets(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ANY_COMB) + targets = list(os.get_preprocessors_and_targets(self.tl)) + # for each of the 3 PRIMARY options, the combinations of all + self.assertEqual(len(targets), 3 * 2 ** 3) + + def test_any_order_has_unique_targets(self): + os = targeting.OptionSet.from_dict(self.MULTI_PLACEHOLDER_MULTI_CUI_ANY_COMB) + targets = list(os.get_preprocessors_and_targets(self.tl)) + self.assert_all_unique(targets) diff --git a/tests/utils/regression/test_utils.py b/tests/utils/regression/test_utils.py new file mode 100644 index 000000000..fa50af074 --- /dev/null +++ b/tests/utils/regression/test_utils.py @@ -0,0 +1,183 @@ +from functools import partial +import os +from json import load as load_json +from enum import Enum, auto + +from unittest import TestCase + +from medcat.utils.regression import utils +from medcat.utils.regression.checking import RegressionSuite + + +class PartialSubstituationTests(TestCase): + TEXT1 = "This [PH1] has one placeholder" + PH1 = "PH1" + REPLACEMENT1 = "" + + def test_fails_with_1_ph(self): + with self.assertRaises(utils.IncompatiblePhraseException): + utils.partial_substitute(self.TEXT1, self.PH1, self.REPLACEMENT1, 0) + + TEXT2 = "This [PH1] has [PH1] multiple (2) placeholders" + + def assert_is_correct_for_regr(self, text: str, placeholder: str): + # should leave a placeholder in + self.assertIn(placeholder, text) + # and only 1 + self.assertEqual(text.count(placeholder), 1) + + def assert_has_replaced_and_is_suitable(self, text: str, placeholder: str, replacement: str, + repl_count: int): + self.assert_is_correct_for_regr(text, placeholder) + self.assertIn(replacement, text) + self.assertEqual(text.count(replacement), repl_count) + + def test_works_with_2_ph_0th(self): + text = utils.partial_substitute(self.TEXT2, self.PH1, self.REPLACEMENT1, 0) + self.assert_has_replaced_and_is_suitable(text, self.PH1, self.REPLACEMENT1, 1) + + def test_works_with_2_ph_1st(self): + text = utils.partial_substitute(self.TEXT2, self.PH1, self.REPLACEMENT1, 1) + self.assert_has_replaced_and_is_suitable(text, self.PH1, self.REPLACEMENT1, 1) + + def test_fails_if_too_high_a_change_nr(self): + with self.assertRaises(utils.IncompatiblePhraseException): + utils.partial_substitute(self.TEXT1, self.PH1, self.REPLACEMENT1, 2) + + TEXT3 = "No [PH1] is [PH1] safe [PH1] eh" + + def test_work_with_3_ph(self): + for nr in range(self.TEXT3.count(self.PH1)): + with self.subTest(f"Placeholder #{nr}"): + text = utils.partial_substitute(self.TEXT3, self.PH1, self.REPLACEMENT1, nr) + self.assert_has_replaced_and_is_suitable(text, self.PH1, self.REPLACEMENT1, 2) + + def test_all_possibilities_are_similar(self): + texts = [utils.partial_substitute(self.TEXT3, self.PH1, self.REPLACEMENT1, nr) + for nr in range(self.TEXT3.count(self.PH1))] + # they should all have the same length + lengths = [len(t) for t in texts] + self.assertTrue(all(cl == lengths[0] for cl in lengths)) + # they should all have the same character composition + # i.e they should compose of the same exact characters + char_compos = [set(t) for t in texts] + self.assertTrue(all(cchars == char_compos[0] for cchars in char_compos)) + # and there should be the same amount for each as well + char_counts = [{c: t.count(c) for c in char_compos[0]} for t in texts] + self.assertTrue(all(cchars == char_counts[0] for cchars in char_counts)) + + +class StringLengthLimiterTests(TestCase): + short_str = "short str" + max_len = 25 + keep_front = max_len // 2 - 3 + keep_rear = max_len // 2 - 3 + long_str = " ".join([short_str] * 10) + limiter = partial(utils.limit_str_len, max_length=max_len, + keep_front=keep_front, keep_rear=keep_rear) + + @classmethod + def setUpClass(cls) -> None: + cls.got_short = cls.limiter(cls.short_str) + cls.got_long = cls.limiter(cls.long_str) + + def test_leaves_short(self): + self.assertEqual(self.short_str, self.got_short) + + def test_changes_long(self): + self.assertNotEqual(self.long_str, self.got_long) + + def test_long_gets_shorter(self): + self.assertGreater(self.long_str, self.got_long) + + def test_long_includes_chars(self, chars: str = 'chars'): + self.assertNotIn(chars, self.long_str) + self.assertIn(chars, self.got_long) + + def test_keeps_max_length(self): + s = self.got_long[:self.max_len] + self.assertEqual(s, self.limiter(s)) + + def test_does_not_keep_1_longer_than_max_lenght(self): + s = self.got_long[:self.max_len + 1] + self.assertNotEqual(s, self.limiter(s)) + + +class MCTExportConverterTests(TestCase): + MCT_EXPORT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', + 'resources', 'medcat_trainer_export.json') + + @classmethod + def setUpClass(cls) -> None: + with open(cls.MCT_EXPORT_PATH) as f: + cls.mct_export = load_json(f) + cls.converter = utils.MedCATTrainerExportConverter(cls.mct_export) + cls.converted = cls.converter.convert() + cls.rc = RegressionSuite.from_dict(cls.converted, name="TEST SUITE 3") + + def test_converted_is_dict(self): + self.assertIsInstance(self.converted, dict) + + def test_converted_can_build(self): + self.assertIsInstance(self.rc, RegressionSuite) + + def test_converted_is_nonempty(self): + self.assertGreater(len(self.rc.cases), 0) + self.assertGreater(self.rc.estimate_total_distinct_cases(), 0) + + +class MyE1(Enum): + """Has class doc-string""" + A1 = auto() + """A1 doc string""" + A2 = auto() + """A2 doc string""" + + +class MyE2(Enum): # no class-level doc string + A1 = auto() + """A1 doc string""" + A2 = auto() + """A2 doc string""" + + +class MyE3(Enum): # this will not be changed + """The CLASS-specific doc string""" + A1 = auto() + """A1 doc string""" + A2 = auto() + """A2 doc string""" + + +class EnumDocStringCapturingClass(TestCase): + + @classmethod + def get_doc_string(cls, cnst: Enum) -> str: + # NOTE: this assumes the doc strings are built in this format + return cnst.name + " doc string" + + @classmethod + def setUpClass(cls) -> None: + utils.add_doc_strings_to_enum(MyE1) + utils.add_doc_strings_to_enum(MyE2) + + def assert_has_doc_strings(self, cls): + for ec in cls: + with self.subTest(str(ec)): + self.assertEqual(ec.__doc__, self.get_doc_string(ec)) + + def test_class_w_class_docstring_gets_doc_strings(self): + self.assert_has_doc_strings(MyE1) + + def test_class_wo_class_docstring_gets_doc_strings(self): + self.assert_has_doc_strings(MyE2) + + def test_unchanged_does_not_have_correct_doc_strings(self): + for ec in MyE3: + with self.subTest(str(ec)): + self.assertNotEqual(ec.__doc__, self.get_doc_string(ec)) + + def test_unchanged_has_class_doc_Strings(self): + for ec in MyE3: + with self.subTest(str(ec)): + self.assertEqual(ec.__doc__, MyE3.__doc__)