Spaces:
Sleeping
Sleeping
File size: 8,317 Bytes
77c4802 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# app.py
import io # for creating in-memory binary streams
import wave # for writing WAV audio files
import re # for regular expression utilities
import streamlit as st # Streamlit UI library
from transformers import pipeline # Hugging Face inference pipelines
from PIL import Image # Python Imaging Library for image loading
import numpy as np # numerical operations, especially array handling
# 1) CACHE & LOAD MODELS
@st.cache_resource(show_spinner=False)
def load_captioner():
# Loads BLIP image-to-text model; cached so it loads only once.
# Returns: a function captioner(image: PIL.Image) -> List[Dict],
return pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device="cpu" # Can change to "cuda" if GPU is available
)
@st.cache_resource(show_spinner=False)
def load_story_pipe():
# Loads FLAN-T5 text-to-text model for story generation; cached once.
# Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
return pipeline(
"text2text-generation",
model="google/flan-t5-base",
device="cpu" # Can change to "cuda" if GPU is available
)
@st.cache_resource(show_spinner=False)
def load_tts_pipe():
# Loads Meta MMS-TTS text-to-speech model; cached once.
# Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
return pipeline(
"text-to-speech",
model="facebook/mms-tts-eng",
device="cpu" # Can change to "cuda" if GPU is available
)
# 2) HELPER FUNCTIONS
def sentence_case(text: str) -> str:
# Splits text into sentences on .!? delimiters,
# capitalizes the first character of each sentence,
# then rejoins into a single string.
parts = re.split(r'([.!?])', text) # ["hello", ".", " world", "!"]
out = []
for i in range(0, len(parts) - 1, 2):
sentence = parts[i].strip().capitalize() # capitalize first letter
delimiter = parts[i + 1] # punctuation
# Ensure a space before the sentence if it wasn't the very first part
if out and not sentence.startswith(' ') and out[-1][-1] not in '.!?':
out.append(f" {sentence}{delimiter}")
else:
out.append(f"{sentence}{delimiter}")
# If trailing text without punctuation exists, capitalize and append it.
if len(parts) % 2:
last = parts[-1].strip().capitalize()
if last:
# Ensure a space before if needed
if out and not last.startswith(' ') and out[-1][-1] not in '.!?':
out.append(f" {last}")
else:
out.append(last)
# Clean up potential multiple spaces resulting from split/join
return " ".join(" ".join(out).split())
def caption_image(img: Image.Image, captioner) -> str:
# Given a PIL image and a captioner pipeline, returns a single-line caption.
results = captioner(img) # run model
if not results:
return ""
# extract "generated_text" field from first result
return results[0].get("generated_text", "")
def story_from_caption(caption: str, pipe) -> str:
# Given a caption string and a text2text pipeline, returns a ~100-word story.
prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
results = pipe(
prompt,
max_length=120, # increased max length slightly
min_length=80, # minimum generated tokens
do_sample=True, # enable sampling
top_k=100, # sample from top_k tokens
top_p=0.9, # nucleus sampling threshold
temperature=0.7, # sampling temperature
repetition_penalty=1.1, # discourage repetition
no_repeat_ngram_size=4, # block repeated n-grams
early_stopping=False
)
raw = results[0]["generated_text"].strip() # full generated text
# strip out the prompt if it echoes back - make comparison case-insensitive
if raw.lower().startswith(prompt.lower()):
raw = raw[len(prompt):].strip()
# trim to last complete sentence ending in . ! or ?
match = re.search(r'[.!?]', raw[::-1]) # Search for the first punctuation from the end
if match:
raw = raw[:len(raw) - match.start()] # Trim at that position
elif len(raw) > 80: # If no punctuation found but story is long, trim to a reasonable length
raw = raw[:80] + "..."
return sentence_case(raw)
def tts_bytes(text: str, tts_pipe) -> bytes:
# Given a text string and a tts pipeline, returns WAV-format bytes.
# Clean up text for TTS - remove leading/trailing quotes, etc.
cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
# Basic punctuation cleaning (optional, depending on TTS model)
cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text) # Replace multiple periods with one
cleaned_text = cleaned_text.replace('…', '...') # Replace ellipsis char with dots
# Add a period if the text doesn't end with punctuation (helps TTS model finalize)
if cleaned_text and cleaned_text[-1] not in '.!?':
cleaned_text += '.'
output = tts_pipe(cleaned_text)
# pipeline may return list or single dict
result = output[0] if isinstance(output, list) else output
audio_array = result["audio"] # numpy array: (channels, samples) or (samples,)
rate = result["sampling_rate"] # sampling rate integer
# ensure audio_array is 2D (samples, channels) for consistent handling
if audio_array.ndim == 1:
data = audio_array[:, np.newaxis] # add channel dimension
else:
data = audio_array.T # transpose from (channels, samples) to (samples, channels)
# convert float32 [-1..1] to int16 PCM [-32768..32767]
pcm = (data * 32767).astype(np.int16)
buffer = io.BytesIO()
wf = wave.open(buffer, "wb")
wf.setnchannels(data.shape[1]) # number of channels
wf.setsampwidth(2) # 16 bits = 2 bytes
wf.setframerate(rate) # samples per second
wf.writeframes(pcm.tobytes()) # write PCM data
wf.close()
buffer.seek(0)
return buffer.read() # return raw WAV bytes
# 3) STREAMLIT USER INTERFACE
st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
st.title("✨ Imagine & Narrate")
st.write("Upload any image below to see AI imagine and narrate a story about it!")
# -- Upload image widget --
uploaded = st.file_uploader(
"Choose an image file",
type=["jpg", "jpeg", "png"]
)
if not uploaded:
st.info("➡️ Upload an image above to start the magic!")
st.stop()
# Load the uploaded file into a PIL Image
try:
img = Image.open(uploaded)
except Exception as e:
st.error(f"Error loading image: {e}")
st.stop()
# -- Step 1: Display the image --
st.subheader("📸 Your Visual Input")
st.image(img, use_container_width=True)
st.divider()
# -- Step 2: Generate and display caption --
st.subheader("🧠 Generating Insights")
with st.spinner("Scanning image for key elements…"):
captioner = load_captioner()
raw_caption = caption_image(img, captioner)
if not raw_caption:
st.warning("Could not generate a caption for the image.")
st.stop()
caption = sentence_case(raw_caption)
st.markdown(f"**Identified Scene:** {caption}")
st.divider()
# -- Step 3: Generate and display story --
st.subheader("📖 Crafting a Narrative")
with st.spinner("Writing a compelling story…"):
story_pipe = load_story_pipe()
story = story_from_caption(caption, story_pipe)
if not story or story.strip() == '...': # Check for empty or minimal story
st.warning("Could not generate a meaningful story from the caption.")
st.stop()
st.write(story)
st.divider()
# -- Step 4: Synthesize and play audio --
st.subheader("👂 Hear the Story")
with st.spinner("Synthesizing audio narration…"):
tts_pipe = load_tts_pipe()
try:
audio_bytes = tts_bytes(story, tts_pipe)
st.audio(audio_bytes, format="audio/wav")
except Exception as e:
st.error(f"Error generating audio: {e}")
# Celebration animation
st.balloons() |