Proper way to do contrastive learning with DDP & PT-Lightning #14390
-
I want to use DDP and experiment with contrastive losses. Since DDP processes each subset of the data independently, negative examples that could be used to increase the contrastive power cannot be taken into account using automatic optimization. Suppose I am training with 2 GPU's and each GPU sees a mini-batch of size 4. This leads to missing signal between (x1, x5), (x1, x6), (x1, x7), etc... since x1-x4 are on GPU1 and x5-x8 are on GPU2. What is the recommended method to account for this in PT-Lightning? After computing loss, I'm unclear as to the mechanics for how to distribute that loss computed on rank=0 back to all of the GPU's so that the gradients are synced. Is this something that happens automatically under the hood, or do I need to do something w.r.t. manual optimization? |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 7 replies
-
@awaelchli @williamFalcon This is, in fact, the same tutorial I wanted for sync'ing stuff using |
Beta Was this translation helpful? Give feedback.
-
@kkarrancsu You are definitely on the right track here. In the LightningModule, you have this method for gathering a tensor from all processes: tensors_from_all = self.all_gather(my_tensor) What you want is to back-propagate through this all_gather function, and this is possible if you set tensors_from_all = self.all_gather(my_tensor, sync_grad=True) In your case, your def training_step(self, batch, batch_idx):
outputs = self(batch)
...
all_outputs = self.all_gather(outputs, sync_grads=True)
loss = contrastive_loss_fn(all_outputs, ...)
return loss |
Beta Was this translation helpful? Give feedback.
-
Follow-up question, to do contrastive learning for DP, I believe that we need to implement the loss computation in the Does this mean we should have flags in our Here is a stub:
Two follow-up questions: |
Beta Was this translation helpful? Give feedback.
-
Hi! Thank you for the answer @awaelchli. I also want to use contrastive learning with DDP, and have an additional question on the code you mentioned here. By this I mean, should we have an additional condition in the example:
In such case, would using sync_grads=True mean the computed loss would not need to be distributed/broadcasted back to all workers and all the optimizers would have access to the corresponding gradients automatically for their step? When computing the loss only on Thank you in advance! |
Beta Was this translation helpful? Give feedback.
-
Excellent! However, how about doing contrastive learning with |
Beta Was this translation helpful? Give feedback.
@kkarrancsu You are definitely on the right track here. In the LightningModule, you have this method for gathering a tensor from all processes:
What you want is to back-propagate through this all_gather function, and this is possible if you set
In your case, your
training_step
method could look something like this: