Szeyu's picture
Update app.py
ed3e053 verified
raw
history blame
7.15 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import io, textwrap, numpy as np, soundfile as sf
import logging
# Set up logging for debugging
logging.basicConfig(level=logging.INFO)
# ------------------ Streamlit Page Configuration ------------------
st.set_page_config(
page_title="Picture to Story Magic",
page_icon="🦄",
layout="centered"
)
# ------------------ Custom CSS for a Colorful Background ------------------
st.markdown(
"""
<style>
body {
background-color: #FDEBD0;
}
</style>
""",
unsafe_allow_html=True
)
# ------------------ Playful Header for Young Users ------------------
st.markdown(
"""
<h1 style='text-align: center; color: #ff66cc;'>Picture to Story Magic!</h1>
<p style='text-align: center; font-size: 24px;'>
Hi little artist! Upload your picture and let us create a fun story just for you! 🎉
</p>
""",
unsafe_allow_html=True
)
# ------------------ Lazy Model Loading ------------------
def load_models():
"""
Lazy-load the required pipelines and store them in session state.
"""
if "captioner" not in st.session_state:
st.session_state.captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large"
)
if "storyer" not in st.session_state:
try:
st.session_state.storyer = pipeline(
"text-generation",
model="aspis/gpt2-genre-story-generation"
)
except Exception as e:
logging.warning(f"Failed to load aspis/gpt2-genre-story-generation: {e}. Falling back to gpt2.")
st.session_state.storyer = pipeline(
"text-generation",
model="gpt2"
)
if "tts" not in st.session_state:
st.session_state.tts = pipeline(
"text-to-speech",
model="facebook/mms-tts-eng"
)
# ------------------ Caching Functions ------------------
@st.cache_data(show_spinner=False)
def get_caption(image_bytes):
"""
Converts image bytes into a lower resolution image (maximum 256x256)
and generates a caption.
"""
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image.thumbnail((256, 256))
caption = st.session_state.captioner(image)[0]["generated_text"]
return caption
@st.cache_data(show_spinner=False)
def get_story(caption):
"""
Generates a humorous and engaging children's story based on the caption.
Uses a simplified prompt and robust output parsing.
"""
prompt = (
f"Create a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
f"Use third-person narrative, as if playfully describing the scene."
)
try:
result = st.session_state.storyer(
prompt,
max_new_tokens=150, # Increased to allow more room for story
do_sample=True,
temperature=0.8, # Slightly higher for creativity
top_p=0.9,
return_full_text=False
)
logging.info(f"Story generation raw result: {result}")
# Extract generated text
raw_story = result[0].get("generated_text", "").strip()
# If no meaningful output, generate a fallback story
if not raw_story or len(raw_story.split()) < 10:
logging.warning("Generated story too short or empty. Using fallback.")
raw_story = (
f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way. "
f"The tree sang funny songs, and all the animals joined in for a big party!"
)
# Truncate to 100 words
words = raw_story.split()
story = " ".join(words[:100])
return story
except Exception as e:
logging.error(f"Story generation failed: {e}")
# Fallback story in case of errors
return (
f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way."
)
@st.cache_data(show_spinner=False)
def get_audio(story):
"""
Converts the generated story text into audio.
Splits the text into 300-character chunks to reduce repeated TTS calls.
"""
chunks = textwrap.wrap(story, width=300)
audio_chunks = []
for chunk in chunks:
try:
output = st.session_state.tts(chunk)
if isinstance(output, list):
output = output[0]
if "audio" in output:
audio_array = np.array(output["audio"]).squeeze()
audio_chunks.append(audio_array)
except Exception:
continue
if not audio_chunks:
sr = st.session_state.tts.model.config.sampling_rate
audio = np.zeros(sr, dtype=np.float32)
else:
audio = np.concatenate(audio_chunks)
buffer = io.BytesIO()
sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV")
buffer.seek(0)
return buffer
# ------------------ Main App Logic ------------------
uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
load_models() # Ensure models are loaded
image_bytes = uploaded_file.getvalue()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
st.image(image, caption="Your Amazing Picture!", use_column_width=True)
st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
if st.button("Story, Please!"):
with st.spinner("Generating caption..."):
caption = get_caption(image_bytes)
st.markdown("<h3 style='text-align: center;'>Caption:</h3>", unsafe_allow_html=True)
st.write(caption)
with st.spinner("Generating story..."):
story = get_story(caption)
st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
if not story.strip():
st.write("No story was generated. Please try again.")
else:
st.write(story)
with st.spinner("Generating audio..."):
audio_buffer = get_audio(story)
st.audio(audio_buffer, format="audio/wav", start_time=0)
st.markdown(
"<p style='text-align: center; font-weight: bold;'>Enjoy your magical story! 🎶</p>",
unsafe_allow_html=True
)
except Exception as e:
st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
st.error(f"Error details: {e}")