memero / app.py
andreinigo's picture
Update app.py
40895e5
raw
history blame
2.75 kB
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
pil_image = image.copy() # Create a copy of the input PIL image
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso en español para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
# Put the meme text on the image
draw = ImageDraw.Draw(pil_image)
# Set the fixed font size to 80
font_size = 60
font = ImageFont.truetype("impact.ttf", font_size)
# Calculate the width for text wrap
wrap_width = len(meme_text) // 2
# Wrap the text to fit within the image width and have a maximum of 2 lines
wrapped_text = textwrap.wrap(meme_text, width=wrap_width)
text_lines = wrapped_text
y = 10
for line in text_lines:
line_width = draw.textlength(line, font=font)
_, line_height = font.getmetrics()
x = (pil_image.width - line_width) // 2
draw.text((x, y), line, fill=(255, 255, 255), font=font)
y += line_height + int(line_height * 0.1)
pil_image = pil_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Genera un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch()