update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import json | |
import logging | |
from typing import Iterable, Optional, Sequence | |
import gradio as gr | |
from hydra.utils import instantiate | |
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger | |
# this is required to dynamically load the PIE models | |
from pie_modules.models import * # noqa: F403 | |
from pie_modules.taskmodules import * # noqa: F403 | |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE | |
from pytorch_ie import Pipeline | |
from pytorch_ie.annotations import LabeledSpan | |
from pytorch_ie.documents import ( | |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, | |
) | |
# this is required to dynamically load the PIE models | |
from pytorch_ie.models import * # noqa: F403 | |
from pytorch_ie.taskmodules import * # noqa: F403 | |
from src.utils import parse_config | |
logger = logging.getLogger(__name__) | |
def get_merger() -> SpansViaRelationMerger: | |
return SpansViaRelationMerger( | |
relation_layer="binary_relations", | |
link_relation_label="parts_of_same", | |
create_multi_spans=True, | |
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
result_field_mapping={ | |
"labeled_spans": "labeled_multi_spans", | |
"binary_relations": "binary_relations", | |
"labeled_partitions": "labeled_partitions", | |
}, | |
combine_scores_method="product", | |
) | |
def create_document( | |
text: str, doc_id: str, split_regex: Optional[str] = None | |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: | |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided | |
text. | |
Parameters: | |
text: The text to process. | |
doc_id: The ID of the document. | |
split_regex: A regular expression pattern to use for splitting the text into partitions. | |
Returns: | |
The processed document. | |
""" | |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( | |
id=doc_id, text=text, metadata={} | |
) | |
if split_regex is not None: | |
partitioner = RegexPartitioner( | |
pattern=split_regex, partition_layer_name="labeled_partitions" | |
) | |
document = partitioner(document) | |
else: | |
# add single partition from the whole text (the model only considers text in partitions) | |
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) | |
return document | |
def create_documents( | |
texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None | |
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]: | |
"""Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided | |
texts. | |
Parameters: | |
texts: The texts to process. | |
doc_ids: The IDs of the documents. | |
split_regex: A regular expression pattern to use for splitting the text into partitions. | |
Returns: | |
The processed documents. | |
""" | |
return [ | |
create_document(text=text, doc_id=doc_id, split_regex=split_regex) | |
for text, doc_id in zip(texts, doc_ids) | |
] | |
def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]: | |
try: | |
config = parse_config(config_str, format="yaml") | |
if config is None or config == {}: | |
gr.Warning("Empty argumentation model config provided. No model loaded.") | |
return None | |
# for PIE AutoPipeline, we need to handle the revision separately for | |
# the taskmodule and the model | |
if ( | |
config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained") | |
and "revision" in config | |
): | |
revision = config.pop("revision") | |
if "taskmodule_kwargs" not in config: | |
config["taskmodule_kwargs"] = {} | |
config["taskmodule_kwargs"]["revision"] = revision | |
if "model_kwargs" not in config: | |
config["model_kwargs"] = {} | |
config["model_kwargs"]["revision"] = revision | |
model = instantiate(config, **kwargs) | |
gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}") | |
except Exception as e: | |
raise gr.Error(f"Failed to load argumentation model: {e}") | |
return model | |
def set_relation_types( | |
argumentation_model: Pipeline, | |
default: Optional[Sequence[str]] = None, | |
) -> gr.Dropdown: | |
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): | |
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] | |
else: | |
raise gr.Error("Unsupported taskmodule for relation types") | |
return gr.Dropdown( | |
choices=relation_types, | |
label="Argumentative Relation Types", | |
value=default, | |
multiselect=True, | |
) | |