Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor XLA and TPU checks across codebase #14550

Merged
merged 44 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b212a48
Refactor XLADeviceUtils
carmocca Sep 18, 2022
2995937
Merge branch 'master' into refactor/xla-utils
carmocca Sep 19, 2022
f71705e
Merge branch 'master' into refactor/xla-utils
carmocca Sep 27, 2022
6a40317
Fix?
carmocca Sep 27, 2022
74c536a
Circular dependencies
carmocca Sep 28, 2022
f17360f
Fix Lite tests
carmocca Sep 28, 2022
376605a
PL test fixes
carmocca Sep 28, 2022
3cf69e2
Merge branch 'master' into refactor/xla-utils
carmocca Sep 28, 2022
dcd2f3c
Include lite
carmocca Sep 28, 2022
2ed8471
More cleanup
carmocca Sep 28, 2022
afa7c52
Updating installation
carmocca Sep 28, 2022
b721390
Try to avoid deprecation warning
carmocca Sep 28, 2022
4d1ee98
Undo change
carmocca Sep 28, 2022
18a1a3a
Undo change
carmocca Sep 28, 2022
a101345
Im dumb
carmocca Sep 28, 2022
ffe0de4
Fix App tests (#14922)
otaj Sep 28, 2022
ff64129
Merge branch 'master' into refactor/xla-utils
carmocca Sep 28, 2022
5b3018f
Reorder
carmocca Sep 28, 2022
ed8819e
Fix expected exception
carmocca Sep 28, 2022
778b9c1
Fix 2
carmocca Sep 28, 2022
3d0ff27
Clean env vars
carmocca Sep 28, 2022
ca1116c
Merge branch 'master' into refactor/xla-utils
carmocca Oct 3, 2022
0feac92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2022
cc0cb12
add import
carmocca Oct 3, 2022
5a9fcde
Fix
carmocca Oct 3, 2022
9dab948
Fixes
carmocca Oct 3, 2022
ed6909d
Missed one
carmocca Oct 3, 2022
6f08ded
silly pycharm
carmocca Oct 3, 2022
b6743ad
Mark fork tests as standalone
carmocca Oct 3, 2022
2c9739f
Fixes
carmocca Oct 4, 2022
c0f2e0b
Fix TPU test
carmocca Oct 4, 2022
e90eeaa
Patch device in conftest
carmocca Oct 4, 2022
51f1704
TPU fixes
carmocca Oct 4, 2022
6ca517f
Update monkeypatch
carmocca Oct 4, 2022
eb87e0d
TPU fixes
carmocca Oct 4, 2022
5be521c
Fix TPU tests
carmocca Oct 4, 2022
d52c8e1
Fixed TPU
carmocca Oct 4, 2022
6ef17b4
Fix TPU
carmocca Oct 4, 2022
9b117d5
REVERT ME - let all tpu tests run
carmocca Oct 4, 2022
d53bc2f
Fix TPU
carmocca Oct 4, 2022
0995832
Revert "REVERT ME - let all tpu tests run"
carmocca Oct 4, 2022
9133293
Merge branch 'master' into refactor/xla-utils
tchaton Oct 4, 2022
ca9a6e8
Merge branch 'master' into refactor/xla-utils
carmocca Oct 4, 2022
d296b7b
Fix test on pt1.9
carmocca Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ local tputests = base.BaseTest {
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"

echo "--- Sanity check TPU availability ---"
python -c "from lightning_lite.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
python -c "from pytorch_lightning.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
echo "Sanity check passed!"

echo "--- Running Lite tests ---"
cd tests/tests_lite
PL_RUN_TPU_TESTS=1 coverage run --source=lightning_lite -m pytest -vv --durations=0 ./
Expand Down
2 changes: 0 additions & 2 deletions docs/source-lit/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,6 @@ def find_source():
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/mps_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ If Lightning can't detect the Apple Silicon hardware, it will raise this excepti

.. code::

MisconfigurationException: MPSAccelerator can not run on your system since the accelerator is not available.
MisconfigurationException: `MPSAccelerator` can not run on your system since the accelerator is not available.

If you are seeing this despite running on an ARM-enabled Mac, the most likely cause is that your Python is being emulated and thinks it is running on an Intel CPU.
To solve this, re-install your python executable (and if using environment managers like conda, you have to reinstall these as well) by downloading the Apple M1/M2 build (not Intel!), for example `here <https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links>`_.
2 changes: 0 additions & 2 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ def package_list_from_file(file):
from pytorch_lightning.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
)
Expand Down
75 changes: 72 additions & 3 deletions src/lightning_lite/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import functools
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from lightning_utilities.core.imports import RequirementCache

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities.device_parser import _check_data_type
from lightning_lite.utilities.imports import _TPU_AVAILABLE


class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def setup_device(self, device: torch.device) -> None:
pass

Expand All @@ -47,8 +56,10 @@ def auto_device_count() -> int:
return 8

@staticmethod
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
return _TPU_AVAILABLE
# check `_XLA_AVAILABLE` again to avoid launching processes
return bool(_XLA_AVAILABLE) and _is_device_tpu()

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
Expand All @@ -59,6 +70,64 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
)


