From 6f6c37f81c89812b938189aa469d035dc1bd829d Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Thu, 19 Sep 2019 11:20:39 +0200 Subject: [PATCH 1/4] another reduce of number of concatenations --- flair/models/language_model.py | 4 ++- flair/models/sequence_tagger_model.py | 36 ++++++++++++++------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 5388e116d..3c506b6f3 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -144,7 +144,9 @@ def get_representation( ] sequences_as_char_indices.append(char_indices) - t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to(device=flair.device, non_blocking=True) + t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to( + device=flair.device, non_blocking=True + ) batches.append(t) output_parts = [] diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index c15926840..6f78f7ed7 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -452,30 +452,32 @@ def forward(self, sentences: List[Sentence]): lengths: List[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) - # initialize zero-padded word embeddings tensor - sentence_tensor = torch.zeros( - [ - len(sentences), - longest_token_sequence_in_batch, - self.embeddings.embedding_length, - ], + pre_allocated_zero_tensor = t = torch.zeros( + self.embeddings.embedding_length * longest_token_sequence_in_batch, dtype=torch.float, device=flair.device, ) - + all_embs = list() for s_id, sentence in enumerate(sentences): - all_embs = list() - - for index_token, token in enumerate(sentence): + for token in sentence: embs = token.get_each_embedding() - if not all_embs: - all_embs = [list() for _ in range(len(embs))] for index_emb, emb in enumerate(embs): - all_embs[index_emb].append(emb) + all_embs.append(emb) + nb_padding_token = longest_token_sequence_in_batch - len(sentence) + + if nb_padding_token > 0: + t = pre_allocated_zero_tensor[ + : self.embeddings.embedding_length * nb_padding_token + ] + all_embs.append(t) - concat_word_emb = [torch.stack(embs) for embs in all_embs] - concat_sentence_emb = torch.cat(concat_word_emb, dim=1) - sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb + sentence_tensor = torch.cat(all_embs).view( + [ + len(sentences), + longest_token_sequence_in_batch, + self.embeddings.embedding_length, + ] + ) # -------------------------------------------------------------------- # FF PART From 777b0b125a9192468187249f00771ec1c4f7e610 Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Thu, 19 Sep 2019 11:22:47 +0200 Subject: [PATCH 2/4] small fix --- flair/models/sequence_tagger_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 6f78f7ed7..deb8ce8e7 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -452,7 +452,7 @@ def forward(self, sentences: List[Sentence]): lengths: List[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) - pre_allocated_zero_tensor = t = torch.zeros( + pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, dtype=torch.float, device=flair.device, From 1ae385cf7867282957ded69171fed70d04dbb6b5 Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Thu, 19 Sep 2019 11:34:22 +0200 Subject: [PATCH 3/4] cleaning --- flair/models/sequence_tagger_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index deb8ce8e7..07f5b5df5 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -458,16 +458,16 @@ def forward(self, sentences: List[Sentence]): device=flair.device, ) all_embs = list() - for s_id, sentence in enumerate(sentences): + for sentence in sentences: for token in sentence: embs = token.get_each_embedding() - for index_emb, emb in enumerate(embs): + for emb in embs: all_embs.append(emb) - nb_padding_token = longest_token_sequence_in_batch - len(sentence) + nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) - if nb_padding_token > 0: + if nb_padding_tokens > 0: t = pre_allocated_zero_tensor[ - : self.embeddings.embedding_length * nb_padding_token + : self.embeddings.embedding_length * nb_padding_tokens ] all_embs.append(t) From 9d7bc4dd7ce760322a00762ba9f7c57dabd65959 Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Thu, 19 Sep 2019 11:39:46 +0200 Subject: [PATCH 4/4] cleaning --- flair/models/sequence_tagger_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 07f5b5df5..dc515f9bf 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -457,12 +457,10 @@ def forward(self, sentences: List[Sentence]): dtype=torch.float, device=flair.device, ) + all_embs = list() for sentence in sentences: - for token in sentence: - embs = token.get_each_embedding() - for emb in embs: - all_embs.append(emb) + all_embs += [emb for token in sentence for emb in token.get_each_embedding()] nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) if nb_padding_tokens > 0: