File size: 3,638 Bytes
27df4ce
c1f44fc
27df4ce
 
50526d9
 
27df4ce
 
62446de
48f3abc
27df4ce
897869f
27df4ce
 
 
 
 
897869f
529bd38
 
 
b943b58
529bd38
 
27df4ce
 
50526d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b943b58
a653ebc
27df4ce
 
b943b58
 
27df4ce
 
 
 
 
 
 
043214e
b943b58
27df4ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import re 
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import spacy
from spacy.matcher import Matcher 
device='cpu'

processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("nkasmanoff/git-planet").to(device)

nlp = spacy.load('en_core_web_sm')

def predict(image,max_length=64,device='cpu'):
    pixel_values = processor(images=image, return_tensors="pt").to(device).pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    relation = get_relation(generated_caption)

    entity_pair = get_entities(generated_caption)

    knowlege_triplet = f"'{entity_pair[0]}'---{relation}--->'{entity_pair[1]}'"

    return knowlege_triplet 


def get_entities(sent):
    ## chunk 1
    ent1 = ""
    ent2 = ""

    prv_tok_dep = ""  # dependency tag of previous token in the sentence
    prv_tok_text = ""  # previous token in the sentence

    prefix = ""
    modifier = ""

    #############################################################

    for tok in nlp(sent):
        ## chunk 2
        # if token is a punctuation mark then move on to the next token
        if tok.dep_ != "punct":
            # check: token is a compound word or not
            if tok.dep_ == "compound":
                prefix = tok.text
                # if the previous word was also a 'compound' then add the current word to it
                if prv_tok_dep == "compound":
                    prefix = prv_tok_text + " " + tok.text

            # check: token is a modifier or not
            if tok.dep_.endswith("mod") == True:
                modifier = tok.text
                # if the previous word was also a 'compound' then add the current word to it
                if prv_tok_dep == "compound":
                    modifier = prv_tok_text + " " + tok.text

            ## chunk 3
            if tok.dep_.find("subj") == True:
                ent1 = modifier + " " + prefix + " " + tok.text
                prefix = ""
                modifier = ""
                prv_tok_dep = ""
                prv_tok_text = ""

                ## chunk 4
            if tok.dep_.find("obj") == True:
                ent2 = modifier + " " + prefix + " " + tok.text

            ## chunk 5  
            # update variables
            prv_tok_dep = tok.dep_
            prv_tok_text = tok.text
    #############################################################

    return [ent1.strip(), ent2.strip()]




def get_relation(sent):

    doc = nlp(sent)

    # Matcher class object 
    matcher = Matcher(nlp.vocab)

    #define the pattern 
    pattern = [{'DEP':'ROOT'},
            {'DEP':'prep','OP':"?"},
            {'DEP':'agent','OP':"?"},  
            {'POS':'ADJ','OP':"?"}] 

    matcher.add('matching_pattern', patterns=[pattern])
    matches = matcher(doc)
    k = len(matches) - 1

    span = doc[matches[k][1]:matches[k][2]] 

    return(span.text)



input = gr.inputs.Image(label="Please upload an image", type = 'pil', optional=True)
output = gr.outputs.Textbox(type="text",label="Captions")


title = "Satellite Image Knowledge Extraction"
description = "Provide an image taken from above, and receive back a corresponding head-relation-tail triplet that can be used to form a knowledge graph."

interface = gr.Interface(
        fn=predict,
        inputs = input,
        theme="grass",
        outputs=output,
        title=title,
        description=description
    
    )
interface.launch(debug=True)