justREE commited on
Commit
f136dda
·
verified ·
1 Parent(s): 70a4e6d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +95 -229
src/streamlit_app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import io # for creating in-memory binary streams
4
  import wave # for writing WAV audio files
5
  import re # for regular expression utilities
@@ -8,303 +6,171 @@ from transformers import pipeline # Hugging Face inference pipelines
8
  from PIL import Image # Python Imaging Library for image loading
9
  import numpy as np # numerical operations, especially array handling
10
 
11
- # 1) CACHE & LOAD MODELS
12
- # Use cache_resource for models/objects that should be loaded once per session/run
13
  @st.cache_resource(show_spinner=False)
14
  def load_captioner():
15
- """Loads BLIP image-to-text model; cached so it loads only once."""
16
- # Returns: a function captioner(image: PIL.Image) -> List[Dict]
17
- # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
18
  return pipeline(
19
  "image-to-text",
20
  model="Salesforce/blip-image-captioning-base",
21
- device="cpu"
22
  )
23
 
24
  @st.cache_resource(show_spinner=False)
25
  def load_story_pipe():
26
- """Loads FLAN-T5 text-to-text model for story generation; cached once."""
27
- # Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
28
- # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
29
  return pipeline(
30
  "text2text-generation",
31
  model="google/flan-t5-base",
32
- device="cpu"
33
  )
34
 
35
  @st.cache_resource(show_spinner=False)
36
  def load_tts_pipe():
37
- """Loads Meta MMS-TTS text-to-speech model; cached once."""
38
- # Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
39
- # Using device="cpu" for broader compatibility. Change to "cuda" for GPU.
40
  return pipeline(
41
  "text-to-speech",
42
  model="facebook/mms-tts-eng",
43
- device="cpu"
44
  )
45
 
46
  # 2) HELPER FUNCTIONS
47
  def sentence_case(text: str) -> str:
48
- """
49
- Splits text into sentences on .!? delimiters,
50
- capitalizes the first character of each sentence,
51
- then rejoins into a single string. Handles edge cases like leading/trailing spaces.
52
- """
53
- # Split while keeping the delimiters
54
  parts = re.split(r'([.!?])', text)
55
-
56
  out = []
57
- # Iterate through parts, taking text followed by delimiter
58
  for i in range(0, len(parts) - 1, 2):
59
- sentence = parts[i].strip() # Get the sentence text and remove surrounding whitespace
60
- delimiter = parts[i + 1] # Get the delimiter
61
- if sentence: # Only process if there's actual text
62
- # Capitalize the first letter of the cleaned sentence part
63
- formatted_sentence = sentence[0].upper() + sentence[1:]
64
- # Append the formatted sentence and its delimiter
65
- out.append(f"{formatted_sentence}{delimiter}")
66
- elif delimiter.strip(): # Handle cases where there's just a delimiter (e.g., "...")
67
- out.append(delimiter)
68
-
69
-
70
- # Handle any remaining part if the text didn't end with a delimiter
71
  if len(parts) % 2:
72
- last_part = parts[-1].strip()
73
- if last_part:
74
- # Capitalize the first letter of the last part
75
- formatted_last_part = last_part[0].upper() + last_part[1:]
76
- out.append(formatted_last_part)
77
-
78
- # Join parts and clean up potential excess spaces
79
- # Join with a space first, then split and rejoin to handle multiple spaces
80
  return " ".join(" ".join(out).split())
81
 
82
-
83
  def caption_image(img: Image.Image, captioner) -> str:
84
- """
85
- Given a PIL image and a captioner pipeline, returns a single-line caption.
86
- """
87
- # Ensure image is in RGB format, as some models might expect it
88
  if img.mode != "RGB":
89
  img = img.convert("RGB")
90
-
91
- results = captioner(img) # run model
92
- if not results:
93
- return ""
94
- # extract "generated_text" field from first result
95
- return results[0].get("generated_text", "")
96
 
97
  def story_from_caption(caption: str, pipe) -> str:
98
- """
99
- Given a caption string and a text2text pipeline, returns a ~100-word story.
100
- """
101
  if not caption:
102
  return "Could not generate a story without a caption."
