File size: 1,596 Bytes
9b9d9ed
f61705b
9b9d9ed
 
f61705b
9b9d9ed
f61705b
9b9d9ed
 
 
f61705b
9b9d9ed
f61705b
9b9d9ed
f61705b
9b9d9ed
 
 
f61705b
9b9d9ed
 
 
 
 
f61705b
 
9b9d9ed
 
 
f61705b
 
9b9d9ed
f61705b
9b9d9ed
 
 
 
 
 
7365157
d782de7
 
7365157
9b9d9ed
f61705b
 
 
7365157
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# + tags=["hide_inp"]

desc = """
### Named Entity Recognition

Chain that does named entity recognition with arbitrary labels. [[Code](https://github.com/srush/MiniChain/blob/main/examples/ner.py)]

(Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja)).
"""
# -

# $

from minichain import prompt, show, OpenAI

@prompt(OpenAI(), template_file = "ner.pmpt.tpl", parser="json")
def ner_extract(model, **kwargs):
    return model(kwargs)

@prompt(OpenAI())
def team_describe(model, inp):
    query = "Can you describe these basketball teams? " + \
        " ".join([i["E"] for i in inp if i["T"] =="Team"])
    return model(query)


def ner(text_input, labels, domain):
    extract = ner_extract(dict(text_input=text_input, labels=labels, domain=domain))
    return team_describe(extract)


# $

gradio = show(ner,
              examples=[["An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.", "Team, Date", "Sports"]],
              description=desc,
              subprompts=[ner_extract, team_describe],
              code=open("ner.py", "r").read().split("$")[1].strip().strip("#").strip(),
              )

if __name__ == "__main__":
    gradio.launch()


# View prompt examples.

# + tags=["hide_inp"]
# NERPrompt().show(
#     {
#         "input": "I went to New York",
#         "domain": "Travel",
#         "labels": ["City"]
#     },
#     '[{"T": "City", "E": "New York"}]',
# )
# # -

# # View log.

# minichain.show_log("ner.log")