# define TPU availability timeout in seconds
TPU_CHECK_TIMEOUT = 60


def _inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
try:
queue.put(func(*args, **kwargs))
except Exception:
traceback.print_exc()
queue.put(None)


def _multi_process(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
queue: Queue = Queue()
proc = Process(target=_inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper


@_multi_process
def _is_device_tpu() -> bool:
"""Check if TPU devices are available. Runs XLA device check within a separate process.

Return:
A boolean value indicating if TPU devices are available
"""
if not _XLA_AVAILABLE:
return False
import torch_xla.core.xla_model as xm

# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))


_XLA_AVAILABLE = RequirementCache("torch_xla")
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def tpu_distributed() -> bool:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not TPUAccelerator.is_available():
return False
import torch_xla.core.xla_model as xm

return xm.xrt_world_size() > 1


def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
"""
Parses the tpu_cores given in the format as accepted by the
Expand Down
17 changes: 10 additions & 7 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
from lightning_lite.utilities.device_parser import determine_root_gpu_device
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
Expand Down Expand Up @@ -301,7 +301,7 @@ def _check_device_config_and_set_final_flags(
def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if self._accelerator_flag == "auto":
if _TPU_AVAILABLE:
if TPUAccelerator.is_available():
return "tpu"
if _IPU_AVAILABLE:
return "ipu"
Expand All @@ -328,23 +328,26 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
else:
assert self._accelerator_flag is not None
self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag)
accelerator_cls = self.accelerator.__class__

if not self.accelerator.is_available():
if not accelerator_cls.is_available():
available_accelerator = [
acc_str for acc_str in self._registered_accelerators if ACCELERATOR_REGISTRY.get(acc_str).is_available()
acc_str
for acc_str in self._registered_accelerators
if ACCELERATOR_REGISTRY[acc_str]["accelerator"].is_available()
]
raise RuntimeError(
f"{self.accelerator.__class__.__qualname__} can not run on your system"
f"`{accelerator_cls.__qualname__}` can not run on your system"
" since the accelerator is not available. The following accelerator(s)"
" is available and can be passed into `accelerator` argument of"
f" `Lite`: {available_accelerator}."
)

self._set_devices_flag_if_auto_passed()

self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
self._devices_flag = accelerator_cls.parse_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)

def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or self._devices_flag is None:
Expand Down
26 changes: 20 additions & 6 deletions src/lightning_lite/plugins/environments/xla_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
# limitations under the License.
import logging
import os
from typing import Any

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.imports import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm

log = logging.getLogger(__name__)

Expand All @@ -31,36 +28,53 @@ class XLAEnvironment(ClusterEnvironment):
`here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

@property
def creates_processes_externally(self) -> bool:
return False

@property
def main_address(self) -> str:
import torch_xla.core.xla_env_vars as xenv

return os.environ[xenv.TPU_MESH_CTLER_ADDR]

@property
def main_port(self) -> int:
import torch_xla.core.xla_env_vars as xenv

return int(os.environ[xenv.TPU_MESH_CTLER_PORT])

@staticmethod
def detect() -> bool:
return _TPU_AVAILABLE
return TPUAccelerator.is_available()

def world_size(self) -> int:
import torch_xla.core.xla_model as xm

return xm.xrt_world_size()

def set_world_size(self, size: int) -> None:
log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
import torch_xla.core.xla_model as xm

return xm.get_ordinal()

def set_global_rank(self, rank: int) -> None:
log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
import torch_xla.core.xla_model as xm

return xm.get_local_ordinal()

def node_rank(self) -> int:
import torch_xla.core.xla_env_vars as xenv

return int(os.environ.get(xenv.HOST_ORDINAL, 0))
13 changes: 9 additions & 4 deletions src/lightning_lite/plugins/io/xla_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@

from lightning_utilities.core.apply_func import apply_to_collection

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE
from lightning_lite.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class XLACheckpointIO(TorchCheckpointIO):
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Expand All @@ -55,4 +58,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
# Ref: https://github.com/pytorch/xla/issues/2773
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
import torch_xla.core.xla_model as xm

xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)
11 changes: 5 additions & 6 deletions src/lightning_lite/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@

from torch.multiprocessing import get_context

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning_lite.utilities import _TPU_AVAILABLE
from lightning_lite.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
import torch_xla.distributed.xla_multiprocessing as xmp
else:
xmp = None

if TYPE_CHECKING:
from lightning_lite.strategies import XLAStrategy

Expand All @@ -47,6 +42,8 @@ class _XLALauncher(_MultiProcessingLauncher):
"""

def __init__(self, strategy: "XLAStrategy") -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(strategy=strategy, start_method="fork")

@property
Expand All @@ -66,6 +63,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""
context = get_context(self._start_method)
return_queue = context.SimpleQueue()
import torch_xla.distributed.xla_multiprocessing as xmp

xmp.spawn(
self._wrapping_function,
args=(function, args, kwargs, return_queue),
Expand Down
Loading