|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from NeuralTextGenerator import BertTextGenerator |
|
|
|
|
|
|
|
|
|
print('dfg') |
|
model_name = "JuanJoseMV/BERT_text_gen" |
|
en_model = BertTextGenerator(model_name) |
|
tokenizer = en_model.tokenizer |
|
model = en_model.model |
|
device = model.device |
|
|
|
def classify(sentiment): |
|
parameters = {'n_sentences': 10, |
|
'batch_size': 2, |
|
'avg_len':30, |
|
'max_len':50, |
|
|
|
'generation_method':'parallel', |
|
'sample': True, |
|
'burnin': 450, |
|
'max_iter': 500, |
|
'top_k': 100, |
|
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2]", |
|
|
|
} |
|
sents = en_model.generate(**parameters) |
|
gen_text = '\n'.join(sents) |
|
|
|
return gen_text |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown() |
|
inputs = gr.Dropdown(value=["POSITIVE", "NEGATIVE"], label="Sentiment to generate") |
|
output = gr.Textbox(label="Generated tweet") |
|
b1 = gr.Button("Generate") |
|
b1.click(classify, inputs=inputs, outputs=output) |
|
|
|
demo.launch() |