103
-
104
- prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
105
- # Add a directive for slightly more coherence
106
- prompt += "\n\nWrite a creative and descriptive short story."
107
-
108
  results = pipe(
109
  prompt,
110
- max_length=120, # increased max length slightly
111
- min_length=60, # reduced min length slightly for robustness
112
- do_sample=True, # enable sampling for creativity
113
- top_k=100, # sample from top_k tokens
114
- top_p=0.9, # nucleus sampling threshold
115
- temperature=0.8, # slightly increased temperature for more randomness
116
- repetition_penalty=1.1, # discourage repetition
117
- no_repeat_ngram_size=4, # block repeated n-grams
118
  early_stopping=False
119
  )
120
- raw = results[0]["generated_text"].strip() # full generated text
121
-
122
- # strip out the prompt if it echoes back - make comparison case-insensitive
123
- # Check if the generated text starts with a substantial part of the prompt
124
- prompt_check_length = min(len(prompt) // 2, 50) # Check against first half or 50 chars
125
- if raw.lower().startswith(prompt.lower()[:prompt_check_length]):
126
- # Attempt to remove the echoed prompt more robustly
127
- raw = re.sub(re.escape(prompt), '', raw, count=1, flags=re.IGNORECASE).strip()
128
-
129
-
130
- # trim to last complete sentence ending in . ! or ?
131
- # Search for the first punctuation from the end of the string
132
- match = re.search(r'[.!?]', raw[::-1])
133
- if match:
134
- # Trim the string at the position of the found punctuation
135
- raw = raw[:len(raw) - match.start()]
136
- elif len(raw) > 80: # If no punctuation found and story is long, trim and add ellipsis
137
- raw = raw[:raw.rfind(' ') if raw.rfind(' ') != -1 and raw.rfind(' ') > 60 else 80] + "..."
138
- elif len(raw) < 20: # If the story is very short and has no punctuation
139
- raw += "..." # Add ellipsis to indicate it might be incomplete
140
-
141
-
142
  return sentence_case(raw)
143
 
144
-
145
  def tts_bytes(text: str, tts_pipe) -> bytes:
146
- """
147
- Given a text string and a tts pipeline, returns WAV-format bytes.
148
- Cleans text for better TTS performance and handles audio data conversion.
149
- """
150
  if not text:
151
- return b"" # Return empty bytes if no text
152
-
153
- # Clean up text for TTS - remove leading/trailing quotes, extra whitespace
154
- cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
155
- # Replace multiple periods, handle ellipsis character
156
- cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text)
157
- cleaned_text = cleaned_text.replace('…', '...')
158
- # Ensure text ends with punctuation for better natural speech flow
159
- if cleaned_text and cleaned_text[-1] not in '.!?':
160
- cleaned_text += '.'
161
- # Remove excessive internal whitespace
162
- cleaned_text = " ".join(cleaned_text.split())
163
-
164
- if not cleaned_text:
165
- return b"" # Return empty bytes if cleaning results in empty string
166
-
167
-
168
- output = tts_pipe(cleaned_text)
169
- # pipeline may return list or single dict
170
  result = output[0] if isinstance(output, list) else output
171
-
172
- audio_array = result.get("audio") # numpy array: (channels, samples) or (samples,)
173
- rate = result.get("sampling_rate") # sampling rate integer
174
-
175
  if audio_array is None or rate is None:
176
- st.error("TTS pipeline did not return expected audio data.")
177
  return b""
178
-
179
- # ensure audio_array is 2D (samples, channels) for consistent handling
180
  if audio_array.ndim == 1:
181
- data = audio_array[:, np.newaxis] # add channel dimension
182
  else:
183
- data = audio_array.T # transpose from (channels, samples) to (samples, channels)
184
-
185
-
186
- # convert float32 [-1..1] to int16 PCM [-32768..32767]
187
  pcm = (data * 32767).astype(np.int16)
