Skip to content

Commit

Permalink
GH-440 add correlation metrics for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
heukirne committed Mar 6, 2019
1 parent 58724c4 commit 385e876
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion flair/trainers/trainer_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import List, Union
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr, spearmanr
from flair.training_utils import Metric, EvaluationMetric, clear_embeddings, log_line
from flair.models.text_regression_model import TextRegressor
from flair.data import Sentence, Label
Expand Down Expand Up @@ -50,6 +51,9 @@ def _evaluate_text_regressor(model: flair.nn.Model,

metric['mae'] = mean_absolute_error(results, true_values)
metric['mse'] = mean_squared_error(results, true_values)
metric['pearson'] = pearsonr(results, true_values)[0]
metric['spearman'] = spearmanr(results, true_values)[0]


eval_loss /= len(sentences)

Expand Down Expand Up @@ -93,8 +97,10 @@ def final_test(self,

mae = test_metric['mae']
mse = test_metric['mse']
pearson = test_metric['pearson']
spearman = test_metric['spearman']

log.info(f'AVG: mse {mse} - mae {mae}')
log.info(f'AVG: mse {mse} - mae {mae} - pearson {pearson} - spearman {spearman}')

log_line(log)

Expand Down

0 comments on commit 385e876

Please sign in to comment.