andreinigo commited on
Commit
fa9cc42
·
1 Parent(s): 0354030

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +93 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from lavis.models import load_model_and_preprocess
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import openai
9
+
10
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
11
+
12
+ model, vis_processors, _ = load_model_and_preprocess(
13
+ name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
14
+ )
15
+
16
+ openai.api_key = os.environ["OPENAI_API_KEY"]
17
+
18
+ def generate_caption(image):
19
+ image = vis_processors["eval"](image).unsqueeze(0).to(device)
20
+ caption = model.generate({"image": image})
21
+
22
+ caption = "\n".join(caption)
23
+ #use gpt-4 to generate a meme based on the caption
24
+ response = openai.ChatCompletion.create(
25
+ model="gpt-4",
26
+ messages=[
27
+ {"role": "system", "content": "Escribe un meme chistoso para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
28
+ {"role": "user", "content": caption}
29
+ ],
30
+ temperature=0.6
31
+ )
32
+
33
+ meme_text = response.choices[0].message.content
34
+ print(meme_text)
35
+
36
+ pil_image = Image.fromarray((image.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
37
+
38
+ # Create a copy of the input PIL image
39
+ meme_image = pil_image.copy()
40
+
41
+ # Convert the copied image to RGBA mode to support transparency
42
+ meme_image = meme_image.convert('RGBA')
43
+
44
+ # Add the text to the image
45
+ draw = ImageDraw.Draw(meme_image)
46
+
47
+ # Calculate font size based on text length
48
+ text_length = len(meme_text)
49
+ if text_length <= 15:
50
+ font_size = 24
51
+ elif text_length <= 30:
52
+ font_size = 18
53
+ else:
54
+ font_size = 14
55
+
56
+ font = ImageFont.load_default()
57
+ text_width, text_height = draw.textsize(meme_text, font=font)
58
+
59
+ # Calculate the position to place the text at the top and center horizontally
60
+ x = (meme_image.width - text_width) // 2
61
+ y = 10 # Adjust this value to add more or less padding from the top
62
+
63
+ draw.text((x, y), meme_text, fill=(255, 255, 255), font=font)
64
+
65
+ # Convert the meme_image back to RGB mode
66
+ meme_image = meme_image.convert('RGB')
67
+
68
+
69
+ if torch.cuda.is_available():
70
+ torch.cuda.empty_cache()
71
+
72
+ return meme_image
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown(
77
+ "### Memero - Generador de Memes"
78
+ )
79
+ gr.Markdown(
80
+ "Escribe un meme en español a partir de una imagen."
81
+ )
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ input_image = gr.Image(label="Imagen", type="pil")
86
+ btn_caption = gr.Button("Generar meme")
87
+ output_text = gr.Image(label="Meme", lines=5)
88
+
89
+ btn_caption.click(
90
+ generate_caption, inputs=[input_image], outputs=[output_text]
91
+ )
92
+
93
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ salesforce-lavis
4
+ openai