memero / app.py
andreinigo's picture
Update app.py
697f4c6
raw
history blame
2.85 kB
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
pil_image = image.copy() # Create a copy of the input PIL image
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
#put the meme text on the image
draw = ImageDraw.Draw(pil_image)
# Calculate font size
text_length = len(meme_text)
font_size = 18
font = ImageFont.truetype("impact.ttf", font_size)
text_width, text_height = draw.textsize(meme_text, font=font)
# Wrap the text to fit within the image width
wrapped_text = textwrap.fill(meme_text, width=int(pil_image.width / font_size))
# Calculate the position to place the text at the top and center horizontally
text_lines = wrapped_text.split('\n')
y = 10 # Adjust this value to add more or less padding from the top
for line in text_lines:
line_width, line_height = draw.textsize(line, font=font)
x = (pil_image.width - line_width) // 2
draw.text((x, y), line, fill=(255, 255, 255), font=font)
y += line_height
meme_image = meme_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return meme_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Escribe un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch()