andreinigo commited on
Commit
2892d42
·
1 Parent(s): 3fa52cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
 
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from lavis.models import load_model_and_preprocess
7
- from PIL import Image, ImageDraw, ImageFont
8
  import openai
9
 
10
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
@@ -16,6 +17,7 @@ model, vis_processors, _ = load_model_and_preprocess(
16
  openai.api_key = os.environ["OPENAI_API_KEY"]
17
 
18
  def generate_caption(image):
 
19
  image = vis_processors["eval"](image).unsqueeze(0).to(device)
20
  caption = model.generate({"image": image})
21
 
@@ -33,24 +35,26 @@ def generate_caption(image):
33
  meme_text = response.choices[0].message.content
34
  print(meme_text)
35
 
36
- image = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)).convert('RGB')
37
-
38
- meme_image = image.copy()
39
- draw = ImageDraw.Draw(meme_image)
40
-
41
- font_size = 1
42
- text_width = 0
43
- max_width = int(224 * 0.9)
44
- while text_width < max_width:
45
- font_size += 1
46
- font = ImageFont.truetype("impact.ttf", font_size)
47
- text_width, _ = draw.textsize(meme_text, font=font)
48
 
 
 
 
 
49
  text_width, text_height = draw.textsize(meme_text, font=font)
50
- text_x = (224 - text_width) // 2
51
- text_y = (224 - text_height) // 2
52
-
53
- draw.text((text_x, 0), meme_text, font=font, fill=(255, 255, 255))
 
 
 
 
 
 
 
 
54
 
55
  meme_image = meme_image.convert('RGB')
56
 
 
1
  import os
2
+ import textwrapper
3
+ from PIL import Image, ImageDraw, ImageFont
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
  from lavis.models import load_model_and_preprocess
 
9
  import openai
10
 
11
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
 
17
  openai.api_key = os.environ["OPENAI_API_KEY"]
18
 
19
  def generate_caption(image):
20
+ pil_image = image.copy() # Create a copy of the input PIL image
21
  image = vis_processors["eval"](image).unsqueeze(0).to(device)
22
  caption = model.generate({"image": image})
23
 
 
35
  meme_text = response.choices[0].message.content
36
  print(meme_text)
37
 
38
+ #put the meme text on the image
39
+ draw = ImageDraw.Draw(pil_image)
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Calculate font size
42
+ text_length = len(meme_text)
43
+ font_size = 18
44
+ font = ImageFont.truetype("impact.ttf", font_size)
45
  text_width, text_height = draw.textsize(meme_text, font=font)
46
+
47
+ # Wrap the text to fit within the image width
48
+ wrapped_text = textwrap.fill(meme_text, width=int(pil_image.width / font_size))
49
+
50
+ # Calculate the position to place the text at the top and center horizontally
51
+ text_lines = wrapped_text.split('\n')
52
+ y = 10 # Adjust this value to add more or less padding from the top
53
+ for line in text_lines:
54
+ line_width, line_height = draw.textsize(line, font=font)
55
+ x = (pil_image.width - line_width) // 2
56
+ draw.text((x, y), line, fill=(255, 255, 255), font=font)
57
+ y += line_height
58
 
59
  meme_image = meme_image.convert('RGB')
60