Skip to content

Commit

Permalink
Merge pull request #446 from CogStack/load-cdb-cls-method
Browse files Browse the repository at this point in the history
CU-8694pey4u: extract cdb load to cls method
  • Loading branch information
tomolopolis authored May 30, 2024
2 parents 20d2bce + c6f0658 commit db7259a
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,7 @@ def load_model_pack(cls,
model_pack_path = cls.attempt_unpack(zip_path)

# Load the CDB
cdb_path = os.path.join(model_pack_path, "cdb.dat")
nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY)
has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected
json_path = model_pack_path if has_jsons else None
logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
cdb = CDB.load(cdb_path, json_path)
cdb: CDB = cls.load_cdb(model_pack_path)

# load config
config_path = os.path.join(model_pack_path, "config.json")
Expand All @@ -422,11 +417,9 @@ def load_model_pack(cls,
addl_ner.append(trf)

# Find metacat models in the model_pack
meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_cats = []
for meta_path in meta_paths:
meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
config_dict=meta_cat_config_dict))
meta_cats: List[MetaCAT] = []
if load_meta_models:
meta_cats = [mc[1] for mc in cls.load_meta_cats(model_pack_path)]

# Find Rel models in model_pack
rel_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('rel_')] if load_rel_models else []
Expand All @@ -439,6 +432,47 @@ def load_model_pack(cls,

return cat

@classmethod
def load_cdb(cls, model_pack_path: str) -> CDB:
"""
Loads the concept database from the provided model pack path
Args:
model_pack_path (str): path to model pack, zip or dir.
Returns:
CDB: The loaded concept database
"""
cdb_path = os.path.join(model_pack_path, "cdb.dat")
nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY)
has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected
json_path = model_pack_path if has_jsons else None
logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
cdb = CDB.load(cdb_path, json_path)
return cdb

@classmethod
def load_meta_cats(cls, model_pack_path: str, meta_cat_config_dict: Optional[Dict] = None) -> List[Tuple[str, MetaCAT]]:
"""
Args:
model_pack_path (str): path to model pack, zip or dir.
meta_cat_config_dict (Optional[Dict]):
A config dict that will overwrite existing configs in meta_cat.
e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
Defaults to None.
Returns:
List[Tuple(str, MetaCAT)]: list of pairs of meta cat model names (i.e. the task name) and the MetaCAT models.
"""
meta_paths = [os.path.join(model_pack_path, path)
for path in os.listdir(model_pack_path) if path.startswith('meta_')]
meta_cats = []
for meta_path in meta_paths:
meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
config_dict=meta_cat_config_dict))
return list(zip(meta_paths, meta_cats))

def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]:
"""Push the text through the pipeline.
Expand Down

0 comments on commit db7259a

Please sign in to comment.