Skip to content

Commit

Permalink
Merge pull request #3481 from janpf/triplets
Browse files Browse the repository at this point in the history
Adds DataTriples which act just like DataPairs
  • Loading branch information
alanakbik authored Jul 2, 2024
2 parents a2e0ba5 + 21eaad2 commit a852bff
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 1 deletion.
45 changes: 45 additions & 0 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def to_dict(self) -> Dict[str, typing.Any]:

DT = typing.TypeVar("DT", bound=DataPoint)
DT2 = typing.TypeVar("DT2", bound=DataPoint)
DT3 = typing.TypeVar("DT3", bound=DataPoint)


class _PartOfSentence(DataPoint, ABC):
Expand Down Expand Up @@ -1258,6 +1259,50 @@ def text(self):
TextPair = DataPair[Sentence, Sentence]


class DataTriple(DataPoint, typing.Generic[DT, DT2, DT3]):
def __init__(self, first: DT, second: DT2, third: DT3):
super().__init__()
self.first = first
self.second = second
self.third = third

def to(self, device: str, pin_memory: bool = False):
self.first.to(device, pin_memory)
self.second.to(device, pin_memory)
self.third.to(device, pin_memory)

def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
self.first.clear_embeddings(embedding_names)
self.second.clear_embeddings(embedding_names)
self.third.clear_embeddings(embedding_names)

@property
def embedding(self):
return torch.cat([self.first.embedding, self.second.embedding, self.third.embedding])

def __len__(self):
return len(self.first) + len(self.second) + len(self.third)

@property
def unlabeled_identifier(self):
return f"DataTriple: '{self.first.unlabeled_identifier}' + '{self.second.unlabeled_identifier}' + '{self.third.unlabeled_identifier}'"

@property
def start_position(self) -> int:
return self.first.start_position

@property
def end_position(self) -> int:
return self.first.end_position

@property
def text(self):
return self.first.text + " || " + self.second.text + "||" + self.third.text


TextTriple = DataTriple[Sentence, Sentence, Sentence]


class Image(DataPoint):
def __init__(self, data=None, imageURL=None) -> None:
super().__init__()
Expand Down
4 changes: 4 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@
SUPERGLUE_RTE,
DataPairCorpus,
DataPairDataset,
DataTripleCorpus,
DataTripleDataset,
OpusParallelCorpus,
ParallelTextCorpus,
ParallelTextDataset,
Expand Down Expand Up @@ -529,6 +531,8 @@
"SUPERGLUE_RTE",
"DataPairCorpus",
"DataPairDataset",
"DataTripleCorpus",
"DataTripleDataset",
"OpusParallelCorpus",
"ParallelTextCorpus",
"ParallelTextDataset",
Expand Down
268 changes: 267 additions & 1 deletion flair/datasets/text_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from typing import List, Optional, Union

import flair
from flair.data import Corpus, DataPair, FlairDataset, Sentence, TextPair, _iter_dataset
from flair.data import (
Corpus,
DataPair,
DataTriple,
FlairDataset,
Sentence,
TextPair,
TextTriple,
_iter_dataset,
)
from flair.datasets.base import find_train_dev_test_files
from flair.file_utils import cached_path, unpack_file, unzip_file

Expand Down Expand Up @@ -435,6 +444,263 @@ def __getitem__(self, index: int = 0) -> DataPair:
return self._make_data_pair(self.first_elements[index], self.second_elements[index])


