justREE commited on
Commit
70a4e6d
·
verified ·
1 Parent(s): 321b768

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +187 -85
src/streamlit_app.py CHANGED
@@ -9,68 +9,85 @@ from PIL import Image # Python Imaging Library for image loading
9
  import numpy as np # numerical operations, especially array handling
10
 
11
  # 1) CACHE & LOAD MODELS
 
12
  @st.cache_resource(show_spinner=False)
13
  def load_captioner():
14
- # Loads BLIP image-to-text model; cached so it loads only once.
15
- # Returns: a function captioner(image: PIL.Image) -> List[Dict],
 
16
  return pipeline(
17
  "image-to-text",
18
  model="Salesforce/blip-image-captioning-base",
19
- device="cpu" # Can change to "cuda" if GPU is available
20
  )
21
 
22
  @st.cache_resource(show_spinner=False)
23
  def load_story_pipe():
24
- # Loads FLAN-T5 text-to-text model for story generation; cached once.
25
  # Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
 
26
  return pipeline(
27
  "text2text-generation",
28
  model="google/flan-t5-base",
29
- device="cpu" # Can change to "cuda" if GPU is available
30
  )
31
 
32
  @st.cache_resource(show_spinner=False)
33
  def load_tts_pipe():
34
- # Loads Meta MMS-TTS text-to-speech model; cached once.
35
  # Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
 
36
  return pipeline(
37
  "text-to-speech",
38
  model="facebook/mms-tts-eng",
39
- device="cpu" # Can change to "cuda" if GPU is available
40
  )
41
 
42
  # 2) HELPER FUNCTIONS
43
  def sentence_case(text: str) -> str:
44
- # Splits text into sentences on .!? delimiters,
45
- # capitalizes the first character of each sentence,
46
- # then rejoins into a single string.
47
- parts = re.split(r'([.!?])', text) # ["hello", ".", " world", "!"]
 
 
 
 
48
  out = []
 
49
  for i in range(0, len(parts) - 1, 2):
50
- sentence = parts[i].strip().capitalize() # capitalize first letter
51
- delimiter = parts[i + 1] # punctuation
52
- # Ensure a space before the sentence if it wasn't the very first part
53
- if out and not sentence.startswith(' ') and out[-1][-1] not in '.!?':
54
- out.append(f" {sentence}{delimiter}")
55
- else:
56
- out.append(f"{sentence}{delimiter}")
 
 
 
57
 
58
- # If trailing text without punctuation exists, capitalize and append it.
59
  if len(parts) % 2:
60
- last = parts[-1].strip().capitalize()
61
- if last:
62
- # Ensure a space before if needed
63
- if out and not last.startswith(' ') and out[-1][-1] not in '.!?':
64
- out.append(f" {last}")
65
- else:
66
- out.append(last)
67
-
68
- # Clean up potential multiple spaces resulting from split/join
69
  return " ".join(" ".join(out).split())
70
 
71
 
72
  def caption_image(img: Image.Image, captioner) -> str:
73
- # Given a PIL image and a captioner pipeline, returns a single-line caption.
 
 
 
 
 
 
74
  results = captioner(img) # run model
75
  if not results:
76
  return ""
@@ -78,50 +95,86 @@ def caption_image(img: Image.Image, captioner) -> str:
78
  return results[0].get("generated_text", "")
79
 
80
  def story_from_caption(caption: str, pipe) -> str:
81
- # Given a caption string and a text2text pipeline, returns a ~100-word story.
 
 
 
 
 
82
  prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
 
 
 
83
  results = pipe(
84
  prompt,
85
  max_length=120, # increased max length slightly
86
- min_length=80, # minimum generated tokens
87
- do_sample=True, # enable sampling
88
  top_k=100, # sample from top_k tokens
89
  top_p=0.9, # nucleus sampling threshold
90
- temperature=0.7, # sampling temperature
91
  repetition_penalty=1.1, # discourage repetition
92
  no_repeat_ngram_size=4, # block repeated n-grams
93
  early_stopping=False
94
  )
95
  raw = results[0]["generated_text"].strip() # full generated text
 
96
  # strip out the prompt if it echoes back - make comparison case-insensitive
97
- if raw.lower().startswith(prompt.lower()):
98
- raw = raw[len(prompt):].strip()
 
 
 
 
99
 
100
  # trim to last complete sentence ending in . ! or ?
