11from typing import Any , Union , Optional , Literal
22import ast , re
33import functools as ft
4- from datasets import load_dataset , Dataset as HFDataset
4+ from dataclasses import dataclass
5+ from datasets import load_dataset , Dataset as HFDataset , DatasetDict as HFDatasetDict
56import torch
67from transformers import (
78 AutoModelForSeq2SeqLM ,
2324from renard .utils import make_vocab
2425from sklearn .metrics import precision_recall_fscore_support
2526
26- #: (subject, relation , object)
27+ #: (subject, predicate , object)
2728Relation = tuple [Character , str , Character ]
2829
30+ ARF_VALID_NOVELS = {
31+ "Blue Jackets: The Log of the Teaser" ,
32+ "Nightmare Abbey" ,
33+ "The White Chief of the Caffres" ,
34+ }
35+ ARF_TEST_NOVELS = {
36+ "Molly Brown's Freshman Days" ,
37+ "Ancient Rome: The Lives of Great Men" ,
38+ "The White Chief of the Caffres" ,
39+ }
40+
2941
3042def _load_ARF_line (example : dict , tokenizer : PreTrainedTokenizerFast ) -> BatchEncoding :
3143 relations = ast .literal_eval (example ["relations" ] or "[]" )
@@ -37,7 +49,7 @@ def format_rel(rel: dict) -> str:
3749
3850 text = example ["chunk" ] or ""
3951 batch = tokenizer (
40- tokenizer .bos_token + GenerativeRelationExtractor .task_prompt (text ),
52+ tokenizer .bos_token + Seq2SeqRelationExtractor .task_prompt (text ),
4153 text_target = labels + tokenizer .eos_token ,
4254 add_special_tokens = False ,
4355 )
@@ -46,7 +58,7 @@ def format_rel(rel: dict) -> str:
4658 return batch
4759
4860
49- def load_ARF_dataset (tokenizer : PreTrainedTokenizerFast ) -> HFDataset :
61+ def load_ARF_dataset (tokenizer : PreTrainedTokenizerFast ) -> HFDatasetDict :
5062 """
5163 Load the Artificial Relationships in Fiction dataset
5264 (https://huggingface.co/datasets/Despina/project_gutenberg) by
@@ -57,8 +69,15 @@ def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HFDataset:
5769 "synthetic_relations_in_fiction_books" ,
5870 split = "train" ,
5971 )
60- dataset = dataset .train_test_split (test_size = 0.001 )
61- return dataset .map (ft .partial (_load_ARF_line , tokenizer = tokenizer ))
72+
73+ dataset = dataset .map (ft .partial (_load_ARF_line , tokenizer = tokenizer ))
74+
75+ ARF_TRAIN_NOVELS = set (dataset ["title" ]) - (ARF_VALID_NOVELS | ARF_TEST_NOVELS )
76+ train = dataset .filter (lambda example : example ["title" ] in ARF_TRAIN_NOVELS )
77+ valid = dataset .filter (lambda example : example ["title" ] in ARF_VALID_NOVELS )
78+ test = dataset .filter (lambda example : example ["title" ] in ARF_TEST_NOVELS )
79+
80+ return HFDatasetDict ({"train" : train , "valid" : valid , "test" : test }) # type: ignore
6281
6382
6483def _triple_precision_recall_f1 (
@@ -108,25 +127,27 @@ def train_model_on_ARF(
108127
109128 dataset = load_ARF_dataset (tokenizer )
110129
111- def compute_metrics (eval_preds : EvalPrediction ) -> dict [str , float ]:
130+ def compute_metrics (eval_preds ) -> dict [str , float ]:
112131 eval_preds .label_ids [eval_preds .label_ids == - 100 ] = pad_token_i
132+ eval_preds .predictions [eval_preds .predictions == - 100 ] = pad_token_i
113133
114134 labels_str = tokenizer .batch_decode (
115135 eval_preds .label_ids , skip_special_tokens = True
116136 )
117- labels = list (map (GenerativeRelationExtractor .parse_text_relations , labels_str ))
137+ labels = list (map (Seq2SeqRelationExtractor .parse_text_relations , labels_str ))
118138
119- pred_ids = eval_preds .predictions [0 ].argmax (axis = - 1 )
120- preds_str = tokenizer .batch_decode (pred_ids , skip_special_tokens = True )
121- preds = list (map (GenerativeRelationExtractor .parse_text_relations , preds_str ))
139+ preds_str = tokenizer .batch_decode (
140+ eval_preds .predictions , skip_special_tokens = True
141+ )
142+ preds = list (map (Seq2SeqRelationExtractor .parse_text_relations , preds_str ))
122143
123144 return _triple_precision_recall_f1 (labels , preds )
124145
125146 trainer = Trainer (
126147 model ,
127148 targs ,
128149 train_dataset = dataset ["train" ],
129- eval_dataset = dataset ["test " ],
150+ eval_dataset = dataset ["valid " ],
130151 data_collator = DataCollatorForSeq2Seq (tokenizer , model ),
131152 compute_metrics = compute_metrics ,
132153 )
@@ -135,7 +156,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
135156 return model
136157
137158
138- class GenerativeRelationExtractor (PipelineStep ):
159+ class Seq2SeqRelationExtractor (PipelineStep ):
139160 """
140161
141162 .. warning::
@@ -151,9 +172,7 @@ def __init__(
151172 batch_size : int = 1 ,
152173 device : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
153174 ):
154- self .model = (
155- GenerativeRelationExtractor .DEFAULT_MODEL if model is None else model
156- )
175+ self .model = Seq2SeqRelationExtractor .DEFAULT_MODEL if model is None else model
157176 self .hf_pipeline = None
158177 self .batch_size = batch_size
159178 if device == "auto" :
@@ -180,7 +199,7 @@ def __call__(
180199 # chunk as in the ARF dataset
181200 dataset = HFDataset .from_list (
182201 [
183- {"text" : GenerativeRelationExtractor .task_prompt (" " .join (sent ))}
202+ {"text" : Seq2SeqRelationExtractor .task_prompt (" " .join (sent ))}
184203 for sent in sentences
185204 ]
186205 )
@@ -190,17 +209,13 @@ def __call__(
190209 ):
191210 text_relations = out [0 ]["generated_text" ]
192211
193- raw_triples = GenerativeRelationExtractor .parse_text_relations (
194- text_relations
195- )
212+ raw_triples = Seq2SeqRelationExtractor .parse_text_relations (text_relations )
196213 triples = []
197214 for subj , rel , obj in raw_triples :
198- subj_char = GenerativeRelationExtractor .identify_character (
215+ subj_char = Seq2SeqRelationExtractor .identify_character (
199216 subj , characters
200217 )
201- obj_char = GenerativeRelationExtractor .identify_character (
202- obj , characters
203- )
218+ obj_char = Seq2SeqRelationExtractor .identify_character (obj , characters )
204219 if subj_char is None or obj_char is None or subj_char == obj_char :
205220 continue
206221 triples .append ((subj_char , rel , obj_char ))
0 commit comments