|
import os |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from lavis.models import load_model_and_preprocess |
|
from PIL import Image, ImageDraw, ImageFont |
|
import openai |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
|
|
model, vis_processors, _ = load_model_and_preprocess( |
|
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device |
|
) |
|
|
|
openai.api_key = os.environ["OPENAI_API_KEY"] |
|
|
|
def generate_caption(image): |
|
image = vis_processors["eval"](image).unsqueeze(0).to(device) |
|
caption = model.generate({"image": image}) |
|
|
|
caption = "\n".join(caption) |
|
|
|
response = openai.ChatCompletion.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"}, |
|
{"role": "user", "content": caption} |
|
], |
|
temperature=0.6 |
|
) |
|
|
|
meme_text = response.choices[0].message.content |
|
print(meme_text) |
|
|
|
image = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)).convert('RGB') |
|
|
|
meme_image = image.copy() |
|
draw = ImageDraw.Draw(meme_image) |
|
|
|
font_size = 1 |
|
text_width = 0 |
|
max_width = int(224 * 0.9) |
|
while text_width < max_width: |
|
font_size += 1 |
|
font = ImageFont.truetype("impact.ttf", font_size) |
|
text_width, _ = draw.textsize(meme_text, font=font) |
|
|
|
text_width, text_height = draw.textsize(meme_text, font=font) |
|
text_x = (224 - text_width) // 2 |
|
text_y = (224 - text_height) // 2 |
|
|
|
draw.text((text_x, 0), meme_text, font=font, fill=(255, 255, 255)) |
|
|
|
meme_image = meme_image.convert('RGB') |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return meme_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"### Memero - Generador de Memes" |
|
) |
|
gr.Markdown( |
|
"Escribe un meme en español a partir de una imagen." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Imagen", type="pil") |
|
btn_caption = gr.Button("Generar meme") |
|
output_text = gr.Image(label="Meme", lines=5) |
|
|
|
btn_caption.click( |
|
generate_caption, inputs=[input_image], outputs=[output_text] |
|
) |
|
|
|
demo.launch() |