File size: 2,716 Bytes
fa9cc42 697f4c6 2892d42 fa9cc42 2892d42 fa9cc42 437e10c fa9cc42 437e10c 2892d42 fa9cc42 437e10c 2892d42 437e10c 2892d42 437e10c 2892d42 fa9cc42 5e2bbd7 fa9cc42 5e2bbd7 fa9cc42 d2a1b04 fa9cc42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
pil_image = image.copy() # Create a copy of the input PIL image
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso en español para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
# Put the meme text on the image
draw = ImageDraw.Draw(pil_image)
# Determine font size
max_width = int(pil_image.width * 0.9)
font_size = int(max_width / (len(meme_text) / 2))
font = ImageFont.truetype("impact.ttf", font_size)
# Wrap the text to fit within the image width and have a maximum of 2 lines
wrapped_text = textwrap.fill(meme_text, width=int(max_width / font.getsize("A")[0]))
text_lines = wrapped_text.split('\n')
y = 10
for line in text_lines:
line_width, line_height = draw.textsize(line, font=font)
x = (pil_image.width - line_width) // 2
draw.text((x, y), line, fill=(255, 255, 255), font=font)
y += line_height
pil_image = pil_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Genera un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch() |