Pratap2002 commited on
Commit
d8a9fb0
·
verified ·
1 Parent(s): c59606b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -192
app.py CHANGED
@@ -1,193 +1,165 @@
1
- import os
2
- from dotenv import load_dotenv
3
- import streamlit as st
4
- import requests
5
- from PIL import Image, ImageDraw, ImageFont
6
- import io
7
- import base64
8
- import easyocr
9
- import numpy as np
10
- import cv2
11
-
12
- # Load environment variables
13
- load_dotenv()
14
-
15
- # Set up logging
16
- import logging
17
- logging.basicConfig(level=logging.DEBUG)
18
- logger = logging.getLogger(__name__)
19
-
20
- # Hugging Face API setup
21
- API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
22
-
23
- HF_TOKEN = os.getenv("HF_TOKEN")
24
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
25
-
26
- # Initialize EasyOCR reader
27
- reader = easyocr.Reader(['en'])
28
-
29
- def query(payload):
30
- try:
31
- response = requests.post(API_URL, headers=headers, json=payload)
32
- response.raise_for_status()
33
-
34
- logger.debug(f"API response status code: {response.status_code}")
35
- logger.debug(f"API response headers: {response.headers}")
36
-
37
- content_type = response.headers.get('Content-Type', '')
38
- if 'application/json' in content_type:
39
- return response.json()
40
- elif 'image' in content_type:
41
- return response.content
42
- else:
43
- logger.error(f"Unexpected content type: {content_type}")
44
- st.error(f"Unexpected content type: {content_type}")
45
- return None
46
- except requests.exceptions.RequestException as e:
47
- logger.error(f"Request failed: {str(e)}")
48
- st.error(f"Request failed: {str(e)}")
49
- return None
50
-
51
- def increase_image_quality(image, scale_factor):
52
- width, height = image.size
53
- new_size = (width * scale_factor, height * scale_factor)
54
- return image.resize(new_size, Image.LANCZOS)
55
-
56
- def extract_text_from_image(image):
57
- img_array = np.array(image)
58
- results = reader.readtext(img_array)
59
- return ' '.join([result[1] for result in results])
60
-
61
- def remove_text_from_image(image, text_to_remove):
62
- img_array = np.array(image)
63
- results = reader.readtext(img_array)
64
-
65
- for (bbox, text, prob) in results:
66
- if text_to_remove.lower() in text.lower():
67
- top_left = tuple(map(int, bbox[0]))
68
- bottom_right = tuple(map(int, bbox[2]))
69
-
70
- # Convert image to OpenCV format
71
- img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
72
-
73
- # Create a mask for inpainting
74
- mask = np.zeros(img_cv.shape[:2], dtype=np.uint8)
75
- cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1)
76
-
77
- # Perform inpainting
78
- inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA)
79
-
80
- # Convert back to PIL Image
81
- image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB))
82
-
83
- return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1])
84
-
85
- logger.warning(f"Text '{text_to_remove}' not found in the image.")
86
- return image, None, None
87
-
88
- def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None):
89
- draw = ImageDraw.Draw(image)
90
- try:
91
- font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
92
- except IOError:
93
- logger.warning("Roboto-Bold font not found, using default font")
94
- font = ImageFont.load_default()
95
-
96
- img_width, img_height = image.size
97
- if position is None or size is None:
98
- # Calculate the center position if no position is provided
99
- bbox = font.getbbox(text)
100
- text_width = bbox[2] - bbox[0]
101
- text_height = bbox[3] - bbox[1]
102
- position = ((img_width - text_width) // 2, (img_height - text_height) // 2)
103
- size = (text_width, text_height)
104
-
105
- # Adjust font size to fit within the given size
106
- while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]:
107
- font_size -= 1
108
- font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
109
-
110
- # Use the exact position of the removed text
111
- logger.debug(f"Adding text at position: {position}")
112
- draw.text(position, text, font=font, fill=font_color)
113
- return image
114
-
115
- def main():
116
- st.title("Poster Generator and Editor")
117
-
118
- # Image Generation
119
- st.header("Generate Poster")
120
- poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
121
- prompt = st.text_area("Prompt")
122
- num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
123
- quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)
124
-
125
- if st.button("Generate Images"):
126
- if poster_type == "Other":
127
- full_prompt = f"A colorful poster with the following elements: {prompt}"
128
- else:
129
- full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"
130
-
131
- generated_images = []
132
- for i in range(num_images):
133
- with st.spinner(f"Generating image {i+1}..."):
134
- logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
135
- response = query({"inputs": full_prompt})
136
-
137
- if isinstance(response, bytes):
138
- image = Image.open(io.BytesIO(response))
139
- if quality_factor > 1:
140
- image = increase_image_quality(image, quality_factor)
141
- generated_images.append(image)
142
- else:
143
- st.error("Failed to generate image")
144
-
145
- # Display generated images
146
- for i, img in enumerate(generated_images):
147
- st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
148
-
149
- # Save image to session state for editing
150
- img_byte_arr = io.BytesIO()
151
- img.save(img_byte_arr, format='PNG')
152
- img_byte_arr = img_byte_arr.getvalue()
153
- st.session_state[f'image_{i}'] = img_byte_arr
154
-
155
- # Image Editing
156
- st.header("Edit Poster")
157
- image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))])
158
-
159
- if image_to_edit:
160
- image_index = int(image_to_edit.split()[-1]) - 1
161
- img_bytes = st.session_state[f'image_{image_index}']
162
- img = Image.open(io.BytesIO(img_bytes))
163
- st.image(img, caption="Current Image", use_column_width=True)
164
-
165
- text_to_remove = st.text_input("Text to Remove")
166
- new_text = st.text_input("New Text")
167
- font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40)
168
- font_color = st.color_picker("Font Color", "#FFFFFF")
169
-
170
- if st.button("Apply Changes"):
171
- position = None
172
- size = None
173
- if text_to_remove:
174
- img, position, size = remove_text_from_image(img, text_to_remove)
175
-
176
- if new_text:
177
- img = add_text_to_image(img, new_text, font_size, font_color, position, size)
178
-
179
- st.image(img, caption="Edited Image", use_column_width=True)
180
-
181
- # Save edited image for download
182
- img_byte_arr = io.BytesIO()
183
- img.save(img_byte_arr, format='PNG')
184
- img_byte_arr = img_byte_arr.getvalue()
185
- st.download_button(
186
- label="Download Edited Image",
187
- data=img_byte_arr,
188
- file_name="edited_poster.png",
189
- mime="image/png"
190
- )
191
-
192
- if __name__ == "__main__":
193
  main()
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import io
4
+ import cv2
5
+ import numpy as np
6
+ import easyocr
7
+ import os
8
+ from dotenv import load_dotenv
9
+ import requests
10
+ import logging
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.DEBUG)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Hugging Face API setup
20
+ API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
21
+ HF_TOKEN = os.getenv("HF_TOKEN")
22
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
23
+
24
+ def load_image(image_file):
25
+ img = Image.open(image_file)
26
+ return img
27
+
28
+ def detect_text(image):
29
+ reader = easyocr.Reader(['en'])
30
+ img_array = np.array(image)
31
+ results = reader.readtext(img_array)
32
+ return [(text, box) for (box, text, _) in results]
33
+
34
+ def replace_text_in_image(image, text_to_replace, new_text):
35
+ img_array = np.array(image)
36
+ for (text, box) in detect_text(image):
37
+ if text == text_to_replace:
38
+ x, y, w, h = int(box[0][0]), int(box[0][1]), int(box[2][0] - box[0][0]), int(box[2][1] - box[0][1])
39
+ mask = np.zeros(img_array.shape[:2], dtype=np.uint8)
40
+ cv2.rectangle(mask, (x, y), (x+w, y+h), 255, -1)
41
+ img_array = cv2.inpaint(img_array, mask, 3, cv2.INPAINT_TELEA)
42
+ image = Image.fromarray(img_array)
43
+ draw = ImageDraw.Draw(image)
44
+ font = ImageFont.truetype("arial.ttf", 40)
45
+ draw.text((x, y), new_text, font=font, fill="#000000")
46
+ return image
47
+ return Image.fromarray(img_array)
48
+
49
+ def query(payload):
50
+ try:
51
+ response = requests.post(API_URL, headers=headers, json=payload)
52
+ response.raise_for_status()
53
+
54
+ logger.debug(f"API response status code: {response.status_code}")
55
+ logger.debug(f"API response headers: {response.headers}")
56
+
57
+ content_type = response.headers.get('Content-Type', '')
58
+ if 'application/json' in content_type:
59
+ return response.json()
60
+ elif 'image' in content_type:
61
+ return response.content
62
+ else:
63
+ logger.error(f"Unexpected content type: {content_type}")
64
+ st.error(f"Unexpected content type: {content_type}")
65
+ return None
66
+ except requests.exceptions.RequestException as e:
67
+ logger.error(f"Request failed: {str(e)}")
68
+ st.error(f"Request failed: {str(e)}")
69
+ return None
70
+
71
+ def increase_image_quality(image, scale_factor):
72
+ width, height = image.size
73
+ new_size = (width * scale_factor, height * scale_factor)
74
+ return image.resize(new_size, Image.LANCZOS)
75
+
76
+ def image_text_replacer():
77
+ st.header("Image Text Replacer")
78
+
79
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
80
+ if uploaded_file is not None:
81
+ image = load_image(uploaded_file)
82
+ st.image(image, caption='Uploaded Image', use_column_width=True)
83
+
84
+ text_results = detect_text(image)
85
+
86
+ st.subheader("Detected Text:")
87
+
88
+ edited_image = image.copy()
89
+
90
+ for i, (text, box) in enumerate(text_results):
91
+ if text.strip(): # Only process non-empty text
92
+ st.text(f"{i+1}. {text}")
93
+
94
+ new_text = st.text_input(f"Enter new text for '{text}':", value=text, key=f"new_text_{i}")
95
+
96
+ if st.button(f"Replace '{text}'", key=f"replace_{i}"):
97
+ edited_image = replace_text_in_image(edited_image, text, new_text)
98
+ st.image(edited_image, caption='Edited Image', use_column_width=True)
99
+
100
+ # Provide download option for the edited image
101
+ buf = io.BytesIO()
102
+ edited_image.save(buf, format="PNG")
103
+ byte_im = buf.getvalue()
104
+ st.download_button(
105
+ label="Download edited image",
106
+ data=byte_im,
107
+ file_name="edited_image.png",
108
+ mime="image/png"
109
+ )
110
+
111
+ def poster_generator():
112
+ st.header("Generate Poster")
113
+ poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
114
+ prompt = st.text_area("Prompt")
115
+ num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
116
+ quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)
117
+
118
+ if st.button("Generate Images"):
119
+ if poster_type == "Other":
120
+ full_prompt = f"A colorful poster with the following elements: {prompt}"
121
+ else:
122
+ full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"
123
+
124
+ generated_images = []
125
+ for i in range(num_images):
126
+ with st.spinner(f"Generating image {i+1}..."):
127
+ logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
128
+ response = query({"inputs": full_prompt})
129
+
130
+ if isinstance(response, bytes):
131
+ image = Image.open(io.BytesIO(response))
132
+ if quality_factor > 1:
133
+ image = increase_image_quality(image, quality_factor)
134
+ generated_images.append(image)
135
+ else:
136
+ st.error("Failed to generate image")
137
+
138
+ # Display generated images
139
+ for i, img in enumerate(generated_images):
140
+ st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
141
+
142
+ # Provide download option for the generated image
143
+ buf = io.BytesIO()
144
+ img.save(buf, format="PNG")
145
+ byte_im = buf.getvalue()
146
+ st.download_button(
147
+ label=f"Download generated poster {i+1}",
148
+ data=byte_im,
149
+ file_name=f"generated_poster_{i+1}.png",
150
+ mime="image/png"
151
+ )
152
+
153
+ def main():
154
+ st.title("Image Processing App")
155
+
156
+ app_mode = st.sidebar.selectbox("Choose the app mode",
157
+ ["Image Text Replacer", "Poster Generator"])
158
+
159
+ if app_mode == "Image Text Replacer":
160
+ image_text_replacer()
161
+ elif app_mode == "Poster Generator":
162
+ poster_generator()
163
+
164
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  main()