Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8693n892x environment/dependency snapshots #438

Merged
merged 15 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions install_requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
'pandas>=1.4.2' # first to support 3.11
'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump
'spacy>=3.6.0,<4.0.0' # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
'scipy~=1.9.2' # 1.9.2 is first to support 3.11
'transformers>=4.34.0,<5.0.0' # avoid major version bump
'accelerate>=0.23.0' # required by Trainer class in de-id
'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
'tqdm>=4.27'
'scikit-learn>=1.1.3,<2.0.0' # 1.1.3 is first to supporrt 3.11; avoid major version bump
'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
'datasets>=2.2.2,<3.0.0' # avoid major bump
'jsonpickle>=2.0.0' # allow later versions, tested with 3.0.0
'psutil>=5.8.0'
# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
'multiprocess~=0.70.12' # 0.70.14 seemed to work just fine
'aiofiles>=0.8.0' # allow later versions, tested with 22.1.0
'ipywidgets>=7.6.5' # allow later versions, tested with 0.8.0
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
'blis>=0.7.5' # allow later versions, tested with 0.7.9
'click>=8.0.4' # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
"humanfriendly~=10.0" # for human readable file / RAM sizes
"peft>=0.8.2"
7 changes: 7 additions & 0 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from medcat.utils.decorators import deprecated
from medcat.ner.transformers_ner import TransformersNER
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters

Expand Down Expand Up @@ -318,6 +319,12 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
with open(model_card_path, 'w') as f:
json.dump(self.get_model_card(as_dict=True), f, indent=2)

# add a dependency snapshot
env_info = get_environment_info()
env_info_path = os.path.join(save_dir_path, ENV_SNAPSHOT_FILE_NAME)
with open(env_info_path, 'w') as f:
json.dump(env_info, f)

# Zip everything
shutil.make_archive(os.path.join(_save_dir_path, model_pack_name), 'zip', root_dir=save_dir_path)

Expand Down
60 changes: 60 additions & 0 deletions medcat/utils/saving/envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Dict, Any, Set

import os
import re
import pkg_resources
import platform


ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json"
SETUP_PY_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "setup.py"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these module level vars are not needed anymore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

SETUP_PY_REGEX = re.compile("install_requires=\[([\s\S]*?)\]")


def get_direct_dependencies() -> Set[str]:
"""Get the set of direct dependeny names.

The current implementation reads install_requires.txt for dependenceies,
removes comments, whitespace, quotes; removes the versions and returns
the names as a set.

Returns:
Set[str]: The set of direct dependeny names.
"""
with open("install_requires.txt") as f:
# read every line, strip quotes and comments
dep_lines = [line.split("#")[0].replace("'", "").replace('"', "").strip() for line in f.readlines()]
# remove comment-only (or empty) lines
deps = [dep for dep in dep_lines if dep]
return set(re.split("[<=>~]", dep)[0] for dep in deps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

install directly from a git commit does this work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added @ separator.



def get_installed_packages() -> List[List[str]]:
"""Get the installed packages and their versions.

Returns:
List[List[str]]: List of lists. Each item contains of a dependency name and version.
"""
direct_deps = get_direct_dependencies()
installed_packages = []
for package in pkg_resources.working_set:
if package.project_name not in direct_deps:
continue
Comment on lines +54 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO why not save transitive dependencies in the snapshot as well, together with direct dependencies? It would be easier for users like me to create a virtual environment with high fidelity to the original training environment before starting to use the model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few issues with listing all installed packages:

  • The list of dependencies may change depending on hardware (i.e nvidia or cuda related packages)
  • The version of dependencies may change depending for different python versions
  • Some environments may have unrelated packages

With that said, that was my original idea as well. If we're able to identify the exact environment (down to each individual package) then we should be able to exactly replicate behaviour across environments.

But the main reason to go with the current approach is the fact that downstream / transitive dependencies aren't (or at least shouldn't be) our responsibility. Our direct dependencies should know what they can and cannot work with.
There are, of course, situations where that expectation fails. Especially since we depend on quite a few older versions of other projects.

So all in all, I think it would be worth it to add something to this in the future (as part of another PR). Whether it's all installed packages or some subset of it would need to be figured out as part of that.
But I think this PR will still allow as to get at least 90% of the way there in terms of being able to validate our environment.

installed_packages.append([package.project_name, package.version])
return installed_packages


def get_environment_info() -> Dict[str, Any]:
"""Get the current environment information.

This includes dependency versions, the OS, the CPU architecture and the python version.

Returns:
Dict[str, Any]: _description_
"""
return {
"dependencies": get_installed_packages(),
"os": platform.platform(),
"cpu_architecture": platform.machine(),
"python_version": platform.python_version()
}
34 changes: 8 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
long_description = fh.read()


with open("install_requires.txt") as f:
# read every line, strip quotes and comments
dep_lines = [l.split("#")[0].replace("'", "").replace('"', "").strip() for l in f.readlines()]
# remove comment-only (or empty) lines
install_requires = [dep for dep in dep_lines if dep]