class DataTripleCorpus(Corpus):
def __init__(
self,
data_folder: Union[str, Path],
columns: List[int] = [0, 1, 2, 3],
train_file=None,
test_file=None,
dev_file=None,
use_tokenizer: bool = True,
max_tokens_per_doc=-1,
max_chars_per_doc=-1,
in_memory: bool = True,
label_type: Optional[str] = None,
autofind_splits=True,
sample_missing_splits: bool = True,
skip_first_line: bool = False,
separator: str = "\t",
encoding: str = "utf-8",
):
r"""Corpus for tasks involving triples of sentences or paragraphs.
The data files are expected to be in column format where each line has a column
for the first sentence/paragraph, the second sentence/paragraph, the third sentence/paragraph and the labels, respectively. The columns must be separated by a given separator (default: '\t').
:param data_folder: base folder with the task data
:param columns: List that indicates the columns for the first sentence (first entry in the list),
the second sentence (second entry), the third sentence (third entry), and label (last entry).
default = [0,1,2,3]
:param train_file: the name of the train file
:param test_file: the name of the test file, if None, dev data is sampled from train (if sample_missing_splits is true)
:param dev_file: the name of the dev file, if None, dev data is sampled from train (if sample_missing_splits is true)
:param use_tokenizer: Whether or not to use in-built tokenizer
:param max_tokens_per_doc: If set, shortens sentences to this maximum number of tokens
:param max_chars_per_doc: If set, shortens sentences to this maximum number of characters
:param in_memory: If True, data will be saved in list of flair.data.DataTriple objects, otherwise we use lists with simple strings which need less space
:param label_type: Name of the label of the data triples
:param autofind_splits: If True, train/test/dev files will be automatically identified in the given data_folder
:param sample_missing_splits: If True, a missing train/test/dev file will be sampled from the available data
:param skip_first_line: If True, the first line of data files will be ignored
:param separator: Separator between columns in data files
:param encoding: Encoding of data files
:return: a Corpus with annotated train, dev, and test data
"""
# find train, dev, and test files if not specified
dev_file, test_file, train_file = find_train_dev_test_files(
data_folder,
dev_file,
test_file,
train_file,
autofind_splits=autofind_splits,
)

# create DataTripleDataset for train, test, and dev files, if they are given

train = (
DataTripleDataset(
train_file,
columns=columns,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
max_chars_per_doc=max_chars_per_doc,
in_memory=in_memory,
label_type=label_type,
skip_first_line=skip_first_line,
separator=separator,
encoding=encoding,
)
if train_file is not None
else None
)

test = (
DataTripleDataset(
test_file,
columns=columns,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
max_chars_per_doc=max_chars_per_doc,
in_memory=in_memory,
label_type=label_type,
skip_first_line=skip_first_line,
separator=separator,
encoding=encoding,
)
if test_file is not None
else None
)

dev = (
DataTripleDataset(
dev_file,
columns=columns,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
max_chars_per_doc=max_chars_per_doc,
in_memory=in_memory,
label_type=label_type,
skip_first_line=skip_first_line,
separator=separator,
encoding=encoding,
)
if dev_file is not None
else None
)

super().__init__(
train,
dev,
test,
sample_missing_splits=sample_missing_splits,
name=str(data_folder),
)


class DataTripleDataset(FlairDataset):
def __init__(
self,
path_to_data: Union[str, Path],
columns: List[int] = [0, 1, 2, 3],
max_tokens_per_doc=-1,
max_chars_per_doc=-1,
use_tokenizer=True,
in_memory: bool = True,
label_type: Optional[str] = None,
skip_first_line: bool = False,
separator: str = "\t",
encoding: str = "utf-8",
label: bool = True,
):
r"""Creates a Dataset for triples of sentences/paragraphs.
The file needs to be in a column format,
where each line has a column for the first sentence/paragraph, the second sentence/paragraph, the third sentence/paragraph and the label
seperated by e.g. '\t' (just like in the glue RTE-dataset https://gluebenchmark.com/tasks) .
For each data triple we create a flair.data.DataTriple object.
:param path_to_data: path to the data file
:param columns: list of integers that indicate the respective columns. The first entry is the column
for the first sentence, the second for the second sentence, the third for the third sentence,
and the fourth for the label. Default [0, 1, 2, 3]
:param max_tokens_per_doc: If set, shortens sentences to this maximum number of tokens
:param max_chars_per_doc: If set, shortens sentences to this maximum number of characters
:param use_tokenizer: Whether or not to use the in-built tokenizer
:param in_memory: If True, data will be saved in a list of flair.data.DataTriple objects, otherwise we use lists with simple strings which need less space
:param label_type: Name of the label of the data triples
:param skip_first_line: If True, the first line of the data file will be ignored
:param separator: Separator between columns in the data file
:param encoding: Encoding of the data file
:param label: If False, the dataset expects unlabeled data
"""
path_to_data = Path(path_to_data)

