Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flair Regression #564

Merged
merged 14 commits into from
Apr 16, 2019
4 changes: 4 additions & 0 deletions flair/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class NLPTask(Enum):
TREC_6 = "trec-6"
TREC_50 = "trec-50"

# text regression format
REGRESSION = 'regression'


class NLPTaskDataFetcher:
@staticmethod
Expand Down Expand Up @@ -210,6 +213,7 @@ def load_corpus(
NLPTask.AG_NEWS.value,
NLPTask.TREC_6.value,
NLPTask.TREC_50.value,
NLPTask.REGRESSION.value,
]:
use_tokenizer: bool = False if task in [
NLPTask.TREC_6.value,
Expand Down
59 changes: 59 additions & 0 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import flair
import torch
import torch.nn as nn
from typing import List, Union
from flair.training_utils import clear_embeddings
from flair.data import Sentence, Label
import logging

log = logging.getLogger('flair')

class TextRegressor(flair.models.TextClassifier):

def __init__(self,
document_embeddings: flair.embeddings.DocumentEmbeddings,
label_dictionary: flair.data.Dictionary,
multi_label: bool):

super(TextRegressor, self).__init__(document_embeddings=document_embeddings, label_dictionary=flair.data.Dictionary(), multi_label=multi_label)

log.info('Using REGRESSION - experimental')

self.loss_function = nn.MSELoss()

def _labels_to_indices(self, sentences: List[Sentence]):
indices = [
torch.FloatTensor([float(label.value) for label in sentence.labels])
for sentence in sentences
]

vec = torch.cat(indices, 0)
if torch.cuda.is_available():
vec = vec.cuda()

return vec

def forward_labels_and_loss(self, sentences: Union[Sentence, List[Sentence]]) -> (List[List[float]], torch.tensor):
scores = self.forward(sentences)
loss = self._calculate_loss(scores, sentences)
return scores, loss

def predict(self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32) -> List[Sentence]:

with torch.no_grad():
if type(sentences) is Sentence:
sentences = [sentences]

filtered_sentences = self._filter_empty_sentences(sentences)

batches = [filtered_sentences[x:x + mini_batch_size] for x in range(0, len(filtered_sentences), mini_batch_size)]

for batch in batches:
scores = self.forward(batch)

for (sentence, score) in zip(batch, scores.tolist()):
sentence.labels = [Label(value=str(score[0]))]

clear_embeddings(batch)

return sentences
147 changes: 147 additions & 0 deletions flair/trainers/trainer_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import flair
import torch
import torch.nn as nn

from typing import List, Union
from flair.training_utils import MetricRegression, EvaluationMetric, clear_embeddings, log_line
from flair.models.text_regression_model import TextRegressor
from flair.data import Sentence, Label
from pathlib import Path
import logging

log = logging.getLogger('flair')

class RegressorTrainer(flair.trainers.ModelTrainer):

def train(self,
base_path: Union[Path, str],
evaluation_metric: EvaluationMetric = EvaluationMetric.MEAN_SQUARED_ERROR,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
eval_mini_batch_size: int = None,
max_epochs: int = 100,
anneal_factor: float = 0.5,
patience: int = 3,
anneal_against_train_loss: bool = True,
train_with_dev: bool = False,
monitor_train: bool = False,
embeddings_in_memory: bool = True,
checkpoint: bool = False,
save_final_model: bool = True,
anneal_with_restarts: bool = False,
test_mode: bool = False,
param_selection_mode: bool = False,
**kwargs
) -> dict:

return super(RegressorTrainer, self).train(
base_path=base_path,
evaluation_metric=evaluation_metric,
learning_rate=learning_rate,
mini_batch_size=mini_batch_size,
eval_mini_batch_size=eval_mini_batch_size,
max_epochs=max_epochs,
anneal_factor=anneal_factor,
patience=patience,
anneal_against_train_loss=anneal_against_train_loss,
train_with_dev=train_with_dev,
monitor_train=monitor_train,
embeddings_in_memory=embeddings_in_memory,
checkpoint=checkpoint,
save_final_model=save_final_model,
anneal_with_restarts=anneal_with_restarts,
test_mode=test_mode,
param_selection_mode=param_selection_mode)

@staticmethod
def _evaluate_text_regressor(model: flair.nn.Model,
sentences: List[Sentence],
eval_mini_batch_size: int = 32,
embeddings_in_memory: bool = False,
out_path: Path = None) -> (dict, float):

with torch.no_grad():
eval_loss = 0

batches = [sentences[x:x + eval_mini_batch_size] for x in
range(0, len(sentences), eval_mini_batch_size)]

metric = MetricRegression('Evaluation')

lines: List[str] = []
for batch in batches:

scores, loss = model.forward_labels_and_loss(batch)

true_values = []
for sentence in batch:
for label in sentence.labels:
true_values.append(float(label.value))

results = []
for score in scores:
if type(score[0]) is Label:
results.append(float(score[0].score))
else:
results.append(float(score[0]))

clear_embeddings(batch, also_clear_word_embeddings=not embeddings_in_memory)

eval_loss += loss

metric.true.extend(true_values)
metric.pred.extend(results)

eval_loss /= len(sentences)

##TODO: not saving lines yet
if out_path is not None:
with open(out_path, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))

