File size: 4,938 Bytes
ced4316
3133b5e
d868d2e
3133b5e
 
ced4316
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
3133b5e
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
e7eaeed
ced4316
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868d2e
3133b5e
ced4316
d868d2e
 
 
ced4316
 
 
 
d868d2e
ced4316
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
import logging
from typing import Iterable, Optional, Sequence

import gradio as gr
from hydra.utils import instantiate
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger

# this is required to dynamically load the PIE models
from pie_modules.models import *  # noqa: F403
from pie_modules.taskmodules import *  # noqa: F403
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import (
    TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
    TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)

# this is required to dynamically load the PIE models
from pytorch_ie.models import *  # noqa: F403
from pytorch_ie.taskmodules import *  # noqa: F403

from src.utils import parse_config

logger = logging.getLogger(__name__)


def get_merger() -> SpansViaRelationMerger:
    return SpansViaRelationMerger(
        relation_layer="binary_relations",
        link_relation_label="parts_of_same",
        create_multi_spans=True,
        result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
        result_field_mapping={
            "labeled_spans": "labeled_multi_spans",
            "binary_relations": "binary_relations",
            "labeled_partitions": "labeled_partitions",
        },
        combine_scores_method="product",
    )


def create_document(
    text: str, doc_id: str, split_regex: Optional[str] = None
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
    """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
    text.

    Parameters:
        text: The text to process.
        doc_id: The ID of the document.
        split_regex: A regular expression pattern to use for splitting the text into partitions.

    Returns:
        The processed document.
    """

    document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
        id=doc_id, text=text, metadata={}
    )
    if split_regex is not None:
        partitioner = RegexPartitioner(
            pattern=split_regex, partition_layer_name="labeled_partitions"
        )
        document = partitioner(document)
    else:
        # add single partition from the whole text (the model only considers text in partitions)
        document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
    return document


def create_documents(
    texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
    """Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
    texts.

    Parameters:
        texts: The texts to process.
        doc_ids: The IDs of the documents.
        split_regex: A regular expression pattern to use for splitting the text into partitions.

    Returns:
        The processed documents.
    """
    return [
        create_document(text=text, doc_id=doc_id, split_regex=split_regex)
        for text, doc_id in zip(texts, doc_ids)
    ]


def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]:
    try:
        config = parse_config(config_str, format="yaml")
        if config is None or config == {}:
            gr.Warning("Empty argumentation model config provided. No model loaded.")
            return None

        # for PIE AutoPipeline, we need to handle the revision separately for
        # the taskmodule and the model
        if (
            config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained")
            and "revision" in config
        ):
            revision = config.pop("revision")
            if "taskmodule_kwargs" not in config:
                config["taskmodule_kwargs"] = {}
            config["taskmodule_kwargs"]["revision"] = revision
            if "model_kwargs" not in config:
                config["model_kwargs"] = {}
            config["model_kwargs"]["revision"] = revision
        model = instantiate(config, **kwargs)
        gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
    except Exception as e:
        raise gr.Error(f"Failed to load argumentation model: {e}")

    return model


def set_relation_types(
    argumentation_model: Pipeline,
    default: Optional[Sequence[str]] = None,
) -> gr.Dropdown:
    if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
        relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
    else:
        raise gr.Error("Unsupported taskmodule for relation types")

    return gr.Dropdown(
        choices=relation_types,
        label="Argumentative Relation Types",
        value=default,
        multiselect=True,
    )