Skip to content

Commit

Permalink
Merge pull request #2888 from flairNLP/learning_rate_factor
Browse files Browse the repository at this point in the history
Convenience method for learning rate factor
  • Loading branch information
alanakbik authored Aug 6, 2022
2 parents b722172 + 4522185 commit a927b30
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,9 +893,29 @@ def fine_tune(
mini_batch_size: int = 4,
embeddings_storage_mode: str = "none",
use_final_model_for_eval: bool = True,
decoder_lr_factor: float = 1.0,
**trainer_args,
):

# If set, add a factor to the learning rate of all parameters with 'decoder' in name
if decoder_lr_factor != 1.0:
optimizer = optimizer(
[
{
"params": [param for name, param in self.model.named_parameters() if "decoder" in name],
"lr": learning_rate * decoder_lr_factor,
},
{
"params": [param for name, param in self.model.named_parameters() if "decoder" not in name],
"lr": learning_rate,
},
]
)
log.info(
f"Increasing learning rate to {learning_rate * decoder_lr_factor} for the following "
f"parameters: {[name for name, param in self.model.named_parameters() if 'decoder' in name]}"
)

return self.train(
base_path=base_path,
learning_rate=learning_rate,
Expand Down

0 comments on commit a927b30

Please sign in to comment.