genius / app.py
beyond's picture
Create app.py
44213d9
raw
history blame
758 Bytes
import gradio as gr
from transformers import pipeline
# pipeline_en = pipeline(task="text2text-generation", model="beyond/genius-large")
pipeline_zh = pipeline(task="text2text-generation", model="beyond/genius-base-chinese")
def predict_en(sketch):
generated_text = pipeline_en(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
return generated_text
def predict_zh(sketch):
generated_text = pipeline_zh(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
return generated_text
gr.Interface(
predict_zh,
inputs=gr.inputs.Textbox(lines=7, placeholder='Input your sketch', label='Input'),
outputs=gr.outputs.Label(num_top_classes=2),
title="Sketch-based Text Generation",
).launch()