Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Embeddings from ELMo Transformer #2351

Closed
stefan-it opened this issue Jan 12, 2019 · 4 comments
Closed

Embeddings from ELMo Transformer #2351

stefan-it opened this issue Jan 12, 2019 · 4 comments
Assignees

Comments

@stefan-it
Copy link

stefan-it commented Jan 12, 2019

Hi,

I have the following question regarding to the ELMo transformer model. I trained an own model and a model.tar.gz was created successfully :)

With the "normal" ELMo model I would do the following to get the embeddings:

import allennlp.commands.elmo

import torch

options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'

cuda_device = 0 if torch.cuda.is_available() else -1

ee = allennlp.commands.elmo.ElmoEmbedder(options_file=options_file,
                                         weight_file=weight_file,
cuda_device=cuda_device)

sentence = "It is snowing in Munich ."

tokens = [token for token in sentence.split()]

embeddings = ee.embed_batch([tokens])

sentence_embeddings = embeddings[0]

for token, token_idx in zip(tokens, range(len(tokens))):
    embedding = torch.cat([torch.FloatTensor(sentence_embeddings[0, token_idx, :]),
                           torch.FloatTensor(sentence_embeddings[1, token_idx, :]),
                           torch.FloatTensor(sentence_embeddings[2, token_idx, :])], 0)

    word_embedding = torch.autograd.Variable(embedding)

    print(word_embedding)

Could you provide a small snippet of how to get the embeddings from a transformer-based ELMo model 🤔


Some more questions:

  • The model.tar.gz contains config.json and weights.th. Could the weights.th file converted to a hdf5 file so that the ElmoEmbedder can be used?

  • I guess BidirectionalLanguageModelTokenEmbedder should be used:

     embedder = LanguageModelTokenEmbedder(archive_file='model.tar.gz')

    It is derived from LanguageModelTokenEmbedder, but this base class does not provide any "embed" methods :(


Thanks ❤️

@matt-peters
Copy link
Contributor

See: https://github.com/allenai/allennlp/blob/master/tutorials/how_to/training_transformer_elmo.md#using-transformer-elmo-with-existing-allennlp-models

The transformer and LSTM architectures are sufficiently different that it is not possible to convert the torch weights.th to hdf5 and use the ElmoEmbedder.

@stefan-it
Copy link
Author

Thanks for the explanation :) Unfortunately, the Transformer ELMo documentation only shows how to this via a jsonnet configuration. But how can this be done programmatically using the API?

@brendan-ai2
Copy link
Contributor

Hey @stefan-it, here's a code snippet:

from allennlp.modules.token_embedders.bidirectional_language_model_token_embedder import BidirectionalLanguageModelTokenEmbedder
from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
from allennlp.data.tokenizers.token import Token
import torch

lm_model_file = "/home/brendanr/workbenches/calypso/sers/full__2k_samples__8k_fd__NO_SCATTER__02/model.tar.gz"

sentence = "It is raining in Seattle ."
tokens = [Token(word) for word in sentence.split()]

lm_embedder = BidirectionalLanguageModelTokenEmbedder(
  archive_file=lm_model_file,
  dropout=0.2,
  bos_eos_tokens=["<S>", "</S>"],
  remove_bos_eos=True,
  requires_grad=False
)

indexer = ELMoTokenCharactersIndexer()
vocab = lm_embedder._lm.vocab
character_indices = indexer.tokens_to_indices(tokens, vocab, "elmo")["elmo"]

# Batch of size 1
indices_tensor = torch.LongTensor([character_indices])

# Embed and extract the single element from the batch.
embeddings = lm_embedder(indices_tensor)[0]

for word_embedding in embeddings:
  print(word_embedding)

Two things to note:

  1. This sidesteps our data loading and batching mechanisms for brevity. In case you or another reader aren't familiar with them, we have a tutorial at https://allennlp.org/tutorials that ties together our various abstractions.

  2. The BidirectionalLanguageModelTokenEmbedder returns a scalar mix of the layers in the language model. If you want access to these layers directly, I'd refer you to the embedder code here: https://github.com/allenai/allennlp/blob/master/allennlp/modules/token_embedders/language_model_token_embedder.py#L172

Let me know if you have any other questions!

@brendan-ai2 brendan-ai2 self-assigned this Jan 15, 2019
brendan-ai2 added a commit that referenced this issue Jan 16, 2019
- Improves documentation for using transformer ELMo.
- A user requested an example of how to use transformer ELMo directly.
- #2351
@stefan-it
Copy link
Author

Thanks so much! It is perfectly working :) Whenever I'll have questions about getting access to specific layers, I'll open a follow-up issue!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants