JuanJoseMV's picture
hotfix
7deb508
raw
history blame
1.44 kB
# import gradio as gr
# print('hello')
# import torch
# print('sdfsdf')
# def greet(sentiment):
# return "Hello " + sentiment + "!!"
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()
import gradio as gr
from NeuralTextGenerator import BertTextGenerator
# from transformers import pipeline
# generator = pipeline("sentiment-analysis")
print('dfg')
model_name = "JuanJoseMV/BERT_text_gen" #"dbmdz/bert-base-italian-uncased"
en_model = BertTextGenerator(model_name)
tokenizer = en_model.tokenizer
model = en_model.model
device = model.device
def classify(sentiment):
parameters = {'n_sentences': 10,
'batch_size': 2,
'avg_len':30,
'max_len':50,
# 'std_len' : 3,
'generation_method':'parallel',
'sample': True,
'burnin': 450,
'max_iter': 500,
'top_k': 100,
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2]",
# 'verbose': True
}
sents = en_model.generate(**parameters)
gen_text = '\n'.join(sents)
return gen_text
demo = gr.Blocks()
with demo:
gr.Markdown()
inputs = gr.Dropdown(value=["POSITIVE", "NEGATIVE"], label="Sentiment to generate")
output = gr.Textbox(label="Generated tweet")
b1 = gr.Button("Generate")
b1.click(classify, inputs=inputs, outputs=output)
demo.launch()