-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic mixed precision #934
Conversation
@vinhngx This is great, thanks for adding this! Could you also include APEX support for the |
Thanks @alanakbik. I've added APEX to the language model trainer.
|
Awesome, thanks! |
I tried it in combination with Without 2019-07-29 20:42:27,531 epoch 1 - iter 0/795 - loss 55.33607864 throughput (samples/sec): 2900.75
2019-07-29 20:43:08,202 epoch 1 - iter 79/795 - loss 9.06459931 throughput (samples/sec): 31.19
2019-07-29 20:43:50,289 epoch 1 - iter 158/795 - loss 7.42229333 throughput (samples/sec): 30.12
2019-07-29 20:44:31,381 epoch 1 - iter 237/795 - loss 6.46287990 throughput (samples/sec): 30.85
2019-07-29 20:45:12,541 epoch 1 - iter 316/795 - loss 5.83835078 throughput (samples/sec): 30.80
2019-07-29 20:45:53,771 epoch 1 - iter 395/795 - loss 5.42233314 throughput (samples/sec): 30.75
2019-07-29 20:46:34,585 epoch 1 - iter 474/795 - loss 5.05916398 throughput (samples/sec): 31.06
2019-07-29 20:47:16,937 epoch 1 - iter 553/795 - loss 4.78533588 throughput (samples/sec): 29.93
2019-07-29 20:47:59,127 epoch 1 - iter 632/795 - loss 4.53341620 throughput (samples/sec): 30.05
2019-07-29 20:48:40,761 epoch 1 - iter 711/795 - loss 4.35730412 throughput (samples/sec): 30.45
2019-07-29 20:49:22,217 epoch 1 - iter 790/795 - loss 4.17682579 throughput (samples/sec): 30.58
2019-07-29 20:49:23,789 ----------------------------------------------------------------------------------------------------
2019-07-29 20:49:23,789 EPOCH 1 done: loss 4.1919 - lr 0.1000
2019-07-29 20:51:24,836 DEV : loss 2.9106078147888184 - score 0.6502
2019-07-29 20:51:24,872 BAD EPOCHS (no improvement): 0
2019-07-29 20:51:25,456 ----------------------------------------------------------------------------------------------------
2019-07-29 20:51:25,563 epoch 2 - iter 0/795 - loss 0.98591888 throughput (samples/sec): 12115.87
2019-07-29 20:51:35,162 epoch 2 - iter 79/795 - loss 2.51360194 throughput (samples/sec): 133.39
2019-07-29 20:51:44,908 epoch 2 - iter 158/795 - loss 2.54102277 throughput (samples/sec): 131.35
2019-07-29 20:51:54,648 epoch 2 - iter 237/795 - loss 2.55147406 throughput (samples/sec): 131.44
2019-07-29 20:52:04,199 epoch 2 - iter 316/795 - loss 2.45608709 throughput (samples/sec): 134.08
2019-07-29 20:52:13,655 epoch 2 - iter 395/795 - loss 2.43405817 throughput (samples/sec): 135.43
2019-07-29 20:52:23,213 epoch 2 - iter 474/795 - loss 2.42034678 throughput (samples/sec): 133.95
2019-07-29 20:52:33,684 epoch 2 - iter 553/795 - loss 2.36586307 throughput (samples/sec): 122.13
2019-07-29 20:52:44,983 epoch 2 - iter 632/795 - loss 2.32865828 throughput (samples/sec): 113.07
2019-07-29 20:52:55,985 epoch 2 - iter 711/795 - loss 2.28737589 throughput (samples/sec): 116.18
2019-07-29 20:53:07,091 epoch 2 - iter 790/795 - loss 2.28597844 throughput (samples/sec): 115.07
2019-07-29 20:53:07,747 ---------------------------------------------------------------------------------------------------- With 2019-07-29 23:40:03,869 epoch 1 - iter 0/795 - loss 40.28230286 throughput (samples/sec): 2979.66
2019-07-29 23:40:49,085 epoch 1 - iter 79/795 - loss 10.62385516 throughput (samples/sec): 28.04
2019-07-29 23:41:32,755 epoch 1 - iter 158/795 - loss 7.94562467 throughput (samples/sec): 29.03
2019-07-29 23:42:15,033 epoch 1 - iter 237/795 - loss 6.78094378 throughput (samples/sec): 29.98
2019-07-29 23:42:59,963 epoch 1 - iter 316/795 - loss 6.05153041 throughput (samples/sec): 28.21
2019-07-29 23:43:42,590 epoch 1 - iter 395/795 - loss 5.49845262 throughput (samples/sec): 29.74
2019-07-29 23:44:25,641 epoch 1 - iter 474/795 - loss 5.16456086 throughput (samples/sec): 29.48
2019-07-29 23:45:09,427 epoch 1 - iter 553/795 - loss 4.82567009 throughput (samples/sec): 28.95
2019-07-29 23:45:52,538 epoch 1 - iter 632/795 - loss 4.57211298 throughput (samples/sec): 29.40
2019-07-29 23:46:33,286 epoch 1 - iter 711/795 - loss 4.38901180 throughput (samples/sec): 31.11
2019-07-29 23:47:17,371 epoch 1 - iter 790/795 - loss 4.20550554 throughput (samples/sec): 28.75
2019-07-29 23:47:19,562 ----------------------------------------------------------------------------------------------------
2019-07-29 23:47:19,562 EPOCH 1 done: loss 4.2002 - lr 0.1000
2019-07-29 23:49:26,519 DEV : loss 3.2141149044036865 - score 0.5982
2019-07-29 23:49:26,558 BAD EPOCHS (no improvement): 0
2019-07-29 23:49:27,139 ----------------------------------------------------------------------------------------------------
2019-07-29 23:49:27,282 epoch 2 - iter 0/795 - loss 1.12042511 throughput (samples/sec): 8974.31
2019-07-29 23:49:37,825 epoch 2 - iter 79/795 - loss 2.54959387 throughput (samples/sec): 121.26
2019-07-29 23:49:48,250 epoch 2 - iter 158/795 - loss 2.49878981 throughput (samples/sec): 122.71
2019-07-29 23:49:58,878 epoch 2 - iter 237/795 - loss 2.43002224 throughput (samples/sec): 120.30
2019-07-29 23:50:09,504 epoch 2 - iter 316/795 - loss 2.40728975 throughput (samples/sec): 120.34
2019-07-29 23:50:20,038 epoch 2 - iter 395/795 - loss 2.34507697 throughput (samples/sec): 121.37
2019-07-29 23:50:30,534 epoch 2 - iter 474/795 - loss 2.32682362 throughput (samples/sec): 121.85
2019-07-29 23:50:41,203 epoch 2 - iter 553/795 - loss 2.30398093 throughput (samples/sec): 119.88
2019-07-29 23:50:51,679 epoch 2 - iter 632/795 - loss 2.27872584 throughput (samples/sec): 122.10
2019-07-29 23:51:02,250 epoch 2 - iter 711/795 - loss 2.25251733 throughput (samples/sec): 120.96
2019-07-29 23:51:13,189 epoch 2 - iter 790/795 - loss 2.22839135 throughput (samples/sec): 116.84
2019-07-29 23:51:13,751 ----------------------------------------------------------------------------------------------------
|
@stefan-it that is strange. Do you have a setup to test a language model training task? LM training is the only task for which we currently have 100% GPU usage. |
@stefan-it There is some overhead incurred with mixed precision training, so models with low arithmetic intensity, i.e. the ratio of algorithm implementation operations and the number of bytes accessed, might not necessarily benefit. For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions (such as batch size, input size, output size, and channel counts) should be powers of two if under 256, or Large models such as BERT pre-training have seen up to 4x speed up. I'm curious to see which Flair models will see most benefits :) |
I've tested APEX for training a character-LM over the 1-B word corpus. It's a 512 state, 1-layer RNN with a dictionary size of 280, sequence length of 256 and a mini-batch size of 256. Without APEX: | end of split 1 / 99 | epoch 1 | time: 54.22s | valid loss 1.61 | valid ppl 5.01 | learning rate 20.0000
| end of split 2 / 99 | epoch 1 | time: 53.45s | valid loss 1.40 | valid ppl 4.06 | learning rate 20.0000
| end of split 3 / 99 | epoch 1 | time: 53.72s | valid loss 1.54 | valid ppl 4.68 | learning rate 20.0000
| end of split 4 / 99 | epoch 1 | time: 53.24s | valid loss 1.29 | valid ppl 3.63 | learning rate 20.0000
| end of split 5 / 99 | epoch 1 | time: 53.53s | valid loss 1.26 | valid ppl 3.52 | learning rate 20.0000
| end of split 6 / 99 | epoch 1 | time: 53.71s | valid loss 1.24 | valid ppl 3.46 | learning rate 20.0000
| end of split 7 / 99 | epoch 1 | time: 53.62s | valid loss 1.23 | valid ppl 3.41 | learning rate 20.0000
| end of split 8 / 99 | epoch 1 | time: 53.76s | valid loss 1.21 | valid ppl 3.36 | learning rate 20.0000
| end of split 9 / 99 | epoch 1 | time: 52.98s | valid loss 1.21 | valid ppl 3.37 | learning rate 20.0000
| end of split 10 / 99 | epoch 1 | time: 53.71s | valid loss 1.19 | valid ppl 3.30 | learning rate 20.0000
| end of split 11 / 99 | epoch 1 | time: 53.77s | valid loss 1.21 | valid ppl 3.36 | learning rate 20.0000
| end of split 12 / 99 | epoch 1 | time: 53.75s | valid loss 1.19 | valid ppl 3.28 | learning rate 20.0000
| end of split 13 / 99 | epoch 1 | time: 53.85s | valid loss 1.18 | valid ppl 3.24 | learning rate 20.0000
| end of split 14 / 99 | epoch 1 | time: 53.36s | valid loss 1.17 | valid ppl 3.22 | learning rate 20.0000
| end of split 15 / 99 | epoch 1 | time: 53.76s | valid loss 1.17 | valid ppl 3.21 | learning rate 20.0000
| end of split 16 / 99 | epoch 1 | time: 53.79s | valid loss 1.17 | valid ppl 3.21 | learning rate 20.0000
| end of split 17 / 99 | epoch 1 | time: 53.82s | valid loss 1.18 | valid ppl 3.27 | learning rate 20.0000
| end of split 18 / 99 | epoch 1 | time: 53.87s | valid loss 1.16 | valid ppl 3.18 | learning rate 20.0000
| end of split 19 / 99 | epoch 1 | time: 53.76s | valid loss 1.16 | valid ppl 3.20 | learning rate 20.0000
| end of split 20 / 99 | epoch 1 | time: 54.16s | valid loss 1.16 | valid ppl 3.18 | learning rate 20.0000
| end of split 21 / 99 | epoch 1 | time: 53.21s | valid loss 1.15 | valid ppl 3.17 | learning rate 20.0000
| end of split 22 / 99 | epoch 1 | time: 53.66s | valid loss 1.15 | valid ppl 3.14 | learning rate 20.0000
| end of split 23 / 99 | epoch 1 | time: 53.96s | valid loss 1.15 | valid ppl 3.15 | learning rate 20.0000
| end of split 24 / 99 | epoch 1 | time: 53.62s | valid loss 1.17 | valid ppl 3.23 | learning rate 20.0000
| end of split 25 / 99 | epoch 1 | time: 53.81s | valid loss 1.14 | valid ppl 3.12 | learning rate 20.0000
| end of split 26 / 99 | epoch 1 | time: 53.67s | valid loss 1.14 | valid ppl 3.12 | learning rate 20.0000
| end of split 27 / 99 | epoch 1 | time: 53.88s | valid loss 1.14 | valid ppl 3.12 | learning rate 20.0000
| end of split 28 / 99 | epoch 1 | time: 53.54s | valid loss 1.15 | valid ppl 3.17 | learning rate 20.0000
| end of split 29 / 99 | epoch 1 | time: 53.41s | valid loss 1.13 | valid ppl 3.10 | learning rate 20.0000
| end of split 30 / 99 | epoch 1 | time: 53.80s | valid loss 1.13 | valid ppl 3.10 | learning rate 20.0000 With APEX: | end of split 1 / 99 | epoch 1 | time: 35.92s | valid loss 1.63 | valid ppl 5.12 | learning rate 20.0000
| end of split 2 / 99 | epoch 1 | time: 36.27s | valid loss 1.41 | valid ppl 4.09 | learning rate 20.0000
| end of split 3 / 99 | epoch 1 | time: 34.98s | valid loss 1.56 | valid ppl 4.75 | learning rate 20.0000
| end of split 4 / 99 | epoch 1 | time: 35.17s | valid loss 1.28 | valid ppl 3.61 | learning rate 20.0000
| end of split 5 / 99 | epoch 1 | time: 35.14s | valid loss 1.26 | valid ppl 3.53 | learning rate 20.0000
| end of split 6 / 99 | epoch 1 | time: 35.93s | valid loss 1.24 | valid ppl 3.45 | learning rate 20.0000
| end of split 7 / 99 | epoch 1 | time: 35.41s | valid loss 1.22 | valid ppl 3.40 | learning rate 20.0000
| end of split 8 / 99 | epoch 1 | time: 34.83s | valid loss 1.22 | valid ppl 3.38 | learning rate 20.0000
| end of split 9 / 99 | epoch 1 | time: 35.57s | valid loss 1.21 | valid ppl 3.36 | learning rate 20.0000
| end of split 10 / 99 | epoch 1 | time: 36.25s | valid loss 1.20 | valid ppl 3.31 | learning rate 20.0000
| end of split 11 / 99 | epoch 1 | time: 36.07s | valid loss 1.22 | valid ppl 3.38 | learning rate 20.0000
| end of split 12 / 99 | epoch 1 | time: 35.28s | valid loss 1.19 | valid ppl 3.30 | learning rate 20.0000
| end of split 13 / 99 | epoch 1 | time: 35.93s | valid loss 1.18 | valid ppl 3.26 | learning rate 20.0000
| end of split 14 / 99 | epoch 1 | time: 35.78s | valid loss 1.17 | valid ppl 3.23 | learning rate 20.0000
| end of split 15 / 99 | epoch 1 | time: 34.98s | valid loss 1.17 | valid ppl 3.22 | learning rate 20.0000
| end of split 16 / 99 | epoch 1 | time: 36.06s | valid loss 1.17 | valid ppl 3.22 | learning rate 20.0000
| end of split 17 / 99 | epoch 1 | time: 35.85s | valid loss 1.19 | valid ppl 3.28 | learning rate 20.0000
| end of split 18 / 99 | epoch 1 | time: 35.33s | valid loss 1.16 | valid ppl 3.19 | learning rate 20.0000
| end of split 19 / 99 | epoch 1 | time: 34.77s | valid loss 1.16 | valid ppl 3.20 | learning rate 20.0000
| end of split 20 / 99 | epoch 1 | time: 36.08s | valid loss 1.16 | valid ppl 3.19 | learning rate 20.0000
| end of split 21 / 99 | epoch 1 | time: 35.72s | valid loss 1.15 | valid ppl 3.17 | learning rate 20.0000
| end of split 22 / 99 | epoch 1 | time: 35.86s | valid loss 1.15 | valid ppl 3.16 | learning rate 20.0000
| end of split 23 / 99 | epoch 1 | time: 36.13s | valid loss 1.15 | valid ppl 3.16 | learning rate 20.0000
| end of split 24 / 99 | epoch 1 | time: 35.09s | valid loss 1.18 | valid ppl 3.27 | learning rate 20.0000
| end of split 25 / 99 | epoch 1 | time: 36.16s | valid loss 1.15 | valid ppl 3.15 | learning rate 20.0000
| end of split 26 / 99 | epoch 1 | time: 34.85s | valid loss 1.14 | valid ppl 3.13 | learning rate 20.0000
| end of split 27 / 99 | epoch 1 | time: 35.22s | valid loss 1.14 | valid ppl 3.14 | learning rate 20.0000
| end of split 28 / 99 | epoch 1 | time: 35.74s | valid loss 1.15 | valid ppl 3.17 | learning rate 20.0000
| end of split 29 / 99 | epoch 1 | time: 35.24s | valid loss 1.14 | valid ppl 3.11 | learning rate 20.0000
| end of split 30 / 99 | epoch 1 | time: 36.01s | valid loss 1.14 | valid ppl 3.12 | learning rate 20.0000 So we are seeing average time per split reduced from ~ 54 seconds to ~ 35 seconds, which is a nice speed up! I previously tested a smaller language model with only 128 states and there was no difference. I could image that the difference is more pronounced if we train larger LMs. Our typical size LM has 2048 states, so maybe there's even better increases here. |
And another update: When training a large LM with 2048 states, we are seeing over 3x speedup, from > 400 seconds per split down to 130 second per split! :) Here's without APEX: | end of split 1 / 99 | epoch 1 | time: 408.14s | valid loss 1.67 | valid ppl 5.29 | learning rate 20.0000
| end of split 2 / 99 | epoch 1 | time: 406.09s | valid loss 1.38 | valid ppl 3.96 | learning rate 20.0000
| end of split 3 / 99 | epoch 1 | time: 403.50s | valid loss 1.52 | valid ppl 4.58 | learning rate 20.0000
| end of split 4 / 99 | epoch 1 | time: 403.65s | valid loss 1.22 | valid ppl 3.40 | learning rate 20.0000
| end of split 5 / 99 | epoch 1 | time: 404.23s | valid loss 1.18 | valid ppl 3.25 | learning rate 20.0000
| end of split 6 / 99 | epoch 1 | time: 402.83s | valid loss 1.15 | valid ppl 3.17 | learning rate 20.0000
| end of split 7 / 99 | epoch 1 | time: 404.56s | valid loss 1.12 | valid ppl 3.08 | learning rate 20.0000
| end of split 8 / 99 | epoch 1 | time: 406.90s | valid loss 1.11 | valid ppl 3.03 | learning rate 20.0000
| end of split 9 / 99 | epoch 1 | time: 403.44s | valid loss 1.10 | valid ppl 3.01 | learning rate 20.0000
| end of split 10 / 99 | epoch 1 | time: 404.39s | valid loss 1.09 | valid ppl 2.97 | learning rate 20.0000
| end of split 11 / 99 | epoch 1 | time: 403.61s | valid loss 1.13 | valid ppl 3.09 | learning rate 20.0000
| end of split 12 / 99 | epoch 1 | time: 403.89s | valid loss 1.08 | valid ppl 2.95 | learning rate 20.0000
| end of split 13 / 99 | epoch 1 | time: 402.65s | valid loss 1.06 | valid ppl 2.88 | learning rate 20.0000
| end of split 14 / 99 | epoch 1 | time: 404.43s | valid loss 1.05 | valid ppl 2.85 | learning rate 20.0000 With APEX: | end of split 1 / 99 | epoch 1 | time: 127.69s | valid loss 1.67 | valid ppl 5.30 | learning rate 20.0000
| end of split 2 / 99 | epoch 1 | time: 128.90s | valid loss 1.37 | valid ppl 3.93 | learning rate 20.0000
| end of split 3 / 99 | epoch 1 | time: 128.36s | valid loss 1.50 | valid ppl 4.49 | learning rate 20.0000
| end of split 4 / 99 | epoch 1 | time: 128.64s | valid loss 1.22 | valid ppl 3.37 | learning rate 20.0000
| end of split 5 / 99 | epoch 1 | time: 129.29s | valid loss 1.18 | valid ppl 3.26 | learning rate 20.0000
| end of split 6 / 99 | epoch 1 | time: 129.61s | valid loss 1.16 | valid ppl 3.20 | learning rate 20.0000
| end of split 7 / 99 | epoch 1 | time: 127.88s | valid loss 1.13 | valid ppl 3.10 | learning rate 20.0000
| end of split 8 / 99 | epoch 1 | time: 128.66s | valid loss 1.11 | valid ppl 3.05 | learning rate 20.0000
| end of split 9 / 99 | epoch 1 | time: 127.56s | valid loss 1.11 | valid ppl 3.05 | learning rate 20.0000
| end of split 10 / 99 | epoch 1 | time: 129.15s | valid loss 1.08 | valid ppl 2.96 | learning rate 20.0000
| end of split 11 / 99 | epoch 1 | time: 128.45s | valid loss 1.13 | valid ppl 3.10 | learning rate 20.0000
| end of split 12 / 99 | epoch 1 | time: 128.99s | valid loss 1.07 | valid ppl 2.93 | learning rate 20.0000
| end of split 13 / 99 | epoch 1 | time: 129.57s | valid loss 1.06 | valid ppl 2.87 | learning rate 20.0000
| end of split 14 / 99 | epoch 1 | time: 128.62s | valid loss 1.05 | valid ppl 2.85 | learning rate 20.0000 Will merge the PR! |
👍 |
1 similar comment
👍 |
@vinhngx in |
Good point, I'll put in a PR. |
GH-934: renamed amp variable
In response to #324, this PR makes use of APEX (https://github.com/NVIDIA/apex) to provide automatic mixed precision training to Flair.
Automatic mixed precision training makes use of both FP32 and FP16 precisions where appropriate. FP16 operations can leverage the Tensor cores on NVIDIA GPUs (Volta, Turing or newer architectures) for much improved throughput.
Automatic mixed precision training can be enabled via an appropriate flag passed to the trainer:
How mixed precision works
Mixed precision is the use of both float16 and float32 data types when training a model.
Performing arithmetic operations in float16 takes advantage of the performance gains of using specialized processing units such as the Tensor cores on NVIDIA GPUs. Due to the smaller representable range of float16, performing the entire training with float16 data type can result in underflow of the gradients, leading to convergence or model quality issues.
However, performing only select arithmetic operations in float16 results in performance gains when using compatible hardware accelerators, decreasing training time and reducing memory usage, typically without sacrificing model performance.
To learn more about mixed precision and how it works:
Overview of Automatic Mixed Precision for Deep Learning
NVIDIA Mixed Precision Training Documentation
NVIDIA Deep Learning Performance Guide