188
-
189
- buffer = io.BytesIO()
190
- wf = wave.open(buffer, "wb")
191
- try:
192
- wf.setnchannels(data.shape[1] if data.ndim == 2 else 1) # set number of channels
193
- wf.setsampwidth(2) # 16 bits = 2 bytes
194
- wf.setframerate(rate) # samples per second
195
- wf.writeframes(pcm.tobytes()) # write PCM data
196
- finally:
197
- wf.close() # Ensure the wave file object is closed
198
-
199
- buffer.seek(0)
200
- return buffer.read() # return raw WAV bytes
201
 
202
  # 3) STREAMLIT USER INTERFACE
203
- # --- Page Config ---
204
- st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
205
 
206
- # --- Title and Intro ---
207
- st.title(" Imagine & Narrate")
208
- st.write("Upload any image below to see AI imagine and narrate a story about it!")
209
 
210
- # --- File Uploader ---
211
- uploaded = st.file_uploader(
212
  "Choose an image file",
213
- type=["jpg", "jpeg", "png"] # Specify allowed types
214
- # Add an optional help text
215
- # help="Supported formats: JPG, JPEG, PNG."
216
  )
 
 
217
 
218
- # --- Handle No Upload ---
219
- if not uploaded:
220
  st.info("➡️ Upload an image above to start the magic!")
221
- st.stop() # Halt execution until file is uploaded
222
 
223
- # --- Image Loading ---
224
- # Use st.status for a nicer progress/status display during potentially slow steps
225
- with st.status("Loading image...", expanded=True) as status:
226
- try:
227
- status.update(label="Opening image file...", state="running")
228
- img = Image.open(uploaded)
229
- status.update(label="Image loaded successfully!", state="complete", expanded=False)
230
- except Exception as e:
231
- status.update(label=f"Error loading image: {e}", state="error")
232
- st.error(f"Could not load the image. Please try a different file. Error: {e}")
233
- st.stop() # Stop if image loading fails
234
 
235
- # --- Display Image ---
236
  st.subheader("📸 Your Visual Input")
237
- st.image(img, use_container_width=True, caption=uploaded.name) # Add caption with filename
238
  st.divider()
239
 
240
- # --- Step 2: Generate Caption ---
241
- st.subheader("🧠 Generating Insights")
242
- # Using st.status again for the pipeline steps
243
- with st.status("Scanning image for key elements…", expanded=True) as status:
244
- try:
245
- status.update(label="Running image captioning model...", state="running")
246
- captioner = load_captioner()
247
- raw_caption = caption_image(img, captioner)
248
-
249
- if not raw_caption:
250
- status.update(label="Image analysis failed.", state="error")
251
- st.warning("Could not generate a caption for the image.")
252
- st.stop()
253
-
254
- caption = sentence_case(raw_caption)
255
- status.update(label="Image analyzed, caption generated!", state="complete", expanded=False)
256
-
257
- except Exception as e:
258
- status.update(label=f"Error during image analysis: {e}", state="error")
259
- st.error(f"An error occurred during image analysis: {e}")
260
  st.stop()
261
-
262
-
263
  st.markdown(f"**Identified Scene:** {caption}")
264
  st.divider()
265
 
266
- # --- Step 3: Generate Story ---
267
- st.subheader("📖 Crafting a Narrative")
268
- with st.status("Writing a compelling story", expanded=True) as status:
269
- try:
270
- status.update(label="Running story generation model...", state="running")
271
- story_pipe = load_story_pipe()
272
- story = story_from_caption(caption, story_pipe)
273
-
274
- if not story or story.strip() in ['.', '..', '...']: # Check for empty or minimal story
275
- status.update(label="Story generation failed.", state="error")
276
- st.warning("Could not generate a meaningful story from the caption.")
277
- st.stop()
278
-
279
- status.update(label="Story crafted!", state="complete", expanded=False)
280
-
281
- except Exception as e:
282
- status.update(label=f"Error during story generation: {e}", state="error")
283
- st.error(f"An error occurred during story generation: {e}")
284
  st.stop()
285
-
286
  st.write(story)
287
  st.divider()
288
 
289
- # --- Step 4: Synthesize Audio ---
290
  st.subheader("👂 Hear the Story")
