memero / app.py
andreinigo's picture
Update app.py
5e2bbd7
raw
history blame
2.84 kB
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
pil_image = image.copy() # Create a copy of the input PIL image
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
#put the meme text on the image
draw = ImageDraw.Draw(pil_image)
# Calculate font size
text_length = len(meme_text)
font_size = 18
font = ImageFont.truetype("impact.ttf", font_size)
text_width, text_height = draw.textsize(meme_text, font=font)
# Wrap the text to fit within the image width
wrapped_text = textwrap.fill(meme_text, width=int(pil_image.width / font_size))
# Calculate the position to place the text at the top and center horizontally
text_lines = wrapped_text.split('\n')
y = 10 # Adjust this value to add more or less padding from the top
for line in text_lines:
line_width, line_height = draw.textsize(line, font=font)
x = (pil_image.width - line_width) // 2
draw.text((x, y), line, fill=(255, 255, 255), font=font)
y += line_height
pil_image = pil_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Escribe un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch()