Skip to content

Commit

Permalink
Support true 16-bit precision with deepspeed (#17576)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored May 12, 2023
1 parent cbc536a commit 7268670
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ jobs:
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
run: |
python -m coverage run --source ${{ env.COVERAGE_SCOPE }} \
-m pytest -v --timeout=30 --durations=50 --random-order-seed=44
-m pytest -v --timeout=30 --durations=50
- name: Statistics
if: success()
Expand Down
30 changes: 25 additions & 5 deletions src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal, TYPE_CHECKING
from contextlib import contextmanager
from typing import Any, Generator, Literal, TYPE_CHECKING

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import get_args

from lightning.fabric.plugins.precision.precision import Precision
Expand All @@ -28,7 +30,7 @@
if _DEEPSPEED_AVAILABLE: # type: ignore[has-type]
import deepspeed

_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true", "16-mixed", "bf16-mixed"]


class DeepSpeedPrecision(Precision):
Expand All @@ -51,11 +53,29 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
)
self.precision = precision

precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16, "32-true": torch.float32}
self._desired_input_dtype = precision_to_type[self.precision]
precision_to_type = {
"bf16-mixed": torch.bfloat16,
"16-mixed": torch.float16,
"bf16-true": torch.bfloat16,
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_dtype = precision_to_type[self.precision]

def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_dtype)
return module

@contextmanager
def init_context(self) -> Generator[None, None, None]:
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_dtype if "true" in self.precision else default_dtype)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
Expand Down
12 changes: 8 additions & 4 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,13 @@ def init_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision.precision == "16-mixed":
# Note: For the mixed settings '16-mixed' and 'bf16-mixed', we shouldn't convert the weights to half
# precision, but we are keeping the 'bug' for backward compatibility.
# TODO: This can be properly implemented once https://github.com/Lightning-AI/lightning/issues/17581
# gets resolved
if self.precision.precision in ("16-mixed", "16-true"):
dtype = torch.float16
elif self.precision.precision == "bf16-mixed":
elif self.precision.precision in ("bf16-mixed", "bf16-true"):
dtype = torch.bfloat16
else:
dtype = torch.float32
Expand Down Expand Up @@ -624,7 +628,7 @@ def _format_config(self) -> None:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision.precision == "16-mixed":
if self.precision.precision in ("16-mixed", "16-true"):
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -636,7 +640,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision.precision == "bf16-mixed":
elif "bf16" not in self.config and self.precision.precision in ("bf16-mixed", "bf16-true"):
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

Expand Down
53 changes: 53 additions & 0 deletions tests/tests_fabric/plugins/precision/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest.mock import Mock

import pytest
import torch

from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning.fabric.utilities.types import Steppable
Expand Down Expand Up @@ -49,3 +50,55 @@ def test_deepspeed_precision_optimizer_step():
optimizer = model = Mock()
precision.optimizer_step(optimizer, lr_kwargs={})
model.step.assert_called_once_with(lr_kwargs={})


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.bfloat16),
("16-mixed", torch.float16),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_selected_dtype(precision, expected_dtype):
plugin = DeepSpeedPrecision(precision=precision)
assert plugin.precision == precision
assert plugin._desired_dtype == expected_dtype


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_module_init_context(precision, expected_dtype):
plugin = DeepSpeedPrecision(precision=precision)
with plugin.init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == expected_dtype
assert model.weight.dtype == expected_dtype


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = DeepSpeedPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype

0 comments on commit 7268670

Please sign in to comment.