ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import json
import logging
from typing import Iterable, Optional, Sequence
import gradio as gr
from hydra.utils import instantiate
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
# this is required to dynamically load the PIE models
from pie_modules.models import * # noqa: F403
from pie_modules.taskmodules import * # noqa: F403
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import (
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
# this is required to dynamically load the PIE models
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.taskmodules import * # noqa: F403
from src.utils import parse_config
logger = logging.getLogger(__name__)
def get_merger() -> SpansViaRelationMerger:
return SpansViaRelationMerger(
relation_layer="binary_relations",
link_relation_label="parts_of_same",
create_multi_spans=True,
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
result_field_mapping={
"labeled_spans": "labeled_multi_spans",
"binary_relations": "binary_relations",
"labeled_partitions": "labeled_partitions",
},
combine_scores_method="product",
)
def create_document(
text: str, doc_id: str, split_regex: Optional[str] = None
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
text.
Parameters:
text: The text to process.
doc_id: The ID of the document.
split_regex: A regular expression pattern to use for splitting the text into partitions.
Returns:
The processed document.
"""
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
id=doc_id, text=text, metadata={}
)
if split_regex is not None:
partitioner = RegexPartitioner(
pattern=split_regex, partition_layer_name="labeled_partitions"
)
document = partitioner(document)
else:
# add single partition from the whole text (the model only considers text in partitions)
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
return document
def create_documents(
texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
"""Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
texts.
Parameters:
texts: The texts to process.
doc_ids: The IDs of the documents.
split_regex: A regular expression pattern to use for splitting the text into partitions.
Returns:
The processed documents.
"""
return [
create_document(text=text, doc_id=doc_id, split_regex=split_regex)
for text, doc_id in zip(texts, doc_ids)
]
def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]:
try:
config = parse_config(config_str, format="yaml")
if config is None or config == {}:
gr.Warning("Empty argumentation model config provided. No model loaded.")
return None
# for PIE AutoPipeline, we need to handle the revision separately for
# the taskmodule and the model
if (
config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained")
and "revision" in config
):
revision = config.pop("revision")
if "taskmodule_kwargs" not in config:
config["taskmodule_kwargs"] = {}
config["taskmodule_kwargs"]["revision"] = revision
if "model_kwargs" not in config:
config["model_kwargs"] = {}
config["model_kwargs"]["revision"] = revision
model = instantiate(config, **kwargs)
gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
except Exception as e:
raise gr.Error(f"Failed to load argumentation model: {e}")
return model
def set_relation_types(
argumentation_model: Pipeline,
default: Optional[Sequence[str]] = None,
) -> gr.Dropdown:
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
else:
raise gr.Error("Unsupported taskmodule for relation types")
return gr.Dropdown(
choices=relation_types,
label="Argumentative Relation Types",
value=default,
multiselect=True,
)