File size: 2,549 Bytes
fa9cc42 6231fde fa9cc42 6231fde fa9cc42 6231fde fa9cc42 6231fde fa9cc42 6231fde fa9cc42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
from PIL import Image, ImageDraw, ImageFont
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
image = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)).convert('RGB')
meme_image = image.copy()
draw = ImageDraw.Draw(meme_image)
font_size = 1
text_width = 0
max_width = int(224 * 0.9)
while text_width < max_width:
font_size += 1
font = ImageFont.truetype("impact.ttf", font_size)
text_width, _ = draw.textsize(meme_text, font=font)
text_width, text_height = draw.textsize(meme_text, font=font)
text_x = (224 - text_width) // 2
text_y = (224 - text_height) // 2
draw.text((text_x, 0), meme_text, font=font, fill=(255, 255, 255))
meme_image = meme_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return meme_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Escribe un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch() |