-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add per-label thresholds #2366
Add per-label thresholds #2366
Conversation
flair/nn/model.py
Outdated
@@ -361,7 +361,7 @@ def __init__(self, | |||
|
|||
# set up multi-label logic | |||
self.multi_label = multi_label | |||
self.multi_label_threshold = multi_label_threshold | |||
self.multi_label_threshold = {'default': multi_label_threshold} if type(multi_label_threshold) is float else multi_label_threshold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be more robust, i.e. this particular line would break if user sets an integer as multi_label_threshold
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, doing this only in constructor means if the multi_label_threshold
is set from somewhere else (e.g. from reading state dict), it would still be float rather than a dict.
flair/nn/model.py
Outdated
if type(self.multi_label_theshold) is not map: | ||
label_threshold = self.multi_label_threshold | ||
else: | ||
label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the multi_label_threshold
does not have key default
this fails. This should be checked anywhere where self.multi_label_threshold
is set (in constructor, when saving or loading the model?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibility is to have multi_label_threshold
as a property, and check for this stuff in setter. Example code:
class Model:
def __init__(self, threshold):
self.threshold = threshold
@property
def threshold(self):
return self._threshold
@threshold.setter
def threshold(self, x):
if type(x) is not dict:
raise Exception('Not a dict')
elif 'default' not in x:
raise Exception('default key not present')
else:
self._threshold = x
else: | ||
self._multi_label_threshold = {'default': x} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider checking if x
is a number, e.g.
from numbers import Number
...
if type(x) is dict:
...
elif isinstance(x, Number):
self._multi_label_threshold = {'default': x}
else:
raise Exception('The multi_label_threshold should be either a dict or a number')
@yosipk thanks for adding this - very helpful! |
Closes #2274