101
- match = re.search(r'[.!?]', raw[::-1]) # Search for the first punctuation from the end
 
102
  if match:
103
- raw = raw[:len(raw) - match.start()] # Trim at that position
104
- elif len(raw) > 80: # If no punctuation found but story is long, trim to a reasonable length
105
- raw = raw[:80] + "..."
 
 
 
 
106
 
107
  return sentence_case(raw)
108
 
 
109
  def tts_bytes(text: str, tts_pipe) -> bytes:
110
- # Given a text string and a tts pipeline, returns WAV-format bytes.
111
- # Clean up text for TTS - remove leading/trailing quotes, etc.
 
 
 
 
 
 
112
  cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
113
- # Basic punctuation cleaning (optional, depending on TTS model)
114
- cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text) # Replace multiple periods with one
115
- cleaned_text = cleaned_text.replace('…', '...') # Replace ellipsis char with dots
116
- # Add a period if the text doesn't end with punctuation (helps TTS model finalize)
117
  if cleaned_text and cleaned_text[-1] not in '.!?':
118
  cleaned_text += '.'
 
 
 
 
 
 
119
 
120
  output = tts_pipe(cleaned_text)
121
  # pipeline may return list or single dict
122
  result = output[0] if isinstance(output, list) else output
123
- audio_array = result["audio"] # numpy array: (channels, samples) or (samples,)
124
- rate = result["sampling_rate"] # sampling rate integer
 
 
 
 
 
125
 
126
  # ensure audio_array is 2D (samples, channels) for consistent handling
127
  if audio_array.ndim == 1:
@@ -135,74 +188,123 @@ def tts_bytes(text: str, tts_pipe) -> bytes:
135
 
136
  buffer = io.BytesIO()
137
  wf = wave.open(buffer, "wb")
138
- wf.setnchannels(data.shape[1]) # number of channels
139
- wf.setsampwidth(2) # 16 bits = 2 bytes
140
- wf.setframerate(rate) # samples per second
141
- wf.writeframes(pcm.tobytes()) # write PCM data
142
- wf.close()
 
 
 
143
  buffer.seek(0)
144
  return buffer.read() # return raw WAV bytes
145
 
146
  # 3) STREAMLIT USER INTERFACE
 
147
  st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
 
 
148
  st.title("✨ Imagine & Narrate")
149
  st.write("Upload any image below to see AI imagine and narrate a story about it!")
150
 
151
- # -- Upload image widget --
152
  uploaded = st.file_uploader(
153
  "Choose an image file",
154
- type=["jpg", "jpeg", "png"]
 
 
155
  )
 
 
156
  if not uploaded:
157
  st.info("➡️ Upload an image above to start the magic!")
158
- st.stop()
159
-
160
- # Load the uploaded file into a PIL Image
161
- try:
162
- img = Image.open(uploaded)
163
- except Exception as e:
164
- st.error(f"Error loading image: {e}")
165
- st.stop()
166
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # -- Step 1: Display the image --
169
  st.subheader("📸 Your Visual Input")
170
- st.image(img, use_container_width=True)
171
  st.divider()
172
 
173
- # -- Step 2: Generate and display caption --
174
  st.subheader("🧠 Generating Insights")