291
- with st.status("Synthesizing audio narration…", expanded=True) as status:
292
- try:
293
- status.update(label="Running text-to-speech model...", state="running")
294
- tts_pipe = load_tts_pipe()
295
- audio_bytes = tts_bytes(story, tts_pipe)
296
-
297
- if not audio_bytes:
298
- status.update(label="Audio generation failed.", state="error")
299
- st.warning("Could not generate audio for the story.")
300
- else:
301
- status.update(label="Audio generated!", state="complete", expanded=False)
302
- st.audio(audio_bytes, format="audio/wav")
303
-
304
- except Exception as e:
305
- status.update(label=f"Error during audio synthesis: {e}", state="error")
306
- st.error(f"An error occurred during audio synthesis: {e}")
307
-
308
-
309
- # --- Celebration ---
310
  st.balloons()
 
 
 
1
  import io # for creating in-memory binary streams
2
  import wave # for writing WAV audio files
3
  import re # for regular expression utilities
 
6
  from PIL import Image # Python Imaging Library for image loading
7
  import numpy as np # numerical operations, especially array handling
8
 
9
+ # 1) CACHE & LOAD MODELS (CPU only)
 
10
  @st.cache_resource(show_spinner=False)
11
  def load_captioner():
 
 
 
12
  return pipeline(
13
  "image-to-text",
14
  model="Salesforce/blip-image-captioning-base",
15
+ device=-1 # force CPU
16
  )
17
 
18
  @st.cache_resource(show_spinner=False)
19
  def load_story_pipe():
 
 
 
20
  return pipeline(
21
  "text2text-generation",
22
  model="google/flan-t5-base",
23
+ device=-1 # force CPU
24
  )
25
 
26
  @st.cache_resource(show_spinner=False)
27
  def load_tts_pipe():
 
 
 
28
  return pipeline(
29
  "text-to-speech",
30
  model="facebook/mms-tts-eng",
31
+ device=-1 # force CPU
32
  )
33
 
34
  # 2) HELPER FUNCTIONS
35
  def sentence_case(text: str) -> str:
 
 
 
 
 
 
36
  parts = re.split(r'([.!?])', text)
 
37
  out = []
 
38
  for i in range(0, len(parts) - 1, 2):
39
+ sentence = parts[i].strip()
40
+ delimiter = parts[i + 1]
41
+ if sentence:
42
+ formatted = sentence[0].upper() + sentence[1:]
43
+ out.append(f"{formatted}{delimiter}")
 
 
 
 
 
 
 
44
  if len(parts) % 2:
45
+ last = parts[-1].strip()
46
+ if last:
47
+ formatted = last[0].upper() + last[1:]
48
+ out.append(formatted)
 
 
 
 
49
  return " ".join(" ".join(out).split())
50
 
 
51
  def caption_image(img: Image.Image, captioner) -> str:
 
 
 
 
52
  if img.mode != "RGB":
53
  img = img.convert("RGB")
54
+ results = captioner(img)
55
+ return (results[0].get("generated_text", "") if results else "")
 
 
 
 
56
 
57
  def story_from_caption(caption: str, pipe) -> str:
 
 
 
58
  if not caption:
59
  return "Could not generate a story without a caption."
60
+ prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}\n\nWrite a creative and descriptive short story."
 
 
 
 
61
  results = pipe(
62
  prompt,
63
+ max_length=120,
64
+ min_length=60,
65
+ do_sample=True,
66
+ top_k=100,
67
+ top_p=0.9,
68
+ temperature=0.8,
69
+ repetition_penalty=1.1,
70
+ no_repeat_ngram_size=4,
71
  early_stopping=False
72
  )
73
+ raw = results[0]["generated_text"].strip()
74
+ # Remove prompt echo if present
75
+ raw = re.sub(re.escape(prompt), "", raw, flags=re.IGNORECASE).strip()
76
+ # Trim to last full sentence
77
+ idx = max(raw.rfind("."), raw.rfind("!"), raw.rfind("?"))
78
+ if idx != -1:
79
+ raw = raw[:idx+1]
80
+ elif len(raw) > 80:
81
+ raw = raw[:raw.rfind(" ") if raw.rfind(" ") > 60 else 80] + "..."
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return sentence_case(raw)
83
 
 
84
  def tts_bytes(text: str, tts_pipe) -> bytes:
 
 
 
 
