Skip to content

Commit

Permalink
GH-873: update TransformerXLEmbeddings (now compatible with pytorch-t…
Browse files Browse the repository at this point in the history
…ransformers)
  • Loading branch information
stefan-it committed Jul 17, 2019
1 parent 313c4ec commit 2626652
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,18 +873,22 @@ def __str__(self):


class TransformerXLEmbeddings(TokenEmbeddings):
def __init__(self, model: str = "transfo-xl-wt103"):
def __init__(self, model: str = "transfo-xl-wt103", layers: str = "-1"):
"""Transformer-XL embeddings, as proposed in Dai et al., 2019.
:param model: name of Transformer-XL model
:param layers: comma-separated list of layers
"""
super().__init__()

if model not in TRANSFORMER_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys():
raise ValueError("Provided Transformer-XL model is not available.")

self.tokenizer = TransfoXLTokenizer.from_pretrained(model)
self.model = TransfoXLModel.from_pretrained(model)
self.model = TransfoXLModel.from_pretrained(
pretrained_model_name_or_path=model, output_hidden_states=True
)
self.name = model
self.layers: List[int] = [int(layer) for layer in layers.split(",")]
self.static_embeddings = True

dummy_sentence: Sentence = Sentence()
Expand All @@ -904,18 +908,22 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

with torch.no_grad():
for sentence in sentences:
token_strings = [token.text for token in sentence.tokens]
indexed_tokens = self.tokenizer.convert_tokens_to_ids(token_strings)
for token in sentence.tokens:
token_text = token.text
indexed_token = self.tokenizer.convert_tokens_to_ids([token_text])
token_tensor = torch.tensor([indexed_token])
token_tensor = token_tensor.to(flair.device)

tokens_tensor = torch.tensor([indexed_tokens])
tokens_tensor = tokens_tensor.to(flair.device)
_, _, hidden_states = self.model(token_tensor)

hidden_states, _ = self.model(tokens_tensor)
token_embeddings = []

for token, token_idx in zip(
sentence.tokens, range(len(sentence.tokens))
):
token.set_embedding(self.name, hidden_states[0][token_idx])
for layer in self.layers:
current_embedding = hidden_states[layer][0][0]
token_embeddings.append(current_embedding)

final_token_embedding = torch.cat(token_embeddings)
token.set_embedding(self.name, final_token_embedding)

return sentences

Expand Down

0 comments on commit 2626652

Please sign in to comment.