175
- with st.spinner("Scanning image for key elements…"):
176
- captioner = load_captioner()
177
- raw_caption = caption_image(img, captioner)
178
- if not raw_caption:
179
- st.warning("Could not generate a caption for the image.")
180
- st.stop()
181
- caption = sentence_case(raw_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  st.markdown(f"**Identified Scene:** {caption}")
183
  st.divider()
184
 
185
- # -- Step 3: Generate and display story --
186
  st.subheader("📖 Crafting a Narrative")
187
- with st.spinner("Writing a compelling story…"):
188
- story_pipe = load_story_pipe()
189
- story = story_from_caption(caption, story_pipe)
190
- if not story or story.strip() == '...': # Check for empty or minimal story
191
- st.warning("Could not generate a meaningful story from the caption.")
192
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
193
  st.write(story)
194
  st.divider()
195
 
196
- # -- Step 4: Synthesize and play audio --
197
  st.subheader("👂 Hear the Story")
198
- with st.spinner("Synthesizing audio narration…"):
199
- tts_pipe = load_tts_pipe()
200
  try:
 
 
201
  audio_bytes = tts_bytes(story, tts_pipe)
202
- st.audio(audio_bytes, format="audio/wav")
 
 
 
 
 
 
 
203
  except Exception as e:
204
- st.error(f"Error generating audio: {e}")
 
205
 
206
 
207
- # Celebration animation
208
  st.balloons()
 
9
  import numpy as np # numerical operations, especially array handling
10
 
11
  # 1) CACHE & LOAD MODELS
12
+ # Use cache_resource for models/objects that should be loaded once per session/run
13
  @st.cache_resource(show_spinner=False)
14
  def load_captioner():
15
+ """Loads BLIP image-to-text model; cached so it loads only once."""
16
+ # Returns: a function captioner(image: PIL.Image) -> List[Dict]
17
+ # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
18
  return pipeline(
19
  "image-to-text",
20
  model="Salesforce/blip-image-captioning-base",
21
+ device="cpu"
22
  )
23
 
24
  @st.cache_resource(show_spinner=False)
25
  def load_story_pipe():
26
+ """Loads FLAN-T5 text-to-text model for story generation; cached once."""
27
  # Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
28
+ # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
29
  return pipeline(
30
  "text2text-generation",
31
  model="google/flan-t5-base",
32
+ device="cpu"
33
  )
34
 
35
  @st.cache_resource(show_spinner=False)
36
  def load_tts_pipe():
37
+ """Loads Meta MMS-TTS text-to-speech model; cached once."""
38
  # Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
39
+ # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
40
  return pipeline(
41
  "text-to-speech",
42
  model="facebook/mms-tts-eng",
43
+ device="cpu"
44
  )
45
 
46
  # 2) HELPER FUNCTIONS
47
  def sentence_case(text: str) -> str:
48
+ """
49
+ Splits text into sentences on .!? delimiters,
50
+ capitalizes the first character of each sentence,
51
+ then rejoins into a single string. Handles edge cases like leading/trailing spaces.
52
+ """
53
+ # Split while keeping the delimiters
54
+ parts = re.split(r'([.!?])', text)
55
+
56
  out = []
57
+ # Iterate through parts, taking text followed by delimiter
58
  for i in range(0, len(parts) - 1, 2):
59
+ sentence = parts[i].strip() # Get the sentence text and remove surrounding whitespace
60
+ delimiter = parts[i + 1] # Get the delimiter
61
+ if sentence: # Only process if there's actual text
62
+ # Capitalize the first letter of the cleaned sentence part
63
+ formatted_sentence = sentence[0].upper() + sentence[1:]
64
+ # Append the formatted sentence and its delimiter
65
+ out.append(f"{formatted_sentence}{delimiter}")
66
+ elif delimiter.strip(): # Handle cases where there's just a delimiter (e.g., "...")
67
+ out.append(delimiter)
68
+
69
 
70
+ # Handle any remaining part if the text didn't end with a delimiter
71
  if len(parts) % 2:
72
+ last_part = parts[-1].strip()
73
+ if last_part:
74
+ # Capitalize the first letter of the last part
75
+ formatted_last_part = last_part[0].upper() + last_part[1:]
76
+ out.append(formatted_last_part)
77
+
78
+ # Join parts and clean up potential excess spaces
79
+ # Join with a space first, then split and rejoin to handle multiple spaces
 
80
  return " ".join(" ".join(out).split())
81
 
82
 
83
  def caption_image(img: Image.Image, captioner) -> str:
84
+ """
85
+ Given a PIL image and a captioner pipeline, returns a single-line caption.
86
+ """
87
+ # Ensure image is in RGB format, as some models might expect it
88
+ if img.mode != "RGB":
89
+ img = img.convert("RGB")
90
+
91
  results = captioner(img) # run model
92
  if not results:
93
  return ""
 
95
  return results[0].get("generated_text", "")
96
 
97
  def story_from_caption(caption: str, pipe) -> str:
98
+ """
99
+ Given a caption string and a text2text pipeline, returns a ~100-word story.
100
+ """
101
+ if not caption:
102
+ return "Could not generate a story without a caption."
103
+
104
  prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
105
+ # Add a directive for slightly more coherence
106
+ prompt += "\n\nWrite a creative and descriptive short story."
107
+
108
  results = pipe(
109
  prompt,
110
  max_length=120, # increased max length slightly
111
+ min_length=60, # reduced min length slightly for robustness
112
+ do_sample=True, # enable sampling for creativity
113
  top_k=100, # sample from top_k tokens
114
  top_p=0.9, # nucleus sampling threshold
115
+ temperature=0.8, # slightly increased temperature for more randomness
116
  repetition_penalty=1.1, # discourage repetition
117
  no_repeat_ngram_size=4, # block repeated n-grams
118
  early_stopping=False
119
  )
120
  raw = results[0]["generated_text"].strip() # full generated text
121
+
122
  # strip out the prompt if it echoes back - make comparison case-insensitive
123
+ # Check if the generated text starts with a substantial part of the prompt
124
+ prompt_check_length = min(len(prompt) // 2, 50) # Check against first half or 50 chars
125
+ if raw.lower().startswith(prompt.lower()[:prompt_check_length]):
126
+ # Attempt to remove the echoed prompt more robustly
127
+ raw = re.sub(re.escape(prompt), '', raw, count=1, flags=re.IGNORECASE).strip()
128
+
129
 
130
  # trim to last complete sentence ending in . ! or ?
131
+ # Search for the first punctuation from the end of the string
132
+ match = re.search(r'[.!?]', raw[::-1])
133
  if match:
134
+ # Trim the string at the position of the found punctuation
135
+ raw = raw[:len(raw) - match.start()]
136
+ elif len(raw) > 80: # If no punctuation found and story is long, trim and add ellipsis
137
+ raw = raw[:raw.rfind(' ') if raw.rfind(' ') != -1 and raw.rfind(' ') > 60 else 80] + "..."
138
+ elif len(raw) < 20: # If the story is very short and has no punctuation
139
+ raw += "..." # Add ellipsis to indicate it might be incomplete
140
+
141
 
142
  return sentence_case(raw)
143
 
144
+
145
  def tts_bytes(text: str, tts_pipe) -> bytes:
146
+ """
147
+ Given a text string and a tts pipeline, returns WAV-format bytes.
148
+ Cleans text for better TTS performance and handles audio data conversion.
149
+ """
150
+ if not text:
151
+ return b"" # Return empty bytes if no text
152
+
153
+ # Clean up text for TTS - remove leading/trailing quotes, extra whitespace
154
  cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
155
+ # Replace multiple periods, handle ellipsis character
156
+ cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text)
157
+ cleaned_text = cleaned_text.replace('…', '...')
158
+ # Ensure text ends with punctuation for better natural speech flow
159
  if cleaned_text and cleaned_text[-1] not in '.!?':
160
  cleaned_text += '.'
161
+ # Remove excessive internal whitespace
162
+ cleaned_text = " ".join(cleaned_text.split())
163
+
164
+ if not cleaned_text:
165
+ return b"" # Return empty bytes if cleaning results in empty string
166
+
167
 
168
  output = tts_pipe(cleaned_text)
169
  # pipeline may return list or single dict
170
  result = output[0] if isinstance(output, list) else output
171
+
172
+ audio_array = result.get("audio") # numpy array: (channels, samples) or (samples,)
173
+ rate = result.get("sampling_rate") # sampling rate integer
174
+
175
+ if audio_array is None or rate is None:
176
+ st.error("TTS pipeline did not return expected audio data.")
177
+ return b""
178
 
179
  # ensure audio_array is 2D (samples, channels) for consistent handling
180
  if audio_array.ndim == 1:
 
188
 
189
  buffer = io.BytesIO()
190
  wf = wave.open(buffer, "wb")
191
+ try:
192
+ wf.setnchannels(data.shape[1] if data.ndim == 2 else 1) # set number of channels
193
+ wf.setsampwidth(2) # 16 bits = 2 bytes
194
+ wf.setframerate(rate) # samples per second
195
+ wf.writeframes(pcm.tobytes()) # write PCM data
196
+ finally:
197
+ wf.close() # Ensure the wave file object is closed
198
+
199
  buffer.seek(0)
200
  return buffer.read() # return raw WAV bytes
201
 
202
  # 3) STREAMLIT USER INTERFACE
203
+ # --- Page Config ---
204
  st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
205
+
206
+ # --- Title and Intro ---
207
  st.title("✨ Imagine & Narrate")
208
  st.write("Upload any image below to see AI imagine and narrate a story about it!")
209
 
210
+ # --- File Uploader ---
211
  uploaded = st.file_uploader(
212
  "Choose an image file",
213
+ type=["jpg", "jpeg", "png"] # Specify allowed types
214
+ # Add an optional help text
215
+ # help="Supported formats: JPG, JPEG, PNG."
216
  )
217
+
218
+ # --- Handle No Upload ---
219
  if not uploaded:
220
  st.info("➡️ Upload an image above to start the magic!")
221
+ st.stop() # Halt execution until file is uploaded
 
 
 
 
 
 
 
222
 
223
+ # --- Image Loading ---
224
+ # Use st.status for a nicer progress/status display during potentially slow steps
225
+ with st.status("Loading image...", expanded=True) as status:
226
+ try:
227
+ status.update(label="Opening image file...", state="running")
228
+ img = Image.open(uploaded)
229
+ status.update(label="Image loaded successfully!", state="complete", expanded=False)
230
+ except Exception as e:
231
+ status.update(label=f"Error loading image: {e}", state="error")
232
+ st.error(f"Could not load the image. Please try a different file. Error: {e}")
233
+ st.stop() # Stop if image loading fails
234
 
235
+ # --- Display Image ---
236
  st.subheader("📸 Your Visual Input")
237
+ st.image(img, use_container_width=True, caption=uploaded.name) # Add caption with filename
238
  st.divider()
239
 
240
+ # --- Step 2: Generate Caption ---
241
  st.subheader("🧠 Generating Insights")
242
+ # Using st.status again for the pipeline steps
243
+ with st.status("Scanning image for key elements…", expanded=True) as status:
244
+ try:
245
+ status.update(label="Running image captioning model...", state="running")
246
+ captioner = load_captioner()
247
+ raw_caption = caption_image(img, captioner)
248
+
249
+ if not raw_caption:
250
+ status.update(label="Image analysis failed.", state="error")
251
+ st.warning("Could not generate a caption for the image.")
252
+ st.stop()
253
+
254
+ caption = sentence_case(raw_caption)
255
+ status.update(label="Image analyzed, caption generated!", state="complete", expanded=False)
256
+
257
+ except Exception as e:
258
+ status.update(label=f"Error during image analysis: {e}", state="error")
259
+ st.error(f"An error occurred during image analysis: {e}")
260
+ st.stop()
261
+
262
+
263
  st.markdown(f"**Identified Scene:** {caption}")
264
  st.divider()
265
 
266
+ # --- Step 3: Generate Story ---
267
  st.subheader("📖 Crafting a Narrative")
268
+ with st.status("Writing a compelling story…", expanded=True) as status:
269
+ try:
270
+ status.update(label="Running story generation model...", state="running")
271
+ story_pipe = load_story_pipe()
272
+ story = story_from_caption(caption, story_pipe)
273
+
274
+ if not story or story.strip() in ['.', '..', '...']: # Check for empty or minimal story
275
+ status.update(label="Story generation failed.", state="error")
276
+ st.warning("Could not generate a meaningful story from the caption.")
277
+ st.stop()
278
+
279
+ status.update(label="Story crafted!", state="complete", expanded=False)
280
+
281
+ except Exception as e:
282
+ status.update(label=f"Error during story generation: {e}", state="error")
283
+ st.error(f"An error occurred during story generation: {e}")
284
+ st.stop()
285
+
286
  st.write(story)
287
  st.divider()
288
 
289
+ # --- Step 4: Synthesize Audio ---
290
  st.subheader("👂 Hear the Story")
291
+ with st.status("Synthesizing audio narration…", expanded=True) as status:
 
292
  try:
293
+ status.update(label="Running text-to-speech model...", state="running")
294
+ tts_pipe = load_tts_pipe()
295
  audio_bytes = tts_bytes(story, tts_pipe)
296
+
297
+ if not audio_bytes:
298
+ status.update(label="Audio generation failed.", state="error")
299
+ st.warning("Could not generate audio for the story.")
300
+ else:
301
+ status.update(label="Audio generated!", state="complete", expanded=False)
302
+ st.audio(audio_bytes, format="audio/wav")
303
+
304
  except Exception as e:
305
+ status.update(label=f"Error during audio synthesis: {e}", state="error")
306
+ st.error(f"An error occurred during audio synthesis: {e}")
307
 
308
 
309
+ # --- Celebration ---
310
  st.balloons()