|
import os |
|
import textwrap |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from lavis.models import load_model_and_preprocess |
|
import openai |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
|
|
model, vis_processors, _ = load_model_and_preprocess( |
|
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device |
|
) |
|
|
|
openai.api_key = os.environ["OPENAI_API_KEY"] |
|
|
|
def generate_caption(image): |
|
pil_image = image.copy() |
|
image = vis_processors["eval"](image).unsqueeze(0).to(device) |
|
caption = model.generate({"image": image}) |
|
|
|
caption = "\n".join(caption) |
|
|
|
response = openai.ChatCompletion.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": "Escribe un meme chistoso en español a partir de la descripción de una imagen dada por el usuario. No uses emojis, ni comillas, ni saltos de línea. No es necesario que empieces con 'cuando'. El output del asistente solo debe ser el texto del meme. Debe ser corto pero chistoso."}, |
|
{"role": "user", "content": caption} |
|
], |
|
temperature=0.6 |
|
) |
|
|
|
meme_text = response.choices[0].message.content |
|
print(meme_text) |
|
|
|
|
|
draw = ImageDraw.Draw(pil_image) |
|
|
|
|
|
font_size = 60 |
|
font = ImageFont.truetype("impact.ttf", font_size) |
|
|
|
|
|
alphabet = "ABCEMOPQRSTWXZ" |
|
total_char_width = sum(draw.textlength(char, font=font) for char in alphabet) |
|
average_char_width = total_char_width / len(alphabet) |
|
|
|
|
|
chars_per_line = int(pil_image.width / average_char_width) |
|
|
|
|
|
wrapped_text = textwrap.fill(meme_text, width=chars_per_line) |
|
|
|
|
|
text_lines = wrapped_text.split('\n') |
|
y = 10 |
|
for line in text_lines: |
|
line_width = draw.textlength(line, font=font) |
|
line_mask = font.getmask(line) |
|
_, line_height = line_mask.size |
|
x = (pil_image.width - line_width) // 2 |
|
|
|
draw.text((x, y), line, fill=(255, 255, 255), font=font) |
|
y += line_height + int(line_height * 0.1) |
|
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return pil_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"### Memero - Generador de Memes" |
|
) |
|
gr.Markdown( |
|
"Genera un meme en español a partir de una imagen." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Imagen", type="pil") |
|
btn_caption = gr.Button("Generar meme") |
|
output_text = gr.Image(label="Meme", lines=5) |
|
|
|
btn_caption.click( |
|
generate_caption, inputs=[input_image], outputs=[output_text] |
|
) |
|
|
|
demo.launch() |