Commit
·
fa9cc42
1
Parent(s):
0354030
Upload 2 files
Browse files- app.py +93 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from lavis.models import load_model_and_preprocess
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
import openai
|
9 |
+
|
10 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
model, vis_processors, _ = load_model_and_preprocess(
|
13 |
+
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
|
14 |
+
)
|
15 |
+
|
16 |
+
openai.api_key = os.environ["OPENAI_API_KEY"]
|
17 |
+
|
18 |
+
def generate_caption(image):
|
19 |
+
image = vis_processors["eval"](image).unsqueeze(0).to(device)
|
20 |
+
caption = model.generate({"image": image})
|
21 |
+
|
22 |
+
caption = "\n".join(caption)
|
23 |
+
#use gpt-4 to generate a meme based on the caption
|
24 |
+
response = openai.ChatCompletion.create(
|
25 |
+
model="gpt-4",
|
26 |
+
messages=[
|
27 |
+
{"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
|
28 |
+
{"role": "user", "content": caption}
|
29 |
+
],
|
30 |
+
temperature=0.6
|
31 |
+
)
|
32 |
+
|
33 |
+
meme_text = response.choices[0].message.content
|
34 |
+
print(meme_text)
|
35 |
+
|
36 |
+
pil_image = Image.fromarray((image.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
37 |
+
|
38 |
+
# Create a copy of the input PIL image
|
39 |
+
meme_image = pil_image.copy()
|
40 |
+
|
41 |
+
# Convert the copied image to RGBA mode to support transparency
|
42 |
+
meme_image = meme_image.convert('RGBA')
|
43 |
+
|
44 |
+
# Add the text to the image
|
45 |
+
draw = ImageDraw.Draw(meme_image)
|
46 |
+
|
47 |
+
# Calculate font size based on text length
|
48 |
+
text_length = len(meme_text)
|
49 |
+
if text_length <= 15:
|
50 |
+
font_size = 24
|
51 |
+
elif text_length <= 30:
|
52 |
+
font_size = 18
|
53 |
+
else:
|
54 |
+
font_size = 14
|
55 |
+
|
56 |
+
font = ImageFont.load_default()
|
57 |
+
text_width, text_height = draw.textsize(meme_text, font=font)
|
58 |
+
|
59 |
+
# Calculate the position to place the text at the top and center horizontally
|
60 |
+
x = (meme_image.width - text_width) // 2
|
61 |
+
y = 10 # Adjust this value to add more or less padding from the top
|
62 |
+
|
63 |
+
draw.text((x, y), meme_text, fill=(255, 255, 255), font=font)
|
64 |
+
|
65 |
+
# Convert the meme_image back to RGB mode
|
66 |
+
meme_image = meme_image.convert('RGB')
|
67 |
+
|
68 |
+
|
69 |
+
if torch.cuda.is_available():
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
|
72 |
+
return meme_image
|
73 |
+
|
74 |
+
|
75 |
+
with gr.Blocks() as demo:
|
76 |
+
gr.Markdown(
|
77 |
+
"### Memero - Generador de Memes"
|
78 |
+
)
|
79 |
+
gr.Markdown(
|
80 |
+
"Escribe un meme en español a partir de una imagen."
|
81 |
+
)
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column():
|
85 |
+
input_image = gr.Image(label="Imagen", type="pil")
|
86 |
+
btn_caption = gr.Button("Generar meme")
|
87 |
+
output_text = gr.Image(label="Meme", lines=5)
|
88 |
+
|
89 |
+
btn_caption.click(
|
90 |
+
generate_caption, inputs=[input_image], outputs=[output_text]
|
91 |
+
)
|
92 |
+
|
93 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
salesforce-lavis
|
4 |
+
openai
|