File size: 4,768 Bytes
376a51c
 
 
 
 
 
 
 
 
 
1818a6a
376a51c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import subprocess
try:
    import gradio as gr
    import spacy
    import glirel
except:
    subprocess.run(["pip", "install", "gradio==4.31.5"])
    subprocess.run(["pip", "install", "spacy"])
    subprocess.run(["pip", "install", "glirel"])
    subprocess.run(["pip", "install", "scipy==1.10.1"])
    subprocess.run(["pip", "install", "numpy==1.26.4"])

try:
    import spacy
    spacy.load("en_core_web_sm")
except:
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_md"])
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_lg"])




from typing import Dict, Union
import gradio as gr
from glirel import GLiREL
import spacy

examples = [
    [
        "Amazon, founded by Jeff Bezos, is a leader in e-commerce and cloud computing. The company has also ventured into artificial intelligence and digital streaming.",
        "en_core_web_sm",
        "Founded_By, Located_In, Produces, Operates_In, Works_With, Known_For, Headquartered_In, Partnership_With, Innovates_In, Established_In",
    ],
    [
        "J.K. Rowling, the author of the Harry Potter series, has significantly impacted modern literature. Her books have been translated into numerous languages and adapted into successful films.",
        "en_core_web_sm",
        "Translated_Into, Adapted_Into, Born_In, Author_Of, Known_For, Works_With, Located_In, Writes_For, Produced_By, Published_By"
    ],
    [
        "Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company is headquartered in Cupertino, California.",
        "en_core_web_sm",
        "CO_FOUNDER, HEADQUARTERED_IN, FOUNDED_BY, LOCATED_IN, ESTABLISHED_IN, PARTNERSHIP_WITH, WORKS_WITH, KNOWN_FOR"
    ]

]


# Load the relation extraction model
rel_model = GLiREL.from_pretrained("jackboyla/glirel_beta")

# Function to perform Named Entity Recognition
def perform_ner(text, model_name):
    nlp = spacy.load(model_name)
    doc = nlp(text)
    return doc

# Function to extract relations
def extract_relations(tokens, ner, labels):
    relations = rel_model.predict_relations(tokens, labels, threshold=0.0, ner=ner, top_k=1)
    sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
    return sorted_data_desc

def format_ner(text, ner):
    if isinstance(ner[0], spacy.tokens.Span):
        # if ner is spacy entities; otherwise we assume the format is correct
        ner = [[ent.start_char, ent.end_char, ent.label_, ent.text] for ent in ner]
    return {
        "text": text,
        "entities": [
            {
                "entity": entity[2],
                "word": entity[3],
                "start": entity[0],
                "end": entity[1],
                "score": 0,
            }
            for entity in ner
        ],
    }

# Gradio Interface
def process(text, model_name, labels):
    doc = perform_ner(text, model_name)
    tokens = [token.text for token in doc]
    ner = [[ent.start, (ent.end-1), ent.label_, ent.text] for ent in doc.ents]
    labels = labels.split(',')
    relations = extract_relations(tokens, ner, labels)
    print(relations)
    formatted_ner = format_ner(doc.text, doc.ents)
    formatted_rel = ""
    for item in relations:
        formatted_rel += f"{item['head_text']} --> {item['label']} --> {item['tail_text']} \t\t| score: {item['score']}\n" 
    return formatted_ner, formatted_rel

# Gradio App Layout
with gr.Blocks() as demo:

    gr.Markdown("# 🕵️‍♀️GLiREL: Zero-Shot Relation Extraction")
    gr.Markdown("GitHub: https://github.com/jackboyla/GLiREL")

    text_input = gr.Textbox(label="Input Text", value="Jack lives in London but he was born in Mongolia.")
    model_name_input = gr.Dropdown(choices=["en_core_web_sm", "en_core_web_md", "en_core_web_lg"], label="NER Model", value="en_core_web_sm")
    labels_input = gr.Textbox(label="Relation Labels (comma-separated)", value="country of origin, licensed to broadcast to, father, followed by, characters")

    ner_output = gr.HighlightedText(label="NER")
    rel_output = gr.Textbox(label="Relation Extraction Results")

    extract_button = gr.Button("Extract Relations")
    extract_button.click(
        fn=process,
        inputs=[text_input, model_name_input, labels_input],
        outputs=[ner_output, rel_output]
    )

    examples = gr.Examples(
        examples,
        fn=process,
        inputs=[text_input, model_name_input, labels_input],
        outputs=[ner_output, rel_output],
        cache_examples=True,
    )


if __name__ == "__main__":
    demo.launch(server_port=9989)