85
  if not text:
86
+ return b""
87
+ cleaned = re.sub(r'^["\']|["\']$', '', text).strip()
88
+ cleaned = re.sub(r'\.{2,}', '.', cleaned).replace('…', '...')
89
+ if cleaned[-1] not in ".!?":
90
+ cleaned += "."
91
+ cleaned = " ".join(cleaned.split())
92
+ output = tts_pipe(cleaned)
 
 
 
 
 
 
 
 
 
 
 
 
93
  result = output[0] if isinstance(output, list) else output
94
+ audio_array = result.get("audio")
95
+ rate = result.get("sampling_rate")
 
 
96
  if audio_array is None or rate is None:
 
97
  return b""
 
 
98
  if audio_array.ndim == 1:
99
+ data = audio_array[:, np.newaxis]
100
  else:
101
+ data = audio_array.T
 
 
 
102
  pcm = (data * 32767).astype(np.int16)
103
+ buf = io.BytesIO()
104
+ wf = wave.open(buf, "wb")
105
+ wf.setnchannels(data.shape[1])
106
+ wf.setsampwidth(2)
107
+ wf.setframerate(rate)
108
+ wf.writeframes(pcm.tobytes())
109
+ wf.close()
110
+ buf.seek(0)
111
+ return buf.read()
 
 
 
 
112
 
113
  # 3) STREAMLIT USER INTERFACE
114
+ st.set_page_config(page_title="✨ Imagine & Narrate", page_icon="✨", layout="centered")
 
115
 
116
+ # Persist upload across reruns
117
+ if "uploaded_file" not in st.session_state:
118
+ st.session_state.uploaded_file = None
119
 
120
+ new_upload = st.file_uploader(
 
121
  "Choose an image file",
122
+ type=["jpg", "jpeg", "png"]
 
 
123
  )
124
+ if new_upload is not None:
125
+ st.session_state.uploaded_file = new_upload
126
 
127
+ if st.session_state.uploaded_file is None:
128
+ st.title("✨ Imagine & Narrate")
129
  st.info("➡️ Upload an image above to start the magic!")
130
+ st.stop()
131
 
132
+ uploaded = st.session_state.uploaded_file
133
+ try:
134
+ img = Image.open(uploaded)
135
+ except Exception as e:
136
+ st.error(f"Could not load the image: {e}")
137
+ st.stop()
 
 
 
 
 
138
 
139
+ st.title("✨ Imagine & Narrate")
140
  st.subheader("📸 Your Visual Input")
141
+ st.image(img, caption=uploaded.name, use_container_width=True)
142
  st.divider()
143
 
144
+ # Step 1: Generate Caption
145
+ st.subheader("🧠 Generating Caption")
146
+ with st.spinner("Analyzing image..."):
147
+ captioner = load_captioner()
148
+ raw_caption = caption_image(img, captioner)
149
+ if not raw_caption:
150
+ st.error("Failed to generate caption.")
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  st.stop()
152
+ caption = sentence_case(raw_caption)
 
153
  st.markdown(f"**Identified Scene:** {caption}")
154
  st.divider()
155
 
156
+ # Step 2: Generate Story
157
+ st.subheader("📖 Crafting a Story")
158
+ with st.spinner("Writing story..."):
159
+ story_pipe = load_story_pipe()
160
+ story = story_from_caption(caption, story_pipe)
161
+ if not story or story.strip() in {".", "..", "..."}:
162
+ st.error("Failed to generate story.")
 
 
 
 
 
 
 
 
 
 
 
163
  st.stop()
 
164
  st.write(story)
165
  st.divider()
166
 
167
+ # Step 3: Synthesize Audio
168
  st.subheader("👂 Hear the Story")
169
+ with st.spinner("Synthesizing audio..."):
170
+ tts_pipe = load_tts_pipe()
171
+ audio_bytes = tts_bytes(story, tts_pipe)
172
+ if not audio_bytes:
173
+ st.warning("Audio generation failed.")
174
+ else:
175
+ st.audio(audio_bytes, format="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
176
  st.balloons()