Pratap2002 commited on
Commit
f17dc8b
·
verified ·
1 Parent(s): 3ea6f67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -187
app.py CHANGED
@@ -1,188 +1,193 @@
1
- import os
2
- import streamlit as st
3
- import requests
4
- from PIL import Image, ImageDraw, ImageFont
5
- import io
6
- import base64
7
- import easyocr
8
- import numpy as np
9
- import cv2
10
-
11
- # Set up logging
12
- import logging
13
- logging.basicConfig(level=logging.DEBUG)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Hugging Face API setup
17
- API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
18
-
19
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
20
-
21
- # Initialize EasyOCR reader
22
- reader = easyocr.Reader(['en'])
23
-
24
- def query(payload):
25
- try:
26
- response = requests.post(API_URL, headers=headers, json=payload)
27
- response.raise_for_status()
28
-
29
- logger.debug(f"API response status code: {response.status_code}")
30
- logger.debug(f"API response headers: {response.headers}")
31
-
32
- content_type = response.headers.get('Content-Type', '')
33
- if 'application/json' in content_type:
34
- return response.json()
35
- elif 'image' in content_type:
36
- return response.content
37
- else:
38
- logger.error(f"Unexpected content type: {content_type}")
39
- st.error(f"Unexpected content type: {content_type}")
40
- return None
41
- except requests.exceptions.RequestException as e:
42
- logger.error(f"Request failed: {str(e)}")
43
- st.error(f"Request failed: {str(e)}")
44
- return None
45
-
46
- def increase_image_quality(image, scale_factor):
47
- width, height = image.size
48
- new_size = (width * scale_factor, height * scale_factor)
49
- return image.resize(new_size, Image.LANCZOS)
50
-
51
- def extract_text_from_image(image):
52
- img_array = np.array(image)
53
- results = reader.readtext(img_array)
54
- return ' '.join([result[1] for result in results])
55
-
56
- def remove_text_from_image(image, text_to_remove):
57
- img_array = np.array(image)
58
- results = reader.readtext(img_array)
59
-
60
- for (bbox, text, prob) in results:
61
- if text_to_remove.lower() in text.lower():
62
- top_left = tuple(map(int, bbox[0]))
63
- bottom_right = tuple(map(int, bbox[2]))
64
-
65
- # Convert image to OpenCV format
66
- img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
67
-
68
- # Create a mask for inpainting
69
- mask = np.zeros(img_cv.shape[:2], dtype=np.uint8)
70
- cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1)
71
-
72
- # Perform inpainting
73
- inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA)
74
-
75
- # Convert back to PIL Image
76
- image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB))
77
-
78
- return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1])
79
-
80
- logger.warning(f"Text '{text_to_remove}' not found in the image.")
81
- return image, None, None
82
-
83
- def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None):
84
- draw = ImageDraw.Draw(image)
85
- try:
86
- font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
87
- except IOError:
88
- logger.warning("Roboto-Bold font not found, using default font")
89
- font = ImageFont.load_default()
90
-
91
- img_width, img_height = image.size
92
- if position is None or size is None:
93
- # Calculate the center position if no position is provided
94
- bbox = font.getbbox(text)
95
- text_width = bbox[2] - bbox[0]
96
- text_height = bbox[3] - bbox[1]
97
- position = ((img_width - text_width) // 2, (img_height - text_height) // 2)
98
- size = (text_width, text_height)
99
-
100
- # Adjust font size to fit within the given size
101
- while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]:
102
- font_size -= 1
103
- font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
104
-
105
- # Use the exact position of the removed text
106
- logger.debug(f"Adding text at position: {position}")
107
- draw.text(position, text, font=font, fill=font_color)
108
- return image
109
-
110
- def main():
111
- st.title("Poster Generator and Editor")
112
-
113
- # Image Generation
114
- st.header("Generate Poster")
115
- poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
116
- prompt = st.text_area("Prompt")
117
- num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
118
- quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)
119
-
120
- if st.button("Generate Images"):
121
- if poster_type == "Other":
122
- full_prompt = f"A colorful poster with the following elements: {prompt}"
123
- else:
124
- full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"
125
-
126
- generated_images = []
127
- for i in range(num_images):
128
- with st.spinner(f"Generating image {i+1}..."):
129
- logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
130
- response = query({"inputs": full_prompt})
131
-
132
- if isinstance(response, bytes):
133
- image = Image.open(io.BytesIO(response))
134
- if quality_factor > 1:
135
- image = increase_image_quality(image, quality_factor)
136
- generated_images.append(image)
137
- else:
138
- st.error("Failed to generate image")
139
-
140
- # Display generated images
141
- for i, img in enumerate(generated_images):
142
- st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
143
-
144
- # Save image to session state for editing
145
- img_byte_arr = io.BytesIO()
146
- img.save(img_byte_arr, format='PNG')
147
- img_byte_arr = img_byte_arr.getvalue()
148
- st.session_state[f'image_{i}'] = img_byte_arr
149
-
150
- # Image Editing
151
- st.header("Edit Poster")
152
- image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))])
153
-
154
- if image_to_edit:
155
- image_index = int(image_to_edit.split()[-1]) - 1
156
- img_bytes = st.session_state[f'image_{image_index}']
157
- img = Image.open(io.BytesIO(img_bytes))
158
- st.image(img, caption="Current Image", use_column_width=True)
159
-
160
- text_to_remove = st.text_input("Text to Remove")
161
- new_text = st.text_input("New Text")
162
- font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40)
163
- font_color = st.color_picker("Font Color", "#FFFFFF")
164
-
165
- if st.button("Apply Changes"):
166
- position = None
167
- size = None
168
- if text_to_remove:
169
- img, position, size = remove_text_from_image(img, text_to_remove)
170
-
171
- if new_text:
172
- img = add_text_to_image(img, new_text, font_size, font_color, position, size)
173
-
174
- st.image(img, caption="Edited Image", use_column_width=True)
175
-
176
- # Save edited image for download
177
- img_byte_arr = io.BytesIO()
178
- img.save(img_byte_arr, format='PNG')
179
- img_byte_arr = img_byte_arr.getvalue()
180
- st.download_button(
181
- label="Download Edited Image",
182
- data=img_byte_arr,
183
- file_name="edited_poster.png",
184
- mime="image/png"
185
- )
186
-
187
- if __name__ == "__main__":
 
 
 
 
 
