From ed2679757e4150cd73433e393c8565fb04efaef0 Mon Sep 17 00:00:00 2001 From: Josip Krapac Date: Tue, 10 Aug 2021 13:09:00 +0200 Subject: [PATCH 1/5] First commit for PR that adds per-label thresholds --- flair/nn/model.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index dc7bd1c55..ab7ee2e86 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -361,7 +361,7 @@ def __init__(self, # set up multi-label logic self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold + self.multi_label_threshold = {'default': multi_label_threshold} if type(multi_label_threshold) is float else multi_label_threshold # loss weights and loss function self.weight_dict = loss_weights @@ -476,17 +476,17 @@ def predict( if len(label_candidates) > 0: if self.multi_label or multi_class_prob: - sigmoided = torch.sigmoid(scores) - s_idx = 0 - for sentence, label in zip(sentences, label_candidates): - for idx in range(sigmoided.size(1)): - if sigmoided[s_idx, idx] > self.multi_label_threshold or multi_class_prob: - label_value = self.label_dictionary.get_item_for_index(idx) - if label_value == 'O': continue - label.set_value(value=label_value, score=sigmoided[s_idx, idx].item()) + sigmoided = torch.sigmoid(scores) # size: (n_sentences, n_classes) + n_labels = sigmoided.size(1) + for s_idx, (sentence, label) in enumerate(zip(sentences, label_candidates)): + for l_idx in range(n_labels): + label_value = self.label_dictionary.get_item_for_index(l_idx) + if label_value == 'O': continue + label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + label_score = sigmoided[s_idx, l_idx].item() + if label_score > label_threshold or multi_class_prob: + label.set_value(value=label_value, score=label_score) sentence.add_complex_label(label_name, copy.deepcopy(label)) - s_idx += 1 - else: softmax = torch.nn.functional.softmax(scores, dim=-1) conf, idx = torch.max(softmax, dim=-1) @@ -526,9 +526,11 @@ def _get_multi_label(self, label_scores) -> List[Label]: results = list(map(lambda x: sigmoid(x), label_scores)) for idx, conf in enumerate(results): - if conf > self.multi_label_threshold: - label = self.label_dictionary.get_item_for_index(idx) - labels.append(Label(label, conf.item())) + label_value = self.label_dictionary.get_item_for_index(idx) + label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + label_score = conf.item() + if label_score > label_threshold: + labels.append(Label(label_value, label_score)) return labels From 1e0f783ec38fc58d591b106647d859ba7ae390b7 Mon Sep 17 00:00:00 2001 From: Josip Krapac Date: Tue, 10 Aug 2021 15:27:05 +0200 Subject: [PATCH 2/5] Refactored so it's robust to setting the threshold from outside of the constructor --- flair/nn/model.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index ab7ee2e86..6026cfcd8 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -361,7 +361,7 @@ def __init__(self, # set up multi-label logic self.multi_label = multi_label - self.multi_label_threshold = {'default': multi_label_threshold} if type(multi_label_threshold) is float else multi_label_threshold + self.multi_label_threshold = multi_label_threshold # loss weights and loss function self.weight_dict = loss_weights @@ -400,6 +400,7 @@ def _calculate_loss(self, scores, labels): return self.loss_function(scores, labels), len(labels) + def predict( self, sentences: Union[List[Sentence], Sentence], @@ -482,7 +483,7 @@ def predict( for l_idx in range(n_labels): label_value = self.label_dictionary.get_item_for_index(l_idx) if label_value == 'O': continue - label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + label_threshold = self._get_label_threshold(label_value) label_score = sigmoided[s_idx, l_idx].item() if label_score > label_threshold or multi_class_prob: label.set_value(value=label_value, score=label_score) @@ -503,6 +504,14 @@ def predict( if return_loss: return overall_loss, label_count + def _get_label_threshold(self, label_value): + if type(self.multi_label_theshold) is not map: + label_threshold = self.multi_label_threshold + else: + label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + + return label_threshold + def _obtain_labels( self, scores: List[List[float]], predict_prob: bool = False ) -> List[List[Label]]: @@ -527,7 +536,7 @@ def _get_multi_label(self, label_scores) -> List[Label]: results = list(map(lambda x: sigmoid(x), label_scores)) for idx, conf in enumerate(results): label_value = self.label_dictionary.get_item_for_index(idx) - label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + label_threshold = self._get_label_threshold(label_value) label_score = conf.item() if label_score > label_threshold: labels.append(Label(label_value, label_score)) From dd5f001ffc5260379daf786b44285a8fcdc92476 Mon Sep 17 00:00:00 2001 From: Josip Krapac Date: Tue, 10 Aug 2021 17:15:20 +0200 Subject: [PATCH 3/5] Added property and property getter for multi_label_threshold --- flair/nn/model.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 6026cfcd8..ec9ad082a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -381,6 +381,20 @@ def __init__(self, else: self.loss_function = torch.nn.CrossEntropyLoss(weight=self.loss_weights) + @property + def multi_label_threshold(self): + return self._multi_label_threshold + + @setter.multi_label_threshold + def multi_label_threshold(self, x): + if type(x) is dict: + if 'default' in x: + self._multi_label_threshold = x + else: + raise Exception('multi_label_threshold dict should have a "default" key') + else: + self._multi_label_threshold = {'default': x} + def forward_loss(self, sentences: Union[List[DataPoint], DataPoint]) -> torch.tensor: scores, labels = self.forward_pass(sentences) return self._calculate_loss(scores, labels) @@ -505,10 +519,9 @@ def predict( return overall_loss, label_count def _get_label_threshold(self, label_value): - if type(self.multi_label_theshold) is not map: - label_threshold = self.multi_label_threshold - else: - label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] + label_threshold = self.multi_label_threshold['default'] + if label_value in self.multi_label_threshold: + label_threshold = self.multi_label_threshold[label_value] return label_threshold From 45e5d8ab9e426e07be3042c2fa5d1b74cd7c5fb6 Mon Sep 17 00:00:00 2001 From: Josip Krapac Date: Tue, 10 Aug 2021 17:26:31 +0200 Subject: [PATCH 4/5] Fixed typo --- flair/nn/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index ec9ad082a..5ca78d08e 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -385,7 +385,7 @@ def __init__(self, def multi_label_threshold(self): return self._multi_label_threshold - @setter.multi_label_threshold + @multi_label_threshold.setter def multi_label_threshold(self, x): if type(x) is dict: if 'default' in x: From e93de1ea674a71d4de2e22f297c9dffbf531218a Mon Sep 17 00:00:00 2001 From: Josip Krapac Date: Wed, 11 Aug 2021 09:59:19 +0200 Subject: [PATCH 5/5] Dummy commit to trigger push to PR --- flair/nn/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 5ca78d08e..1c6012412 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -385,8 +385,8 @@ def __init__(self, def multi_label_threshold(self): return self._multi_label_threshold - @multi_label_threshold.setter - def multi_label_threshold(self, x): + @multi_label_threshold.setter + def multi_label_threshold(self, x): # setter method if type(x) is dict: if 'default' in x: self._multi_label_threshold = x