Skip to content

Commit

Permalink
Merge pull request #1130 from pommedeterresautee/concat_reduce
Browse files Browse the repository at this point in the history
Another strong reduce of concatenation for a small optimization (-4% inference time)
  • Loading branch information
yosipk authored Sep 19, 2019
2 parents 64dabf7 + 9d7bc4d commit 8d65cc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
4 changes: 3 additions & 1 deletion flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def get_representation(
]

sequences_as_char_indices.append(char_indices)
t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to(device=flair.device, non_blocking=True)
t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to(
device=flair.device, non_blocking=True
)
batches.append(t)

output_parts = []
Expand Down
38 changes: 19 additions & 19 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,31 +452,31 @@ def forward(self, sentences: List[Sentence]):
lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
longest_token_sequence_in_batch: int = max(lengths)

# initialize zero-padded word embeddings tensor
sentence_tensor = torch.zeros(
pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
device=flair.device,
)

all_embs = list()
for sentence in sentences:
all_embs += [emb for token in sentence for emb in token.get_each_embedding()]
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)

if nb_padding_tokens > 0:
t = pre_allocated_zero_tensor[
: self.embeddings.embedding_length * nb_padding_tokens
]
all_embs.append(t)

sentence_tensor = torch.cat(all_embs).view(
[
len(sentences),
longest_token_sequence_in_batch,
self.embeddings.embedding_length,
],
dtype=torch.float,
device=flair.device,
]
)

for s_id, sentence in enumerate(sentences):
all_embs = list()

for index_token, token in enumerate(sentence):
embs = token.get_each_embedding()
if not all_embs:
all_embs = [list() for _ in range(len(embs))]
for index_emb, emb in enumerate(embs):
all_embs[index_emb].append(emb)

concat_word_emb = [torch.stack(embs) for embs in all_embs]
concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

# --------------------------------------------------------------------
# FF PART
# --------------------------------------------------------------------
Expand Down

0 comments on commit 8d65cc6

Please sign in to comment.