188
  main()
 
1
+ import os
2
+ import streamlit as st
3
+ import requests
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import io
6
+ import base64
7
+ import easyocr
8
+ import numpy as np
9
+ import cv2
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
+
16
+ # Set up logging
17
+ import logging
18
+ logging.basicConfig(level=logging.DEBUG)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Hugging Face API setup
22
+ API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
23
+
24
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
25
+
26
+ # Initialize EasyOCR reader
27
+ reader = easyocr.Reader(['en'])
28
+
29
+ def query(payload):
30
+ try:
31
+ response = requests.post(API_URL, headers=headers, json=payload)
32
+ response.raise_for_status()
33
+
34
+ logger.debug(f"API response status code: {response.status_code}")
35
+ logger.debug(f"API response headers: {response.headers}")
36
+
37
+ content_type = response.headers.get('Content-Type', '')
38
+ if 'application/json' in content_type:
39
+ return response.json()
40
+ elif 'image' in content_type:
41
+ return response.content
42
+ else:
43
+ logger.error(f"Unexpected content type: {content_type}")
44
+ st.error(f"Unexpected content type: {content_type}")
45
+ return None
46
+ except requests.exceptions.RequestException as e:
47
+ logger.error(f"Request failed: {str(e)}")
48
+ st.error(f"Request failed: {str(e)}")
49
+ return None
50
+
51
+ def increase_image_quality(image, scale_factor):
52
+ width, height = image.size
53
+ new_size = (width * scale_factor, height * scale_factor)
54
+ return image.resize(new_size, Image.LANCZOS)
55
+
56
+ def extract_text_from_image(image):
57
+ img_array = np.array(image)
58
+ results = reader.readtext(img_array)
59
+ return ' '.join([result[1] for result in results])
60
+
61
+ def remove_text_from_image(image, text_to_remove):
62
+ img_array = np.array(image)
63
+ results = reader.readtext(img_array)
64
+
65
+ for (bbox, text, prob) in results:
66
+ if text_to_remove.lower() in text.lower():
67
+ top_left = tuple(map(int, bbox[0]))
68
+ bottom_right = tuple(map(int, bbox[2]))
69
+
70
+ # Convert image to OpenCV format
71
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
72
+
73
+ # Create a mask for inpainting
74
+ mask = np.zeros(img_cv.shape[:2], dtype=np.uint8)
75
+ cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1)
76
+
77
+ # Perform inpainting
78
+ inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA)
79
+
80
+ # Convert back to PIL Image
81
+ image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB))
82
+
83
+ return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1])
84
+
85
+ logger.warning(f"Text '{text_to_remove}' not found in the image.")
86
+ return image, None, None
87
+
88
+ def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None):
89
+ draw = ImageDraw.Draw(image)
90
+ try:
91
+ font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
92
+ except IOError:
93
+ logger.warning("Roboto-Bold font not found, using default font")
94
+ font = ImageFont.load_default()
95
+
96
+ img_width, img_height = image.size
97
+ if position is None or size is None:
98
+ # Calculate the center position if no position is provided
99
+ bbox = font.getbbox(text)
100
+ text_width = bbox[2] - bbox[0]
101
+ text_height = bbox[3] - bbox[1]
102
+ position = ((img_width - text_width) // 2, (img_height - text_height) // 2)
103
+ size = (text_width, text_height)
104
+
105
+ # Adjust font size to fit within the given size
106
+ while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]:
107
+ font_size -= 1
108
+ font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
109
+
110
+ # Use the exact position of the removed text
111
+ logger.debug(f"Adding text at position: {position}")
112
+ draw.text(position, text, font=font, fill=font_color)
113
+ return image
114
+
115
+ def main():
116
+ st.title("Poster Generator and Editor")
117
+
118
+ # Image Generation
119
+ st.header("Generate Poster")
120
+ poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
121
+ prompt = st.text_area("Prompt")
122
+ num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
123
+ quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)
124
+
125
+ if st.button("Generate Images"):
126
+ if poster_type == "Other":
127
+ full_prompt = f"A colorful poster with the following elements: {prompt}"
128
+ else:
129
+ full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"
130
+
131
+ generated_images = []
132
+ for i in range(num_images):
133
+ with st.spinner(f"Generating image {i+1}..."):
134
+ logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
135
+ response = query({"inputs": full_prompt})
136
+
137
+ if isinstance(response, bytes):
138
+ image = Image.open(io.BytesIO(response))
139
+ if quality_factor > 1:
140
+ image = increase_image_quality(image, quality_factor)
141
+ generated_images.append(image)
142
+ else:
143
+ st.error("Failed to generate image")
144
+
145
+ # Display generated images
146
+ for i, img in enumerate(generated_images):
147
+ st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
148
+
149
+ # Save image to session state for editing
150
+ img_byte_arr = io.BytesIO()
151
+ img.save(img_byte_arr, format='PNG')
152
+ img_byte_arr = img_byte_arr.getvalue()
153
+ st.session_state[f'image_{i}'] = img_byte_arr
154
+
155
+ # Image Editing
156
+ st.header("Edit Poster")
157
+ image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))])
158
+
159
+ if image_to_edit:
160
+ image_index = int(image_to_edit.split()[-1]) - 1
161
+ img_bytes = st.session_state[f'image_{image_index}']
162
+ img = Image.open(io.BytesIO(img_bytes))
163
+ st.image(img, caption="Current Image", use_column_width=True)
164
+
165
+ text_to_remove = st.text_input("Text to Remove")
166
+ new_text = st.text_input("New Text")
167
+ font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40)
168
+ font_color = st.color_picker("Font Color", "#FFFFFF")
169
+
170
+ if st.button("Apply Changes"):
171
+ position = None
172
+ size = None
173
+ if text_to_remove:
174
+ img, position, size = remove_text_from_image(img, text_to_remove)
175
+
176
+ if new_text:
177
+ img = add_text_to_image(img, new_text, font_size, font_color, position, size)
178
+
179
+ st.image(img, caption="Edited Image", use_column_width=True)
180
+
181
+ # Save edited image for download
182
+ img_byte_arr = io.BytesIO()
183
+ img.save(img_byte_arr, format='PNG')
184
+ img_byte_arr = img_byte_arr.getvalue()
185
+ st.download_button(
186
+ label="Download Edited Image",
187
+ data=img_byte_arr,
188
+ file_name="edited_poster.png",
189
+ mime="image/png"
190
+ )
191
+
192
+ if __name__ == "__main__":
193
  main()