Skip to content

Commit

Permalink
Merge pull request #2895 from helpmefindaname/bf/resume_epochs_warning
Browse files Browse the repository at this point in the history
warn if resuming with too low max_epochs & ' additional_epochs'  parameter
  • Loading branch information
alanakbik authored Aug 10, 2022
2 parents 7d8cac1 + 1b7a002 commit afac9d2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
9 changes: 9 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def train(
training_parameters[parameter] = local_variables[parameter]
model_card["training_parameters"] = training_parameters

if epoch >= max_epochs:
log.warning(f"Starting at epoch {epoch + 1}/{max_epochs}. No training will be done.")

# add model card to model
self.model.model_card = model_card
assert self.corpus.train
Expand Down Expand Up @@ -859,6 +862,7 @@ def train(
def resume(
self,
model: Model,
additional_epochs: Optional[int] = None,
**trainer_args,
):

Expand All @@ -879,6 +883,11 @@ def resume(
kwargs = args_used_to_train_model["kwargs"]
del args_used_to_train_model["kwargs"]

if additional_epochs is not None:
args_used_to_train_model["max_epochs"] = (
args_used_to_train_model.pop("epoch", kwargs.pop("epoch", 0)) + additional_epochs
)

# resume training with these parameters
self.train(**args_used_to_train_model, **kwargs)

Expand Down
33 changes: 33 additions & 0 deletions tests/test_sequence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,39 @@ def test_train_resume_tagger(results_base_path, tasks_base_path):
del trainer


@pytest.mark.integration
def test_train_resume_tagger_with_additional_epochs(results_base_path, tasks_base_path):

corpus_1 = flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={0: "text", 3: "ner"})
corpus_2 = flair.datasets.NER_GERMAN_GERMEVAL(base_path=tasks_base_path).downsample(0.1)

corpus = MultiCorpus([corpus_1, corpus_2])
tag_dictionary = corpus.make_label_dictionary("ner", add_unk=False)

model: SequenceTagger = SequenceTagger(
hidden_size=64,
embeddings=turian_embeddings,
tag_dictionary=tag_dictionary,
tag_type="ner",
use_crf=False,
)

# train model for 2 epochs
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=1, shuffle=False, checkpoint=True)

del model

# load the checkpoint model and train until epoch 4
checkpoint_model = SequenceTagger.load(results_base_path / "checkpoint.pt")
trainer.resume(model=checkpoint_model, additional_epochs=1)

assert checkpoint_model.model_card["training_parameters"]["max_epochs"] == 2

# clean up results directory
del trainer


@pytest.mark.integration
def test_find_learning_rate(results_base_path, tasks_base_path):
corpus = flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={0: "text", 3: "ner"})
Expand Down

0 comments on commit afac9d2

Please sign in to comment.