Skip to content

Commit

Permalink
Automatically fix implicit None for mypy using official tooling
Browse files Browse the repository at this point in the history
By running:
pipx run no_implicit_optional flair
pipx run no_implicit_optional tests

Following:
https://github.com/hauntsaninja/no_implicit_optional
  • Loading branch information
Lingepumpe committed Apr 21, 2023
1 parent 2e3851f commit 8677c62
Show file tree
Hide file tree
Showing 29 changed files with 335 additions and 315 deletions.
20 changes: 10 additions & 10 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def to(self, device: str, pin_memory: bool = False):
else:
self._embeddings[name] = vector.to(device, non_blocking=True)

def clear_embeddings(self, embedding_names: List[str] = None):
def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
if embedding_names is None:
self._embeddings = {}
else:
Expand Down Expand Up @@ -352,12 +352,12 @@ def remove_labels(self, typename: str):
if typename in self.annotation_layers.keys():
del self.annotation_layers[typename]

def get_label(self, label_type: str = None, zero_tag_value="O"):
def get_label(self, label_type: Optional[str] = None, zero_tag_value="O"):
if len(self.get_labels(label_type)) == 0:
return Label(self, zero_tag_value)
return self.get_labels(label_type)[0]

def get_labels(self, typename: str = None):
def get_labels(self, typename: Optional[str] = None):
if typename is None:
return self.labels

Expand Down Expand Up @@ -472,7 +472,7 @@ class Token(_PartOfSentence):
def __init__(
self,
text: str,
head_id: int = None,
head_id: Optional[int] = None,
whitespace_after: int = 1,
start_position: int = 0,
sentence=None,
Expand Down Expand Up @@ -682,7 +682,7 @@ def __init__(
self,
text: Union[str, List[str], List[Token]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: str = None,
language_code: Optional[str] = None,
start_position: int = 0,
):
"""Class to hold all metadata related to a text.
Expand Down Expand Up @@ -831,7 +831,7 @@ def to(self, device: str, pin_memory: bool = False):
for token in self:
token.to(device, pin_memory)

def clear_embeddings(self, embedding_names: List[str] = None):
def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
super().clear_embeddings(embedding_names)

# clear token embeddings
Expand Down Expand Up @@ -946,7 +946,7 @@ def to_original_text(self) -> str:
[t.text + t.whitespace_after * " " for t in self.tokens]
).strip()

def to_dict(self, tag_type: str = None):
def to_dict(self, tag_type: Optional[str] = None):
labels = []

if tag_type:
Expand Down Expand Up @@ -1100,7 +1100,7 @@ def set_context_for_sentences(cls, sentences: List["Sentence"]) -> None:
previous_sentence._next_sentence = sentence
previous_sentence = sentence

def get_labels(self, label_type: str = None):
def get_labels(self, label_type: Optional[str] = None):
# if no label if specified, return all labels
if label_type is None:
return sorted(self.labels)
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def to(self, device: str, pin_memory: bool = False):
self.first.to(device, pin_memory)
self.second.to(device, pin_memory)

def clear_embeddings(self, embedding_names: List[str] = None):
def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
self.first.clear_embeddings(embedding_names)
self.second.clear_embeddings(embedding_names)

Expand Down Expand Up @@ -1362,7 +1362,7 @@ def _downsample_to_proportion(dataset: Dataset, proportion: float):
splits = randomly_split_into_two_datasets(dataset, sampled_size)
return splits[0]

def obtain_statistics(self, label_type: str = None, pretty_print: bool = True) -> Union[dict, str]:
def obtain_statistics(self, label_type: Optional[str] = None, pretty_print: bool = True) -> Union[dict, str]:
"""Print statistics about the class distribution and sentence sizes.
only labels of sentences are taken into account
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Generic, List, Union
from typing import Generic, List, Optional, Union

import torch.utils.data.dataloader
from deprecated import deprecated
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
database: str,
collection: str,
text_field: str,
categories_field: List[str] = None,
categories_field: Optional[List[str]] = None,
max_tokens_per_doc: int = -1,
max_chars_per_doc: int = -1,
tokenizer: Tokenizer = SegtokTokenizer(),
Expand Down
Loading

0 comments on commit 8677c62

Please sign in to comment.