Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add per-label thresholds #2366

Merged
merged 5 commits into from
Aug 17, 2021
Merged

Add per-label thresholds #2366

merged 5 commits into from
Aug 17, 2021

Conversation

yosipk
Copy link
Collaborator

@yosipk yosipk commented Aug 10, 2021

Closes #2274

@yosipk yosipk changed the title PR that adds per-label thresholds Add per-label thresholds Aug 10, 2021
@@ -361,7 +361,7 @@ def __init__(self,

# set up multi-label logic
self.multi_label = multi_label
self.multi_label_threshold = multi_label_threshold
self.multi_label_threshold = {'default': multi_label_threshold} if type(multi_label_threshold) is float else multi_label_threshold
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be more robust, i.e. this particular line would break if user sets an integer as multi_label_threshold.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, doing this only in constructor means if the multi_label_threshold is set from somewhere else (e.g. from reading state dict), it would still be float rather than a dict.

if type(self.multi_label_theshold) is not map:
label_threshold = self.multi_label_threshold
else:
label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the multi_label_threshold does not have key default this fails. This should be checked anywhere where self.multi_label_threshold is set (in constructor, when saving or loading the model?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibility is to have multi_label_threshold as a property, and check for this stuff in setter. Example code:

class Model:

    def __init__(self, threshold):
        self.threshold = threshold

    @property
    def threshold(self):
        return self._threshold

    @threshold.setter
    def threshold(self, x):
        if type(x) is not dict:
            raise Exception('Not a dict')
        elif 'default' not in x:
            raise Exception('default key not present')
        else:
            self._threshold = x

Comment on lines +395 to +396
else:
self._multi_label_threshold = {'default': x}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider checking if x is a number, e.g.

from numbers import Number
...
if type(x) is dict:
  ...
elif isinstance(x, Number):
  self._multi_label_threshold = {'default': x}
else:
  raise Exception('The multi_label_threshold should be either a dict or a number') 

@alanakbik
Copy link
Collaborator

@yosipk thanks for adding this - very helpful!

@alanakbik alanakbik merged commit 133e5da into master Aug 17, 2021
@alanakbik alanakbik deleted the GH-2274-per-label-threshold branch August 17, 2021 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow per-label threshold for multi-label classifiers
2 participants