File size: 2,549 Bytes
fa9cc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6231fde
fa9cc42
6231fde
 
fa9cc42
6231fde
 
 
 
 
 
 
fa9cc42
 
6231fde
 
fa9cc42
6231fde
fa9cc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
from PIL import Image, ImageDraw, ImageFont
import openai

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)

openai.api_key = os.environ["OPENAI_API_KEY"]

def generate_caption(image):
    image = vis_processors["eval"](image).unsqueeze(0).to(device)
    caption = model.generate({"image": image})

    caption = "\n".join(caption)
    #use gpt-4 to generate a meme based on the caption
    response = openai.ChatCompletion.create(
                model="gpt-4",
                messages=[
                    {"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
                    {"role": "user", "content": caption}
                ],
                temperature=0.6
            )
    
    meme_text = response.choices[0].message.content
    print(meme_text)

    image = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)).convert('RGB')

    meme_image = image.copy()
    draw = ImageDraw.Draw(meme_image)

    font_size = 1
    text_width = 0
    max_width = int(224 * 0.9)
    while text_width < max_width:
        font_size += 1
        font = ImageFont.truetype("impact.ttf", font_size)
        text_width, _ = draw.textsize(meme_text, font=font)

    text_width, text_height = draw.textsize(meme_text, font=font)
    text_x = (224 - text_width) // 2
    text_y = (224 - text_height) // 2

    draw.text((text_x, 0), meme_text, font=font, fill=(255, 255, 255))

    meme_image = meme_image.convert('RGB')

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return meme_image


with gr.Blocks() as demo:
    gr.Markdown(
        "### Memero - Generador de Memes"
    )
    gr.Markdown(
        "Escribe un meme en español a partir de una imagen."
    )

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Imagen", type="pil")
            btn_caption = gr.Button("Generar meme")
            output_text = gr.Image(label="Meme", lines=5)

    btn_caption.click(
        generate_caption, inputs=[input_image], outputs=[output_text]
    )

demo.launch()