# stop if the file does not exist
assert path_to_data.exists()

self.in_memory = in_memory

self.use_tokenizer = use_tokenizer

self.max_tokens_per_doc = max_tokens_per_doc

self.label = label

assert label_type is not None
self.label_type = label_type

self.total_data_count: int = 0

if self.in_memory:
self.data_triples: List[DataTriple] = []
else:
self.first_elements: List[str] = []
self.second_elements: List[str] = []
self.third_elements: List[str] = []
self.labels: List[Optional[str]] = []

with open(str(path_to_data), encoding=encoding) as source_file:
source_line = source_file.readline()

if skip_first_line:
source_line = source_file.readline()

while source_line:
source_line_list = source_line.strip().split(separator)

first_element = source_line_list[columns[0]]
second_element = source_line_list[columns[1]]
third_element = source_line_list[columns[2]]

if self.label:
triple_label: Optional[str] = source_line_list[columns[3]]
else:
triple_label = None

if max_chars_per_doc > 0:
first_element = first_element[:max_chars_per_doc]
second_element = second_element[:max_chars_per_doc]
third_element = third_element[:max_chars_per_doc]

if self.in_memory:
data_triple = self._make_data_triple(first_element, second_element, third_element, triple_label)
self.data_triples.append(data_triple)
else:
self.first_elements.append(first_element)
self.second_elements.append(second_element)
self.third_elements.append(third_element)
if self.label:
self.labels.append(triple_label)

self.total_data_count += 1

source_line = source_file.readline()

# create a DataTriple object from strings
def _make_data_triple(
self, first_element: str, second_element: str, third_element: str, label: Optional[str] = None
):
first_sentence = Sentence(first_element, use_tokenizer=self.use_tokenizer)
second_sentence = Sentence(second_element, use_tokenizer=self.use_tokenizer)
third_sentence = Sentence(third_element, use_tokenizer=self.use_tokenizer)

if self.max_tokens_per_doc > 0:
first_sentence.tokens = first_sentence.tokens[: self.max_tokens_per_doc]
second_sentence.tokens = second_sentence.tokens[: self.max_tokens_per_doc]
third_sentence.tokens = third_sentence.tokens[: self.max_tokens_per_doc]

data_triple = TextTriple(first_sentence, second_sentence, third_sentence)

if label:
data_triple.add_label(typename=self.label_type, value=label)

return data_triple

def is_in_memory(self) -> bool:
return self.in_memory

def __len__(self):
return self.total_data_count

# if in_memory is True we return a DataTriple, otherwise we create one from the lists of strings
def __getitem__(self, index: int = 0) -> DataTriple:
if self.in_memory:
return self.data_triples[index]
elif self.label:
return self._make_data_triple(
self.first_elements[index],
self.second_elements[index],
self.third_elements[index],
self.labels[index],
)
else:
return self._make_data_triple(
self.first_elements[index], self.second_elements[index], self.third_elements[index]
)


class GLUE_RTE(DataPairCorpus):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions flair/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .tars_model import FewshotClassifier, TARSClassifier, TARSTagger
from .text_classification_model import TextClassifier
from .text_regression_model import TextRegressor
from .triple_classification_model import TextTripleClassifier
from .word_tagger_model import TokenClassifier, WordTagger

__all__ = [
Expand All @@ -22,6 +23,7 @@
"LanguageModel",
"Lemmatizer",
"TextPairClassifier",
"TextTripleClassifier",
"TextPairRegressor",
"RelationClassifier",
"RelationExtractor",
Expand Down
Loading

0 comments on commit a852bff

Please sign in to comment.