Skip to content

Commit

Permalink
Merge pull request #661 from GiliGoldin/master
Browse files Browse the repository at this point in the history
add option to support customized elmo embeddings
  • Loading branch information
stefan-it authored Apr 16, 2019
2 parents f01683d + 524ac08 commit dea50fa
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def __str__(self):
class ELMoEmbeddings(TokenEmbeddings):
"""Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018."""

def __init__(self, model: str = "original"):
def __init__(
self, model: str = "original", options_file: str = None, weight_file: str = None
):
super().__init__()

try:
Expand All @@ -442,22 +444,23 @@ def __init__(self, model: str = "original"):
self.name = "elmo-" + model
self.static_embeddings = True

# the default model for ELMo is the 'original' model, which is very large
options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
# alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
if model == "small":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
if model == "medium":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
if model == "pt" or model == "portuguese":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
if model == "pubmed":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
if not options_file or not weight_file:
# the default model for ELMo is the 'original' model, which is very large
options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
# alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
if model == "small":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
if model == "medium":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
if model == "pt" or model == "portuguese":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
if model == "pubmed":
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"

# put on Cuda if available
from flair import device
Expand Down Expand Up @@ -492,7 +495,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
sentence_embeddings = embeddings[i]

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):

word_embedding = torch.cat(
[
torch.FloatTensor(sentence_embeddings[0, token_idx, :]),
Expand Down

0 comments on commit dea50fa

Please sign in to comment.