Skip to content

Commit

Permalink
Merge pull request #2492 from flairNLP/relation-models
Browse files Browse the repository at this point in the history
Add pretrained relation extraction models
  • Loading branch information
alanakbik authored Oct 30, 2021
2 parents 06a78c0 + 30d4da0 commit 46569a6
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions flair/models/relation_extractor_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import List, Union, Tuple, Optional

import torch
Expand All @@ -7,6 +8,7 @@
import flair.embeddings
import flair.nn
from flair.data import DataPoint, RelationLabel, Span, Sentence
from flair.file_utils import cached_path

log = logging.getLogger("flair")

Expand All @@ -22,9 +24,9 @@ def __init__(
entity_pair_filters: List[Tuple[str, str]] = None,
pooling_operation: str = "first_last",
dropout_value: float = 0.0,
locked_dropout_value: float = 0.0,
locked_dropout_value: float = 0.1,
word_dropout_value: float = 0.0,
non_linear_decoder: Optional[int] = None,
non_linear_decoder: Optional[int] = 2048,
**classifierargs,
):
"""
Expand Down Expand Up @@ -203,13 +205,15 @@ def forward_pass(self,

relation_embeddings.append(embedding)

# stack and drop out
all_relations = torch.stack(relation_embeddings)
# stack and drop out (squeeze and unsqueeze)
all_relations = torch.stack(relation_embeddings).unsqueeze(1)

all_relations = self.dropout(all_relations)
all_relations = self.locked_dropout(all_relations)
all_relations = self.word_dropout(all_relations)

all_relations = all_relations.squeeze(1)

# send through decoder
if self.non_linear_decoder:
sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations)))
Expand Down Expand Up @@ -266,6 +270,22 @@ def _init_model_with_state_dict(state):
def label_type(self):
return self._label_type

@staticmethod
def _fetch_model(model_name) -> str:

model_map = {}

hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"

model_map["relations-fast"] = "/".join([hu_path, "relations-fast", "relations-fast.pt"])
model_map["relations"] = "/".join([hu_path, "relations", "relations.pt"])

cache_dir = Path("models")
if model_name in model_map:
model_name = cached_path(model_map[model_name], cache_dir=cache_dir)

return model_name


def create_position_string(head: Span, tail: Span) -> str:
return f"{head.id_text} -> {tail.id_text}"

0 comments on commit 46569a6

Please sign in to comment.