return metric, eval_loss


def _calculate_evaluation_results_for(self,
dataset_name: str,
dataset: List[Sentence],
evaluation_metric: EvaluationMetric,
embeddings_in_memory: bool,
eval_mini_batch_size: int,
out_path: Path = None):

metric, loss = RegressorTrainer._evaluate_text_regressor(self.model, dataset, eval_mini_batch_size=eval_mini_batch_size,
embeddings_in_memory=embeddings_in_memory, out_path=out_path)

mse = metric.mean_squared_error()
mae = metric.mean_absolute_error()

log.info(f'{dataset_name:<5}: loss {loss:.8f} - mse {mse:.4f} - mae {mae:.4f}')

return metric, loss

def final_test(self,
base_path: Path,
embeddings_in_memory: bool,
evaluation_metric: EvaluationMetric,
eval_mini_batch_size: int):

log_line(log)
log.info('Testing using best model ...')

self.model.eval()

if (base_path / 'best-model.pt').exists():
self.model = TextRegressor.load_from_file(base_path / 'best-model.pt')

test_metric, test_loss = self._evaluate_text_regressor(self.model, self.corpus.test, eval_mini_batch_size=eval_mini_batch_size,
embeddings_in_memory=embeddings_in_memory)

log.info(f'AVG: mse: {test_metric.mean_squared_error():.4f} - '
f'mae: {test_metric.mean_absolute_error():.4f} - '
f'pearson: {test_metric.pearsonr():.4f} - '
f'spearman: {test_metric.spearmanr():.4f}')

log_line(log)

return test_metric.mean_squared_error()
64 changes: 60 additions & 4 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import List
from flair.data import Dictionary, Sentence
from functools import reduce
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr, spearmanr


class Metric(object):
Expand Down Expand Up @@ -171,11 +173,65 @@ def __str__(self):
return "\n".join(all_lines)


class MetricRegression(object):

def __init__(self, name):
self.name = name

self.true = []
self.pred = []

def mean_squared_error(self):
return mean_squared_error(self.true, self.pred)

def mean_absolute_error(self):
return mean_absolute_error(self.true, self.pred)

def pearsonr(self):
return pearsonr(self.true, self.pred)[0]

def spearmanr(self):
return spearmanr(self.true, self.pred)[0]

## dummy return to fulfill trainer.train() needs
def micro_avg_f_score(self):
return self.mean_squared_error()

def to_tsv(self):
return '{}\t{}\t{}\t{}'.format(
self.mean_squared_error(),
self.mean_absolute_error(),
self.pearsonr(),
self.spearmanr(),
)

@staticmethod
def tsv_header(prefix=None):
if prefix:
return '{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN'.format(
prefix)

return 'MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN'

@staticmethod
def to_empty_tsv():
return '\t_\t_\t_\t_'

def __str__(self):
line = 'mean squared error: {0:.4f} - mean absolute error: {1:.4f} - pearson: {2:.4f} - spearman: {3:.4f}'.format(
self.mean_squared_error(),
self.mean_absolute_error(),
self.pearsonr(),
self.spearmanr())
return line


class EvaluationMetric(Enum):
MICRO_ACCURACY = "micro-average accuracy"
MICRO_F1_SCORE = "micro-average f1-score"
MACRO_ACCURACY = "macro-average accuracy"
MACRO_F1_SCORE = "macro-average f1-score"
MICRO_ACCURACY = 'micro-average accuracy'
MICRO_F1_SCORE = 'micro-average f1-score'
MACRO_ACCURACY = 'macro-average accuracy'
MACRO_F1_SCORE = 'macro-average f1-score'
MEAN_SQUARED_ERROR = 'mean squared error'


class WeightExtractor(object):
Expand Down
14 changes: 14 additions & 0 deletions tests/resources/tasks/regression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## REGRESSION

Data is taken from [here](http://saifmohammad.com/WebPages/EmotionIntensity-SharedTask.html).

The dataset contains a collection of tweets with joy intensity value.
We took the joy dataset and converted it to the expected format of our data fetcher:
```
__label__<joy_intensity> <text>
```

#### Publication About the Dataset

* Emotion Intensities in Tweets. Saif M. Mohammad and Felipe Bravo-Marquez. In Proceedings of the sixth joint conference on lexical and computational semantics (*Sem), August 2017, Vancouver, Canada.
* WASSA-2017 Shared Task on Emotion Intensity. Saif M. Mohammad and Felipe Bravo-Marquez. In Proceedings of the EMNLP 2017 Workshop on Computational Approaches to Subjectivity, Sentiment, and Social Media (WASSA), September 2017, Copenhagen, Denmark.
Loading