Skip to content

Commit

Permalink
GH-440 add new evaluation metrics
Browse files Browse the repository at this point in the history
add mean squared error as default for regression
  • Loading branch information
heukirne committed Mar 6, 2019
1 parent 385e876 commit 3129192
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
40 changes: 40 additions & 0 deletions flair/trainers/trainer_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,46 @@

class RegressorTrainer(flair.trainers.ModelTrainer):

def train(self,
base_path: Union[Path, str],
evaluation_metric: EvaluationMetric = EvaluationMetric.MEAN_SQUARED_ERROR,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
eval_mini_batch_size: int = None,
max_epochs: int = 100,
anneal_factor: float = 0.5,
patience: int = 3,
anneal_against_train_loss: bool = True,
train_with_dev: bool = False,
monitor_train: bool = False,
embeddings_in_memory: bool = True,
checkpoint: bool = False,
save_final_model: bool = True,
anneal_with_restarts: bool = False,
test_mode: bool = False,
param_selection_mode: bool = False,
**kwargs
) -> dict:

return super(RegressorTrainer, self).train(
base_path=base_path,
evaluation_metric=evaluation_metric,
learning_rate=learning_rate,
mini_batch_size=mini_batch_size,
eval_mini_batch_size=eval_mini_batch_size,
max_epochs=max_epochs,
anneal_factor=anneal_factor,
patience=patience,
anneal_against_train_loss=anneal_against_train_loss,
train_with_dev=train_with_dev,
monitor_train=monitor_train,
embeddings_in_memory=embeddings_in_memory,
checkpoint=checkpoint,
save_final_model=save_final_model,
anneal_with_restarts=anneal_with_restarts,
test_mode=test_mode,
param_selection_mode=param_selection_mode)

@staticmethod
def _evaluate_text_regressor(model: flair.nn.Model,
sentences: List[Sentence],
Expand Down
1 change: 1 addition & 0 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class EvaluationMetric(Enum):
MICRO_F1_SCORE = 'micro-average f1-score'
MACRO_ACCURACY = 'macro-average accuracy'
MACRO_F1_SCORE = 'macro-average f1-score'
MEAN_SQUARED_ERROR = 'mean squared error'


class WeightExtractor(object):
Expand Down

0 comments on commit 3129192

Please sign in to comment.