Skip to content

Commit 33626a7

Browse files
committed
improve training for Seq2SeqRelationExtractor
1 parent cba71f4 commit 33626a7

2 files changed

Lines changed: 41 additions & 26 deletions

File tree

renard/pipeline/preconfigured.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def relational_pipeline(
128128
"""
129129
from renard.pipeline.tokenization import NLTKTokenizer
130130
from renard.pipeline.character_unification import GraphRulesCharacterUnifier
131-
from renard.pipeline.relation_extraction import GenerativeRelationExtractor
131+
from renard.pipeline.relation_extraction import Seq2SeqRelationExtractor
132132
from renard.pipeline.graph_extraction import RelationalGraphExtractor
133133

134134
tokenizer_kwargs = tokenizer_kwargs or {}
@@ -142,7 +142,7 @@ def relational_pipeline(
142142
NLTKTokenizer(**tokenizer_kwargs),
143143
ner_step(**ner_kwargs),
144144
GraphRulesCharacterUnifier(**character_unifier_kwargs),
145-
GenerativeRelationExtractor(**relation_extractor_kwargs),
145+
Seq2SeqRelationExtractor(**relation_extractor_kwargs),
146146
RelationalGraphExtractor(**graph_extractor_kwargs),
147147
],
148148
**pipeline_kwargs,

renard/pipeline/relation_extraction.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Union, Optional, Literal
22
import ast, re
33
import functools as ft
4-
from datasets import load_dataset, Dataset as HFDataset
4+
from dataclasses import dataclass
5+
from datasets import load_dataset, Dataset as HFDataset, DatasetDict as HFDatasetDict
56
import torch
67
from transformers import (
78
AutoModelForSeq2SeqLM,
@@ -23,9 +24,20 @@
2324
from renard.utils import make_vocab
2425
from sklearn.metrics import precision_recall_fscore_support
2526

26-
#: (subject, relation, object)
27+
#: (subject, predicate, object)
2728
Relation = tuple[Character, str, Character]
2829

30+
ARF_VALID_NOVELS = {
31+
"Blue Jackets: The Log of the Teaser",
32+
"Nightmare Abbey",
33+
"The White Chief of the Caffres",
34+
}
35+
ARF_TEST_NOVELS = {
36+
"Molly Brown's Freshman Days",
37+
"Ancient Rome: The Lives of Great Men",
38+
"The White Chief of the Caffres",
39+
}
40+
2941

3042
def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> BatchEncoding:
3143
relations = ast.literal_eval(example["relations"] or "[]")
@@ -37,7 +49,7 @@ def format_rel(rel: dict) -> str:
3749

3850
text = example["chunk"] or ""
3951
batch = tokenizer(
40-
tokenizer.bos_token + GenerativeRelationExtractor.task_prompt(text),
52+
tokenizer.bos_token + Seq2SeqRelationExtractor.task_prompt(text),
4153
text_target=labels + tokenizer.eos_token,
4254
add_special_tokens=False,
4355
)
@@ -46,7 +58,7 @@ def format_rel(rel: dict) -> str:
4658
return batch
4759

4860

49-
def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HFDataset:
61+
def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HFDatasetDict:
5062
"""
5163
Load the Artificial Relationships in Fiction dataset
5264
(https://huggingface.co/datasets/Despina/project_gutenberg) by
@@ -57,8 +69,15 @@ def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HFDataset:
5769
"synthetic_relations_in_fiction_books",
5870
split="train",
5971
)
60-
dataset = dataset.train_test_split(test_size=0.001)
61-
return dataset.map(ft.partial(_load_ARF_line, tokenizer=tokenizer))
72+
73+
dataset = dataset.map(ft.partial(_load_ARF_line, tokenizer=tokenizer))
74+
75+
ARF_TRAIN_NOVELS = set(dataset["title"]) - (ARF_VALID_NOVELS | ARF_TEST_NOVELS)
76+
train = dataset.filter(lambda example: example["title"] in ARF_TRAIN_NOVELS)
77+
valid = dataset.filter(lambda example: example["title"] in ARF_VALID_NOVELS)
78+
test = dataset.filter(lambda example: example["title"] in ARF_TEST_NOVELS)
79+
80+
return HFDatasetDict({"train": train, "valid": valid, "test": test}) # type: ignore
6281

6382

6483
def _triple_precision_recall_f1(
@@ -108,25 +127,27 @@ def train_model_on_ARF(
108127

109128
dataset = load_ARF_dataset(tokenizer)
110129

111-
def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
130+
def compute_metrics(eval_preds) -> dict[str, float]:
112131
eval_preds.label_ids[eval_preds.label_ids == -100] = pad_token_i
132+
eval_preds.predictions[eval_preds.predictions == -100] = pad_token_i
113133

114134
labels_str = tokenizer.batch_decode(
115135
eval_preds.label_ids, skip_special_tokens=True
116136
)
117-
labels = list(map(GenerativeRelationExtractor.parse_text_relations, labels_str))
137+
labels = list(map(Seq2SeqRelationExtractor.parse_text_relations, labels_str))
118138

119-
pred_ids = eval_preds.predictions[0].argmax(axis=-1)
120-
preds_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
121-
preds = list(map(GenerativeRelationExtractor.parse_text_relations, preds_str))
139+
preds_str = tokenizer.batch_decode(
140+
eval_preds.predictions, skip_special_tokens=True
141+
)
142+
preds = list(map(Seq2SeqRelationExtractor.parse_text_relations, preds_str))
122143

123144
return _triple_precision_recall_f1(labels, preds)
124145

125146
trainer = Trainer(
126147
model,
127148
targs,
128149
train_dataset=dataset["train"],
129-
eval_dataset=dataset["test"],
150+
eval_dataset=dataset["valid"],
130151
data_collator=DataCollatorForSeq2Seq(tokenizer, model),
131152
compute_metrics=compute_metrics,
132153
)
@@ -135,7 +156,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
135156
return model
136157

137158

138-
class GenerativeRelationExtractor(PipelineStep):
159+
class Seq2SeqRelationExtractor(PipelineStep):
139160
"""
140161
141162
.. warning::
@@ -151,9 +172,7 @@ def __init__(
151172
batch_size: int = 1,
152173
device: Literal["cpu", "cuda", "auto"] = "auto",
153174
):
154-
self.model = (
155-
GenerativeRelationExtractor.DEFAULT_MODEL if model is None else model
156-
)
175+
self.model = Seq2SeqRelationExtractor.DEFAULT_MODEL if model is None else model
157176
self.hf_pipeline = None
158177
self.batch_size = batch_size
159178
if device == "auto":
@@ -180,7 +199,7 @@ def __call__(
180199
# chunk as in the ARF dataset
181200
dataset = HFDataset.from_list(
182201
[
183-
{"text": GenerativeRelationExtractor.task_prompt(" ".join(sent))}
202+
{"text": Seq2SeqRelationExtractor.task_prompt(" ".join(sent))}
184203
for sent in sentences
185204
]
186205
)
@@ -190,17 +209,13 @@ def __call__(
190209
):
191210
text_relations = out[0]["generated_text"]
192211

193-
raw_triples = GenerativeRelationExtractor.parse_text_relations(
194-
text_relations
195-
)
212+
raw_triples = Seq2SeqRelationExtractor.parse_text_relations(text_relations)
196213
triples = []
197214
for subj, rel, obj in raw_triples:
198-
subj_char = GenerativeRelationExtractor.identify_character(
215+
subj_char = Seq2SeqRelationExtractor.identify_character(
199216
subj, characters
200217
)
201-
obj_char = GenerativeRelationExtractor.identify_character(
202-
obj, characters
203-
)
218+
obj_char = Seq2SeqRelationExtractor.identify_character(obj, characters)
204219
if subj_char is None or obj_char is None or subj_char == obj_char:
205220
continue
206221
triples.append((subj_char, rel, obj_char))

0 commit comments

Comments
 (0)