Skip to content

Commit

Permalink
Merge pull request #1257 from yahshibu/some_trivial_modifications
Browse files Browse the repository at this point in the history
Add staticmethod decorator
  • Loading branch information
yosipk authored Nov 28, 2019
2 parents fa3be40 + 1c16d12 commit 1d44abf
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 1 deletion.
2 changes: 2 additions & 0 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):

rnn_type = "LSTM" if not "rnn_type" in state.keys() else state["rnn_type"]
Expand Down Expand Up @@ -833,6 +834,7 @@ def _filter_empty_string(texts: List[str]) -> List[str]:
)
return filtered_texts

@staticmethod
def _fetch_model(model_name) -> str:

model_map = {}
Expand Down
1 change: 1 addition & 0 deletions flair/models/similarity_learning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):
# The conversion from old model's constructor interface
if "input_embeddings" in state:
Expand Down
2 changes: 2 additions & 0 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):

model = TextClassifier(
Expand Down Expand Up @@ -417,6 +418,7 @@ def _labels_to_indices(self, sentences: List[Sentence]):

return vec

@staticmethod
def _fetch_model(model_name) -> str:

model_map = {}
Expand Down
1 change: 1 addition & 0 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):

model = TextRegressor(document_embeddings=state["document_embeddings"])
Expand Down
4 changes: 3 additions & 1 deletion flair/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def _get_state_dict(self):
functionality."""
pass

@staticmethod
@abstractmethod
def _init_model_with_state_dict(state):
"""Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint()
functionality."""
pass

@staticmethod
@abstractmethod
def _fetch_model(model_name) -> str:
return model_name
Expand All @@ -71,7 +73,7 @@ def save(self, model_file: Union[str, Path]):
def load(cls, model: Union[str, Path]):
"""
Loads the model from the given file.
:param model_file: the model file
:param model: the model file
:return: the loaded text classifier model
"""
model_file = cls._fetch_model(str(model))
Expand Down

0 comments on commit 1d44abf

Please sign in to comment.