Commit
·
2892d42
1
Parent(s):
3fa52cc
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import os
|
|
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from lavis.models import load_model_and_preprocess
|
7 |
-
from PIL import Image, ImageDraw, ImageFont
|
8 |
import openai
|
9 |
|
10 |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
@@ -16,6 +17,7 @@ model, vis_processors, _ = load_model_and_preprocess(
|
|
16 |
openai.api_key = os.environ["OPENAI_API_KEY"]
|
17 |
|
18 |
def generate_caption(image):
|
|
|
19 |
image = vis_processors["eval"](image).unsqueeze(0).to(device)
|
20 |
caption = model.generate({"image": image})
|
21 |
|
@@ -33,24 +35,26 @@ def generate_caption(image):
|
|
33 |
meme_text = response.choices[0].message.content
|
34 |
print(meme_text)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
meme_image = image.copy()
|
39 |
-
draw = ImageDraw.Draw(meme_image)
|
40 |
-
|
41 |
-
font_size = 1
|
42 |
-
text_width = 0
|
43 |
-
max_width = int(224 * 0.9)
|
44 |
-
while text_width < max_width:
|
45 |
-
font_size += 1
|
46 |
-
font = ImageFont.truetype("impact.ttf", font_size)
|
47 |
-
text_width, _ = draw.textsize(meme_text, font=font)
|
48 |
|
|
|
|
|
|
|
|
|
49 |
text_width, text_height = draw.textsize(meme_text, font=font)
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
meme_image = meme_image.convert('RGB')
|
56 |
|
|
|
1 |
import os
|
2 |
+
import textwrapper
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from lavis.models import load_model_and_preprocess
|
|
|
9 |
import openai
|
10 |
|
11 |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
|
|
17 |
openai.api_key = os.environ["OPENAI_API_KEY"]
|
18 |
|
19 |
def generate_caption(image):
|
20 |
+
pil_image = image.copy() # Create a copy of the input PIL image
|
21 |
image = vis_processors["eval"](image).unsqueeze(0).to(device)
|
22 |
caption = model.generate({"image": image})
|
23 |
|
|
|
35 |
meme_text = response.choices[0].message.content
|
36 |
print(meme_text)
|
37 |
|
38 |
+
#put the meme text on the image
|
39 |
+
draw = ImageDraw.Draw(pil_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Calculate font size
|
42 |
+
text_length = len(meme_text)
|
43 |
+
font_size = 18
|
44 |
+
font = ImageFont.truetype("impact.ttf", font_size)
|
45 |
text_width, text_height = draw.textsize(meme_text, font=font)
|
46 |
+
|
47 |
+
# Wrap the text to fit within the image width
|
48 |
+
wrapped_text = textwrap.fill(meme_text, width=int(pil_image.width / font_size))
|
49 |
+
|
50 |
+
# Calculate the position to place the text at the top and center horizontally
|
51 |
+
text_lines = wrapped_text.split('\n')
|
52 |
+
y = 10 # Adjust this value to add more or less padding from the top
|
53 |
+
for line in text_lines:
|
54 |
+
line_width, line_height = draw.textsize(line, font=font)
|
55 |
+
x = (pil_image.width - line_width) // 2
|
56 |
+
draw.text((x, y), line, fill=(255, 255, 255), font=font)
|
57 |
+
y += line_height
|
58 |
|
59 |
meme_image = meme_image.convert('RGB')
|
60 |
|