Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from openprompt.plms import load_plm | |
| from openprompt import PromptDataLoader | |
| from openprompt.prompts import ManualVerbalizer | |
| from openprompt.prompts import ManualTemplate | |
| from openprompt.data_utils import InputExample | |
| from openprompt import PromptForClassification | |
| def sentiment_analysis(sentence, template, model_name, positive, neutral, negative): | |
| model_name = model_name | |
| template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}') | |
| template = template.replace('[MASK]', '{"mask"}') | |
| classes = ['positive', 'neutral', 'negative'] | |
| label_words = { | |
| "positive": positive.split(" "), | |
| "neutral": neutral.split(" "), | |
| "negative": negative.split(" "), | |
| } | |
| type_dic = { | |
| "CCCC/ARCH_tuned_bert":"bert", | |
| "bert-base-uncased":"bert", | |
| "roberta-base":"roberta", | |
| "yiyanghkust/finbert-pretrain":"bert", | |
| "facebook/opt-125m":"opt", | |
| "facebook/opt-350m":"opt", | |
| } | |
| testdata = [InputExample(guid=0,text_a=sentence,label=0)] | |
| plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name) | |
| promptTemplate = ManualTemplate( | |
| text = template, | |
| tokenizer = tokenizer, | |
| ) | |
| promptVerbalizer = ManualVerbalizer( | |
| classes = classes, | |
| label_words = label_words, | |
| tokenizer = tokenizer, | |
| ) | |
| test_dataloader = PromptDataLoader( | |
| dataset = testdata, | |
| tokenizer = tokenizer, | |
| template = promptTemplate, | |
| tokenizer_wrapper_class = WrapperClass, | |
| batch_size = 1, | |
| max_seq_length = 512, | |
| ) | |
| prompt_model = PromptForClassification( | |
| plm=plm, | |
| template=promptTemplate, | |
| verbalizer=promptVerbalizer, | |
| freeze_plm=False #whether or not to freeze the pretrained language model | |
| ) | |
| for step, inputs in enumerate(test_dataloader): | |
| logits = prompt_model(inputs) | |
| return classes[torch.argmax(logits, dim=-1)[0]] | |
| demo = gr.Interface(fn=sentiment_analysis, | |
| inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"), | |
| gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"), | |
| gr.Radio(choices=["ARCH_tuned_RoBerta","FNCH_tuned_RoBerta","AREN_tuned_RoBerta","FNEN_tuned_RoBerta","bert-base-uncased"], label="model choics"), | |
| gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"), | |
| gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"), | |
| gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words") | |
| ], | |
| outputs="text", | |
| ) | |
| demo.launch() |