setuptools.setup(
name="medcat",
setup_requires=["setuptools_scm"],
Expand All @@ -17,32 +24,7 @@
packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets',
'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction',
'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'],
install_requires=[
'numpy>=1.22.0,<1.26.0', # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
'pandas>=1.4.2', # first to support 3.11
'gensim>=4.3.0,<5.0.0', # 5.3.0 is first to support 3.11; avoid major version bump
'spacy>=3.6.0,<4.0.0', # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
'scipy~=1.9.2', # 1.9.2 is first to support 3.11
'transformers>=4.34.0,<5.0.0', # avoid major version bump
'accelerate>=0.23.0', # required by Trainer class in de-id
'torch>=1.13.0,<3.0.0', # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
'tqdm>=4.27',
'scikit-learn>=1.1.3,<2.0.0', # 1.1.3 is first to supporrt 3.11; avoid major version bump
'dill>=0.3.6,<1.0.0', # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
'datasets>=2.2.2,<3.0.0', # avoid major bump
'jsonpickle>=2.0.0', # allow later versions, tested with 3.0.0
'psutil>=5.8.0',
# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
'multiprocess~=0.70.12', # 0.70.14 seemed to work just fine
'aiofiles>=0.8.0', # allow later versions, tested with 22.1.0
'ipywidgets>=7.6.5', # allow later versions, tested with 0.8.0
'xxhash>=3.0.0', # allow later versions, tested with 3.1.0
'blis>=0.7.5', # allow later versions, tested with 0.7.9
'click>=8.0.4', # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes
"humanfriendly~=10.0", # for human readable file / RAM sizes
"peft>=0.8.2", # allow later versions, tested with 0.10.0
],
install_requires=install_requires,
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
Expand Down
105 changes: 105 additions & 0 deletions tests/utils/saving/test_envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any
import platform
import os
import tempfile
import json
import zipfile

from medcat.cat import CAT
from medcat.utils.saving import envsnapshot

import unittest


def list_zip_contents(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
return zip_ref.namelist()


class DirectDependenciesTests(unittest.TestCase):

def setUp(self) -> None:
self.direct_deps = envsnapshot.get_direct_dependencies()

def test_nonempty(self):
self.assertTrue(self.direct_deps)

def test_does_not_contain_versions(self, version_starters: str = '<=>~'):
for dep in self.direct_deps:
for vs in version_starters:
with self.subTest(f"DEP '{dep}' check for '{vs}'"):
self.assertNotIn(vs, dep)

def test_deps_are_installed_packages(self):
for dep in self.direct_deps:
with self.subTest(f"Has '{dep}'"):
envsnapshot.pkg_resources.require(dep)


class EnvSnapshotAloneTests(unittest.TestCase):

def setUp(self) -> None:
self.env_info = envsnapshot.get_environment_info()

def test_info_is_dict(self):
self.assertIsInstance(self.env_info, dict)

def test_info_is_not_empty(self):
self.assertTrue(self.env_info)

def assert_has_target(self, target: str, expected: Any):
self.assertIn(target, self.env_info)
py_ver = self.env_info[target]
self.assertEqual(py_ver, expected)

def test_has_os(self):
self.assert_has_target("os", platform.platform())

def test_has_py_ver(self):
self.assert_has_target("python_version", platform.python_version())

def test_has_cpu_arch(self):
self.assert_has_target("cpu_architecture", platform.machine())

def test_has_dependencies(self, name: str = "dependencies"):
# NOTE: just making sure it's a anon-empty list
self.assertIn(name, self.env_info)
deps = self.env_info[name]
self.assertTrue(deps)

def test_all_direct_dependencies_are_installed(self):
deps = self.env_info['dependencies']
direct_deps = envsnapshot.get_direct_dependencies()
self.assertEqual(len(deps), len(direct_deps))


CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples")
ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME


class EnvSnapshotInCATTests(unittest.TestCase):
expected_env = envsnapshot.get_environment_info()

@classmethod
def setUpClass(cls) -> None:
cls.cat = CAT.load_model_pack(CAT_PATH)
cls._temp_dir = tempfile.TemporaryDirectory()
mpn = cls.cat.create_model_pack(cls._temp_dir.name)
cls.cat_folder = os.path.join(cls._temp_dir.name, mpn)
cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME)

def test_has_environment(self):
self.assertTrue(os.path.exists(self.envrion_file_path))

def test_eviron_saved(self):
with open(self.envrion_file_path) as f:
saved_info: dict = json.load(f)
self.assertEqual(saved_info.keys(), self.expected_env.keys())
for k in saved_info:
with self.subTest(k):
v1, v2 = saved_info[k], self.expected_env[k]
self.assertEqual(v1, v2)

def test_zip_has_env_snapshot(self):
filenames = list_zip_contents(self.cat_folder + ".zip")
self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames)
4 changes: 3 additions & 1 deletion tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from medcat.vocab import Vocab

from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import ENV_SNAPSHOT_FILE_NAME

import medcat.utils.saving.coding as _

Expand Down Expand Up @@ -60,6 +61,7 @@ class ModelCreationTests(unittest.TestCase):
json_model_pack = tempfile.TemporaryDirectory()
EXAMPLES = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "..", "..", "..", "examples")
EXCEPTIONAL_JSONS = ['model_card.json', ENV_SNAPSHOT_FILE_NAME]

@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_dill_to_json(self):
SPECIALITY_NAMES) - len(ONE2MANY))
for json in jsons:
with self.subTest(f'JSON {json}'):
if json.endswith('model_card.json'):
if any(json.endswith(exception) for exception in self.EXCEPTIONAL_JSONS):
continue # ignore model card here
if any(name in json for name in ONE2MANY):
# ignore cui2many and name2many
Expand Down
Loading