File size: 2,745 Bytes
fa9cc42
697f4c6
2892d42
fa9cc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2892d42
fa9cc42
 
 
 
 
 
 
 
437e10c
fa9cc42
 
 
 
 
 
 
 
437e10c
2892d42
fa9cc42
40895e5
 
2892d42
437e10c
91bb5aa
 
 
437e10c
91bb5aa
 
437e10c
 
2892d42
2d292a8
 
2892d42
40895e5
2892d42
40895e5
fa9cc42
5e2bbd7
fa9cc42
 
 
 
5e2bbd7
fa9cc42
 
 
 
 
 
 
d2a1b04
fa9cc42
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont

import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)

openai.api_key = os.environ["OPENAI_API_KEY"]

def generate_caption(image):
    pil_image = image.copy()  # Create a copy of the input PIL image
    image = vis_processors["eval"](image).unsqueeze(0).to(device)
    caption = model.generate({"image": image})

    caption = "\n".join(caption)
    #use gpt-4 to generate a meme based on the caption
    response = openai.ChatCompletion.create(
                model="gpt-4",
                messages=[
                    {"role": "system", "content": "Escribe un meme chistoso en español para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
                    {"role": "user", "content": caption}
                ],
                temperature=0.6
            )
    
    meme_text = response.choices[0].message.content
    print(meme_text)

    # Put the meme text on the image
    draw = ImageDraw.Draw(pil_image)

    # Set the fixed font size to 80
    font_size = 60
    font = ImageFont.truetype("impact.ttf", font_size)

    # Calculate the width for text wrap
    wrap_width = len(meme_text) // 2

    # Wrap the text to fit within the image width and have a maximum of 2 lines
    wrapped_text = textwrap.wrap(meme_text, width=wrap_width)
    text_lines = wrapped_text

    y = 10
    for line in text_lines:
        line_width = draw.textlength(line, font=font)
        _, line_height = font.getmetrics()
        x = (pil_image.width - line_width) // 2

        draw.text((x, y), line, fill=(255, 255, 255), font=font)
        y += line_height + int(line_height * 0.1)

    pil_image = pil_image.convert('RGB')

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return pil_image


with gr.Blocks() as demo:
    gr.Markdown(
        "### Memero - Generador de Memes"
    )
    gr.Markdown(
        "Genera un meme en español a partir de una imagen."
    )

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Imagen", type="pil")
            btn_caption = gr.Button("Generar meme")
            output_text = gr.Image(label="Meme", lines=5)

    btn_caption.click(
        generate_caption, inputs=[input_image], outputs=[output_text]
    )

demo.launch()