Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cu 8693n892x dependency snapshots importlib #35

Closed
wants to merge 11 commits into from
Closed
8 changes: 1 addition & 7 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Check types
run: |
python -m mypy --follow-imports=normal medcat
- name: Lint
run: |
flake8 medcat
- name: Test
run: |
timeout 17m python -m unittest discover
python -m unittest tests.utils.saving.test_envsnapshot

publish-to-test-pypi:

Expand Down
7 changes: 7 additions & 0 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from medcat.utils.decorators import deprecated
from medcat.ner.transformers_ner import TransformersNER
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters

Expand Down Expand Up @@ -318,6 +319,12 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
with open(model_card_path, 'w') as f:
json.dump(self.get_model_card(as_dict=True), f, indent=2)

# add a dependency snapshot
env_info = get_environment_info()
env_info_path = os.path.join(save_dir_path, ENV_SNAPSHOT_FILE_NAME)
with open(env_info_path, 'w') as f:
json.dump(env_info, f)

# Zip everything
shutil.make_archive(os.path.join(_save_dir_path, model_pack_name), 'zip', root_dir=save_dir_path)

Expand Down
60 changes: 60 additions & 0 deletions medcat/utils/saving/envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Dict, Any, Set

import re
import pkg_resources
import platform
from importlib_metadata import distribution


ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json"


def get_direct_dependencies() -> Set[str]:
"""Get the set of direct dependeny names.

The current implementation uses importlib_metadata to figure out
the names of the required packages and removes their version info.

Raises:
ValueError: If the unlikely event that the dependencies are unable to be obtained.

Returns:
Set[str]: The set of direct dependeny names.
"""
package_name = __package__.split(".")[0]
dist = distribution(package_name)
deps = dist.metadata.get_all('Requires-Dist')
if not deps:
raise ValueError("Unable to identify dependencies")
return set(re.split("[<=>~]", dep)[0] for dep in deps)


def get_installed_packages() -> List[List[str]]:
"""Get the installed packages and their versions.

Returns:
List[List[str]]: List of lists. Each item contains of a dependency name and version.
"""
direct_deps = get_direct_dependencies()
installed_packages = []
for package in pkg_resources.working_set:
if package.project_name not in direct_deps:
continue
installed_packages.append([package.project_name, package.version])
return installed_packages


def get_environment_info() -> Dict[str, Any]:
"""Get the current environment information.

This includes dependency versions, the OS, the CPU architecture and the python version.

Returns:
Dict[str, Any]: _description_
"""
return {
"dependencies": get_installed_packages(),
"os": platform.platform(),
"cpu_architecture": platform.machine(),
"python_version": platform.python_version()
}
133 changes: 133 additions & 0 deletions tests/utils/saving/test_envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any
import platform
import os
import tempfile
import json
import zipfile
import re

from medcat.cat import CAT
from medcat.utils.saving import envsnapshot

import unittest


def list_zip_contents(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
return zip_ref.namelist()



ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json"
SETUP_PY_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "setup.py"))
SETUP_PY_REGEX = re.compile("install_requires=\[([\s\S]*?)\]")


def get_direct_dependencies_regex() -> set:
if not os.path.exists(SETUP_PY_PATH):
raise FileNotFoundError(f"{SETUP_PY_PATH} does not exist.")
with open(SETUP_PY_PATH) as f:
setup_py_code = f.read()
found = SETUP_PY_REGEX.findall(setup_py_code)
if not found:
raise ValueError("Did not find install requirements in setup.py")
if len(found) > 1:
raise ValueError("Ambiguous install requirements in setup.py")
deps_str = found[0]
# evaluate list of dependencies (including potential version pins)
deps: list = eval("[" + deps_str + "]")
# remove versions where applicable
return set(re.split("[<=>~]", dep)[0] for dep in deps)


class DirectDependenciesTests(unittest.TestCase):

def setUp(self) -> None:
self.direct_deps = envsnapshot.get_direct_dependencies()

def test_nonempty(self):
self.assertTrue(self.direct_deps)

def test_does_not_contain_versions(self, version_starters: str = '<=>~'):
for dep in self.direct_deps:
for vs in version_starters:
with self.subTest(f"DEP '{dep}' check for '{vs}'"):
self.assertNotIn(vs, dep)

def test_deps_are_installed_packages(self):
for dep in self.direct_deps:
with self.subTest(f"Has '{dep}'"):
envsnapshot.pkg_resources.require(dep)

def test_deps_are_same_as_per_regex(self):
regex_deps = get_direct_dependencies_regex()
self.assertEqual(regex_deps, self.direct_deps)


class EnvSnapshotAloneTests(unittest.TestCase):

def setUp(self) -> None:
self.env_info = envsnapshot.get_environment_info()

def test_info_is_dict(self):
self.assertIsInstance(self.env_info, dict)

def test_info_is_not_empty(self):
self.assertTrue(self.env_info)

def assert_has_target(self, target: str, expected: Any):
self.assertIn(target, self.env_info)
py_ver = self.env_info[target]
self.assertEqual(py_ver, expected)

def test_has_os(self):
self.assert_has_target("os", platform.platform())

def test_has_py_ver(self):
self.assert_has_target("python_version", platform.python_version())

def test_has_cpu_arch(self):
self.assert_has_target("cpu_architecture", platform.machine())

def test_has_dependencies(self, name: str = "dependencies"):
# NOTE: just making sure it's a anon-empty list
self.assertIn(name, self.env_info)
deps = self.env_info[name]
self.assertTrue(deps)

def test_all_direct_dependencies_are_installed(self):
deps = self.env_info['dependencies']
direct_deps = envsnapshot.get_direct_dependencies()
self.assertEqual(len(deps), len(direct_deps))


CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples")
ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME


class EnvSnapshotInCATTests(unittest.TestCase):
expected_env = envsnapshot.get_environment_info()

@classmethod
def setUpClass(cls) -> None:
cls.cat = CAT.load_model_pack(CAT_PATH)
cls._temp_dir = tempfile.TemporaryDirectory()
mpn = cls.cat.create_model_pack(cls._temp_dir.name)
cls.cat_folder = os.path.join(cls._temp_dir.name, mpn)
cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME)

def test_has_environment(self):
self.assertTrue(os.path.exists(self.envrion_file_path))

def test_eviron_saved(self):
with open(self.envrion_file_path) as f:
saved_info: dict = json.load(f)
self.assertEqual(saved_info.keys(), self.expected_env.keys())
for k in saved_info:
with self.subTest(k):
v1, v2 = saved_info[k], self.expected_env[k]
self.assertEqual(v1, v2)

def test_zip_has_env_snapshot(self):
filenames = list_zip_contents(self.cat_folder + ".zip")
self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames)
4 changes: 3 additions & 1 deletion tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from medcat.vocab import Vocab

from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import ENV_SNAPSHOT_FILE_NAME

import medcat.utils.saving.coding as _

Expand Down Expand Up @@ -60,6 +61,7 @@ class ModelCreationTests(unittest.TestCase):
json_model_pack = tempfile.TemporaryDirectory()
EXAMPLES = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "..", "..", "..", "examples")
EXCEPTIONAL_JSONS = ['model_card.json', ENV_SNAPSHOT_FILE_NAME]

@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_dill_to_json(self):
SPECIALITY_NAMES) - len(ONE2MANY))
for json in jsons:
with self.subTest(f'JSON {json}'):
if json.endswith('model_card.json'):
if any(json.endswith(exception) for exception in self.EXCEPTIONAL_JSONS):
continue # ignore model card here
if any(name in json for name in ONE2MANY):
# ignore cui2many and name2many
Expand Down
Loading