-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
only enable query key scaling during fp16 #7946
Conversation
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
@@ -1544,6 +1544,11 @@ def build_transformer_config(self) -> TransformerConfig: | |||
|
|||
attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True | |||
apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False) | |||
|
|||
if apply_query_key_layer_scaling and not model_parallel_config.fp16: | |||
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think model_parallel_config.fp16 is the right check though. That arg is for fp16 + megatron_amp_O2.
Maybe we should just check trainer.precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can I check against self.torch_dtype
?
self.torch_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed 25b7349
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
if fp16_enabled: | ||
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1" | ||
else: | ||
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we set the env var here as well to "0" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, otherwise it will error 366cf61
Signed-off-by: Gerald Shen <geshen@nvidia.com>
jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * add warning Signed-off-by: Gerald Shen <geshen@nvidia.com> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * remove var from jenkens file Signed-off-by: Gerald Shen <geshen@nvidia.com> * fix test by setting TE var Signed-off-by: Gerald Shen <geshen@nvidia.com> * set to 0 if disabled Signed-off-by: Gerald Shen <geshen@nvidia.com> --------- Signed-off-by: Gerald Shen <geshen@nvidia.com> Signed-off-by: Piotr Żelasko <petezor@gmail.com>
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * add warning Signed-off-by: Gerald Shen <geshen@nvidia.com> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * remove var from jenkens file Signed-off-by: Gerald Shen <geshen@nvidia.com> * fix test by setting TE var Signed-off-by: Gerald Shen <geshen@nvidia.com> * set to 0 if disabled Signed-off-by: Gerald Shen <geshen@nvidia.com> --------- Signed-off-by: Gerald Shen <geshen@nvidia.com> Signed-off-by: Sasha Meister <ameister@nvidia.com>
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * add warning Signed-off-by: Gerald Shen <geshen@nvidia.com> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <geshen@nvidia.com> * remove var from jenkens file Signed-off-by: Gerald Shen <geshen@nvidia.com> * fix test by setting TE var Signed-off-by: Gerald Shen <geshen@nvidia.com> * set to 0 if disabled Signed-off-by: Gerald Shen <geshen@nvidia.com> --------- Signed-off-by: Gerald Shen <geshen@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Enable query key scaling only during fp16, since this is when TE enables query key scaling. https://github.com/NVIDIA/TransformerEngine/blob/666539f36275fa9c0fbc99f9ea50f2d6e29e336f/transformer_engine/pytorch/attention.py#L940